[ctc] KWS with CTCloss training and CTC prefix beam search detection. (#135)

* add ctcloss training scripts.

* update compute_det_ctc

* fix typo.

* add fsmn model, can use pretrained kws model from modelscope.

* Add streaming detection of CTC model. Add CTC model onnx export. Add CTC model's result in README; For now CTC model runtime is not supported yet.

* QA run.sh, maxpooling training scripts is compatible. Ready to PR.

* Add a streaming kws demo, support fsmn online forward

* fix typo.

* Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward.

* fix repeat activation, add a interval restrict.

* fix timestamp when subsampling!=1.

* fix flake8, update training script and README, give pretrained ckpt.

* fix quickcheck and flake8

* Add realtime CTC-KWS demo in README.

---------

Co-authored-by: dujing <dujing@xmov.ai>
This commit is contained in:
Jean Du 2023-08-16 10:07:04 +08:00 committed by GitHub
parent 85350c38a8
commit b233d46552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 3328 additions and 19 deletions

View File

@ -1,3 +1,5 @@
Comparison among different backbones,
all models use Max-Pooling loss.
FRRs with FAR fixed at once per hour:
| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
@ -8,3 +10,58 @@ FRRs with FAR fixed at once per hour:
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.
Since the FAR is pretty low when using CTC loss,
the follow results are FRRs with FAR fixed at once per 12 hours:
Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
FRRs with FAR fixed at once per 12 hours
| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|------------|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
FRRs with FAR fixed at once per 12 hours:
| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |
Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.
Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:
| model | stream | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 |
| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 |
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |
Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
we record the probability of a keyword in a decoding path once the keyword appears in this path.
Actually the probability will increase through the time, so we record a lower value of probability,
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.
On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).
For realtime CTC-KWS, we should process wave input on streaming-fashion,
include feature extraction, keyword decoding and detection and some postprocessing.
Here is a [demo](https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/summary) in python,
the core code is in wekws/bin/stream_kws_ctc.py, you can refer it to implement the runtime code.

View File

@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256
model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001
training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

View File

@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 200
model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001
training_config:
grad_clip: 5
max_epoch: 50
log_interval: 100
criterion: ctc

View File

@ -0,0 +1,64 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.
context_expansion: true
context_expansion_conf:
left: 2
right: 2
frame_skip: 3
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256
model:
input_dim: 400
preprocessing:
type: none
hidden_dim: 128
backbone:
type: fsmn
input_affine_dim: 140
num_layers: 4
linear_dim: 250
proj_dim: 128
left_order: 10
right_order: 2
left_stride: 1
right_stride: 1
output_affine_dim: 140
classifier:
type: identity
dropout: 0.1
activation:
type: identity
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001
training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

View File

@ -3,8 +3,8 @@
. ./path.sh
stage=0
stop_stage=4
stage=$1
stop_stage=$2
num_keywords=2
config=conf/ds_tcn.yaml
@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python wekws/bin/score.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
@ -111,6 +112,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt
done
# plot det curve
python wekws/bin/plot_det_curve.py \
--keywords_dict dict/words.txt \
--stats_dir $result_dir \
--figure_file $result_dir/det.png
fi

View File

@ -0,0 +1,223 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
# 2023 Jing Du(thuduj12@163.com)
. ./path.sh
stage=$1
stop_stage=$2
num_keywords=2599
config=conf/ds_tcn_ctc.yaml
norm_mean=true
norm_var=true
gpus="0"
checkpoint=
dir=exp/ds_tcn_ctc
average_model=true
num_average=30
if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
. tools/parse_options.sh || exit 1;
window_shift=50
#Whether to train base model. If set true, must put train+dev data in trainbase_dir
trainbase=false
trainbase_dir=data/base
trainbase_config=conf/ds_tcn_ctc_base.yaml
trainbase_exp=exp/base
if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
echo "Download and extracte all datasets"
local/mobvoi_data_download.sh --dl_dir $download_dir
fi
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
echo "Preparing datasets..."
mkdir -p dict
echo "<filler> -1" > dict/words.txt
echo "Hi_Xiaowen 0" >> dict/words.txt
echo "Nihao_Wenwen 1" >> dict/words.txt
for folder in train dev test; do
mkdir -p data/$folder
for prefix in p n; do
mkdir -p data/${prefix}_$folder
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
data/${prefix}_$folder
done
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
rm -rf data/p_$folder data/n_$folder
done
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
# to transcribe the negative wavs, and upload the transcription to modelscope.
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
for folder in train dev test; do
if [ -f data/$folder/text ];then
mv data/$folder/text data/$folder/text.label
fi
cp mobvoi_kws_transcription/$folder.text data/$folder/text
done
# and we also copy the tokens and lexicon that used in
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
# Here we use tokens.txt and lexicon.txt to convert txt into index
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
for x in train dev ; do
if [ ! -f $trainbase_dir/$x/wav.scp ] || [ ! -f $trainbase_dir/$x/text ]; then
echo "If You Want to Train Base KWS-CTC Model, You Should Prepare ASR Data by Yourself."
echo "The wav.scp and text in KALDI-format is Needed, You Should Put Them in $trainbase_dir/$x"
exit
fi
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
fi
# Here we use tokens.txt and lexicon.txt to convert txt into index
if [ ! -f $trainbase_dir/$x/data.list ]; then
tools/make_list.py $trainbase_dir/$x/wav.scp $trainbase_dir/$x/text \
$trainbase_dir/$x/wav.dur $trainbase_dir/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi
done
echo "Start base training ..."
mkdir -p $trainbase_exp
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $trainbase_config \
--train_data $trainbase_dir/train/data.list \
--cv_data $trainbase_dir/dev/data.list \
--model_dir $trainbase_exp \
--num_workers 2 \
--ddp.dist_backend nccl \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts # \
#--checkpoint $trainbase_exp/23.pt
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
if $trainbase; then
echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt"
checkpoint=$trainbase_exp/final.pt
else
echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/23.pt"
if [ ! -d mobvoi_kws_transcription ] ;then
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
fi
checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not converge well.
fi
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
--train_data data/train/data.list \
--cv_data data/dev/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
if $average_model; then
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
stream=true # we detect keyword online with ctc_prefix_beam_search
score_prefix=""
if $stream ; then
score_prefix=stream_
fi
python wekws/bin/${score_prefix}score_ctc.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \
--score_file $result_dir/score.txt \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_jit.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--jit_model $dir/$jit_model
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--onnx_model $dir/$onnx_model
fi

View File

