fix typo.

This commit is contained in:
dujing 2023-06-16 15:53:39 +08:00
parent 909480da52
commit 6f8207267e
11 changed files with 77 additions and 60 deletions

View File

@ -15,8 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse, logging
import json, re
import argparse
import logging
import json
import re
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
@ -189,7 +191,7 @@ if __name__ == '__main__':
arr = line.strip().split(maxsplit=1)
key = arr[0]
tokens = None
if token_table!=None and lexicon_table!=None :
if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"]
@ -201,7 +203,7 @@ if __name__ == '__main__':
wav = wav_table[key]
assert key in duration_table
duration = duration_table[key]
if tokens == None:
if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav)
else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)

View File

@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse, logging, glob
import json, re, os, numpy as np
import argparse
import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt
import pypinyin # for Chinese Character
@ -205,8 +210,7 @@ if __name__ == '__main__':
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][
'keyword_table'].items():
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse
import copy
import logging
import os, sys, math
import os
import sys
import math
import torch
import yaml
@ -138,7 +140,7 @@ def main():
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table,lexicon_table)
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
@ -165,11 +167,11 @@ def main():
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i],
keywords_idxset)
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
hit_keyword = None
hit_score = 1.0
start = 0; end = 0
start = 0
end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
@ -181,7 +183,7 @@ def main():
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset+len(lab)-1]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
@ -193,7 +195,7 @@ def main():
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"duration {end-start}, score {hit_score}, Activated.")
f"duration {end - start}, score {hit_score}, Activated.")
else:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -15,9 +15,11 @@
from __future__ import print_function
import argparse
import struct, wave
import struct
import wave
import logging
import os, math
import os
import math
import numpy as np
import torchaudio.compliance.kaldi as kaldi
@ -183,13 +185,13 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
class KeyWordSpotter(torch.nn.Module):
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
threshold, min_frames=5, max_frames=250, interval_frames=50,
score_beam = 3, path_beam = 20,
score_beam=3, path_beam=20,
gpu=-1, is_jit_model=False,):
super().__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
with open(config_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
dataset_conf=configs['dataset_conf']
dataset_conf = configs['dataset_conf']
# feature related
self.sample_rate = 16000
@ -198,7 +200,7 @@ class KeyWordSpotter(torch.nn.Module):
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift/1000 * self.downsampling # in second
self.resolution = self.frame_shift / 1000 * self.downsampling # in second
# fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
wave = np.array(data)
wave = np.append(self.wave_remained, wave)
wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -303,18 +306,18 @@ class KeyWordSpotter(torch.nn.Module):
if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk
if self.feature_remained == None: # first chunk
if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
else:
feats_pad = torch.cat((self.feature_remained, feats) )
feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats.shape[0] - self.right_context
ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm):
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i+ctx_win])).unsqueeze(0)
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats
self.feature_remained = feats[-self.left_context:]
@ -322,7 +325,7 @@ class KeyWordSpotter(torch.nn.Module):
if self.downsampling > 1:
feats = feats[self.feature_context_offset::self.downsampling, :]
complement = feats.size(1) % self.downsampling
self.feature_context_offset = complement if complement==0 else self.downsampling-complement
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
return feats
def decode_keywords(self, t, probs):
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
def execute_detection(self, t):
absolute_time = t + self.total_frames
hit_keyword = None
start = 0; end = 0
start = 0
end = 0
# hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
@ -385,8 +389,8 @@ class KeyWordSpotter(torch.nn.Module):
self.result = {
"state": 1 if self.activated else 0,
"keyword": hit_keyword if self.activated else None,
"start": start*self.resolution if self.activated else None,
"end": end*self.resolution if self.activated else None,
"start": start * self.resolution if self.activated else None,
"end": end * self.resolution if self.activated else None,
"score": self.hit_score if self.activated else None
}

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse
import copy
import logging
import os, sys, math
import os
import sys
import math
import torch
import yaml
@ -197,7 +199,8 @@ def main():
hit_keyword = None
activated = False
hit_score = 1.0
start = 0; end = 0
start = 0
end = 0
# 2. CTC beam search step by step
for t in range(0, maxlen):
@ -298,7 +301,7 @@ def main():
break
duration = end - start
if hit_keyword is not None :
if hit_keyword is not None:
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
activated = True
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))

View File

@ -353,7 +353,7 @@ def padding(data):
if isinstance(sample[0]['label'], int):
padded_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int32)
dtype=torch.int32)
label_lengths = torch.tensor([1 for i in order],
dtype=torch.int32)
else:

View File

@ -221,19 +221,18 @@ class FSMNBlock(nn.Module):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0])
else :
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
else:
in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :]
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
#out = out[:, :, :-self.rorder*self.rstride, :]
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]

View File

@ -64,7 +64,7 @@ class KWSModel(nn.Module):
def forward(
self,
x: torch.Tensor,
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch, math, sys
import torch
import math
import sys
import torch.nn.functional as F
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import List, Tuple
from wekws.utils.mask import padding_mask

View File

@ -14,7 +14,8 @@
# limitations under the License.
import json
import math,re
import math
import re
import numpy as np