Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward.
This commit is contained in:
parent
6f8207267e
commit
ceb59e87a6
@ -203,7 +203,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
--test_data data/test/data.list \
|
||||
--window_shift $window_shift \
|
||||
--step 0.001 \
|
||||
--score_file $result_dir/score.txt
|
||||
--score_file $result_dir/score.txt \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
fi
|
||||
|
||||
|
||||
|
||||
@ -154,17 +154,21 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
--test_data data/test/data.list \
|
||||
--window_shift $window_shift \
|
||||
--step 0.001 \
|
||||
--score_file $result_dir/score.txt
|
||||
--score_file $result_dir/score.txt \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||
# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input
|
||||
# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
|
||||
# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||
# python wekws/bin/export_jit.py \
|
||||
# --config $dir/config.yaml \
|
||||
# --checkpoint $score_checkpoint \
|
||||
# --jit_model $dir/$jit_model
|
||||
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||
python wekws/bin/export_jit.py \
|
||||
--config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--jit_model $dir/$jit_model
|
||||
python wekws/bin/export_onnx.py \
|
||||
--config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
|
||||
@ -23,6 +23,7 @@ import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pypinyin # for Chinese Character
|
||||
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
@ -43,7 +44,7 @@ def space_mixed_label(input_str):
|
||||
space_str = ''.join(f'{sub} ' for sub in splits)
|
||||
return space_str.strip()
|
||||
|
||||
def load_label_and_score(keywords_list, label_file, score_file):
|
||||
def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
|
||||
score_table = {}
|
||||
with open(score_file, 'r', encoding='utf8') as fin:
|
||||
# read score file and store in table
|
||||
@ -52,10 +53,11 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
||||
key = arr[0]
|
||||
is_detected = arr[1]
|
||||
if is_detected == 'detected':
|
||||
keyword=true_keywords[arr[2]]
|
||||
if key not in score_table:
|
||||
score_table.update({
|
||||
key: {
|
||||
'kw': space_mixed_label(arr[2]),
|
||||
'kw': space_mixed_label(keyword),
|
||||
'confi': float(arr[3])
|
||||
}
|
||||
})
|
||||
@ -72,6 +74,7 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
||||
# build empty structure for keyword-filler infos
|
||||
keyword_filler_table = {}
|
||||
for keyword in keywords_list:
|
||||
keyword = true_keywords[keyword]
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_filler_table[keyword] = {}
|
||||
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||
@ -93,6 +96,7 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
||||
assert key in score_table
|
||||
|
||||
for keyword in keywords_list:
|
||||
keyword = true_keywords[keyword]
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_regstr_lrblk = ' ' + keyword + ' '
|
||||
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
||||
@ -154,6 +158,8 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='compute det curve')
|
||||
parser.add_argument('--test_data', required=True, help='label file')
|
||||
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
|
||||
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
||||
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
||||
parser.add_argument('--score_file', required=True, help='score file')
|
||||
parser.add_argument('--step', type=float, default=0.01,
|
||||
help='threshold step')
|
||||
@ -183,9 +189,18 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
window_shift = args.window_shift
|
||||
keywords_list = args.keywords.strip().split(',')
|
||||
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
|
||||
|
||||
token_table = read_token(args.token_file)
|
||||
lexicon_table = read_lexicon(args.lexicon_file)
|
||||
true_keywords = {}
|
||||
for keyword in keywords_list:
|
||||
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||
true_keywords[keyword] = ''.join(strs)
|
||||
|
||||
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords)
|
||||
|
||||
for keyword in keywords_list:
|
||||
keyword = true_keywords[keyword]
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||
|
||||
@ -35,7 +35,10 @@ from tools.make_list import query_token_set, read_lexicon, read_token
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='detect keywords online.')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--wav_path', required=True, help='test wave path.')
|
||||
parser.add_argument('--wav_path', required=False, default=None, help='test wave path.')
|
||||
parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.')
|
||||
parser.add_argument('--result_file', required=False, default=None, help='test result.')
|
||||
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
@ -209,7 +212,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
self.left_context = dataset_conf['context_expansion_conf']['left']
|
||||
self.right_context = dataset_conf['context_expansion_conf']['right']
|
||||
self.feature_remained = None
|
||||
self.feature_context_offset = 0 # after downsample, offset exist.
|
||||
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||
|
||||
|
||||
# model related
|
||||
@ -289,6 +292,9 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
|
||||
wave = np.array(data)
|
||||
wave = np.append(self.wave_remained, wave)
|
||||
if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context :
|
||||
self.wave_remained = wave
|
||||
return None
|
||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
|
||||
feats = kaldi.fbank(wave_tensor,
|
||||
@ -312,7 +318,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
else:
|
||||
feats_pad = torch.cat((self.feature_remained, feats))
|
||||
|
||||
ctx_frm = feats.shape[0] - self.right_context
|
||||
ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context)
|
||||
ctx_win = (self.left_context + self.right_context + 1)
|
||||
ctx_dim = feats.shape[1] * ctx_win
|
||||
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
||||
@ -320,12 +326,13 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
||||
|
||||
# update feature remained, and feats
|
||||
self.feature_remained = feats[-self.left_context:]
|
||||
self.feature_remained = feats[-(self.left_context+self.right_context):]
|
||||
feats = feats_ctx.to(self.device)
|
||||
if self.downsampling > 1:
|
||||
feats = feats[self.feature_context_offset::self.downsampling, :]
|
||||
complement = feats.size(1) % self.downsampling
|
||||
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
|
||||
last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset
|
||||
remainder = (feats.size(0)+last_remainder) % self.downsampling
|
||||
feats = feats[self.feats_ctx_offset::self.downsampling, :]
|
||||
self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder
|
||||
return feats
|
||||
|
||||
def decode_keywords(self, t, probs):
|
||||
@ -396,11 +403,14 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
|
||||
def forward(self, wave_chunk):
|
||||
feature = self.accept_wave(wave_chunk)
|
||||
if feature is None or feature.size(0) < 1:
|
||||
return self.result
|
||||
feature = feature.unsqueeze(0) # add a batch dimension
|
||||
logits, self.in_cache = self.model(feature, self.in_cache)
|
||||
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search
|
||||
for (t, prob) in enumerate(probs):
|
||||
t *= self.downsampling
|
||||
self.decode_keywords(t, prob)
|
||||
self.execute_detection(t)
|
||||
|
||||
@ -409,7 +419,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
# since a chunk include about 30 frames, once activated, we can jump the latter frames.
|
||||
# TODO: there should give another method to update result, avoiding self.result being cleared.
|
||||
break
|
||||
self.total_frames += len(probs) # update frame offset
|
||||
self.total_frames += len(probs) * self.downsampling # update frame offset
|
||||
return self.result
|
||||
|
||||
def reset(self):
|
||||
@ -417,6 +427,15 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
self.activated = False
|
||||
self.hit_score = 1.0
|
||||
|
||||
def reset_all(self):
|
||||
self.reset()
|
||||
self.wave_remained = np.array([])
|
||||
self.feature_remained = None
|
||||
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
self.total_frames = 0 # frame offset, for absolute time
|
||||
self.result = {}
|
||||
|
||||
def demo():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
@ -435,22 +454,60 @@ def demo():
|
||||
args.gpu,
|
||||
args.jit_model)
|
||||
|
||||
# actually this could be done in __init__ method, we pull it outside for changing keywords.
|
||||
# actually this could be done in __init__ method, we pull it outside for changing keywords more freely.
|
||||
kws.set_keywords(args.keywords)
|
||||
|
||||
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
||||
# In demo we read wave in non-streaming fashion.
|
||||
with wave.open(args.wav_path, 'rb') as fin:
|
||||
assert fin.getnchannels() == 1
|
||||
wav = fin.readframes(fin.getnframes())
|
||||
if args.wav_path:
|
||||
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
||||
# In demo we read wave in non-streaming fashion.
|
||||
with wave.open(args.wav_path, 'rb') as fin:
|
||||
assert fin.getnchannels() == 1
|
||||
wav = fin.readframes(fin.getnframes())
|
||||
|
||||
# We inference every 0.3 seconds, in streaming fashion.
|
||||
interval = int(0.3 * 16000) * 2
|
||||
for i in range(0, len(wav), interval):
|
||||
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||
result = kws.forward(chunk_wav)
|
||||
print(result)
|
||||
# We inference every 0.3 seconds, in streaming fashion.
|
||||
interval = int(0.3 * 16000) * 2
|
||||
for i in range(0, len(wav), interval):
|
||||
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||
result = kws.forward(chunk_wav)
|
||||
print(result)
|
||||
|
||||
fout = None
|
||||
if args.result_file:
|
||||
fout = open(args.result_file, 'w', encoding='utf-8')
|
||||
|
||||
if args.wav_scp:
|
||||
with open(args.wav_scp, 'r') as fscp:
|
||||
for line in fscp:
|
||||
line = line.strip().split()
|
||||
assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}"
|
||||
|
||||
utt_name, wav_path = line[0], line[1]
|
||||
with wave.open(wav_path, 'rb') as fin:
|
||||
assert fin.getnchannels() == 1
|
||||
wav = fin.readframes(fin.getnframes())
|
||||
|
||||
kws.reset_all()
|
||||
activated = False
|
||||
|
||||
# We inference every 0.3 seconds, in streaming fashion.
|
||||
interval = int(0.3 * 16000) * 2
|
||||
for i in range(0, len(wav), interval):
|
||||
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||
result = kws.forward(chunk_wav)
|
||||
if 'state' in result and result['state'] == 1:
|
||||
activated = True
|
||||
if fout:
|
||||
hit_keyword = result['keyword']
|
||||
hit_score = result['score']
|
||||
fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score))
|
||||
|
||||
if not activated:
|
||||
if fout:
|
||||
fout.write('{} rejected\n'.format(utt_name))
|
||||
|
||||
|
||||
if fout:
|
||||
fout.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
demo()
|
||||
|
||||
@ -44,11 +44,11 @@ def get_args():
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--batch_size',
|
||||
default=16,
|
||||
default=1,
|
||||
type=int,
|
||||
help='batch size for inference')
|
||||
parser.add_argument('--num_workers',
|
||||
default=0,
|
||||
default=1,
|
||||
type=int,
|
||||
help='num of subprocess workers for reading')
|
||||
parser.add_argument('--pin_memory',
|
||||
@ -131,6 +131,8 @@ def main():
|
||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||
|
||||
downsampling_factor = test_conf.get('frame_skip', 1)
|
||||
|
||||
test_dataset = Dataset(args.test_data, test_conf)
|
||||
test_data_loader = DataLoader(test_dataset,
|
||||
batch_size=None,
|
||||
@ -205,6 +207,7 @@ def main():
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
probs = ctc_probs[t] # (vocab_size,)
|
||||
t *= downsampling_factor # the real time
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
|
||||
@ -157,8 +157,12 @@ def main():
|
||||
# Try to export the model by script, if fails, we should refine
|
||||
# the code to satisfy the script export requirements
|
||||
if rank == 0:
|
||||
script_model = torch.jit.script(model)
|
||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
pass
|
||||
# TODO: for now streaming FSMN do not support export to JITScript,
|
||||
# TODO: because there is nn.Sequential with Tuple input in current FSMN modules.
|
||||
# the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
|
||||
# script_model = torch.jit.script(model)
|
||||
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
executor = Executor()
|
||||
# If specify checkpoint, load some info from checkpoint
|
||||
if args.checkpoint is not None:
|
||||
|
||||
@ -377,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob):
|
||||
rir_io = io.BytesIO(rir_data)
|
||||
_, rir_audio = wavfile.read(rir_io)
|
||||
rir_audio = rir_audio.astype(np.float32)
|
||||
if len(rir_audio.shape) > 1:
|
||||
rir_audio = rir_audio[:, 0]
|
||||
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
||||
out_audio = signal.convolve(audio, rir_audio,
|
||||
mode='full')[:audio_len]
|
||||
@ -405,6 +407,8 @@ def add_noise(data, noise_source, aug_prob):
|
||||
snr_range = [0, 15]
|
||||
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
||||
noise_audio = noise_audio.astype(np.float32)
|
||||
if len(noise_audio.shape) > 1:
|
||||
noise_audio = noise_audio[:, 0]
|
||||
if noise_audio.shape[0] > audio_len:
|
||||
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
||||
noise_audio = noise_audio[start:start + audio_len]
|
||||
|
||||
@ -39,11 +39,12 @@ class LinearTransform(nn.Module):
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else:
|
||||
in_cache = None
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
output = self.quant(input)
|
||||
output = self.linear(output)
|
||||
output = self.dequant(output)
|
||||
@ -102,11 +103,12 @@ class AffineTransform(nn.Module):
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else:
|
||||
in_cache = None
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
output = self.quant(input)
|
||||
output = self.linear(output)
|
||||
output = self.dequant(output)
|
||||
@ -213,15 +215,16 @@ class FSMNBlock(nn.Module):
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor] ):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else :
|
||||
in_cache = None
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
|
||||
if in_cache is None or len(in_cache) == 0 :
|
||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
||||
else:
|
||||
in_cache = in_cache.to(x_per.device)
|
||||
@ -339,11 +342,12 @@ class RectifiedLinear(nn.Module):
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else :
|
||||
in_cache = None
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
out = self.relu(input)
|
||||
# out = self.dropout(out)
|
||||
return (out, in_cache)
|
||||
@ -456,7 +460,7 @@ class FSMN(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
in_cache #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -464,8 +468,8 @@ class FSMN(nn.Module):
|
||||
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
|
||||
"""
|
||||
|
||||
if in_cache is None or len(in_cache) == 0 or in_cache[0] == None:
|
||||
in_cache = [None for _ in range(len(self.fsmn))]
|
||||
if in_cache is None or len(in_cache) == 0 :
|
||||
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))]
|
||||
input = (input, in_cache)
|
||||
x1 = self.in_linear1(input)
|
||||
x2 = self.in_linear2(x1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user