fix flake8, update training script and README, give pretrained ckpt.

This commit is contained in:
dujing 2023-07-24 16:39:13 +08:00
parent 45f0522f19
commit ea6a0f5cda
13 changed files with 299 additions and 159 deletions

View File

@ -1,4 +1,6 @@
Comparison among different backbones. FRRs with FAR fixed at once per hour: Comparison among different backbones,
all models use Max-Pooling loss.
FRRs with FAR fixed at once per hour:
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen | | model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|-----------------------|-----------|-----------|------------|--------------| |-----------------------|-----------|-----------|------------|--------------|
@ -9,32 +11,35 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour:
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 | | MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 | | MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
Next, we use CTC loss to train the model, with DS_TCN and FSMN. Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
and we use CTC prefix beam search to decode and detect keywords, and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion. the detection is either in non-streaming or streaming fashion.
Since the FAR is pretty low when using CTC loss, Since the FAR is pretty low when using CTC loss,
the follow result is FRRs with FAR fixed at once per 12 hours: the follow results are FRRs with FAR fixed at once per 12 hours:
Comparison between Max-pooling and CTC loss. Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch). The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
FRRs with FAR fixed at once per 12 hours FRRs with FAR fixed at once per 12 hours
| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
| model | loss | hi_xiaowen | nihao_wenwen | |-----------------------|-------------|------------|--------------|------------|
|-----------------------|-------------|------------|--------------| | DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | | DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch) Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
and FSMN(modelscope released, xiaoyunxiaoyun model). and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
FRRs with FAR fixed at once per 12 hours: FRRs with FAR fixed at once per 12 hours:
| model | params(K) | hi_xiaowen | nihao_wenwen | | model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------| |-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | | DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | | FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |
Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.
Comparison Between stream_score_ctc and score_ctc. Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours: FRRs with FAR fixed at once per 12 hours:
@ -52,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va
which result in a higher False Rejection Rate in Detection Error Tradeoff result. which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold. The actual FRR will be lower than the DET curve gives in a given threshold.
Now, the model with CTC loss may not get the best performance, On some small data KWS tasks, we believe the FSMN-CTC model is more robust
but it's more robust compared with the classification model using CE/Max-pooling loss. compared with the classification model using CE/Max-pooling loss.
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

View File

@ -194,12 +194,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \
--num_workers 8 \ --num_workers 8 \
--keywords 嗨小问,你好问问 \ --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \ --token_file data/tokens.txt \
--lexicon_file data/lexicon.txt --lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \ python wekws/bin/compute_det_ctc.py \
--keywords 嗨小问,你好问问 \ --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \ --test_data data/test/data.list \
--window_shift $window_shift \ --window_shift $window_shift \
--step 0.001 \ --step 0.001 \

View File

@ -145,12 +145,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \
--num_workers 8 \ --num_workers 8 \
--keywords 嗨小问,你好问问 \ --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \ --token_file data/tokens.txt \
--lexicon_file data/lexicon.txt --lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \ python wekws/bin/compute_det_ctc.py \
--keywords 嗨小问,你好问问 \ --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \ --test_data data/test/data.list \
--window_shift $window_shift \ --window_shift $window_shift \
--step 0.001 \ --step 0.001 \
@ -161,14 +161,13 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450 onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') # For now, FSMN can not export to JITScript
# python wekws/bin/export_jit.py \ # python wekws/bin/export_jit.py \
# --config $dir/config.yaml \ # --config $dir/config.yaml \
# --checkpoint $score_checkpoint \ # --checkpoint $score_checkpoint \
# --jit_model $dir/$jit_model # --jit_model $dir/$jit_model
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_onnx.py \ python wekws/bin/export_onnx.py \
--config $dir/config.yaml \ --config $dir/config.yaml \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \

View File

