From 0ac7a417b1a8c09baf55f6554741d47269090b2e Mon Sep 17 00:00:00 2001 From: dujing Date: Sat, 3 Jun 2023 14:24:37 +0800 Subject: [PATCH] Add streaming detection of CTC model. Add CTC model onnx export. Add CTC model's result in README; For now CTC model runtime is not supported yet. --- examples/hi_xiaowen/s0/README.md | 49 +++- examples/hi_xiaowen/s0/run_ctc.sh | 9 +- examples/hi_xiaowen/s0/run_fsmn_ctc.sh | 7 +- wekws/bin/export_onnx.py | 3 + wekws/bin/score_ctc.py | 15 +- wekws/bin/stream_score_ctc.py | 331 +++++++++++++++++++++++++ wekws/model/kws_model.py | 13 + wekws/model/loss.py | 7 +- 8 files changed, 421 insertions(+), 13 deletions(-) create mode 100644 wekws/bin/stream_score_ctc.py diff --git a/examples/hi_xiaowen/s0/README.md b/examples/hi_xiaowen/s0/README.md index 964c81d..9477005 100644 --- a/examples/hi_xiaowen/s0/README.md +++ b/examples/hi_xiaowen/s0/README.md @@ -1,4 +1,4 @@ -FRRs with FAR fixed at once per hour: +Comparison among different backbones. FRRs with FAR fixed at once per hour: | model | params(K) | epoch | hi_xiaowen | nihao_wenwen | |-----------------------|-----------|-----------|------------|--------------| @@ -8,3 +8,50 @@ FRRs with FAR fixed at once per hour: | DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 | | MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 | | MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 | + +Next, we use CTC loss to train the model, with DS_TCN and FSMN. +and we use CTC prefix beam search to decode and detect keywords, +the detection is either in non-streaming or streaming fashion. + +Since the FAR is pretty low when using CTC loss, +the follow result is FRRs with FAR fixed at once per 12 hours: + +Comparison between Max-pooling and CTC loss. +The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch). +FRRs with FAR fixed at once per 12 hours + + +| model | loss | hi_xiaowen | nihao_wenwen | +|-----------------------|-------------|------------|--------------| +| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | +| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | + + +Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch) +and FSMN(modelscope released, xiaoyunxiaoyun model). +FRRs with FAR fixed at once per 12 hours: + +| model | params(K) | hi_xiaowen | nihao_wenwen | +|-----------------------|-------------|------------|--------------| +| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | +| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | + +Comparison Between stream_score_ctc and score_ctc. +FRRs with FAR fixed at once per 12 hours: + +| model | stream | hi_xiaowen | nihao_wenwen | +|-----------------------|-------------|------------|--------------| +| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 | +| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 | +| FSMN(spec_aug) | no | 0.031012 | 0.022460 | +| FSMN(spec_aug) | yes | 0.115215 | 0.020205 | + +Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame), +we record the probability of a keyword in a decoding path once the keyword appears in this path. +Actually the probability will increase through the time, so we record a lower value of probability, +which result in a higher False Rejection Rate in Detection Error Tradeoff result. +The actual FRR will be lower than the DET curve gives in a given threshold. + +Now, the model with CTC loss may not get the best performance, +but it's more robust compared with the classification model using CE/Max-pooling loss. +For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). \ No newline at end of file diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh index ee1f34d..8f2d694 100644 --- a/examples/hi_xiaowen/s0/run_ctc.sh +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -153,7 +153,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ! -d mobvoi_kws_transcription ] ;then git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git fi - checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not be the best. + checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not converge well. fi torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ @@ -181,7 +181,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi result_dir=$dir/test_$(basename $score_checkpoint) mkdir -p $result_dir - python wekws/bin/score_ctc.py \ + stream=true # we detect keyword online with ctc_prefix_beam_search + score_prefix="" + if $stream ; then + score_prefix=stream_ + fi + python wekws/bin/${score_prefix}score_ctc.py \ --config $dir/config.yaml \ --test_data data/test/data.list \ --gpu 0 \ diff --git a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh index 39c6bff..79774a9 100644 --- a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh +++ b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh @@ -132,7 +132,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi result_dir=$dir/test_$(basename $score_checkpoint) mkdir -p $result_dir - python wekws/bin/score_ctc.py \ + stream=true # we detect keyword online with ctc_prefix_beam_search + score_prefix="" + if $stream ; then + score_prefix=stream_ + fi + python wekws/bin/${score_prefix}score_ctc.py \ --config $dir/config.yaml \ --test_data data/test/data.list \ --gpu 0 \ diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index 75eb5b0..4850d19 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -41,6 +41,9 @@ def main(): configs = yaml.load(fin, Loader=yaml.FullLoader) feature_dim = configs['model']['input_dim'] model = init_model(configs['model']) + if configs['training_config']['criterion'] == 'ctc': + # if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search + model.forward = model.forward_softmax print(model) load_checkpoint(model, args.checkpoint) diff --git a/wekws/bin/score_ctc.py b/wekws/bin/score_ctc.py index 161f388..2af3df9 100644 --- a/wekws/bin/score_ctc.py +++ b/wekws/bin/score_ctc.py @@ -169,7 +169,7 @@ def main(): keywords_idxset) hit_keyword = None hit_score = 1.0 - # start = 0; end = 0 + start = 0; end = 0 for one_hyp in hyps: prefix_ids = one_hyp[0] # path_score = one_hyp[1] @@ -180,8 +180,8 @@ def main(): offset = is_sublist(prefix_ids, lab) if offset != -1: hit_keyword = word - # start = prefix_nodes[offset]['frame'] - # end = prefix_nodes[offset+len(lab)-1]['frame'] + start = prefix_nodes[offset]['frame'] + end = prefix_nodes[offset+len(lab)-1]['frame'] for idx in range(offset, offset + len(lab)): hit_score *= prefix_nodes[idx]['prob'] break @@ -190,12 +190,13 @@ def main(): break if hit_keyword is not None: - # fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\ - # .format(key, start*0.03, end*0.03, hit_keyword, hit_score)) - fout.write('{} detected {} {:.3f}\n'.format( - key, hit_keyword, hit_score)) + fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score)) + logging.info( + f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " + f"duration {end-start}, score {hit_score}, Activated.") else: fout.write('{} rejected\n'.format(key)) + logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx)) diff --git a/wekws/bin/stream_score_ctc.py b/wekws/bin/stream_score_ctc.py new file mode 100644 index 0000000..8ea692f --- /dev/null +++ b/wekws/bin/stream_score_ctc.py @@ -0,0 +1,331 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) +# 2023 Jing Du(thuduj12@163.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os, sys, math + +import torch +import yaml +from collections import defaultdict +from torch.utils.data import DataLoader + +from wekws.dataset.dataset import Dataset +from wekws.model.kws_model import init_model +from wekws.utils.checkpoint import load_checkpoint +from tools.make_list import query_token_set, read_lexicon, read_token + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--batch_size', + default=16, + type=int, + help='batch size for inference') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--score_file', + required=True, + help='output score file') + parser.add_argument('--jit_model', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + parser.add_argument('--score_beam_size', + default=3, + type=int, + help='The first prune beam, filter out those frames with low scores.') + parser.add_argument('--path_beam_size', + default=20, + type=int, + help='The second prune beam, keep only path_beam_size candidates.') + parser.add_argument('--threshold', + type=float, + default=0.0, + help='The threshold of kws. If ctc_search probs exceed this value,' + 'the keyword will be activated.') + parser.add_argument('--min_frames', + default=5, + type=int, + help='The min frames of keyword\'s duration.') + parser.add_argument('--max_frames', + default=250, + type=int, + help='The max frames of keyword\'s duration.') + + args = parser.parse_args() + return args + + +def is_sublist(main_list, check_list): + if len(main_list) < len(check_list): + return -1 + + if len(main_list) == len(check_list): + return 0 if main_list == check_list else -1 + + for i in range(len(main_list) - len(check_list)): + if main_list[i] == check_list[0]: + for j in range(len(check_list)): + if main_list[i + j] != check_list[j]: + break + else: + return i + else: + return -1 + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['feature_extraction_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_size'] = args.batch_size + + test_dataset = Dataset(args.test_data, test_conf) + test_data_loader = DataLoader(test_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + + if args.jit_model: + model = torch.jit.load(args.checkpoint) + # For script model, only cpu is supported. + device = torch.device('cpu') + else: + # Init asr model from configs + model = init_model(configs['model']) + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + model.eval() + score_abs_path = os.path.abspath(args.score_file) + + token_table = read_token(args.token_file) + lexicon_table = read_lexicon(args.lexicon_file) + # 4. parse keywords tokens + assert args.keywords is not None, 'at least one keyword is needed' + keywords_str = args.keywords + keywords_list = keywords_str.strip().replace(' ', '').split(',') + keywords_token = {} + keywords_idxset = {0} + keywords_strset = {''} + keywords_tokenmap = {'': 0} + for keyword in keywords_list: + strs, indexes = query_token_set(keyword, token_table, lexicon_table) + keywords_token[keyword] = {} + keywords_token[keyword]['token_id'] = indexes + keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) + for i in indexes) + [keywords_strset.add(i) for i in strs] + [keywords_idxset.add(i) for i in indexes] + for txt, idx in zip(strs, indexes): + if keywords_tokenmap.get(txt, None) is None: + keywords_tokenmap[txt] = idx + + token_print = '' + for txt, idx in keywords_tokenmap.items(): + token_print += f'{txt}({idx}) ' + logging.info(f'Token set is: {token_print}') + + with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, lengths, target_lengths = batch + feats = feats.to(device) + lengths = lengths.to(device) + logits, _ = model(feats) + logits = logits.softmax(2) # (batch_size, maxlen, vocab_size) + logits = logits.cpu() + for i in range(len(keys)): + key = keys[i] + score = logits[i][:lengths[i]] + # hyps = ctc_prefix_beam_search(score, lengths[i], + # keywords_idxset) + maxlen = score.size(0) + ctc_probs = score + cur_hyps = [(tuple(), (1.0, 0.0, []))] + + hit_keyword = None + activated = False + hit_score = 1.0 + start = 0; end = 0 + + # 2. CTC beam search step by step + for t in range(0, maxlen): + probs = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (0.0, 0.0, [])) + + # 2.1 First beam prune: select topk best + top_k_probs, top_k_index = probs.topk(args.score_beam_size) + + # filter prob score that is too small + filter_probs = [] + filter_index = [] + for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): + if keywords_idxset is not None: + if prob > 0.05 and idx in keywords_idxset: + filter_probs.append(prob) + filter_index.append(idx) + else: + if prob > 0.05: + filter_probs.append(prob) + filter_index.append(idx) + + if len(filter_index) == 0: + continue + + for s in filter_index: + ps = probs[s].item() + + for prefix, (pb, pnb, cur_nodes) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb, nodes = next_hyps[prefix] + n_pb = n_pb + pb * ps + pnb * ps + nodes = cur_nodes.copy() + next_hyps[prefix] = (n_pb, n_pnb, nodes) + elif s == last: + if not math.isclose(pnb, 0.0, abs_tol=0.000001): + # Update *ss -> *s; + n_pb, n_pnb, nodes = next_hyps[prefix] + n_pnb = n_pnb + pnb * ps + nodes = cur_nodes.copy() + if ps > nodes[-1]['prob']: # update frame and prob + nodes[-1]['prob'] = ps + nodes[-1]['frame'] = t + next_hyps[prefix] = (n_pb, n_pnb, nodes) + + if not math.isclose(pb, 0.0, abs_tol=0.000001): + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s,) + n_pb, n_pnb, nodes = next_hyps[n_prefix] + n_pnb = n_pnb + pb * ps + nodes = cur_nodes.copy() + nodes.append(dict(token=s, frame=t, + prob=ps)) # to record token prob + next_hyps[n_prefix] = (n_pb, n_pnb, nodes) + else: + n_prefix = prefix + (s,) + n_pb, n_pnb, nodes = next_hyps[n_prefix] + if nodes: + if ps > nodes[-1]['prob']: # update frame and prob + # nodes[-1]['prob'] = ps + # nodes[-1]['frame'] = t + nodes.pop() # to avoid change other beam which has this node. + nodes.append(dict(token=s, frame=t, prob=ps)) + else: + nodes = cur_nodes.copy() + nodes.append(dict(token=s, frame=t, + prob=ps)) # to record token prob + n_pnb = n_pnb + pb * ps + pnb * ps + next_hyps[n_prefix] = (n_pb, n_pnb, nodes) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) + + cur_hyps = next_hyps[:args.path_beam_size] + + hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] + + for one_hyp in hyps: + prefix_ids = one_hyp[0] + # path_score = one_hyp[1] + prefix_nodes = one_hyp[2] + assert len(prefix_ids) == len(prefix_nodes) + for word in keywords_token.keys(): + lab = keywords_token[word]['token_id'] + offset = is_sublist(prefix_ids, lab) + if offset != -1: + hit_keyword = word + start = prefix_nodes[offset]['frame'] + end = prefix_nodes[offset + len(lab) - 1]['frame'] + for idx in range(offset, offset + len(lab)): + hit_score *= prefix_nodes[idx]['prob'] + break + if hit_keyword is not None: + hit_score = math.sqrt(hit_score) + break + + duration = end - start + if hit_keyword is not None : + if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames: + activated = True + fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) + logging.info( + f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " + f"duration {duration}, score {hit_score} Activated.") + + # clear the ctc_prefix buffer, and clear hit_keyword + cur_hyps = [(tuple(), (1.0, 0.0, []))] + hit_keyword = None + hit_score = 1.0 + elif hit_score < args.threshold: + logging.info( + f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " + f"but {hit_score} less than {args.threshold}, Deactivated. ") + elif args.min_frames > duration or duration > args.max_frames: + logging.info( + f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " + f"but {duration} beyond range({args.min_frames}~{args.max_frames}), Deactivated. ") + if not activated: + fout.write('{} rejected\n'.format(key)) + logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") + + if batch_idx % 10 == 0: + print('Progress batch {}'.format(batch_idx)) + sys.stdout.flush() + + +if __name__ == '__main__': + main() diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index d4d7b5c..2c89269 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -74,6 +74,19 @@ class KWSModel(nn.Module): x = self.activation(x) return x, out_cache + def forward_softmax(self, + x: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_cmvn is not None: + x = self.global_cmvn(x) + x = self.preprocessing(x) + x, out_cache = self.backbone(x, in_cache) + x = self.classifier(x) + x = self.activation(x) + x = x.softmax(2) + return x, out_cache + def fuse_modules(self): self.preprocessing.fuse_modules() self.backbone.fuse_modules() diff --git a/wekws/model/loss.py b/wekws/model/loss.py index ce80da0..b272630 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 Binbin Zhang +# 2023 Jing Du # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -276,8 +277,10 @@ def ctc_prefix_beam_search( n_pb, n_pnb, nodes = next_hyps[n_prefix] if nodes: if ps > nodes[-1]['prob']: # update frame and prob - nodes[-1]['prob'] = ps - nodes[-1]['frame'] = t + # nodes[-1]['prob'] = ps + # nodes[-1]['frame'] = t + nodes.pop() # to avoid change other beam which has this node. + nodes.append(dict(token=s, frame=t, prob=ps)) else: nodes = cur_nodes.copy() nodes.append(dict(token=s, frame=t,