fix flake8, update training script and README, give pretrained ckpt.

This commit is contained in:
dujing 2023-07-24 16:39:13 +08:00
parent 45f0522f19
commit ea6a0f5cda
13 changed files with 299 additions and 159 deletions

View File

@ -1,4 +1,6 @@
Comparison among different backbones. FRRs with FAR fixed at once per hour:
Comparison among different backbones,
all models use Max-Pooling loss.
FRRs with FAR fixed at once per hour:
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|-----------------------|-----------|-----------|------------|--------------|
@ -9,32 +11,35 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour:
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.
Since the FAR is pretty low when using CTC loss,
the follow result is FRRs with FAR fixed at once per 12 hours:
the follow results are FRRs with FAR fixed at once per 12 hours:
Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
FRRs with FAR fixed at once per 12 hours
| model | loss | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|------------|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
and FSMN(modelscope released, xiaoyunxiaoyun model).
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
FRRs with FAR fixed at once per 12 hours:
| model | params(K) | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |
Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.
Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:
@ -52,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.
Now, the model with CTC loss may not get the best performance,
but it's more robust compared with the classification model using CE/Max-pooling loss.
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

View File

@ -194,12 +194,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords 嗨小问,你好问问 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \
--keywords 嗨小问,你好问问 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \

View File

@ -145,12 +145,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords 嗨小问,你好问问 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \
--keywords 嗨小问,你好问问 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \
@ -161,14 +161,13 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input
# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
# For now, FSMN can not export to JITScript
# python wekws/bin/export_jit.py \
# --config $dir/config.yaml \
# --checkpoint $score_checkpoint \
# --jit_model $dir/$jit_model
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \

View File

@ -45,7 +45,8 @@ def query_token_set(txt, symbol_table, lexicon_table):
tokens_str = tokens_str + ('!sil', )
elif part == '<blk>' or part == '<blank>':
tokens_str = tokens_str + ('<blk>', )
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str = tokens_str + ('<GBG>', )
elif part in symbol_table:
tokens_str = tokens_str + (part, )
@ -75,11 +76,11 @@ def query_token_set(txt, symbol_table, lexicon_table):
if '<GBG>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
logging.info(
f'\'{ch}\' is not in token set, replace with <GBG>')
f'{ch} is not in token set, replace with <GBG>')
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
logging.info(
f'\'{ch}\' is not in token set, replace with <blk>')
f'{ch} is not in token set, replace with <blk>')
return tokens_str, tokens_idx
@ -94,7 +95,8 @@ def query_token_list(txt, symbol_table, lexicon_table):
tokens_str.append('!sil')
elif part == '<blk>' or part == '<blank>':
tokens_str.append('<blk>')
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str.append('<GBG>')
elif part in symbol_table:
tokens_str.append(part)
@ -124,11 +126,11 @@ def query_token_list(txt, symbol_table, lexicon_table):
if '<GBG>' in symbol_table:
tokens_idx.append(symbol_table['<GBG>'])
logging.info(
f'\'{ch}\' is not in token set, replace with <GBG>')
f'{ch} is not in token set, replace with <GBG>')
else:
tokens_idx.append(symbol_table['<blk>'])
logging.info(
f'\'{ch}\' is not in token set, replace with <blk>')
f'{ch} is not in token set, replace with <blk>')
return tokens_str, tokens_idx
@ -160,8 +162,10 @@ if __name__ == '__main__':
parser.add_argument('text_file', help='text file')
parser.add_argument('duration_file', help='duration file')
parser.add_argument('output_file', help='output list file')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args()
wav_table = {}
@ -196,7 +200,9 @@ if __name__ == '__main__':
txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"]
else:
tokens, txt = query_token_list(arr[1], token_table, lexicon_table)
tokens, txt = query_token_list(arr[1],
token_table,
lexicon_table)
else:
txt = int(arr[1])
assert key in wav_table
@ -206,7 +212,8 @@ if __name__ == '__main__':
if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav)
else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
line = dict(key=key, tok=tokens, txt=txt,
duration=duration, wav=wav)
json_line = json.dumps(line, ensure_ascii=False)
fout.write(json_line + '\n')

View File

@ -157,18 +157,23 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
parser.add_argument('--keywords', type=str, default=None,
help='keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01,
help='threshold step')
parser.add_argument('--window_shift', type=int, default=50,
help='window_shift is used to skip the frames after triggered')
help='window_shift is used to '
'skip the frames after triggered')
parser.add_argument('--stats_dir',
required=False,
default=None,
help='false reject/alarm stats dir, default in score_file')
help='false reject/alarm stats dir, '
'default in score_file')
parser.add_argument('--det_curve_path',
required=False,
default=None,
@ -188,7 +193,11 @@ if __name__ == '__main__':
args = parser.parse_args()
window_shift = args.window_shift
keywords_list = args.keywords.strip().split(',')
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords.strip().split(',')
token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
@ -197,7 +206,8 @@ if __name__ == '__main__':
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
true_keywords[keyword] = ''.join(strs)
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords)
keyword_filler_table = load_label_and_score(
keywords_list, args.test_data, args.score_file, true_keywords)
for keyword in keywords_list:
keyword = true_keywords[keyword]
@ -206,8 +216,10 @@ if __name__ == '__main__':
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
filler_dur = keyword_filler_table[keyword]['filler_duration']
filler_num = len(keyword_filler_table[keyword]['filler_table'])
assert keyword_num > 0, 'Can\'t compute det for {} without positive sample'
assert filler_num > 0, 'Can\'t compute det for {} without negative sample'
assert keyword_num > 0, \
'Can\'t compute det for {} without positive sample'
assert filler_num > 0, \
'Can\'t compute det for {} without negative sample'
logging.info('Computing det for {}'.format(keyword))
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
@ -218,14 +230,16 @@ if __name__ == '__main__':
stats_dir = args.stats_dir
else:
stats_dir = os.path.dirname(args.score_file)
stats_file = os.path.join(stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
stats_file = os.path.join(
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
with open(stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
for key, confi in \
keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:
@ -253,4 +267,5 @@ if __name__ == '__main__':
det_curve_path = args.det_curve_path
else:
det_curve_path = os.path.join(stats_dir, 'det.png')
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
plot_det(stats_dir, det_curve_path,
args.xlim, args.x_step, args.ylim, args.y_step)

View File

@ -42,7 +42,7 @@ def main():
feature_dim = configs['model']['input_dim']
model = init_model(configs['model'])
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
# if we use ctc_loss, the logits need to be convert into probs
model.forward = model.forward_softmax
print(model)

View File

@ -65,9 +65,12 @@ def get_args():
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args()
return args
@ -133,7 +136,9 @@ def main():
lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed'
keywords_str = args.keywords
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
@ -167,7 +172,9 @@ def main():
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
hyps = ctc_prefix_beam_search(score,
lengths[i],
keywords_idxset)
hit_keyword = None
hit_score = 1.0
start = 0
@ -192,10 +199,13 @@ def main():
break
if hit_keyword is not None:
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"duration {end - start}, score {hit_score}, Activated.")
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"duration {end - start}, "
f"score {hit_score}, Activated.")
else:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -16,7 +16,8 @@ from __future__ import print_function
import argparse
import struct
import wave
#import wave
import librosa
import logging
import os
import math
@ -35,9 +36,12 @@ from tools.make_list import query_token_set, read_lexicon, read_token
def get_args():
parser = argparse.ArgumentParser(description='detect keywords online.')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--wav_path', required=False, default=None, help='test wave path.')
parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.')
parser.add_argument('--result_file', required=False, default=None, help='test result.')
parser.add_argument('--wav_path', required=False,
default=None, help='test wave path.')
parser.add_argument('--wav_scp', required=False,
default=None, help='test wave scp.')
parser.add_argument('--result_file', required=False,
default=None, help='test result.')
parser.add_argument('--gpu',
type=int,
@ -48,21 +52,27 @@ def get_args():
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_beam_size',
default=3,
type=int,
help='The first prune beam, filter out those frames with low scores.')
help='The first prune beam, '
'filter out those frames with low scores.')
parser.add_argument('--path_beam_size',
default=20,
type=int,
help='The second prune beam, keep only path_beam_size candidates.')
help='The second prune beam, '
'keep only path_beam_size candidates.')
parser.add_argument('--threshold',
type=float,
default=0.0,
help='The threshold of kws. If ctc_search probs exceed this value,'
help='The threshold of kws. '
'If ctc_search probs exceed this value,'
'the keyword will be activated.')
parser.add_argument('--min_frames',
default=5,
@ -98,16 +108,22 @@ def is_sublist(main_list, check_list):
else:
return -1
def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size):
def ctc_prefix_beam_search(t, probs,
cur_hyps,
keywords_idxset,
score_beam_size):
'''
:param t: the time in frame
:param probs: the probability in t_th frame, (vocab_size, )
:param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))]
in tuple, 1st is prefix id, 2nd include p_blank, p_non_blank, and path nodes list.
in path nodes list, each node is a dict of {token=idx, frame=t, prob=ps}
in tuple, 1st is prefix id, 2nd include p_blank,
p_non_blank, and path nodes list.
in path nodes list, each node is
a dict of {token=idx, frame=t, prob=ps}
:param keywords_idxset: the index of keywords in token.txt
:param score_beam_size: the probability threshold, to filter out those frames with low probs.
:param score_beam_size: the probability threshold,
to filter out those frames with low probs.
:return:
next_hyps: the hypothesis depend on current hyp and current frame.
'''
@ -170,7 +186,8 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
nodes.pop() # to avoid change other beam which has this node.
nodes.pop()
# to avoid change other beam which has this node.
nodes.append(dict(token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
@ -199,11 +216,14 @@ class KeyWordSpotter(torch.nn.Module):
# feature related
self.sample_rate = 16000
self.wave_remained = np.array([])
self.num_mel_bins = dataset_conf['feature_extraction_conf']['num_mel_bins']
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
self.num_mel_bins = dataset_conf[
'feature_extraction_conf']['num_mel_bins']
self.frame_length = dataset_conf[
'feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf[
'feature_extraction_conf']['frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift / 1000 # in second
self.resolution = self.frame_shift / 1000 # in second
# fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0
@ -231,9 +251,11 @@ class KeyWordSpotter(torch.nn.Module):
self.model.eval()
logging.info(f'model {ckpt_path} loaded.')
self.token_table = read_token(token_path)
logging.info(f'tokens {token_path} with {len(self.token_table)} units loaded.')
logging.info(f'tokens {token_path} with '
f'{len(self.token_table)} units loaded.')
self.lexicon_table = read_lexicon(lexicon_path)
logging.info(f'lexicons {lexicon_path} with {len(self.lexicon_table)} units loaded.')
logging.info(f'lexicons {lexicon_path} with '
f'{len(self.lexicon_table)} units loaded.')
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
@ -257,7 +279,9 @@ class KeyWordSpotter(torch.nn.Module):
def set_keywords(self, keywords):
# 4. parse keywords tokens
assert keywords is not None, 'at least one keyword is needed, multiple keywords should be splitted with comma(,)'
assert keywords is not None, \
'at least one keyword is needed, ' \
'multiple keywords should be splitted with comma(,)'
keywords_str = keywords
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
@ -265,7 +289,8 @@ class KeyWordSpotter(torch.nn.Module):
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, self.token_table, self.lexicon_table)
strs, indexes = query_token_set(
keyword, self.token_table, self.lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
@ -284,16 +309,20 @@ class KeyWordSpotter(torch.nn.Module):
self.keywords_token = keywords_token
def accept_wave(self, wave):
assert isinstance(wave, bytes), "please make sure the input format is bytes(raw PCM)"
assert isinstance(wave, bytes), \
"please make sure the input format is bytes(raw PCM)"
# convert bytes into float32
data = []
for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
data.append(value)
# here we don't divide 32768.0,
# because kaldi.fbank accept original input
wave = np.array(data)
wave = np.append(self.wave_remained, wave)
if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context :
if wave.size < (self.frame_length * self.sample_rate / 1000) \
* self.right_context :
self.wave_remained = wave
return None
wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -311,29 +340,37 @@ class KeyWordSpotter(torch.nn.Module):
self.wave_remained = wave[feat_len * frame_shift:]
if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
assert feat_len > self.right_context, \
"make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk
if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
# pad first frame at the beginning,
# replicate just support last dimension, so we do transpose.
feats_pad = F.pad(
feats.T, (self.left_context, 0), mode='replicate').T
else:
feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context)
ctx_frm = feats_pad.shape[0] - \
(self.right_context+self.right_context)
ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm):
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
feats_ctx[i] = torch.cat(
tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats
self.feature_remained = feats[-(self.left_context+self.right_context):]
self.feature_remained = \
feats[-(self.left_context+self.right_context):]
feats = feats_ctx.to(self.device)
if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset
last_remainder = 0 if self.feats_ctx_offset==0 \
else self.downsampling-self.feats_ctx_offset
remainder = (feats.size(0)+last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder
self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling-remainder
return feats
def decode_keywords(self, t, probs):
@ -344,7 +381,8 @@ class KeyWordSpotter(torch.nn.Module):
self.cur_hyps,
self.keywords_idxset,
self.score_beam)
# update cur_hyps. note: the hyps is sort by path score(pnb+pb), not the keywords' probabilities.
# update cur_hyps. note: the hyps is sort by path score(pnb+pb),
# not the keywords' probabilities.
cur_hyps = next_hyps[:self.path_beam]
self.cur_hyps = cur_hyps
@ -381,27 +419,36 @@ class KeyWordSpotter(torch.nn.Module):
if hit_keyword is not None:
if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames):
and (self.last_active_pos==-1 or
end-self.last_active_pos >= self.interval_frames):
self.activated = True
self.last_active_pos = end
logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames:
elif self.last_active_pos>0 and \
end-self.last_active_pos < self.interval_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ")
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but interval {end-self.last_active_pos} "
f"is lower than {self.interval_frames}, Deactivated. ")
elif self.hit_score < self.threshold:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ")
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but {self.hit_score} "
f"is lower than {self.threshold}, Deactivated. ")
elif self.min_frames > duration or duration > self.max_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"but {duration} beyond range({self.min_frames}~{self.max_frames}), Deactivated. ")
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but {duration} beyond range"
f"({self.min_frames}~{self.max_frames}), Deactivated. ")
self.result = {
"state": 1 if self.activated else 0,
@ -418,7 +465,7 @@ class KeyWordSpotter(torch.nn.Module):
feature = feature.unsqueeze(0) # add a batch dimension
logits, self.in_cache = self.model(feature, self.in_cache)
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search
probs = probs[0].cpu() # remove batch dimension
for (t, prob) in enumerate(probs):
t *= self.downsampling
self.decode_keywords(t, prob)
@ -426,10 +473,14 @@ class KeyWordSpotter(torch.nn.Module):
if self.activated:
self.reset()
# since a chunk include about 30 frames, once activated, we can jump the latter frames.
# TODO: there should give another method to update result, avoiding self.result being cleared.
# since a chunk include about 30 frames,
# once activated, we can jump the latter frames.
# TODO: there should give another method to update result,
# avoiding self.result being cleared.
break
self.total_frames += len(probs) * self.downsampling # update frame offset
# update frame offset
self.total_frames += len(probs) * self.downsampling
return self.result
def reset(self):
@ -465,15 +516,20 @@ def demo():
args.gpu,
args.jit_model)
# actually this could be done in __init__ method, we pull it outside for changing keywords more freely.
# actually this could be done in __init__ method,
# we pull it outside for changing keywords more freely.
kws.set_keywords(args.keywords)
if args.wav_path:
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
# In demo we read wave in non-streaming fashion.
with wave.open(args.wav_path, 'rb') as fin:
assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes())
# with wave.open(args.wav_path, 'rb') as fin:
# assert fin.getnchannels() == 1
# wav = fin.readframes(fin.getnframes())
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
# NOTE: model supports 16k sample_rate
wav = (y * (1 << 15)).astype("int16").tobytes()
# We inference every 0.3 seconds, in streaming fashion.
interval = int(0.3 * 16000) * 2
@ -490,12 +546,18 @@ def demo():
with open(args.wav_scp, 'r') as fscp:
for line in fscp:
line = line.strip().split()
assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}"
assert len(line) == 2, \
f"The scp should be in kaldi format: " \
f"\"utt_name wav_path\", but got {line}"
utt_name, wav_path = line[0], line[1]
with wave.open(wav_path, 'rb') as fin:
assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes())
# with wave.open(args.wav_path, 'rb') as fin:
# assert fin.getnchannels() == 1
# wav = fin.readframes(fin.getnframes())
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
# NOTE: model supports 16k sample_rate
wav = (y * (1 << 15)).astype("int16").tobytes()
kws.reset_all()
activated = False
@ -510,7 +572,8 @@ def demo():
if fout:
hit_keyword = result['keyword']
hit_score = result['score']
fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score))
fout.write('{} detected {} {:.3f}\n'.format(
utt_name, hit_keyword, hit_score))
if not activated:
if fout:

