fix typo.

This commit is contained in:
dujing 2023-06-16 15:53:39 +08:00
parent 909480da52
commit 6f8207267e
11 changed files with 77 additions and 60 deletions

View File

@ -9,17 +9,17 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour:
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.
Since the FAR is pretty low when using CTC loss,
Since the FAR is pretty low when using CTC loss,
the follow result is FRRs with FAR fixed at once per 12 hours:
Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
FRRs with FAR fixed at once per 12 hours
| model | loss | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
@ -27,8 +27,8 @@ FRRs with FAR fixed at once per 12 hours
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
and FSMN(modelscope released, xiaoyunxiaoyun model).
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
and FSMN(modelscope released, xiaoyunxiaoyun model).
FRRs with FAR fixed at once per 12 hours:
| model | params(K) | hi_xiaowen | nihao_wenwen |
@ -36,8 +36,8 @@ FRRs with FAR fixed at once per 12 hours:
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:
Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:
| model | stream | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
@ -46,12 +46,12 @@ FRRs with FAR fixed at once per 12 hours:
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
we record the probability of a keyword in a decoding path once the keyword appears in this path.
Actually the probability will increase through the time, so we record a lower value of probability,
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.
Now, the model with CTC loss may not get the best performance,
but it's more robust compared with the classification model using CE/Max-pooling loss.
Now, the model with CTC loss may not get the best performance,
but it's more robust compared with the classification model using CE/Max-pooling loss.
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

View File

@ -15,8 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse, logging
import json, re
import argparse
import logging
import json
import re
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
@ -189,7 +191,7 @@ if __name__ == '__main__':
arr = line.strip().split(maxsplit=1)
key = arr[0]
tokens = None
if token_table!=None and lexicon_table!=None :
if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"]
@ -201,7 +203,7 @@ if __name__ == '__main__':
wav = wav_table[key]
assert key in duration_table
duration = duration_table[key]
if tokens == None:
if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav)
else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)

View File

@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse, logging, glob
import json, re, os, numpy as np
import argparse
import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt
import pypinyin # for Chinese Character
@ -205,8 +210,7 @@ if __name__ == '__main__':
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][
'keyword_table'].items():
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:
@ -234,4 +238,4 @@ if __name__ == '__main__':
det_curve_path = args.det_curve_path
else:
det_curve_path = os.path.join(stats_dir, 'det.png')
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse
import copy
import logging
import os, sys, math
import os
import sys
import math
import torch
import yaml
@ -138,7 +140,7 @@ def main():
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table,lexicon_table)
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
@ -165,11 +167,11 @@ def main():
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i],
keywords_idxset)
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
hit_keyword = None
hit_score = 1.0
start = 0; end = 0
start = 0
end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
@ -181,7 +183,7 @@ def main():
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset+len(lab)-1]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
@ -193,7 +195,7 @@ def main():
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"duration {end-start}, score {hit_score}, Activated.")
f"duration {end - start}, score {hit_score}, Activated.")
else:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -15,9 +15,11 @@
from __future__ import print_function
import argparse
import struct, wave
import struct
import wave
import logging
import os, math
import os
import math
import numpy as np
import torchaudio.compliance.kaldi as kaldi
@ -183,13 +185,13 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
class KeyWordSpotter(torch.nn.Module):
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
threshold, min_frames=5, max_frames=250, interval_frames=50,
score_beam = 3, path_beam = 20,
score_beam=3, path_beam=20,
gpu=-1, is_jit_model=False,):
super().__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
with open(config_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
dataset_conf=configs['dataset_conf']
dataset_conf = configs['dataset_conf']
# feature related
self.sample_rate = 16000
@ -198,7 +200,7 @@ class KeyWordSpotter(torch.nn.Module):
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift/1000 * self.downsampling # in second
self.resolution = self.frame_shift / 1000 * self.downsampling # in second
# fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
wave = np.array(data)
wave = np.append(self.wave_remained, wave)
wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -303,18 +306,18 @@ class KeyWordSpotter(torch.nn.Module):
if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk
if self.feature_remained == None: # first chunk
if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
else:
feats_pad = torch.cat((self.feature_remained, feats) )
feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats.shape[0] - self.right_context
ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm):
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i+ctx_win])).unsqueeze(0)
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats
self.feature_remained = feats[-self.left_context:]
@ -322,7 +325,7 @@ class KeyWordSpotter(torch.nn.Module):
if self.downsampling > 1:
feats = feats[self.feature_context_offset::self.downsampling, :]
complement = feats.size(1) % self.downsampling
self.feature_context_offset = complement if complement==0 else self.downsampling-complement
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
return feats
def decode_keywords(self, t, probs):
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
def execute_detection(self, t):
absolute_time = t + self.total_frames
hit_keyword = None
start = 0; end = 0
start = 0
end = 0
# hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
@ -385,8 +389,8 @@ class KeyWordSpotter(torch.nn.Module):
self.result = {
"state": 1 if self.activated else 0,
"keyword": hit_keyword if self.activated else None,
"start": start*self.resolution if self.activated else None,
"end": end*self.resolution if self.activated else None,
"start": start * self.resolution if self.activated else None,
"end": end * self.resolution if self.activated else None,
"score": self.hit_score if self.activated else None
}

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse
import copy
import logging
import os, sys, math
import os
import sys
import math
import torch
import yaml
@ -197,7 +199,8 @@ def main():
hit_keyword = None
activated = False
hit_score = 1.0
start = 0; end = 0
start = 0
end = 0
# 2. CTC beam search step by step
for t in range(0, maxlen):
@ -298,7 +301,7 @@ def main():
break
duration = end - start
if hit_keyword is not None :
if hit_keyword is not None:
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
activated = True
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))

View File

@ -353,7 +353,7 @@ def padding(data):
if isinstance(sample[0]['label'], int):
padded_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int32)
dtype=torch.int32)
label_lengths = torch.tensor([1 for i in order],
dtype=torch.int32)
else:

View File

@ -221,19 +221,18 @@ class FSMNBlock(nn.Module):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0])
else :
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
else:
in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :]
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
#out = out[:, :, :-self.rorder*self.rstride, :]
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]

View File

@ -64,7 +64,7 @@ class KWSModel(nn.Module):
def forward(
self,
x: torch.Tensor,
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch, math, sys
import torch
import math
import sys
import torch.nn.functional as F
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import List, Tuple
from wekws.utils.mask import padding_mask
@ -453,4 +455,4 @@ class Calculator:
return result
def keys(self):
return list(self.data.keys())
return list(self.data.keys())

View File

@ -14,7 +14,8 @@
# limitations under the License.
import json
import math,re
import math
import re
import numpy as np
@ -88,4 +89,4 @@ def load_kaldi_cmvn(cmvn_file):
cmvn = np.array([means, variance])
cmvn = np.tile(cmvn, (1, copy_times))
return cmvn
return cmvn