fix typo.
This commit is contained in:
parent
909480da52
commit
6f8207267e
@ -9,17 +9,17 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour:
|
||||
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||
|
||||
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
|
||||
and we use CTC prefix beam search to decode and detect keywords,
|
||||
the detection is either in non-streaming or streaming fashion.
|
||||
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
|
||||
and we use CTC prefix beam search to decode and detect keywords,
|
||||
the detection is either in non-streaming or streaming fashion.
|
||||
|
||||
Since the FAR is pretty low when using CTC loss,
|
||||
Since the FAR is pretty low when using CTC loss,
|
||||
the follow result is FRRs with FAR fixed at once per 12 hours:
|
||||
|
||||
Comparison between Max-pooling and CTC loss.
|
||||
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
|
||||
Comparison between Max-pooling and CTC loss.
|
||||
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
|
||||
FRRs with FAR fixed at once per 12 hours
|
||||
|
||||
|
||||
|
||||
| model | loss | hi_xiaowen | nihao_wenwen |
|
||||
|-----------------------|-------------|------------|--------------|
|
||||
@ -27,8 +27,8 @@ FRRs with FAR fixed at once per 12 hours
|
||||
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
|
||||
|
||||
|
||||
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
|
||||
and FSMN(modelscope released, xiaoyunxiaoyun model).
|
||||
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
|
||||
and FSMN(modelscope released, xiaoyunxiaoyun model).
|
||||
FRRs with FAR fixed at once per 12 hours:
|
||||
|
||||
| model | params(K) | hi_xiaowen | nihao_wenwen |
|
||||
@ -36,8 +36,8 @@ FRRs with FAR fixed at once per 12 hours:
|
||||
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
|
||||
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
|
||||
|
||||
Comparison Between stream_score_ctc and score_ctc.
|
||||
FRRs with FAR fixed at once per 12 hours:
|
||||
Comparison Between stream_score_ctc and score_ctc.
|
||||
FRRs with FAR fixed at once per 12 hours:
|
||||
|
||||
| model | stream | hi_xiaowen | nihao_wenwen |
|
||||
|-----------------------|-------------|------------|--------------|
|
||||
@ -46,12 +46,12 @@ FRRs with FAR fixed at once per 12 hours:
|
||||
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
|
||||
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
|
||||
|
||||
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
|
||||
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
|
||||
we record the probability of a keyword in a decoding path once the keyword appears in this path.
|
||||
Actually the probability will increase through the time, so we record a lower value of probability,
|
||||
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
||||
The actual FRR will be lower than the DET curve gives in a given threshold.
|
||||
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
||||
The actual FRR will be lower than the DET curve gives in a given threshold.
|
||||
|
||||
Now, the model with CTC loss may not get the best performance,
|
||||
but it's more robust compared with the classification model using CE/Max-pooling loss.
|
||||
Now, the model with CTC loss may not get the best performance,
|
||||
but it's more robust compared with the classification model using CE/Max-pooling loss.
|
||||
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
||||
@ -15,8 +15,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse, logging
|
||||
import json, re
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||
|
||||
@ -189,7 +191,7 @@ if __name__ == '__main__':
|
||||
arr = line.strip().split(maxsplit=1)
|
||||
key = arr[0]
|
||||
tokens = None
|
||||
if token_table!=None and lexicon_table!=None :
|
||||
if token_table is not None and lexicon_table is not None :
|
||||
if len(arr) < 2: # for some utterence, no text
|
||||
txt = [1] # the <blank>/sil is indexed by 1
|
||||
tokens = ["sil"]
|
||||
@ -201,7 +203,7 @@ if __name__ == '__main__':
|
||||
wav = wav_table[key]
|
||||
assert key in duration_table
|
||||
duration = duration_table[key]
|
||||
if tokens == None:
|
||||
if tokens is None:
|
||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||
else:
|
||||
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
||||
|
||||
@ -14,8 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse, logging, glob
|
||||
import json, re, os, numpy as np
|
||||
import argparse
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pypinyin # for Chinese Character
|
||||
|
||||
@ -205,8 +210,7 @@ if __name__ == '__main__':
|
||||
num_false_reject = 0
|
||||
num_true_detect = 0
|
||||
# transverse the all keyword_table
|
||||
for key, confi in keyword_filler_table[keyword][
|
||||
'keyword_table'].items():
|
||||
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
|
||||
if confi < threshold:
|
||||
num_false_reject += 1
|
||||
else:
|
||||
@ -234,4 +238,4 @@ if __name__ == '__main__':
|
||||
det_curve_path = args.det_curve_path
|
||||
else:
|
||||
det_curve_path = os.path.join(stats_dir, 'det.png')
|
||||
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
|
||||
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
|
||||
|
||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os, sys, math
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -138,7 +140,7 @@ def main():
|
||||
keywords_strset = {'<blk>'}
|
||||
keywords_tokenmap = {'<blk>': 0}
|
||||
for keyword in keywords_list:
|
||||
strs, indexes = query_token_set(keyword, token_table,lexicon_table)
|
||||
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||
keywords_token[keyword] = {}
|
||||
keywords_token[keyword]['token_id'] = indexes
|
||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||
@ -165,11 +167,11 @@ def main():
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = logits[i][:lengths[i]]
|
||||
hyps = ctc_prefix_beam_search(score, lengths[i],
|
||||
keywords_idxset)
|
||||
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
start = 0; end = 0
|
||||
start = 0
|
||||
end = 0
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
@ -181,7 +183,7 @@ def main():
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
start = prefix_nodes[offset]['frame']
|
||||
end = prefix_nodes[offset+len(lab)-1]['frame']
|
||||
end = prefix_nodes[offset + len(lab) - 1]['frame']
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
@ -193,7 +195,7 @@ def main():
|
||||
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
|
||||
logging.info(
|
||||
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
||||
f"duration {end-start}, score {hit_score}, Activated.")
|
||||
f"duration {end - start}, score {hit_score}, Activated.")
|
||||
else:
|
||||
fout.write('{} rejected\n'.format(key))
|
||||
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||
|
||||
@ -15,9 +15,11 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import struct, wave
|
||||
import struct
|
||||
import wave
|
||||
import logging
|
||||
import os, math
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
@ -183,13 +185,13 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
|
||||
class KeyWordSpotter(torch.nn.Module):
|
||||
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
|
||||
threshold, min_frames=5, max_frames=250, interval_frames=50,
|
||||
score_beam = 3, path_beam = 20,
|
||||
score_beam=3, path_beam=20,
|
||||
gpu=-1, is_jit_model=False,):
|
||||
super().__init__()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
||||
with open(config_path, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
dataset_conf=configs['dataset_conf']
|
||||
dataset_conf = configs['dataset_conf']
|
||||
|
||||
# feature related
|
||||
self.sample_rate = 16000
|
||||
@ -198,7 +200,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
|
||||
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
|
||||
self.downsampling = dataset_conf.get('frame_skip', 1)
|
||||
self.resolution = self.frame_shift/1000 * self.downsampling # in second
|
||||
self.resolution = self.frame_shift / 1000 * self.downsampling # in second
|
||||
# fsmn splice operation
|
||||
self.context_expansion = dataset_conf.get('context_expansion', False)
|
||||
self.left_context = 0
|
||||
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
for i in range(0, len(wave), 2):
|
||||
value = struct.unpack('<h', wave[i:i + 2])[0]
|
||||
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
||||
|
||||
wave = np.array(data)
|
||||
wave = np.append(self.wave_remained, wave)
|
||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||
@ -303,18 +306,18 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
if self.context_expansion:
|
||||
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
||||
# pad feats with remained feature from last chunk
|
||||
if self.feature_remained == None: # first chunk
|
||||
if self.feature_remained is None: # first chunk
|
||||
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
||||
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
||||
else:
|
||||
feats_pad = torch.cat((self.feature_remained, feats) )
|
||||
feats_pad = torch.cat((self.feature_remained, feats))
|
||||
|
||||
ctx_frm = feats.shape[0] - self.right_context
|
||||
ctx_win = (self.left_context + self.right_context + 1)
|
||||
ctx_dim = feats.shape[1] * ctx_win
|
||||
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
||||
for i in range(ctx_frm):
|
||||
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i+ctx_win])).unsqueeze(0)
|
||||
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
||||
|
||||
# update feature remained, and feats
|
||||
self.feature_remained = feats[-self.left_context:]
|
||||
@ -322,7 +325,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
if self.downsampling > 1:
|
||||
feats = feats[self.feature_context_offset::self.downsampling, :]
|
||||
complement = feats.size(1) % self.downsampling
|
||||
self.feature_context_offset = complement if complement==0 else self.downsampling-complement
|
||||
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
|
||||
return feats
|
||||
|
||||
def decode_keywords(self, t, probs):
|
||||
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
def execute_detection(self, t):
|
||||
absolute_time = t + self.total_frames
|
||||
hit_keyword = None
|
||||
start = 0; end = 0
|
||||
start = 0
|
||||
end = 0
|
||||
|
||||
# hyps for detection
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
||||
@ -385,8 +389,8 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
self.result = {
|
||||
"state": 1 if self.activated else 0,
|
||||
"keyword": hit_keyword if self.activated else None,
|
||||
"start": start*self.resolution if self.activated else None,
|
||||
"end": end*self.resolution if self.activated else None,
|
||||
"start": start * self.resolution if self.activated else None,
|
||||
"end": end * self.resolution if self.activated else None,
|
||||
"score": self.hit_score if self.activated else None
|
||||
}
|
||||
|
||||
|
||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os, sys, math
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -197,7 +199,8 @@ def main():
|
||||
hit_keyword = None
|
||||
activated = False
|
||||
hit_score = 1.0
|
||||
start = 0; end = 0
|
||||
start = 0
|
||||
end = 0
|
||||
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
@ -298,7 +301,7 @@ def main():
|
||||
break
|
||||
|
||||
duration = end - start
|
||||
if hit_keyword is not None :
|
||||
if hit_keyword is not None:
|
||||
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
|
||||
activated = True
|
||||
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))
|
||||
|
||||
@ -353,7 +353,7 @@ def padding(data):
|
||||
|
||||
if isinstance(sample[0]['label'], int):
|
||||
padded_labels = torch.tensor([sample[i]['label'] for i in order],
|
||||
dtype=torch.int32)
|
||||
dtype=torch.int32)
|
||||
label_lengths = torch.tensor([1 for i in order],
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
|
||||
@ -221,19 +221,18 @@ class FSMNBlock(nn.Module):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
|
||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0])
|
||||
else :
|
||||
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
|
||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
||||
else:
|
||||
in_cache = in_cache.to(x_per.device)
|
||||
x_pad = torch.cat((in_cache, x_per), dim=2)
|
||||
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :]
|
||||
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
|
||||
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
|
||||
y_left = self.quant(y_left)
|
||||
y_left = self.conv_left(y_left)
|
||||
y_left = self.dequant(y_left)
|
||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
||||
|
||||
#out = out[:, :, :-self.rorder*self.rstride, :]
|
||||
if self.conv_right is not None:
|
||||
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
|
||||
|
||||
@ -64,7 +64,7 @@ class KWSModel(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
|
||||
@ -13,10 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch, math, sys
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import torch.nn.functional as F
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
from wekws.utils.mask import padding_mask
|
||||
|
||||
@ -453,4 +455,4 @@ class Calculator:
|
||||
return result
|
||||
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
return list(self.data.keys())
|
||||
|
||||
@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math,re
|
||||
import math
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -88,4 +89,4 @@ def load_kaldi_cmvn(cmvn_file):
|
||||
cmvn = np.array([means, variance])
|
||||
cmvn = np.tile(cmvn, (1, copy_times))
|
||||
|
||||
return cmvn
|
||||
return cmvn
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user