@ -0,0 +1,175 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
# 2023 Jing Du(thuduj12@163.com)
. ./path.sh
stage=$1
stop_stage=$2
num_keywords=2599
config=conf/fsmn_ctc.yaml
norm_mean=true
norm_var=true
gpus="0"
checkpoint=
dir=exp/fsmn_ctc
average_model=true
num_average=30
if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
. tools/parse_options.sh || exit 1;
window_shift=50
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
echo "Download and extracte all datasets"
local/mobvoi_data_download.sh --dl_dir $download_dir
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Preparing datasets..."
mkdir -p dict
echo "<filler> -1" > dict/words.txt
echo "Hi_Xiaowen 0" >> dict/words.txt
echo "Nihao_Wenwen 1" >> dict/words.txt
for folder in train dev test; do
mkdir -p data/$folder
for prefix in p n; do
mkdir -p data/${prefix}_$folder
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
data/${prefix}_$folder
done
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
rm -rf data/p_$folder data/n_$folder
done
fi
if [ ${stage} -le -0 ] && [ ${stop_stage} -ge -0 ]; then
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
# to transcribe the negative wavs, and upload the transcription to modelscope.
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
for folder in train dev test; do
if [ -f data/$folder/text ];then
mv data/$folder/text data/$folder/text.label
fi
cp mobvoi_kws_transcription/$folder.text data/$folder/text
done
# and we also copy the tokens and lexicon that used in
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
# Here we use tokens.txt and lexicon.txt to convert txt into index
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Use the base model from modelscope"
if [ ! -d speech_charctc_kws_phone-xiaoyun ] ;then
git lfs install
git clone https://www.modelscope.cn/damo/speech_charctc_kws_phone-xiaoyun.git
fi
checkpoint=speech_charctc_kws_phone-xiaoyun/train/base.pt
cp speech_charctc_kws_phone-xiaoyun/train/feature_transform.txt.80dim-l2r2 data/global_cmvn.kaldi
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/global_cmvn.kaldi"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
--train_data data/train/data.list \
--cv_data data/dev/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
if $average_model; then
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
stream=true # we detect keyword online with ctc_prefix_beam_search
score_prefix=""
if $stream ; then
score_prefix=stream_
fi
python wekws/bin/${score_prefix}score_ctc.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \
--score_file $result_dir/score.txt \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
# For now, FSMN can not export to JITScript
# python wekws/bin/export_jit.py \
# --config $dir/config.yaml \
# --checkpoint $score_checkpoint \
# --jit_model $dir/$jit_model
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--onnx_model $dir/$onnx_model
fi

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,7 +16,145 @@
# limitations under the License.
import argparse
import logging
import json
import re
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
def split_mixed_label(input_str):
tokens = []
s = input_str.lower()
while len(s) > 0:
match = re.match(r'[A-Za-z!?,<>()\']+', s)
if match is not None:
word = match.group(0)
else:
word = s[0:1]
tokens.append(word)
s = s.replace(word, '', 1).strip(' ')
return tokens
def query_token_set(txt, symbol_table, lexicon_table):
tokens_str = tuple()
tokens_idx = tuple()
parts = split_mixed_label(txt)
for part in parts:
if part == '!sil' or part == '(sil)' or part == '<sil>':
tokens_str = tokens_str + ('!sil', )
elif part == '<blk>' or part == '<blank>':
tokens_str = tokens_str + ('<blk>', )
elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str = tokens_str + ('<GBG>', )
elif part in symbol_table:
tokens_str = tokens_str + (part, )
elif part in lexicon_table:
for ch in lexicon_table[part]:
tokens_str = tokens_str + (ch, )
else:
# case with symbols or meaningless english letter combination
part = re.sub(symbol_str, '', part)
for ch in part:
tokens_str = tokens_str + (ch, )
for ch in tokens_str:
if ch in symbol_table:
tokens_idx = tokens_idx + (symbol_table[ch], )
elif ch == '!sil':
if 'sil' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['sil'], )
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
elif ch == '<GBG>':
if '<GBG>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
else:
if '<GBG>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
logging.info(
f'{ch} is not in token set, replace with <GBG>')
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
logging.info(
f'{ch} is not in token set, replace with <blk>')
return tokens_str, tokens_idx
def query_token_list(txt, symbol_table, lexicon_table):
tokens_str = []
tokens_idx = []
parts = split_mixed_label(txt)
for part in parts:
if part == '!sil' or part == '(sil)' or part == '<sil>':
tokens_str.append('!sil')
elif part == '<blk>' or part == '<blank>':
tokens_str.append('<blk>')
elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str.append('<GBG>')
elif part in symbol_table:
tokens_str.append(part)
elif part in lexicon_table:
for ch in lexicon_table[part]:
tokens_str.append(ch)
else:
# case with symbols or meaningless english letter combination
part = re.sub(symbol_str, '', part)
for ch in part:
tokens_str.append(ch)
for ch in tokens_str:
if ch in symbol_table:
tokens_idx.append(symbol_table[ch])
elif ch == '!sil':
if 'sil' in symbol_table:
tokens_idx.append(symbol_table['sil'])
else:
tokens_idx.append(symbol_table['<blk>'])
elif ch == '<GBG>':
if '<GBG>' in symbol_table:
tokens_idx.append(symbol_table['<GBG>'])
else:
tokens_idx.append(symbol_table['<blk>'])
else:
if '<GBG>' in symbol_table:
tokens_idx.append(symbol_table['<GBG>'])
logging.info(
f'{ch} is not in token set, replace with <GBG>')
else:
tokens_idx.append(symbol_table['<blk>'])
logging.info(
f'{ch} is not in token set, replace with <blk>')
return tokens_str, tokens_idx
def read_token(token_file):
tokens_table = {}
with open(token_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
tokens_table[arr[0]] = int(arr[1]) - 1
fin.close()
return tokens_table
def read_lexicon(lexicon_file):
lexicon_table = {}
with open(lexicon_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().replace('\t', ' ').split()
assert len(arr) >= 2
lexicon_table[arr[0]] = arr[1:]
fin.close()
return lexicon_table
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
@ -23,6 +162,10 @@ if __name__ == '__main__':
parser.add_argument('text_file', help='text file')
parser.add_argument('duration_file', help='duration file')
parser.add_argument('output_file', help='output list file')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args()
wav_table = {}
@ -39,16 +182,38 @@ if __name__ == '__main__':
assert len(arr) == 2
duration_table[arr[0]] = float(arr[1])
token_table = None
if args.token_file:
token_table = read_token(args.token_file)
lexicon_table = None
if args.lexicon_file:
lexicon_table = read_lexicon(args.lexicon_file)
with open(args.text_file, 'r', encoding='utf8') as fin, \
open(args.output_file, 'w', encoding='utf8') as fout:
for line in fin:
arr = line.strip().split(maxsplit=1)
key = arr[0]
txt = int(arr[1])
tokens = None
if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"]
else:
tokens, txt = query_token_list(arr[1],
token_table,
lexicon_table)
else:
txt = int(arr[1])
assert key in wav_table
wav = wav_table[key]
assert key in duration_table
duration = duration_table[key]
line = dict(key=key, txt=txt, duration=duration, wav=wav)
if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav)
else:
line = dict(key=key, tok=tokens, txt=txt,
duration=duration, wav=wav)
json_line = json.dumps(line, ensure_ascii=False)
fout.write(json_line + '\n')

View File

@ -0,0 +1,271 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt
import pypinyin # for Chinese Character
from tools.make_list import query_token_set, read_lexicon, read_token
def split_mixed_label(input_str):
tokens = []
s = input_str.lower()
while len(s) > 0:
match = re.match(r'[A-Za-z!?,<>()\']+', s)
if match is not None:
word = match.group(0)
else:
word = s[0:1]
tokens.append(word)
s = s.replace(word, '', 1).strip(' ')
return tokens
def space_mixed_label(input_str):
splits = split_mixed_label(input_str)
space_str = ''.join(f'{sub} ' for sub in splits)
return space_str.strip()
def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table
for line in fin:
arr = line.strip().split()
key = arr[0]
is_detected = arr[1]
if is_detected == 'detected':
keyword = true_keywords[arr[2]]
if key not in score_table:
score_table.update({
key: {
'kw': space_mixed_label(keyword),
'confi': float(arr[3])
}
})
else:
if key not in score_table:
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
label_lists = []
with open(label_file, 'r', encoding='utf8') as fin:
for line in fin:
obj = json.loads(line.strip())
label_lists.append(obj)
# build empty structure for keyword-filler infos
keyword_filler_table = {}
for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword)
keyword_filler_table[keyword] = {}
keyword_filler_table[keyword]['keyword_table'] = {}
keyword_filler_table[keyword]['keyword_duration'] = 0.0
keyword_filler_table[keyword]['filler_table'] = {}
keyword_filler_table[keyword]['filler_duration'] = 0.0
for obj in label_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'tok' in obj # here we use the tokens
assert 'duration' in obj
key = obj['key']
txt = "".join(obj['tok'])
txt = space_mixed_label(txt)
txt_regstr_lrblk = ' ' + txt + ' '
duration = obj['duration']
assert key in score_table
for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword)
keyword_regstr_lrblk = ' ' + keyword + ' '
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['keyword_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance detected but not match this keyword
keyword_filler_table[keyword]['keyword_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['keyword_duration'] += duration
else:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['filler_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance if detected, which is not FA for this keyword
keyword_filler_table[keyword]['filler_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['filler_duration'] += duration
return keyword_filler_table
def load_stats_file(stats_file):
values = []
with open(stats_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
threshold, fa_per_hour, frr = arr
values.append([float(fa_per_hour), float(frr) * 100])
values.reverse()
return np.array(values)
def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
det_title = "DetCurve"
plt.figure(dpi=200)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['font.size'] = 12
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
logging.info(f'reading det data from {file}')
label = os.path.basename(file).split('.')[1]
label = "".join(pypinyin.lazy_pinyin(label))
values = load_stats_file(file)
plt.plot(values[:, 0], values[:, 1], label=label)
plt.xlim([0, xlim])
plt.ylim([0, ylim])
plt.xticks(range(0, xlim + x_step, x_step))
plt.yticks(range(0, ylim + y_step, y_step))
plt.xlabel('False Alarm Per Hour')
plt.ylabel('False Rejection Rate (%)')
plt.grid(linestyle='--')
plt.legend(loc='best', fontsize=6)
plt.savefig(figure_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keywords', type=str, default=None,
help='keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01,
help='threshold step')
parser.add_argument('--window_shift', type=int, default=50,
help='window_shift is used to '
'skip the frames after triggered')
parser.add_argument('--stats_dir',
required=False,
default=None,
help='false reject/alarm stats dir, '
'default in score_file')
parser.add_argument('--det_curve_path',
required=False,
default=None,
help='det curve path, default is stats_dir/det.png')
parser.add_argument(
'--xlim',
type=int,
default=5,
help='xlimrange of x-axis, x is false alarm per hour')
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
parser.add_argument(
'--ylim',
type=int,
default=35,
help='ylimrange of y-axis, y is false rejection rate')
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
args = parser.parse_args()
window_shift = args.window_shift
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords.strip().split(',')
token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
true_keywords = {}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
true_keywords[keyword] = ''.join(strs)
keyword_filler_table = load_label_and_score(
keywords_list, args.test_data, args.score_file, true_keywords)
for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword)
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
filler_dur = keyword_filler_table[keyword]['filler_duration']
filler_num = len(keyword_filler_table[keyword]['filler_table'])
assert keyword_num > 0, \
'Can\'t compute det for {} without positive sample'
assert filler_num > 0, \
'Can\'t compute det for {} without negative sample'
logging.info('Computing det for {}'.format(keyword))
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
keyword_dur / 3600.0, keyword_num))
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
if args.stats_dir :
stats_dir = args.stats_dir
else:
stats_dir = os.path.dirname(args.score_file)
stats_file = os.path.join(
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
with open(stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in \
keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:
num_true_detect += 1
num_false_alarm = 0
# transverse the all filler_table
for key, confi in keyword_filler_table[
keyword]['filler_table'].items():
if confi >= threshold:
num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}')
false_reject_rate = num_false_reject / keyword_num
true_detect_rate = num_true_detect / keyword_num
num_false_alarm = max(num_false_alarm, 1e-6)
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
false_alarm_rate = num_false_alarm / filler_num
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
threshold, false_alarm_per_hour, false_reject_rate))
threshold += args.step
if args.det_curve_path :
det_curve_path = args.det_curve_path
else:
det_curve_path = os.path.join(stats_dir, 'det.png')
plot_det(stats_dir, det_curve_path,
args.xlim, args.x_step, args.ylim, args.y_step)

View File

@ -41,6 +41,9 @@ def main():
configs = yaml.load(fin, Loader=yaml.FullLoader)
feature_dim = configs['model']['input_dim']
model = init_model(configs['model'])
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
# if we use ctc_loss, the logits need to be convert into probs
model.forward = model.forward_softmax
print(model)
load_checkpoint(model, args.checkpoint)

View File

@ -106,7 +106,7 @@ def main():
score_abs_path = os.path.abspath(args.score_file)
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths = batch
keys, feats, target, lengths, target_lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits, _ = model(feats)

219
wekws/bin/score_ctc.py Normal file
View File

@ -0,0 +1,219 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import math
import torch
import yaml
from torch.utils.data import DataLoader
from wekws.dataset.dataset import Dataset
from wekws.model.kws_model import init_model
from wekws.utils.checkpoint import load_checkpoint
from wekws.model.loss import ctc_prefix_beam_search
from tools.make_list import query_token_set, read_lexicon, read_token
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--batch_size',
default=16,
type=int,
help='batch size for inference')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--score_file',
required=True,
help='output score file')
parser.add_argument('--jit_model',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args()
return args
def is_sublist(main_list, check_list):
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list)):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['shuffle'] = False
test_conf['feature_extraction_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_size'] = args.batch_size
test_dataset = Dataset(args.test_data, test_conf)
test_data_loader = DataLoader(test_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
if args.jit_model:
model = torch.jit.load(args.checkpoint)
# For script model, only cpu is supported.
device = torch.device('cpu')
else:
# Init asr model from configs
model = init_model(configs['model'])
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
score_abs_path = os.path.abspath(args.score_file)
token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed'
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
for i in indexes)
[keywords_strset.add(i) for i in strs]
[keywords_idxset.add(i) for i in indexes]
for txt, idx in zip(strs, indexes):
if keywords_tokenmap.get(txt, None) is None:
keywords_tokenmap[txt] = idx
token_print = ''
for txt, idx in keywords_tokenmap.items():
token_print += f'{txt}({idx}) '
logging.info(f'Token set is: {token_print}')
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths, target_lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits, _ = model(feats)
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
logits = logits.cpu()
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score,
lengths[i],
keywords_idxset)
hit_keyword = None
hit_score = 1.0
start = 0
end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in keywords_token.keys():
lab = keywords_token[word]['token_id']
offset = is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
hit_score = math.sqrt(hit_score)
break
if hit_keyword is not None:
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"duration {end - start}, "
f"score {hit_score}, Activated.")
else:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
if batch_idx % 10 == 0:
print('Progress batch {}'.format(batch_idx))
sys.stdout.flush()
if __name__ == '__main__':
main()

