fix typo.

This commit is contained in:
dujing 2023-06-16 15:53:39 +08:00
parent 909480da52
commit 6f8207267e
11 changed files with 77 additions and 60 deletions

View File

@ -15,8 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse, logging import argparse
import json, re import logging
import json
import re
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
@ -189,7 +191,7 @@ if __name__ == '__main__':
arr = line.strip().split(maxsplit=1) arr = line.strip().split(maxsplit=1)
key = arr[0] key = arr[0]
tokens = None tokens = None
if token_table!=None and lexicon_table!=None : if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1 txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"] tokens = ["sil"]
@ -201,7 +203,7 @@ if __name__ == '__main__':
wav = wav_table[key] wav = wav_table[key]
assert key in duration_table assert key in duration_table
duration = duration_table[key] duration = duration_table[key]
if tokens == None: if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav) line = dict(key=key, txt=txt, duration=duration, wav=wav)
else: else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav) line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)

View File

@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse, logging, glob import argparse
import json, re, os, numpy as np import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pypinyin # for Chinese Character import pypinyin # for Chinese Character
@ -205,8 +210,7 @@ if __name__ == '__main__':
num_false_reject = 0 num_false_reject = 0
num_true_detect = 0 num_true_detect = 0
# transverse the all keyword_table # transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][ for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
'keyword_table'].items():
if confi < threshold: if confi < threshold:
num_false_reject += 1 num_false_reject += 1
else: else:

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse import argparse
import copy import copy
import logging import logging
import os, sys, math import os
import sys
import math
import torch import torch
import yaml import yaml
@ -165,11 +167,11 @@ def main():
for i in range(len(keys)): for i in range(len(keys)):
key = keys[i] key = keys[i]
score = logits[i][:lengths[i]] score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i], hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
keywords_idxset)
hit_keyword = None hit_keyword = None
hit_score = 1.0 hit_score = 1.0
start = 0; end = 0 start = 0
end = 0
for one_hyp in hyps: for one_hyp in hyps:
prefix_ids = one_hyp[0] prefix_ids = one_hyp[0]
# path_score = one_hyp[1] # path_score = one_hyp[1]

View File

@ -15,9 +15,11 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import struct, wave import struct
import wave
import logging import logging
import os, math import os
import math
import numpy as np import numpy as np
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
for i in range(0, len(wave), 2): for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0] value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
wave = np.array(data) wave = np.array(data)
wave = np.append(self.wave_remained, wave) wave = np.append(self.wave_remained, wave)
wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -303,7 +306,7 @@ class KeyWordSpotter(torch.nn.Module):
if self.context_expansion: if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context." assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk # pad feats with remained feature from last chunk
if self.feature_remained == None: # first chunk if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose. # pad first frame at the beginning, replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
else: else:
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
def execute_detection(self, t): def execute_detection(self, t):
absolute_time = t + self.total_frames absolute_time = t + self.total_frames
hit_keyword = None hit_keyword = None
start = 0; end = 0 start = 0
end = 0
# hyps for detection # hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps] hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse import argparse
import copy import copy
import logging import logging
import os, sys, math import os
import sys
import math
import torch import torch
import yaml import yaml
@ -197,7 +199,8 @@ def main():
hit_keyword = None hit_keyword = None
activated = False activated = False
hit_score = 1.0 hit_score = 1.0
start = 0; end = 0 start = 0
end = 0
# 2. CTC beam search step by step # 2. CTC beam search step by step
for t in range(0, maxlen): for t in range(0, maxlen):

View File

@ -221,7 +221,7 @@ class FSMNBlock(nn.Module):
x = torch.unsqueeze(input, 1) x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache)==0 or in_cache[0]==None: if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0]) x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
else: else:
in_cache = in_cache.to(x_per.device) in_cache = in_cache.to(x_per.device)
@ -233,7 +233,6 @@ class FSMNBlock(nn.Module):
y_left = self.dequant(y_left) y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
#out = out[:, :, :-self.rorder*self.rstride, :]
if self.conv_right is not None: if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]

View File

@ -64,7 +64,7 @@ class KWSModel(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None: if self.global_cmvn is not None:
x = self.global_cmvn(x) x = self.global_cmvn(x)

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch, math, sys import torch
import math
import sys
import torch.nn.functional as F import torch.nn.functional as F
from collections import defaultdict from collections import defaultdict
from typing import List, Optional, Tuple from typing import List, Tuple
from wekws.utils.mask import padding_mask from wekws.utils.mask import padding_mask

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
import json import json
import math,re import math
import re
import numpy as np import numpy as np