diff --git a/examples/hi_xiaowen/s0/README.md b/examples/hi_xiaowen/s0/README.md index 9477005..f776416 100644 --- a/examples/hi_xiaowen/s0/README.md +++ b/examples/hi_xiaowen/s0/README.md @@ -9,17 +9,17 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour: | MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 | | MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 | -Next, we use CTC loss to train the model, with DS_TCN and FSMN. -and we use CTC prefix beam search to decode and detect keywords, -the detection is either in non-streaming or streaming fashion. +Next, we use CTC loss to train the model, with DS_TCN and FSMN. +and we use CTC prefix beam search to decode and detect keywords, +the detection is either in non-streaming or streaming fashion. -Since the FAR is pretty low when using CTC loss, +Since the FAR is pretty low when using CTC loss, the follow result is FRRs with FAR fixed at once per 12 hours: -Comparison between Max-pooling and CTC loss. -The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch). +Comparison between Max-pooling and CTC loss. +The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch). FRRs with FAR fixed at once per 12 hours - + | model | loss | hi_xiaowen | nihao_wenwen | |-----------------------|-------------|------------|--------------| @@ -27,8 +27,8 @@ FRRs with FAR fixed at once per 12 hours | DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | -Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch) -and FSMN(modelscope released, xiaoyunxiaoyun model). +Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch) +and FSMN(modelscope released, xiaoyunxiaoyun model). FRRs with FAR fixed at once per 12 hours: | model | params(K) | hi_xiaowen | nihao_wenwen | @@ -36,8 +36,8 @@ FRRs with FAR fixed at once per 12 hours: | DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | | FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | -Comparison Between stream_score_ctc and score_ctc. -FRRs with FAR fixed at once per 12 hours: +Comparison Between stream_score_ctc and score_ctc. +FRRs with FAR fixed at once per 12 hours: | model | stream | hi_xiaowen | nihao_wenwen | |-----------------------|-------------|------------|--------------| @@ -46,12 +46,12 @@ FRRs with FAR fixed at once per 12 hours: | FSMN(spec_aug) | no | 0.031012 | 0.022460 | | FSMN(spec_aug) | yes | 0.115215 | 0.020205 | -Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame), +Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame), we record the probability of a keyword in a decoding path once the keyword appears in this path. Actually the probability will increase through the time, so we record a lower value of probability, -which result in a higher False Rejection Rate in Detection Error Tradeoff result. -The actual FRR will be lower than the DET curve gives in a given threshold. +which result in a higher False Rejection Rate in Detection Error Tradeoff result. +The actual FRR will be lower than the DET curve gives in a given threshold. -Now, the model with CTC loss may not get the best performance, -but it's more robust compared with the classification model using CE/Max-pooling loss. +Now, the model with CTC loss may not get the best performance, +but it's more robust compared with the classification model using CE/Max-pooling loss. For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). \ No newline at end of file diff --git a/tools/make_list.py b/tools/make_list.py index 60928c2..da76823 100755 --- a/tools/make_list.py +++ b/tools/make_list.py @@ -15,8 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse, logging -import json, re +import argparse +import logging +import json +import re symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' @@ -189,7 +191,7 @@ if __name__ == '__main__': arr = line.strip().split(maxsplit=1) key = arr[0] tokens = None - if token_table!=None and lexicon_table!=None : + if token_table is not None and lexicon_table is not None : if len(arr) < 2: # for some utterence, no text txt = [1] # the /sil is indexed by 1 tokens = ["sil"] @@ -201,7 +203,7 @@ if __name__ == '__main__': wav = wav_table[key] assert key in duration_table duration = duration_table[key] - if tokens == None: + if tokens is None: line = dict(key=key, txt=txt, duration=duration, wav=wav) else: line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav) diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index fcc82bc..14f3d86 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -14,8 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse, logging, glob -import json, re, os, numpy as np +import argparse +import logging +import glob +import json +import re +import os +import numpy as np import matplotlib.pyplot as plt import pypinyin # for Chinese Character @@ -205,8 +210,7 @@ if __name__ == '__main__': num_false_reject = 0 num_true_detect = 0 # transverse the all keyword_table - for key, confi in keyword_filler_table[keyword][ - 'keyword_table'].items(): + for key, confi in keyword_filler_table[keyword]['keyword_table'].items(): if confi < threshold: num_false_reject += 1 else: @@ -234,4 +238,4 @@ if __name__ == '__main__': det_curve_path = args.det_curve_path else: det_curve_path = os.path.join(stats_dir, 'det.png') - plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step) \ No newline at end of file + plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step) diff --git a/wekws/bin/score_ctc.py b/wekws/bin/score_ctc.py index 2af3df9..ef442b8 100644 --- a/wekws/bin/score_ctc.py +++ b/wekws/bin/score_ctc.py @@ -19,7 +19,9 @@ from __future__ import print_function import argparse import copy import logging -import os, sys, math +import os +import sys +import math import torch import yaml @@ -138,7 +140,7 @@ def main(): keywords_strset = {''} keywords_tokenmap = {'': 0} for keyword in keywords_list: - strs, indexes = query_token_set(keyword, token_table,lexicon_table) + strs, indexes = query_token_set(keyword, token_table, lexicon_table) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) @@ -165,11 +167,11 @@ def main(): for i in range(len(keys)): key = keys[i] score = logits[i][:lengths[i]] - hyps = ctc_prefix_beam_search(score, lengths[i], - keywords_idxset) + hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset) hit_keyword = None hit_score = 1.0 - start = 0; end = 0 + start = 0 + end = 0 for one_hyp in hyps: prefix_ids = one_hyp[0] # path_score = one_hyp[1] @@ -181,7 +183,7 @@ def main(): if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] - end = prefix_nodes[offset+len(lab)-1]['frame'] + end = prefix_nodes[offset + len(lab) - 1]['frame'] for idx in range(offset, offset + len(lab)): hit_score *= prefix_nodes[idx]['prob'] break @@ -193,7 +195,7 @@ def main(): fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score)) logging.info( f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " - f"duration {end-start}, score {hit_score}, Activated.") + f"duration {end - start}, score {hit_score}, Activated.") else: fout.write('{} rejected\n'.format(key)) logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") diff --git a/wekws/bin/stream_kws_ctc.py b/wekws/bin/stream_kws_ctc.py index 1eb3aa3..869b880 100644 --- a/wekws/bin/stream_kws_ctc.py +++ b/wekws/bin/stream_kws_ctc.py @@ -15,9 +15,11 @@ from __future__ import print_function import argparse -import struct, wave +import struct +import wave import logging -import os, math +import os +import math import numpy as np import torchaudio.compliance.kaldi as kaldi @@ -183,13 +185,13 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size) class KeyWordSpotter(torch.nn.Module): def __init__(self, ckpt_path, config_path, token_path, lexicon_path, threshold, min_frames=5, max_frames=250, interval_frames=50, - score_beam = 3, path_beam = 20, + score_beam=3, path_beam=20, gpu=-1, is_jit_model=False,): super().__init__() os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) with open(config_path, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) - dataset_conf=configs['dataset_conf'] + dataset_conf = configs['dataset_conf'] # feature related self.sample_rate = 16000 @@ -198,7 +200,7 @@ class KeyWordSpotter(torch.nn.Module): self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms self.downsampling = dataset_conf.get('frame_skip', 1) - self.resolution = self.frame_shift/1000 * self.downsampling # in second + self.resolution = self.frame_shift / 1000 * self.downsampling # in second # fsmn splice operation self.context_expansion = dataset_conf.get('context_expansion', False) self.left_context = 0 @@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module): for i in range(0, len(wave), 2): value = struct.unpack(' self.right_context, "make sure each chunk feat length is large than right context." # pad feats with remained feature from last chunk - if self.feature_remained == None: # first chunk + if self.feature_remained is None: # first chunk # pad first frame at the beginning, replicate just support last dimension, so we do transpose. feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T else: - feats_pad = torch.cat((self.feature_remained, feats) ) + feats_pad = torch.cat((self.feature_remained, feats)) ctx_frm = feats.shape[0] - self.right_context ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) for i in range(ctx_frm): - feats_ctx[i] = torch.cat(tuple(feats_pad[i: i+ctx_win])).unsqueeze(0) + feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) # update feature remained, and feats self.feature_remained = feats[-self.left_context:] @@ -322,7 +325,7 @@ class KeyWordSpotter(torch.nn.Module): if self.downsampling > 1: feats = feats[self.feature_context_offset::self.downsampling, :] complement = feats.size(1) % self.downsampling - self.feature_context_offset = complement if complement==0 else self.downsampling-complement + self.feature_context_offset = complement if complement == 0 else self.downsampling-complement return feats def decode_keywords(self, t, probs): @@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module): def execute_detection(self, t): absolute_time = t + self.total_frames hit_keyword = None - start = 0; end = 0 + start = 0 + end = 0 # hyps for detection hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps] @@ -385,8 +389,8 @@ class KeyWordSpotter(torch.nn.Module): self.result = { "state": 1 if self.activated else 0, "keyword": hit_keyword if self.activated else None, - "start": start*self.resolution if self.activated else None, - "end": end*self.resolution if self.activated else None, + "start": start * self.resolution if self.activated else None, + "end": end * self.resolution if self.activated else None, "score": self.hit_score if self.activated else None } diff --git a/wekws/bin/stream_score_ctc.py b/wekws/bin/stream_score_ctc.py index 8ea692f..954f8eb 100644 --- a/wekws/bin/stream_score_ctc.py +++ b/wekws/bin/stream_score_ctc.py @@ -19,7 +19,9 @@ from __future__ import print_function import argparse import copy import logging -import os, sys, math +import os +import sys +import math import torch import yaml @@ -197,7 +199,8 @@ def main(): hit_keyword = None activated = False hit_score = 1.0 - start = 0; end = 0 + start = 0 + end = 0 # 2. CTC beam search step by step for t in range(0, maxlen): @@ -298,7 +301,7 @@ def main(): break duration = end - start - if hit_keyword is not None : + if hit_keyword is not None: if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames: activated = True fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) diff --git a/wekws/dataset/processor.py b/wekws/dataset/processor.py index f1d9073..ca8a9db 100644 --- a/wekws/dataset/processor.py +++ b/wekws/dataset/processor.py @@ -353,7 +353,7 @@ def padding(data): if isinstance(sample[0]['label'], int): padded_labels = torch.tensor([sample[i]['label'] for i in order], - dtype=torch.int32) + dtype=torch.int32) label_lengths = torch.tensor([1 for i in order], dtype=torch.int32) else: diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py index c5f57be..903d456 100644 --- a/wekws/model/fsmn.py +++ b/wekws/model/fsmn.py @@ -221,19 +221,18 @@ class FSMNBlock(nn.Module): x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) - if in_cache is None or len(in_cache)==0 or in_cache[0]==None: - x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0]) - else : + if in_cache is None or len(in_cache) == 0 or in_cache[0] is None: + x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0]) + else: in_cache = in_cache.to(x_per.device) x_pad = torch.cat((in_cache, x_per), dim=2) - in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :] + in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :] y_left = x_pad[:, :, :-self.rorder * self.rstride, :] y_left = self.quant(y_left) y_left = self.conv_left(y_left) y_left = self.dequant(y_left) - out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left + out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left - #out = out[:, :, :-self.rorder*self.rstride, :] if self.conv_right is not None: # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index b67a8d3..2c89269 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -64,7 +64,7 @@ class KWSModel(nn.Module): def forward( self, x: torch.Tensor, - in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) ) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_cmvn is not None: x = self.global_cmvn(x) diff --git a/wekws/model/loss.py b/wekws/model/loss.py index b272630..46004a4 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch, math, sys +import torch +import math +import sys import torch.nn.functional as F from collections import defaultdict -from typing import List, Optional, Tuple +from typing import List, Tuple from wekws.utils.mask import padding_mask @@ -453,4 +455,4 @@ class Calculator: return result def keys(self): - return list(self.data.keys()) \ No newline at end of file + return list(self.data.keys()) diff --git a/wekws/utils/cmvn.py b/wekws/utils/cmvn.py index 4280679..b28fad3 100644 --- a/wekws/utils/cmvn.py +++ b/wekws/utils/cmvn.py @@ -14,7 +14,8 @@ # limitations under the License. import json -import math,re +import math +import re import numpy as np @@ -88,4 +89,4 @@ def load_kaldi_cmvn(cmvn_file): cmvn = np.array([means, variance]) cmvn = np.tile(cmvn, (1, copy_times)) - return cmvn \ No newline at end of file + return cmvn