fix typo.

This commit is contained in:
dujing 2023-06-16 15:53:39 +08:00
parent 909480da52
commit 6f8207267e
11 changed files with 77 additions and 60 deletions

View File

@ -15,8 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse, logging
import json, re
import argparse
import logging
import json
import re
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
@ -189,7 +191,7 @@ if __name__ == '__main__':
arr = line.strip().split(maxsplit=1)
key = arr[0]
tokens = None
if token_table!=None and lexicon_table!=None :
if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"]
@ -201,7 +203,7 @@ if __name__ == '__main__':
wav = wav_table[key]
assert key in duration_table
duration = duration_table[key]
if tokens == None:
if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav)
else:
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)

View File

@ -14,8 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse, logging, glob
import json, re, os, numpy as np
import argparse
import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt
import pypinyin # for Chinese Character
@ -205,8 +210,7 @@ if __name__ == '__main__':
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][
'keyword_table'].items():
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse
import copy
import logging
import os, sys, math
import os
import sys
import math
import torch
import yaml
@ -165,11 +167,11 @@ def main():
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score, lengths[i],
keywords_idxset)
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
hit_keyword = None
hit_score = 1.0
start = 0; end = 0
start = 0
end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]

View File

@ -15,9 +15,11 @@
from __future__ import print_function
import argparse
import struct, wave
import struct
import wave
import logging
import os, math
import os
import math
import numpy as np
import torchaudio.compliance.kaldi as kaldi
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
wave = np.array(data)
wave = np.append(self.wave_remained, wave)
wave_tensor = torch.from_numpy(wave).float().to(self.device)
@ -303,7 +306,7 @@ class KeyWordSpotter(torch.nn.Module):
if self.context_expansion:
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk
if self.feature_remained == None: # first chunk
if self.feature_remained is None: # first chunk
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
else:
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
def execute_detection(self, t):
absolute_time = t + self.total_frames
hit_keyword = None
start = 0; end = 0
start = 0
end = 0
# hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]

View File

@ -19,7 +19,9 @@ from __future__ import print_function
import argparse
import copy
import logging
import os, sys, math
import os
import sys
import math
import torch
import yaml
@ -197,7 +199,8 @@ def main():
hit_keyword = None
activated = False
hit_score = 1.0
start = 0; end = 0
start = 0
end = 0
# 2. CTC beam search step by step
for t in range(0, maxlen):

View File

@ -221,7 +221,7 @@ class FSMNBlock(nn.Module):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
else:
in_cache = in_cache.to(x_per.device)
@ -233,7 +233,6 @@ class FSMNBlock(nn.Module):
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
#out = out[:, :, :-self.rorder*self.rstride, :]
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]

View File

@ -64,7 +64,7 @@ class KWSModel(nn.Module):
def forward(
self,
x: torch.Tensor,
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch, math, sys
import torch
import math
import sys
import torch.nn.functional as F
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import List, Tuple
from wekws.utils.mask import padding_mask

View File

@ -14,7 +14,8 @@
# limitations under the License.
import json
import math,re
import math
import re
import numpy as np