fix flake8, update training script and README, give pretrained ckpt.
This commit is contained in:
parent
45f0522f19
commit
ea6a0f5cda
@ -1,4 +1,6 @@
|
|||||||
Comparison among different backbones. FRRs with FAR fixed at once per hour:
|
Comparison among different backbones,
|
||||||
|
all models use Max-Pooling loss.
|
||||||
|
FRRs with FAR fixed at once per hour:
|
||||||
|
|
||||||
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|
||||||
|-----------------------|-----------|-----------|------------|--------------|
|
|-----------------------|-----------|-----------|------------|--------------|
|
||||||
@ -9,32 +11,35 @@ Comparison among different backbones. FRRs with FAR fixed at once per hour:
|
|||||||
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||||
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||||
|
|
||||||
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
|
Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
|
||||||
and we use CTC prefix beam search to decode and detect keywords,
|
and we use CTC prefix beam search to decode and detect keywords,
|
||||||
the detection is either in non-streaming or streaming fashion.
|
the detection is either in non-streaming or streaming fashion.
|
||||||
|
|
||||||
Since the FAR is pretty low when using CTC loss,
|
Since the FAR is pretty low when using CTC loss,
|
||||||
the follow result is FRRs with FAR fixed at once per 12 hours:
|
the follow results are FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
Comparison between Max-pooling and CTC loss.
|
Comparison between Max-pooling and CTC loss.
|
||||||
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
|
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
|
||||||
FRRs with FAR fixed at once per 12 hours
|
FRRs with FAR fixed at once per 12 hours
|
||||||
|
|
||||||
|
| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|
||||||
| model | loss | hi_xiaowen | nihao_wenwen |
|
|-----------------------|-------------|------------|--------------|------------|
|
||||||
|-----------------------|-------------|------------|--------------|
|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
|
||||||
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 |
|
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
|
||||||
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
|
|
||||||
|
|
||||||
|
|
||||||
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
|
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
|
||||||
and FSMN(modelscope released, xiaoyunxiaoyun model).
|
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
|
||||||
FRRs with FAR fixed at once per 12 hours:
|
FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
| model | params(K) | hi_xiaowen | nihao_wenwen |
|
| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|
||||||
|-----------------------|-------------|------------|--------------|
|
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
|
||||||
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
|
||||||
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
|
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |
|
||||||
|
|
||||||
|
Now, the DSTCN model with CTC loss may not get the best performance, because the
|
||||||
|
pretraining phase is not sufficiently converged. We recommend you use pretrained
|
||||||
|
FSMN model as initial checkpoint to train your own model.
|
||||||
|
|
||||||
Comparison Between stream_score_ctc and score_ctc.
|
Comparison Between stream_score_ctc and score_ctc.
|
||||||
FRRs with FAR fixed at once per 12 hours:
|
FRRs with FAR fixed at once per 12 hours:
|
||||||
@ -52,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va
|
|||||||
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
||||||
The actual FRR will be lower than the DET curve gives in a given threshold.
|
The actual FRR will be lower than the DET curve gives in a given threshold.
|
||||||
|
|
||||||
Now, the model with CTC loss may not get the best performance,
|
On some small data KWS tasks, we believe the FSMN-CTC model is more robust
|
||||||
but it's more robust compared with the classification model using CE/Max-pooling loss.
|
compared with the classification model using CE/Max-pooling loss.
|
||||||
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
||||||
@ -194,12 +194,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
--num_workers 8 \
|
--num_workers 8 \
|
||||||
--keywords 嗨小问,你好问问 \
|
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||||
--token_file data/tokens.txt \
|
--token_file data/tokens.txt \
|
||||||
--lexicon_file data/lexicon.txt
|
--lexicon_file data/lexicon.txt
|
||||||
|
|
||||||
python wekws/bin/compute_det_ctc.py \
|
python wekws/bin/compute_det_ctc.py \
|
||||||
--keywords 嗨小问,你好问问 \
|
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--window_shift $window_shift \
|
--window_shift $window_shift \
|
||||||
--step 0.001 \
|
--step 0.001 \
|
||||||
|
|||||||
@ -145,12 +145,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
--num_workers 8 \
|
--num_workers 8 \
|
||||||
--keywords 嗨小问,你好问问 \
|
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||||
--token_file data/tokens.txt \
|
--token_file data/tokens.txt \
|
||||||
--lexicon_file data/lexicon.txt
|
--lexicon_file data/lexicon.txt
|
||||||
|
|
||||||
python wekws/bin/compute_det_ctc.py \
|
python wekws/bin/compute_det_ctc.py \
|
||||||
--keywords 嗨小问,你好问问 \
|
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--window_shift $window_shift \
|
--window_shift $window_shift \
|
||||||
--step 0.001 \
|
--step 0.001 \
|
||||||
@ -161,14 +161,13 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input
|
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||||
# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
|
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||||
# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
# For now, FSMN can not export to JITScript
|
||||||
# python wekws/bin/export_jit.py \
|
# python wekws/bin/export_jit.py \
|
||||||
# --config $dir/config.yaml \
|
# --config $dir/config.yaml \
|
||||||
# --checkpoint $score_checkpoint \
|
# --checkpoint $score_checkpoint \
|
||||||
# --jit_model $dir/$jit_model
|
# --jit_model $dir/$jit_model
|
||||||
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
|
||||||
python wekws/bin/export_onnx.py \
|
python wekws/bin/export_onnx.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
|
|||||||
@ -45,7 +45,8 @@ def query_token_set(txt, symbol_table, lexicon_table):
|
|||||||
tokens_str = tokens_str + ('!sil', )
|
tokens_str = tokens_str + ('!sil', )
|
||||||
elif part == '<blk>' or part == '<blank>':
|
elif part == '<blk>' or part == '<blank>':
|
||||||
tokens_str = tokens_str + ('<blk>', )
|
tokens_str = tokens_str + ('<blk>', )
|
||||||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
elif part == '(noise)' or part == 'noise)' or \
|
||||||
|
part == '(noise' or part == '<noise>':
|
||||||
tokens_str = tokens_str + ('<GBG>', )
|
tokens_str = tokens_str + ('<GBG>', )
|
||||||
elif part in symbol_table:
|
elif part in symbol_table:
|
||||||
tokens_str = tokens_str + (part, )
|
tokens_str = tokens_str + (part, )
|
||||||
@ -75,11 +76,11 @@ def query_token_set(txt, symbol_table, lexicon_table):
|
|||||||
if '<GBG>' in symbol_table:
|
if '<GBG>' in symbol_table:
|
||||||
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||||
logging.info(
|
logging.info(
|
||||||
f'\'{ch}\' is not in token set, replace with <GBG>')
|
f'{ch} is not in token set, replace with <GBG>')
|
||||||
else:
|
else:
|
||||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||||
logging.info(
|
logging.info(
|
||||||
f'\'{ch}\' is not in token set, replace with <blk>')
|
f'{ch} is not in token set, replace with <blk>')
|
||||||
|
|
||||||
return tokens_str, tokens_idx
|
return tokens_str, tokens_idx
|
||||||
|
|
||||||
@ -94,7 +95,8 @@ def query_token_list(txt, symbol_table, lexicon_table):
|
|||||||
tokens_str.append('!sil')
|
tokens_str.append('!sil')
|
||||||
elif part == '<blk>' or part == '<blank>':
|
elif part == '<blk>' or part == '<blank>':
|
||||||
tokens_str.append('<blk>')
|
tokens_str.append('<blk>')
|
||||||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
elif part == '(noise)' or part == 'noise)' or \
|
||||||
|
part == '(noise' or part == '<noise>':
|
||||||
tokens_str.append('<GBG>')
|
tokens_str.append('<GBG>')
|
||||||
elif part in symbol_table:
|
elif part in symbol_table:
|
||||||
tokens_str.append(part)
|
tokens_str.append(part)
|
||||||
@ -124,11 +126,11 @@ def query_token_list(txt, symbol_table, lexicon_table):
|
|||||||
if '<GBG>' in symbol_table:
|
if '<GBG>' in symbol_table:
|
||||||
tokens_idx.append(symbol_table['<GBG>'])
|
tokens_idx.append(symbol_table['<GBG>'])
|
||||||
logging.info(
|
logging.info(
|
||||||
f'\'{ch}\' is not in token set, replace with <GBG>')
|
f'{ch} is not in token set, replace with <GBG>')
|
||||||
else:
|
else:
|
||||||
tokens_idx.append(symbol_table['<blk>'])
|
tokens_idx.append(symbol_table['<blk>'])
|
||||||
logging.info(
|
logging.info(
|
||||||
f'\'{ch}\' is not in token set, replace with <blk>')
|
f'{ch} is not in token set, replace with <blk>')
|
||||||
|
|
||||||
return tokens_str, tokens_idx
|
return tokens_str, tokens_idx
|
||||||
|
|
||||||
@ -160,8 +162,10 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('text_file', help='text file')
|
parser.add_argument('text_file', help='text file')
|
||||||
parser.add_argument('duration_file', help='duration file')
|
parser.add_argument('duration_file', help='duration file')
|
||||||
parser.add_argument('output_file', help='output list file')
|
parser.add_argument('output_file', help='output list file')
|
||||||
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
parser.add_argument('--token_file', type=str, default=None,
|
||||||
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||||
|
help='the path of lexicon.txt')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
wav_table = {}
|
wav_table = {}
|
||||||
@ -196,7 +200,9 @@ if __name__ == '__main__':
|
|||||||
txt = [1] # the <blank>/sil is indexed by 1
|
txt = [1] # the <blank>/sil is indexed by 1
|
||||||
tokens = ["sil"]
|
tokens = ["sil"]
|
||||||
else:
|
else:
|
||||||
tokens, txt = query_token_list(arr[1], token_table, lexicon_table)
|
tokens, txt = query_token_list(arr[1],
|
||||||
|
token_table,
|
||||||
|
lexicon_table)
|
||||||
else:
|
else:
|
||||||
txt = int(arr[1])
|
txt = int(arr[1])
|
||||||
assert key in wav_table
|
assert key in wav_table
|
||||||
@ -206,7 +212,8 @@ if __name__ == '__main__':
|
|||||||
if tokens is None:
|
if tokens is None:
|
||||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||||
else:
|
else:
|
||||||
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, tok=tokens, txt=txt,
|
||||||
|
duration=duration, wav=wav)
|
||||||
|
|
||||||
json_line = json.dumps(line, ensure_ascii=False)
|
json_line = json.dumps(line, ensure_ascii=False)
|
||||||
fout.write(json_line + '\n')
|
fout.write(json_line + '\n')
|
||||||
|
|||||||
@ -157,18 +157,23 @@ def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='compute det curve')
|
parser = argparse.ArgumentParser(description='compute det curve')
|
||||||
parser.add_argument('--test_data', required=True, help='label file')
|
parser.add_argument('--test_data', required=True, help='label file')
|
||||||
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
|
parser.add_argument('--keywords', type=str, default=None,
|
||||||
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
help='keywords, split with comma(,)')
|
||||||
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
parser.add_argument('--token_file', type=str, default=None,
|
||||||
|
help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||||
|
help='the path of lexicon.txt')
|
||||||
parser.add_argument('--score_file', required=True, help='score file')
|
parser.add_argument('--score_file', required=True, help='score file')
|
||||||
parser.add_argument('--step', type=float, default=0.01,
|
parser.add_argument('--step', type=float, default=0.01,
|
||||||
help='threshold step')
|
help='threshold step')
|
||||||
parser.add_argument('--window_shift', type=int, default=50,
|
parser.add_argument('--window_shift', type=int, default=50,
|
||||||
help='window_shift is used to skip the frames after triggered')
|
help='window_shift is used to '
|
||||||
|
'skip the frames after triggered')
|
||||||
parser.add_argument('--stats_dir',
|
parser.add_argument('--stats_dir',
|
||||||
required=False,
|
required=False,
|
||||||
default=None,
|
default=None,
|
||||||
help='false reject/alarm stats dir, default in score_file')
|
help='false reject/alarm stats dir, '
|
||||||
|
'default in score_file')
|
||||||
parser.add_argument('--det_curve_path',
|
parser.add_argument('--det_curve_path',
|
||||||
required=False,
|
required=False,
|
||||||
default=None,
|
default=None,
|
||||||
@ -188,7 +193,11 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
window_shift = args.window_shift
|
window_shift = args.window_shift
|
||||||
keywords_list = args.keywords.strip().split(',')
|
logging.info(f"keywords is {args.keywords}, "
|
||||||
|
f"Chinese is converted into Unicode.")
|
||||||
|
|
||||||
|
keywords = args.keywords.encode('utf-8').decode('unicode_escape')
|
||||||
|
keywords_list = keywords.strip().split(',')
|
||||||
|
|
||||||
token_table = read_token(args.token_file)
|
token_table = read_token(args.token_file)
|
||||||
lexicon_table = read_lexicon(args.lexicon_file)
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
@ -197,7 +206,8 @@ if __name__ == '__main__':
|
|||||||
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||||
true_keywords[keyword] = ''.join(strs)
|
true_keywords[keyword] = ''.join(strs)
|
||||||
|
|
||||||
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords)
|
keyword_filler_table = load_label_and_score(
|
||||||
|
keywords_list, args.test_data, args.score_file, true_keywords)
|
||||||
|
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
keyword = true_keywords[keyword]
|
keyword = true_keywords[keyword]
|
||||||
@ -206,8 +216,10 @@ if __name__ == '__main__':
|
|||||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||||
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
||||||
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
||||||
assert keyword_num > 0, 'Can\'t compute det for {} without positive sample'
|
assert keyword_num > 0, \
|
||||||
assert filler_num > 0, 'Can\'t compute det for {} without negative sample'
|
'Can\'t compute det for {} without positive sample'
|
||||||
|
assert filler_num > 0, \
|
||||||
|
'Can\'t compute det for {} without negative sample'
|
||||||
|
|
||||||
logging.info('Computing det for {}'.format(keyword))
|
logging.info('Computing det for {}'.format(keyword))
|
||||||
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
||||||
@ -218,14 +230,16 @@ if __name__ == '__main__':
|
|||||||
stats_dir = args.stats_dir
|
stats_dir = args.stats_dir
|
||||||
else:
|
else:
|
||||||
stats_dir = os.path.dirname(args.score_file)
|
stats_dir = os.path.dirname(args.score_file)
|
||||||
stats_file = os.path.join(stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
stats_file = os.path.join(
|
||||||
|
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
||||||
with open(stats_file, 'w', encoding='utf8') as fout:
|
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||||
threshold = 0.0
|
threshold = 0.0
|
||||||
while threshold <= 1.0:
|
while threshold <= 1.0:
|
||||||
num_false_reject = 0
|
num_false_reject = 0
|
||||||
num_true_detect = 0
|
num_true_detect = 0
|
||||||
# transverse the all keyword_table
|
# transverse the all keyword_table
|
||||||
for key, confi in keyword_filler_table[keyword]['keyword_table'].items():
|
for key, confi in \
|
||||||
|
keyword_filler_table[keyword]['keyword_table'].items():
|
||||||
if confi < threshold:
|
if confi < threshold:
|
||||||
num_false_reject += 1
|
num_false_reject += 1
|
||||||
else:
|
else:
|
||||||
@ -253,4 +267,5 @@ if __name__ == '__main__':
|
|||||||
det_curve_path = args.det_curve_path
|
det_curve_path = args.det_curve_path
|
||||||
else:
|
else:
|
||||||
det_curve_path = os.path.join(stats_dir, 'det.png')
|
det_curve_path = os.path.join(stats_dir, 'det.png')
|
||||||
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
|
plot_det(stats_dir, det_curve_path,
|
||||||
|
args.xlim, args.x_step, args.ylim, args.y_step)
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def main():
|
|||||||
feature_dim = configs['model']['input_dim']
|
feature_dim = configs['model']['input_dim']
|
||||||
model = init_model(configs['model'])
|
model = init_model(configs['model'])
|
||||||
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
|
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
|
||||||
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
|
# if we use ctc_loss, the logits need to be convert into probs
|
||||||
model.forward = model.forward_softmax
|
model.forward = model.forward_softmax
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
|||||||
@ -65,9 +65,12 @@ def get_args():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help='Use pinned memory buffers used for reading')
|
help='Use pinned memory buffers used for reading')
|
||||||
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
|
parser.add_argument('--keywords', type=str, default=None,
|
||||||
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
help='the keywords, split with comma(,)')
|
||||||
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
parser.add_argument('--token_file', type=str, default=None,
|
||||||
|
help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||||
|
help='the path of lexicon.txt')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -133,7 +136,9 @@ def main():
|
|||||||
lexicon_table = read_lexicon(args.lexicon_file)
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
# 4. parse keywords tokens
|
# 4. parse keywords tokens
|
||||||
assert args.keywords is not None, 'at least one keyword is needed'
|
assert args.keywords is not None, 'at least one keyword is needed'
|
||||||
keywords_str = args.keywords
|
logging.info(f"keywords is {args.keywords}, "
|
||||||
|
f"Chinese is converted into Unicode.")
|
||||||
|
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
|
||||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||||
keywords_token = {}
|
keywords_token = {}
|
||||||
keywords_idxset = {0}
|
keywords_idxset = {0}
|
||||||
@ -167,7 +172,9 @@ def main():
|
|||||||
for i in range(len(keys)):
|
for i in range(len(keys)):
|
||||||
key = keys[i]
|
key = keys[i]
|
||||||
score = logits[i][:lengths[i]]
|
score = logits[i][:lengths[i]]
|
||||||
hyps = ctc_prefix_beam_search(score, lengths[i], keywords_idxset)
|
hyps = ctc_prefix_beam_search(score,
|
||||||
|
lengths[i],
|
||||||
|
keywords_idxset)
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
start = 0
|
start = 0
|
||||||
@ -192,10 +199,13 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if hit_keyword is not None:
|
if hit_keyword is not None:
|
||||||
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
|
fout.write('{} detected {} {:.3f}\n'.format(
|
||||||
|
key, hit_keyword, hit_score))
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||||
f"duration {end - start}, score {hit_score}, Activated.")
|
f"in {key} from {start} to {end} frame. "
|
||||||
|
f"duration {end - start}, "
|
||||||
|
f"score {hit_score}, Activated.")
|
||||||
else:
|
else:
|
||||||
fout.write('{} rejected\n'.format(key))
|
fout.write('{} rejected\n'.format(key))
|
||||||
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||||
|
|||||||
@ -16,7 +16,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import struct
|
import struct
|
||||||
import wave
|
#import wave
|
||||||
|
import librosa
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
@ -35,9 +36,12 @@ from tools.make_list import query_token_set, read_lexicon, read_token
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(description='detect keywords online.')
|
parser = argparse.ArgumentParser(description='detect keywords online.')
|
||||||
parser.add_argument('--config', required=True, help='config file')
|
parser.add_argument('--config', required=True, help='config file')
|
||||||
parser.add_argument('--wav_path', required=False, default=None, help='test wave path.')
|
parser.add_argument('--wav_path', required=False,
|
||||||
parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.')
|
default=None, help='test wave path.')
|
||||||
parser.add_argument('--result_file', required=False, default=None, help='test result.')
|
parser.add_argument('--wav_scp', required=False,
|
||||||
|
default=None, help='test wave scp.')
|
||||||
|
parser.add_argument('--result_file', required=False,
|
||||||
|
default=None, help='test result.')
|
||||||
|
|
||||||
parser.add_argument('--gpu',
|
parser.add_argument('--gpu',
|
||||||
type=int,
|
type=int,
|
||||||
@ -48,21 +52,27 @@ def get_args():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help='Use pinned memory buffers used for reading')
|
help='Use pinned memory buffers used for reading')
|
||||||
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
|
parser.add_argument('--keywords', type=str, default=None,
|
||||||
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
help='the keywords, split with comma(,)')
|
||||||
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
parser.add_argument('--token_file', type=str, default=None,
|
||||||
|
help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||||
|
help='the path of lexicon.txt')
|
||||||
parser.add_argument('--score_beam_size',
|
parser.add_argument('--score_beam_size',
|
||||||
default=3,
|
default=3,
|
||||||
type=int,
|
type=int,
|
||||||
help='The first prune beam, filter out those frames with low scores.')
|
help='The first prune beam, '
|
||||||
|
'filter out those frames with low scores.')
|
||||||
parser.add_argument('--path_beam_size',
|
parser.add_argument('--path_beam_size',
|
||||||
default=20,
|
default=20,
|
||||||
type=int,
|
type=int,
|
||||||
help='The second prune beam, keep only path_beam_size candidates.')
|
help='The second prune beam, '
|
||||||
|
'keep only path_beam_size candidates.')
|
||||||
parser.add_argument('--threshold',
|
parser.add_argument('--threshold',
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help='The threshold of kws. If ctc_search probs exceed this value,'
|
help='The threshold of kws. '
|
||||||
|
'If ctc_search probs exceed this value,'
|
||||||
'the keyword will be activated.')
|
'the keyword will be activated.')
|
||||||
parser.add_argument('--min_frames',
|
parser.add_argument('--min_frames',
|
||||||
default=5,
|
default=5,
|
||||||
@ -98,16 +108,22 @@ def is_sublist(main_list, check_list):
|
|||||||
else:
|
else:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size):
|
def ctc_prefix_beam_search(t, probs,
|
||||||
|
cur_hyps,
|
||||||
|
keywords_idxset,
|
||||||
|
score_beam_size):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
:param t: the time in frame
|
:param t: the time in frame
|
||||||
:param probs: the probability in t_th frame, (vocab_size, )
|
:param probs: the probability in t_th frame, (vocab_size, )
|
||||||
:param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))]
|
:param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))]
|
||||||
in tuple, 1st is prefix id, 2nd include p_blank, p_non_blank, and path nodes list.
|
in tuple, 1st is prefix id, 2nd include p_blank,
|
||||||
in path nodes list, each node is a dict of {token=idx, frame=t, prob=ps}
|
p_non_blank, and path nodes list.
|
||||||
|
in path nodes list, each node is
|
||||||
|
a dict of {token=idx, frame=t, prob=ps}
|
||||||
:param keywords_idxset: the index of keywords in token.txt
|
:param keywords_idxset: the index of keywords in token.txt
|
||||||
:param score_beam_size: the probability threshold, to filter out those frames with low probs.
|
:param score_beam_size: the probability threshold,
|
||||||
|
to filter out those frames with low probs.
|
||||||
:return:
|
:return:
|
||||||
next_hyps: the hypothesis depend on current hyp and current frame.
|
next_hyps: the hypothesis depend on current hyp and current frame.
|
||||||
'''
|
'''
|
||||||
@ -170,7 +186,8 @@ def ctc_prefix_beam_search(t, probs, cur_hyps, keywords_idxset, score_beam_size)
|
|||||||
if ps > nodes[-1]['prob']: # update frame and prob
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
# nodes[-1]['prob'] = ps
|
# nodes[-1]['prob'] = ps
|
||||||
# nodes[-1]['frame'] = t
|
# nodes[-1]['frame'] = t
|
||||||
nodes.pop() # to avoid change other beam which has this node.
|
nodes.pop()
|
||||||
|
# to avoid change other beam which has this node.
|
||||||
nodes.append(dict(token=s, frame=t, prob=ps))
|
nodes.append(dict(token=s, frame=t, prob=ps))
|
||||||
else:
|
else:
|
||||||
nodes = cur_nodes.copy()
|
nodes = cur_nodes.copy()
|
||||||
@ -199,11 +216,14 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
# feature related
|
# feature related
|
||||||
self.sample_rate = 16000
|
self.sample_rate = 16000
|
||||||
self.wave_remained = np.array([])
|
self.wave_remained = np.array([])
|
||||||
self.num_mel_bins = dataset_conf['feature_extraction_conf']['num_mel_bins']
|
self.num_mel_bins = dataset_conf[
|
||||||
self.frame_length = dataset_conf['feature_extraction_conf']['frame_length'] # in ms
|
'feature_extraction_conf']['num_mel_bins']
|
||||||
self.frame_shift = dataset_conf['feature_extraction_conf']['frame_shift'] # in ms
|
self.frame_length = dataset_conf[
|
||||||
|
'feature_extraction_conf']['frame_length'] # in ms
|
||||||
|
self.frame_shift = dataset_conf[
|
||||||
|
'feature_extraction_conf']['frame_shift'] # in ms
|
||||||
self.downsampling = dataset_conf.get('frame_skip', 1)
|
self.downsampling = dataset_conf.get('frame_skip', 1)
|
||||||
self.resolution = self.frame_shift / 1000 # in second
|
self.resolution = self.frame_shift / 1000 # in second
|
||||||
# fsmn splice operation
|
# fsmn splice operation
|
||||||
self.context_expansion = dataset_conf.get('context_expansion', False)
|
self.context_expansion = dataset_conf.get('context_expansion', False)
|
||||||
self.left_context = 0
|
self.left_context = 0
|
||||||
@ -231,9 +251,11 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
logging.info(f'model {ckpt_path} loaded.')
|
logging.info(f'model {ckpt_path} loaded.')
|
||||||
self.token_table = read_token(token_path)
|
self.token_table = read_token(token_path)
|
||||||
logging.info(f'tokens {token_path} with {len(self.token_table)} units loaded.')
|
logging.info(f'tokens {token_path} with '
|
||||||
|
f'{len(self.token_table)} units loaded.')
|
||||||
self.lexicon_table = read_lexicon(lexicon_path)
|
self.lexicon_table = read_lexicon(lexicon_path)
|
||||||
logging.info(f'lexicons {lexicon_path} with {len(self.lexicon_table)} units loaded.')
|
logging.info(f'lexicons {lexicon_path} with '
|
||||||
|
f'{len(self.lexicon_table)} units loaded.')
|
||||||
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
|
||||||
|
|
||||||
@ -257,7 +279,9 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
|
|
||||||
def set_keywords(self, keywords):
|
def set_keywords(self, keywords):
|
||||||
# 4. parse keywords tokens
|
# 4. parse keywords tokens
|
||||||
assert keywords is not None, 'at least one keyword is needed, multiple keywords should be splitted with comma(,)'
|
assert keywords is not None, \
|
||||||
|
'at least one keyword is needed, ' \
|
||||||
|
'multiple keywords should be splitted with comma(,)'
|
||||||
keywords_str = keywords
|
keywords_str = keywords
|
||||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||||
keywords_token = {}
|
keywords_token = {}
|
||||||
@ -265,7 +289,8 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
keywords_strset = {'<blk>'}
|
keywords_strset = {'<blk>'}
|
||||||
keywords_tokenmap = {'<blk>': 0}
|
keywords_tokenmap = {'<blk>': 0}
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
strs, indexes = query_token_set(keyword, self.token_table, self.lexicon_table)
|
strs, indexes = query_token_set(
|
||||||
|
keyword, self.token_table, self.lexicon_table)
|
||||||
keywords_token[keyword] = {}
|
keywords_token[keyword] = {}
|
||||||
keywords_token[keyword]['token_id'] = indexes
|
keywords_token[keyword]['token_id'] = indexes
|
||||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||||
@ -284,16 +309,20 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.keywords_token = keywords_token
|
self.keywords_token = keywords_token
|
||||||
|
|
||||||
def accept_wave(self, wave):
|
def accept_wave(self, wave):
|
||||||
assert isinstance(wave, bytes), "please make sure the input format is bytes(raw PCM)"
|
assert isinstance(wave, bytes), \
|
||||||
|
"please make sure the input format is bytes(raw PCM)"
|
||||||
# convert bytes into float32
|
# convert bytes into float32
|
||||||
data = []
|
data = []
|
||||||
for i in range(0, len(wave), 2):
|
for i in range(0, len(wave), 2):
|
||||||
value = struct.unpack('<h', wave[i:i + 2])[0]
|
value = struct.unpack('<h', wave[i:i + 2])[0]
|
||||||
data.append(value) # here we don't divide 32768.0, because kaldi.fbank accept original input
|
data.append(value)
|
||||||
|
# here we don't divide 32768.0,
|
||||||
|
# because kaldi.fbank accept original input
|
||||||
|
|
||||||
wave = np.array(data)
|
wave = np.array(data)
|
||||||
wave = np.append(self.wave_remained, wave)
|
wave = np.append(self.wave_remained, wave)
|
||||||
if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context :
|
if wave.size < (self.frame_length * self.sample_rate / 1000) \
|
||||||
|
* self.right_context :
|
||||||
self.wave_remained = wave
|
self.wave_remained = wave
|
||||||
return None
|
return None
|
||||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||||
@ -311,29 +340,37 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.wave_remained = wave[feat_len * frame_shift:]
|
self.wave_remained = wave[feat_len * frame_shift:]
|
||||||
|
|
||||||
if self.context_expansion:
|
if self.context_expansion:
|
||||||
assert feat_len > self.right_context, "make sure each chunk feat length is large than right context."
|
assert feat_len > self.right_context, \
|
||||||
|
"make sure each chunk feat length is large than right context."
|
||||||
# pad feats with remained feature from last chunk
|
# pad feats with remained feature from last chunk
|
||||||
if self.feature_remained is None: # first chunk
|
if self.feature_remained is None: # first chunk
|
||||||
# pad first frame at the beginning, replicate just support last dimension, so we do transpose.
|
# pad first frame at the beginning,
|
||||||
feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T
|
# replicate just support last dimension, so we do transpose.
|
||||||
|
feats_pad = F.pad(
|
||||||
|
feats.T, (self.left_context, 0), mode='replicate').T
|
||||||
else:
|
else:
|
||||||
feats_pad = torch.cat((self.feature_remained, feats))
|
feats_pad = torch.cat((self.feature_remained, feats))
|
||||||
|
|
||||||
ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context)
|
ctx_frm = feats_pad.shape[0] - \
|
||||||
|
(self.right_context+self.right_context)
|
||||||
ctx_win = (self.left_context + self.right_context + 1)
|
ctx_win = (self.left_context + self.right_context + 1)
|
||||||
ctx_dim = feats.shape[1] * ctx_win
|
ctx_dim = feats.shape[1] * ctx_win
|
||||||
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
||||||
for i in range(ctx_frm):
|
for i in range(ctx_frm):
|
||||||
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
feats_ctx[i] = torch.cat(
|
||||||
|
tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
||||||
|
|
||||||
# update feature remained, and feats
|
# update feature remained, and feats
|
||||||
self.feature_remained = feats[-(self.left_context+self.right_context):]
|
self.feature_remained = \
|
||||||
|
feats[-(self.left_context+self.right_context):]
|
||||||
feats = feats_ctx.to(self.device)
|
feats = feats_ctx.to(self.device)
|
||||||
if self.downsampling > 1:
|
if self.downsampling > 1:
|
||||||
last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset
|
last_remainder = 0 if self.feats_ctx_offset==0 \
|
||||||
|
else self.downsampling-self.feats_ctx_offset
|
||||||
remainder = (feats.size(0)+last_remainder) % self.downsampling
|
remainder = (feats.size(0)+last_remainder) % self.downsampling
|
||||||
feats = feats[self.feats_ctx_offset::self.downsampling, :]
|
feats = feats[self.feats_ctx_offset::self.downsampling, :]
|
||||||
self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder
|
self.feats_ctx_offset = remainder \
|
||||||
|
if remainder == 0 else self.downsampling-remainder
|
||||||
return feats
|
return feats
|
||||||
|
|
||||||
def decode_keywords(self, t, probs):
|
def decode_keywords(self, t, probs):
|
||||||
@ -344,7 +381,8 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.cur_hyps,
|
self.cur_hyps,
|
||||||
self.keywords_idxset,
|
self.keywords_idxset,
|
||||||
self.score_beam)
|
self.score_beam)
|
||||||
# update cur_hyps. note: the hyps is sort by path score(pnb+pb), not the keywords' probabilities.
|
# update cur_hyps. note: the hyps is sort by path score(pnb+pb),
|
||||||
|
# not the keywords' probabilities.
|
||||||
cur_hyps = next_hyps[:self.path_beam]
|
cur_hyps = next_hyps[:self.path_beam]
|
||||||
self.cur_hyps = cur_hyps
|
self.cur_hyps = cur_hyps
|
||||||
|
|
||||||
@ -381,27 +419,36 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
if hit_keyword is not None:
|
if hit_keyword is not None:
|
||||||
if self.hit_score >= self.threshold and \
|
if self.hit_score >= self.threshold and \
|
||||||
self.min_frames <= duration <= self.max_frames \
|
self.min_frames <= duration <= self.max_frames \
|
||||||
and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames):
|
and (self.last_active_pos==-1 or
|
||||||
|
end-self.last_active_pos >= self.interval_frames):
|
||||||
self.activated = True
|
self.activated = True
|
||||||
self.last_active_pos = end
|
self.last_active_pos = end
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} "
|
||||||
|
f"from {start} to {end} frame. "
|
||||||
f"duration {duration}, score {self.hit_score}, Activated.")
|
f"duration {duration}, score {self.hit_score}, Activated.")
|
||||||
|
|
||||||
elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames:
|
elif self.last_active_pos>0 and \
|
||||||
|
end-self.last_active_pos < self.interval_frames:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} "
|
||||||
f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ")
|
f"from {start} to {end} frame. "
|
||||||
|
f"but interval {end-self.last_active_pos} "
|
||||||
|
f"is lower than {self.interval_frames}, Deactivated. ")
|
||||||
|
|
||||||
elif self.hit_score < self.threshold:
|
elif self.hit_score < self.threshold:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} "
|
||||||
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ")
|
f"from {start} to {end} frame. "
|
||||||
|
f"but {self.hit_score} "
|
||||||
|
f"is lower than {self.threshold}, Deactivated. ")
|
||||||
|
|
||||||
elif self.min_frames > duration or duration > self.max_frames:
|
elif self.min_frames > duration or duration > self.max_frames:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} "
|
||||||
f"but {duration} beyond range({self.min_frames}~{self.max_frames}), Deactivated. ")
|
f"from {start} to {end} frame. "
|
||||||
|
f"but {duration} beyond range"
|
||||||
|
f"({self.min_frames}~{self.max_frames}), Deactivated. ")
|
||||||
|
|
||||||
self.result = {
|
self.result = {
|
||||||
"state": 1 if self.activated else 0,
|
"state": 1 if self.activated else 0,
|
||||||
@ -418,7 +465,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
feature = feature.unsqueeze(0) # add a batch dimension
|
feature = feature.unsqueeze(0) # add a batch dimension
|
||||||
logits, self.in_cache = self.model(feature, self.in_cache)
|
logits, self.in_cache = self.model(feature, self.in_cache)
|
||||||
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||||
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search
|
probs = probs[0].cpu() # remove batch dimension
|
||||||
for (t, prob) in enumerate(probs):
|
for (t, prob) in enumerate(probs):
|
||||||
t *= self.downsampling
|
t *= self.downsampling
|
||||||
self.decode_keywords(t, prob)
|
self.decode_keywords(t, prob)
|
||||||
@ -426,10 +473,14 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
|
|
||||||
if self.activated:
|
if self.activated:
|
||||||
self.reset()
|
self.reset()
|
||||||
# since a chunk include about 30 frames, once activated, we can jump the latter frames.
|
# since a chunk include about 30 frames,
|
||||||
# TODO: there should give another method to update result, avoiding self.result being cleared.
|
# once activated, we can jump the latter frames.
|
||||||
|
# TODO: there should give another method to update result,
|
||||||
|
# avoiding self.result being cleared.
|
||||||
break
|
break
|
||||||
self.total_frames += len(probs) * self.downsampling # update frame offset
|
|
||||||
|
# update frame offset
|
||||||
|
self.total_frames += len(probs) * self.downsampling
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -465,15 +516,20 @@ def demo():
|
|||||||
args.gpu,
|
args.gpu,
|
||||||
args.jit_model)
|
args.jit_model)
|
||||||
|
|
||||||
# actually this could be done in __init__ method, we pull it outside for changing keywords more freely.
|
# actually this could be done in __init__ method,
|
||||||
|
# we pull it outside for changing keywords more freely.
|
||||||
kws.set_keywords(args.keywords)
|
kws.set_keywords(args.keywords)
|
||||||
|
|
||||||
if args.wav_path:
|
if args.wav_path:
|
||||||
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
||||||
# In demo we read wave in non-streaming fashion.
|
# In demo we read wave in non-streaming fashion.
|
||||||
with wave.open(args.wav_path, 'rb') as fin:
|
# with wave.open(args.wav_path, 'rb') as fin:
|
||||||
assert fin.getnchannels() == 1
|
# assert fin.getnchannels() == 1
|
||||||
wav = fin.readframes(fin.getnframes())
|
# wav = fin.readframes(fin.getnframes())
|
||||||
|
|
||||||
|
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
|
||||||
|
# NOTE: model supports 16k sample_rate
|
||||||
|
wav = (y * (1 << 15)).astype("int16").tobytes()
|
||||||
|
|
||||||
# We inference every 0.3 seconds, in streaming fashion.
|
# We inference every 0.3 seconds, in streaming fashion.
|
||||||
interval = int(0.3 * 16000) * 2
|
interval = int(0.3 * 16000) * 2
|
||||||
@ -490,12 +546,18 @@ def demo():
|
|||||||
with open(args.wav_scp, 'r') as fscp:
|
with open(args.wav_scp, 'r') as fscp:
|
||||||
for line in fscp:
|
for line in fscp:
|
||||||
line = line.strip().split()
|
line = line.strip().split()
|
||||||
assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}"
|
assert len(line) == 2, \
|
||||||
|
f"The scp should be in kaldi format: " \
|
||||||
|
f"\"utt_name wav_path\", but got {line}"
|
||||||
|
|
||||||
utt_name, wav_path = line[0], line[1]
|
utt_name, wav_path = line[0], line[1]
|
||||||
with wave.open(wav_path, 'rb') as fin:
|
# with wave.open(args.wav_path, 'rb') as fin:
|
||||||
assert fin.getnchannels() == 1
|
# assert fin.getnchannels() == 1
|
||||||
wav = fin.readframes(fin.getnframes())
|
# wav = fin.readframes(fin.getnframes())
|
||||||
|
|
||||||
|
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
|
||||||
|
# NOTE: model supports 16k sample_rate
|
||||||
|
wav = (y * (1 << 15)).astype("int16").tobytes()
|
||||||
|
|
||||||
kws.reset_all()
|
kws.reset_all()
|
||||||
activated = False
|
activated = False
|
||||||
@ -510,7 +572,8 @@ def demo():
|
|||||||
if fout:
|
if fout:
|
||||||
hit_keyword = result['keyword']
|
hit_keyword = result['keyword']
|
||||||
hit_score = result['score']
|
hit_score = result['score']
|
||||||
fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score))
|
fout.write('{} detected {} {:.3f}\n'.format(
|
||||||
|
utt_name, hit_keyword, hit_score))
|
||||||
|
|
||||||
if not activated:
|
if not activated:
|
||||||
if fout:
|
if fout:
|
||||||
|
|||||||
@ -66,30 +66,36 @@ def get_args():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help='Use pinned memory buffers used for reading')
|
help='Use pinned memory buffers used for reading')
|
||||||
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
|
parser.add_argument('--keywords', type=str, default=None,
|
||||||
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
help='the keywords, split with comma(,)')
|
||||||
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
parser.add_argument('--token_file', type=str, default=None,
|
||||||
|
help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||||
|
help='the path of lexicon.txt')
|
||||||
parser.add_argument('--score_beam_size',
|
parser.add_argument('--score_beam_size',
|
||||||
default=3,
|
default=3,
|
||||||
type=int,
|
type=int,
|
||||||
help='The first prune beam, filter out those frames with low scores.')
|
help='The first prune beam, f'
|
||||||
|
'ilter out those frames with low scores.')
|
||||||
parser.add_argument('--path_beam_size',
|
parser.add_argument('--path_beam_size',
|
||||||
default=20,
|
default=20,
|
||||||
type=int,
|
type=int,
|
||||||
help='The second prune beam, keep only path_beam_size candidates.')
|
help='The second prune beam, '
|
||||||
|
'keep only path_beam_size candidates.')
|
||||||
parser.add_argument('--threshold',
|
parser.add_argument('--threshold',
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help='The threshold of kws. If ctc_search probs exceed this value,'
|
help='The threshold of kws. '
|
||||||
|
'If ctc_search probs exceed this value,'
|
||||||
'the keyword will be activated.')
|
'the keyword will be activated.')
|
||||||
parser.add_argument('--min_frames',
|
parser.add_argument('--min_frames',
|
||||||
default=5,
|
default=5,
|
||||||
type=int,
|
type=int,
|
||||||
help='The min frames of keyword\'s duration.')
|
help='The min frames of keyword duration.')
|
||||||
parser.add_argument('--max_frames',
|
parser.add_argument('--max_frames',
|
||||||
default=250,
|
default=250,
|
||||||
type=int,
|
type=int,
|
||||||
help='The max frames of keyword\'s duration.')
|
help='The max frames of keyword duration.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -158,7 +164,9 @@ def main():
|
|||||||
lexicon_table = read_lexicon(args.lexicon_file)
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
# 4. parse keywords tokens
|
# 4. parse keywords tokens
|
||||||
assert args.keywords is not None, 'at least one keyword is needed'
|
assert args.keywords is not None, 'at least one keyword is needed'
|
||||||
keywords_str = args.keywords
|
logging.info(f"keywords is {args.keywords}, "
|
||||||
|
f"Chinese is converted into Unicode.")
|
||||||
|
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
|
||||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||||
keywords_token = {}
|
keywords_token = {}
|
||||||
keywords_idxset = {0}
|
keywords_idxset = {0}
|
||||||
@ -217,7 +225,8 @@ def main():
|
|||||||
# filter prob score that is too small
|
# filter prob score that is too small
|
||||||
filter_probs = []
|
filter_probs = []
|
||||||
filter_index = []
|
filter_index = []
|
||||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
for prob, idx in zip(
|
||||||
|
top_k_probs.tolist(), top_k_index.tolist()):
|
||||||
if keywords_idxset is not None:
|
if keywords_idxset is not None:
|
||||||
if prob > 0.05 and idx in keywords_idxset:
|
if prob > 0.05 and idx in keywords_idxset:
|
||||||
filter_probs.append(prob)
|
filter_probs.append(prob)
|
||||||
@ -246,7 +255,8 @@ def main():
|
|||||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||||
n_pnb = n_pnb + pnb * ps
|
n_pnb = n_pnb + pnb * ps
|
||||||
nodes = cur_nodes.copy()
|
nodes = cur_nodes.copy()
|
||||||
if ps > nodes[-1]['prob']: # update frame and prob
|
# update frame and prob
|
||||||
|
if ps > nodes[-1]['prob']:
|
||||||
nodes[-1]['prob'] = ps
|
nodes[-1]['prob'] = ps
|
||||||
nodes[-1]['frame'] = t
|
nodes[-1]['frame'] = t
|
||||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||||
@ -257,32 +267,37 @@ def main():
|
|||||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
n_pnb = n_pnb + pb * ps
|
n_pnb = n_pnb + pb * ps
|
||||||
nodes = cur_nodes.copy()
|
nodes = cur_nodes.copy()
|
||||||
nodes.append(dict(token=s, frame=t,
|
nodes.append(dict(
|
||||||
prob=ps)) # to record token prob
|
token=s, frame=t, prob=ps))
|
||||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||||
else:
|
else:
|
||||||
n_prefix = prefix + (s,)
|
n_prefix = prefix + (s,)
|
||||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
if nodes:
|
if nodes:
|
||||||
if ps > nodes[-1]['prob']: # update frame and prob
|
# update frame and prob
|
||||||
|
if ps > nodes[-1]['prob']:
|
||||||
# nodes[-1]['prob'] = ps
|
# nodes[-1]['prob'] = ps
|
||||||
# nodes[-1]['frame'] = t
|
# nodes[-1]['frame'] = t
|
||||||
nodes.pop() # to avoid change other beam which has this node.
|
# avoid change other beam has this node.
|
||||||
nodes.append(dict(token=s, frame=t, prob=ps))
|
nodes.pop()
|
||||||
|
nodes.append(dict(
|
||||||
|
token=s, frame=t, prob=ps))
|
||||||
else:
|
else:
|
||||||
nodes = cur_nodes.copy()
|
nodes = cur_nodes.copy()
|
||||||
nodes.append(dict(token=s, frame=t,
|
nodes.append(dict(
|
||||||
prob=ps)) # to record token prob
|
token=s, frame=t, prob=ps))
|
||||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
|
||||||
# 2.2 Second beam prune
|
# 2.2 Second beam prune
|
||||||
next_hyps = sorted(
|
next_hyps = sorted(
|
||||||
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
next_hyps.items(),
|
||||||
|
key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||||
|
|
||||||
cur_hyps = next_hyps[:args.path_beam_size]
|
cur_hyps = next_hyps[:args.path_beam_size]
|
||||||
|
|
||||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
hyps = [(y[0], y[1][0] + y[1][1], y[1][2])
|
||||||
|
for y in cur_hyps]
|
||||||
|
|
||||||
for one_hyp in hyps:
|
for one_hyp in hyps:
|
||||||
prefix_ids = one_hyp[0]
|
prefix_ids = one_hyp[0]
|
||||||
@ -295,7 +310,8 @@ def main():
|
|||||||
if offset != -1:
|
if offset != -1:
|
||||||
hit_keyword = word
|
hit_keyword = word
|
||||||
start = prefix_nodes[offset]['frame']
|
start = prefix_nodes[offset]['frame']
|
||||||
end = prefix_nodes[offset + len(lab) - 1]['frame']
|
end = prefix_nodes[
|
||||||
|
offset + len(lab) - 1]['frame']
|
||||||
for idx in range(offset, offset + len(lab)):
|
for idx in range(offset, offset + len(lab)):
|
||||||
hit_score *= prefix_nodes[idx]['prob']
|
hit_score *= prefix_nodes[idx]['prob']
|
||||||
break
|
break
|
||||||
@ -305,25 +321,35 @@ def main():
|
|||||||
|
|
||||||
duration = end - start
|
duration = end - start
|
||||||
if hit_keyword is not None:
|
if hit_keyword is not None:
|
||||||
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
|
if hit_score >= args.threshold and \
|
||||||
|
args.min_frames <= duration <= args.max_frames:
|
||||||
activated = True
|
activated = True
|
||||||
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))
|
fout.write('{} detected {} {:.3f}\n'.format(
|
||||||
|
key, hit_keyword, hit_score))
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||||
f"duration {duration}, score {hit_score} Activated.")
|
f"in {key} from {start} to {end} frame. "
|
||||||
|
f"duration {duration}, s"
|
||||||
|
f"core {hit_score} Activated.")
|
||||||
|
|
||||||
# clear the ctc_prefix buffer, and clear hit_keyword
|
# clear the ctc_prefix buffer, and hit_keyword
|
||||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
elif hit_score < args.threshold:
|
elif hit_score < args.threshold:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||||
f"but {hit_score} less than {args.threshold}, Deactivated. ")
|
f"in {key} from {start} to {end} frame. "
|
||||||
elif args.min_frames > duration or duration > args.max_frames:
|
f"but {hit_score} less than "
|
||||||
|
f"{args.threshold}, Deactivated. ")
|
||||||
|
elif args.min_frames > duration \
|
||||||
|
or duration > args.max_frames:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||||
f"but {duration} beyond range({args.min_frames}~{args.max_frames}), Deactivated. ")
|
f"in {key} from {start} to {end} frame. "
|
||||||
|
f"but {duration} beyond "
|
||||||
|
f"range({args.min_frames}~{args.max_frames}), "
|
||||||
|
f"Deactivated. ")
|
||||||
if not activated:
|
if not activated:
|
||||||
fout.write('{} rejected\n'.format(key))
|
fout.write('{} rejected\n'.format(key))
|
||||||
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||||
|
|||||||
@ -159,8 +159,12 @@ def main():
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
pass
|
pass
|
||||||
# TODO: for now streaming FSMN do not support export to JITScript,
|
# TODO: for now streaming FSMN do not support export to JITScript,
|
||||||
# TODO: because there is nn.Sequential with Tuple input in current FSMN modules.
|
# TODO: because there is nn.Sequential with Tuple input
|
||||||
# the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
|
# in current FSMN modules.
|
||||||
|
# the issue is in https://stackoverflow.com/questions/75714299/
|
||||||
|
# pytorch-jit-script-error-when-sequential-container-
|
||||||
|
# takes-a-tuple-input/76553450#76553450
|
||||||
|
|
||||||
# script_model = torch.jit.script(model)
|
# script_model = torch.jit.script(model)
|
||||||
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||||
executor = Executor()
|
executor = Executor()
|
||||||
|
|||||||
@ -225,16 +225,19 @@ class FSMNBlock(nn.Module):
|
|||||||
x_per = x.permute(0, 3, 2, 1)
|
x_per = x.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
if in_cache is None or len(in_cache) == 0 :
|
if in_cache is None or len(in_cache) == 0 :
|
||||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
|
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride
|
||||||
|
+ self.rorder * self.rstride, 0])
|
||||||
else:
|
else:
|
||||||
in_cache = in_cache.to(x_per.device)
|
in_cache = in_cache.to(x_per.device)
|
||||||
x_pad = torch.cat((in_cache, x_per), dim=2)
|
x_pad = torch.cat((in_cache, x_per), dim=2)
|
||||||
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride + self.rorder * self.rstride):, :]
|
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride
|
||||||
|
+ self.rorder * self.rstride):, :]
|
||||||
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
|
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
|
||||||
y_left = self.quant(y_left)
|
y_left = self.quant(y_left)
|
||||||
y_left = self.conv_left(y_left)
|
y_left = self.conv_left(y_left)
|
||||||
y_left = self.dequant(y_left)
|
y_left = self.dequant(y_left)
|
||||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left
|
out = x_pad[:, :, (self.lorder - 1) * self.lstride:
|
||||||
|
-self.rorder * self.rstride, :] + y_left
|
||||||
|
|
||||||
if self.conv_right is not None:
|
if self.conv_right is not None:
|
||||||
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||||
@ -253,7 +256,8 @@ class FSMNBlock(nn.Module):
|
|||||||
def to_kaldi_net(self):
|
def to_kaldi_net(self):
|
||||||
re_str = ''
|
re_str = ''
|
||||||
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
|
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
|
||||||
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
|
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
|
||||||
|
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
|
||||||
1, self.lorder, self.rorder, self.lstride, self.rstride)
|
1, self.lorder, self.rorder, self.lstride, self.rstride)
|
||||||
|
|
||||||
# print(self.conv_left.weight,self.conv_right.weight)
|
# print(self.conv_left.weight,self.conv_right.weight)
|
||||||
@ -441,7 +445,8 @@ class FSMN(nn.Module):
|
|||||||
self.output_affine_dim = output_affine_dim
|
self.output_affine_dim = output_affine_dim
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
|
|
||||||
self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride
|
self.padding = (self.lorder-1) * self.lstride \
|
||||||
|
+ self.rorder * self.rstride
|
||||||
|
|
||||||
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||||
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||||
@ -469,7 +474,8 @@ class FSMN(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if in_cache is None or len(in_cache) == 0 :
|
if in_cache is None or len(in_cache) == 0 :
|
||||||
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))]
|
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||||
|
for _ in range(len(self.fsmn))]
|
||||||
input = (input, in_cache)
|
input = (input, in_cache)
|
||||||
x1 = self.in_linear1(input)
|
x1 = self.in_linear1(input)
|
||||||
x2 = self.in_linear2(x1)
|
x2 = self.in_linear2(x1)
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class KWSModel(nn.Module):
|
|||||||
"""Our model consists of four parts:
|
"""Our model consists of four parts:
|
||||||
1. global_cmvn: Optional, (idim, idim)
|
1. global_cmvn: Optional, (idim, idim)
|
||||||
2. preprocessing: feature dimention projection, (idim, hdim)
|
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
3. backbone: backbone of the whole network, (hdim, hdim)
|
||||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||||
5. activation:
|
5. activation:
|
||||||
nn.Sigmoid for wakeup word
|
nn.Sigmoid for wakeup word
|
||||||
@ -76,7 +76,8 @@ class KWSModel(nn.Module):
|
|||||||
|
|
||||||
def forward_softmax(self,
|
def forward_softmax(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
in_cache: torch.Tensor = torch.zeros(
|
||||||
|
0, 0, 0, dtype=torch.float)
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
@ -196,7 +197,8 @@ def init_model(configs):
|
|||||||
classifier = LinearClassifier(hidden_dim, output_dim)
|
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||||
activation = nn.Sigmoid()
|
activation = nn.Sigmoid()
|
||||||
|
|
||||||
# Here we add a possible "activation_type", one can choose to use other activation function.
|
# Here we add a possible "activation_type",
|
||||||
|
# one can choose to use other activation function.
|
||||||
# We use nn.Identity just for CTC loss
|
# We use nn.Identity just for CTC loss
|
||||||
if "activation" in configs:
|
if "activation" in configs:
|
||||||
activation_type = configs["activation"]["type"]
|
activation_type = configs["activation"]["type"]
|
||||||
|
|||||||
@ -188,7 +188,8 @@ def criterion(type: str,
|
|||||||
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
elif type == 'ctc':
|
elif type == 'ctc':
|
||||||
loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation)
|
loss, acc = ctc_loss(
|
||||||
|
logits, target, lengths, target_lengths, validation)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
else:
|
else:
|
||||||
exit(1)
|
exit(1)
|
||||||
@ -281,7 +282,8 @@ def ctc_prefix_beam_search(
|
|||||||
if ps > nodes[-1]['prob']: # update frame and prob
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
# nodes[-1]['prob'] = ps
|
# nodes[-1]['prob'] = ps
|
||||||
# nodes[-1]['frame'] = t
|
# nodes[-1]['frame'] = t
|
||||||
nodes.pop() # to avoid change other beam which has this node.
|
# avoid change other beam which has this node.
|
||||||
|
nodes.pop()
|
||||||
nodes.append(dict(token=s, frame=t, prob=ps))
|
nodes.append(dict(token=s, frame=t, prob=ps))
|
||||||
else:
|
else:
|
||||||
nodes = cur_nodes.copy()
|
nodes = cur_nodes.copy()
|
||||||
@ -429,7 +431,8 @@ class Calculator:
|
|||||||
break
|
break
|
||||||
else: # shouldn't reach here
|
else: # shouldn't reach here
|
||||||
print(
|
print(
|
||||||
'this should not happen , i = {i} , j = {j} , error = {error}'
|
'this should not happen, '
|
||||||
|
'i = {i} , j = {j} , error = {error}'
|
||||||
.format(i=i, j=j, error=self.space[i][j]['error']))
|
.format(i=i, j=j, error=self.space[i][j]['error']))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user