Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward.
This commit is contained in:
parent
6f8207267e
commit
ceb59e87a6
@ -203,7 +203,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--window_shift $window_shift \
|
--window_shift $window_shift \
|
||||||
--step 0.001 \
|
--step 0.001 \
|
||||||
--score_file $result_dir/score.txt
|
--score_file $result_dir/score.txt \
|
||||||
|
--token_file data/tokens.txt \
|
||||||
|
--lexicon_file data/lexicon.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -154,17 +154,21 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--window_shift $window_shift \
|
--window_shift $window_shift \
|
||||||
--step 0.001 \
|
--step 0.001 \
|
||||||
--score_file $result_dir/score.txt
|
--score_file $result_dir/score.txt \
|
||||||
|
--token_file data/tokens.txt \
|
||||||
|
--lexicon_file data/lexicon.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input
|
||||||
|
# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
|
||||||
|
# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||||
|
# python wekws/bin/export_jit.py \
|
||||||
|
# --config $dir/config.yaml \
|
||||||
|
# --checkpoint $score_checkpoint \
|
||||||
|
# --jit_model $dir/$jit_model
|
||||||
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||||
python wekws/bin/export_jit.py \
|
|
||||||
--config $dir/config.yaml \
|
|
||||||
--checkpoint $score_checkpoint \
|
|
||||||
--jit_model $dir/$jit_model
|
|
||||||
python wekws/bin/export_onnx.py \
|
python wekws/bin/export_onnx.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pypinyin # for Chinese Character
|
import pypinyin # for Chinese Character
|
||||||
|
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||||
|
|
||||||
def split_mixed_label(input_str):
|
def split_mixed_label(input_str):
|
||||||
tokens = []
|
tokens = []
|
||||||
@ -43,7 +44,7 @@ def space_mixed_label(input_str):
|
|||||||
space_str = ''.join(f'{sub} ' for sub in splits)
|
space_str = ''.join(f'{sub} ' for sub in splits)
|
||||||
return space_str.strip()
|
return space_str.strip()
|
||||||
|
|
||||||
def load_label_and_score(keywords_list, label_file, score_file):
|
def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
|
||||||
score_table = {}
|
score_table = {}
|
||||||
with open(score_file, 'r', encoding='utf8') as fin:
|
with open(score_file, 'r', encoding='utf8') as fin:
|
||||||
# read score file and store in table
|
# read score file and store in table
|
||||||
@ -52,10 +53,11 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
|||||||
key = arr[0]
|
key = arr[0]
|
||||||
is_detected = arr[1]
|
is_detected = arr[1]
|
||||||
if is_detected == 'detected':
|
if is_detected == 'detected':
|
||||||
|
keyword=true_keywords[arr[2]]
|
||||||
if key not in score_table:
|
if key not in score_table:
|
||||||
score_table.update({
|
score_table.update({
|
||||||
key: {
|
key: {
|
||||||
'kw': space_mixed_label(arr[2]),
|
'kw': space_mixed_label(keyword),
|
||||||
'confi': float(arr[3])
|
'confi': float(arr[3])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -72,6 +74,7 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
|||||||
# build empty structure for keyword-filler infos
|
# build empty structure for keyword-filler infos
|
||||||
keyword_filler_table = {}
|
keyword_filler_table = {}
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
|
keyword = true_keywords[keyword]
|
||||||
keyword = space_mixed_label(keyword)
|
keyword = space_mixed_label(keyword)
|
||||||
keyword_filler_table[keyword] = {}
|
keyword_filler_table[keyword] = {}
|
||||||
keyword_filler_table[keyword]['keyword_table'] = {}
|
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||||
@ -93,6 +96,7 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
|||||||
assert key in score_table
|
assert key in score_table
|
||||||
|
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
|
keyword = true_keywords[keyword]
|
||||||
keyword = space_mixed_label(keyword)
|
keyword = space_mixed_label(keyword)
|
||||||
keyword_regstr_lrblk = ' ' + keyword + ' '
|
keyword_regstr_lrblk = ' ' + keyword + ' '
|
||||||
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
||||||
@ -154,6 +158,8 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser(description='compute det curve')
|
parser = argparse.ArgumentParser(description='compute det curve')
|
||||||
parser.add_argument('--test_data', required=True, help='label file')
|
parser.add_argument('--test_data', required=True, help='label file')
|
||||||
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
|
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
|
||||||
|
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
||||||
parser.add_argument('--score_file', required=True, help='score file')
|
parser.add_argument('--score_file', required=True, help='score file')
|
||||||
parser.add_argument('--step', type=float, default=0.01,
|
parser.add_argument('--step', type=float, default=0.01,
|
||||||
help='threshold step')
|
help='threshold step')
|
||||||
@ -183,9 +189,18 @@ if __name__ == '__main__':
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
window_shift = args.window_shift
|
window_shift = args.window_shift
|
||||||
keywords_list = args.keywords.strip().split(',')
|
keywords_list = args.keywords.strip().split(',')
|
||||||
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
|
|
||||||
|
token_table = read_token(args.token_file)
|
||||||
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
|
true_keywords = {}
|
||||||
|
for keyword in keywords_list:
|
||||||
|
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||||
|
true_keywords[keyword] = ''.join(strs)
|
||||||
|
|
||||||
|
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords)
|
||||||
|
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
|
keyword = true_keywords[keyword]
|
||||||
keyword = space_mixed_label(keyword)
|
keyword = space_mixed_label(keyword)
|
||||||
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||||
|
|||||||
@ -35,7 +35,10 @@ from tools.make_list import query_token_set, read_lexicon, read_token
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(description='detect keywords online.')
|
parser = argparse.ArgumentParser(description='detect keywords online.')
|
||||||
parser.add_argument('--config', required=True, help='config file')
|
parser.add_argument('--config', required=True, help='config file')
|
||||||
parser.add_argument('--wav_path', required=True, help='test wave path.')
|
parser.add_argument('--wav_path', required=False, default=None, help='test wave path.')
|
||||||
|
parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.')
|
||||||
|
parser.add_argument('--result_file', required=False, default=None, help='test result.')
|
||||||
|
|
||||||
parser.add_argument('--gpu',
|
parser.add_argument('--gpu',
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
@ -209,7 +212,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.left_context = dataset_conf['context_expansion_conf']['left']
|
self.left_context = dataset_conf['context_expansion_conf']['left']
|
||||||
self.right_context = dataset_conf['context_expansion_conf']['right']
|
self.right_context = dataset_conf['context_expansion_conf']['right']
|
||||||
self.feature_remained = None
|
self.feature_remained = None
|
||||||
self.feature_context_offset = 0 # after downsample, offset exist.
|
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||||
|
|
||||||
|
|
||||||
# model related
|
# model related
|
||||||
@ -289,6 +292,9 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
|
|
||||||
wave = np.array(data)
|
wave = np.array(data)
|
||||||
wave = np.append(self.wave_remained, wave)
|
wave = np.append(self.wave_remained, wave)
|
||||||
|
if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context :
|
||||||
|
self.wave_remained = wave
|
||||||
|
return None
|
||||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||||
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
|
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
|
||||||
feats = kaldi.fbank(wave_tensor,
|
feats = kaldi.fbank(wave_tensor,
|
||||||
@ -312,7 +318,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
feats_pad = torch.cat((self.feature_remained, feats))
|
feats_pad = torch.cat((self.feature_remained, feats))
|
||||||
|
|
||||||
ctx_frm = feats.shape[0] - self.right_context
|
ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context)
|
||||||
ctx_win = (self.left_context + self.right_context + 1)
|
ctx_win = (self.left_context + self.right_context + 1)
|
||||||
ctx_dim = feats.shape[1] * ctx_win
|
ctx_dim = feats.shape[1] * ctx_win
|
||||||
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
||||||
@ -320,12 +326,13 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
||||||
|
|
||||||
# update feature remained, and feats
|
# update feature remained, and feats
|
||||||
self.feature_remained = feats[-self.left_context:]
|
self.feature_remained = feats[-(self.left_context+self.right_context):]
|
||||||
feats = feats_ctx.to(self.device)
|
feats = feats_ctx.to(self.device)
|
||||||
if self.downsampling > 1:
|
if self.downsampling > 1:
|
||||||
feats = feats[self.feature_context_offset::self.downsampling, :]
|
last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset
|
||||||
complement = feats.size(1) % self.downsampling
|
remainder = (feats.size(0)+last_remainder) % self.downsampling
|
||||||
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
|
feats = feats[self.feats_ctx_offset::self.downsampling, :]
|
||||||
|
self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder
|
||||||
return feats
|
return feats
|
||||||
|
|
||||||
def decode_keywords(self, t, probs):
|
def decode_keywords(self, t, probs):
|
||||||
@ -396,11 +403,14 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, wave_chunk):
|
def forward(self, wave_chunk):
|
||||||
feature = self.accept_wave(wave_chunk)
|
feature = self.accept_wave(wave_chunk)
|
||||||
|
if feature is None or feature.size(0) < 1:
|
||||||
|
return self.result
|
||||||
feature = feature.unsqueeze(0) # add a batch dimension
|
feature = feature.unsqueeze(0) # add a batch dimension
|
||||||
logits, self.in_cache = self.model(feature, self.in_cache)
|
logits, self.in_cache = self.model(feature, self.in_cache)
|
||||||
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||||
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search
|
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search
|
||||||
for (t, prob) in enumerate(probs):
|
for (t, prob) in enumerate(probs):
|
||||||
|
t *= self.downsampling
|
||||||
self.decode_keywords(t, prob)
|
self.decode_keywords(t, prob)
|
||||||
self.execute_detection(t)
|
self.execute_detection(t)
|
||||||
|
|
||||||
@ -409,7 +419,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
# since a chunk include about 30 frames, once activated, we can jump the latter frames.
|
# since a chunk include about 30 frames, once activated, we can jump the latter frames.
|
||||||
# TODO: there should give another method to update result, avoiding self.result being cleared.
|
# TODO: there should give another method to update result, avoiding self.result being cleared.
|
||||||
break
|
break
|
||||||
self.total_frames += len(probs) # update frame offset
|
self.total_frames += len(probs) * self.downsampling # update frame offset
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -417,6 +427,15 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.activated = False
|
self.activated = False
|
||||||
self.hit_score = 1.0
|
self.hit_score = 1.0
|
||||||
|
|
||||||
|
def reset_all(self):
|
||||||
|
self.reset()
|
||||||
|
self.wave_remained = np.array([])
|
||||||
|
self.feature_remained = None
|
||||||
|
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||||
|
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
self.total_frames = 0 # frame offset, for absolute time
|
||||||
|
self.result = {}
|
||||||
|
|
||||||
def demo():
|
def demo():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
@ -435,22 +454,60 @@ def demo():
|
|||||||
args.gpu,
|
args.gpu,
|
||||||
args.jit_model)
|
args.jit_model)
|
||||||
|
|
||||||
# actually this could be done in __init__ method, we pull it outside for changing keywords.
|
# actually this could be done in __init__ method, we pull it outside for changing keywords more freely.
|
||||||
kws.set_keywords(args.keywords)
|
kws.set_keywords(args.keywords)
|
||||||
|
|
||||||
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
if args.wav_path:
|
||||||
# In demo we read wave in non-streaming fashion.
|
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
||||||
with wave.open(args.wav_path, 'rb') as fin:
|
# In demo we read wave in non-streaming fashion.
|
||||||
assert fin.getnchannels() == 1
|
with wave.open(args.wav_path, 'rb') as fin:
|
||||||
wav = fin.readframes(fin.getnframes())
|
assert fin.getnchannels() == 1
|
||||||
|
wav = fin.readframes(fin.getnframes())
|
||||||
|
|
||||||
# We inference every 0.3 seconds, in streaming fashion.
|
# We inference every 0.3 seconds, in streaming fashion.
|
||||||
interval = int(0.3 * 16000) * 2
|
interval = int(0.3 * 16000) * 2
|
||||||
for i in range(0, len(wav), interval):
|
for i in range(0, len(wav), interval):
|
||||||
chunk_wav = wav[i: min(i + interval, len(wav))]
|
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||||
result = kws.forward(chunk_wav)
|
result = kws.forward(chunk_wav)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
fout = None
|
||||||
|
if args.result_file:
|
||||||
|
fout = open(args.result_file, 'w', encoding='utf-8')
|
||||||
|
|
||||||
|
if args.wav_scp:
|
||||||
|
with open(args.wav_scp, 'r') as fscp:
|
||||||
|
for line in fscp:
|
||||||
|
line = line.strip().split()
|
||||||
|
assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}"
|
||||||
|
|
||||||
|
utt_name, wav_path = line[0], line[1]
|
||||||
|
with wave.open(wav_path, 'rb') as fin:
|
||||||
|
assert fin.getnchannels() == 1
|
||||||
|
wav = fin.readframes(fin.getnframes())
|
||||||
|
|
||||||
|
kws.reset_all()
|
||||||
|
activated = False
|
||||||
|
|
||||||
|
# We inference every 0.3 seconds, in streaming fashion.
|
||||||
|
interval = int(0.3 * 16000) * 2
|
||||||
|
for i in range(0, len(wav), interval):
|
||||||
|
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||||
|
result = kws.forward(chunk_wav)
|
||||||
|
if 'state' in result and result['state'] == 1:
|
||||||
|
activated = True
|
||||||
|
if fout:
|
||||||
|
hit_keyword = result['keyword']
|
||||||
|
hit_score = result['score']
|
||||||
|
fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score))
|
||||||
|
|
||||||
|
if not activated:
|
||||||
|
if fout:
|
||||||
|
fout.write('{} rejected\n'.format(utt_name))
|
||||||
|
|
||||||
|
|
||||||
|
if fout:
|
||||||
|
fout.close()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
demo()
|
demo()
|
||||||
|
|||||||
@ -44,11 +44,11 @@ def get_args():
|
|||||||
help='gpu id for this rank, -1 for cpu')
|
help='gpu id for this rank, -1 for cpu')
|
||||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||||
parser.add_argument('--batch_size',
|
parser.add_argument('--batch_size',
|
||||||
default=16,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help='batch size for inference')
|
help='batch size for inference')
|
||||||
parser.add_argument('--num_workers',
|
parser.add_argument('--num_workers',
|
||||||
default=0,
|
default=1,
|
||||||
type=int,
|
type=int,
|
||||||
help='num of subprocess workers for reading')
|
help='num of subprocess workers for reading')
|
||||||
parser.add_argument('--pin_memory',
|
parser.add_argument('--pin_memory',
|
||||||
@ -131,6 +131,8 @@ def main():
|
|||||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||||
|
|
||||||
|
downsampling_factor = test_conf.get('frame_skip', 1)
|
||||||
|
|
||||||
test_dataset = Dataset(args.test_data, test_conf)
|
test_dataset = Dataset(args.test_data, test_conf)
|
||||||
test_data_loader = DataLoader(test_dataset,
|
test_data_loader = DataLoader(test_dataset,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
@ -205,6 +207,7 @@ def main():
|
|||||||
# 2. CTC beam search step by step
|
# 2. CTC beam search step by step
|
||||||
for t in range(0, maxlen):
|
for t in range(0, maxlen):
|
||||||
probs = ctc_probs[t] # (vocab_size,)
|
probs = ctc_probs[t] # (vocab_size,)
|
||||||
|
t *= downsampling_factor # the real time
|
||||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||||
|
|
||||||
|
|||||||
@ -157,8 +157,12 @@ def main():
|
|||||||
# Try to export the model by script, if fails, we should refine
|
# Try to export the model by script, if fails, we should refine
|
||||||
# the code to satisfy the script export requirements
|
# the code to satisfy the script export requirements
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
script_model = torch.jit.script(model)
|
pass
|
||||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
# TODO: for now streaming FSMN do not support export to JITScript,
|
||||||
|
# TODO: because there is nn.Sequential with Tuple input in current FSMN modules.
|
||||||
|
# the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
|
||||||
|
# script_model = torch.jit.script(model)
|
||||||
|
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||||
executor = Executor()
|
executor = Executor()
|
||||||
# If specify checkpoint, load some info from checkpoint
|
# If specify checkpoint, load some info from checkpoint
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
|
|||||||
@ -377,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob):
|
|||||||
rir_io = io.BytesIO(rir_data)
|
rir_io = io.BytesIO(rir_data)
|
||||||
_, rir_audio = wavfile.read(rir_io)
|
_, rir_audio = wavfile.read(rir_io)
|
||||||
rir_audio = rir_audio.astype(np.float32)
|
rir_audio = rir_audio.astype(np.float32)
|
||||||
|
if len(rir_audio.shape) > 1:
|
||||||
|
rir_audio = rir_audio[:, 0]
|
||||||
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
||||||
out_audio = signal.convolve(audio, rir_audio,
|
out_audio = signal.convolve(audio, rir_audio,
|
||||||
mode='full')[:audio_len]
|
mode='full')[:audio_len]
|
||||||
@ -405,6 +407,8 @@ def add_noise(data, noise_source, aug_prob):
|
|||||||
snr_range = [0, 15]
|
snr_range = [0, 15]
|
||||||
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
||||||
noise_audio = noise_audio.astype(np.float32)
|
noise_audio = noise_audio.astype(np.float32)
|
||||||
|
if len(noise_audio.shape) > 1:
|
||||||
|
noise_audio = noise_audio[:, 0]
|
||||||
if noise_audio.shape[0] > audio_len:
|
if noise_audio.shape[0] > audio_len:
|
||||||
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
||||||
noise_audio = noise_audio[start:start + audio_len]
|
noise_audio = noise_audio[start:start + audio_len]
|
||||||
|
|||||||
@ -39,11 +39,12 @@ class LinearTransform(nn.Module):
|
|||||||
self.quant = torch.quantization.QuantStub()
|
self.quant = torch.quantization.QuantStub()
|
||||||
self.dequant = torch.quantization.DeQuantStub()
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self,
|
||||||
|
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
if isinstance(input, tuple):
|
if isinstance(input, tuple):
|
||||||
input, in_cache = input
|
input, in_cache = input
|
||||||
else:
|
else:
|
||||||
in_cache = None
|
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||||
output = self.quant(input)
|
output = self.quant(input)
|
||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = self.dequant(output)
|
output = self.dequant(output)
|
||||||
@ -102,11 +103,12 @@ class AffineTransform(nn.Module):
|
|||||||
self.quant = torch.quantization.QuantStub()
|
self.quant = torch.quantization.QuantStub()
|
||||||
self.dequant = torch.quantization.DeQuantStub()
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self,
|
||||||
|
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
if isinstance(input, tuple):
|
if isinstance(input, tuple):
|
||||||
input, in_cache = input
|
input, in_cache = input
|
||||||
else:
|
else:
|
||||||
in_cache = None
|
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||||
output = self.quant(input)
|
output = self.quant(input)
|
||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = self.dequant(output)
|
output = self.dequant(output)
|
||||||
@ -213,15 +215,16 @@ class FSMNBlock(nn.Module):
|
|||||||
self.quant = torch.quantization.QuantStub()
|
self.quant = torch.quantization.QuantStub()
|
||||||
self.dequant = torch.quantization.DeQuantStub()
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self,
|
||||||
|
input: Tuple[torch.Tensor, torch.Tensor] ):
|
||||||
if isinstance(input, tuple):
|
if isinstance(input, tuple):
|
||||||
input, in_cache = input
|
input, in_cache = input
|
||||||
else :
|
else :
|
||||||
in_cache = None
|
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||||
x = torch.unsqueeze(input, 1)
|
x = torch.unsqueeze(input, 1)
|
||||||
x_per = x.permute(0, 3, 2, 1)
|
x_per = x.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
|
if in_cache is None or len(in_cache) == 0 :
|
||||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
||||||
else:
|
else:
|
||||||
in_cache = in_cache.to(x_per.device)
|
in_cache = in_cache.to(x_per.device)
|
||||||
@ -339,11 +342,12 @@ class RectifiedLinear(nn.Module):
|
|||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.dropout = nn.Dropout(0.1)
|
self.dropout = nn.Dropout(0.1)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self,
|
||||||
|
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
if isinstance(input, tuple):
|
if isinstance(input, tuple):
|
||||||
input, in_cache = input
|
input, in_cache = input
|
||||||
else :
|
else :
|
||||||
in_cache = None
|
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||||
out = self.relu(input)
|
out = self.relu(input)
|
||||||
# out = self.dropout(out)
|
# out = self.dropout(out)
|
||||||
return (out, in_cache)
|
return (out, in_cache)
|
||||||
@ -456,7 +460,7 @@ class FSMN(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
in_cache #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -464,8 +468,8 @@ class FSMN(nn.Module):
|
|||||||
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
|
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if in_cache is None or len(in_cache) == 0 or in_cache[0] == None:
|
if in_cache is None or len(in_cache) == 0 :
|
||||||
in_cache = [None for _ in range(len(self.fsmn))]
|
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))]
|
||||||
input = (input, in_cache)
|
input = (input, in_cache)
|
||||||
x1 = self.in_linear1(input)
|
x1 = self.in_linear1(input)
|
||||||
x2 = self.in_linear2(x1)
|
x2 = self.in_linear2(x1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user