Add streaming detection of CTC model. Add CTC model onnx export. Add CTC model's result in README; For now CTC model runtime is not supported yet.
This commit is contained in:
parent
6d7e7784b5
commit
0ac7a417b1
@ -1,4 +1,4 @@
|
|||||||
FRRs with FAR fixed at once per hour:
|
Comparison among different backbones. FRRs with FAR fixed at once per hour:
|
||||||
|
|
||||||
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|
||||||
|-----------------------|-----------|-----------|------------|--------------|
|
|-----------------------|-----------|-----------|------------|--------------|
|
||||||
@ -8,3 +8,50 @@ FRRs with FAR fixed at once per hour:
|
|||||||
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
|
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
|
||||||
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||||
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||||
|
|
||||||
|
Next, we use CTC loss to train the model, with DS_TCN and FSMN.
|
||||||
|
and we use CTC prefix beam search to decode and detect keywords,
|
||||||
|
the detection is either in non-streaming or streaming fashion.
|
||||||
|
|
||||||
|
Since the FAR is pretty low when using CTC loss,
|
||||||
|
the follow result is FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
|
Comparison between Max-pooling and CTC loss.
|
||||||
|
The CTC model is fine-tuned with base model trained on WenetSpeech(23 epoch).
|
||||||
|
FRRs with FAR fixed at once per 12 hours
|
||||||
|
|
||||||
|
|
||||||
|
| model | loss | hi_xiaowen | nihao_wenwen |
|
||||||
|
|-----------------------|-------------|------------|--------------|
|
||||||
|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 |
|
||||||
|
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
|
||||||
|
|
||||||
|
|
||||||
|
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch)
|
||||||
|
and FSMN(modelscope released, xiaoyunxiaoyun model).
|
||||||
|
FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
|
| model | params(K) | hi_xiaowen | nihao_wenwen |
|
||||||
|
|-----------------------|-------------|------------|--------------|
|
||||||
|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 |
|
||||||
|
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 |
|
||||||
|
|
||||||
|
Comparison Between stream_score_ctc and score_ctc.
|
||||||
|
FRRs with FAR fixed at once per 12 hours:
|
||||||
|
|
||||||
|
| model | stream | hi_xiaowen | nihao_wenwen |
|
||||||
|
|-----------------------|-------------|------------|--------------|
|
||||||
|
| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 |
|
||||||
|
| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 |
|
||||||
|
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
|
||||||
|
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
|
||||||
|
|
||||||
|
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
|
||||||
|
we record the probability of a keyword in a decoding path once the keyword appears in this path.
|
||||||
|
Actually the probability will increase through the time, so we record a lower value of probability,
|
||||||
|
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
||||||
|
The actual FRR will be lower than the DET curve gives in a given threshold.
|
||||||
|
|
||||||
|
Now, the model with CTC loss may not get the best performance,
|
||||||
|
but it's more robust compared with the classification model using CE/Max-pooling loss.
|
||||||
|
For more result of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
||||||
@ -153,7 +153,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|||||||
if [ ! -d mobvoi_kws_transcription ] ;then
|
if [ ! -d mobvoi_kws_transcription ] ;then
|
||||||
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
||||||
fi
|
fi
|
||||||
checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not be the best.
|
checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not converge well.
|
||||||
fi
|
fi
|
||||||
|
|
||||||
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
@ -181,7 +181,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
mkdir -p $result_dir
|
mkdir -p $result_dir
|
||||||
python wekws/bin/score_ctc.py \
|
stream=true # we detect keyword online with ctc_prefix_beam_search
|
||||||
|
score_prefix=""
|
||||||
|
if $stream ; then
|
||||||
|
score_prefix=stream_
|
||||||
|
fi
|
||||||
|
python wekws/bin/${score_prefix}score_ctc.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--gpu 0 \
|
--gpu 0 \
|
||||||
|
|||||||
@ -132,7 +132,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
mkdir -p $result_dir
|
mkdir -p $result_dir
|
||||||
python wekws/bin/score_ctc.py \
|
stream=true # we detect keyword online with ctc_prefix_beam_search
|
||||||
|
score_prefix=""
|
||||||
|
if $stream ; then
|
||||||
|
score_prefix=stream_
|
||||||
|
fi
|
||||||
|
python wekws/bin/${score_prefix}score_ctc.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--gpu 0 \
|
--gpu 0 \
|
||||||
|
|||||||
@ -41,6 +41,9 @@ def main():
|
|||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
feature_dim = configs['model']['input_dim']
|
feature_dim = configs['model']['input_dim']
|
||||||
model = init_model(configs['model'])
|
model = init_model(configs['model'])
|
||||||
|
if configs['training_config']['criterion'] == 'ctc':
|
||||||
|
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
|
||||||
|
model.forward = model.forward_softmax
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
load_checkpoint(model, args.checkpoint)
|
load_checkpoint(model, args.checkpoint)
|
||||||
|
|||||||
@ -169,7 +169,7 @@ def main():
|
|||||||
keywords_idxset)
|
keywords_idxset)
|
||||||
hit_keyword = None
|
hit_keyword = None
|
||||||
hit_score = 1.0
|
hit_score = 1.0
|
||||||
# start = 0; end = 0
|
start = 0; end = 0
|
||||||
for one_hyp in hyps:
|
for one_hyp in hyps:
|
||||||
prefix_ids = one_hyp[0]
|
prefix_ids = one_hyp[0]
|
||||||
# path_score = one_hyp[1]
|
# path_score = one_hyp[1]
|
||||||
@ -180,8 +180,8 @@ def main():
|
|||||||
offset = is_sublist(prefix_ids, lab)
|
offset = is_sublist(prefix_ids, lab)
|
||||||
if offset != -1:
|
if offset != -1:
|
||||||
hit_keyword = word
|
hit_keyword = word
|
||||||
# start = prefix_nodes[offset]['frame']
|
start = prefix_nodes[offset]['frame']
|
||||||
# end = prefix_nodes[offset+len(lab)-1]['frame']
|
end = prefix_nodes[offset+len(lab)-1]['frame']
|
||||||
for idx in range(offset, offset + len(lab)):
|
for idx in range(offset, offset + len(lab)):
|
||||||
hit_score *= prefix_nodes[idx]['prob']
|
hit_score *= prefix_nodes[idx]['prob']
|
||||||
break
|
break
|
||||||
@ -190,12 +190,13 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if hit_keyword is not None:
|
if hit_keyword is not None:
|
||||||
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
|
fout.write('{} detected {} {:.3f}\n'.format(key, hit_keyword, hit_score))
|
||||||
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
|
logging.info(
|
||||||
fout.write('{} detected {} {:.3f}\n'.format(
|
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
||||||
key, hit_keyword, hit_score))
|
f"duration {end-start}, score {hit_score}, Activated.")
|
||||||
else:
|
else:
|
||||||
fout.write('{} rejected\n'.format(key))
|
fout.write('{} rejected\n'.format(key))
|
||||||
|
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||||
|
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
print('Progress batch {}'.format(batch_idx))
|
print('Progress batch {}'.format(batch_idx))
|
||||||
|
|||||||
331
wekws/bin/stream_score_ctc.py
Normal file
331
wekws/bin/stream_score_ctc.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
|
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||||
|
# 2023 Jing Du(thuduj12@163.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os, sys, math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from collections import defaultdict
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from wekws.dataset.dataset import Dataset
|
||||||
|
from wekws.model.kws_model import init_model
|
||||||
|
from wekws.utils.checkpoint import load_checkpoint
|
||||||
|
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='recognize with your model')
|
||||||
|
parser.add_argument('--config', required=True, help='config file')
|
||||||
|
parser.add_argument('--test_data', required=True, help='test data file')
|
||||||
|
parser.add_argument('--gpu',
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help='gpu id for this rank, -1 for cpu')
|
||||||
|
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||||
|
parser.add_argument('--batch_size',
|
||||||
|
default=16,
|
||||||
|
type=int,
|
||||||
|
help='batch size for inference')
|
||||||
|
parser.add_argument('--num_workers',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='num of subprocess workers for reading')
|
||||||
|
parser.add_argument('--pin_memory',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Use pinned memory buffers used for reading')
|
||||||
|
parser.add_argument('--prefetch',
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help='prefetch number')
|
||||||
|
parser.add_argument('--score_file',
|
||||||
|
required=True,
|
||||||
|
help='output score file')
|
||||||
|
parser.add_argument('--jit_model',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Use pinned memory buffers used for reading')
|
||||||
|
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
|
||||||
|
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
||||||
|
parser.add_argument('--score_beam_size',
|
||||||
|
default=3,
|
||||||
|
type=int,
|
||||||
|
help='The first prune beam, filter out those frames with low scores.')
|
||||||
|
parser.add_argument('--path_beam_size',
|
||||||
|
default=20,
|
||||||
|
type=int,
|
||||||
|
help='The second prune beam, keep only path_beam_size candidates.')
|
||||||
|
parser.add_argument('--threshold',
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help='The threshold of kws. If ctc_search probs exceed this value,'
|
||||||
|
'the keyword will be activated.')
|
||||||
|
parser.add_argument('--min_frames',
|
||||||
|
default=5,
|
||||||
|
type=int,
|
||||||
|
help='The min frames of keyword\'s duration.')
|
||||||
|
parser.add_argument('--max_frames',
|
||||||
|
default=250,
|
||||||
|
type=int,
|
||||||
|
help='The max frames of keyword\'s duration.')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def is_sublist(main_list, check_list):
|
||||||
|
if len(main_list) < len(check_list):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
if len(main_list) == len(check_list):
|
||||||
|
return 0 if main_list == check_list else -1
|
||||||
|
|
||||||
|
for i in range(len(main_list) - len(check_list)):
|
||||||
|
if main_list[i] == check_list[0]:
|
||||||
|
for j in range(len(check_list)):
|
||||||
|
if main_list[i + j] != check_list[j]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return i
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||||
|
|
||||||
|
with open(args.config, 'r') as fin:
|
||||||
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
test_conf = copy.deepcopy(configs['dataset_conf'])
|
||||||
|
test_conf['filter_conf']['max_length'] = 102400
|
||||||
|
test_conf['filter_conf']['min_length'] = 0
|
||||||
|
test_conf['speed_perturb'] = False
|
||||||
|
test_conf['spec_aug'] = False
|
||||||
|
test_conf['shuffle'] = False
|
||||||
|
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||||
|
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||||
|
|
||||||
|
test_dataset = Dataset(args.test_data, test_conf)
|
||||||
|
test_data_loader = DataLoader(test_dataset,
|
||||||
|
batch_size=None,
|
||||||
|
pin_memory=args.pin_memory,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch)
|
||||||
|
|
||||||
|
if args.jit_model:
|
||||||
|
model = torch.jit.load(args.checkpoint)
|
||||||
|
# For script model, only cpu is supported.
|
||||||
|
device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
# Init asr model from configs
|
||||||
|
model = init_model(configs['model'])
|
||||||
|
load_checkpoint(model, args.checkpoint)
|
||||||
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||||
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
score_abs_path = os.path.abspath(args.score_file)
|
||||||
|
|
||||||
|
token_table = read_token(args.token_file)
|
||||||
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
|
# 4. parse keywords tokens
|
||||||
|
assert args.keywords is not None, 'at least one keyword is needed'
|
||||||
|
keywords_str = args.keywords
|
||||||
|
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||||
|
keywords_token = {}
|
||||||
|
keywords_idxset = {0}
|
||||||
|
keywords_strset = {'<blk>'}
|
||||||
|
keywords_tokenmap = {'<blk>': 0}
|
||||||
|
for keyword in keywords_list:
|
||||||
|
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||||
|
keywords_token[keyword] = {}
|
||||||
|
keywords_token[keyword]['token_id'] = indexes
|
||||||
|
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||||
|
for i in indexes)
|
||||||
|
[keywords_strset.add(i) for i in strs]
|
||||||
|
[keywords_idxset.add(i) for i in indexes]
|
||||||
|
for txt, idx in zip(strs, indexes):
|
||||||
|
if keywords_tokenmap.get(txt, None) is None:
|
||||||
|
keywords_tokenmap[txt] = idx
|
||||||
|
|
||||||
|
token_print = ''
|
||||||
|
for txt, idx in keywords_tokenmap.items():
|
||||||
|
token_print += f'{txt}({idx}) '
|
||||||
|
logging.info(f'Token set is: {token_print}')
|
||||||
|
|
||||||
|
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||||
|
for batch_idx, batch in enumerate(test_data_loader):
|
||||||
|
keys, feats, target, lengths, target_lengths = batch
|
||||||
|
feats = feats.to(device)
|
||||||
|
lengths = lengths.to(device)
|
||||||
|
logits, _ = model(feats)
|
||||||
|
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||||
|
logits = logits.cpu()
|
||||||
|
for i in range(len(keys)):
|
||||||
|
key = keys[i]
|
||||||
|
score = logits[i][:lengths[i]]
|
||||||
|
# hyps = ctc_prefix_beam_search(score, lengths[i],
|
||||||
|
# keywords_idxset)
|
||||||
|
maxlen = score.size(0)
|
||||||
|
ctc_probs = score
|
||||||
|
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||||
|
|
||||||
|
hit_keyword = None
|
||||||
|
activated = False
|
||||||
|
hit_score = 1.0
|
||||||
|
start = 0; end = 0
|
||||||
|
|
||||||
|
# 2. CTC beam search step by step
|
||||||
|
for t in range(0, maxlen):
|
||||||
|
probs = ctc_probs[t] # (vocab_size,)
|
||||||
|
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||||
|
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||||
|
|
||||||
|
# 2.1 First beam prune: select topk best
|
||||||
|
top_k_probs, top_k_index = probs.topk(args.score_beam_size)
|
||||||
|
|
||||||
|
# filter prob score that is too small
|
||||||
|
filter_probs = []
|
||||||
|
filter_index = []
|
||||||
|
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||||
|
if keywords_idxset is not None:
|
||||||
|
if prob > 0.05 and idx in keywords_idxset:
|
||||||
|
filter_probs.append(prob)
|
||||||
|
filter_index.append(idx)
|
||||||
|
else:
|
||||||
|
if prob > 0.05:
|
||||||
|
filter_probs.append(prob)
|
||||||
|
filter_index.append(idx)
|
||||||
|
|
||||||
|
if len(filter_index) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for s in filter_index:
|
||||||
|
ps = probs[s].item()
|
||||||
|
|
||||||
|
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||||
|
last = prefix[-1] if len(prefix) > 0 else None
|
||||||
|
if s == 0: # blank
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||||
|
n_pb = n_pb + pb * ps + pnb * ps
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
elif s == last:
|
||||||
|
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||||
|
# Update *ss -> *s;
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||||
|
n_pnb = n_pnb + pnb * ps
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
|
nodes[-1]['prob'] = ps
|
||||||
|
nodes[-1]['frame'] = t
|
||||||
|
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
|
||||||
|
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||||
|
# Update *s-s -> *ss, - is for blank
|
||||||
|
n_prefix = prefix + (s,)
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
|
n_pnb = n_pnb + pb * ps
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
nodes.append(dict(token=s, frame=t,
|
||||||
|
prob=ps)) # to record token prob
|
||||||
|
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
else:
|
||||||
|
n_prefix = prefix + (s,)
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
|
if nodes:
|
||||||
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
|
# nodes[-1]['prob'] = ps
|
||||||
|
# nodes[-1]['frame'] = t
|
||||||
|
nodes.pop() # to avoid change other beam which has this node.
|
||||||
|
nodes.append(dict(token=s, frame=t, prob=ps))
|
||||||
|
else:
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
nodes.append(dict(token=s, frame=t,
|
||||||
|
prob=ps)) # to record token prob
|
||||||
|
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||||
|
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
|
||||||
|
# 2.2 Second beam prune
|
||||||
|
next_hyps = sorted(
|
||||||
|
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||||
|
|
||||||
|
cur_hyps = next_hyps[:args.path_beam_size]
|
||||||
|
|
||||||
|
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||||
|
|
||||||
|
for one_hyp in hyps:
|
||||||
|
prefix_ids = one_hyp[0]
|
||||||
|
# path_score = one_hyp[1]
|
||||||
|
prefix_nodes = one_hyp[2]
|
||||||
|
assert len(prefix_ids) == len(prefix_nodes)
|
||||||
|
for word in keywords_token.keys():
|
||||||
|
lab = keywords_token[word]['token_id']
|
||||||
|
offset = is_sublist(prefix_ids, lab)
|
||||||
|
if offset != -1:
|
||||||
|
hit_keyword = word
|
||||||
|
start = prefix_nodes[offset]['frame']
|
||||||
|
end = prefix_nodes[offset + len(lab) - 1]['frame']
|
||||||
|
for idx in range(offset, offset + len(lab)):
|
||||||
|
hit_score *= prefix_nodes[idx]['prob']
|
||||||
|
break
|
||||||
|
if hit_keyword is not None:
|
||||||
|
hit_score = math.sqrt(hit_score)
|
||||||
|
break
|
||||||
|
|
||||||
|
duration = end - start
|
||||||
|
if hit_keyword is not None :
|
||||||
|
if hit_score >= args.threshold and args.min_frames <= duration <= args.max_frames:
|
||||||
|
activated = True
|
||||||
|
fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score))
|
||||||
|
logging.info(
|
||||||
|
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
||||||
|
f"duration {duration}, score {hit_score} Activated.")
|
||||||
|
|
||||||
|
# clear the ctc_prefix buffer, and clear hit_keyword
|
||||||
|
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||||
|
hit_keyword = None
|
||||||
|
hit_score = 1.0
|
||||||
|
elif hit_score < args.threshold:
|
||||||
|
logging.info(
|
||||||
|
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
||||||
|
f"but {hit_score} less than {args.threshold}, Deactivated. ")
|
||||||
|
elif args.min_frames > duration or duration > args.max_frames:
|
||||||
|
logging.info(
|
||||||
|
f"batch:{batch_idx}_{i} detect {hit_keyword} in {key} from {start} to {end} frame. "
|
||||||
|
f"but {duration} beyond range({args.min_frames}~{args.max_frames}), Deactivated. ")
|
||||||
|
if not activated:
|
||||||
|
fout.write('{} rejected\n'.format(key))
|
||||||
|
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||||
|
|
||||||
|
if batch_idx % 10 == 0:
|
||||||
|
print('Progress batch {}'.format(batch_idx))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -74,6 +74,19 @@ class KWSModel(nn.Module):
|
|||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
return x, out_cache
|
return x, out_cache
|
||||||
|
|
||||||
|
def forward_softmax(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.global_cmvn is not None:
|
||||||
|
x = self.global_cmvn(x)
|
||||||
|
x = self.preprocessing(x)
|
||||||
|
x, out_cache = self.backbone(x, in_cache)
|
||||||
|
x = self.classifier(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = x.softmax(2)
|
||||||
|
return x, out_cache
|
||||||
|
|
||||||
def fuse_modules(self):
|
def fuse_modules(self):
|
||||||
self.preprocessing.fuse_modules()
|
self.preprocessing.fuse_modules()
|
||||||
self.backbone.fuse_modules()
|
self.backbone.fuse_modules()
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2021 Binbin Zhang
|
# Copyright (c) 2021 Binbin Zhang
|
||||||
|
# 2023 Jing Du
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -276,8 +277,10 @@ def ctc_prefix_beam_search(
|
|||||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
if nodes:
|
if nodes:
|
||||||
if ps > nodes[-1]['prob']: # update frame and prob
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
nodes[-1]['prob'] = ps
|
# nodes[-1]['prob'] = ps
|
||||||
nodes[-1]['frame'] = t
|
# nodes[-1]['frame'] = t
|
||||||
|
nodes.pop() # to avoid change other beam which has this node.
|
||||||
|
nodes.append(dict(token=s, frame=t, prob=ps))
|
||||||
else:
|
else:
|
||||||
nodes = cur_nodes.copy()
|
nodes = cur_nodes.copy()
|
||||||
nodes.append(dict(token=s, frame=t,
|
nodes.append(dict(token=s, frame=t,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user