587
wekws/bin/stream_kws_ctc.py Normal file
View File

@ -0,0 +1,587 @@
# Copyright (c) 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import struct
# import wave
import librosa
import logging
import os
import math
import numpy as np
import torchaudio.compliance.kaldi as kaldi
import torch
import torch.nn.functional as F
import yaml
from collections import defaultdict
from wekws.model.kws_model import init_model
from wekws.utils.checkpoint import load_checkpoint
from tools.make_list import query_token_set, read_lexicon, read_token
def get_args():
parser = argparse.ArgumentParser(description='detect keywords online.')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--wav_path', required=False,
default=None, help='test wave path.')
parser.add_argument('--wav_scp', required=False,
default=None, help='test wave scp.')
parser.add_argument('--result_file', required=False,
default=None, help='test result.')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--jit_model',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_beam_size',
default=3,
type=int,
help='The first prune beam, '
'filter out those frames with low scores.')
parser.add_argument('--path_beam_size',
default=20,
type=int,
help='The second prune beam, '
'keep only path_beam_size candidates.')
parser.add_argument('--threshold',
type=float,
default=0.0,
help='The threshold of kws. '
'If ctc_search probs exceed this value,'
'the keyword will be activated.')
parser.add_argument('--min_frames',
default=5,
type=int,
help='The min frames of keyword\'s duration.')
parser.add_argument('--max_frames',
default=250,
type=int,
help='The max frames of keyword\'s duration.')
parser.add_argument('--interval_frames',
default=50,
type=int,
help='The interval frames of two continuous keywords.')
args = parser.parse_args()
return args
def is_sublist(main_list, check_list):
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list)):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
def ctc_prefix_beam_search(t, probs,
cur_hyps,
keywords_idxset,
score_beam_size):
'''
:param t: the time in frame
:param probs: the probability in t_th frame, (vocab_size, )
:param cur_hyps: list of tuples. [(tuple(), (1.0, 0.0, []))]
in tuple, 1st is prefix id, 2nd include p_blank,
p_non_blank, and path nodes list.
in path nodes list, each node is
a dict of {token=idx, frame=t, prob=ps}
:param keywords_idxset: the index of keywords in token.txt
:param score_beam_size: the probability threshold,
to filter out those frames with low probs.
:return:
next_hyps: the hypothesis depend on current hyp and current frame.
'''
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(score_beam_size)
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
if keywords_idxset is not None:
if prob > 0.05 and idx in keywords_idxset:
filter_probs.append(prob)
filter_index.append(idx)
else:
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)
if len(filter_index) == 0:
return cur_hyps
for s in filter_index:
ps = probs[s].item()
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s,)
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s,)
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
nodes.pop()
# to avoid change other beam which has this node.
nodes.append(dict(token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
return next_hyps
class KeyWordSpotter(torch.nn.Module):
def __init__(self, ckpt_path, config_path, token_path, lexicon_path,
threshold, min_frames=5, max_frames=250, interval_frames=50,
score_beam=3, path_beam=20,
gpu=-1, is_jit_model=False,):
super().__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
with open(config_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
dataset_conf = configs['dataset_conf']
# feature related
self.sample_rate = 16000
self.wave_remained = np.array([])
self.num_mel_bins = dataset_conf[
'feature_extraction_conf']['num_mel_bins']
self.frame_length = dataset_conf[
'feature_extraction_conf']['frame_length'] # in ms
self.frame_shift = dataset_conf[
'feature_extraction_conf']['frame_shift'] # in ms
self.downsampling = dataset_conf.get('frame_skip', 1)
self.resolution = self.frame_shift / 1000 # in second
# fsmn splice operation
self.context_expansion = dataset_conf.get('context_expansion', False)
self.left_context = 0
self.right_context = 0
if self.context_expansion:
self.left_context = dataset_conf['context_expansion_conf']['left']
self.right_context = dataset_conf['context_expansion_conf']['right']
self.feature_remained = None
self.feats_ctx_offset = 0 # after downsample, offset exist.
# model related
if is_jit_model:
model = torch.jit.load(ckpt_path)
# For script model, only cpu is supported.
device = torch.device('cpu')
else:
# Init model from configs
model = init_model(configs['model'])
load_checkpoint(model, ckpt_path)
use_cuda = gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
self.device = device
self.model = model.to(device)
self.model.eval()
logging.info(f'model {ckpt_path} loaded.')
self.token_table = read_token(token_path)
logging.info(f'tokens {token_path} with '
f'{len(self.token_table)} units loaded.')
self.lexicon_table = read_lexicon(lexicon_path)
logging.info(f'lexicons {lexicon_path} with '
f'{len(self.lexicon_table)} units loaded.')
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
# decoding and detection related
self.score_beam = score_beam
self.path_beam = path_beam
self.threshold = threshold
self.min_frames = min_frames
self.max_frames = max_frames
self.interval_frames = interval_frames
self.cur_hyps = [(tuple(), (1.0, 0.0, []))]
self.hit_score = 1.0
self.hit_keyword = None
self.activated = False
self.total_frames = 0 # frame offset, for absolute time
self.last_active_pos = -1 # the last frame of being activated
self.result = {}
def set_keywords(self, keywords):
# 4. parse keywords tokens
assert keywords is not None, \
'at least one keyword is needed, ' \
'multiple keywords should be splitted with comma(,)'
keywords_str = keywords
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(
keyword, self.token_table, self.lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
for i in indexes)
[keywords_strset.add(i) for i in strs]
[keywords_idxset.add(i) for i in indexes]
for txt, idx in zip(strs, indexes):
if keywords_tokenmap.get(txt, None) is None:
keywords_tokenmap[txt] = idx
token_print = ''
for txt, idx in keywords_tokenmap.items():
token_print += f'{txt}({idx}) '
logging.info(f'Token set is: {token_print}')
self.keywords_idxset = keywords_idxset
self.keywords_token = keywords_token
def accept_wave(self, wave):
assert isinstance(wave, bytes), \
"please make sure the input format is bytes(raw PCM)"
# convert bytes into float32
data = []
for i in range(0, len(wave), 2):
value = struct.unpack('<h', wave[i:i + 2])[0]
data.append(value)
# here we don't divide 32768.0,
# because kaldi.fbank accept original input
wave = np.array(data)
wave = np.append(self.wave_remained, wave)
if wave.size < (self.frame_length * self.sample_rate / 1000) \
* self.right_context :
self.wave_remained = wave
return None
wave_tensor = torch.from_numpy(wave).float().to(self.device)
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
feats = kaldi.fbank(wave_tensor,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=0,
energy_floor=0.0,
sample_frequency=self.sample_rate)
# update wave remained
feat_len = len(feats)
frame_shift = int(self.frame_shift / 1000 * self.sample_rate)
self.wave_remained = wave[feat_len * frame_shift:]
if self.context_expansion:
assert feat_len > self.right_context, \
"make sure each chunk feat length is large than right context."
# pad feats with remained feature from last chunk
if self.feature_remained is None: # first chunk
# pad first frame at the beginning,
# replicate just support last dimension, so we do transpose.
feats_pad = F.pad(
feats.T, (self.left_context, 0), mode='replicate').T
else:
feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - (
self.right_context + self.right_context)
ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
for i in range(ctx_frm):
feats_ctx[i] = torch.cat(
tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats
self.feature_remained = \
feats[-(self.left_context + self.right_context):]
feats = feats_ctx.to(self.device)
if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset == 0 \
else self.downsampling - self.feats_ctx_offset
remainder = (feats.size(0) + last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling - remainder
return feats
def decode_keywords(self, t, probs):
absolute_time = t + self.total_frames
# search next_hyps depend on current probs and hyps.
next_hyps = ctc_prefix_beam_search(absolute_time,
probs,
self.cur_hyps,
self.keywords_idxset,
self.score_beam)
# update cur_hyps. note: the hyps is sort by path score(pnb+pb),
# not the keywords' probabilities.
cur_hyps = next_hyps[:self.path_beam]
self.cur_hyps = cur_hyps
def execute_detection(self, t):
absolute_time = t + self.total_frames
hit_keyword = None
start = 0
end = 0
# hyps for detection
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps]
# detect keywords in decoding paths.
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in self.keywords_token.keys():
lab = self.keywords_token[word]['token_id']
offset = is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
self.hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
self.hit_score = math.sqrt(self.hit_score)
break
duration = end - start
if hit_keyword is not None:
if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos == -1 or
end - self.last_active_pos >= self.interval_frames):
self.activated = True
self.last_active_pos = end
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos > 0 and \
end - self.last_active_pos < self.interval_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but interval {end-self.last_active_pos} "
f"is lower than {self.interval_frames}, Deactivated. ")
elif self.hit_score < self.threshold:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but {self.hit_score} "
f"is lower than {self.threshold}, Deactivated. ")
elif self.min_frames > duration or duration > self.max_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "
f"but {duration} beyond range"
f"({self.min_frames}~{self.max_frames}), Deactivated. ")
self.result = {
"state": 1 if self.activated else 0,
"keyword": hit_keyword if self.activated else None,
"start": start * self.resolution if self.activated else None,
"end": end * self.resolution if self.activated else None,
"score": self.hit_score if self.activated else None
}
def forward(self, wave_chunk):
feature = self.accept_wave(wave_chunk)
if feature is None or feature.size(0) < 1:
return {} # # the feature is not enough to get result.
feature = feature.unsqueeze(0) # add a batch dimension
logits, self.in_cache = self.model(feature, self.in_cache)
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
probs = probs[0].cpu() # remove batch dimension
for (t, prob) in enumerate(probs):
t *= self.downsampling
self.decode_keywords(t, prob)
self.execute_detection(t)
if self.activated:
self.reset()
# since a chunk include about 30 frames,
# once activated, we can jump the latter frames.
# TODO: there should give another method to update result,
# avoiding self.result being cleared.
break
# update frame offset
self.total_frames += len(probs) * self.downsampling
return self.result
def reset(self):
self.cur_hyps = [(tuple(), (1.0, 0.0, []))]
self.activated = False
self.hit_score = 1.0
def reset_all(self):
self.reset()
self.wave_remained = np.array([])
self.feature_remained = None
self.feats_ctx_offset = 0 # after downsample, offset exist.
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
self.total_frames = 0 # frame offset, for absolute time
self.last_active_pos = -1 # the last frame of being activated
self.result = {}
def demo():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
kws = KeyWordSpotter(args.checkpoint,
args.config,
args.token_file,
args.lexicon_file,
args.threshold,
args.min_frames,
args.max_frames,
args.interval_frames,
args.score_beam_size,
args.path_beam_size,
args.gpu,
args.jit_model)
# actually this could be done in __init__ method,
# we pull it outside for changing keywords more freely.
kws.set_keywords(args.keywords)
if args.wav_path:
# Caution: input WAV should be standard 16k, 16 bits, 1 channel
# In demo we read wave in non-streaming fashion.
# with wave.open(args.wav_path, 'rb') as fin:
# assert fin.getnchannels() == 1
# wav = fin.readframes(fin.getnframes())
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
# NOTE: model supports 16k sample_rate
wav = (y * (1 << 15)).astype("int16").tobytes()
# We inference every 0.3 seconds, in streaming fashion.
interval = int(0.3 * 16000) * 2
for i in range(0, len(wav), interval):
chunk_wav = wav[i: min(i + interval, len(wav))]
result = kws.forward(chunk_wav)
print(result)
fout = None
if args.result_file:
fout = open(args.result_file, 'w', encoding='utf-8')
if args.wav_scp:
with open(args.wav_scp, 'r') as fscp:
for line in fscp:
line = line.strip().split()
assert len(line) == 2, \
f"The scp should be in kaldi format: " \
f"\"utt_name wav_path\", but got {line}"
utt_name, wav_path = line[0], line[1]
# with wave.open(args.wav_path, 'rb') as fin:
# assert fin.getnchannels() == 1
# wav = fin.readframes(fin.getnframes())
y, _ = librosa.load(args.wav_path, sr=16000, mono=True)
# NOTE: model supports 16k sample_rate
wav = (y * (1 << 15)).astype("int16").tobytes()
kws.reset_all()
activated = False
# We inference every 0.3 seconds, in streaming fashion.
interval = int(0.3 * 16000) * 2
for i in range(0, len(wav), interval):
chunk_wav = wav[i: min(i + interval, len(wav))]
result = kws.forward(chunk_wav)
if 'state' in result and result['state'] == 1:
activated = True
if fout:
hit_keyword = result['keyword']
hit_score = result['score']
fout.write('{} detected {} {:.3f}\n'.format(
utt_name, hit_keyword, hit_score))
if not activated:
if fout:
fout.write('{} rejected\n'.format(utt_name))
if fout:
fout.close()
if __name__ == '__main__':
demo()

