fix typo.
This commit is contained in:
parent
909480da52
commit
6f8207267e
@ -9,17 +9,17 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour:
|
|||||||
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||||
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||||
|
|
||||||
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
|
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
|
||||||
and we use CTC prefix beam search to decode and detect keywords,
|
and we use CTC prefix beam search to decode and detect keywords,
|
||||||
the detection is either in non-streaming or streaming fashion.
|
the detection is either in non-streaming or streaming fashion.
|
||||||
|
|
||||||
Since the FAR is pretty low when using CTC loss,
|
Since the FAR is pretty low when using CTC loss,
|
||||||
the follow result is FRRs with FAR fixed at once per 12 hours:
|
the follow result is FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
Comparison between Max-pooling and CTC loss.
|
Comparison between Max-pooling and CTC loss.
|
||||||
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
|
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
|
||||||
FRRs with FAR fixed at once per 12 hours
|
FRRs with FAR fixed at once per 12 hours
|
||||||
|
|
||||||
|
|
||||||
| model | loss | hi_xiaowen | nihao_wenwen |
|
| model | loss | hi_xiaowen | nihao_wenwen |
|
||||||
|-----------------------|-------------|------------|--------------|
|
|-----------------------|-------------|------------|--------------|
|
||||||
@ -27,8 +27,8 @@ FRRs with FAR fixed at once per 12 hours
|
|||||||
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
|
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
|
||||||
|
|
||||||
|
|
||||||
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
|
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
|
||||||
and FSMN(modelscope released, xiaoyunxiaoyun model).
|
and FSMN(modelscope released, xiaoyunxiaoyun model).
|
||||||
FRRs with FAR fixed at once per 12 hours:
|
FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
| model | params(K) | hi_xiaowen | nihao_wenwen |
|
| model | params(K) | hi_xiaowen | nihao_wenwen |
|
||||||
@ -36,8 +36,8 @@ FRRs with FAR fixed at once per 12 hours:
|
|||||||
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
|
||||||
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
|
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
|
||||||
|
|
||||||
Comparison Between stream_score_ctc and score_ctc.
|
Comparison Between stream_score_ctc and score_ctc.
|
||||||
FRRs with FAR fixed at once per 12 hours:
|
FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
| model | stream | hi_xiaowen | nihao_wenwen |
|
| model | stream | hi_xiaowen | nihao_wenwen |
|
||||||
|-----------------------|-------------|------------|--------------|
|
|-----------------------|-------------|------------|--------------|
|
||||||
@ -46,12 +46,12 @@ FRRs with FAR fixed at once per 12 hours:
|
|||||||
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
|
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
|
||||||
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
|
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
|
||||||
|
|
||||||
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
|
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
|
||||||
we record the probability of a keyword in a decoding path once the keyword appears in this path.
|
we record the probability of a keyword in a decoding path once the keyword appears in this path.
|
||||||
Actually the probability will increase through the time, so we record a lower value of probability,
|
Actually the probability will increase through the time, so we record a lower value of probability,
|
||||||
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
||||||
The actual FRR will be lower than the DET curve gives in a given threshold.
|
The actual FRR will be lower than the DET curve gives in a given threshold.
|
||||||
|
|
||||||
Now, the model with CTC loss may not get the best performance,
|
Now, the model with CTC loss may not get the best performance,
|
||||||
but it's more robust compared with the classification model using CE/Max-pooling loss.
|
but it's more robust compared with the classification model using CE/Max-pooling loss.
|
||||||
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
||||||
@ -15,8 +15,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse, logging
|
import argparse
|
||||||
import json, re
|
import logging
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||||
|
|
||||||
@ -189,7 +191,7 @@ if __name__ == '__main__':
|
|||||||
arr = line.strip().split(maxsplit=1)
|
arr = line.strip().split(maxsplit=1)
|
||||||
key = arr[0]
|
key = arr[0]
|
||||||
tokens = None
|
tokens = None
|
||||||
if token_table!=None and lexicon_table!=None :
|
if token_table is not None and lexicon_table is not None :
|
||||||
if len(arr) < 2: # for some utterence, no text
|
if len(arr) < 2: # for some utterence, no text
|
||||||
txt = [1] # the <blank>/sil is indexed by 1
|
txt = [1] # the <blank>/sil is indexed by 1
|
||||||
tokens = ["sil"]
|
tokens = ["sil"]
|
||||||
@ -201,7 +203,7 @@ if __name__ == '__main__':
|
|||||||
wav = wav_table[key]
|
wav = wav_table[key]
|
||||||
assert key in duration_table
|
assert key in duration_table
|
||||||
duration = duration_table[key]
|
duration = duration_table[key]
|
||||||
if tokens == None:
|
if tokens is None:
|
||||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||||
else:
|
else:
|
||||||
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
||||||
|
|||||||
@ -14,8 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse, logging, glob
|
import argparse
|
||||||
import json, re, os, numpy as np
|
import logging
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pypinyin # for Chinese Character
|
import pypinyin # for Chinese Character
|
||||||
|
|
||||||
@ -205,8 +210,7 @@ if __name__ == '__main__':
|
|||||||
num_false_reject = 0
|
num_false_reject = 0
|
||||||
num_true_detect = 0
|
num_true_detect = 0
|
||||||
# transverse the all keyword_table
|
# transverse the all keyword_table
|
||||||
for key, confi in keyword_filler_table[keyword][
|
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
|
||||||
'keyword_table'].items():
|
|
||||||
if confi < threshold:
|
if confi < threshold:
|
||||||
num_false_reject += 1
|
num_false_reject += 1
|
||||||
else:
|
else:
|
||||||
@ -234,4 +238,4 @@ if __name__ == '__main__':
|
|||||||
det_curve_path = args.det_curve_path
|
det_curve_path = args.det_curve_path
|
||||||
else:
|
else:
|
||||||
det_curve_path = os.path.join(stats_dir, 'det.png')
|
det_curve_path = os.path.join(stats_dir, 'det.png')
|
||||||
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
|
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
|
||||||
|
|||||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os, sys, math
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -138,7 +140,7 @@ def main():
|
|||||||
keywords_strset = {'<blk>'}
|
keywords_strset = {'<blk>'}
|
||||||
keywords_tokenmap = {'<blk>': 0}
|
keywords_tokenmap = {'<blk>': 0}
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
strs, indexes = query_token_set(keyword, token_table,lexicon_table)
|
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||||
keywords_token[keyword] = {}
|
keywords_token[keyword] = {}
|
||||||
keywords_token[keyword]['token_id'] = indexes
|
keywords_token[keyword]['token_id'] = indexes
|
||||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||||
@ -165,11 +167,11 @@ def main():
|
|||||||
for i in range(len(keys)):
|
for i in range(len(keys)):
|
||||||
key = keys[i]
|
key = keys[i]
|
||||||
score = logits[i][:lengths[i]]
|
score = logits[i][:lengths[i]]
|
||||||
hyps = ctc_prefix_beam_search(score, lengths[i],
|
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
|
||||||
keywords_idxset)
|
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
start = 0; end = 0
|
start = 0
|
||||||
|
end = 0
|
||||||
for one_hyp in hyps:
|
for one_hyp in hyps:
|
||||||
prefix_ids = one_hyp[0]
|
prefix_ids = one_hyp[0]
|
||||||
# path_score = one_hyp[1]
|
# path_score = one_hyp[1]
|
||||||
@ -181,7 +183,7 @@ def main():
|
|||||||
if offset != -1:
|
if offset != -1:
|
||||||
hit_keyword = word
|
hit_keyword = word
|
||||||
start = prefix_nodes[offset]['frame']
|
start = prefix_nodes[offset]['frame']
|
||||||
end = prefix_nodes[offset+len(lab)-1]['frame']
|
end = prefix_nodes[offset + len(lab) - 1]['frame']
|
||||||
for idx in range(offset, offset + len(lab)):
|
for idx in range(offset, offset + len(lab)):
|
||||||
hit_score *= prefix_nodes[idx]['prob']
|
hit_score *= prefix_nodes[idx]['prob']
|
||||||
break
|
break
|
||||||
@ -193,7 +195,7 @@ def main():
|
|||||||
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
|
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
||||||
f"duration {end-start}, score {hit_score}, Activated.")
|
f"duration {end - start}, score {hit_score}, Activated.")
|
||||||
else:
|
else:
|
||||||
fout.write('{} rejected\n'.format(key))
|
fout.write('{} rejected\n'.format(key))
|
||||||
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||||
|
|||||||
@ -15,9 +15,11 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import struct, wave
|
import struct
|
||||||
|
import wave
|
||||||
import logging
|
import logging
|
||||||
import os, math
|
import os
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
|
|
||||||
@ -183,13 +185,13 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
|
|||||||
class KeyWordSpotter(torch.nn.Module):
|
class KeyWordSpotter(torch.nn.Module):
|
||||||
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
|
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
|
||||||
threshold, min_frames=5, max_frames=250, interval_frames=50,
|
threshold, min_frames=5, max_frames=250, interval_frames=50,
|
||||||
score_beam = 3, path_beam = 20,
|
score_beam=3, path_beam=20,
|
||||||
gpu=-1, is_jit_model=False,):
|
gpu=-1, is_jit_model=False,):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
||||||
with open(config_path, 'r') as fin:
|
with open(config_path, 'r') as fin:
|
||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
dataset_conf=configs['dataset_conf']
|
dataset_conf = configs['dataset_conf']
|
||||||
|
|
||||||
# feature related
|
# feature related
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
@ -198,7 +200,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
|
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
|
||||||
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
|
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
|
||||||
self.downsampling = dataset_conf.get('frame_skip', 1)
|
self.downsampling = dataset_conf.get('frame_skip', 1)
|
||||||
self.resolution = self.frame_shift/1000 * self.downsampling # in second
|
self.resolution = self.frame_shift / 1000 * self.downsampling # in second
|
||||||
# fsmn splice operation
|
# fsmn splice operation
|
||||||
self.context_expansion = dataset_conf.get('context_expansion', False)
|
self.context_expansion = dataset_conf.get('context_expansion', False)
|
||||||
self.left_context = 0
|
self.left_context = 0
|
||||||
@ -284,6 +286,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
for i in range(0, len(wave), 2):
|
for i in range(0, len(wave), 2):
|
||||||
value = struct.unpack('<h', wave[i:i + 2])[0]
|
value = struct.unpack('<h', wave[i:i + 2])[0]
|
||||||
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
||||||
|
|
||||||
wave = np.array(data)
|
wave = np.array(data)
|
||||||
wave = np.append(self.wave_remained, wave)
|
wave = np.append(self.wave_remained, wave)
|
||||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||||
@ -303,18 +306,18 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
if self.context_expansion:
|
if self.context_expansion:
|
||||||
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
||||||
# pad feats with remained feature from last chunk
|
# pad feats with remained feature from last chunk
|
||||||
if self.feature_remained == None: # first chunk
|
if self.feature_remained is None: # first chunk
|
||||||
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
||||||
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
||||||
else:
|
else:
|
||||||
feats_pad = torch.cat((self.feature_remained, feats) )
|
feats_pad = torch.cat((self.feature_remained, feats))
|
||||||
|
|
||||||
ctx_frm = feats.shape[0] - self.right_context
|
ctx_frm = feats.shape[0] - self.right_context
|
||||||
ctx_win = (self.left_context + self.right_context + 1)
|
ctx_win = (self.left_context + self.right_context + 1)
|
||||||
ctx_dim = feats.shape[1] * ctx_win
|
ctx_dim = feats.shape[1] * ctx_win
|
||||||
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
||||||
for i in range(ctx_frm):
|
for i in range(ctx_frm):
|
||||||
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i+ctx_win])).unsqueeze(0)
|
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
||||||
|
|
||||||
# update feature remained, and feats
|
# update feature remained, and feats
|
||||||
self.feature_remained = feats[-self.left_context:]
|
self.feature_remained = feats[-self.left_context:]
|
||||||
@ -322,7 +325,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
if self.downsampling > 1:
|
if self.downsampling > 1:
|
||||||
feats = feats[self.feature_context_offset::self.downsampling, :]
|
feats = feats[self.feature_context_offset::self.downsampling, :]
|
||||||
complement = feats.size(1) % self.downsampling
|
complement = feats.size(1) % self.downsampling
|
||||||
self.feature_context_offset = complement if complement==0 else self.downsampling-complement
|
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement
|
||||||
return feats
|
return feats
|
||||||
|
|
||||||
def decode_keywords(self, t, probs):
|
def decode_keywords(self, t, probs):
|
||||||
@ -340,7 +343,8 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
def execute_detection(self, t):
|
def execute_detection(self, t):
|
||||||
absolute_time = t + self.total_frames
|
absolute_time = t + self.total_frames
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
start = 0; end = 0
|
start = 0
|
||||||
|
end = 0
|
||||||
|
|
||||||
# hyps for detection
|
# hyps for detection
|
||||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
||||||
@ -385,8 +389,8 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.result = {
|
self.result = {
|
||||||
"state": 1 if self.activated else 0,
|
"state": 1 if self.activated else 0,
|
||||||
"keyword": hit_keyword if self.activated else None,
|
"keyword": hit_keyword if self.activated else None,
|
||||||
"start": start*self.resolution if self.activated else None,
|
"start": start * self.resolution if self.activated else None,
|
||||||
"end": end*self.resolution if self.activated else None,
|
"end": end * self.resolution if self.activated else None,
|
||||||
"score": self.hit_score if self.activated else None
|
"score": self.hit_score if self.activated else None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,9 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os, sys, math
|
import os
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -197,7 +199,8 @@ def main():
|
|||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
activated = False
|
activated = False
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
start = 0; end = 0
|
start = 0
|
||||||
|
end = 0
|
||||||
|
|
||||||
# 2. CTC beam search step by step
|
# 2. CTC beam search step by step
|
||||||
for t in range(0, maxlen):
|
for t in range(0, maxlen):
|
||||||
@ -298,7 +301,7 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
duration = end - start
|
duration = end - start
|
||||||
if hit_keyword is not None :
|
if hit_keyword is not None:
|
||||||
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
|
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
|
||||||
activated = True
|
activated = True
|
||||||
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))
|
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))
|
||||||
|
|||||||
@ -353,7 +353,7 @@ def padding(data):
|
|||||||
|
|
||||||
if isinstance(sample[0]['label'], int):
|
if isinstance(sample[0]['label'], int):
|
||||||
padded_labels = torch.tensor([sample[i]['label'] for i in order],
|
padded_labels = torch.tensor([sample[i]['label'] for i in order],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
label_lengths = torch.tensor([1 for i in order],
|
label_lengths = torch.tensor([1 for i in order],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -221,19 +221,18 @@ class FSMNBlock(nn.Module):
|
|||||||
x = torch.unsqueeze(input, 1)
|
x = torch.unsqueeze(input, 1)
|
||||||
x_per = x.permute(0, 3, 2, 1)
|
x_per = x.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
if in_cache is None or len(in_cache)==0 or in_cache[0]==None:
|
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None:
|
||||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0])
|
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
||||||
else :
|
else:
|
||||||
in_cache = in_cache.to(x_per.device)
|
in_cache = in_cache.to(x_per.device)
|
||||||
x_pad = torch.cat((in_cache, x_per), dim=2)
|
x_pad = torch.cat((in_cache, x_per), dim=2)
|
||||||
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :]
|
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
|
||||||
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
|
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
|
||||||
y_left = self.quant(y_left)
|
y_left = self.quant(y_left)
|
||||||
y_left = self.conv_left(y_left)
|
y_left = self.conv_left(y_left)
|
||||||
y_left = self.dequant(y_left)
|
y_left = self.dequant(y_left)
|
||||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
||||||
|
|
||||||
#out = out[:, :, :-self.rorder*self.rstride, :]
|
|
||||||
if self.conv_right is not None:
|
if self.conv_right is not None:
|
||||||
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||||
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
|
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class KWSModel(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
in_cache=None #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
|
|||||||
@ -13,10 +13,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch, math, sys
|
import torch
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from wekws.utils.mask import padding_mask
|
from wekws.utils.mask import padding_mask
|
||||||
|
|
||||||
@ -453,4 +455,4 @@ class Calculator:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return list(self.data.keys())
|
return list(self.data.keys())
|
||||||
|
|||||||
@ -14,7 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math,re
|
import math
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -88,4 +89,4 @@ def load_kaldi_cmvn(cmvn_file):
|
|||||||
cmvn = np.array([means, variance])
|
cmvn = np.array([means, variance])
|
||||||
cmvn = np.tile(cmvn, (1, copy_times))
|
cmvn = np.tile(cmvn, (1, copy_times))
|
||||||
|
|
||||||
return cmvn
|
return cmvn
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user