[ctc] KWS with CTCloss training and CTC prefix beam search detection. (#135)
* add ctcloss training scripts. * update compute_det_ctc * fix typo. * add fsmn model, can use pretrained kws model from modelscope. * Add streaming detection of CTC model. Add CTC model onnx export. Add CTC model's result in README; For now CTC model runtime is not supported yet. * QA run.sh, maxpooling training scripts is compatible. Ready to PR. * Add a streaming kws demo, support fsmn online forward * fix typo. * Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward. * fix repeat activation, add a interval restrict. * fix timestamp when subsampling!=1. * fix flake8, update training script and README, give pretrained ckpt. * fix quickcheck and flake8 * Add realtime CTC-KWS demo in README. --------- Co-authored-by: dujing <dujing@xmov.ai>
This commit is contained in:
parent
85350c38a8
commit
b233d46552
@ -1,3 +1,5 @@
|
||||
Comparison among different backbones,
|
||||
all models use Max-Pooling loss.
|
||||
FRRs with FAR fixed at once per hour:
|
||||
|
||||
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
|
||||
@ -8,3 +10,58 @@ FRRs with FAR fixed at once per hour:
|
||||
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
|
||||
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||
|
||||
Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
|
||||
and we use CTC prefix beam search to decode and detect keywords,
|
||||
the detection is either in non-streaming or streaming fashion.
|
||||
|
||||
Since the FAR is pretty low when using CTC loss,
|
||||
the follow results are FRRs with FAR fixed at once per 12 hours:
|
||||
|
||||
Comparison between Max-pooling and CTC loss.
|
||||
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
|
||||
FRRs with FAR fixed at once per 12 hours
|
||||
|
||||
| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|
||||
|-----------------------|-------------|------------|--------------|------------|
|
||||
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
|
||||
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
|
||||
|
||||
|
||||
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
|
||||
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
|
||||
FRRs with FAR fixed at once per 12 hours:
|
||||
|
||||
| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|
||||
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
|
||||
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
|
||||
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |
|
||||
|
||||
Now, the DSTCN model with CTC loss may not get the best performance, because the
|
||||
pretraining phase is not sufficiently converged. We recommend you use pretrained
|
||||
FSMN model as initial checkpoint to train your own model.
|
||||
|
||||
Comparison Between stream_score_ctc and score_ctc.
|
||||
FRRs with FAR fixed at once per 12 hours:
|
||||
|
||||
| model | stream | hi_xiaowen | nihao_wenwen |
|
||||
|-----------------------|-------------|------------|--------------|
|
||||
| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 |
|
||||
| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 |
|
||||
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
|
||||
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
|
||||
|
||||
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
|
||||
we record the probability of a keyword in a decoding path once the keyword appears in this path.
|
||||
Actually the probability will increase through the time, so we record a lower value of probability,
|
||||
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
|
||||
The actual FRR will be lower than the DET curve gives in a given threshold.
|
||||
|
||||
On some small data KWS tasks, we believe the FSMN-CTC model is more robust
|
||||
compared with the classification model using CE/Max-pooling loss.
|
||||
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
|
||||
|
||||
For realtime CTC-KWS, we should process wave input on streaming-fashion,
|
||||
include feature extraction, keyword decoding and detection and some postprocessing.
|
||||
Here is a [demo](https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/summary) in python,
|
||||
the core code is in wekws/bin/stream_kws_ctc.py, you can refer it to implement the runtime code.
|
||||
|
||||
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Normal file
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 20
|
||||
max_f: 10
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 256
|
||||
|
||||
model:
|
||||
hidden_dim: 256
|
||||
preprocessing:
|
||||
type: linear
|
||||
backbone:
|
||||
type: tcn
|
||||
ds: true
|
||||
num_layers: 4
|
||||
kernel_size: 8
|
||||
dropout: 0.1
|
||||
activation:
|
||||
type: identity
|
||||
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 0.0001
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 80
|
||||
log_interval: 10
|
||||
criterion: ctc
|
||||
|
||||
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Normal file
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 20
|
||||
max_f: 10
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 200
|
||||
|
||||
model:
|
||||
hidden_dim: 256
|
||||
preprocessing:
|
||||
type: linear
|
||||
backbone:
|
||||
type: tcn
|
||||
ds: true
|
||||
num_layers: 4
|
||||
kernel_size: 8
|
||||
dropout: 0.1
|
||||
activation:
|
||||
type: identity
|
||||
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 0.0001
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 50
|
||||
log_interval: 100
|
||||
criterion: ctc
|
||||
|
||||
64
examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml
Normal file
64
examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml
Normal file
@ -0,0 +1,64 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 80
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.
|
||||
context_expansion: true
|
||||
context_expansion_conf:
|
||||
left: 2
|
||||
right: 2
|
||||
frame_skip: 3
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 20
|
||||
max_f: 10
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 256
|
||||
|
||||
model:
|
||||
input_dim: 400
|
||||
preprocessing:
|
||||
type: none
|
||||
hidden_dim: 128
|
||||
backbone:
|
||||
type: fsmn
|
||||
input_affine_dim: 140
|
||||
num_layers: 4
|
||||
linear_dim: 250
|
||||
proj_dim: 128
|
||||
left_order: 10
|
||||
right_order: 2
|
||||
left_stride: 1
|
||||
right_stride: 1
|
||||
output_affine_dim: 140
|
||||
classifier:
|
||||
type: identity
|
||||
dropout: 0.1
|
||||
activation:
|
||||
type: identity
|
||||
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 0.0001
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 80
|
||||
log_interval: 10
|
||||
criterion: ctc
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
. ./path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=4
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
num_keywords=2
|
||||
|
||||
config=conf/ds_tcn.yaml
|
||||
@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
python wekws/bin/score.py \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/test/data.list \
|
||||
--gpu 0 \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
@ -111,6 +112,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
--score_file $result_dir/score.txt \
|
||||
--stats_file $result_dir/stats.${keyword}.txt
|
||||
done
|
||||
|
||||
# plot det curve
|
||||
python wekws/bin/plot_det_curve.py \
|
||||
--keywords_dict dict/words.txt \
|
||||
--stats_dir $result_dir \
|
||||
--figure_file $result_dir/det.png
|
||||
fi
|
||||
|
||||
|
||||
|
||||
223
examples/hi_xiaowen/s0/run_ctc.sh
Normal file
223
examples/hi_xiaowen/s0/run_ctc.sh
Normal file
@ -0,0 +1,223 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
|
||||
# 2023 Jing Du(thuduj12@163.com)
|
||||
|
||||
. ./path.sh
|
||||
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
num_keywords=2599
|
||||
|
||||
config=conf/ds_tcn_ctc.yaml
|
||||
norm_mean=true
|
||||
norm_var=true
|
||||
gpus="0"
|
||||
|
||||
checkpoint=
|
||||
dir=exp/ds_tcn_ctc
|
||||
average_model=true
|
||||
num_average=30
|
||||
if $average_model ;then
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
else
|
||||
score_checkpoint=$dir/final.pt
|
||||
fi
|
||||
|
||||
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
window_shift=50
|
||||
|
||||
#Whether to train base model. If set true, must put train+dev data in trainbase_dir
|
||||
trainbase=false
|
||||
trainbase_dir=data/base
|
||||
trainbase_config=conf/ds_tcn_ctc_base.yaml
|
||||
trainbase_exp=exp/base
|
||||
|
||||
if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
|
||||
echo "Download and extracte all datasets"
|
||||
local/mobvoi_data_download.sh --dl_dir $download_dir
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
|
||||
echo "Preparing datasets..."
|
||||
mkdir -p dict
|
||||
echo "<filler> -1" > dict/words.txt
|
||||
echo "Hi_Xiaowen 0" >> dict/words.txt
|
||||
echo "Nihao_Wenwen 1" >> dict/words.txt
|
||||
|
||||
for folder in train dev test; do
|
||||
mkdir -p data/$folder
|
||||
for prefix in p n; do
|
||||
mkdir -p data/${prefix}_$folder
|
||||
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
|
||||
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
|
||||
data/${prefix}_$folder
|
||||
done
|
||||
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
|
||||
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
|
||||
rm -rf data/p_$folder data/n_$folder
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
# to transcribe the negative wavs, and upload the transcription to modelscope.
|
||||
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
||||
for folder in train dev test; do
|
||||
if [ -f data/$folder/text ];then
|
||||
mv data/$folder/text data/$folder/text.label
|
||||
fi
|
||||
cp mobvoi_kws_transcription/$folder.text data/$folder/text
|
||||
done
|
||||
|
||||
# and we also copy the tokens and lexicon that used in
|
||||
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
|
||||
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
|
||||
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Compute CMVN and Format datasets"
|
||||
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||
--in_scp data/train/wav.scp \
|
||||
--out_cmvn data/train/global_cmvn
|
||||
|
||||
for x in train dev test; do
|
||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||
|
||||
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||
data/$x/wav.dur data/$x/data.list \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
||||
for x in train dev ; do
|
||||
if [ ! -f $trainbase_dir/$x/wav.scp ] || [ ! -f $trainbase_dir/$x/text ]; then
|
||||
echo "If You Want to Train Base KWS-CTC Model, You Should Prepare ASR Data by Yourself."
|
||||
echo "The wav.scp and text in KALDI-format is Needed, You Should Put Them in $trainbase_dir/$x"
|
||||
exit
|
||||
fi
|
||||
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
|
||||
tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
|
||||
fi
|
||||
|
||||
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||
if [ ! -f $trainbase_dir/$x/data.list ]; then
|
||||
tools/make_list.py $trainbase_dir/$x/wav.scp $trainbase_dir/$x/text \
|
||||
$trainbase_dir/$x/wav.dur $trainbase_dir/$x/data.list \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Start base training ..."
|
||||
mkdir -p $trainbase_exp
|
||||
cmvn_opts=
|
||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
wekws/bin/train.py --gpus $gpus \
|
||||
--config $trainbase_config \
|
||||
--train_data $trainbase_dir/train/data.list \
|
||||
--cv_data $trainbase_dir/dev/data.list \
|
||||
--model_dir $trainbase_exp \
|
||||
--num_workers 2 \
|
||||
--ddp.dist_backend nccl \
|
||||
--num_keywords $num_keywords \
|
||||
--min_duration 50 \
|
||||
--seed 666 \
|
||||
$cmvn_opts # \
|
||||
#--checkpoint $trainbase_exp/23.pt
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Start training ..."
|
||||
mkdir -p $dir
|
||||
cmvn_opts=
|
||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||
|
||||
if $trainbase; then
|
||||
echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt"
|
||||
checkpoint=$trainbase_exp/final.pt
|
||||
else
|
||||
echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/23.pt"
|
||||
if [ ! -d mobvoi_kws_transcription ] ;then
|
||||
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
||||
fi
|
||||
checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not converge well.
|
||||
fi
|
||||
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
wekws/bin/train.py --gpus $gpus \
|
||||
--config $config \
|
||||
--train_data data/train/data.list \
|
||||
--cv_data data/dev/data.list \
|
||||
--model_dir $dir \
|
||||
--num_workers 8 \
|
||||
--num_keywords $num_keywords \
|
||||
--min_duration 50 \
|
||||
--seed 666 \
|
||||
$cmvn_opts \
|
||||
${checkpoint:+--checkpoint $checkpoint}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Do model average, Compute FRR/FAR ..."
|
||||
if $average_model; then
|
||||
python wekws/bin/average_model.py \
|
||||
--dst_model $score_checkpoint \
|
||||
--src_path $dir \
|
||||
--num ${num_average} \
|
||||
--val_best
|
||||
fi
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
mkdir -p $result_dir
|
||||
stream=true # we detect keyword online with ctc_prefix_beam_search
|
||||
score_prefix=""
|
||||
if $stream ; then
|
||||
score_prefix=stream_
|
||||
fi
|
||||
python wekws/bin/${score_prefix}score_ctc.py \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/test/data.list \
|
||||
--gpu 0 \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
--num_workers 8 \
|
||||
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
|
||||
python wekws/bin/compute_det_ctc.py \
|
||||
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||
--test_data data/test/data.list \
|
||||
--window_shift $window_shift \
|
||||
--step 0.001 \
|
||||
--score_file $result_dir/score.txt \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||
python wekws/bin/export_jit.py \
|
||||
--config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--jit_model $dir/$jit_model
|
||||
python wekws/bin/export_onnx.py \
|
||||
--config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--onnx_model $dir/$onnx_model
|
||||
fi
|
||||
175
examples/hi_xiaowen/s0/run_fsmn_ctc.sh
Normal file
175
examples/hi_xiaowen/s0/run_fsmn_ctc.sh
Normal file
@ -0,0 +1,175 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
|
||||
# 2023 Jing Du(thuduj12@163.com)
|
||||
|
||||
. ./path.sh
|
||||
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
num_keywords=2599
|
||||
|
||||
config=conf/fsmn_ctc.yaml
|
||||
norm_mean=true
|
||||
norm_var=true
|
||||
gpus="0"
|
||||
|
||||
checkpoint=
|
||||
dir=exp/fsmn_ctc
|
||||
average_model=true
|
||||
num_average=30
|
||||
if $average_model ;then
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
else
|
||||
score_checkpoint=$dir/final.pt
|
||||
fi
|
||||
|
||||
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
window_shift=50
|
||||
|
||||
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
|
||||
echo "Download and extracte all datasets"
|
||||
local/mobvoi_data_download.sh --dl_dir $download_dir
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "Preparing datasets..."
|
||||
mkdir -p dict
|
||||
echo "<filler> -1" > dict/words.txt
|
||||
echo "Hi_Xiaowen 0" >> dict/words.txt
|
||||
echo "Nihao_Wenwen 1" >> dict/words.txt
|
||||
|
||||
for folder in train dev test; do
|
||||
mkdir -p data/$folder
|
||||
for prefix in p n; do
|
||||
mkdir -p data/${prefix}_$folder
|
||||
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
|
||||
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
|
||||
data/${prefix}_$folder
|
||||
done
|
||||
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
|
||||
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
|
||||
rm -rf data/p_$folder data/n_$folder
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le -0 ] && [ ${stop_stage} -ge -0 ]; then
|
||||
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
# to transcribe the negative wavs, and upload the transcription to modelscope.
|
||||
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
||||
for folder in train dev test; do
|
||||
if [ -f data/$folder/text ];then
|
||||
mv data/$folder/text data/$folder/text.label
|
||||
fi
|
||||
cp mobvoi_kws_transcription/$folder.text data/$folder/text
|
||||
done
|
||||
|
||||
# and we also copy the tokens and lexicon that used in
|
||||
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
|
||||
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
|
||||
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Compute CMVN and Format datasets"
|
||||
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||
--in_scp data/train/wav.scp \
|
||||
--out_cmvn data/train/global_cmvn
|
||||
|
||||
for x in train dev test; do
|
||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||
|
||||
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||
data/$x/wav.dur data/$x/data.list \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
|
||||
echo "Use the base model from modelscope"
|
||||
if [ ! -d speech_charctc_kws_phone-xiaoyun ] ;then
|
||||
git lfs install
|
||||
git clone https://www.modelscope.cn/damo/speech_charctc_kws_phone-xiaoyun.git
|
||||
fi
|
||||
checkpoint=speech_charctc_kws_phone-xiaoyun/train/base.pt
|
||||
cp speech_charctc_kws_phone-xiaoyun/train/feature_transform.txt.80dim-l2r2 data/global_cmvn.kaldi
|
||||
|
||||
echo "Start training ..."
|
||||
mkdir -p $dir
|
||||
cmvn_opts=
|
||||
$norm_mean && cmvn_opts="--cmvn_file data/global_cmvn.kaldi"
|
||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
wekws/bin/train.py --gpus $gpus \
|
||||
--config $config \
|
||||
--train_data data/train/data.list \
|
||||
--cv_data data/dev/data.list \
|
||||
--model_dir $dir \
|
||||
--num_workers 8 \
|
||||
--num_keywords $num_keywords \
|
||||
--min_duration 50 \
|
||||
--seed 666 \
|
||||
$cmvn_opts \
|
||||
${checkpoint:+--checkpoint $checkpoint}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Do model average, Compute FRR/FAR ..."
|
||||
if $average_model; then
|
||||
python wekws/bin/average_model.py \
|
||||
--dst_model $score_checkpoint \
|
||||
--src_path $dir \
|
||||
--num ${num_average} \
|
||||
--val_best
|
||||
fi
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
mkdir -p $result_dir
|
||||
stream=true # we detect keyword online with ctc_prefix_beam_search
|
||||
score_prefix=""
|
||||
if $stream ; then
|
||||
score_prefix=stream_
|
||||
fi
|
||||
python wekws/bin/${score_prefix}score_ctc.py \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/test/data.list \
|
||||
--gpu 0 \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
--num_workers 8 \
|
||||
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
|
||||
python wekws/bin/compute_det_ctc.py \
|
||||
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
|
||||
--test_data data/test/data.list \
|
||||
--window_shift $window_shift \
|
||||
--step 0.001 \
|
||||
--score_file $result_dir/score.txt \
|
||||
--token_file data/tokens.txt \
|
||||
--lexicon_file data/lexicon.txt
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||
# For now, FSMN can not export to JITScript
|
||||
# python wekws/bin/export_jit.py \
|
||||
# --config $dir/config.yaml \
|
||||
# --checkpoint $score_checkpoint \
|
||||
# --jit_model $dir/$jit_model
|
||||
python wekws/bin/export_onnx.py \
|
||||
--config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--onnx_model $dir/$onnx_model
|
||||
fi
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2023 Jing Du(thuduj12@163.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -15,7 +16,145 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
s = input_str.lower()
|
||||
while len(s) > 0:
|
||||
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||
if match is not None:
|
||||
word = match.group(0)
|
||||
else:
|
||||
word = s[0:1]
|
||||
tokens.append(word)
|
||||
s = s.replace(word, '', 1).strip(' ')
|
||||
return tokens
|
||||
|
||||
def query_token_set(txt, symbol_table, lexicon_table):
|
||||
tokens_str = tuple()
|
||||
tokens_idx = tuple()
|
||||
|
||||
parts = split_mixed_label(txt)
|
||||
for part in parts:
|
||||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||
tokens_str = tokens_str + ('!sil', )
|
||||
elif part == '<blk>' or part == '<blank>':
|
||||
tokens_str = tokens_str + ('<blk>', )
|
||||
elif part == '(noise)' or part == 'noise)' or \
|
||||
part == '(noise' or part == '<noise>':
|
||||
tokens_str = tokens_str + ('<GBG>', )
|
||||
elif part in symbol_table:
|
||||
tokens_str = tokens_str + (part, )
|
||||
elif part in lexicon_table:
|
||||
for ch in lexicon_table[part]:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
else:
|
||||
# case with symbols or meaningless english letter combination
|
||||
part = re.sub(symbol_str, '', part)
|
||||
for ch in part:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
|
||||
for ch in tokens_str:
|
||||
if ch in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table[ch], )
|
||||
elif ch == '!sil':
|
||||
if 'sil' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['sil'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||
elif ch == '<GBG>':
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||
else:
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||
logging.info(
|
||||
f'{ch} is not in token set, replace with <GBG>')
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||
logging.info(
|
||||
f'{ch} is not in token set, replace with <blk>')
|
||||
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
|
||||
def query_token_list(txt, symbol_table, lexicon_table):
|
||||
tokens_str = []
|
||||
tokens_idx = []
|
||||
|
||||
parts = split_mixed_label(txt)
|
||||
for part in parts:
|
||||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||
tokens_str.append('!sil')
|
||||
elif part == '<blk>' or part == '<blank>':
|
||||
tokens_str.append('<blk>')
|
||||
elif part == '(noise)' or part == 'noise)' or \
|
||||
part == '(noise' or part == '<noise>':
|
||||
tokens_str.append('<GBG>')
|
||||
elif part in symbol_table:
|
||||
tokens_str.append(part)
|
||||
elif part in lexicon_table:
|
||||
for ch in lexicon_table[part]:
|
||||
tokens_str.append(ch)
|
||||
else:
|
||||
# case with symbols or meaningless english letter combination
|
||||
part = re.sub(symbol_str, '', part)
|
||||
for ch in part:
|
||||
tokens_str.append(ch)
|
||||
|
||||
for ch in tokens_str:
|
||||
if ch in symbol_table:
|
||||
tokens_idx.append(symbol_table[ch])
|
||||
elif ch == '!sil':
|
||||
if 'sil' in symbol_table:
|
||||
tokens_idx.append(symbol_table['sil'])
|
||||
else:
|
||||
tokens_idx.append(symbol_table['<blk>'])
|
||||
elif ch == '<GBG>':
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx.append(symbol_table['<GBG>'])
|
||||
else:
|
||||
tokens_idx.append(symbol_table['<blk>'])
|
||||
else:
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx.append(symbol_table['<GBG>'])
|
||||
logging.info(
|
||||
f'{ch} is not in token set, replace with <GBG>')
|
||||
else:
|
||||
tokens_idx.append(symbol_table['<blk>'])
|
||||
logging.info(
|
||||
f'{ch} is not in token set, replace with <blk>')
|
||||
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
def read_token(token_file):
|
||||
tokens_table = {}
|
||||
with open(token_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
assert len(arr) == 2
|
||||
tokens_table[arr[0]] = int(arr[1]) - 1
|
||||
fin.close()
|
||||
return tokens_table
|
||||
|
||||
|
||||
def read_lexicon(lexicon_file):
|
||||
lexicon_table = {}
|
||||
with open(lexicon_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().replace('\t', ' ').split()
|
||||
assert len(arr) >= 2
|
||||
lexicon_table[arr[0]] = arr[1:]
|
||||
fin.close()
|
||||
return lexicon_table
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
@ -23,6 +162,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument('text_file', help='text file')
|
||||
parser.add_argument('duration_file', help='duration file')
|
||||
parser.add_argument('output_file', help='output list file')
|
||||
parser.add_argument('--token_file', type=str, default=None,
|
||||
help='the path of tokens.txt')
|
||||
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||
help='the path of lexicon.txt')
|
||||
args = parser.parse_args()
|
||||
|
||||
wav_table = {}
|
||||
@ -39,16 +182,38 @@ if __name__ == '__main__':
|
||||
assert len(arr) == 2
|
||||
duration_table[arr[0]] = float(arr[1])
|
||||
|
||||
token_table = None
|
||||
if args.token_file:
|
||||
token_table = read_token(args.token_file)
|
||||
lexicon_table = None
|
||||
if args.lexicon_file:
|
||||
lexicon_table = read_lexicon(args.lexicon_file)
|
||||
|
||||
with open(args.text_file, 'r', encoding='utf8') as fin, \
|
||||
open(args.output_file, 'w', encoding='utf8') as fout:
|
||||
for line in fin:
|
||||
arr = line.strip().split(maxsplit=1)
|
||||
key = arr[0]
|
||||
txt = int(arr[1])
|
||||
tokens = None
|
||||
if token_table is not None and lexicon_table is not None :
|
||||
if len(arr) < 2: # for some utterence, no text
|
||||
txt = [1] # the <blank>/sil is indexed by 1
|
||||
tokens = ["sil"]
|
||||
else:
|
||||
tokens, txt = query_token_list(arr[1],
|
||||
token_table,
|
||||
lexicon_table)
|
||||
else:
|
||||
txt = int(arr[1])
|
||||
assert key in wav_table
|
||||
wav = wav_table[key]
|
||||
assert key in duration_table
|
||||
duration = duration_table[key]
|
||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||
if tokens is None:
|
||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||
else:
|
||||
line = dict(key=key, tok=tokens, txt=txt,
|
||||
duration=duration, wav=wav)
|
||||
|
||||
json_line = json.dumps(line, ensure_ascii=False)
|
||||
fout.write(json_line + '\n')
|
||||
|
||||
271
wekws/bin/compute_det_ctc.py
Normal file
271
wekws/bin/compute_det_ctc.py
Normal file
@ -0,0 +1,271 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||
# 2023 Jing Du(thuduj12@163.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pypinyin # for Chinese Character
|
||||
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
s = input_str.lower()
|
||||
while len(s) > 0:
|
||||
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||
if match is not None:
|
||||
word = match.group(0)
|
||||
else:
|
||||
word = s[0:1]
|
||||
tokens.append(word)
|
||||
s = s.replace(word, '', 1).strip(' ')
|
||||
return tokens
|
||||
|
||||
|
||||
def space_mixed_label(input_str):
|
||||
splits = split_mixed_label(input_str)
|
||||
space_str = ''.join(f'{sub} ' for sub in splits)
|
||||
return space_str.strip()
|
||||
|
||||
def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
|
||||
score_table = {}
|
||||
with open(score_file, 'r', encoding='utf8') as fin:
|
||||
# read score file and store in table
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
key = arr[0]
|
||||
is_detected = arr[1]
|
||||
if is_detected == 'detected':
|
||||
keyword = true_keywords[arr[2]]
|
||||
if key not in score_table:
|
||||
score_table.update({
|
||||
key: {
|
||||
'kw': space_mixed_label(keyword),
|
||||
'confi': float(arr[3])
|
||||
}
|
||||
})
|
||||
else:
|
||||
if key not in score_table:
|
||||
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
|
||||
|
||||
label_lists = []
|
||||
with open(label_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
obj = json.loads(line.strip())
|
||||
label_lists.append(obj)
|
||||
|
||||
# build empty structure for keyword-filler infos
|
||||
keyword_filler_table = {}
|
||||
for keyword in keywords_list:
|
||||
keyword = true_keywords[keyword]
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_filler_table[keyword] = {}
|
||||
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||
keyword_filler_table[keyword]['keyword_duration'] = 0.0
|
||||
keyword_filler_table[keyword]['filler_table'] = {}
|
||||
keyword_filler_table[keyword]['filler_duration'] = 0.0
|
||||
|
||||
for obj in label_lists:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'tok' in obj # here we use the tokens
|
||||
assert 'duration' in obj
|
||||
|
||||
key = obj['key']
|
||||
txt = "".join(obj['tok'])
|
||||
txt = space_mixed_label(txt)
|
||||
txt_regstr_lrblk = ' ' + txt + ' '
|
||||
duration = obj['duration']
|
||||
assert key in score_table
|
||||
|
||||
for keyword in keywords_list:
|
||||
keyword = true_keywords[keyword]
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_regstr_lrblk = ' ' + keyword + ' '
|
||||
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
else:
|
||||
# uttrance detected but not match this keyword
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword]['keyword_duration'] += duration
|
||||
else:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
else:
|
||||
# uttrance if detected, which is not FA for this keyword
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword]['filler_duration'] += duration
|
||||
|
||||
return keyword_filler_table
|
||||
|
||||
def load_stats_file(stats_file):
|
||||
values = []
|
||||
with open(stats_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
threshold, fa_per_hour, frr = arr
|
||||
values.append([float(fa_per_hour), float(frr) * 100])
|
||||
values.reverse()
|
||||
return np.array(values)
|
||||
|
||||
def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
|
||||
det_title = "DetCurve"
|
||||
plt.figure(dpi=200)
|
||||
plt.rcParams['xtick.direction'] = 'in'
|
||||
plt.rcParams['ytick.direction'] = 'in'
|
||||
plt.rcParams['font.size'] = 12
|
||||
|
||||
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
|
||||
logging.info(f'reading det data from {file}')
|
||||
label = os.path.basename(file).split('.')[1]
|
||||
label = "".join(pypinyin.lazy_pinyin(label))
|
||||
values = load_stats_file(file)
|
||||
plt.plot(values[:, 0], values[:, 1], label=label)
|
||||
|
||||
plt.xlim([0, xlim])
|
||||
plt.ylim([0, ylim])
|
||||
plt.xticks(range(0, xlim + x_step, x_step))
|
||||
plt.yticks(range(0, ylim + y_step, y_step))
|
||||
plt.xlabel('False Alarm Per Hour')
|
||||
plt.ylabel('False Rejection Rate (%)')
|
||||
plt.grid(linestyle='--')
|
||||
plt.legend(loc='best', fontsize=6)
|
||||
plt.savefig(figure_file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='compute det curve')
|
||||
parser.add_argument('--test_data', required=True, help='label file')
|
||||
parser.add_argument('--keywords', type=str, default=None,
|
||||
help='keywords, split with comma(,)')
|
||||
parser.add_argument('--token_file', type=str, default=None,
|
||||
help='the path of tokens.txt')
|
||||
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||
help='the path of lexicon.txt')
|
||||
parser.add_argument('--score_file', required=True, help='score file')
|
||||
parser.add_argument('--step', type=float, default=0.01,
|
||||
help='threshold step')
|
||||
parser.add_argument('--window_shift', type=int, default=50,
|
||||
help='window_shift is used to '
|
||||
'skip the frames after triggered')
|
||||
parser.add_argument('--stats_dir',
|
||||
required=False,
|
||||
default=None,
|
||||
help='false reject/alarm stats dir, '
|
||||
'default in score_file')
|
||||
parser.add_argument('--det_curve_path',
|
||||
required=False,
|
||||
default=None,
|
||||
help='det curve path, default is stats_dir/det.png')
|
||||
parser.add_argument(
|
||||
'--xlim',
|
||||
type=int,
|
||||
default=5,
|
||||
help='xlim:range of x-axis, x is false alarm per hour')
|
||||
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
|
||||
parser.add_argument(
|
||||
'--ylim',
|
||||
type=int,
|
||||
default=35,
|
||||
help='ylim:range of y-axis, y is false rejection rate')
|
||||
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
|
||||
|
||||
args = parser.parse_args()
|
||||
window_shift = args.window_shift
|
||||
logging.info(f"keywords is {args.keywords}, "
|
||||
f"Chinese is converted into Unicode.")
|
||||
|
||||
keywords = args.keywords.encode('utf-8').decode('unicode_escape')
|
||||
keywords_list = keywords.strip().split(',')
|
||||
|
||||
token_table = read_token(args.token_file)
|
||||
lexicon_table = read_lexicon(args.lexicon_file)
|
||||
true_keywords = {}
|
||||
for keyword in keywords_list:
|
||||
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||
true_keywords[keyword] = ''.join(strs)
|
||||
|
||||
keyword_filler_table = load_label_and_score(
|
||||
keywords_list, args.test_data, args.score_file, true_keywords)
|
||||
|
||||
for keyword in keywords_list:
|
||||
keyword = true_keywords[keyword]
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
||||
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
||||
assert keyword_num > 0, \
|
||||
'Can\'t compute det for {} without positive sample'
|
||||
assert filler_num > 0, \
|
||||
'Can\'t compute det for {} without negative sample'
|
||||
|
||||
logging.info('Computing det for {}'.format(keyword))
|
||||
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
||||
keyword_dur / 3600.0, keyword_num))
|
||||
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
|
||||
|
||||
if args.stats_dir :
|
||||
stats_dir = args.stats_dir
|
||||
else:
|
||||
stats_dir = os.path.dirname(args.score_file)
|
||||
stats_file = os.path.join(
|
||||
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
||||
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||
threshold = 0.0
|
||||
while threshold <= 1.0:
|
||||
num_false_reject = 0
|
||||
num_true_detect = 0
|
||||
# transverse the all keyword_table
|
||||
for key, confi in \
|
||||
keyword_filler_table[keyword]['keyword_table'].items():
|
||||
if confi < threshold:
|
||||
num_false_reject += 1
|
||||
else:
|
||||
num_true_detect += 1
|
||||
|
||||
num_false_alarm = 0
|
||||
# transverse the all filler_table
|
||||
for key, confi in keyword_filler_table[
|
||||
keyword]['filler_table'].items():
|
||||
if confi >= threshold:
|
||||
num_false_alarm += 1
|
||||
# print(f'false alarm: {keyword}, {key}, {confi}')
|
||||
|
||||
false_reject_rate = num_false_reject / keyword_num
|
||||
true_detect_rate = num_true_detect / keyword_num
|
||||
|
||||
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
|
||||
false_alarm_rate = num_false_alarm / filler_num
|
||||
|
||||
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
|
||||
threshold, false_alarm_per_hour, false_reject_rate))
|
||||
threshold += args.step
|
||||
if args.det_curve_path :
|
||||
det_curve_path = args.det_curve_path
|
||||
else:
|
||||
det_curve_path = os.path.join(stats_dir, 'det.png')
|
||||
plot_det(stats_dir, det_curve_path,
|
||||
args.xlim, args.x_step, args.ylim, args.y_step)
|
||||
@ -41,6 +41,9 @@ def main():
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feature_dim = configs['model']['input_dim']
|
||||
model = init_model(configs['model'])
|
||||
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
|
||||
# if we use ctc_loss, the logits need to be convert into probs
|
||||
model.forward = model.forward_softmax
|
||||
print(model)
|
||||
|
||||
load_checkpoint(model, args.checkpoint)
|
||||
|
||||
@ -106,7 +106,7 @@ def main():
|
||||
score_abs_path = os.path.abspath(args.score_file)
|
||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(test_data_loader):
|
||||
keys, feats, target, lengths = batch
|
||||
keys, feats, target, lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
logits, _ = model(feats)
|
||||
|
||||
219
wekws/bin/score_ctc.py
Normal file
219
wekws/bin/score_ctc.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||
# 2023 Jing Du(thuduj12@163.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from wekws.dataset.dataset import Dataset
|
||||
from wekws.model.kws_model import init_model
|
||||
from wekws.utils.checkpoint import load_checkpoint
|
||||
from wekws.model.loss import ctc_prefix_beam_search
|
||||
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='recognize with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--test_data', required=True, help='test data file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--batch_size',
|
||||
default=16,
|
||||
type=int,
|
||||
help='batch size for inference')
|
||||
parser.add_argument('--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='num of subprocess workers for reading')
|
||||
parser.add_argument('--pin_memory',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--prefetch',
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
parser.add_argument('--score_file',
|
||||
required=True,
|
||||
help='output score file')
|
||||
parser.add_argument('--jit_model',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--keywords', type=str, default=None,
|
||||
help='the keywords, split with comma(,)')
|
||||
parser.add_argument('--token_file', type=str, default=None,
|
||||
help='the path of tokens.txt')
|
||||
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||
help='the path of lexicon.txt')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def is_sublist(main_list, check_list):
|
||||
if len(main_list) < len(check_list):
|
||||
return -1
|
||||
|
||||
if len(main_list) == len(check_list):
|
||||
return 0 if main_list == check_list else -1
|
||||
|
||||
for i in range(len(main_list) - len(check_list)):
|
||||
if main_list[i] == check_list[0]:
|
||||
for j in range(len(check_list)):
|
||||
if main_list[i + j] != check_list[j]:
|
||||
break
|
||||
else:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
|
||||
test_conf = copy.deepcopy(configs['dataset_conf'])
|
||||
test_conf['filter_conf']['max_length'] = 102400
|
||||
test_conf['filter_conf']['min_length'] = 0
|
||||
test_conf['speed_perturb'] = False
|
||||
test_conf['spec_aug'] = False
|
||||
test_conf['shuffle'] = False
|
||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||
|
||||
test_dataset = Dataset(args.test_data, test_conf)
|
||||
test_data_loader = DataLoader(test_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
|
||||
if args.jit_model:
|
||||
model = torch.jit.load(args.checkpoint)
|
||||
# For script model, only cpu is supported.
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
# Init asr model from configs
|
||||
model = init_model(configs['model'])
|
||||
load_checkpoint(model, args.checkpoint)
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
score_abs_path = os.path.abspath(args.score_file)
|
||||
|
||||
token_table = read_token(args.token_file)
|
||||
lexicon_table = read_lexicon(args.lexicon_file)
|
||||
# 4. parse keywords tokens
|
||||
assert args.keywords is not None, 'at least one keyword is needed'
|
||||
logging.info(f"keywords is {args.keywords}, "
|
||||
f"Chinese is converted into Unicode.")
|
||||
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
|
||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||
keywords_token = {}
|
||||
keywords_idxset = {0}
|
||||
keywords_strset = {'<blk>'}
|
||||
keywords_tokenmap = {'<blk>': 0}
|
||||
for keyword in keywords_list:
|
||||
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||
keywords_token[keyword] = {}
|
||||
keywords_token[keyword]['token_id'] = indexes
|
||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||
for i in indexes)
|
||||
[keywords_strset.add(i) for i in strs]
|
||||
[keywords_idxset.add(i) for i in indexes]
|
||||
for txt, idx in zip(strs, indexes):
|
||||
if keywords_tokenmap.get(txt, None) is None:
|
||||
keywords_tokenmap[txt] = idx
|
||||
|
||||
token_print = ''
|
||||
for txt, idx in keywords_tokenmap.items():
|
||||
token_print += f'{txt}({idx}) '
|
||||
logging.info(f'Token set is: {token_print}')
|
||||
|
||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(test_data_loader):
|
||||
keys, feats, target, lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
logits, _ = model(feats)
|
||||
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||
logits = logits.cpu()
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = logits[i][:lengths[i]]
|
||||
hyps = ctc_prefix_beam_search(score,
|
||||
lengths[i],
|
||||
keywords_idxset)
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
start = 0
|
||||
end = 0
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
prefix_nodes = one_hyp[2]
|
||||
assert len(prefix_ids) == len(prefix_nodes)
|
||||
for word in keywords_token.keys():
|
||||
lab = keywords_token[word]['token_id']
|
||||
offset = is_sublist(prefix_ids, lab)
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
start = prefix_nodes[offset]['frame']
|
||||
end = prefix_nodes[offset + len(lab) - 1]['frame']
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
if hit_keyword is not None:
|
||||
hit_score = math.sqrt(hit_score)
|
||||
break
|
||||
|
||||
if hit_keyword is not None:
|
||||
fout.write('{} detected {} {:.3f}\n'.format(
|
||||
key, hit_keyword, hit_score))
|
||||
logging.info(
|
||||
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||
f"in {key} from {start} to {end} frame. "
|
||||
f"duration {end - start}, "
|
||||
f"score {hit_score}, Activated.")
|
||||
else:
|
||||
fout.write('{} rejected\n'.format(key))
|
||||
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
print('Progress batch {}'.format(batch_idx))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
587
wekws/bin/stream_kws_ctc.py
Normal file
587
wekws/bin/stream_kws_ctc.py
Normal file
@ -0,0 +1,587 @@
|
||||
# Copyright (c) 2023 Jing Du(thuduj12@163.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import struct
|
||||
# import wave
|
||||
import librosa
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from collections import defaultdict
|
||||
from wekws.model.kws_model import init_model
|
||||
from wekws.utils.checkpoint import load_checkpoint
|
||||
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='detect keywords online.')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--wav_path', required=False,
|
||||
default=None, help='test wave path.')
|
||||
parser.add_argument('--wav_scp', required=False,
|
||||
default=None, help='test wave scp.')
|
||||
parser.add_argument('--result_file', required=False,
|
||||
default=None, help='test result.')
|
||||
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--jit_model',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--keywords', type=str, default=None,
|
||||
help='the keywords, split with comma(,)')
|
||||
parser.add_argument('--token_file', type=str, default=None,
|
||||
help='the path of tokens.txt')
|
||||
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||
help='the path of lexicon.txt')
|
||||
parser.add_argument('--score_beam_size',
|
||||
default=3,
|
||||
type=int,
|
||||
help='The first prune beam, '
|
||||
'filter out those frames with low scores.')
|
||||
parser.add_argument('--path_beam_size',
|
||||
default=20,
|
||||
type=int,
|
||||
help='The second prune beam, '
|
||||
'keep only path_beam_size candidates.')
|
||||
parser.add_argument('--threshold',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='The threshold of kws. '
|
||||
'If ctc_search probs exceed this value,'
|
||||
'the keyword will be activated.')
|
||||
parser.add_argument('--min_frames',
|
||||
default=5,
|
||||
type=int,
|
||||
help='The min frames of keyword\'s duration.')
|
||||
parser.add_argument('--max_frames',
|
||||
default=250,
|
||||
type=int,
|
||||
help='The max frames of keyword\'s duration.')
|
||||
parser.add_argument('--interval_frames',
|
||||
default=50,
|
||||
type=int,
|
||||
help='The interval frames of two continuous keywords.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def is_sublist(main_list, check_list):
|
||||
if len(main_list) < len(check_list):
|
||||
return -1
|
||||
|
||||
if len(main_list) == len(check_list):
|
||||
return 0 if main_list == check_list else -1
|
||||
|
||||
for i in range(len(main_list) - len(check_list)):
|
||||
if main_list[i] == check_list[0]:
|
||||
for j in range(len(check_list)):
|
||||
if main_list[i + j] != check_list[j]:
|
||||
break
|
||||
else:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
def ctc_prefix_beam_search(t, probs,
|
||||
cur_hyps,
|
||||
keywords_idxset,
|
||||
score_beam_size):
|
||||
'''
|
||||
|
||||
:param t: the time in frame
|
||||
:param probs: the probability in t_th frame, (vocab_size, )
|
||||
:param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))]
|
||||
in tuple, 1st is prefix id, 2nd include p_blank,
|
||||
p_non_blank, and path nodes list.
|
||||
in path nodes list, each node is
|
||||
a dict of {token=idx, frame=t, prob=ps}
|
||||
:param keywords_idxset: the index of keywords in token.txt
|
||||
:param score_beam_size: the probability threshold,
|
||||
to filter out those frames with low probs.
|
||||
:return:
|
||||
next_hyps: the hypothesis depend on current hyp and current frame.
|
||||
'''
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_probs, top_k_index = probs.topk(score_beam_size)
|
||||
|
||||
# filter prob score that is too small
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if keywords_idxset is not None:
|
||||
if prob > 0.05 and idx in keywords_idxset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
else:
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
if len(filter_index) == 0:
|
||||
return cur_hyps
|
||||
|
||||
for s in filter_index:
|
||||
ps = probs[s].item()
|
||||
|
||||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == 0: # blank
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pb = n_pb + pb * ps + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
elif s == last:
|
||||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pnb = n_pnb + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s,)
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
n_pnb = n_pnb + pb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
else:
|
||||
n_prefix = prefix + (s,)
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
if nodes:
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
# nodes[-1]['prob'] = ps
|
||||
# nodes[-1]['frame'] = t
|
||||
nodes.pop()
|
||||
# to avoid change other beam which has this node.
|
||||
nodes.append(dict(token=s, frame=t, prob=ps))
|
||||
else:
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(
|
||||
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||
|
||||
return next_hyps
|
||||
|
||||
class KeyWordSpotter(torch.nn.Module):
|
||||
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
|
||||
threshold, min_frames=5, max_frames=250, interval_frames=50,
|
||||
score_beam=3, path_beam=20,
|
||||
gpu=-1, is_jit_model=False,):
|
||||
super().__init__()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
||||
with open(config_path, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
dataset_conf = configs['dataset_conf']
|
||||
|
||||
# feature related
|
||||
self.sample_rate = 16000
|
||||
self.wave_remained = np.array([])
|
||||
self.num_mel_bins = dataset_conf[
|
||||
'feature_extraction_conf']['num_mel_bins']
|
||||
self.frame_length = dataset_conf[
|
||||
'feature_extraction_conf']['frame_length'] # in ms
|
||||
self.frame_shift = dataset_conf[
|
||||
'feature_extraction_conf']['frame_shift'] # in ms
|
||||
self.downsampling = dataset_conf.get('frame_skip', 1)
|
||||
self.resolution = self.frame_shift / 1000 # in second
|
||||
# fsmn splice operation
|
||||
self.context_expansion = dataset_conf.get('context_expansion', False)
|
||||
self.left_context = 0
|
||||
self.right_context = 0
|
||||
if self.context_expansion:
|
||||
self.left_context = dataset_conf['context_expansion_conf']['left']
|
||||
self.right_context = dataset_conf['context_expansion_conf']['right']
|
||||
self.feature_remained = None
|
||||
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||
|
||||
|
||||
# model related
|
||||
if is_jit_model:
|
||||
model = torch.jit.load(ckpt_path)
|
||||
# For script model, only cpu is supported.
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
# Init model from configs
|
||||
model = init_model(configs['model'])
|
||||
load_checkpoint(model, ckpt_path)
|
||||
use_cuda = gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
self.device = device
|
||||
self.model = model.to(device)
|
||||
self.model.eval()
|
||||
logging.info(f'model {ckpt_path} loaded.')
|
||||
self.token_table = read_token(token_path)
|
||||
logging.info(f'tokens {token_path} with '
|
||||
f'{len(self.token_table)} units loaded.')
|
||||
self.lexicon_table = read_lexicon(lexicon_path)
|
||||
logging.info(f'lexicons {lexicon_path} with '
|
||||
f'{len(self.lexicon_table)} units loaded.')
|
||||
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
|
||||
|
||||
# decoding and detection related
|
||||
self.score_beam = score_beam
|
||||
self.path_beam = path_beam
|
||||
|
||||
self.threshold = threshold
|
||||
self.min_frames = min_frames
|
||||
self.max_frames = max_frames
|
||||
self.interval_frames = interval_frames
|
||||
|
||||
self.cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
self.hit_score = 1.0
|
||||
self.hit_keyword = None
|
||||
self.activated = False
|
||||
|
||||
self.total_frames = 0 # frame offset, for absolute time
|
||||
self.last_active_pos = -1 # the last frame of being activated
|
||||
self.result = {}
|
||||
|
||||
def set_keywords(self, keywords):
|
||||
# 4. parse keywords tokens
|
||||
assert keywords is not None, \
|
||||
'at least one keyword is needed, ' \
|
||||
'multiple keywords should be splitted with comma(,)'
|
||||
keywords_str = keywords
|
||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||
keywords_token = {}
|
||||
keywords_idxset = {0}
|
||||
keywords_strset = {'<blk>'}
|
||||
keywords_tokenmap = {'<blk>': 0}
|
||||
for keyword in keywords_list:
|
||||
strs, indexes = query_token_set(
|
||||
keyword, self.token_table, self.lexicon_table)
|
||||
keywords_token[keyword] = {}
|
||||
keywords_token[keyword]['token_id'] = indexes
|
||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||
for i in indexes)
|
||||
[keywords_strset.add(i) for i in strs]
|
||||
[keywords_idxset.add(i) for i in indexes]
|
||||
for txt, idx in zip(strs, indexes):
|
||||
if keywords_tokenmap.get(txt, None) is None:
|
||||
keywords_tokenmap[txt] = idx
|
||||
|
||||
token_print = ''
|
||||
for txt, idx in keywords_tokenmap.items():
|
||||
token_print += f'{txt}({idx}) '
|
||||
logging.info(f'Token set is: {token_print}')
|
||||
self.keywords_idxset = keywords_idxset
|
||||
self.keywords_token = keywords_token
|
||||
|
||||
def accept_wave(self, wave):
|
||||
assert isinstance(wave, bytes), \
|
||||
"please make sure the input format is bytes(raw PCM)"
|
||||
# convert bytes into float32
|
||||
data = []
|
||||
for i in range(0, len(wave), 2):
|
||||
value = struct.unpack('<h', wave[i:i + 2])[0]
|
||||
data.append(value)
|
||||
# here we don't divide 32768.0,
|
||||
# because kaldi.fbank accept original input
|
||||
|
||||
wave = np.array(data)
|
||||
wave = np.append(self.wave_remained, wave)
|
||||
if wave.size < (self.frame_length * self.sample_rate / 1000) \
|
||||
* self.right_context :
|
||||
self.wave_remained = wave
|
||||
return None
|
||||
wave_tensor = torch.from_numpy(wave).float().to(self.device)
|
||||
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
|
||||
feats = kaldi.fbank(wave_tensor,
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
dither=0,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=self.sample_rate)
|
||||
# update wave remained
|
||||
feat_len = len(feats)
|
||||
frame_shift = int(self.frame_shift / 1000 * self.sample_rate)
|
||||
self.wave_remained = wave[feat_len * frame_shift:]
|
||||
|
||||
if self.context_expansion:
|
||||
assert feat_len > self.right_context, \
|
||||
"make sure each chunk feat length is large than right context."
|
||||
# pad feats with remained feature from last chunk
|
||||
if self.feature_remained is None: # first chunk
|
||||
# pad first frame at the beginning,
|
||||
# replicate just support last dimension, so we do transpose.
|
||||
feats_pad = F.pad(
|
||||
feats.T, (self.left_context, 0), mode='replicate').T
|
||||
else:
|
||||
feats_pad = torch.cat((self.feature_remained, feats))
|
||||
|
||||
ctx_frm = feats_pad.shape[0] - (
|
||||
self.right_context + self.right_context)
|
||||
ctx_win = (self.left_context + self.right_context + 1)
|
||||
ctx_dim = feats.shape[1] * ctx_win
|
||||
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
|
||||
for i in range(ctx_frm):
|
||||
feats_ctx[i] = torch.cat(
|
||||
tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
|
||||
|
||||
# update feature remained, and feats
|
||||
self.feature_remained = \
|
||||
feats[-(self.left_context + self.right_context):]
|
||||
feats = feats_ctx.to(self.device)
|
||||
if self.downsampling > 1:
|
||||
last_remainder = 0 if self.feats_ctx_offset == 0 \
|
||||
else self.downsampling - self.feats_ctx_offset
|
||||
remainder = (feats.size(0) + last_remainder) % self.downsampling
|
||||
feats = feats[self.feats_ctx_offset::self.downsampling, :]
|
||||
self.feats_ctx_offset = remainder \
|
||||
if remainder == 0 else self.downsampling - remainder
|
||||
return feats
|
||||
|
||||
def decode_keywords(self, t, probs):
|
||||
absolute_time = t + self.total_frames
|
||||
# search next_hyps depend on current probs and hyps.
|
||||
next_hyps = ctc_prefix_beam_search(absolute_time,
|
||||
probs,
|
||||
self.cur_hyps,
|
||||
self.keywords_idxset,
|
||||
self.score_beam)
|
||||
# update cur_hyps. note: the hyps is sort by path score(pnb+pb),
|
||||
# not the keywords' probabilities.
|
||||
cur_hyps = next_hyps[:self.path_beam]
|
||||
self.cur_hyps = cur_hyps
|
||||
|
||||
def execute_detection(self, t):
|
||||
absolute_time = t + self.total_frames
|
||||
hit_keyword = None
|
||||
start = 0
|
||||
end = 0
|
||||
|
||||
# hyps for detection
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
|
||||
|
||||
# detect keywords in decoding paths.
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
prefix_nodes = one_hyp[2]
|
||||
assert len(prefix_ids) == len(prefix_nodes)
|
||||
for word in self.keywords_token.keys():
|
||||
lab = self.keywords_token[word]['token_id']
|
||||
offset = is_sublist(prefix_ids, lab)
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
start = prefix_nodes[offset]['frame']
|
||||
end = prefix_nodes[offset + len(lab) - 1]['frame']
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
self.hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
if hit_keyword is not None:
|
||||
self.hit_score = math.sqrt(self.hit_score)
|
||||
break
|
||||
|
||||
duration = end - start
|
||||
if hit_keyword is not None:
|
||||
if self.hit_score >= self.threshold and \
|
||||
self.min_frames <= duration <= self.max_frames \
|
||||
and (self.last_active_pos == -1 or
|
||||
end - self.last_active_pos >= self.interval_frames):
|
||||
self.activated = True
|
||||
self.last_active_pos = end
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} "
|
||||
f"from {start} to {end} frame. "
|
||||
f"duration {duration}, score {self.hit_score}, Activated.")
|
||||
|
||||
elif self.last_active_pos > 0 and \
|
||||
end - self.last_active_pos < self.interval_frames:
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} "
|
||||
f"from {start} to {end} frame. "
|
||||
f"but interval {end-self.last_active_pos} "
|
||||
f"is lower than {self.interval_frames}, Deactivated. ")
|
||||
|
||||
elif self.hit_score < self.threshold:
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} "
|
||||
f"from {start} to {end} frame. "
|
||||
f"but {self.hit_score} "
|
||||
f"is lower than {self.threshold}, Deactivated. ")
|
||||
|
||||
elif self.min_frames > duration or duration > self.max_frames:
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} "
|
||||
f"from {start} to {end} frame. "
|
||||
f"but {duration} beyond range"
|
||||
f"({self.min_frames}~{self.max_frames}), Deactivated. ")
|
||||
|
||||
self.result = {
|
||||
"state": 1 if self.activated else 0,
|
||||
"keyword": hit_keyword if self.activated else None,
|
||||
"start": start * self.resolution if self.activated else None,
|
||||
"end": end * self.resolution if self.activated else None,
|
||||
"score": self.hit_score if self.activated else None
|
||||
}
|
||||
|
||||
def forward(self, wave_chunk):
|
||||
feature = self.accept_wave(wave_chunk)
|
||||
if feature is None or feature.size(0) < 1:
|
||||
return {} # # the feature is not enough to get result.
|
||||
feature = feature.unsqueeze(0) # add a batch dimension
|
||||
logits, self.in_cache = self.model(feature, self.in_cache)
|
||||
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||
probs = probs[0].cpu() # remove batch dimension
|
||||
for (t, prob) in enumerate(probs):
|
||||
t *= self.downsampling
|
||||
self.decode_keywords(t, prob)
|
||||
self.execute_detection(t)
|
||||
|
||||
if self.activated:
|
||||
self.reset()
|
||||
# since a chunk include about 30 frames,
|
||||
# once activated, we can jump the latter frames.
|
||||
# TODO: there should give another method to update result,
|
||||
# avoiding self.result being cleared.
|
||||
break
|
||||
|
||||
# update frame offset
|
||||
self.total_frames += len(probs) * self.downsampling
|
||||
return self.result
|
||||
|
||||
def reset(self):
|
||||
self.cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
self.activated = False
|
||||
self.hit_score = 1.0
|
||||
|
||||
def reset_all(self):
|
||||
self.reset()
|
||||
self.wave_remained = np.array([])
|
||||
self.feature_remained = None
|
||||
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
self.total_frames = 0 # frame offset, for absolute time
|
||||
self.last_active_pos = -1 # the last frame of being activated
|
||||
self.result = {}
|
||||
|
||||
def demo():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
kws = KeyWordSpotter(args.checkpoint,
|
||||
args.config,
|
||||
args.token_file,
|
||||
args.lexicon_file,
|
||||
args.threshold,
|
||||
args.min_frames,
|
||||
args.max_frames,
|
||||
args.interval_frames,
|
||||
args.score_beam_size,
|
||||
args.path_beam_size,
|
||||
args.gpu,
|
||||
args.jit_model)
|
||||
|
||||
# actually this could be done in __init__ method,
|
||||
# we pull it outside for changing keywords more freely.
|
||||
kws.set_keywords(args.keywords)
|
||||
|
||||
if args.wav_path:
|
||||
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
|
||||
# In demo we read wave in non-streaming fashion.
|
||||
# with wave.open(args.wav_path, 'rb') as fin:
|
||||
# assert fin.getnchannels() == 1
|
||||
# wav = fin.readframes(fin.getnframes())
|
||||
|
||||
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
|
||||
# NOTE: model supports 16k sample_rate
|
||||
wav = (y * (1 << 15)).astype("int16").tobytes()
|
||||
|
||||
# We inference every 0.3 seconds, in streaming fashion.
|
||||
interval = int(0.3 * 16000) * 2
|
||||
for i in range(0, len(wav), interval):
|
||||
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||
result = kws.forward(chunk_wav)
|
||||
print(result)
|
||||
|
||||
fout = None
|
||||
if args.result_file:
|
||||
fout = open(args.result_file, 'w', encoding='utf-8')
|
||||
|
||||
if args.wav_scp:
|
||||
with open(args.wav_scp, 'r') as fscp:
|
||||
for line in fscp:
|
||||
line = line.strip().split()
|
||||
assert len(line) == 2, \
|
||||
f"The scp should be in kaldi format: " \
|
||||
f"\"utt_name wav_path\", but got {line}"
|
||||
|
||||
utt_name, wav_path = line[0], line[1]
|
||||
# with wave.open(args.wav_path, 'rb') as fin:
|
||||
# assert fin.getnchannels() == 1
|
||||
# wav = fin.readframes(fin.getnframes())
|
||||
|
||||
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
|
||||
# NOTE: model supports 16k sample_rate
|
||||
wav = (y * (1 << 15)).astype("int16").tobytes()
|
||||
|
||||
kws.reset_all()
|
||||
activated = False
|
||||
|
||||
# We inference every 0.3 seconds, in streaming fashion.
|
||||
interval = int(0.3 * 16000) * 2
|
||||
for i in range(0, len(wav), interval):
|
||||
chunk_wav = wav[i: min(i + interval, len(wav))]
|
||||
result = kws.forward(chunk_wav)
|
||||
if 'state' in result and result['state'] == 1:
|
||||
activated = True
|
||||
if fout:
|
||||
hit_keyword = result['keyword']
|
||||
hit_score = result['score']
|
||||
fout.write('{} detected {} {:.3f}\n'.format(
|
||||
utt_name, hit_keyword, hit_score))
|
||||
|
||||
if not activated:
|
||||
if fout:
|
||||
fout.write('{} rejected\n'.format(utt_name))
|
||||
|
||||
|
||||
if fout:
|
||||
fout.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
demo()
|
||||
363
wekws/bin/stream_score_ctc.py
Normal file
363
wekws/bin/stream_score_ctc.py
Normal file
@ -0,0 +1,363 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||
# 2023 Jing Du(thuduj12@163.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from collections import defaultdict
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from wekws.dataset.dataset import Dataset
|
||||
from wekws.model.kws_model import init_model
|
||||
from wekws.utils.checkpoint import load_checkpoint
|
||||
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='recognize with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--test_data', required=True, help='test data file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--batch_size',
|
||||
default=1,
|
||||
type=int,
|
||||
help='batch size for inference')
|
||||
parser.add_argument('--num_workers',
|
||||
default=1,
|
||||
type=int,
|
||||
help='num of subprocess workers for reading')
|
||||
parser.add_argument('--pin_memory',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--prefetch',
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
parser.add_argument('--score_file',
|
||||
required=True,
|
||||
help='output score file')
|
||||
parser.add_argument('--jit_model',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--keywords', type=str, default=None,
|
||||
help='the keywords, split with comma(,)')
|
||||
parser.add_argument('--token_file', type=str, default=None,
|
||||
help='the path of tokens.txt')
|
||||
parser.add_argument('--lexicon_file', type=str, default=None,
|
||||
help='the path of lexicon.txt')
|
||||
parser.add_argument('--score_beam_size',
|
||||
default=3,
|
||||
type=int,
|
||||
help='The first prune beam, f'
|
||||
'ilter out those frames with low scores.')
|
||||
parser.add_argument('--path_beam_size',
|
||||
default=20,
|
||||
type=int,
|
||||
help='The second prune beam, '
|
||||
'keep only path_beam_size candidates.')
|
||||
parser.add_argument('--threshold',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='The threshold of kws. '
|
||||
'If ctc_search probs exceed this value,'
|
||||
'the keyword will be activated.')
|
||||
parser.add_argument('--min_frames',
|
||||
default=5,
|
||||
type=int,
|
||||
help='The min frames of keyword duration.')
|
||||
parser.add_argument('--max_frames',
|
||||
default=250,
|
||||
type=int,
|
||||
help='The max frames of keyword duration.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def is_sublist(main_list, check_list):
|
||||
if len(main_list) < len(check_list):
|
||||
return -1
|
||||
|
||||
if len(main_list) == len(check_list):
|
||||
return 0 if main_list == check_list else -1
|
||||
|
||||
for i in range(len(main_list) - len(check_list)):
|
||||
if main_list[i] == check_list[0]:
|
||||
for j in range(len(check_list)):
|
||||
if main_list[i + j] != check_list[j]:
|
||||
break
|
||||
else:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
|
||||
test_conf = copy.deepcopy(configs['dataset_conf'])
|
||||
test_conf['filter_conf']['max_length'] = 102400
|
||||
test_conf['filter_conf']['min_length'] = 0
|
||||
test_conf['speed_perturb'] = False
|
||||
test_conf['spec_aug'] = False
|
||||
test_conf['shuffle'] = False
|
||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||
|
||||
downsampling_factor = test_conf.get('frame_skip', 1)
|
||||
|
||||
test_dataset = Dataset(args.test_data, test_conf)
|
||||
test_data_loader = DataLoader(test_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
|
||||
if args.jit_model:
|
||||
model = torch.jit.load(args.checkpoint)
|
||||
# For script model, only cpu is supported.
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
# Init asr model from configs
|
||||
model = init_model(configs['model'])
|
||||
load_checkpoint(model, args.checkpoint)
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
score_abs_path = os.path.abspath(args.score_file)
|
||||
|
||||
token_table = read_token(args.token_file)
|
||||
lexicon_table = read_lexicon(args.lexicon_file)
|
||||
# 4. parse keywords tokens
|
||||
assert args.keywords is not None, 'at least one keyword is needed'
|
||||
logging.info(f"keywords is {args.keywords}, "
|
||||
f"Chinese is converted into Unicode.")
|
||||
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
|
||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||
keywords_token = {}
|
||||
keywords_idxset = {0}
|
||||
keywords_strset = {'<blk>'}
|
||||
keywords_tokenmap = {'<blk>': 0}
|
||||
for keyword in keywords_list:
|
||||
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
|
||||
keywords_token[keyword] = {}
|
||||
keywords_token[keyword]['token_id'] = indexes
|
||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||
for i in indexes)
|
||||
[keywords_strset.add(i) for i in strs]
|
||||
[keywords_idxset.add(i) for i in indexes]
|
||||
for txt, idx in zip(strs, indexes):
|
||||
if keywords_tokenmap.get(txt, None) is None:
|
||||
keywords_tokenmap[txt] = idx
|
||||
|
||||
token_print = ''
|
||||
for txt, idx in keywords_tokenmap.items():
|
||||
token_print += f'{txt}({idx}) '
|
||||
logging.info(f'Token set is: {token_print}')
|
||||
|
||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(test_data_loader):
|
||||
keys, feats, target, lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
logits, _ = model(feats)
|
||||
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||
logits = logits.cpu()
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = logits[i][:lengths[i]]
|
||||
# hyps = ctc_prefix_beam_search(score, lengths[i],
|
||||
# keywords_idxset)
|
||||
maxlen = score.size(0)
|
||||
ctc_probs = score
|
||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
|
||||
hit_keyword = None
|
||||
activated = False
|
||||
hit_score = 1.0
|
||||
start = 0
|
||||
end = 0
|
||||
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
probs = ctc_probs[t] # (vocab_size,)
|
||||
t *= downsampling_factor # the real time
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_probs, top_k_index = probs.topk(args.score_beam_size)
|
||||
|
||||
# filter prob score that is too small
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(
|
||||
top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if keywords_idxset is not None:
|
||||
if prob > 0.05 and idx in keywords_idxset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
else:
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
if len(filter_index) == 0:
|
||||
continue
|
||||
|
||||
for s in filter_index:
|
||||
ps = probs[s].item()
|
||||
|
||||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == 0: # blank
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pb = n_pb + pb * ps + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
elif s == last:
|
||||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pnb = n_pnb + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
# update frame and prob
|
||||
if ps > nodes[-1]['prob']:
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s,)
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
n_pnb = n_pnb + pb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(
|
||||
token=s, frame=t, prob=ps))
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
else:
|
||||
n_prefix = prefix + (s,)
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
if nodes:
|
||||
# update frame and prob
|
||||
if ps > nodes[-1]['prob']:
|
||||
# nodes[-1]['prob'] = ps
|
||||
# nodes[-1]['frame'] = t
|
||||
# avoid change other beam has this node.
|
||||
nodes.pop()
|
||||
nodes.append(dict(
|
||||
token=s, frame=t, prob=ps))
|
||||
else:
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(
|
||||
token=s, frame=t, prob=ps))
|
||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(
|
||||
next_hyps.items(),
|
||||
key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||
|
||||
cur_hyps = next_hyps[:args.path_beam_size]
|
||||
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2])
|
||||
for y in cur_hyps]
|
||||
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
prefix_nodes = one_hyp[2]
|
||||
assert len(prefix_ids) == len(prefix_nodes)
|
||||
for word in keywords_token.keys():
|
||||
lab = keywords_token[word]['token_id']
|
||||
offset = is_sublist(prefix_ids, lab)
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
start = prefix_nodes[offset]['frame']
|
||||
end = prefix_nodes[
|
||||
offset + len(lab) - 1]['frame']
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
if hit_keyword is not None:
|
||||
hit_score = math.sqrt(hit_score)
|
||||
break
|
||||
|
||||
duration = end - start
|
||||
if hit_keyword is not None:
|
||||
if hit_score >= args.threshold and \
|
||||
args.min_frames <= duration <= args.max_frames:
|
||||
activated = True
|
||||
fout.write('{} detected {} {:.3f}\n'.format(
|
||||
key, hit_keyword, hit_score))
|
||||
logging.info(
|
||||
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||
f"in {key} from {start} to {end} frame. "
|
||||
f"duration {duration}, s"
|
||||
f"core {hit_score} Activated.")
|
||||
|
||||
# clear the ctc_prefix buffer, and hit_keyword
|
||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
elif hit_score < args.threshold:
|
||||
logging.info(
|
||||
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||
f"in {key} from {start} to {end} frame. "
|
||||
f"but {hit_score} less than "
|
||||
f"{args.threshold}, Deactivated. ")
|
||||
elif args.min_frames > duration \
|
||||
or duration > args.max_frames:
|
||||
logging.info(
|
||||
f"batch:{batch_idx}_{i} detect {hit_keyword} "
|
||||
f"in {key} from {start} to {end} frame. "
|
||||
f"but {duration} beyond "
|
||||
f"range({args.min_frames}~{args.max_frames}), "
|
||||
f"Deactivated. ")
|
||||
if not activated:
|
||||
fout.write('{} rejected\n'.format(key))
|
||||
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
print('Progress batch {}'.format(batch_idx))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -134,7 +134,8 @@ def main():
|
||||
output_dim = args.num_keywords
|
||||
|
||||
# Write model_dir/config.yaml for inference and export
|
||||
configs['model']['input_dim'] = input_dim
|
||||
if 'input_dim' not in configs['model']:
|
||||
configs['model']['input_dim'] = input_dim
|
||||
configs['model']['output_dim'] = output_dim
|
||||
if args.cmvn_file is not None:
|
||||
configs['model']['cmvn'] = {}
|
||||
@ -156,8 +157,16 @@ def main():
|
||||
# Try to export the model by script, if fails, we should refine
|
||||
# the code to satisfy the script export requirements
|
||||
if rank == 0:
|
||||
script_model = torch.jit.script(model)
|
||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
pass
|
||||
# TODO: for now streaming FSMN do not support export to JITScript,
|
||||
# TODO: because there is nn.Sequential with Tuple input
|
||||
# in current FSMN modules.
|
||||
# the issue is in https://stackoverflow.com/questions/75714299/
|
||||
# pytorch-jit-script-error-when-sequential-container-
|
||||
# takes-a-tuple-input/76553450#76553450
|
||||
|
||||
# script_model = torch.jit.script(model)
|
||||
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
executor = Executor()
|
||||
# If specify checkpoint, load some info from checkpoint
|
||||
if args.checkpoint is not None:
|
||||
|
||||
@ -162,6 +162,16 @@ def Dataset(data_list_file, conf,
|
||||
spec_aug_conf = conf.get('spec_aug_conf', {})
|
||||
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
|
||||
|
||||
context_expansion = conf.get('context_expansion', False)
|
||||
if context_expansion:
|
||||
context_expansion_conf = conf.get('context_expansion_conf', {})
|
||||
dataset = Processor(dataset, processor.context_expansion,
|
||||
**context_expansion_conf)
|
||||
|
||||
frame_skip = conf.get('frame_skip', 1)
|
||||
if frame_skip > 1:
|
||||
dataset = Processor(dataset, processor.frame_skip, frame_skip)
|
||||
|
||||
if shuffle:
|
||||
shuffle_conf = conf.get('shuffle_conf', {})
|
||||
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
|
||||
|
||||
@ -263,6 +263,51 @@ def shuffle(data, shuffle_size=1000):
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
def context_expansion(data, left=1, right=1):
|
||||
""" expand left and right frames
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
left (int): feature left context frames
|
||||
right (int): feature right context frames
|
||||
|
||||
Returns:
|
||||
data: Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
index = 0
|
||||
feats = sample['feat']
|
||||
ctx_dim = feats.shape[0]
|
||||
ctx_frm = feats.shape[1] * (left + right + 1)
|
||||
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
|
||||
for lag in range(-left, right + 1):
|
||||
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
|
||||
feats, -lag, 0)
|
||||
index = index + feats.shape[1]
|
||||
|
||||
# replication pad left margin
|
||||
for idx in range(left):
|
||||
for cpx in range(left - idx):
|
||||
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
|
||||
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]
|
||||
|
||||
feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
|
||||
sample['feat'] = feats_ctx
|
||||
yield sample
|
||||
|
||||
|
||||
def frame_skip(data, skip_rate=1):
|
||||
""" skip frame
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
skip_rate (int): take every N-frames for model input
|
||||
|
||||
Returns:
|
||||
data: Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
feats_skip = sample['feat'][::skip_rate, :]
|
||||
sample['feat'] = feats_skip
|
||||
yield sample
|
||||
|
||||
def batch(data, batch_size=16):
|
||||
""" Static batch the data by `batch_size`
|
||||
@ -302,12 +347,24 @@ def padding(data):
|
||||
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
||||
sorted_feats = [sample[i]['feat'] for i in order]
|
||||
sorted_keys = [sample[i]['key'] for i in order]
|
||||
sorted_labels = torch.tensor([sample[i]['label'] for i in order],
|
||||
dtype=torch.int64)
|
||||
padded_feats = pad_sequence(sorted_feats,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
|
||||
|
||||
if isinstance(sample[0]['label'], int):
|
||||
padded_labels = torch.tensor([sample[i]['label'] for i in order],
|
||||
dtype=torch.int32)
|
||||
label_lengths = torch.tensor([1 for i in order],
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
sorted_labels = [
|
||||
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
|
||||
]
|
||||
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
|
||||
dtype=torch.int32)
|
||||
padded_labels = pad_sequence(
|
||||
sorted_labels, batch_first=True, padding_value=-1)
|
||||
yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths)
|
||||
|
||||
|
||||
def add_reverb(data, reverb_source, aug_prob):
|
||||
@ -320,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob):
|
||||
rir_io = io.BytesIO(rir_data)
|
||||
_, rir_audio = wavfile.read(rir_io)
|
||||
rir_audio = rir_audio.astype(np.float32)
|
||||
if len(rir_audio.shape) > 1:
|
||||
rir_audio = rir_audio[:, 0]
|
||||
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
||||
out_audio = signal.convolve(audio, rir_audio,
|
||||
mode='full')[:audio_len]
|
||||
@ -348,6 +407,8 @@ def add_noise(data, noise_source, aug_prob):
|
||||
snr_range = [0, 15]
|
||||
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
||||
noise_audio = noise_audio.astype(np.float32)
|
||||
if len(noise_audio.shape) > 1:
|
||||
noise_audio = noise_audio[:, 0]
|
||||
if noise_audio.shape[0] > audio_len:
|
||||
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
||||
noise_audio = noise_audio[start:start + audio_len]
|
||||
|
||||
558
wekws/model/fsmn.py
Normal file
558
wekws/model/fsmn.py
Normal file
@ -0,0 +1,558 @@
|
||||
'''
|
||||
FSMN implementation.
|
||||
|
||||
Copyright: 2022-03-09 yueyue.nyy
|
||||
2023 Jing Du
|
||||
'''
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def toKaldiMatrix(np_mat):
|
||||
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
|
||||
out_str = str(np_mat)
|
||||
out_str = out_str.replace('[', '')
|
||||
out_str = out_str.replace(']', '')
|
||||
return '[ %s ]\n' % out_str
|
||||
|
||||
|
||||
def printTensor(torch_tensor):
|
||||
re_str = ''
|
||||
x = torch_tensor.detach().squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
print(re_str)
|
||||
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else:
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
output = self.quant(input)
|
||||
output = self.linear(output)
|
||||
output = self.dequant(output)
|
||||
|
||||
return (output, in_cache)
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
re_str += '<LearnRateCoef> 1\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
linear_line = fread.readline()
|
||||
linear_split = linear_line.strip().split()
|
||||
assert len(linear_split) == 3
|
||||
assert linear_split[0] == '<LinearTransform>'
|
||||
self.output_dim = int(linear_split[1])
|
||||
self.input_dim = int(linear_split[2])
|
||||
|
||||
learn_rate_line = fread.readline()
|
||||
assert learn_rate_line.find('LearnRateCoef') != -1
|
||||
|
||||
self.linear.reset_parameters()
|
||||
|
||||
# linear_weights = self.state_dict()['linear.weight']
|
||||
# print(linear_weights.shape)
|
||||
new_weights = torch.zeros((self.output_dim, self.input_dim),
|
||||
dtype=torch.float32)
|
||||
for i in range(self.output_dim):
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.input_dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_weights[i, :] = cols
|
||||
|
||||
self.linear.weight.data = new_weights
|
||||
|
||||
|
||||
class AffineTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else:
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
output = self.quant(input)
|
||||
output = self.linear(output)
|
||||
output = self.dequant(output)
|
||||
|
||||
return (output, in_cache)
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
|
||||
linear_bias = self.state_dict()['linear.bias']
|
||||
x = linear_bias.squeeze().numpy()
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
affine_line = fread.readline()
|
||||
affine_split = affine_line.strip().split()
|
||||
assert len(affine_split) == 3
|
||||
assert affine_split[0] == '<AffineTransform>'
|
||||
self.output_dim = int(affine_split[1])
|
||||
self.input_dim = int(affine_split[2])
|
||||
print('AffineTransform output/input dim: %d %d' %
|
||||
(self.output_dim, self.input_dim))
|
||||
|
||||
learn_rate_line = fread.readline()
|
||||
assert learn_rate_line.find('LearnRateCoef') != -1
|
||||
|
||||
# linear_weights = self.state_dict()['linear.weight']
|
||||
# print(linear_weights.shape)
|
||||
self.linear.reset_parameters()
|
||||
|
||||
new_weights = torch.zeros((self.output_dim, self.input_dim),
|
||||
dtype=torch.float32)
|
||||
for i in range(self.output_dim):
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.input_dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_weights[i, :] = cols
|
||||
|
||||
self.linear.weight.data = new_weights
|
||||
|
||||
# linear_bias = self.state_dict()['linear.bias']
|
||||
# print(linear_bias.shape)
|
||||
bias_line = fread.readline()
|
||||
splits = bias_line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.output_dim
|
||||
new_bias = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
|
||||
self.linear.bias.data = new_bias
|
||||
|
||||
|
||||
class FSMNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
lorder=None,
|
||||
rorder=None,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
super(FSMNBlock, self).__init__()
|
||||
|
||||
self.dim = input_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim, [lorder, 1],
|
||||
dilation=[lstride, 1],
|
||||
groups=self.dim,
|
||||
bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
self.dim,
|
||||
self.dim, [rorder, 1],
|
||||
dilation=[rstride, 1],
|
||||
groups=self.dim,
|
||||
bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else :
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
if in_cache is None or len(in_cache) == 0 :
|
||||
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride
|
||||
+ self.rorder * self.rstride, 0])
|
||||
else:
|
||||
in_cache = in_cache.to(x_per.device)
|
||||
x_pad = torch.cat((in_cache, x_per), dim=2)
|
||||
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride
|
||||
+ self.rorder * self.rstride):, :]
|
||||
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
|
||||
y_left = self.quant(y_left)
|
||||
y_left = self.conv_left(y_left)
|
||||
y_left = self.dequant(y_left)
|
||||
out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
|
||||
self.rstride, :] + y_left
|
||||
|
||||
if self.conv_right is not None:
|
||||
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = x_pad[:, :, -(
|
||||
x_per.size(2) + self.rorder * self.rstride):, :]
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
y_right = self.quant(y_right)
|
||||
y_right = self.conv_right(y_right)
|
||||
y_right = self.dequant(y_right)
|
||||
out += y_right
|
||||
|
||||
out_per = out.permute(0, 3, 2, 1)
|
||||
output = out_per.squeeze(1)
|
||||
|
||||
return (output, in_cache)
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
|
||||
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
|
||||
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
|
||||
1, self.lorder, self.rorder, self.lstride, self.rstride)
|
||||
|
||||
# print(self.conv_left.weight,self.conv_right.weight)
|
||||
lfiters = self.state_dict()['conv_left.weight']
|
||||
x = np.flipud(lfiters.squeeze().numpy().T)
|
||||
re_str += toKaldiMatrix(x)
|
||||
|
||||
if self.conv_right is not None:
|
||||
rfiters = self.state_dict()['conv_right.weight']
|
||||
x = (rfiters.squeeze().numpy().T)
|
||||
re_str += toKaldiMatrix(x)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
fsmn_line = fread.readline()
|
||||
fsmn_split = fsmn_line.strip().split()
|
||||
assert len(fsmn_split) == 3
|
||||
assert fsmn_split[0] == '<Fsmn>'
|
||||
self.dim = int(fsmn_split[1])
|
||||
|
||||
params_line = fread.readline()
|
||||
params_split = params_line.strip().strip('[]').strip().split()
|
||||
assert len(params_split) == 12
|
||||
assert params_split[0] == '<LearnRateCoef>'
|
||||
assert params_split[2] == '<LOrder>'
|
||||
self.lorder = int(params_split[3])
|
||||
assert params_split[4] == '<ROrder>'
|
||||
self.rorder = int(params_split[5])
|
||||
assert params_split[6] == '<LStride>'
|
||||
self.lstride = int(params_split[7])
|
||||
assert params_split[8] == '<RStride>'
|
||||
self.rstride = int(params_split[9])
|
||||
assert params_split[10] == '<MaxNorm>'
|
||||
|
||||
# lfilters = self.state_dict()['conv_left.weight']
|
||||
# print(lfilters.shape)
|
||||
print('read conv_left weight')
|
||||
new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
|
||||
dtype=torch.float32)
|
||||
for i in range(self.lorder):
|
||||
print('read conv_left weight -- %d' % i)
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
|
||||
|
||||
new_lfilters = torch.transpose(new_lfilters, 0, 2)
|
||||
# print(new_lfilters.shape)
|
||||
|
||||
self.conv_left.reset_parameters()
|
||||
self.conv_left.weight.data = new_lfilters
|
||||
# print(self.conv_left.weight.shape)
|
||||
|
||||
if self.rorder > 0:
|
||||
# rfilters = self.state_dict()['conv_right.weight']
|
||||
# print(rfilters.shape)
|
||||
print('read conv_right weight')
|
||||
new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
|
||||
dtype=torch.float32)
|
||||
line = fread.readline()
|
||||
for i in range(self.rorder):
|
||||
print('read conv_right weight -- %d' % i)
|
||||
line = fread.readline()
|
||||
splits = line.strip().strip('[]').strip().split()
|
||||
assert len(splits) == self.dim
|
||||
cols = torch.tensor([float(item) for item in splits],
|
||||
dtype=torch.float32)
|
||||
new_rfilters[i, 0, :, 0] = cols
|
||||
|
||||
new_rfilters = torch.transpose(new_rfilters, 0, 2)
|
||||
# print(new_rfilters.shape)
|
||||
self.conv_right.reset_parameters()
|
||||
self.conv_right.weight.data = new_rfilters
|
||||
# print(self.conv_right.weight.shape)
|
||||
|
||||
|
||||
class RectifiedLinear(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self,
|
||||
input: Tuple[torch.Tensor, torch.Tensor]):
|
||||
if isinstance(input, tuple):
|
||||
input, in_cache = input
|
||||
else :
|
||||
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
out = self.relu(input)
|
||||
# out = self.dropout(out)
|
||||
return (out, in_cache)
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
return re_str
|
||||
|
||||
# re_str = ''
|
||||
# re_str += '<ParametricRelu> %d %d\n' % (self.dim, self.dim)
|
||||
# re_str += '<AlphaLearnRateCoef> 0 <BetaLearnRateCoef> 0\n'
|
||||
# re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32'))
|
||||
# re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32'))
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
# return re_str
|
||||
|
||||
def to_pytorch_net(self, fread):
|
||||
line = fread.readline()
|
||||
splits = line.strip().split()
|
||||
assert len(splits) == 3
|
||||
assert splits[0] == '<RectifiedLinear>'
|
||||
assert int(splits[1]) == int(splits[2])
|
||||
assert int(splits[1]) == self.dim
|
||||
self.dim = int(splits[1])
|
||||
|
||||
|
||||
def _build_repeats(
|
||||
fsmn_layers: int,
|
||||
linear_dim: int,
|
||||
proj_dim: int,
|
||||
lorder: int,
|
||||
rorder: int,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
repeats = [
|
||||
nn.Sequential(
|
||||
LinearTransform(linear_dim, proj_dim),
|
||||
FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
|
||||
AffineTransform(proj_dim, linear_dim),
|
||||
RectifiedLinear(linear_dim, linear_dim))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
|
||||
class FSMN(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
input_affine_dim: int,
|
||||
fsmn_layers: int,
|
||||
linear_dim: int,
|
||||
proj_dim: int,
|
||||
lorder: int,
|
||||
rorder: int,
|
||||
lstride: int,
|
||||
rstride: int,
|
||||
output_affine_dim: int,
|
||||
output_dim: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_dim: input dimension
|
||||
input_affine_dim: input affine layer dimension
|
||||
fsmn_layers: no. of fsmn units
|
||||
linear_dim: fsmn input dimension
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
lstride: fsmn left stride
|
||||
rstride: fsmn right stride
|
||||
output_affine_dim: output affine layer dimension
|
||||
output_dim: output dimension
|
||||
"""
|
||||
super(FSMN, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.input_affine_dim = input_affine_dim
|
||||
self.fsmn_layers = fsmn_layers
|
||||
self.linear_dim = linear_dim
|
||||
self.proj_dim = proj_dim
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
self.output_affine_dim = output_affine_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.padding = (self.lorder - 1) * self.lstride \
|
||||
+ self.rorder * self.rstride
|
||||
|
||||
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||
self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
|
||||
self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder,
|
||||
rorder, lstride, rstride)
|
||||
|
||||
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
|
||||
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
|
||||
# self.softmax = nn.Softmax(dim = -1)
|
||||
|
||||
def fuse_modules(self):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor (B, T, D)
|
||||
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
|
||||
"""
|
||||
|
||||
if in_cache is None or len(in_cache) == 0 :
|
||||
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float)
|
||||
for _ in range(len(self.fsmn))]
|
||||
input = (input, in_cache)
|
||||
x1 = self.in_linear1(input)
|
||||
x2 = self.in_linear2(x1)
|
||||
x3 = self.relu(x2)
|
||||
# x4 = self.fsmn(x3)
|
||||
x4, _ = x3
|
||||
for layer, module in enumerate(self.fsmn):
|
||||
x4, in_cache[layer] = module((x4, in_cache[layer]))
|
||||
x5 = self.out_linear1(x4)
|
||||
x6 = self.out_linear2(x5)
|
||||
# x7 = self.softmax(x6)
|
||||
x7, _ = x6
|
||||
# return x7, None
|
||||
return x7, in_cache
|
||||
|
||||
def to_kaldi_net(self):
|
||||
re_str = ''
|
||||
re_str += '<Nnet>\n'
|
||||
re_str += self.in_linear1.to_kaldi_net()
|
||||
re_str += self.in_linear2.to_kaldi_net()
|
||||
re_str += self.relu.to_kaldi_net()
|
||||
|
||||
for fsmn in self.fsmn:
|
||||
re_str += fsmn[0].to_kaldi_net()
|
||||
re_str += fsmn[1].to_kaldi_net()
|
||||
re_str += fsmn[2].to_kaldi_net()
|
||||
re_str += fsmn[3].to_kaldi_net()
|
||||
|
||||
re_str += self.out_linear1.to_kaldi_net()
|
||||
re_str += self.out_linear2.to_kaldi_net()
|
||||
re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
|
||||
# re_str += '<!EndOfComponent>\n'
|
||||
re_str += '</Nnet>\n'
|
||||
|
||||
return re_str
|
||||
|
||||
def to_pytorch_net(self, kaldi_file):
|
||||
with open(kaldi_file, 'r', encoding='utf8') as fread:
|
||||
fread = open(kaldi_file, 'r')
|
||||
nnet_start_line = fread.readline()
|
||||
assert nnet_start_line.strip() == '<Nnet>'
|
||||
|
||||
self.in_linear1.to_pytorch_net(fread)
|
||||
self.in_linear2.to_pytorch_net(fread)
|
||||
self.relu.to_pytorch_net(fread)
|
||||
|
||||
for fsmn in self.fsmn:
|
||||
fsmn[0].to_pytorch_net(fread)
|
||||
fsmn[1].to_pytorch_net(fread)
|
||||
fsmn[2].to_pytorch_net(fread)
|
||||
fsmn[3].to_pytorch_net(fread)
|
||||
|
||||
self.out_linear1.to_pytorch_net(fread)
|
||||
self.out_linear2.to_pytorch_net(fread)
|
||||
|
||||
softmax_line = fread.readline()
|
||||
softmax_split = softmax_line.strip().split()
|
||||
assert softmax_split[0].strip() == '<Softmax>'
|
||||
assert int(softmax_split[1]) == self.output_dim
|
||||
assert int(softmax_split[2]) == self.output_dim
|
||||
# '<!EndOfComponent>\n'
|
||||
|
||||
nnet_end_line = fread.readline()
|
||||
assert nnet_end_line.strip() == '</Nnet>'
|
||||
fread.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
|
||||
print(fsmn)
|
||||
|
||||
num_params = sum(p.numel() for p in fsmn.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 400) # batch-size * time * dim
|
||||
y, _ = fsmn(x) # batch-size * time * dim
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
|
||||
print(fsmn.to_kaldi_net())
|
||||
@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
# 2023 Jing Du
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -25,14 +26,15 @@ from wekws.model.subsampling import (LinearSubsampling1, Conv1dSubsampling1,
|
||||
NoSubsampling)
|
||||
from wekws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||
from wekws.model.mdtc import MDTC
|
||||
from wekws.utils.cmvn import load_cmvn
|
||||
from wekws.utils.cmvn import load_cmvn, load_kaldi_cmvn
|
||||
from wekws.model.fsmn import FSMN
|
||||
|
||||
|
||||
class KWSModel(nn.Module):
|
||||
"""Our model consists of four parts:
|
||||
1. global_cmvn: Optional, (idim, idim)
|
||||
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||
3. backbone: backbone of the whole network, (hdim, hdim)
|
||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||
5. activation:
|
||||
nn.Sigmoid for wakeup word
|
||||
@ -72,6 +74,20 @@ class KWSModel(nn.Module):
|
||||
x = self.activation(x)
|
||||
return x, out_cache
|
||||
|
||||
def forward_softmax(self,
|
||||
x: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(
|
||||
0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
x = self.preprocessing(x)
|
||||
x, out_cache = self.backbone(x, in_cache)
|
||||
x = self.classifier(x)
|
||||
x = self.activation(x)
|
||||
x = x.softmax(2)
|
||||
return x, out_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
self.preprocessing.fuse_modules()
|
||||
self.backbone.fuse_modules()
|
||||
@ -80,7 +96,10 @@ class KWSModel(nn.Module):
|
||||
def init_model(configs):
|
||||
cmvn = configs.get('cmvn', {})
|
||||
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
|
||||
mean, istd = load_cmvn(cmvn['cmvn_file'])
|
||||
if "kaldi" in cmvn['cmvn_file']:
|
||||
mean, istd = load_kaldi_cmvn(cmvn['cmvn_file'])
|
||||
else:
|
||||
mean, istd = load_cmvn(cmvn['cmvn_file'])
|
||||
global_cmvn = GlobalCMVN(
|
||||
torch.from_numpy(mean).float(),
|
||||
torch.from_numpy(istd).float(),
|
||||
@ -135,6 +154,20 @@ def init_model(configs):
|
||||
hidden_dim,
|
||||
kernel_size,
|
||||
causal=causal)
|
||||
elif backbone_type == 'fsmn':
|
||||
input_affine_dim = configs['backbone']['input_affine_dim']
|
||||
num_layers = configs['backbone']['num_layers']
|
||||
linear_dim = configs['backbone']['linear_dim']
|
||||
proj_dim = configs['backbone']['proj_dim']
|
||||
left_order = configs['backbone']['left_order']
|
||||
right_order = configs['backbone']['right_order']
|
||||
left_stride = configs['backbone']['left_stride']
|
||||
right_stride = configs['backbone']['right_stride']
|
||||
output_affine_dim = configs['backbone']['output_affine_dim']
|
||||
backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
|
||||
proj_dim, left_order, right_order, left_stride,
|
||||
right_stride, output_affine_dim, output_dim)
|
||||
|
||||
else:
|
||||
print('Unknown body type {}'.format(backbone_type))
|
||||
sys.exit(1)
|
||||
@ -154,6 +187,8 @@ def init_model(configs):
|
||||
# last means we use last frame to do backpropagation, so the model
|
||||
# can be infered streamingly
|
||||
classifier = LastClassifier(classifier_base)
|
||||
elif classifier_type == 'identity':
|
||||
classifier = nn.Identity()
|
||||
else:
|
||||
print('Unknown classifier type {}'.format(classifier_type))
|
||||
sys.exit(1)
|
||||
@ -162,6 +197,17 @@ def init_model(configs):
|
||||
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||
activation = nn.Sigmoid()
|
||||
|
||||
# Here we add a possible "activation_type",
|
||||
# one can choose to use other activation function.
|
||||
# We use nn.Identity just for CTC loss
|
||||
if "activation" in configs:
|
||||
activation_type = configs["activation"]["type"]
|
||||
if activation_type == 'identity':
|
||||
activation = nn.Identity()
|
||||
else:
|
||||
print('Unknown activation type {}'.format(activation_type))
|
||||
sys.exit(1)
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone, classifier, activation)
|
||||
return kws_model
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
# 2023 Jing Du
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,7 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import torch.nn.functional as F
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple
|
||||
|
||||
from wekws.utils.mask import padding_mask
|
||||
|
||||
@ -93,6 +98,65 @@ def acc_frame(
|
||||
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
||||
return correct * 100.0 / logits.size(0)
|
||||
|
||||
def acc_utterance(logits: torch.Tensor, target: torch.Tensor,
|
||||
logits_length: torch.Tensor, target_length: torch.Tensor):
|
||||
if logits is None:
|
||||
return 0
|
||||
|
||||
logits = logits.softmax(2) # (1, maxlen, vocab_size)
|
||||
logits = logits.cpu()
|
||||
target = target.cpu()
|
||||
|
||||
total_word = 0
|
||||
total_ins = 0
|
||||
total_sub = 0
|
||||
total_del = 0
|
||||
calculator = Calculator()
|
||||
for i in range(logits.size(0)):
|
||||
score = logits[i][:logits_length[i]]
|
||||
hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5)
|
||||
lab = [str(item) for item in target[i][:target_length[i]].tolist()]
|
||||
rec = []
|
||||
if len(hyps) > 0:
|
||||
rec = [str(item) for item in hyps[0][0]]
|
||||
result = calculator.calculate(lab, rec)
|
||||
# print(f'result:{result}')
|
||||
if result['all'] != 0:
|
||||
total_word += result['all']
|
||||
total_ins += result['ins']
|
||||
total_sub += result['sub']
|
||||
total_del += result['del']
|
||||
|
||||
return float(total_word - total_ins - total_sub
|
||||
- total_del) * 100.0 / total_word
|
||||
|
||||
def ctc_loss(logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
need_acc: bool = False):
|
||||
""" CTC Loss
|
||||
Args:
|
||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||
target: (B)
|
||||
logits_lengths: (B)
|
||||
target_lengths: (B)
|
||||
Returns:
|
||||
(float): loss of current batch
|
||||
"""
|
||||
|
||||
acc = 0.0
|
||||
if need_acc:
|
||||
acc = acc_utterance(logits, target, logits_lengths, target_lengths)
|
||||
|
||||
# logits: (B, L, D) -> (L, B, D)
|
||||
logits = logits.transpose(0, 1)
|
||||
logits = logits.log_softmax(2)
|
||||
loss = F.ctc_loss(
|
||||
logits, target, logits_lengths, target_lengths, reduction='sum')
|
||||
loss = loss / logits.size(1) # batch mean
|
||||
|
||||
return loss, acc
|
||||
|
||||
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
||||
""" Cross Entropy Loss
|
||||
@ -114,12 +178,284 @@ def criterion(type: str,
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
min_duration: int = 0):
|
||||
target_lengths: torch.Tensor = None,
|
||||
min_duration: int = 0,
|
||||
validation: bool = False, ):
|
||||
if type == 'ce':
|
||||
loss, acc = cross_entropy(logits, target)
|
||||
return loss, acc
|
||||
elif type == 'max_pooling':
|
||||
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
||||
return loss, acc
|
||||
elif type == 'ctc':
|
||||
loss, acc = ctc_loss(
|
||||
logits, target, lengths, target_lengths, validation)
|
||||
return loss, acc
|
||||
else:
|
||||
exit(1)
|
||||
|
||||
def ctc_prefix_beam_search(
|
||||
logits: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
keywords_tokenset: set = None,
|
||||
score_beam_size: int = 3,
|
||||
path_beam_size: int = 20,
|
||||
) -> Tuple[List[List[int]], torch.Tensor]:
|
||||
""" CTC prefix beam search inner implementation
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (1, max_len, vocab_size)
|
||||
logits_lengths (torch.Tensor): (1, )
|
||||
keywords_tokenset (set): token set for filtering score
|
||||
score_beam_size (int): beam size for score
|
||||
path_beam_size (int): beam size for path
|
||||
|
||||
Returns:
|
||||
List[List[int]]: nbest results
|
||||
"""
|
||||
maxlen = logits.size(0)
|
||||
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
|
||||
ctc_probs = logits
|
||||
|
||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
|
||||
# 2. CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
probs = ctc_probs[t] # (vocab_size,)
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_probs, top_k_index = probs.topk(
|
||||
score_beam_size) # (score_beam_size,)
|
||||
|
||||
# filter prob score that is too small
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if keywords_tokenset is not None:
|
||||
if prob > 0.05 and idx in keywords_tokenset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
else:
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
if len(filter_index) == 0:
|
||||
continue
|
||||
|
||||
for s in filter_index:
|
||||
ps = probs[s].item()
|
||||
|
||||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == 0: # blank
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pb = n_pb + pb * ps + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
elif s == last:
|
||||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pnb = n_pnb + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
n_pnb = n_pnb + pb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
else:
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
if nodes:
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
# nodes[-1]['prob'] = ps
|
||||
# nodes[-1]['frame'] = t
|
||||
# avoid change other beam which has this node.
|
||||
nodes.pop()
|
||||
nodes.append(dict(token=s, frame=t, prob=ps))
|
||||
else:
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(
|
||||
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||
|
||||
cur_hyps = next_hyps[:path_beam_size]
|
||||
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
class Calculator:
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.space = []
|
||||
self.cost = {}
|
||||
self.cost['cor'] = 0
|
||||
self.cost['sub'] = 1
|
||||
self.cost['del'] = 1
|
||||
self.cost['ins'] = 1
|
||||
|
||||
def calculate(self, lab, rec):
|
||||
# Initialization
|
||||
lab.insert(0, '')
|
||||
rec.insert(0, '')
|
||||
while len(self.space) < len(lab):
|
||||
self.space.append([])
|
||||
for row in self.space:
|
||||
for element in row:
|
||||
element['dist'] = 0
|
||||
element['error'] = 'non'
|
||||
while len(row) < len(rec):
|
||||
row.append({'dist': 0, 'error': 'non'})
|
||||
for i in range(len(lab)):
|
||||
self.space[i][0]['dist'] = i
|
||||
self.space[i][0]['error'] = 'del'
|
||||
for j in range(len(rec)):
|
||||
self.space[0][j]['dist'] = j
|
||||
self.space[0][j]['error'] = 'ins'
|
||||
self.space[0][0]['error'] = 'non'
|
||||
for token in lab:
|
||||
if token not in self.data and len(token) > 0:
|
||||
self.data[token] = {
|
||||
'all': 0,
|
||||
'cor': 0,
|
||||
'sub': 0,
|
||||
'ins': 0,
|
||||
'del': 0
|
||||
}
|
||||
for token in rec:
|
||||
if token not in self.data and len(token) > 0:
|
||||
self.data[token] = {
|
||||
'all': 0,
|
||||
'cor': 0,
|
||||
'sub': 0,
|
||||
'ins': 0,
|
||||
'del': 0
|
||||
}
|
||||
# Computing edit distance
|
||||
for i, lab_token in enumerate(lab):
|
||||
for j, rec_token in enumerate(rec):
|
||||
if i == 0 or j == 0:
|
||||
continue
|
||||
min_dist = sys.maxsize
|
||||
min_error = 'none'
|
||||
dist = self.space[i - 1][j]['dist'] + self.cost['del']
|
||||
error = 'del'
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
|
||||
error = 'ins'
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
if lab_token == rec_token:
|
||||
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
|
||||
error = 'cor'
|
||||
else:
|
||||
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
|
||||
error = 'sub'
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
self.space[i][j]['dist'] = min_dist
|
||||
self.space[i][j]['error'] = min_error
|
||||
# Tracing back
|
||||
result = {
|
||||
'lab': [],
|
||||
'rec': [],
|
||||
'all': 0,
|
||||
'cor': 0,
|
||||
'sub': 0,
|
||||
'ins': 0,
|
||||
'del': 0
|
||||
}
|
||||
i = len(lab) - 1
|
||||
j = len(rec) - 1
|
||||
while True:
|
||||
if self.space[i][j]['error'] == 'cor': # correct
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
|
||||
result['all'] = result['all'] + 1
|
||||
result['cor'] = result['cor'] + 1
|
||||
result['lab'].insert(0, lab[i])
|
||||
result['rec'].insert(0, rec[j])
|
||||
i = i - 1
|
||||
j = j - 1
|
||||
elif self.space[i][j]['error'] == 'sub': # substitution
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
|
||||
result['all'] = result['all'] + 1
|
||||
result['sub'] = result['sub'] + 1
|
||||
result['lab'].insert(0, lab[i])
|
||||
result['rec'].insert(0, rec[j])
|
||||
i = i - 1
|
||||
j = j - 1
|
||||
elif self.space[i][j]['error'] == 'del': # deletion
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
|
||||
result['all'] = result['all'] + 1
|
||||
result['del'] = result['del'] + 1
|
||||
result['lab'].insert(0, lab[i])
|
||||
result['rec'].insert(0, '')
|
||||
i = i - 1
|
||||
elif self.space[i][j]['error'] == 'ins': # insertion
|
||||
if len(rec[j]) > 0:
|
||||
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
|
||||
result['ins'] = result['ins'] + 1
|
||||
result['lab'].insert(0, '')
|
||||
result['rec'].insert(0, rec[j])
|
||||
j = j - 1
|
||||
elif self.space[i][j]['error'] == 'non': # starting point
|
||||
break
|
||||
else: # shouldn't reach here
|
||||
print(
|
||||
'this should not happen, '
|
||||
'i = {i} , j = {j} , error = {error}'
|
||||
.format(i=i, j=j, error=self.space[i][j]['error']))
|
||||
return result
|
||||
|
||||
def overall(self):
|
||||
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
||||
for token in self.data:
|
||||
result['all'] = result['all'] + self.data[token]['all']
|
||||
result['cor'] = result['cor'] + self.data[token]['cor']
|
||||
result['sub'] = result['sub'] + self.data[token]['sub']
|
||||
result['ins'] = result['ins'] + self.data[token]['ins']
|
||||
result['del'] = result['del'] + self.data[token]['del']
|
||||
return result
|
||||
|
||||
def cluster(self, data):
|
||||
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
||||
for token in data:
|
||||
if token in self.data:
|
||||
result['all'] = result['all'] + self.data[token]['all']
|
||||
result['cor'] = result['cor'] + self.data[token]['cor']
|
||||
result['sub'] = result['sub'] + self.data[token]['sub']
|
||||
result['ins'] = result['ins'] + self.data[token]['ins']
|
||||
result['del'] = result['del'] + self.data[token]['del']
|
||||
return result
|
||||
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -42,3 +43,50 @@ def load_cmvn(json_cmvn_file):
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
def load_kaldi_cmvn(cmvn_file):
|
||||
""" Load the kaldi format cmvn stats file and no need to calculate
|
||||
|
||||
Args:
|
||||
cmvn_file: cmvn stats file in kaldi format
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
|
||||
means = None
|
||||
variance = None
|
||||
with open(cmvn_file) as f:
|
||||
all_lines = f.readlines()
|
||||
for idx, line in enumerate(all_lines):
|
||||
if line.find('AddShift') != -1:
|
||||
segs = line.strip().split(' ')
|
||||
assert len(segs) == 3
|
||||
next_line = all_lines[idx + 1]
|
||||
means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
|
||||
means_list = means_str.strip().split(' ')
|
||||
means = [0 - float(s) for s in means_list]
|
||||
assert len(means) == int(segs[1])
|
||||
elif line.find('Rescale') != -1:
|
||||
segs = line.strip().split(' ')
|
||||
assert len(segs) == 3
|
||||
next_line = all_lines[idx + 1]
|
||||
vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
|
||||
vars_list = vars_str.strip().split(' ')
|
||||
variance = [float(s) for s in vars_list]
|
||||
assert len(variance) == int(segs[1])
|
||||
elif line.find('Splice') != -1:
|
||||
segs = line.strip().split(' ')
|
||||
assert len(segs) == 3
|
||||
next_line = all_lines[idx + 1]
|
||||
splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
|
||||
splice_list = splice_str.strip().split(' ')
|
||||
assert len(splice_list) * int(segs[2]) == int(segs[1])
|
||||
copy_times = len(splice_list)
|
||||
else:
|
||||
continue
|
||||
|
||||
cmvn = np.array([means, variance])
|
||||
cmvn = np.tile(cmvn, (1, copy_times))
|
||||
|
||||
return cmvn
|
||||
|
||||
@ -34,17 +34,20 @@ class Executor:
|
||||
min_duration = args.get('min_duration', 0)
|
||||
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths = batch
|
||||
key, feats, target, feats_lengths, label_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
label_lengths = label_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits, _ = model(feats)
|
||||
loss_type = args.get('criterion', 'max_pooling')
|
||||
loss, acc = criterion(loss_type, logits, target, feats_lengths,
|
||||
min_duration)
|
||||
target_lengths=label_lengths,
|
||||
min_duration=min_duration,
|
||||
validation=False)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
@ -67,16 +70,20 @@ class Executor:
|
||||
total_acc = 0.0
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths = batch
|
||||
key, feats, target, feats_lengths, label_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
label_lengths = label_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits, _ = model(feats)
|
||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
||||
logits, target, feats_lengths)
|
||||
logits, target, feats_lengths,
|
||||
target_lengths=label_lengths,
|
||||
min_duration=0,
|
||||
validation=True)
|
||||
if torch.isfinite(loss):
|
||||
num_seen_utts += num_utts
|
||||
total_loss += loss.item() * num_utts
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user