@ -45,7 +45,8 @@ def query_token_set(txt, symbol_table, lexicon_table):
tokens_str = tokens_str + ('!sil', ) tokens_str = tokens_str + ('!sil', )
elif part == '<blk>' or part == '<blank>': elif part == '<blk>' or part == '<blank>':
tokens_str = tokens_str + ('<blk>', ) tokens_str = tokens_str + ('<blk>', )
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>': elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str = tokens_str + ('<GBG>', ) tokens_str = tokens_str + ('<GBG>', )
elif part in symbol_table: elif part in symbol_table:
tokens_str = tokens_str + (part, ) tokens_str = tokens_str + (part, )
@ -75,11 +76,11 @@ def query_token_set(txt, symbol_table, lexicon_table):
if '<GBG>' in symbol_table: if '<GBG>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<GBG>'], ) tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
logging.info( logging.info(
f'\'{ch}\' is not in token set, replace with <GBG>') f'{ch} is not in token set, replace with <GBG>')
else: else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], ) tokens_idx = tokens_idx + (symbol_table['<blk>'], )
logging.info( logging.info(
f'\'{ch}\' is not in token set, replace with <blk>') f'{ch} is not in token set, replace with <blk>')
return tokens_str, tokens_idx return tokens_str, tokens_idx
@ -94,7 +95,8 @@ def query_token_list(txt, symbol_table, lexicon_table):
tokens_str.append('!sil') tokens_str.append('!sil')
elif part == '<blk>' or part == '<blank>': elif part == '<blk>' or part == '<blank>':
tokens_str.append('<blk>') tokens_str.append('<blk>')
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>': elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str.append('<GBG>') tokens_str.append('<GBG>')
elif part in symbol_table: elif part in symbol_table:
tokens_str.append(part) tokens_str.append(part)
@ -124,11 +126,11 @@ def query_token_list(txt, symbol_table, lexicon_table):
if '<GBG>' in symbol_table: if '<GBG>' in symbol_table:
tokens_idx.append(symbol_table['<GBG>']) tokens_idx.append(symbol_table['<GBG>'])
logging.info( logging.info(
f'\'{ch}\' is not in token set, replace with <GBG>') f'{ch} is not in token set, replace with <GBG>')
else: else:
tokens_idx.append(symbol_table['<blk>']) tokens_idx.append(symbol_table['<blk>'])
logging.info( logging.info(
f'\'{ch}\' is not in token set, replace with <blk>') f'{ch} is not in token set, replace with <blk>')
return tokens_str, tokens_idx return tokens_str, tokens_idx
@ -160,8 +162,10 @@ if __name__ == '__main__':
parser.add_argument('text_file', help='text file') parser.add_argument('text_file', help='text file')
parser.add_argument('duration_file', help='duration file') parser.add_argument('duration_file', help='duration file')
parser.add_argument('output_file', help='output list file') parser.add_argument('output_file', help='output list file')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') parser.add_argument('--token_file', type=str, default=None,
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args() args = parser.parse_args()
wav_table = {} wav_table = {}
@ -196,7 +200,9 @@ if __name__ == '__main__':
txt = [1] # the <blank>/sil is indexed by 1 txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"] tokens = ["sil"]
else: else:
tokens, txt = query_token_list(arr[1], token_table, lexicon_table) tokens, txt = query_token_list(arr[1],
token_table,
lexicon_table)
else: else:
txt = int(arr[1]) txt = int(arr[1])
assert key in wav_table assert key in wav_table
@ -206,7 +212,8 @@ if __name__ == '__main__':
if tokens is None: if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav) line = dict(key=key, txt=txt, duration=duration, wav=wav)
else: else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav) line = dict(key=key, tok=tokens, txt=txt,
duration=duration, wav=wav)
json_line = json.dumps(line, ensure_ascii=False) json_line = json.dumps(line, ensure_ascii=False)
fout.write(json_line + '\n') fout.write(json_line + '\n')

View File