View File

@ -0,0 +1,363 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import math
import torch
import yaml
from collections import defaultdict
from torch.utils.data import DataLoader
from wekws.dataset.dataset import Dataset
from wekws.model.kws_model import init_model
from wekws.utils.checkpoint import load_checkpoint
from tools.make_list import query_token_set, read_lexicon, read_token
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--batch_size',
default=1,
type=int,
help='batch size for inference')
parser.add_argument('--num_workers',
default=1,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--score_file',
required=True,
help='output score file')
parser.add_argument('--jit_model',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_beam_size',
default=3,
type=int,
help='The first prune beam, f'
'ilter out those frames with low scores.')
parser.add_argument('--path_beam_size',
default=20,
type=int,
help='The second prune beam, '
'keep only path_beam_size candidates.')
parser.add_argument('--threshold',
type=float,
default=0.0,
help='The threshold of kws. '
'If ctc_search probs exceed this value,'
'the keyword will be activated.')
parser.add_argument('--min_frames',
default=5,
type=int,
help='The min frames of keyword duration.')
parser.add_argument('--max_frames',
default=250,
type=int,
help='The max frames of keyword duration.')
args = parser.parse_args()
return args
def is_sublist(main_list, check_list):
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list)):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['shuffle'] = False
test_conf['feature_extraction_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_size'] = args.batch_size
downsampling_factor = test_conf.get('frame_skip', 1)
test_dataset = Dataset(args.test_data, test_conf)
test_data_loader = DataLoader(test_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
if args.jit_model:
model = torch.jit.load(args.checkpoint)
# For script model, only cpu is supported.
device = torch.device('cpu')
else:
# Init asr model from configs
model = init_model(configs['model'])
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
score_abs_path = os.path.abspath(args.score_file)
token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed'
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
for i in indexes)
[keywords_strset.add(i) for i in strs]
[keywords_idxset.add(i) for i in indexes]
for txt, idx in zip(strs, indexes):
if keywords_tokenmap.get(txt, None) is None:
keywords_tokenmap[txt] = idx
token_print = ''
for txt, idx in keywords_tokenmap.items():
token_print += f'{txt}({idx}) '
logging.info(f'Token set is: {token_print}')
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths, target_lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits, _ = model(feats)
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
logits = logits.cpu()
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
# hyps = ctc_prefix_beam_search(score, lengths[i],
# keywords_idxset)
maxlen = score.size(0)
ctc_probs = score
cur_hyps = [(tuple(), (1.0, 0.0, []))]
hit_keyword = None
activated = False
hit_score = 1.0
start = 0
end = 0
# 2. CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
t *= downsampling_factor # the real time
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(args.score_beam_size)
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(
top_k_probs.tolist(), top_k_index.tolist()):
if keywords_idxset is not None:
if prob > 0.05 and idx in keywords_idxset:
filter_probs.append(prob)
filter_index.append(idx)
else:
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)
if len(filter_index) == 0:
continue
for s in filter_index:
ps = probs[s].item()
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
# update frame and prob
if ps > nodes[-1]['prob']:
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s,)
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(
token=s, frame=t, prob=ps))
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s,)
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
# update frame and prob
if ps > nodes[-1]['prob']:
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
# avoid change other beam has this node.
nodes.pop()
nodes.append(dict(
token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
nodes.append(dict(
token=s, frame=t, prob=ps))
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: (x[1][0] + x[1][1]), reverse=True)
cur_hyps = next_hyps[:args.path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2])
for y in cur_hyps]
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in keywords_token.keys():
lab = keywords_token[word]['token_id']
offset = is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[
offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
hit_score = math.sqrt(hit_score)
break
duration = end - start
if hit_keyword is not None:
if hit_score >= args.threshold and \
args.min_frames <= duration <= args.max_frames:
activated = True
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"duration {duration}, s"
f"core {hit_score} Activated.")
# clear the ctc_prefix buffer, and hit_keyword
cur_hyps = [(tuple(), (1.0, 0.0, []))]
hit_keyword = None
hit_score = 1.0
elif hit_score < args.threshold:
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"but {hit_score} less than "
f"{args.threshold}, Deactivated. ")
elif args.min_frames > duration \
or duration > args.max_frames:
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"but {duration} beyond "
f"range({args.min_frames}~{args.max_frames}), "
f"Deactivated. ")
if not activated:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")
if batch_idx % 10 == 0:
print('Progress batch {}'.format(batch_idx))
sys.stdout.flush()
if __name__ == '__main__':
main()

