From ea6a0f5cda46de17cd7f9a09396e5ff0d6618709 Mon Sep 17 00:00:00 2001 From: dujing Date: Mon, 24 Jul 2023 16:39:13 +0800 Subject: [PATCH] fix flake8, update training script and README, give pretrained ckpt. --- examples/hi_xiaowen/s0/README.md | 41 +++--- examples/hi_xiaowen/s0/run_ctc.sh | 4 +- examples/hi_xiaowen/s0/run_fsmn_ctc.sh | 11 +- tools/make_list.py | 27 ++-- wekws/bin/compute_det_ctc.py | 39 ++++-- wekws/bin/export_onnx.py | 2 +- wekws/bin/score_ctc.py | 26 ++-- wekws/bin/stream_kws_ctc.py | 177 +++++++++++++++++-------- wekws/bin/stream_score_ctc.py | 88 +++++++----- wekws/bin/train.py | 8 +- wekws/model/fsmn.py | 18 ++- wekws/model/kws_model.py | 8 +- wekws/model/loss.py | 9 +- 13 files changed, 299 insertions(+), 159 deletions(-) diff --git a/examples/hi_xiaowen/s0/README.md b/examples/hi_xiaowen/s0/README.md index f776416..4f2f7d0 100644 --- a/examples/hi_xiaowen/s0/README.md +++ b/examples/hi_xiaowen/s0/README.md @@ -1,4 +1,6 @@ -Comparison among different backbones. FRRs with FAR fixed at once per hour: +Comparison among different backbones, +all models use Max-Pooling loss. +FRRs with FAR fixed at once per hour: | model | params(K) | epoch | hi_xiaowen | nihao_wenwen | |-----------------------|-----------|-----------|------------|--------------| @@ -9,32 +11,35 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour: | MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 | | MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 | -Next, we use CTC loss to train the model, with DS_TCN and FSMN. +Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones. and we use CTC prefix beam search to decode and detect keywords, the detection is either in non-streaming or streaming fashion. Since the FAR is pretty low when using CTC loss, -the follow result is FRRs with FAR fixed at once per 12 hours: +the follow results are FRRs with FAR fixed at once per 12 hours: Comparison between Max-pooling and CTC loss. -The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch). +The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged). FRRs with FAR fixed at once per 12 hours - -| model | loss | hi_xiaowen | nihao_wenwen | -|-----------------------|-------------|------------|--------------| -| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | -| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | +| model | loss | hi_xiaowen | nihao_wenwen | model ckpt | +|-----------------------|-------------|------------|--------------|------------| +| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) | +| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) | -Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch) -and FSMN(modelscope released, xiaoyunxiaoyun model). +Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged) +and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged). FRRs with FAR fixed at once per 12 hours: -| model | params(K) | hi_xiaowen | nihao_wenwen | -|-----------------------|-------------|------------|--------------| -| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | -| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | +| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt | +|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------| +| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) | +| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) | + +Now, the DSTCN model with CTC loss may not get the best performance, because the +pretraining phase is not sufficiently converged. We recommend you use pretrained +FSMN model as initial checkpoint to train your own model. Comparison Between stream_score_ctc and score_ctc. FRRs with FAR fixed at once per 12 hours: @@ -52,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va which result in a higher False Rejection Rate in Detection Error Tradeoff result. The actual FRR will be lower than the DET curve gives in a given threshold. -Now, the model with CTC loss may not get the best performance, -but it's more robust compared with the classification model using CE/Max-pooling loss. -For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). \ No newline at end of file +On some small data KWS tasks, we believe the FSMN-CTC model is more robust +compared with the classification model using CE/Max-pooling loss. +For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). \ No newline at end of file diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh index 7cc1daf..647d6ae 100644 --- a/examples/hi_xiaowen/s0/run_ctc.sh +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -194,12 +194,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --checkpoint $score_checkpoint \ --score_file $result_dir/score.txt \ --num_workers 8 \ - --keywords 嗨小问,你好问问 \ + --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \ --token_file data/tokens.txt \ --lexicon_file data/lexicon.txt python wekws/bin/compute_det_ctc.py \ - --keywords 嗨小问,你好问问 \ + --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \ --test_data data/test/data.list \ --window_shift $window_shift \ --step 0.001 \ diff --git a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh index 743e349..18b1d08 100644 --- a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh +++ b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh @@ -145,12 +145,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --checkpoint $score_checkpoint \ --score_file $result_dir/score.txt \ --num_workers 8 \ - --keywords 嗨小问,你好问问 \ + --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \ --token_file data/tokens.txt \ --lexicon_file data/lexicon.txt python wekws/bin/compute_det_ctc.py \ - --keywords 嗨小问,你好问问 \ + --keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \ --test_data data/test/data.list \ --window_shift $window_shift \ --step 0.001 \ @@ -161,14 +161,13 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then -# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input -# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450 -# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') + jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') + onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') + # For now, FSMN can not export to JITScript # python wekws/bin/export_jit.py \ # --config $dir/config.yaml \ # --checkpoint $score_checkpoint \ # --jit_model $dir/$jit_model - onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') python wekws/bin/export_onnx.py \ --config $dir/config.yaml \ --checkpoint $score_checkpoint \ diff --git a/tools/make_list.py b/tools/make_list.py index da76823..76b55cf 100755 --- a/tools/make_list.py +++ b/tools/make_list.py @@ -45,7 +45,8 @@ def query_token_set(txt, symbol_table, lexicon_table): tokens_str = tokens_str + ('!sil', ) elif part == '' or part == '': tokens_str = tokens_str + ('', ) - elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '': + elif part == '(noise)' or part == 'noise)' or \ + part == '(noise' or part == '': tokens_str = tokens_str + ('', ) elif part in symbol_table: tokens_str = tokens_str + (part, ) @@ -75,11 +76,11 @@ def query_token_set(txt, symbol_table, lexicon_table): if '' in symbol_table: tokens_idx = tokens_idx + (symbol_table[''], ) logging.info( - f'\'{ch}\' is not in token set, replace with ') + f'{ch} is not in token set, replace with ') else: tokens_idx = tokens_idx + (symbol_table[''], ) logging.info( - f'\'{ch}\' is not in token set, replace with ') + f'{ch} is not in token set, replace with ') return tokens_str, tokens_idx @@ -94,7 +95,8 @@ def query_token_list(txt, symbol_table, lexicon_table): tokens_str.append('!sil') elif part == '' or part == '': tokens_str.append('') - elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '': + elif part == '(noise)' or part == 'noise)' or \ + part == '(noise' or part == '': tokens_str.append('') elif part in symbol_table: tokens_str.append(part) @@ -124,11 +126,11 @@ def query_token_list(txt, symbol_table, lexicon_table): if '' in symbol_table: tokens_idx.append(symbol_table['']) logging.info( - f'\'{ch}\' is not in token set, replace with ') + f'{ch} is not in token set, replace with ') else: tokens_idx.append(symbol_table['']) logging.info( - f'\'{ch}\' is not in token set, replace with ') + f'{ch} is not in token set, replace with ') return tokens_str, tokens_idx @@ -160,8 +162,10 @@ if __name__ == '__main__': parser.add_argument('text_file', help='text file') parser.add_argument('duration_file', help='duration file') parser.add_argument('output_file', help='output list file') - parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + parser.add_argument('--token_file', type=str, default=None, + help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, + help='the path of lexicon.txt') args = parser.parse_args() wav_table = {} @@ -196,7 +200,9 @@ if __name__ == '__main__': txt = [1] # the /sil is indexed by 1 tokens = ["sil"] else: - tokens, txt = query_token_list(arr[1], token_table, lexicon_table) + tokens, txt = query_token_list(arr[1], + token_table, + lexicon_table) else: txt = int(arr[1]) assert key in wav_table @@ -206,7 +212,8 @@ if __name__ == '__main__': if tokens is None: line = dict(key=key, txt=txt, duration=duration, wav=wav) else: - line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav) + line = dict(key=key, tok=tokens, txt=txt, + duration=duration, wav=wav) json_line = json.dumps(line, ensure_ascii=False) fout.write(json_line + '\n') diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index 2c31c7f..2ff459f 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -157,18 +157,23 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): if __name__ == '__main__': parser = argparse.ArgumentParser(description='compute det curve') parser.add_argument('--test_data', required=True, help='label file') - parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + parser.add_argument('--keywords', type=str, default=None, + help='keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, + help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, + help='the path of lexicon.txt') parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--step', type=float, default=0.01, help='threshold step') parser.add_argument('--window_shift', type=int, default=50, - help='window_shift is used to skip the frames after triggered') + help='window_shift is used to ' + 'skip the frames after triggered') parser.add_argument('--stats_dir', required=False, default=None, - help='false reject/alarm stats dir, default in score_file') + help='false reject/alarm stats dir, ' + 'default in score_file') parser.add_argument('--det_curve_path', required=False, default=None, @@ -188,7 +193,11 @@ if __name__ == '__main__': args = parser.parse_args() window_shift = args.window_shift - keywords_list = args.keywords.strip().split(',') + logging.info(f"keywords is {args.keywords}, " + f"Chinese is converted into Unicode.") + + keywords = args.keywords.encode('utf-8').decode('unicode_escape') + keywords_list = keywords.strip().split(',') token_table = read_token(args.token_file) lexicon_table = read_lexicon(args.lexicon_file) @@ -197,7 +206,8 @@ if __name__ == '__main__': strs, indexes = query_token_set(keyword, token_table, lexicon_table) true_keywords[keyword] = ''.join(strs) - keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords) + keyword_filler_table = load_label_and_score( + keywords_list, args.test_data, args.score_file, true_keywords) for keyword in keywords_list: keyword = true_keywords[keyword] @@ -206,8 +216,10 @@ if __name__ == '__main__': keyword_num = len(keyword_filler_table[keyword]['keyword_table']) filler_dur = keyword_filler_table[keyword]['filler_duration'] filler_num = len(keyword_filler_table[keyword]['filler_table']) - assert keyword_num > 0, 'Can\'t compute det for {} without positive sample' - assert filler_num > 0, 'Can\'t compute det for {} without negative sample' + assert keyword_num > 0, \ + 'Can\'t compute det for {} without positive sample' + assert filler_num > 0, \ + 'Can\'t compute det for {} without negative sample' logging.info('Computing det for {}'.format(keyword)) logging.info(' Keyword duration: {} Hours, wave number: {}'.format( @@ -218,14 +230,16 @@ if __name__ == '__main__': stats_dir = args.stats_dir else: stats_dir = os.path.dirname(args.score_file) - stats_file = os.path.join(stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt') + stats_file = os.path.join( + stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt') with open(stats_file, 'w', encoding='utf8') as fout: threshold = 0.0 while threshold <= 1.0: num_false_reject = 0 num_true_detect = 0 # transverse the all keyword_table - for key, confi in keyword_filler_table[keyword]['keyword_table'].items(): + for key, confi in \ + keyword_filler_table[keyword]['keyword_table'].items(): if confi < threshold: num_false_reject += 1 else: @@ -253,4 +267,5 @@ if __name__ == '__main__': det_curve_path = args.det_curve_path else: det_curve_path = os.path.join(stats_dir, 'det.png') - plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step) + plot_det(stats_dir, det_curve_path, + args.xlim, args.x_step, args.ylim, args.y_step) diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index e45e4b2..ad1f101 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -42,7 +42,7 @@ def main(): feature_dim = configs['model']['input_dim'] model = init_model(configs['model']) if configs['training_config'].get('criterion', 'max_pooling') == 'ctc': - # if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search + # if we use ctc_loss, the logits need to be convert into probs model.forward = model.forward_softmax print(model) diff --git a/wekws/bin/score_ctc.py b/wekws/bin/score_ctc.py index ef442b8..9be9da1 100644 --- a/wekws/bin/score_ctc.py +++ b/wekws/bin/score_ctc.py @@ -65,9 +65,12 @@ def get_args(): action='store_true', default=False, help='Use pinned memory buffers used for reading') - parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + parser.add_argument('--keywords', type=str, default=None, + help='the keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, + help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, + help='the path of lexicon.txt') args = parser.parse_args() return args @@ -133,7 +136,9 @@ def main(): lexicon_table = read_lexicon(args.lexicon_file) # 4. parse keywords tokens assert args.keywords is not None, 'at least one keyword is needed' - keywords_str = args.keywords + logging.info(f"keywords is {args.keywords}, " + f"Chinese is converted into Unicode.") + keywords_str = args.keywords.encode('utf-8').decode('unicode_escape') keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} keywords_idxset = {0} @@ -167,7 +172,9 @@ def main(): for i in range(len(keys)): key = keys[i] score = logits[i][:lengths[i]] - hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset) + hyps = ctc_prefix_beam_search(score, + lengths[i], + keywords_idxset) hit_keyword = None hit_score = 1.0 start = 0 @@ -192,10 +199,13 @@ def main(): break if hit_keyword is not None: - fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score)) + fout.write('{} detected {} {:.3f}\n'.format( + key, hit_keyword, hit_score)) logging.info( - f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " - f"duration {end - start}, score {hit_score}, Activated.") + f"batch:{batch_idx}_{i} detect {hit_keyword} " + f"in {key} from {start} to {end} frame. " + f"duration {end - start}, " + f"score {hit_score}, Activated.") else: fout.write('{} rejected\n'.format(key)) logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") diff --git a/wekws/bin/stream_kws_ctc.py b/wekws/bin/stream_kws_ctc.py index ad5d626..b20a13e 100644 --- a/wekws/bin/stream_kws_ctc.py +++ b/wekws/bin/stream_kws_ctc.py @@ -16,7 +16,8 @@ from __future__ import print_function import argparse import struct -import wave +#import wave +import librosa import logging import os import math @@ -35,9 +36,12 @@ from tools.make_list import query_token_set, read_lexicon, read_token def get_args(): parser = argparse.ArgumentParser(description='detect keywords online.') parser.add_argument('--config', required=True, help='config file') - parser.add_argument('--wav_path', required=False, default=None, help='test wave path.') - parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.') - parser.add_argument('--result_file', required=False, default=None, help='test result.') + parser.add_argument('--wav_path', required=False, + default=None, help='test wave path.') + parser.add_argument('--wav_scp', required=False, + default=None, help='test wave scp.') + parser.add_argument('--result_file', required=False, + default=None, help='test result.') parser.add_argument('--gpu', type=int, @@ -48,21 +52,27 @@ def get_args(): action='store_true', default=False, help='Use pinned memory buffers used for reading') - parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + parser.add_argument('--keywords', type=str, default=None, + help='the keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, + help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, + help='the path of lexicon.txt') parser.add_argument('--score_beam_size', default=3, type=int, - help='The first prune beam, filter out those frames with low scores.') + help='The first prune beam, ' + 'filter out those frames with low scores.') parser.add_argument('--path_beam_size', default=20, type=int, - help='The second prune beam, keep only path_beam_size candidates.') + help='The second prune beam, ' + 'keep only path_beam_size candidates.') parser.add_argument('--threshold', type=float, default=0.0, - help='The threshold of kws. If ctc_search probs exceed this value,' + help='The threshold of kws. ' + 'If ctc_search probs exceed this value,' 'the keyword will be activated.') parser.add_argument('--min_frames', default=5, @@ -98,16 +108,22 @@ def is_sublist(main_list, check_list): else: return -1 -def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size): +def ctc_prefix_beam_search(t, probs, + cur_hyps, + keywords_idxset, + score_beam_size): ''' :param t: the time in frame :param probs: the probability in t_th frame, (vocab_size, ) :param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))] - in tuple, 1st is prefix id, 2nd include p_blank, p_non_blank, and path nodes list. - in path nodes list, each node is a dict of {token=idx, frame=t, prob=ps} + in tuple, 1st is prefix id, 2nd include p_blank, + p_non_blank, and path nodes list. + in path nodes list, each node is + a dict of {token=idx, frame=t, prob=ps} :param keywords_idxset: the index of keywords in token.txt - :param score_beam_size: the probability threshold, to filter out those frames with low probs. + :param score_beam_size: the probability threshold, + to filter out those frames with low probs. :return: next_hyps: the hypothesis depend on current hyp and current frame. ''' @@ -170,7 +186,8 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size) if ps > nodes[-1]['prob']: # update frame and prob # nodes[-1]['prob'] = ps # nodes[-1]['frame'] = t - nodes.pop() # to avoid change other beam which has this node. + nodes.pop() + # to avoid change other beam which has this node. nodes.append(dict(token=s, frame=t, prob=ps)) else: nodes = cur_nodes.copy() @@ -199,11 +216,14 @@ class KeyWordSpotter(torch.nn.Module): # feature related self.sample_rate = 16000 self.wave_remained = np.array([]) - self.num_mel_bins = dataset_conf['feature_extraction_conf']['num_mel_bins'] - self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms - self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms + self.num_mel_bins = dataset_conf[ + 'feature_extraction_conf']['num_mel_bins'] + self.frame_length = dataset_conf[ + 'feature_extraction_conf']['frame_length'] # in ms + self.frame_shift = dataset_conf[ + 'feature_extraction_conf']['frame_shift'] # in ms self.downsampling = dataset_conf.get('frame_skip', 1) - self.resolution = self.frame_shift / 1000 # in second + self.resolution = self.frame_shift / 1000 # in second # fsmn splice operation self.context_expansion = dataset_conf.get('context_expansion', False) self.left_context = 0 @@ -231,9 +251,11 @@ class KeyWordSpotter(torch.nn.Module): self.model.eval() logging.info(f'model {ckpt_path} loaded.') self.token_table = read_token(token_path) - logging.info(f'tokens {token_path} with {len(self.token_table)} units loaded.') + logging.info(f'tokens {token_path} with ' + f'{len(self.token_table)} units loaded.') self.lexicon_table = read_lexicon(lexicon_path) - logging.info(f'lexicons {lexicon_path} with {len(self.lexicon_table)} units loaded.') + logging.info(f'lexicons {lexicon_path} with ' + f'{len(self.lexicon_table)} units loaded.') self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) @@ -257,7 +279,9 @@ class KeyWordSpotter(torch.nn.Module): def set_keywords(self, keywords): # 4. parse keywords tokens - assert keywords is not None, 'at least one keyword is needed, multiple keywords should be splitted with comma(,)' + assert keywords is not None, \ + 'at least one keyword is needed, ' \ + 'multiple keywords should be splitted with comma(,)' keywords_str = keywords keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} @@ -265,7 +289,8 @@ class KeyWordSpotter(torch.nn.Module): keywords_strset = {''} keywords_tokenmap = {'': 0} for keyword in keywords_list: - strs, indexes = query_token_set(keyword, self.token_table, self.lexicon_table) + strs, indexes = query_token_set( + keyword, self.token_table, self.lexicon_table) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) @@ -284,16 +309,20 @@ class KeyWordSpotter(torch.nn.Module): self.keywords_token = keywords_token def accept_wave(self, wave): - assert isinstance(wave, bytes), "please make sure the input format is bytes(raw PCM)" + assert isinstance(wave, bytes), \ + "please make sure the input format is bytes(raw PCM)" # convert bytes into float32 data = [] for i in range(0, len(wave), 2): value = struct.unpack(' self.right_context, "make sure each chunk feat length is large than right context." + assert feat_len > self.right_context, \ + "make sure each chunk feat length is large than right context." # pad feats with remained feature from last chunk if self.feature_remained is None: # first chunk - # pad first frame at the beginning, replicate just support last dimension, so we do transpose. - feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T + # pad first frame at the beginning, + # replicate just support last dimension, so we do transpose. + feats_pad = F.pad( + feats.T, (self.left_context, 0), mode='replicate').T else: feats_pad = torch.cat((self.feature_remained, feats)) - ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context) + ctx_frm = feats_pad.shape[0] - \ + (self.right_context+self.right_context) ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) for i in range(ctx_frm): - feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) + feats_ctx[i] = torch.cat( + tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) # update feature remained, and feats - self.feature_remained = feats[-(self.left_context+self.right_context):] + self.feature_remained = \ + feats[-(self.left_context+self.right_context):] feats = feats_ctx.to(self.device) if self.downsampling > 1: - last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset + last_remainder = 0 if self.feats_ctx_offset==0 \ + else self.downsampling-self.feats_ctx_offset remainder = (feats.size(0)+last_remainder) % self.downsampling feats = feats[self.feats_ctx_offset::self.downsampling, :] - self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder + self.feats_ctx_offset = remainder \ + if remainder == 0 else self.downsampling-remainder return feats def decode_keywords(self, t, probs): @@ -344,7 +381,8 @@ class KeyWordSpotter(torch.nn.Module): self.cur_hyps, self.keywords_idxset, self.score_beam) - # update cur_hyps. note: the hyps is sort by path score(pnb+pb), not the keywords' probabilities. + # update cur_hyps. note: the hyps is sort by path score(pnb+pb), + # not the keywords' probabilities. cur_hyps = next_hyps[:self.path_beam] self.cur_hyps = cur_hyps @@ -381,27 +419,36 @@ class KeyWordSpotter(torch.nn.Module): if hit_keyword is not None: if self.hit_score >= self.threshold and \ self.min_frames <= duration <= self.max_frames \ - and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames): + and (self.last_active_pos==-1 or + end-self.last_active_pos >= self.interval_frames): self.activated = True self.last_active_pos = end logging.info( - f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " + f"Frame {absolute_time} detect {hit_keyword} " + f"from {start} to {end} frame. " f"duration {duration}, score {self.hit_score}, Activated.") - elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames: + elif self.last_active_pos>0 and \ + end-self.last_active_pos < self.interval_frames: logging.info( - f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " - f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ") + f"Frame {absolute_time} detect {hit_keyword} " + f"from {start} to {end} frame. " + f"but interval {end-self.last_active_pos} " + f"is lower than {self.interval_frames}, Deactivated. ") elif self.hit_score < self.threshold: logging.info( - f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " - f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ") + f"Frame {absolute_time} detect {hit_keyword} " + f"from {start} to {end} frame. " + f"but {self.hit_score} " + f"is lower than {self.threshold}, Deactivated. ") elif self.min_frames > duration or duration > self.max_frames: logging.info( - f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " - f"but {duration} beyond range({self.min_frames}~{self.max_frames}), Deactivated. ") + f"Frame {absolute_time} detect {hit_keyword} " + f"from {start} to {end} frame. " + f"but {duration} beyond range" + f"({self.min_frames}~{self.max_frames}), Deactivated. ") self.result = { "state": 1 if self.activated else 0, @@ -418,7 +465,7 @@ class KeyWordSpotter(torch.nn.Module): feature = feature.unsqueeze(0) # add a batch dimension logits, self.in_cache = self.model(feature, self.in_cache) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) - probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search + probs = probs[0].cpu() # remove batch dimension for (t, prob) in enumerate(probs): t *= self.downsampling self.decode_keywords(t, prob) @@ -426,10 +473,14 @@ class KeyWordSpotter(torch.nn.Module): if self.activated: self.reset() - # since a chunk include about 30 frames, once activated, we can jump the latter frames. - # TODO: there should give another method to update result, avoiding self.result being cleared. + # since a chunk include about 30 frames, + # once activated, we can jump the latter frames. + # TODO: there should give another method to update result, + # avoiding self.result being cleared. break - self.total_frames += len(probs) * self.downsampling # update frame offset + + # update frame offset + self.total_frames += len(probs) * self.downsampling return self.result def reset(self): @@ -465,15 +516,20 @@ def demo(): args.gpu, args.jit_model) - # actually this could be done in __init__ method, we pull it outside for changing keywords more freely. + # actually this could be done in __init__ method, + # we pull it outside for changing keywords more freely. kws.set_keywords(args.keywords) if args.wav_path: # Caution: input WAV should be standard 16k, 16 bits, 1 channel # In demo we read wave in non-streaming fashion. - with wave.open(args.wav_path, 'rb') as fin: - assert fin.getnchannels() == 1 - wav = fin.readframes(fin.getnframes()) + # with wave.open(args.wav_path, 'rb') as fin: + # assert fin.getnchannels() == 1 + # wav = fin.readframes(fin.getnframes()) + + y, _ = librosa.load(args.wav_path, sr=16000, mono=True) + # NOTE: model supports 16k sample_rate + wav = (y * (1 << 15)).astype("int16").tobytes() # We inference every 0.3 seconds, in streaming fashion. interval = int(0.3 * 16000) * 2 @@ -490,12 +546,18 @@ def demo(): with open(args.wav_scp, 'r') as fscp: for line in fscp: line = line.strip().split() - assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}" + assert len(line) == 2, \ + f"The scp should be in kaldi format: " \ + f"\"utt_name wav_path\", but got {line}" utt_name, wav_path = line[0], line[1] - with wave.open(wav_path, 'rb') as fin: - assert fin.getnchannels() == 1 - wav = fin.readframes(fin.getnframes()) + # with wave.open(args.wav_path, 'rb') as fin: + # assert fin.getnchannels() == 1 + # wav = fin.readframes(fin.getnframes()) + + y, _ = librosa.load(args.wav_path, sr=16000, mono=True) + # NOTE: model supports 16k sample_rate + wav = (y * (1 << 15)).astype("int16").tobytes() kws.reset_all() activated = False @@ -510,7 +572,8 @@ def demo(): if fout: hit_keyword = result['keyword'] hit_score = result['score'] - fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score)) + fout.write('{} detected {} {:.3f}\n'.format( + utt_name, hit_keyword, hit_score)) if not activated: if fout: diff --git a/wekws/bin/stream_score_ctc.py b/wekws/bin/stream_score_ctc.py index fab510f..c03e66b 100644 --- a/wekws/bin/stream_score_ctc.py +++ b/wekws/bin/stream_score_ctc.py @@ -66,30 +66,36 @@ def get_args(): action='store_true', default=False, help='Use pinned memory buffers used for reading') - parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') - parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') - parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + parser.add_argument('--keywords', type=str, default=None, + help='the keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, + help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, + help='the path of lexicon.txt') parser.add_argument('--score_beam_size', default=3, type=int, - help='The first prune beam, filter out those frames with low scores.') + help='The first prune beam, f' + 'ilter out those frames with low scores.') parser.add_argument('--path_beam_size', default=20, type=int, - help='The second prune beam, keep only path_beam_size candidates.') + help='The second prune beam, ' + 'keep only path_beam_size candidates.') parser.add_argument('--threshold', type=float, default=0.0, - help='The threshold of kws. If ctc_search probs exceed this value,' + help='The threshold of kws. ' + 'If ctc_search probs exceed this value,' 'the keyword will be activated.') parser.add_argument('--min_frames', default=5, type=int, - help='The min frames of keyword\'s duration.') + help='The min frames of keyword duration.') parser.add_argument('--max_frames', default=250, type=int, - help='The max frames of keyword\'s duration.') + help='The max frames of keyword duration.') args = parser.parse_args() return args @@ -158,7 +164,9 @@ def main(): lexicon_table = read_lexicon(args.lexicon_file) # 4. parse keywords tokens assert args.keywords is not None, 'at least one keyword is needed' - keywords_str = args.keywords + logging.info(f"keywords is {args.keywords}, " + f"Chinese is converted into Unicode.") + keywords_str = args.keywords.encode('utf-8').decode('unicode_escape') keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} keywords_idxset = {0} @@ -217,7 +225,8 @@ def main(): # filter prob score that is too small filter_probs = [] filter_index = [] - for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): + for prob, idx in zip( + top_k_probs.tolist(), top_k_index.tolist()): if keywords_idxset is not None: if prob > 0.05 and idx in keywords_idxset: filter_probs.append(prob) @@ -246,7 +255,8 @@ def main(): n_pb, n_pnb, nodes = next_hyps[prefix] n_pnb = n_pnb + pnb * ps nodes = cur_nodes.copy() - if ps > nodes[-1]['prob']: # update frame and prob + # update frame and prob + if ps > nodes[-1]['prob']: nodes[-1]['prob'] = ps nodes[-1]['frame'] = t next_hyps[prefix] = (n_pb, n_pnb, nodes) @@ -257,32 +267,37 @@ def main(): n_pb, n_pnb, nodes = next_hyps[n_prefix] n_pnb = n_pnb + pb * ps nodes = cur_nodes.copy() - nodes.append(dict(token=s, frame=t, - prob=ps)) # to record token prob + nodes.append(dict( + token=s, frame=t, prob=ps)) next_hyps[n_prefix] = (n_pb, n_pnb, nodes) else: n_prefix = prefix + (s,) n_pb, n_pnb, nodes = next_hyps[n_prefix] if nodes: - if ps > nodes[-1]['prob']: # update frame and prob + # update frame and prob + if ps > nodes[-1]['prob']: # nodes[-1]['prob'] = ps # nodes[-1]['frame'] = t - nodes.pop() # to avoid change other beam which has this node. - nodes.append(dict(token=s, frame=t, prob=ps)) + # avoid change other beam has this node. + nodes.pop() + nodes.append(dict( + token=s, frame=t, prob=ps)) else: nodes = cur_nodes.copy() - nodes.append(dict(token=s, frame=t, - prob=ps)) # to record token prob + nodes.append(dict( + token=s, frame=t, prob=ps)) n_pnb = n_pnb + pb * ps + pnb * ps next_hyps[n_prefix] = (n_pb, n_pnb, nodes) # 2.2 Second beam prune next_hyps = sorted( - next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) + next_hyps.items(), + key=lambda x: (x[1][0] + x[1][1]), reverse=True) cur_hyps = next_hyps[:args.path_beam_size] - hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] + hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) + for y in cur_hyps] for one_hyp in hyps: prefix_ids = one_hyp[0] @@ -295,7 +310,8 @@ def main(): if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] - end = prefix_nodes[offset + len(lab) - 1]['frame'] + end = prefix_nodes[ + offset + len(lab) - 1]['frame'] for idx in range(offset, offset + len(lab)): hit_score *= prefix_nodes[idx]['prob'] break @@ -305,25 +321,35 @@ def main(): duration = end - start if hit_keyword is not None: - if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames: + if hit_score >= args.threshold and \ + args.min_frames <= duration <= args.max_frames: activated = True - fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) + fout.write('{} detected {} {:.3f}\n'.format( + key, hit_keyword, hit_score)) logging.info( - f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " - f"duration {duration}, score {hit_score} Activated.") + f"batch:{batch_idx}_{i} detect {hit_keyword} " + f"in {key} from {start} to {end} frame. " + f"duration {duration}, s" + f"core {hit_score} Activated.") - # clear the ctc_prefix buffer, and clear hit_keyword + # clear the ctc_prefix buffer, and hit_keyword cur_hyps = [(tuple(), (1.0, 0.0, []))] hit_keyword = None hit_score = 1.0 elif hit_score < args.threshold: logging.info( - f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " - f"but {hit_score} less than {args.threshold}, Deactivated. ") - elif args.min_frames > duration or duration > args.max_frames: + f"batch:{batch_idx}_{i} detect {hit_keyword} " + f"in {key} from {start} to {end} frame. " + f"but {hit_score} less than " + f"{args.threshold}, Deactivated. ") + elif args.min_frames > duration \ + or duration > args.max_frames: logging.info( - f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " - f"but {duration} beyond range({args.min_frames}~{args.max_frames}), Deactivated. ") + f"batch:{batch_idx}_{i} detect {hit_keyword} " + f"in {key} from {start} to {end} frame. " + f"but {duration} beyond " + f"range({args.min_frames}~{args.max_frames}), " + f"Deactivated. ") if not activated: fout.write('{} rejected\n'.format(key)) logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") diff --git a/wekws/bin/train.py b/wekws/bin/train.py index b81ad9f..c88ea1f 100644 --- a/wekws/bin/train.py +++ b/wekws/bin/train.py @@ -159,8 +159,12 @@ def main(): if rank == 0: pass # TODO: for now streaming FSMN do not support export to JITScript, - # TODO: because there is nn.Sequential with Tuple input in current FSMN modules. - # the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450 + # TODO: because there is nn.Sequential with Tuple input + # in current FSMN modules. + # the issue is in https://stackoverflow.com/questions/75714299/ + # pytorch-jit-script-error-when-sequential-container- + # takes-a-tuple-input/76553450#76553450 + # script_model = torch.jit.script(model) # script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py index 86ce41f..2b9b74d 100644 --- a/wekws/model/fsmn.py +++ b/wekws/model/fsmn.py @@ -225,16 +225,19 @@ class FSMNBlock(nn.Module): x_per = x.permute(0, 3, 2, 1) if in_cache is None or len(in_cache) == 0 : - x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0]) + x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + + self.rorder * self.rstride, 0]) else: in_cache = in_cache.to(x_per.device) x_pad = torch.cat((in_cache, x_per), dim=2) - in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :] + in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + + self.rorder * self.rstride):, :] y_left = x_pad[:, :, :-self.rorder * self.rstride, :] y_left = self.quant(y_left) y_left = self.conv_left(y_left) y_left = self.dequant(y_left) - out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left + out = x_pad[:, :, (self.lorder - 1) * self.lstride: + -self.rorder * self.rstride, :] + y_left if self.conv_right is not None: # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) @@ -253,7 +256,8 @@ class FSMNBlock(nn.Module): def to_kaldi_net(self): re_str = '' re_str += ' %d %d\n' % (self.dim, self.dim) - re_str += ' %d %d %d %d %d 0\n' % ( + re_str += ' %d %d %d ' \ + ' %d %d 0\n' % ( 1, self.lorder, self.rorder, self.lstride, self.rstride) # print(self.conv_left.weight,self.conv_right.weight) @@ -441,7 +445,8 @@ class FSMN(nn.Module): self.output_affine_dim = output_affine_dim self.output_dim = output_dim - self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride + self.padding = (self.lorder-1) * self.lstride \ + + self.rorder * self.rstride self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) @@ -469,7 +474,8 @@ class FSMN(nn.Module): """ if in_cache is None or len(in_cache) == 0 : - in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))] + in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) + for _ in range(len(self.fsmn))] input = (input, in_cache) x1 = self.in_linear1(input) x2 = self.in_linear2(x1) diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index 2c89269..349690b 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -34,7 +34,7 @@ class KWSModel(nn.Module): """Our model consists of four parts: 1. global_cmvn: Optional, (idim, idim) 2. preprocessing: feature dimention projection, (idim, hdim) - 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) + 3. backbone: backbone of the whole network, (hdim, hdim) 4. classifier: output layer or classifier of KWS model, (hdim, odim) 5. activation: nn.Sigmoid for wakeup word @@ -76,7 +76,8 @@ class KWSModel(nn.Module): def forward_softmax(self, x: torch.Tensor, - in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + in_cache: torch.Tensor = torch.zeros( + 0, 0, 0, dtype=torch.float) ) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_cmvn is not None: x = self.global_cmvn(x) @@ -196,7 +197,8 @@ def init_model(configs): classifier = LinearClassifier(hidden_dim, output_dim) activation = nn.Sigmoid() - # Here we add a possible "activation_type", one can choose to use other activation function. + # Here we add a possible "activation_type", + # one can choose to use other activation function. # We use nn.Identity just for CTC loss if "activation" in configs: activation_type = configs["activation"]["type"] diff --git a/wekws/model/loss.py b/wekws/model/loss.py index 46004a4..fef17d6 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -188,7 +188,8 @@ def criterion(type: str, loss, acc = max_pooling_loss(logits, target, lengths, min_duration) return loss, acc elif type == 'ctc': - loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation) + loss, acc = ctc_loss( + logits, target, lengths, target_lengths, validation) return loss, acc else: exit(1) @@ -281,7 +282,8 @@ def ctc_prefix_beam_search( if ps > nodes[-1]['prob']: # update frame and prob # nodes[-1]['prob'] = ps # nodes[-1]['frame'] = t - nodes.pop() # to avoid change other beam which has this node. + # avoid change other beam which has this node. + nodes.pop() nodes.append(dict(token=s, frame=t, prob=ps)) else: nodes = cur_nodes.copy() @@ -429,7 +431,8 @@ class Calculator: break else: # shouldn't reach here print( - 'this should not happen , i = {i} , j = {j} , error = {error}' + 'this should not happen, ' + 'i = {i} , j = {j} , error = {error}' .format(i=i, j=j, error=self.space[i][j]['error'])) return result