@ -157,18 +157,23 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve') parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file') parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)') parser.add_argument('--keywords', type=str, default=None,
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') help='keywords, split with comma(,)')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01, parser.add_argument('--step', type=float, default=0.01,
help='threshold step') help='threshold step')
parser.add_argument('--window_shift', type=int, default=50, parser.add_argument('--window_shift', type=int, default=50,
help='window_shift is used to skip the frames after triggered') help='window_shift is used to '
'skip the frames after triggered')
parser.add_argument('--stats_dir', parser.add_argument('--stats_dir',
required=False, required=False,
default=None, default=None,
help='false reject/alarm stats dir, default in score_file') help='false reject/alarm stats dir, '
'default in score_file')
parser.add_argument('--det_curve_path', parser.add_argument('--det_curve_path',
required=False, required=False,
default=None, default=None,
@ -188,7 +193,11 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
window_shift = args.window_shift window_shift = args.window_shift
keywords_list = args.keywords.strip().split(',') logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords.strip().split(',')
token_table = read_token(args.token_file) token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file) lexicon_table = read_lexicon(args.lexicon_file)
@ -197,7 +206,8 @@ if __name__ == '__main__':
strs, indexes = query_token_set(keyword, token_table, lexicon_table) strs, indexes = query_token_set(keyword, token_table, lexicon_table)
true_keywords[keyword] = ''.join(strs) true_keywords[keyword] = ''.join(strs)
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords) keyword_filler_table = load_label_and_score(
keywords_list, args.test_data, args.score_file, true_keywords)
for keyword in keywords_list: for keyword in keywords_list:
keyword = true_keywords[keyword] keyword = true_keywords[keyword]
@ -206,8 +216,10 @@ if __name__ == '__main__':
keyword_num = len(keyword_filler_table[keyword]['keyword_table']) keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
filler_dur = keyword_filler_table[keyword]['filler_duration'] filler_dur = keyword_filler_table[keyword]['filler_duration']
filler_num = len(keyword_filler_table[keyword]['filler_table']) filler_num = len(keyword_filler_table[keyword]['filler_table'])
assert keyword_num > 0, 'Can\'t compute det for {} without positive sample' assert keyword_num > 0, \
assert filler_num > 0, 'Can\'t compute det for {} without negative sample' 'Can\'t compute det for {} without positive sample'
assert filler_num > 0, \
'Can\'t compute det for {} without negative sample'
logging.info('Computing det for {}'.format(keyword)) logging.info('Computing det for {}'.format(keyword))
logging.info(' Keyword duration: {} Hours, wave number: {}'.format( logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
@ -218,14 +230,16 @@ if __name__ == '__main__':
stats_dir = args.stats_dir stats_dir = args.stats_dir
else: else:
stats_dir = os.path.dirname(args.score_file) stats_dir = os.path.dirname(args.score_file)
stats_file = os.path.join(stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt') stats_file = os.path.join(
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
with open(stats_file, 'w', encoding='utf8') as fout: with open(stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0 threshold = 0.0
while threshold <= 1.0: while threshold <= 1.0:
num_false_reject = 0 num_false_reject = 0
num_true_detect = 0 num_true_detect = 0
# transverse the all keyword_table # transverse the all keyword_table
for key, confi in keyword_filler_table[keyword]['keyword_table'].items(): for key, confi in \
keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold: if confi < threshold:
num_false_reject += 1 num_false_reject += 1
else: else:
@ -253,4 +267,5 @@ if __name__ == '__main__':
det_curve_path = args.det_curve_path det_curve_path = args.det_curve_path
else: else:
det_curve_path = os.path.join(stats_dir, 'det.png') det_curve_path = os.path.join(stats_dir, 'det.png')
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step) plot_det(stats_dir, det_curve_path,
args.xlim, args.x_step, args.ylim, args.y_step)

View File

@ -42,7 +42,7 @@ def main():
feature_dim = configs['model']['input_dim'] feature_dim = configs['model']['input_dim']
model = init_model(configs['model']) model = init_model(configs['model'])
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc': if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search # if we use ctc_loss, the logits need to be convert into probs
model.forward = model.forward_softmax model.forward = model.forward_softmax
print(model) print(model)

View File

@ -65,9 +65,12 @@ def get_args():
action='store_true', action='store_true',
default=False, default=False,
help='Use pinned memory buffers used for reading') help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') parser.add_argument('--keywords', type=str, default=None,
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') help='the keywords, split with comma(,)')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -133,7 +136,9 @@ def main():
lexicon_table = read_lexicon(args.lexicon_file) lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens # 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed' assert args.keywords is not None, 'at least one keyword is needed'
keywords_str = args.keywords logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {} keywords_token = {}
keywords_idxset = {0} keywords_idxset = {0}
@ -167,7 +172,9 @@ def main():
for i in range(len(keys)): for i in range(len(keys)):
key = keys[i] key = keys[i]
score = logits[i][:lengths[i]] score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset) hyps = ctc_prefix_beam_search(score,
lengths[i],
keywords_idxset)
hit_keyword = None hit_keyword = None
hit_score = 1.0 hit_score = 1.0
start = 0 start = 0
@ -192,10 +199,13 @@ def main():
break break
if hit_keyword is not None: if hit_keyword is not None:
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score)) fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info( logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"duration {end - start}, score {hit_score}, Activated.") f"in {key} from {start} to {end} frame. "
f"duration {end - start}, "
f"score {hit_score}, Activated.")
else: else:
fout.write('{} rejected\n'.format(key)) fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -16,7 +16,8 @@ from __future__ import print_function
import argparse import argparse
import struct import struct
import wave #import wave
import librosa
import logging import logging
import os import os
import math import math
@ -35,9 +36,12 @@ from tools.make_list import query_token_set, read_lexicon, read_token
def get_args(): def get_args():
parser = argparse.ArgumentParser(description='detect keywords online.') parser = argparse.ArgumentParser(description='detect keywords online.')
parser.add_argument('--config', required=True, help='config file') parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--wav_path', required=False, default=None, help='test wave path.') parser.add_argument('--wav_path', required=False,
parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.') default=None, help='test wave path.')
parser.add_argument('--result_file', required=False, default=None, help='test result.') parser.add_argument('--wav_scp', required=False,
default=None, help='test wave scp.')
parser.add_argument('--result_file', required=False,
default=None, help='test result.')
parser.add_argument('--gpu', parser.add_argument('--gpu',
type=int, type=int,
@ -48,21 +52,27 @@ def get_args():
action='store_true', action='store_true',
default=False, default=False,
help='Use pinned memory buffers used for reading') help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') parser.add_argument('--keywords', type=str, default=None,
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') help='the keywords, split with comma(,)')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_beam_size', parser.add_argument('--score_beam_size',
default=3, default=3,
type=int, type=int,
help='The first prune beam, filter out those frames with low scores.') help='The first prune beam, '
'filter out those frames with low scores.')
parser.add_argument('--path_beam_size', parser.add_argument('--path_beam_size',
default=20, default=20,
type=int, type=int,
help='The second prune beam, keep only path_beam_size candidates.') help='The second prune beam, '
'keep only path_beam_size candidates.')
parser.add_argument('--threshold', parser.add_argument('--threshold',
type=float, type=float,
default=0.0, default=0.0,
help='The threshold of kws. If ctc_search probs exceed this value,' help='The threshold of kws. '
'If ctc_search probs exceed this value,'
'the keyword will be activated.') 'the keyword will be activated.')
parser.add_argument('--min_frames', parser.add_argument('--min_frames',
default=5, default=5,
@ -98,16 +108,22 @@ def is_sublist(main_list, check_list):
else: else:
return -1 return -1
def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size): def ctc_prefix_beam_search(t, probs,
cur_hyps,
keywords_idxset,
score_beam_size):
''' '''
:param t: the time in frame :param t: the time in frame
:param probs: the probability in t_th frame, (vocab_size, ) :param probs: the probability in t_th frame, (vocab_size, )
:param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))] :param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))]
in tuple, 1st is prefix id, 2nd include p_blank, p_non_blank, and path nodes list. in tuple, 1st is prefix id, 2nd include p_blank,
in path nodes list, each node is a dict of {token=idx, frame=t, prob=ps} p_non_blank, and path nodes list.
in path nodes list, each node is
a dict of {token=idx, frame=t, prob=ps}
:param keywords_idxset: the index of keywords in token.txt :param keywords_idxset: the index of keywords in token.txt
:param score_beam_size: the probability threshold, to filter out those frames with low probs. :param score_beam_size: the probability threshold,
to filter out those frames with low probs.
:return: :return:
next_hyps: the hypothesis depend on current hyp and current frame. next_hyps: the hypothesis depend on current hyp and current frame.
''' '''
@ -170,7 +186,8 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
if ps > nodes[-1]['prob']: # update frame and prob if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps # nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t # nodes[-1]['frame'] = t
nodes.pop() # to avoid change other beam which has this node. nodes.pop()
# to avoid change other beam which has this node.
nodes.append(dict(token=s, frame=t, prob=ps)) nodes.append(dict(token=s, frame=t, prob=ps))
else: else:
nodes = cur_nodes.copy() nodes = cur_nodes.copy()
@ -199,11 +216,14 @@ class KeyWordSpotter(torch.nn.Module):
# feature related # feature related
self.sample_rate = 16000 self.sample_rate = 16000
self.wave_remained = np.array([]) self.wave_remained = np.array([])
self.num_mel_bins = dataset_conf['feature_extraction_conf']['num_mel_bins'] self.num_mel_bins = dataset_conf[
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms 'feature_extraction_conf']['num_mel_bins']
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms self.frame_length = dataset_conf[
'feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf[
'feature_extraction_conf']['frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1) self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift / 1000 # in second self.resolution = self.frame_shift / 1000 # in second
# fsmn splice operation # fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False) self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0 self.left_context = 0
@ -231,9 +251,11 @@ class KeyWordSpotter(torch.nn.Module):
self.model.eval() self.model.eval()
logging.info(f'model {ckpt_path} loaded.') logging.info(f'model {ckpt_path} loaded.')
self.token_table = read_token(token_path) self.token_table = read_token(token_path)
logging.info(f'tokens {token_path} with {len(self.token_table)} units loaded.') logging.info(f'tokens {token_path} with '
f'{len(self.token_table)} units loaded.')
self.lexicon_table = read_lexicon(lexicon_path) self.lexicon_table = read_lexicon(lexicon_path)
logging.info(f'lexicons {lexicon_path} with {len(self.lexicon_table)} units loaded.') logging.info(f'lexicons {lexicon_path} with '
f'{len(self.lexicon_table)} units loaded.')
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
@ -257,7 +279,9 @@ class KeyWordSpotter(torch.nn.Module):
def set_keywords(self, keywords): def set_keywords(self, keywords):
# 4. parse keywords tokens # 4. parse keywords tokens
assert keywords is not None, 'at least one keyword is needed, multiple keywords should be splitted with comma(,)' assert keywords is not None, \
'at least one keyword is needed, ' \
'multiple keywords should be splitted with comma(,)'
keywords_str = keywords keywords_str = keywords
keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {} keywords_token = {}
@ -265,7 +289,8 @@ class KeyWordSpotter(torch.nn.Module):
keywords_strset = {'<blk>'} keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0} keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list: for keyword in keywords_list:
strs, indexes = query_token_set(keyword, self.token_table, self.lexicon_table) strs, indexes = query_token_set(
keyword, self.token_table, self.lexicon_table)
keywords_token[keyword] = {} keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
@ -284,16 +309,20 @@ class KeyWordSpotter(torch.nn.Module):
self.keywords_token = keywords_token self.keywords_token = keywords_token
def accept_wave(self, wave): def accept_wave(self, wave):
assert isinstance(wave, bytes), "please make sure the input format is bytes(raw PCM)" assert isinstance(wave, bytes), \
"please make sure the input format is bytes(raw PCM)"
# convert bytes into float32 # convert bytes into float32
data = [] data = []
for i in range(0, len(wave), 2): for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0] value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input data.append(value)
# here we don't divide 32768.0,
# because kaldi.fbank accept original input
wave = np.array(data) wave = np.array(data)
wave = np.append(self.wave_remained, wave) wave = np.append(self.wave_remained, wave)
if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context : if wave.size < (self.frame_length * self.sample_rate / 1000) \
* self.right_context :
self.wave_remained = wave self.wave_remained = wave
return None return None
wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -311,29 +340,37 @@ class KeyWordSpotter(torch.nn.Module):
self.wave_remained = wave[feat_len * frame_shift:] self.wave_remained = wave[feat_len * frame_shift:]
if self.context_expansion: if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context." assert feat_len > self.right_context, \
"make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk # pad feats with remained feature from last chunk
if self.feature_remained is None: # first chunk if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose. # pad first frame at the beginning,
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T # replicate just support last dimension, so we do transpose.
feats_pad = F.pad(
feats.T, (self.left_context, 0), mode='replicate').T
else: else:
feats_pad = torch.cat((self.feature_remained, feats)) feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context) ctx_frm = feats_pad.shape[0] - \
(self.right_context+self.right_context)
ctx_win = (self.left_context + self.right_context + 1) ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm): for i in range(ctx_frm):
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) feats_ctx[i] = torch.cat(
tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats # update feature remained, and feats
self.feature_remained = feats[-(self.left_context+self.right_context):] self.feature_remained = \
feats[-(self.left_context+self.right_context):]
feats = feats_ctx.to(self.device) feats = feats_ctx.to(self.device)
if self.downsampling > 1: if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset last_remainder = 0 if self.feats_ctx_offset==0 \
else self.downsampling-self.feats_ctx_offset
remainder = (feats.size(0)+last_remainder) % self.downsampling remainder = (feats.size(0)+last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :] feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling-remainder
return feats return feats
def decode_keywords(self, t, probs): def decode_keywords(self, t, probs):
@ -344,7 +381,8 @@ class KeyWordSpotter(torch.nn.Module):
self.cur_hyps, self.cur_hyps,
self.keywords_idxset, self.keywords_idxset,
self.score_beam) self.score_beam)
# update cur_hyps. note: the hyps is sort by path score(pnb+pb), not the keywords' probabilities. # update cur_hyps. note: the hyps is sort by path score(pnb+pb),
# not the keywords' probabilities.
cur_hyps = next_hyps[:self.path_beam] cur_hyps = next_hyps[:self.path_beam]
self.cur_hyps = cur_hyps self.cur_hyps = cur_hyps
@ -381,27 +419,36 @@ class KeyWordSpotter(torch.nn.Module):
if hit_keyword is not None: if hit_keyword is not None:
if self.hit_score >= self.threshold and \ if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \ self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames): and (self.last_active_pos==-1 or
end-self.last_active_pos >= self.interval_frames):
self.activated = True self.activated = True
self.last_active_pos = end self.last_active_pos = end
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.") f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames: elif self.last_active_pos>0 and \
end-self.last_active_pos < self.interval_frames:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} "
f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ") f"from {start} to {end} frame. "
f"but interval {end-self.last_active_pos} "
f"is lower than {self.interval_frames}, Deactivated. ")
elif self.hit_score < self.threshold: elif self.hit_score < self.threshold:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} "
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ") f"from {start} to {end} frame. "
f"but {self.hit_score} "
f"is lower than {self.threshold}, Deactivated. ")
elif self.min_frames > duration or duration > self.max_frames: elif self.min_frames > duration or duration > self.max_frames:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} "
f"but {duration} beyond range({self.min_frames}~{self.max_frames}), Deactivated. ") f"from {start} to {end} frame. "
f"but {duration} beyond range"
f"({self.min_frames}~{self.max_frames}), Deactivated. ")
self.result = { self.result = {
"state": 1 if self.activated else 0, "state": 1 if self.activated else 0,
@ -418,7 +465,7 @@ class KeyWordSpotter(torch.nn.Module):
feature = feature.unsqueeze(0) # add a batch dimension feature = feature.unsqueeze(0) # add a batch dimension
logits, self.in_cache = self.model(feature, self.in_cache) logits, self.in_cache = self.model(feature, self.in_cache)
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search probs = probs[0].cpu() # remove batch dimension
for (t, prob) in enumerate(probs): for (t, prob) in enumerate(probs):
t *= self.downsampling t *= self.downsampling
self.decode_keywords(t, prob) self.decode_keywords(t, prob)
@ -426,10 +473,14 @@ class KeyWordSpotter(torch.nn.Module):
if self.activated: if self.activated:
self.reset() self.reset()
# since a chunk include about 30 frames, once activated, we can jump the latter frames. # since a chunk include about 30 frames,
# TODO: there should give another method to update result, avoiding self.result being cleared. # once activated, we can jump the latter frames.
# TODO: there should give another method to update result,
# avoiding self.result being cleared.
break break
self.total_frames += len(probs) * self.downsampling # update frame offset
# update frame offset
self.total_frames += len(probs) * self.downsampling
return self.result return self.result
def reset(self): def reset(self):
@ -465,15 +516,20 @@ def demo():
args.gpu, args.gpu,
args.jit_model) args.jit_model)
# actually this could be done in __init__ method, we pull it outside for changing keywords more freely. # actually this could be done in __init__ method,
# we pull it outside for changing keywords more freely.
kws.set_keywords(args.keywords) kws.set_keywords(args.keywords)
if args.wav_path: if args.wav_path:
# Caution: input WAV should be standard 16k, 16 bits, 1 channel # Caution: input WAV should be standard 16k, 16 bits, 1 channel
# In demo we read wave in non-streaming fashion. # In demo we read wave in non-streaming fashion.
with wave.open(args.wav_path, 'rb') as fin: # with wave.open(args.wav_path, 'rb') as fin:
assert fin.getnchannels() == 1 # assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes()) # wav = fin.readframes(fin.getnframes())
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
# NOTE: model supports 16k sample_rate
wav = (y * (1 << 15)).astype("int16").tobytes()
# We inference every 0.3 seconds, in streaming fashion. # We inference every 0.3 seconds, in streaming fashion.
interval = int(0.3 * 16000) * 2 interval = int(0.3 * 16000) * 2
@ -490,12 +546,18 @@ def demo():
with open(args.wav_scp, 'r') as fscp: with open(args.wav_scp, 'r') as fscp:
for line in fscp: for line in fscp:
line = line.strip().split() line = line.strip().split()
assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}" assert len(line) == 2, \
f"The scp should be in kaldi format: " \
f"\"utt_name wav_path\", but got {line}"
utt_name, wav_path = line[0], line[1] utt_name, wav_path = line[0], line[1]
with wave.open(wav_path, 'rb') as fin: # with wave.open(args.wav_path, 'rb') as fin:
assert fin.getnchannels() == 1 # assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes()) # wav = fin.readframes(fin.getnframes())
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
# NOTE: model supports 16k sample_rate
wav = (y * (1 << 15)).astype("int16").tobytes()
kws.reset_all() kws.reset_all()
activated = False activated = False
@ -510,7 +572,8 @@ def demo():
if fout: if fout:
hit_keyword = result['keyword'] hit_keyword = result['keyword']
hit_score = result['score'] hit_score = result['score']
fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score)) fout.write('{} detected {} {:.3f}\n'.format(
utt_name, hit_keyword, hit_score))
if not activated: if not activated:
if fout: if fout:

View File

@ -66,30 +66,36 @@ def get_args():
action='store_true', action='store_true',
default=False, default=False,
help='Use pinned memory buffers used for reading') help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') parser.add_argument('--keywords', type=str, default=None,
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') help='the keywords, split with comma(,)')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_beam_size', parser.add_argument('--score_beam_size',
default=3, default=3,
type=int, type=int,
help='The first prune beam, filter out those frames with low scores.') help='The first prune beam, f'
'ilter out those frames with low scores.')
parser.add_argument('--path_beam_size', parser.add_argument('--path_beam_size',
default=20, default=20,
type=int, type=int,
help='The second prune beam, keep only path_beam_size candidates.') help='The second prune beam, '
'keep only path_beam_size candidates.')
parser.add_argument('--threshold', parser.add_argument('--threshold',
type=float, type=float,
default=0.0, default=0.0,
help='The threshold of kws. If ctc_search probs exceed this value,' help='The threshold of kws. '
'If ctc_search probs exceed this value,'
'the keyword will be activated.') 'the keyword will be activated.')
parser.add_argument('--min_frames', parser.add_argument('--min_frames',
default=5, default=5,
type=int, type=int,
help='The min frames of keyword\'s duration.') help='The min frames of keyword duration.')
parser.add_argument('--max_frames', parser.add_argument('--max_frames',
default=250, default=250,
type=int, type=int,
help='The max frames of keyword\'s duration.') help='The max frames of keyword duration.')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -158,7 +164,9 @@ def main():
lexicon_table = read_lexicon(args.lexicon_file) lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens # 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed' assert args.keywords is not None, 'at least one keyword is needed'
keywords_str = args.keywords logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {} keywords_token = {}
keywords_idxset = {0} keywords_idxset = {0}
@ -217,7 +225,8 @@ def main():
# filter prob score that is too small # filter prob score that is too small
filter_probs = [] filter_probs = []
filter_index = [] filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): for prob, idx in zip(
top_k_probs.tolist(), top_k_index.tolist()):
if keywords_idxset is not None: if keywords_idxset is not None:
if prob > 0.05 and idx in keywords_idxset: if prob > 0.05 and idx in keywords_idxset:
filter_probs.append(prob) filter_probs.append(prob)
@ -246,7 +255,8 @@ def main():
n_pb, n_pnb, nodes = next_hyps[prefix] n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy() nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob # update frame and prob
if ps > nodes[-1]['prob']:
nodes[-1]['prob'] = ps nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes) next_hyps[prefix] = (n_pb, n_pnb, nodes)
@ -257,32 +267,37 @@ def main():
n_pb, n_pnb, nodes = next_hyps[n_prefix] n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy() nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t, nodes.append(dict(
prob=ps)) # to record token prob token=s, frame=t, prob=ps))
next_hyps[n_prefix] = (n_pb, n_pnb, nodes) next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else: else:
n_prefix = prefix + (s,) n_prefix = prefix + (s,)
n_pb, n_pnb, nodes = next_hyps[n_prefix] n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes: if nodes:
if ps > nodes[-1]['prob']: # update frame and prob # update frame and prob
if ps > nodes[-1]['prob']:
# nodes[-1]['prob'] = ps # nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t # nodes[-1]['frame'] = t
nodes.pop() # to avoid change other beam which has this node. # avoid change other beam has this node.
nodes.append(dict(token=s, frame=t, prob=ps)) nodes.pop()
nodes.append(dict(
token=s, frame=t, prob=ps))
else: else:
nodes = cur_nodes.copy() nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t, nodes.append(dict(
prob=ps)) # to record token prob token=s, frame=t, prob=ps))
n_pnb = n_pnb + pb * ps + pnb * ps n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes) next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune # 2.2 Second beam prune
next_hyps = sorted( next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) next_hyps.items(),
key=lambda x: (x[1][0] + x[1][1]), reverse=True)
cur_hyps = next_hyps[:args.path_beam_size] cur_hyps = next_hyps[:args.path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] hyps = [(y[0], y[1][0] + y[1][1], y[1][2])
for y in cur_hyps]
for one_hyp in hyps: for one_hyp in hyps:
prefix_ids = one_hyp[0] prefix_ids = one_hyp[0]
@ -295,7 +310,8 @@ def main():
if offset != -1: if offset != -1:
hit_keyword = word hit_keyword = word
start = prefix_nodes[offset]['frame'] start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame'] end = prefix_nodes[
offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)): for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob'] hit_score *= prefix_nodes[idx]['prob']
break break
@ -305,25 +321,35 @@ def main():
duration = end - start duration = end - start
if hit_keyword is not None: if hit_keyword is not None:
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames: if hit_score >= args.threshold and \
args.min_frames <= duration <= args.max_frames:
activated = True activated = True
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info( logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"duration {duration}, score {hit_score} Activated.") f"in {key} from {start} to {end} frame. "
f"duration {duration}, s"
f"core {hit_score} Activated.")
# clear the ctc_prefix buffer, and clear hit_keyword # clear the ctc_prefix buffer, and hit_keyword
cur_hyps = [(tuple(), (1.0, 0.0, []))] cur_hyps = [(tuple(), (1.0, 0.0, []))]
hit_keyword = None hit_keyword = None
hit_score = 1.0 hit_score = 1.0
elif hit_score < args.threshold: elif hit_score < args.threshold:
logging.info( logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"but {hit_score} less than {args.threshold}, Deactivated. ") f"in {key} from {start} to {end} frame. "
elif args.min_frames > duration or duration > args.max_frames: f"but {hit_score} less than "
f"{args.threshold}, Deactivated. ")
elif args.min_frames > duration \
or duration > args.max_frames:
logging.info( logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"but {duration} beyond range({args.min_frames}~{args.max_frames}), Deactivated. ") f"in {key} from {start} to {end} frame. "
f"but {duration} beyond "
f"range({args.min_frames}~{args.max_frames}), "
f"Deactivated. ")
if not activated: if not activated:
fout.write('{} rejected\n'.format(key)) fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -159,8 +159,12 @@ def main():
if rank == 0: if rank == 0:
pass pass
# TODO: for now streaming FSMN do not support export to JITScript, # TODO: for now streaming FSMN do not support export to JITScript,
# TODO: because there is nn.Sequential with Tuple input in current FSMN modules. # TODO: because there is nn.Sequential with Tuple input
# the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450 # in current FSMN modules.
# the issue is in https://stackoverflow.com/questions/75714299/
# pytorch-jit-script-error-when-sequential-container-
# takes-a-tuple-input/76553450#76553450
# script_model = torch.jit.script(model) # script_model = torch.jit.script(model)
# script_model.save(os.path.join(args.model_dir, 'init.zip')) # script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor() executor = Executor()

View File

@ -225,16 +225,19 @@ class FSMNBlock(nn.Module):
x_per = x.permute(0, 3, 2, 1) x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache) == 0 : if in_cache is None or len(in_cache) == 0 :
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0]) x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride
+ self.rorder * self.rstride, 0])
else: else:
in_cache = in_cache.to(x_per.device) in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2) x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :] in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride
+ self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :] y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left) y_left = self.quant(y_left)
y_left = self.conv_left(y_left) y_left = self.conv_left(y_left)
y_left = self.dequant(y_left) y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left out = x_pad[:, :, (self.lorder - 1) * self.lstride:
-self.rorder * self.rstride, :] + y_left
if self.conv_right is not None: if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
@ -253,7 +256,8 @@ class FSMNBlock(nn.Module):
def to_kaldi_net(self): def to_kaldi_net(self):
re_str = '' re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim) re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % ( re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride) 1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight) # print(self.conv_left.weight,self.conv_right.weight)
@ -441,7 +445,8 @@ class FSMN(nn.Module):
self.output_affine_dim = output_affine_dim self.output_affine_dim = output_affine_dim
self.output_dim = output_dim self.output_dim = output_dim
self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride self.padding = (self.lorder-1) * self.lstride \
+ self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@ -469,7 +474,8 @@ class FSMN(nn.Module):
""" """
if in_cache is None or len(in_cache) == 0 : if in_cache is None or len(in_cache) == 0 :
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))] in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float)
for _ in range(len(self.fsmn))]
input = (input, in_cache) input = (input, in_cache)
x1 = self.in_linear1(input) x1 = self.in_linear1(input)
x2 = self.in_linear2(x1) x2 = self.in_linear2(x1)

