fix typo.

This commit is contained in:
dujing 2023-06-16 15:53:39 +08:00
parent 909480da52
commit 6f8207267e
11 changed files with 77 additions and 60 deletions

View File

@ -15,8 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse, logging import argparse
import json, re import logging
import json
import re
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
@ -189,7 +191,7 @@ if __name__ == '__main__':
arr = line.strip().split(maxsplit=1) arr = line.strip().split(maxsplit=1)
key = arr[0] key = arr[0]
tokens = None tokens = None
if token_table!=None and lexicon_table!=None : if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1 txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"] tokens = ["sil"]
@ -201,7 +203,7 @@ if __name__ == '__main__':
wav = wav_table[key] wav = wav_table[key]
assert key in duration_table assert key in duration_table
duration = duration_table[key] duration = duration_table[key]
if tokens == None: if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav) line = dict(key=key, txt=txt, duration=duration, wav=wav)
else: else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav) line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)

View File

@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse, logging, glob import argparse
import json, re, os, numpy as np import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pypinyin # for Chinese Character import pypinyin # for Chinese Character
@ -205,8 +210,7 @@ if __name__ == '__main__':
num_false_reject = 0 num_false_reject = 0
num_true_detect = 0 num_true_detect = 0
# transverse the all keyword_table # transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][ for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
'keyword_table'].items():
if confi < threshold: if confi < threshold:
num_false_reject += 1 num_false_reject += 1
else: else:

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse import argparse
import copy import copy
import logging import logging
import os, sys, math import os
import sys
import math
import torch import torch
import yaml import yaml
@ -138,7 +140,7 @@ def main():
keywords_strset = {'<blk>'} keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0} keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list: for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table,lexicon_table) strs, indexes = query_token_set(keyword, token_table, lexicon_table)
keywords_token[keyword] = {} keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
@ -165,11 +167,11 @@ def main():
for i in range(len(keys)): for i in range(len(keys)):
key = keys[i] key = keys[i]
score = logits[i][:lengths[i]] score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i], hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
keywords_idxset)
hit_keyword = None hit_keyword = None
hit_score = 1.0 hit_score = 1.0
start = 0; end = 0 start = 0
end = 0
for one_hyp in hyps: for one_hyp in hyps:
prefix_ids = one_hyp[0] prefix_ids = one_hyp[0]
# path_score = one_hyp[1] # path_score = one_hyp[1]
@ -181,7 +183,7 @@ def main():
if offset != -1: if offset != -1:
hit_keyword = word hit_keyword = word
start = prefix_nodes[offset]['frame'] start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset+len(lab)-1]['frame'] end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)): for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob'] hit_score *= prefix_nodes[idx]['prob']
break break
@ -193,7 +195,7 @@ def main():
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score)) fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
logging.info( logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. " f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
f"duration {end-start}, score {hit_score}, Activated.") f"duration {end - start}, score {hit_score}, Activated.")
else: else:
fout.write('{} rejected\n'.format(key)) fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.") logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

View File