View File

@ -66,30 +66,36 @@ def get_args():
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_beam_size',
default=3,
type=int,
help='The first prune beam, filter out those frames with low scores.')
help='The first prune beam, f'
'ilter out those frames with low scores.')
parser.add_argument('--path_beam_size',
default=20,
type=int,
help='The second prune beam, keep only path_beam_size candidates.')
help='The second prune beam, '
'keep only path_beam_size candidates.')
parser.add_argument('--threshold',
type=float,
default=0.0,
help='The threshold of kws. If ctc_search probs exceed this value,'
help='The threshold of kws. '
'If ctc_search probs exceed this value,'
'the keyword will be activated.')
parser.add_argument('--min_frames',
default=5,
type=int,
help='The min frames of keyword\'s duration.')
help='The min frames of keyword duration.')
parser.add_argument('--max_frames',
default=250,
type=int,
help='The max frames of keyword\'s duration.')
help='The max frames of keyword duration.')
args = parser.parse_args()
return args
@ -158,7 +164,9 @@ def main():
lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed'
keywords_str = args.keywords
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
@ -217,7 +225,8 @@ def main():
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
for prob, idx in zip(
top_k_probs.tolist(), top_k_index.tolist()):
if keywords_idxset is not None:
if prob > 0.05 and idx in keywords_idxset:
filter_probs.append(prob)
@ -246,7 +255,8 @@ def main():
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
# update frame and prob
if ps > nodes[-1]['prob']:
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
@ -257,32 +267,37 @@ def main():
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
nodes.append(dict(
token=s, frame=t, prob=ps))
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s,)
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
# update frame and prob
if ps > nodes[-1]['prob']:
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
nodes.pop() # to avoid change other beam which has this node.
nodes.append(dict(token=s, frame=t, prob=ps))
# avoid change other beam has this node.
nodes.pop()
nodes.append(dict(
token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
nodes.append(dict(
token=s, frame=t, prob=ps))
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
next_hyps.items(),
key=lambda x: (x[1][0] + x[1][1]), reverse=True)
cur_hyps = next_hyps[:args.path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2])
for y in cur_hyps]
for one_hyp in hyps:
prefix_ids = one_hyp[0]
@ -295,7 +310,8 @@ def main():
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
end = prefix_nodes[
offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
@ -305,25 +321,35 @@ def main():
duration = end - start
if hit_keyword is not None:
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
if hit_score >= args.threshold and \
args.min_frames <= duration <= args.max_frames:
activated = True
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"duration {duration}, score {hit_score} Activated.")
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"duration {duration}, s"
f"core {hit_score} Activated.")
# clear the ctc_prefix buffer, and clear hit_keyword
# clear the ctc_prefix buffer, and hit_keyword
cur_hyps = [(tuple(), (1.0, 0.0, []))]
hit_keyword = None
hit_score = 1.0
elif hit_score < args.threshold:
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"but {hit_score} less than {args.threshold}, Deactivated. ")
elif args.min_frames > duration or duration > args.max_frames:
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"but {hit_score} less than "
f"{args.threshold}, Deactivated. ")
elif args.min_frames > duration \
or duration > args.max_frames:
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"but {duration} beyond range({args.min_frames}~{args.max_frames}), Deactivated. ")
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"but {duration} beyond "
f"range({args.min_frames}~{args.max_frames}), "
f"Deactivated. ")
if not activated:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -159,8 +159,12 @@ def main():
if rank == 0:
pass
# TODO: for now streaming FSMN do not support export to JITScript,
# TODO: because there is nn.Sequential with Tuple input in current FSMN modules.
# the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
# TODO: because there is nn.Sequential with Tuple input
# in current FSMN modules.
# the issue is in https://stackoverflow.com/questions/75714299/
# pytorch-jit-script-error-when-sequential-container-
# takes-a-tuple-input/76553450#76553450
# script_model = torch.jit.script(model)
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()