View File

@ -34,7 +34,7 @@ class KWSModel(nn.Module):
"""Our model consists of four parts: """Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim) 1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim) 2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) 3. backbone: backbone of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim) 4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation: 5. activation:
nn.Sigmoid for wakeup word nn.Sigmoid for wakeup word
@ -76,7 +76,8 @@ class KWSModel(nn.Module):
def forward_softmax(self, def forward_softmax(self,
x: torch.Tensor, x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) in_cache: torch.Tensor = torch.zeros(
0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None: if self.global_cmvn is not None:
x = self.global_cmvn(x) x = self.global_cmvn(x)
@ -196,7 +197,8 @@ def init_model(configs):
classifier = LinearClassifier(hidden_dim, output_dim) classifier = LinearClassifier(hidden_dim, output_dim)
activation = nn.Sigmoid() activation = nn.Sigmoid()
# Here we add a possible "activation_type", one can choose to use other activation function. # Here we add a possible "activation_type",
# one can choose to use other activation function.
# We use nn.Identity just for CTC loss # We use nn.Identity just for CTC loss
if "activation" in configs: if "activation" in configs:
activation_type = configs["activation"]["type"] activation_type = configs["activation"]["type"]

View File

@ -188,7 +188,8 @@ def criterion(type: str,
loss, acc = max_pooling_loss(logits, target, lengths, min_duration) loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc return loss, acc
elif type == 'ctc': elif type == 'ctc':
loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation) loss, acc = ctc_loss(
logits, target, lengths, target_lengths, validation)
return loss, acc return loss, acc
else: else:
exit(1) exit(1)
@ -281,7 +282,8 @@ def ctc_prefix_beam_search(
if ps > nodes[-1]['prob']: # update frame and prob if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps # nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t # nodes[-1]['frame'] = t
nodes.pop() # to avoid change other beam which has this node. # avoid change other beam which has this node.
nodes.pop()
nodes.append(dict(token=s, frame=t, prob=ps)) nodes.append(dict(token=s, frame=t, prob=ps))
else: else:
nodes = cur_nodes.copy() nodes = cur_nodes.copy()
@ -429,7 +431,8 @@ class Calculator:
break break
else: # shouldn't reach here else: # shouldn't reach here
print( print(
'this should not happen , i = {i} , j = {j} , error = {error}' 'this should not happen, '
'i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error'])) .format(i=i, j=j, error=self.space[i][j]['error']))
return result return result