fix typo.
This commit is contained in:
parent
909480da52
commit
6f8207267e
@ -15,8 +15,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse, logging
|
||||
import json, re
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||
|
||||
@ -189,7 +191,7 @@ if __name__ == '__main__':
|
||||
arr = line.strip().split(maxsplit=1)
|
||||
key = arr[0]
|
||||
tokens = None
|
||||
if token_table!=None and lexicon_table!=None :
|
||||
if token_table is not None and lexicon_table is not None :
|
||||
if len(arr) < 2: # for some utterence, no text
|
||||
txt = [1] # the <blank>/sil is indexed by 1
|
||||
tokens = ["sil"]
|
||||
@ -201,7 +203,7 @@ if __name__ == '__main__':
|
||||
wav = wav_table[key]
|
||||
assert key in duration_table
|
||||
duration = duration_table[key]
|
||||
if tokens == None:
|
||||
if tokens is None:
|
||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||
else:
|
||||
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
||||
|
||||
@ -14,8 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse, logging, glob
|
||||
import json, re, os, numpy as np
|
||||
import argparse
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pypinyin # for Chinese Character
|
||||
|
||||
@ -205,8 +210,7 @@ if __name__ == '__main__':
|
||||
num_false_reject = 0
|
||||
num_true_detect = 0
|
||||
# transverse the all keyword_table
|
||||
for key, confi in keyword_filler_table[keyword][
|
||||
'keyword_table'].items():
|
||||
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
|
||||
if confi < threshold:
|
||||
num_false_reject += 1
|
||||
else:
|
||||
|
||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os, sys, math
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -165,11 +167,11 @@ def main():
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = logits[i][:lengths[i]]
|
||||
hyps = ctc_prefix_beam_search(score, lengths[i],
|
||||
keywords_idxset)
|
||||
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
start = 0; end = 0
|
||||
start = 0
|
||||
end = 0
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
|
||||
@ -15,9 +15,11 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import struct, wave
|
||||
import struct
|
||||
import wave
|
||||
import logging
|
||||
import os, math
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
for i in range(0, len(wave), 2):
|
||||
value = struct.unpack('<h', wave[i:i + 2])[0]
|
||||
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
||||
|
||||
wave = np.array(data)
|
||||
wave = np.append(self.wave_remained, wave)
|
||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||
@ -303,7 +306,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
if self.context_expansion:
|
||||
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
||||
# pad feats with remained feature from last chunk
|
||||
if self.feature_remained == None: # first chunk
|
||||
if self.feature_remained is None: # first chunk
|
||||
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
||||
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
||||
else:
|
||||
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
def execute_detection(self, t):
|
||||
absolute_time = t + self.total_frames
|
||||
hit_keyword = None
|
||||
start = 0; end = 0
|
||||
start = 0
|
||||
end = 0
|
||||
|
||||
# hyps for detection
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
||||
|
||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os, sys, math
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -197,7 +199,8 @@ def main():
|
||||
hit_keyword = None
|
||||
activated = False
|
||||
hit_score = 1.0
|
||||
start = 0; end = 0
|
||||
start = 0
|
||||
end = 0
|
||||
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
|
||||
@ -221,7 +221,7 @@ class FSMNBlock(nn.Module):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
|
||||
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
|
||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
||||
else:
|
||||
in_cache = in_cache.to(x_per.device)
|
||||
@ -233,7 +233,6 @@ class FSMNBlock(nn.Module):
|
||||
y_left = self.dequant(y_left)
|
||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
||||
|
||||
#out = out[:, :, :-self.rorder*self.rstride, :]
|
||||
if self.conv_right is not None:
|
||||
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
|
||||
|
||||
@ -64,7 +64,7 @@ class KWSModel(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
|
||||
@ -13,10 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch, math, sys
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import torch.nn.functional as F
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
from wekws.utils.mask import padding_mask
|
||||
|
||||
|
||||
@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math,re
|
||||
import math
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user