@ -15,9 +15,11 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import struct, wave import struct
import wave
import logging import logging
import os, math import os
import math
import numpy as np import numpy as np
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
@ -183,13 +185,13 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
class KeyWordSpotter(torch.nn.Module): class KeyWordSpotter(torch.nn.Module):
def __init__(self, ckpt_path, config_path, token_path, lexicon_path, def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
threshold, min_frames=5, max_frames=250, interval_frames=50, threshold, min_frames=5, max_frames=250, interval_frames=50,
score_beam = 3, path_beam = 20, score_beam=3, path_beam=20,
gpu=-1, is_jit_model=False,): gpu=-1, is_jit_model=False,):
super().__init__() super().__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
with open(config_path, 'r') as fin: with open(config_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader) configs = yaml.load(fin, Loader=yaml.FullLoader)
dataset_conf=configs['dataset_conf'] dataset_conf = configs['dataset_conf']
# feature related # feature related
self.sample_rate = 16000 self.sample_rate = 16000
@ -198,7 +200,7 @@ class KeyWordSpotter(torch.nn.Module):
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1) self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift/1000 * self.downsampling # in second self.resolution = self.frame_shift / 1000 * self.downsampling # in second
# fsmn splice operation # fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False) self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0 self.left_context = 0
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
for i in range(0, len(wave), 2): for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0] value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
wave = np.array(data) wave = np.array(data)
wave = np.append(self.wave_remained, wave) wave = np.append(self.wave_remained, wave)
wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -303,18 +306,18 @@ class KeyWordSpotter(torch.nn.Module):
if self.context_expansion: if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context." assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk # pad feats with remained feature from last chunk
if self.feature_remained == None: # first chunk if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose. # pad first frame at the beginning, replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
else: else:
feats_pad = torch.cat((self.feature_remained, feats) ) feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats.shape[0] - self.right_context ctx_frm = feats.shape[0] - self.right_context
ctx_win = (self.left_context + self.right_context + 1) ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm): for i in range(ctx_frm):
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i+ctx_win])).unsqueeze(0) feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats # update feature remained, and feats
self.feature_remained = feats[-self.left_context:] self.feature_remained = feats[-self.left_context:]
@ -322,7 +325,7 @@ class KeyWordSpotter(torch.nn.Module):
if self.downsampling > 1: if self.downsampling > 1:
feats = feats[self.feature_context_offset::self.downsampling, :] feats = feats[self.feature_context_offset::self.downsampling, :]
complement = feats.size(1) % self.downsampling complement = feats.size(1) % self.downsampling
self.feature_context_offset = complement if complement==0 else self.downsampling-complement self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
return feats return feats
def decode_keywords(self, t, probs): def decode_keywords(self, t, probs):
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
def execute_detection(self, t): def execute_detection(self, t):
absolute_time = t + self.total_frames absolute_time = t + self.total_frames
hit_keyword = None hit_keyword = None
start = 0; end = 0 start = 0
end = 0
# hyps for detection # hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps] hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
@ -385,8 +389,8 @@ class KeyWordSpotter(torch.nn.Module):
self.result = { self.result = {
"state": 1 if self.activated else 0, "state": 1 if self.activated else 0,
"keyword": hit_keyword if self.activated else None, "keyword": hit_keyword if self.activated else None,
"start": start*self.resolution if self.activated else None, "start": start * self.resolution if self.activated else None,
"end": end*self.resolution if self.activated else None, "end": end * self.resolution if self.activated else None,
"score": self.hit_score if self.activated else None "score": self.hit_score if self.activated else None
} }

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse import argparse
import copy import copy
import logging import logging
import os, sys, math import os
import sys
import math
import torch import torch
import yaml import yaml
@ -197,7 +199,8 @@ def main():
hit_keyword = None hit_keyword = None
activated = False activated = False
hit_score = 1.0 hit_score = 1.0
start = 0; end = 0 start = 0
end = 0
# 2. CTC beam search step by step # 2. CTC beam search step by step
for t in range(0, maxlen): for t in range(0, maxlen):
@ -298,7 +301,7 @@ def main():
break break
duration = end - start duration = end - start
if hit_keyword is not None : if hit_keyword is not None:
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames: if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
activated = True activated = True
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))

View File

@ -353,7 +353,7 @@ def padding(data):
if isinstance(sample[0]['label'], int): if isinstance(sample[0]['label'], int):
padded_labels = torch.tensor([sample[i]['label'] for i in order], padded_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int32) dtype=torch.int32)
label_lengths = torch.tensor([1 for i in order], label_lengths = torch.tensor([1 for i in order],
dtype=torch.int32) dtype=torch.int32)
else: else:

View File

@ -221,19 +221,18 @@ class FSMNBlock(nn.Module):
x = torch.unsqueeze(input, 1) x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache)==0 or in_cache[0]==None: if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0]) x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
else : else:
in_cache = in_cache.to(x_per.device) in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2) x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :] in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :] y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left) y_left = self.quant(y_left)
y_left = self.conv_left(y_left) y_left = self.conv_left(y_left)
y_left = self.dequant(y_left) y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
#out = out[:, :, :-self.rorder*self.rstride, :]
if self.conv_right is not None: if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]

View File

@ -64,7 +64,7 @@ class KWSModel(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None: if self.global_cmvn is not None:
x = self.global_cmvn(x) x = self.global_cmvn(x)

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch, math, sys import torch
import math
import sys
import torch.nn.functional as F import torch.nn.functional as F
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Tuple from typing import List, Tuple
from wekws.utils.mask import padding_mask from wekws.utils.mask import padding_mask

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
import json import json
import math,re import math
import re
import numpy as np import numpy as np