View File

@ -225,16 +225,19 @@ class FSMNBlock(nn.Module):
x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache) == 0 :
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride
+ self.rorder * self.rstride, 0])
else:
in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride
+ self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
out = x_pad[:, :, (self.lorder - 1) * self.lstride:
-self.rorder * self.rstride, :] + y_left
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
@ -253,7 +256,8 @@ class FSMNBlock(nn.Module):
def to_kaldi_net(self):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight)
@ -441,7 +445,8 @@ class FSMN(nn.Module):
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride
self.padding = (self.lorder-1) * self.lstride \
+ self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@ -469,7 +474,8 @@ class FSMN(nn.Module):
"""
if in_cache is None or len(in_cache) == 0 :
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))]
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float)
for _ in range(len(self.fsmn))]
input = (input, in_cache)
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)

View File

@ -34,7 +34,7 @@ class KWSModel(nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
3. backbone: backbone of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation:
nn.Sigmoid for wakeup word
@ -76,7 +76,8 @@ class KWSModel(nn.Module):
def forward_softmax(self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
in_cache: torch.Tensor = torch.zeros(
0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
@ -196,7 +197,8 @@ def init_model(configs):
classifier = LinearClassifier(hidden_dim, output_dim)
activation = nn.Sigmoid()
# Here we add a possible "activation_type", one can choose to use other activation function.
# Here we add a possible "activation_type",
# one can choose to use other activation function.
# We use nn.Identity just for CTC loss
if "activation" in configs:
activation_type = configs["activation"]["type"]

View File

@ -188,7 +188,8 @@ def criterion(type: str,
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
elif type == 'ctc':
loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation)
loss, acc = ctc_loss(
logits, target, lengths, target_lengths, validation)
return loss, acc
else:
exit(1)
@ -281,7 +282,8 @@ def ctc_prefix_beam_search(
if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
nodes.pop() # to avoid change other beam which has this node.
# avoid change other beam which has this node.
nodes.pop()
nodes.append(dict(token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
@ -429,7 +431,8 @@ class Calculator:
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'
'this should not happen, '
'i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result