View File

@ -134,7 +134,8 @@ def main():
output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export
configs['model']['input_dim'] = input_dim
if 'input_dim' not in configs['model']:
configs['model']['input_dim'] = input_dim
configs['model']['output_dim'] = output_dim
if args.cmvn_file is not None:
configs['model']['cmvn'] = {}
@ -156,8 +157,16 @@ def main():
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
if rank == 0:
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
pass
# TODO: for now streaming FSMN do not support export to JITScript,
# TODO: because there is nn.Sequential with Tuple input
# in current FSMN modules.
# the issue is in https://stackoverflow.com/questions/75714299/
# pytorch-jit-script-error-when-sequential-container-
# takes-a-tuple-input/76553450#76553450
# script_model = torch.jit.script(model)
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:

View File

@ -162,6 +162,16 @@ def Dataset(data_list_file, conf,
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
context_expansion = conf.get('context_expansion', False)
if context_expansion:
context_expansion_conf = conf.get('context_expansion_conf', {})
dataset = Processor(dataset, processor.context_expansion,
**context_expansion_conf)
frame_skip = conf.get('frame_skip', 1)
if frame_skip > 1:
dataset = Processor(dataset, processor.frame_skip, frame_skip)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)

View File

@ -263,6 +263,51 @@ def shuffle(data, shuffle_size=1000):
for x in buf:
yield x
def context_expansion(data, left=1, right=1):
""" expand left and right frames
Args:
data: Iterable[{key, feat, label}]
left (int): feature left context frames
right (int): feature right context frames
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
index = 0
feats = sample['feat']
ctx_dim = feats.shape[0]
ctx_frm = feats.shape[1] * (left + right + 1)
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
for lag in range(-left, right + 1):
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
feats, -lag, 0)
index = index + feats.shape[1]
# replication pad left margin
for idx in range(left):
for cpx in range(left - idx):
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]
feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
sample['feat'] = feats_ctx
yield sample
def frame_skip(data, skip_rate=1):
""" skip frame
Args:
data: Iterable[{key, feat, label}]
skip_rate (int): take every N-frames for model input
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
feats_skip = sample['feat'][::skip_rate, :]
sample['feat'] = feats_skip
yield sample
def batch(data, batch_size=16):
""" Static batch the data by `batch_size`
@ -302,12 +347,24 @@ def padding(data):
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
sorted_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int64)
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
if isinstance(sample[0]['label'], int):
padded_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int32)
label_lengths = torch.tensor([1 for i in order],
dtype=torch.int32)
else:
sorted_labels = [
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
]
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
dtype=torch.int32)
padded_labels = pad_sequence(
sorted_labels, batch_first=True, padding_value=-1)
yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths)
def add_reverb(data, reverb_source, aug_prob):
@ -320,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob):
rir_io = io.BytesIO(rir_data)
_, rir_audio = wavfile.read(rir_io)
rir_audio = rir_audio.astype(np.float32)
if len(rir_audio.shape) > 1:
rir_audio = rir_audio[:, 0]
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
out_audio = signal.convolve(audio, rir_audio,
mode='full')[:audio_len]
@ -348,6 +407,8 @@ def add_noise(data, noise_source, aug_prob):
snr_range = [0, 15]
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
noise_audio = noise_audio.astype(np.float32)
if len(noise_audio.shape) > 1:
noise_audio = noise_audio[:, 0]
if noise_audio.shape[0] > audio_len:
start = random.randint(0, noise_audio.shape[0] - audio_len)
noise_audio = noise_audio[start:start + audio_len]

558
wekws/model/fsmn.py Normal file
View File

@ -0,0 +1,558 @@
'''
FSMN implementation.
Copyright: 2022-03-09 yueyue.nyy
2023 Jing Du
'''
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def toKaldiMatrix(np_mat):
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
out_str = str(np_mat)
out_str = out_str.replace('[', '')
out_str = out_str.replace(']', '')
return '[ %s ]\n' % out_str
def printTensor(torch_tensor):
re_str = ''
x = torch_tensor.detach().squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
print(re_str)
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else:
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return (output, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
linear_line = fread.readline()
linear_split = linear_line.strip().split()
assert len(linear_split) == 3
assert linear_split[0] == '<LinearTransform>'
self.output_dim = int(linear_split[1])
self.input_dim = int(linear_split[2])
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
self.linear.reset_parameters()
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else:
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return (output, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
affine_line = fread.readline()
affine_split = affine_line.strip().split()
assert len(affine_split) == 3
assert affine_split[0] == '<AffineTransform>'
self.output_dim = int(affine_split[1])
self.input_dim = int(affine_split[2])
print('AffineTransform output/input dim: %d %d' %
(self.output_dim, self.input_dim))
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
self.linear.reset_parameters()
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
# linear_bias = self.state_dict()['linear.bias']
# print(linear_bias.shape)
bias_line = fread.readline()
splits = bias_line.strip().strip('[]').strip().split()
assert len(splits) == self.output_dim
new_bias = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
self.linear.bias.data = new_bias
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim,
self.dim, [lorder, 1],
dilation=[lstride, 1],
groups=self.dim,
bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
self.dim,
self.dim, [rorder, 1],
dilation=[rstride, 1],
groups=self.dim,
bias=False)
else:
self.conv_right = None
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else :
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache) == 0 :
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride
+ self.rorder * self.rstride, 0])
else:
in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride
+ self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
self.rstride, :] + y_left
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(
x_per.size(2) + self.rorder * self.rstride):, :]
y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right)
y_right = self.conv_right(y_right)
y_right = self.dequant(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return (output, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight']
x = np.flipud(lfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
if self.conv_right is not None:
rfiters = self.state_dict()['conv_right.weight']
x = (rfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
fsmn_line = fread.readline()
fsmn_split = fsmn_line.strip().split()
assert len(fsmn_split) == 3
assert fsmn_split[0] == '<Fsmn>'
self.dim = int(fsmn_split[1])
params_line = fread.readline()
params_split = params_line.strip().strip('[]').strip().split()
assert len(params_split) == 12
assert params_split[0] == '<LearnRateCoef>'
assert params_split[2] == '<LOrder>'
self.lorder = int(params_split[3])
assert params_split[4] == '<ROrder>'
self.rorder = int(params_split[5])
assert params_split[6] == '<LStride>'
self.lstride = int(params_split[7])
assert params_split[8] == '<RStride>'
self.rstride = int(params_split[9])
assert params_split[10] == '<MaxNorm>'
# lfilters = self.state_dict()['conv_left.weight']
# print(lfilters.shape)
print('read conv_left weight')
new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
dtype=torch.float32)
for i in range(self.lorder):
print('read conv_left weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
new_lfilters = torch.transpose(new_lfilters, 0, 2)
# print(new_lfilters.shape)
self.conv_left.reset_parameters()
self.conv_left.weight.data = new_lfilters
# print(self.conv_left.weight.shape)
if self.rorder > 0:
# rfilters = self.state_dict()['conv_right.weight']
# print(rfilters.shape)
print('read conv_right weight')
new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
dtype=torch.float32)
line = fread.readline()
for i in range(self.rorder):
print('read conv_right weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_rfilters[i, 0, :, 0] = cols
new_rfilters = torch.transpose(new_rfilters, 0, 2)
# print(new_rfilters.shape)
self.conv_right.reset_parameters()
self.conv_right.weight.data = new_rfilters
# print(self.conv_right.weight.shape)
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else :
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
out = self.relu(input)
# out = self.dropout(out)
return (out, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
# re_str += '<!EndOfComponent>\n'
return re_str
# re_str = ''
# re_str += '<ParametricRelu> %d %d\n' % (self.dim, self.dim)
# re_str += '<AlphaLearnRateCoef> 0 <BetaLearnRateCoef> 0\n'
# re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32'))
# re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32'))
# re_str += '<!EndOfComponent>\n'
# return re_str
def to_pytorch_net(self, fread):
line = fread.readline()
splits = line.strip().split()
assert len(splits) == 3
assert splits[0] == '<RectifiedLinear>'
assert int(splits[1]) == int(splits[2])
assert int(splits[1]) == self.dim
self.dim = int(splits[1])
def _build_repeats(
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride=1,
rstride=1,
):
repeats = [
nn.Sequential(
LinearTransform(linear_dim, proj_dim),
FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
AffineTransform(proj_dim, linear_dim),
RectifiedLinear(linear_dim, linear_dim))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
class FSMN(nn.Module):
def __init__(
self,
input_dim: int,
input_affine_dim: int,
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int,
):
"""
Args:
input_dim: input dimension
input_affine_dim: input affine layer dimension
fsmn_layers: no. of fsmn units
linear_dim: fsmn input dimension
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
lstride: fsmn left stride
rstride: fsmn right stride
output_affine_dim: output affine layer dimension
output_dim: output dimension
"""
super(FSMN, self).__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.padding = (self.lorder - 1) * self.lstride \
+ self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder,
rorder, lstride, rstride)
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim = -1)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
"""
if in_cache is None or len(in_cache) == 0 :
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float)
for _ in range(len(self.fsmn))]
input = (input, in_cache)
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
# x4 = self.fsmn(x3)
x4, _ = x3
for layer, module in enumerate(self.fsmn):
x4, in_cache[layer] = module((x4, in_cache[layer]))
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
# x7 = self.softmax(x6)
x7, _ = x6
# return x7, None
return x7, in_cache
def to_kaldi_net(self):
re_str = ''
re_str += '<Nnet>\n'
re_str += self.in_linear1.to_kaldi_net()
re_str += self.in_linear2.to_kaldi_net()
re_str += self.relu.to_kaldi_net()
for fsmn in self.fsmn:
re_str += fsmn[0].to_kaldi_net()
re_str += fsmn[1].to_kaldi_net()
re_str += fsmn[2].to_kaldi_net()
re_str += fsmn[3].to_kaldi_net()
re_str += self.out_linear1.to_kaldi_net()
re_str += self.out_linear2.to_kaldi_net()
re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
# re_str += '<!EndOfComponent>\n'
re_str += '</Nnet>\n'
return re_str
def to_pytorch_net(self, kaldi_file):
with open(kaldi_file, 'r', encoding='utf8') as fread:
fread = open(kaldi_file, 'r')
nnet_start_line = fread.readline()
assert nnet_start_line.strip() == '<Nnet>'
self.in_linear1.to_pytorch_net(fread)
self.in_linear2.to_pytorch_net(fread)
self.relu.to_pytorch_net(fread)
for fsmn in self.fsmn:
fsmn[0].to_pytorch_net(fread)
fsmn[1].to_pytorch_net(fread)
fsmn[2].to_pytorch_net(fread)
fsmn[3].to_pytorch_net(fread)
self.out_linear1.to_pytorch_net(fread)
self.out_linear2.to_pytorch_net(fread)
softmax_line = fread.readline()
softmax_split = softmax_line.strip().split()
assert softmax_split[0].strip() == '<Softmax>'
assert int(softmax_split[1]) == self.output_dim
assert int(softmax_split[2]) == self.output_dim
# '<!EndOfComponent>\n'
nnet_end_line = fread.readline()
assert nnet_end_line.strip() == '</Nnet>'
fread.close()
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang
# 2023 Jing Du
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,14 +26,15 @@ from wekws.model.subsampling import (LinearSubsampling1, Conv1dSubsampling1,
NoSubsampling)
from wekws.model.tcn import TCN, CnnBlock, DsCnnBlock
from wekws.model.mdtc import MDTC
from wekws.utils.cmvn import load_cmvn
from wekws.utils.cmvn import load_cmvn, load_kaldi_cmvn
from wekws.model.fsmn import FSMN
class KWSModel(nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
3. backbone: backbone of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation:
nn.Sigmoid for wakeup word
@ -72,6 +74,20 @@ class KWSModel(nn.Module):
x = self.activation(x)
return x, out_cache
def forward_softmax(self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(
0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
x = self.preprocessing(x)
x, out_cache = self.backbone(x, in_cache)
x = self.classifier(x)
x = self.activation(x)
x = x.softmax(2)
return x, out_cache
def fuse_modules(self):
self.preprocessing.fuse_modules()
self.backbone.fuse_modules()
@ -80,7 +96,10 @@ class KWSModel(nn.Module):
def init_model(configs):
cmvn = configs.get('cmvn', {})
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
mean, istd = load_cmvn(cmvn['cmvn_file'])
if "kaldi" in cmvn['cmvn_file']:
mean, istd = load_kaldi_cmvn(cmvn['cmvn_file'])
else:
mean, istd = load_cmvn(cmvn['cmvn_file'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(),
@ -135,6 +154,20 @@ def init_model(configs):
hidden_dim,
kernel_size,
causal=causal)
elif backbone_type == 'fsmn':
input_affine_dim = configs['backbone']['input_affine_dim']
num_layers = configs['backbone']['num_layers']
linear_dim = configs['backbone']['linear_dim']
proj_dim = configs['backbone']['proj_dim']
left_order = configs['backbone']['left_order']
right_order = configs['backbone']['right_order']
left_stride = configs['backbone']['left_stride']
right_stride = configs['backbone']['right_stride']
output_affine_dim = configs['backbone']['output_affine_dim']
backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
proj_dim, left_order, right_order, left_stride,
right_stride, output_affine_dim, output_dim)
else:
print('Unknown body type {}'.format(backbone_type))
sys.exit(1)
@ -154,6 +187,8 @@ def init_model(configs):
# last means we use last frame to do backpropagation, so the model
# can be infered streamingly
classifier = LastClassifier(classifier_base)
elif classifier_type == 'identity':
classifier = nn.Identity()
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
@ -162,6 +197,17 @@ def init_model(configs):
classifier = LinearClassifier(hidden_dim, output_dim)
activation = nn.Sigmoid()
# Here we add a possible "activation_type",
# one can choose to use other activation function.
# We use nn.Identity just for CTC loss
if "activation" in configs:
activation_type = configs["activation"]["type"]
if activation_type == 'identity':
activation = nn.Identity()
else:
print('Unknown activation type {}'.format(activation_type))
sys.exit(1)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation)
return kws_model

View File

@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang
# 2023 Jing Du
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +14,11 @@
# limitations under the License.
import torch
import math
import sys
import torch.nn.functional as F
from collections import defaultdict
from typing import List, Tuple
from wekws.utils.mask import padding_mask
@ -93,6 +98,65 @@ def acc_frame(
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct * 100.0 / logits.size(0)
def acc_utterance(logits: torch.Tensor, target: torch.Tensor,
logits_length: torch.Tensor, target_length: torch.Tensor):
if logits is None:
return 0
logits = logits.softmax(2) # (1, maxlen, vocab_size)
logits = logits.cpu()
target = target.cpu()
total_word = 0
total_ins = 0
total_sub = 0
total_del = 0
calculator = Calculator()
for i in range(logits.size(0)):
score = logits[i][:logits_length[i]]
hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5)
lab = [str(item) for item in target[i][:target_length[i]].tolist()]
rec = []
if len(hyps) > 0:
rec = [str(item) for item in hyps[0][0]]
result = calculator.calculate(lab, rec)
# print(f'result:{result}')
if result['all'] != 0:
total_word += result['all']
total_ins += result['ins']
total_sub += result['sub']
total_del += result['del']
return float(total_word - total_ins - total_sub
- total_del) * 100.0 / total_word
def ctc_loss(logits: torch.Tensor,
target: torch.Tensor,
logits_lengths: torch.Tensor,
target_lengths: torch.Tensor,
need_acc: bool = False):
""" CTC Loss
Args:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
logits_lengths: (B)
target_lengths: (B)
Returns:
(float): loss of current batch
"""
acc = 0.0
if need_acc:
acc = acc_utterance(logits, target, logits_lengths, target_lengths)
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose(0, 1)
logits = logits.log_softmax(2)
loss = F.ctc_loss(
logits, target, logits_lengths, target_lengths, reduction='sum')
loss = loss / logits.size(1) # batch mean
return loss, acc
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
""" Cross Entropy Loss
@ -114,12 +178,284 @@ def criterion(type: str,
logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
target_lengths: torch.Tensor = None,
min_duration: int = 0,
validation: bool = False, ):
if type == 'ce':
loss, acc = cross_entropy(logits, target)
return loss, acc
elif type == 'max_pooling':
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
elif type == 'ctc':
loss, acc = ctc_loss(
logits, target, lengths, target_lengths, validation)
return loss, acc
else:
exit(1)
def ctc_prefix_beam_search(
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
keywords_tokenset (set): token set for filtering score
score_beam_size (int): beam size for score
path_beam_size (int): beam size for path
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
ctc_probs = logits
cur_hyps = [(tuple(), (1.0, 0.0, []))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(
score_beam_size) # (score_beam_size,)
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
if keywords_tokenset is not None:
if prob > 0.05 and idx in keywords_tokenset:
filter_probs.append(prob)
filter_index.append(idx)
else:
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)
if len(filter_index) == 0:
continue
for s in filter_index:
ps = probs[s].item()
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
# avoid change other beam which has this node.
nodes.pop()
nodes.append(dict(token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
cur_hyps = next_hyps[:path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
return hyps
class Calculator:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, '')
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, '')
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen, '
'i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())

View File

@ -15,6 +15,7 @@
import json
import math
import re
import numpy as np
@ -42,3 +43,50 @@ def load_cmvn(json_cmvn_file):
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_kaldi_cmvn(cmvn_file):
""" Load the kaldi format cmvn stats file and no need to calculate
Args:
cmvn_file: cmvn stats file in kaldi format
Returns:
a numpy array of [means, vars]
"""
means = None
variance = None
with open(cmvn_file) as f:
all_lines = f.readlines()
for idx, line in enumerate(all_lines):
if line.find('AddShift') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
means_list = means_str.strip().split(' ')
means = [0 - float(s) for s in means_list]
assert len(means) == int(segs[1])
elif line.find('Rescale') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
vars_list = vars_str.strip().split(' ')
variance = [float(s) for s in vars_list]
assert len(variance) == int(segs[1])
elif line.find('Splice') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
splice_list = splice_str.strip().split(' ')
assert len(splice_list) * int(segs[2]) == int(segs[1])
copy_times = len(splice_list)
else:
continue
cmvn = np.array([means, variance])
cmvn = np.tile(cmvn, (1, copy_times))
return cmvn

View File

@ -34,17 +34,20 @@ class Executor:
min_duration = args.get('min_duration', 0)
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
key, feats, target, feats_lengths, label_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
label_lengths = label_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
loss_type = args.get('criterion', 'max_pooling')
loss, acc = criterion(loss_type, logits, target, feats_lengths,
min_duration)
target_lengths=label_lengths,
min_duration=min_duration,
validation=False)
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
@ -67,16 +70,20 @@ class Executor:
total_acc = 0.0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
key, feats, target, feats_lengths, label_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
label_lengths = label_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
logits, target, feats_lengths,
target_lengths=label_lengths,
min_duration=0,
validation=True)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts