fix typo.
This commit is contained in:
parent
909480da52
commit
6f8207267e
@ -15,8 +15,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse, logging
|
import argparse
|
||||||
import json, re
|
import logging
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||||
|
|
||||||
@ -189,7 +191,7 @@ if __name__ == '__main__':
|
|||||||
arr = line.strip().split(maxsplit=1)
|
arr = line.strip().split(maxsplit=1)
|
||||||
key = arr[0]
|
key = arr[0]
|
||||||
tokens = None
|
tokens = None
|
||||||
if token_table!=None and lexicon_table!=None :
|
if token_table is not None and lexicon_table is not None :
|
||||||
if len(arr) < 2: # for some utterence, no text
|
if len(arr) < 2: # for some utterence, no text
|
||||||
txt = [1] # the <blank>/sil is indexed by 1
|
txt = [1] # the <blank>/sil is indexed by 1
|
||||||
tokens = ["sil"]
|
tokens = ["sil"]
|
||||||
@ -201,7 +203,7 @@ if __name__ == '__main__':
|
|||||||
wav = wav_table[key]
|
wav = wav_table[key]
|
||||||
assert key in duration_table
|
assert key in duration_table
|
||||||
duration = duration_table[key]
|
duration = duration_table[key]
|
||||||
if tokens == None:
|
if tokens is None:
|
||||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||||
else:
|
else:
|
||||||
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
||||||
|
|||||||
@ -14,8 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse, logging, glob
|
import argparse
|
||||||
import json, re, os, numpy as np
|
import logging
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pypinyin # for Chinese Character
|
import pypinyin # for Chinese Character
|
||||||
|
|
||||||
@ -205,8 +210,7 @@ if __name__ == '__main__':
|
|||||||
num_false_reject = 0
|
num_false_reject = 0
|
||||||
num_true_detect = 0
|
num_true_detect = 0
|
||||||
# transverse the all keyword_table
|
# transverse the all keyword_table
|
||||||
for key, confi in keyword_filler_table[keyword][
|
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
|
||||||
'keyword_table'].items():
|
|
||||||
if confi < threshold:
|
if confi < threshold:
|
||||||
num_false_reject += 1
|
num_false_reject += 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os, sys, math
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -165,11 +167,11 @@ def main():
|
|||||||
for i in range(len(keys)):
|
for i in range(len(keys)):
|
||||||
key = keys[i]
|
key = keys[i]
|
||||||
score = logits[i][:lengths[i]]
|
score = logits[i][:lengths[i]]
|
||||||
hyps = ctc_prefix_beam_search(score, lengths[i],
|
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
|
||||||
keywords_idxset)
|
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
start = 0; end = 0
|
start = 0
|
||||||
|
end = 0
|
||||||
for one_hyp in hyps:
|
for one_hyp in hyps:
|
||||||
prefix_ids = one_hyp[0]
|
prefix_ids = one_hyp[0]
|
||||||
# path_score = one_hyp[1]
|
# path_score = one_hyp[1]
|
||||||
|
|||||||
@ -15,9 +15,11 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import struct, wave
|
import struct
|
||||||
|
import wave
|
||||||
import logging
|
import logging
|
||||||
import os, math
|
import os
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
for i in range(0, len(wave), 2):
|
for i in range(0, len(wave), 2):
|
||||||
value = struct.unpack('<h', wave[i:i + 2])[0]
|
value = struct.unpack('<h', wave[i:i + 2])[0]
|
||||||
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
||||||
|
|
||||||
wave = np.array(data)
|
wave = np.array(data)
|
||||||
wave = np.append(self.wave_remained, wave)
|
wave = np.append(self.wave_remained, wave)
|
||||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||||
@ -303,7 +306,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
if self.context_expansion:
|
if self.context_expansion:
|
||||||
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
||||||
# pad feats with remained feature from last chunk
|
# pad feats with remained feature from last chunk
|
||||||
if self.feature_remained == None: # first chunk
|
if self.feature_remained is None: # first chunk
|
||||||
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
||||||
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
||||||
else:
|
else:
|
||||||
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
def execute_detection(self, t):
|
def execute_detection(self, t):
|
||||||
absolute_time = t + self.total_frames
|
absolute_time = t + self.total_frames
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
start = 0; end = 0
|
start = 0
|
||||||
|
end = 0
|
||||||
|
|
||||||
# hyps for detection
|
# hyps for detection
|
||||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
||||||
|
|||||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os, sys, math
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -197,7 +199,8 @@ def main():
|
|||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
activated = False
|
activated = False
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
start = 0; end = 0
|
start = 0
|
||||||
|
end = 0
|
||||||
|
|
||||||
# 2. CTC beam search step by step
|
# 2. CTC beam search step by step
|
||||||
for t in range(0, maxlen):
|
for t in range(0, maxlen):
|
||||||
|
|||||||
@ -221,7 +221,7 @@ class FSMNBlock(nn.Module):
|
|||||||
x = torch.unsqueeze(input, 1)
|
x = torch.unsqueeze(input, 1)
|
||||||
x_per = x.permute(0, 3, 2, 1)
|
x_per = x.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
|
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
|
||||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
||||||
else:
|
else:
|
||||||
in_cache = in_cache.to(x_per.device)
|
in_cache = in_cache.to(x_per.device)
|
||||||
@ -233,7 +233,6 @@ class FSMNBlock(nn.Module):
|
|||||||
y_left = self.dequant(y_left)
|
y_left = self.dequant(y_left)
|
||||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
||||||
|
|
||||||
#out = out[:, :, :-self.rorder*self.rstride, :]
|
|
||||||
if self.conv_right is not None:
|
if self.conv_right is not None:
|
||||||
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||||
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
|
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class KWSModel(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
|
|||||||
@ -13,10 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch, math, sys
|
import torch
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from wekws.utils.mask import padding_mask
|
from wekws.utils.mask import padding_mask
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math,re
|
import math
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user