add ctcloss training scripts.
This commit is contained in:
parent
85350c38a8
commit
a2f8d0e39e
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Normal file
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
|
num_mel_bins: 40
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 20
|
||||||
|
max_f: 10
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 256
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 256
|
||||||
|
preprocessing:
|
||||||
|
type: linear
|
||||||
|
backbone:
|
||||||
|
type: tcn
|
||||||
|
ds: true
|
||||||
|
num_layers: 4
|
||||||
|
kernel_size: 8
|
||||||
|
dropout: 0.1
|
||||||
|
activation:
|
||||||
|
type: identity
|
||||||
|
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 80
|
||||||
|
log_interval: 10
|
||||||
|
criterion: ctc
|
||||||
|
|
||||||
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Normal file
50
examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
|
num_mel_bins: 40
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 20
|
||||||
|
max_f: 10
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 200
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 256
|
||||||
|
preprocessing:
|
||||||
|
type: linear
|
||||||
|
backbone:
|
||||||
|
type: tcn
|
||||||
|
ds: true
|
||||||
|
num_layers: 4
|
||||||
|
kernel_size: 8
|
||||||
|
dropout: 0.1
|
||||||
|
activation:
|
||||||
|
type: identity
|
||||||
|
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
weight_decay: 0.0001
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 30
|
||||||
|
log_interval: 100
|
||||||
|
criterion: ctc
|
||||||
|
|
||||||
@ -111,6 +111,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
--stats_file $result_dir/stats.${keyword}.txt
|
--stats_file $result_dir/stats.${keyword}.txt
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# plot det curve
|
||||||
|
python wekws/bin/plot_det_curve.py \
|
||||||
|
--keywords_dict dict/words.txt \
|
||||||
|
--stats_dir $result_dir \
|
||||||
|
--figure_file $result_dir/det.png
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
198
examples/hi_xiaowen/s0/run_ctc.sh
Normal file
198
examples/hi_xiaowen/s0/run_ctc.sh
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
|
# 2023 dujing(thuduj12@163.com)
|
||||||
|
|
||||||
|
. ./path.sh
|
||||||
|
|
||||||
|
stage=$1
|
||||||
|
stop_stage=$2
|
||||||
|
num_keywords=2599
|
||||||
|
|
||||||
|
config=conf/ds_tcn_ctc.yaml
|
||||||
|
norm_mean=true
|
||||||
|
norm_var=true
|
||||||
|
gpus="4,5,6,7"
|
||||||
|
|
||||||
|
checkpoint=
|
||||||
|
dir=exp/ds_tcn_ctc_ft
|
||||||
|
|
||||||
|
num_average=30
|
||||||
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
|
download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir
|
||||||
|
|
||||||
|
. tools/parse_options.sh || exit 1;
|
||||||
|
window_shift=50
|
||||||
|
|
||||||
|
#Whether to train base model. If set true, must put train+dev data in trainbase_dir
|
||||||
|
trainbase=true
|
||||||
|
trainbase_dir=data/base
|
||||||
|
trainbase_config=conf/ds_tcn_ctc_base.yaml
|
||||||
|
trainbase_exp=exp/base
|
||||||
|
if $trainbase; then
|
||||||
|
checkpoint=$trainbase_exp/final.pt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
|
||||||
|
echo "Download and extracte all datasets"
|
||||||
|
local/mobvoi_data_download.sh --dl_dir $download_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
|
||||||
|
echo "Preparing datasets..."
|
||||||
|
mkdir -p dict
|
||||||
|
echo "<filler> -1" > dict/words.txt
|
||||||
|
echo "Hi_Xiaowen 0" >> dict/words.txt
|
||||||
|
echo "Nihao_Wenwen 1" >> dict/words.txt
|
||||||
|
|
||||||
|
for folder in train dev test; do
|
||||||
|
mkdir -p data/$folder
|
||||||
|
for prefix in p n; do
|
||||||
|
mkdir -p data/${prefix}_$folder
|
||||||
|
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
|
||||||
|
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
|
||||||
|
data/${prefix}_$folder
|
||||||
|
done
|
||||||
|
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
|
||||||
|
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
|
||||||
|
rm -rf data/p_$folder data/n_$folder
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||||
|
# to transcribe the negative wavs, and upload the transcription to modelscope.
|
||||||
|
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
||||||
|
for folder in train dev test; do
|
||||||
|
if [ -f data/$folder/text ];then
|
||||||
|
mv data/$folder/text data/$folder/text.label
|
||||||
|
fi
|
||||||
|
cp mobvoi_kws_transcription/$folder.text data/$folder/text
|
||||||
|
done
|
||||||
|
|
||||||
|
# and we also copy the tokens and lexicon that used in
|
||||||
|
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
|
||||||
|
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
|
||||||
|
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Compute CMVN and Format datasets"
|
||||||
|
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||||
|
--in_scp data/train/wav.scp \
|
||||||
|
--out_cmvn data/train/global_cmvn
|
||||||
|
|
||||||
|
for x in train dev test; do
|
||||||
|
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||||
|
|
||||||
|
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||||
|
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||||
|
data/$x/wav.dur data/$x/data.list \
|
||||||
|
--token_file data/tokens.txt \
|
||||||
|
--lexicon_file data/lexicon.txt
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
||||||
|
for x in train dev ; do
|
||||||
|
if [ ! -f $trainbase_dir/$x/wav.scp ] || [ ! -f $trainbase_dir/$x/text ]; then
|
||||||
|
echo "If You Want to Train Base KWS-CTC Model, You Should Prepare ASR Data by Yourself."
|
||||||
|
echo "The wav.scp and text in KALDI-format is Needed, You Should Put Them in $trainbase_dir/$x"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
|
||||||
|
tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||||
|
if [ ! -f $trainbase_dir/$x/data.list ]; then
|
||||||
|
tools/make_list.py $trainbase_dir/$x/wav.scp $trainbase_dir/$x/text \
|
||||||
|
$trainbase_dir/$x/wav.dur $trainbase_dir/$x/data.list \
|
||||||
|
--token_file data/tokens.txt \
|
||||||
|
--lexicon_file data/lexicon.txt
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Start base training ..."
|
||||||
|
mkdir -p $trainbase_exp
|
||||||
|
cmvn_opts=
|
||||||
|
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||||
|
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||||
|
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||||
|
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
|
wekws/bin/train.py --gpus $gpus \
|
||||||
|
--config $trainbase_config \
|
||||||
|
--train_data $trainbase_dir/train/data.list \
|
||||||
|
--cv_data $trainbase_dir/dev/data.list \
|
||||||
|
--model_dir $trainbase_exp \
|
||||||
|
--num_workers 8 \
|
||||||
|
--num_keywords $num_keywords \
|
||||||
|
--min_duration 50 \
|
||||||
|
--seed 666 \
|
||||||
|
$cmvn_opts
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
echo "Start training ..."
|
||||||
|
mkdir -p $dir
|
||||||
|
cmvn_opts=
|
||||||
|
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||||
|
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||||
|
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||||
|
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
|
wekws/bin/train.py --gpus $gpus \
|
||||||
|
--config $config \
|
||||||
|
--train_data data/train/data.list \
|
||||||
|
--cv_data data/dev/data.list \
|
||||||
|
--model_dir $dir \
|
||||||
|
--num_workers 8 \
|
||||||
|
--num_keywords $num_keywords \
|
||||||
|
--min_duration 50 \
|
||||||
|
--seed 666 \
|
||||||
|
$cmvn_opts \
|
||||||
|
${checkpoint:+--checkpoint $checkpoint}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
echo "Do model average, Compute FRR/FAR ..."
|
||||||
|
python wekws/bin/average_model.py \
|
||||||
|
--dst_model $score_checkpoint \
|
||||||
|
--src_path $dir \
|
||||||
|
--num ${num_average} \
|
||||||
|
--val_best
|
||||||
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
|
mkdir -p $result_dir
|
||||||
|
python wekws/bin/score_ctc.py \
|
||||||
|
--config $dir/config.yaml \
|
||||||
|
--test_data data/test/data.list \
|
||||||
|
--batch_size 256 \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--score_file $result_dir/score.txt \
|
||||||
|
--num_workers 8 \
|
||||||
|
--keywords 嗨小问,你好问问 \
|
||||||
|
--token_file data/tokens.txt \
|
||||||
|
--lexicon_file data/lexicon.txt
|
||||||
|
|
||||||
|
python wekws/bin/compute_det_ctc.py \
|
||||||
|
--keyword 嗨小问,你好问问 \
|
||||||
|
--test_data data/test/data.list \
|
||||||
|
--window_shift $window_shift \
|
||||||
|
--step 0.001 \
|
||||||
|
--score_file $result_dir/score.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||||
|
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||||
|
python wekws/bin/export_jit.py \
|
||||||
|
--config $dir/config.yaml \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--jit_model $dir/$jit_model
|
||||||
|
python wekws/bin/export_onnx.py \
|
||||||
|
--config $dir/config.yaml \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--onnx_model $dir/$onnx_model
|
||||||
|
fi
|
||||||
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||||
|
# 2023 Jing Du(thuduj12@163.com)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -14,8 +15,142 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse, logging
|
||||||
import json
|
import json, re
|
||||||
|
|
||||||
|
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||||
|
|
||||||
|
def split_mixed_label(input_str):
|
||||||
|
tokens = []
|
||||||
|
s = input_str.lower()
|
||||||
|
while len(s) > 0:
|
||||||
|
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||||
|
if match is not None:
|
||||||
|
word = match.group(0)
|
||||||
|
else:
|
||||||
|
word = s[0:1]
|
||||||
|
tokens.append(word)
|
||||||
|
s = s.replace(word, '', 1).strip(' ')
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def query_token_set(txt, symbol_table, lexicon_table):
|
||||||
|
tokens_str = tuple()
|
||||||
|
tokens_idx = tuple()
|
||||||
|
|
||||||
|
parts = split_mixed_label(txt)
|
||||||
|
for part in parts:
|
||||||
|
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||||
|
tokens_str = tokens_str + ('!sil', )
|
||||||
|
elif part == '<blk>' or part == '<blank>':
|
||||||
|
tokens_str = tokens_str + ('<blk>', )
|
||||||
|
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||||||
|
tokens_str = tokens_str + ('<GBG>', )
|
||||||
|
elif part in symbol_table:
|
||||||
|
tokens_str = tokens_str + (part, )
|
||||||
|
elif part in lexicon_table:
|
||||||
|
for ch in lexicon_table[part]:
|
||||||
|
tokens_str = tokens_str + (ch, )
|
||||||
|
else:
|
||||||
|
# case with symbols or meaningless english letter combination
|
||||||
|
part = re.sub(symbol_str, '', part)
|
||||||
|
for ch in part:
|
||||||
|
tokens_str = tokens_str + (ch, )
|
||||||
|
|
||||||
|
for ch in tokens_str:
|
||||||
|
if ch in symbol_table:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table[ch], )
|
||||||
|
elif ch == '!sil':
|
||||||
|
if 'sil' in symbol_table:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table['sil'], )
|
||||||
|
else:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||||
|
elif ch == '<GBG>':
|
||||||
|
if '<GBG>' in symbol_table:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||||
|
else:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||||
|
else:
|
||||||
|
if '<GBG>' in symbol_table:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||||
|
logging.info(
|
||||||
|
f'\'{ch}\' is not in token set, replace with <GBG>')
|
||||||
|
else:
|
||||||
|
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||||
|
logging.info(
|
||||||
|
f'\'{ch}\' is not in token set, replace with <blk>')
|
||||||
|
|
||||||
|
return tokens_str, tokens_idx
|
||||||
|
|
||||||
|
|
||||||
|
def query_token_list(txt, symbol_table, lexicon_table):
|
||||||
|
tokens_str = []
|
||||||
|
tokens_idx = []
|
||||||
|
|
||||||
|
parts = split_mixed_label(txt)
|
||||||
|
for part in parts:
|
||||||
|
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||||
|
tokens_str.append('!sil')
|
||||||
|
elif part == '<blk>' or part == '<blank>':
|
||||||
|
tokens_str.append('<blk>')
|
||||||
|
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||||||
|
tokens_str.append('<GBG>')
|
||||||
|
elif part in symbol_table:
|
||||||
|
tokens_str.append(part)
|
||||||
|
elif part in lexicon_table:
|
||||||
|
for ch in lexicon_table[part]:
|
||||||
|
tokens_str.append(ch)
|
||||||
|
else:
|
||||||
|
# case with symbols or meaningless english letter combination
|
||||||
|
part = re.sub(symbol_str, '', part)
|
||||||
|
for ch in part:
|
||||||
|
tokens_str.append(ch)
|
||||||
|
|
||||||
|
for ch in tokens_str:
|
||||||
|
if ch in symbol_table:
|
||||||
|
tokens_idx.append(symbol_table[ch])
|
||||||
|
elif ch == '!sil':
|
||||||
|
if 'sil' in symbol_table:
|
||||||
|
tokens_idx.append(symbol_table['sil'])
|
||||||
|
else:
|
||||||
|
tokens_idx.append(symbol_table['<blk>'])
|
||||||
|
elif ch == '<GBG>':
|
||||||
|
if '<GBG>' in symbol_table:
|
||||||
|
tokens_idx.append(symbol_table['<GBG>'])
|
||||||
|
else:
|
||||||
|
tokens_idx.append(symbol_table['<blk>'])
|
||||||
|
else:
|
||||||
|
if '<GBG>' in symbol_table:
|
||||||
|
tokens_idx.append(symbol_table['<GBG>'])
|
||||||
|
logging.info(
|
||||||
|
f'\'{ch}\' is not in token set, replace with <GBG>')
|
||||||
|
else:
|
||||||
|
tokens_idx.append(symbol_table['<blk>'])
|
||||||
|
logging.info(
|
||||||
|
f'\'{ch}\' is not in token set, replace with <blk>')
|
||||||
|
|
||||||
|
return tokens_str, tokens_idx
|
||||||
|
|
||||||
|
def read_token(token_file):
|
||||||
|
tokens_table = {}
|
||||||
|
with open(token_file, 'r', encoding='utf8') as fin:
|
||||||
|
for line in fin:
|
||||||
|
arr = line.strip().split()
|
||||||
|
assert len(arr) == 2
|
||||||
|
tokens_table[arr[0]] = int(arr[1]) - 1
|
||||||
|
fin.close()
|
||||||
|
return tokens_table
|
||||||
|
|
||||||
|
|
||||||
|
def read_lexicon(lexicon_file):
|
||||||
|
lexicon_table = {}
|
||||||
|
with open(lexicon_file, 'r', encoding='utf8') as fin:
|
||||||
|
for line in fin:
|
||||||
|
arr = line.strip().replace('\t', ' ').split()
|
||||||
|
assert len(arr) >= 2
|
||||||
|
lexicon_table[arr[0]] = arr[1:]
|
||||||
|
fin.close()
|
||||||
|
return lexicon_table
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='')
|
parser = argparse.ArgumentParser(description='')
|
||||||
@ -23,6 +158,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('text_file', help='text file')
|
parser.add_argument('text_file', help='text file')
|
||||||
parser.add_argument('duration_file', help='duration file')
|
parser.add_argument('duration_file', help='duration file')
|
||||||
parser.add_argument('output_file', help='output list file')
|
parser.add_argument('output_file', help='output list file')
|
||||||
|
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
wav_table = {}
|
wav_table = {}
|
||||||
@ -39,16 +176,35 @@ if __name__ == '__main__':
|
|||||||
assert len(arr) == 2
|
assert len(arr) == 2
|
||||||
duration_table[arr[0]] = float(arr[1])
|
duration_table[arr[0]] = float(arr[1])
|
||||||
|
|
||||||
|
token_table = None
|
||||||
|
if args.token_file:
|
||||||
|
token_table = read_token(args.token_file)
|
||||||
|
lexicon_table = None
|
||||||
|
if args.lexicon_file:
|
||||||
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
|
|
||||||
with open(args.text_file, 'r', encoding='utf8') as fin, \
|
with open(args.text_file, 'r', encoding='utf8') as fin, \
|
||||||
open(args.output_file, 'w', encoding='utf8') as fout:
|
open(args.output_file, 'w', encoding='utf8') as fout:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
arr = line.strip().split(maxsplit=1)
|
arr = line.strip().split(maxsplit=1)
|
||||||
key = arr[0]
|
key = arr[0]
|
||||||
|
tokens = None
|
||||||
|
if token_table!=None and lexicon_table!=None :
|
||||||
|
if len(arr) < 2: # for some utterence, no text
|
||||||
|
txt = [1] # the <blank>/sil is indexed by 1
|
||||||
|
tokens = ["sil"]
|
||||||
|
else:
|
||||||
|
tokens, txt = query_token_list(arr[1], token_table, lexicon_table)
|
||||||
|
else:
|
||||||
txt = int(arr[1])
|
txt = int(arr[1])
|
||||||
assert key in wav_table
|
assert key in wav_table
|
||||||
wav = wav_table[key]
|
wav = wav_table[key]
|
||||||
assert key in duration_table
|
assert key in duration_table
|
||||||
duration = duration_table[key]
|
duration = duration_table[key]
|
||||||
|
if tokens == None:
|
||||||
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
line = dict(key=key, txt=txt, duration=duration, wav=wav)
|
||||||
|
else:
|
||||||
|
line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav)
|
||||||
|
|
||||||
json_line = json.dumps(line, ensure_ascii=False)
|
json_line = json.dumps(line, ensure_ascii=False)
|
||||||
fout.write(json_line + '\n')
|
fout.write(json_line + '\n')
|
||||||
|
|||||||
241
wekws/bin/compute_det_ctc.py
Normal file
241
wekws/bin/compute_det_ctc.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
|
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||||
|
# 2023 Jing Du(thuduj12@163.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse, logging, glob
|
||||||
|
import json, re, os, numpy as np
|
||||||
|
import matplotlib.font_manager as fm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
font = fm.FontProperties(size=15)
|
||||||
|
|
||||||
|
def split_mixed_label(input_str):
|
||||||
|
tokens = []
|
||||||
|
s = input_str.lower()
|
||||||
|
while len(s) > 0:
|
||||||
|
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||||
|
if match is not None:
|
||||||
|
word = match.group(0)
|
||||||
|
else:
|
||||||
|
word = s[0:1]
|
||||||
|
tokens.append(word)
|
||||||
|
s = s.replace(word, '', 1).strip(' ')
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def space_mixed_label(input_str):
|
||||||
|
splits = split_mixed_label(input_str)
|
||||||
|
space_str = ''.join(f'{sub} ' for sub in splits)
|
||||||
|
return space_str.strip()
|
||||||
|
|
||||||
|
def load_label_and_score(keywords_list, label_file, score_file):
|
||||||
|
# score_table: {uttid: [keywordlist]}
|
||||||
|
score_table = {}
|
||||||
|
with open(score_file, 'r', encoding='utf8') as fin:
|
||||||
|
# read score file and store in table
|
||||||
|
for line in fin:
|
||||||
|
arr = line.strip().split()
|
||||||
|
key = arr[0]
|
||||||
|
is_detected = arr[1]
|
||||||
|
if is_detected == 'detected':
|
||||||
|
if key not in score_table:
|
||||||
|
score_table.update({
|
||||||
|
key: {
|
||||||
|
'kw': space_mixed_label(arr[2]),
|
||||||
|
'confi': float(arr[3])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
if key not in score_table:
|
||||||
|
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
|
||||||
|
|
||||||
|
label_lists = []
|
||||||
|
with open(label_file, 'r', encoding='utf8') as fin:
|
||||||
|
for line in fin:
|
||||||
|
obj = json.loads(line.strip())
|
||||||
|
label_lists.append(obj)
|
||||||
|
|
||||||
|
# build empty structure for keyword-filler infos
|
||||||
|
keyword_filler_table = {}
|
||||||
|
for keyword in keywords_list:
|
||||||
|
keyword = space_mixed_label(keyword)
|
||||||
|
keyword_filler_table[keyword] = {}
|
||||||
|
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||||
|
keyword_filler_table[keyword]['keyword_duration'] = 0.0
|
||||||
|
keyword_filler_table[keyword]['filler_table'] = {}
|
||||||
|
keyword_filler_table[keyword]['filler_duration'] = 0.0
|
||||||
|
|
||||||
|
for obj in label_lists:
|
||||||
|
assert 'key' in obj
|
||||||
|
assert 'wav' in obj
|
||||||
|
assert 'tok' in obj # here we use the tokens
|
||||||
|
assert 'duration' in obj
|
||||||
|
|
||||||
|
key = obj['key']
|
||||||
|
# wav_file = obj['wav']
|
||||||
|
txt = "".join(obj['tok'])
|
||||||
|
txt = space_mixed_label(txt)
|
||||||
|
txt_regstr_lrblk = ' ' + txt + ' '
|
||||||
|
duration = obj['duration']
|
||||||
|
assert key in score_table
|
||||||
|
|
||||||
|
for keyword in keywords_list:
|
||||||
|
keyword = space_mixed_label(keyword)
|
||||||
|
keyword_regstr_lrblk = ' ' + keyword + ' '
|
||||||
|
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
||||||
|
if keyword == score_table[key]['kw']:
|
||||||
|
keyword_filler_table[keyword]['keyword_table'].update(
|
||||||
|
{key: score_table[key]['confi']})
|
||||||
|
else:
|
||||||
|
# uttrance detected but not match this keyword
|
||||||
|
keyword_filler_table[keyword]['keyword_table'].update(
|
||||||
|
{key: -1.0})
|
||||||
|
keyword_filler_table[keyword]['keyword_duration'] += duration
|
||||||
|
else:
|
||||||
|
if keyword == score_table[key]['kw']:
|
||||||
|
keyword_filler_table[keyword]['filler_table'].update(
|
||||||
|
{key: score_table[key]['confi']})
|
||||||
|
else:
|
||||||
|
# uttrance if detected, which is not FA for this keyword
|
||||||
|
keyword_filler_table[keyword]['filler_table'].update(
|
||||||
|
{key: -1.0})
|
||||||
|
keyword_filler_table[keyword]['filler_duration'] += duration
|
||||||
|
|
||||||
|
return keyword_filler_table
|
||||||
|
|
||||||
|
def load_stats_file(stats_file):
|
||||||
|
values = []
|
||||||
|
with open(stats_file, 'r', encoding='utf8') as fin:
|
||||||
|
for line in fin:
|
||||||
|
arr = line.strip().split()
|
||||||
|
threshold, fa_per_hour, frr = arr
|
||||||
|
values.append([float(fa_per_hour), float(frr) * 100])
|
||||||
|
values.reverse()
|
||||||
|
return np.array(values)
|
||||||
|
|
||||||
|
def plot_det(dets_dir, figure_file, det_title="DetCurve"):
|
||||||
|
xlim = '[0,2]'
|
||||||
|
# xstep = kwargs.get('xstep', '1')
|
||||||
|
ylim = '[15,30]'
|
||||||
|
# ystep = kwargs.get('ystep', '5')
|
||||||
|
|
||||||
|
plt.figure(dpi=200)
|
||||||
|
plt.rcParams['xtick.direction'] = 'in'
|
||||||
|
plt.rcParams['ytick.direction'] = 'in'
|
||||||
|
plt.rcParams['font.size'] = 12
|
||||||
|
|
||||||
|
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
|
||||||
|
logging.info(f'reading det data from {file}')
|
||||||
|
label = os.path.basename(file).split('.')[0]
|
||||||
|
values = load_stats_file(file)
|
||||||
|
plt.plot(values[:, 0], values[:, 1], label=label)
|
||||||
|
|
||||||
|
xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',')
|
||||||
|
assert len(xlim_splits) == 2
|
||||||
|
ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',')
|
||||||
|
assert len(ylim_splits) == 2
|
||||||
|
|
||||||
|
plt.xlim(float(xlim_splits[0]), float(xlim_splits[1]))
|
||||||
|
plt.ylim(float(ylim_splits[0]), float(ylim_splits[1]))
|
||||||
|
|
||||||
|
# plt.xticks(range(0, xlim + x_step, x_step))
|
||||||
|
# plt.yticks(range(0, ylim + y_step, y_step))
|
||||||
|
plt.xlabel('False Alarm Per Hour')
|
||||||
|
plt.ylabel('False Rejection Rate (\\%)')
|
||||||
|
plt.title(det_title, fontproperties=font)
|
||||||
|
plt.grid(linestyle='--')
|
||||||
|
# plt.legend(loc='best', fontsize=6)
|
||||||
|
plt.legend(loc='upper right', fontsize=5)
|
||||||
|
# plt.show()
|
||||||
|
plt.savefig(figure_file)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='compute det curve')
|
||||||
|
parser.add_argument('--test_data', required=True, help='label file')
|
||||||
|
parser.add_argument('--keyword', type=str, default=None, help='keyword label')
|
||||||
|
parser.add_argument('--score_file', required=True, help='score file')
|
||||||
|
parser.add_argument('--step', type=float, default=0.01,
|
||||||
|
help='threshold step')
|
||||||
|
parser.add_argument('--window_shift', type=int, default=50,
|
||||||
|
help='window_shift is used to skip the frames after triggered')
|
||||||
|
parser.add_argument('--stats_dir',
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help='false reject/alarm stats dir, default in score_file')
|
||||||
|
parser.add_argument('--det_curve_path',
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help='det curve path, default is stats_dir/det.png')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
window_shift = args.window_shift
|
||||||
|
keywords_list = args.keyword.strip().split(',')
|
||||||
|
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
|
||||||
|
|
||||||
|
for keyword in keywords_list:
|
||||||
|
keyword = space_mixed_label(keyword)
|
||||||
|
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||||
|
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||||
|
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
||||||
|
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
||||||
|
assert keyword_num > 0, 'Can\'t compute det for {} without positive sample'
|
||||||
|
assert filler_num > 0, 'Can\'t compute det for {} without negative sample'
|
||||||
|
|
||||||
|
logging.info('Computing det for {}'.format(keyword))
|
||||||
|
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
||||||
|
keyword_dur / 3600.0, keyword_num))
|
||||||
|
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
|
||||||
|
|
||||||
|
if args.stats_dir :
|
||||||
|
stats_dir = args.stats_dir
|
||||||
|
else:
|
||||||
|
stats_dir = os.path.dirname(args.score_file)
|
||||||
|
stats_file = os.path.join(stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
||||||
|
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||||
|
threshold = 0.0
|
||||||
|
while threshold <= 1.0:
|
||||||
|
num_false_reject = 0
|
||||||
|
num_true_detect = 0
|
||||||
|
# transverse the all keyword_table
|
||||||
|
for key, confi in keyword_filler_table[keyword][
|
||||||
|
'keyword_table'].items():
|
||||||
|
if confi < threshold:
|
||||||
|
num_false_reject += 1
|
||||||
|
else:
|
||||||
|
num_true_detect += 1
|
||||||
|
|
||||||
|
num_false_alarm = 0
|
||||||
|
# transverse the all filler_table
|
||||||
|
for key, confi in keyword_filler_table[keyword][
|
||||||
|
'filler_table'].items():
|
||||||
|
if confi >= threshold:
|
||||||
|
num_false_alarm += 1
|
||||||
|
# print(f'false alarm: {keyword}, {key}, {confi}')
|
||||||
|
|
||||||
|
false_reject_rate = num_false_reject / keyword_num
|
||||||
|
true_detect_rate = num_true_detect / keyword_num
|
||||||
|
|
||||||
|
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||||
|
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
|
||||||
|
false_alarm_rate = num_false_alarm / filler_num
|
||||||
|
|
||||||
|
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
|
||||||
|
threshold, false_alarm_per_hour, threshold))
|
||||||
|
threshold += args.step
|
||||||
|
if args.det_curve_path :
|
||||||
|
det_curve_path = args.det_curve_path
|
||||||
|
else:
|
||||||
|
det_curve_path = os.path.join(stats_dir, 'det.png')
|
||||||
|
plot_det(stats_dir, det_curve_path)
|
||||||
206
wekws/bin/score_ctc.py
Normal file
206
wekws/bin/score_ctc.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
|
# 2022 Shaoqing Yu(954793264@qq.com)
|
||||||
|
# 2023 Jing Du(thuduj12@163.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os, sys, math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from wekws.dataset.dataset import Dataset
|
||||||
|
from wekws.model.kws_model import init_model
|
||||||
|
from wekws.utils.checkpoint import load_checkpoint
|
||||||
|
from wekws.model.loss import ctc_prefix_beam_search
|
||||||
|
from tools.make_list import query_token_set, read_lexicon, read_token
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='recognize with your model')
|
||||||
|
parser.add_argument('--config', required=True, help='config file')
|
||||||
|
parser.add_argument('--test_data', required=True, help='test data file')
|
||||||
|
parser.add_argument('--gpu',
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help='gpu id for this rank, -1 for cpu')
|
||||||
|
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||||
|
parser.add_argument('--batch_size',
|
||||||
|
default=16,
|
||||||
|
type=int,
|
||||||
|
help='batch size for inference')
|
||||||
|
parser.add_argument('--num_workers',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='num of subprocess workers for reading')
|
||||||
|
parser.add_argument('--pin_memory',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Use pinned memory buffers used for reading')
|
||||||
|
parser.add_argument('--prefetch',
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help='prefetch number')
|
||||||
|
parser.add_argument('--score_file',
|
||||||
|
required=True,
|
||||||
|
help='output score file')
|
||||||
|
parser.add_argument('--jit_model',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Use pinned memory buffers used for reading')
|
||||||
|
parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)')
|
||||||
|
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
|
||||||
|
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def is_sublist(main_list, check_list):
|
||||||
|
if len(main_list) < len(check_list):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
if len(main_list) == len(check_list):
|
||||||
|
return 0 if main_list == check_list else -1
|
||||||
|
|
||||||
|
for i in range(len(main_list) - len(check_list)):
|
||||||
|
if main_list[i] == check_list[0]:
|
||||||
|
for j in range(len(check_list)):
|
||||||
|
if main_list[i + j] != check_list[j]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return i
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||||
|
|
||||||
|
with open(args.config, 'r') as fin:
|
||||||
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
test_conf = copy.deepcopy(configs['dataset_conf'])
|
||||||
|
test_conf['filter_conf']['max_length'] = 102400
|
||||||
|
test_conf['filter_conf']['min_length'] = 0
|
||||||
|
test_conf['speed_perturb'] = False
|
||||||
|
test_conf['spec_aug'] = False
|
||||||
|
test_conf['shuffle'] = False
|
||||||
|
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||||
|
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||||
|
|
||||||
|
test_dataset = Dataset(args.test_data, test_conf)
|
||||||
|
test_data_loader = DataLoader(test_dataset,
|
||||||
|
batch_size=None,
|
||||||
|
pin_memory=args.pin_memory,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
prefetch_factor=args.prefetch)
|
||||||
|
|
||||||
|
if args.jit_model:
|
||||||
|
model = torch.jit.load(args.checkpoint)
|
||||||
|
# For script model, only cpu is supported.
|
||||||
|
device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
# Init asr model from configs
|
||||||
|
model = init_model(configs['model'])
|
||||||
|
load_checkpoint(model, args.checkpoint)
|
||||||
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||||
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
score_abs_path = os.path.abspath(args.score_file)
|
||||||
|
|
||||||
|
token_table = read_token(args.token_file)
|
||||||
|
lexicon_table = read_lexicon(args.lexicon_file)
|
||||||
|
# 4. parse keywords tokens
|
||||||
|
assert args.keywords is not None, 'at least one keyword is needed'
|
||||||
|
keywords_str = args.keywords
|
||||||
|
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||||
|
keywords_token = {}
|
||||||
|
keywords_idxset = {0}
|
||||||
|
keywords_strset = {'<blk>'}
|
||||||
|
keywords_tokenmap = {'<blk>': 0}
|
||||||
|
for keyword in keywords_list:
|
||||||
|
strs, indexes = query_token_set(keyword, token_table,lexicon_table)
|
||||||
|
keywords_token[keyword] = {}
|
||||||
|
keywords_token[keyword]['token_id'] = indexes
|
||||||
|
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||||
|
for i in indexes)
|
||||||
|
[keywords_strset.add(i) for i in strs]
|
||||||
|
[keywords_idxset.add(i) for i in indexes]
|
||||||
|
for txt, idx in zip(strs, indexes):
|
||||||
|
if keywords_tokenmap.get(txt, None) is None:
|
||||||
|
keywords_tokenmap[txt] = idx
|
||||||
|
|
||||||
|
token_print = ''
|
||||||
|
for txt, idx in keywords_tokenmap.items():
|
||||||
|
token_print += f'{txt}({idx}) '
|
||||||
|
logging.info(f'Token set is: {token_print}')
|
||||||
|
|
||||||
|
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||||
|
for batch_idx, batch in enumerate(test_data_loader):
|
||||||
|
keys, feats, target, lengths, target_lengths = batch
|
||||||
|
feats = feats.to(device)
|
||||||
|
lengths = lengths.to(device)
|
||||||
|
logits, _ = model(feats)
|
||||||
|
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||||
|
logits = logits.cpu()
|
||||||
|
for i in range(len(keys)):
|
||||||
|
key = keys[i]
|
||||||
|
score = logits[i][:lengths[i]]
|
||||||
|
hyps = ctc_prefix_beam_search(score, lengths[i],
|
||||||
|
keywords_idxset)
|
||||||
|
hit_keyword = None
|
||||||
|
hit_score = 1.0
|
||||||
|
# start = 0; end = 0
|
||||||
|
for one_hyp in hyps:
|
||||||
|
prefix_ids = one_hyp[0]
|
||||||
|
# path_score = one_hyp[1]
|
||||||
|
prefix_nodes = one_hyp[2]
|
||||||
|
assert len(prefix_ids) == len(prefix_nodes)
|
||||||
|
for word in keywords_token.keys():
|
||||||
|
lab = keywords_token[word]['token_id']
|
||||||
|
offset = is_sublist(prefix_ids, lab)
|
||||||
|
if offset != -1:
|
||||||
|
hit_keyword = word
|
||||||
|
# start = prefix_nodes[offset]['frame']
|
||||||
|
# end = prefix_nodes[offset+len(lab)-1]['frame']
|
||||||
|
for idx in range(offset, offset + len(lab)):
|
||||||
|
hit_score *= prefix_nodes[idx]['prob']
|
||||||
|
break
|
||||||
|
if hit_keyword is not None:
|
||||||
|
hit_score = math.sqrt(hit_score)
|
||||||
|
break
|
||||||
|
|
||||||
|
if hit_keyword is not None:
|
||||||
|
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
|
||||||
|
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
|
||||||
|
fout.write('{} {} {:.3f}\n'.format(
|
||||||
|
key, hit_keyword, hit_score))
|
||||||
|
else:
|
||||||
|
fout.write('{} -1 -1\n'.format(key))
|
||||||
|
|
||||||
|
if batch_idx % 10 == 0:
|
||||||
|
print('Progress batch {}'.format(batch_idx))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -302,12 +302,24 @@ def padding(data):
|
|||||||
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
||||||
sorted_feats = [sample[i]['feat'] for i in order]
|
sorted_feats = [sample[i]['feat'] for i in order]
|
||||||
sorted_keys = [sample[i]['key'] for i in order]
|
sorted_keys = [sample[i]['key'] for i in order]
|
||||||
sorted_labels = torch.tensor([sample[i]['label'] for i in order],
|
|
||||||
dtype=torch.int64)
|
|
||||||
padded_feats = pad_sequence(sorted_feats,
|
padded_feats = pad_sequence(sorted_feats,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
padding_value=0)
|
padding_value=0)
|
||||||
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
|
|
||||||
|
if isinstance(sample[0]['label'], int):
|
||||||
|
padded_labels = torch.tensor([sample[i]['label'] for i in order],
|
||||||
|
dtype=torch.int32)
|
||||||
|
label_lengths = torch.tensor([1 for i in order],
|
||||||
|
dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
sorted_labels = [
|
||||||
|
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
|
||||||
|
]
|
||||||
|
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
|
||||||
|
dtype=torch.int32)
|
||||||
|
padded_labels = pad_sequence(
|
||||||
|
sorted_labels, batch_first=True, padding_value=-1)
|
||||||
|
yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths)
|
||||||
|
|
||||||
|
|
||||||
def add_reverb(data, reverb_source, aug_prob):
|
def add_reverb(data, reverb_source, aug_prob):
|
||||||
|
|||||||
@ -162,6 +162,16 @@ def init_model(configs):
|
|||||||
classifier = LinearClassifier(hidden_dim, output_dim)
|
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||||
activation = nn.Sigmoid()
|
activation = nn.Sigmoid()
|
||||||
|
|
||||||
|
# Here we add a possible "activation_type", one can choose to use other activation function.
|
||||||
|
# We use nn.Identity just for CTC loss
|
||||||
|
if "activation" in configs:
|
||||||
|
activation_type = configs["activation"]["type"]
|
||||||
|
if activation_type == 'identity':
|
||||||
|
activation = nn.Identity()
|
||||||
|
else:
|
||||||
|
print('Unknown activation type {}'.format(activation_type))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier, activation)
|
preprocessing, backbone, classifier, activation)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -12,8 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch, math, sys
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from wekws.utils.mask import padding_mask
|
from wekws.utils.mask import padding_mask
|
||||||
|
|
||||||
@ -93,6 +95,65 @@ def acc_frame(
|
|||||||
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
||||||
return correct * 100.0 / logits.size(0)
|
return correct * 100.0 / logits.size(0)
|
||||||
|
|
||||||
|
def acc_utterance(logits: torch.Tensor, target: torch.Tensor,
|
||||||
|
logits_length: torch.Tensor, target_length: torch.Tensor):
|
||||||
|
if logits is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logits = logits.softmax(2) # (1, maxlen, vocab_size)
|
||||||
|
logits = logits.cpu()
|
||||||
|
target = target.cpu()
|
||||||
|
|
||||||
|
total_word = 0
|
||||||
|
total_ins = 0
|
||||||
|
total_sub = 0
|
||||||
|
total_del = 0
|
||||||
|
calculator = Calculator()
|
||||||
|
for i in range(logits.size(0)):
|
||||||
|
score = logits[i][:logits_length[i]]
|
||||||
|
hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5)
|
||||||
|
lab = [str(item) for item in target[i][:target_length[i]].tolist()]
|
||||||
|
rec = []
|
||||||
|
if len(hyps) > 0:
|
||||||
|
rec = [str(item) for item in hyps[0][0]]
|
||||||
|
result = calculator.calculate(lab, rec)
|
||||||
|
# print(f'result:{result}')
|
||||||
|
if result['all'] != 0:
|
||||||
|
total_word += result['all']
|
||||||
|
total_ins += result['ins']
|
||||||
|
total_sub += result['sub']
|
||||||
|
total_del += result['del']
|
||||||
|
|
||||||
|
return float(total_word - total_ins - total_sub
|
||||||
|
- total_del) * 100.0 / total_word
|
||||||
|
|
||||||
|
def ctc_loss(logits: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
logits_lengths: torch.Tensor,
|
||||||
|
target_lengths: torch.Tensor,
|
||||||
|
need_acc: bool = False):
|
||||||
|
""" CTC Loss
|
||||||
|
Args:
|
||||||
|
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||||
|
target: (B)
|
||||||
|
logits_lengths: (B)
|
||||||
|
target_lengths: (B)
|
||||||
|
Returns:
|
||||||
|
(float): loss of current batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
acc = 0.0
|
||||||
|
if need_acc:
|
||||||
|
acc = acc_utterance(logits, target, logits_lengths, target_lengths)
|
||||||
|
|
||||||
|
# logits: (B, L, D) -> (L, B, D)
|
||||||
|
logits = logits.transpose(0, 1)
|
||||||
|
logits = logits.log_softmax(2)
|
||||||
|
loss = F.ctc_loss(
|
||||||
|
logits, target, logits_lengths, target_lengths, reduction='sum')
|
||||||
|
loss = loss / logits.size(1) # batch mean
|
||||||
|
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
||||||
""" Cross Entropy Loss
|
""" Cross Entropy Loss
|
||||||
@ -114,12 +175,279 @@ def criterion(type: str,
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
target: torch.Tensor,
|
target: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
min_duration: int = 0):
|
target_lengths: torch.Tensor = None,
|
||||||
|
min_duration: int = 0,
|
||||||
|
validation: bool = False, ):
|
||||||
if type == 'ce':
|
if type == 'ce':
|
||||||
loss, acc = cross_entropy(logits, target)
|
loss, acc = cross_entropy(logits, target)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
elif type == 'max_pooling':
|
elif type == 'max_pooling':
|
||||||
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
elif type == 'ctc':
|
||||||
|
loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation)
|
||||||
|
return loss, acc
|
||||||
else:
|
else:
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
def ctc_prefix_beam_search(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
logits_lengths: torch.Tensor,
|
||||||
|
keywords_tokenset: set = None,
|
||||||
|
score_beam_size: int = 3,
|
||||||
|
path_beam_size: int = 20,
|
||||||
|
) -> Tuple[List[List[int]], torch.Tensor]:
|
||||||
|
""" CTC prefix beam search inner implementation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): (1, max_len, vocab_size)
|
||||||
|
logits_lengths (torch.Tensor): (1, )
|
||||||
|
keywords_tokenset (set): token set for filtering score
|
||||||
|
score_beam_size (int): beam size for score
|
||||||
|
path_beam_size (int): beam size for path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: nbest results
|
||||||
|
"""
|
||||||
|
maxlen = logits.size(0)
|
||||||
|
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
|
||||||
|
ctc_probs = logits
|
||||||
|
|
||||||
|
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||||
|
|
||||||
|
# 2. CTC beam search step by step
|
||||||
|
for t in range(0, maxlen):
|
||||||
|
probs = ctc_probs[t] # (vocab_size,)
|
||||||
|
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||||
|
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||||
|
|
||||||
|
# 2.1 First beam prune: select topk best
|
||||||
|
top_k_probs, top_k_index = probs.topk(
|
||||||
|
score_beam_size) # (score_beam_size,)
|
||||||
|
|
||||||
|
# filter prob score that is too small
|
||||||
|
filter_probs = []
|
||||||
|
filter_index = []
|
||||||
|
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||||
|
if keywords_tokenset is not None:
|
||||||
|
if prob > 0.05 and idx in keywords_tokenset:
|
||||||
|
filter_probs.append(prob)
|
||||||
|
filter_index.append(idx)
|
||||||
|
else:
|
||||||
|
if prob > 0.05:
|
||||||
|
filter_probs.append(prob)
|
||||||
|
filter_index.append(idx)
|
||||||
|
|
||||||
|
if len(filter_index) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for s in filter_index:
|
||||||
|
ps = probs[s].item()
|
||||||
|
|
||||||
|
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||||
|
last = prefix[-1] if len(prefix) > 0 else None
|
||||||
|
if s == 0: # blank
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||||
|
n_pb = n_pb + pb * ps + pnb * ps
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
elif s == last:
|
||||||
|
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||||
|
# Update *ss -> *s;
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||||
|
n_pnb = n_pnb + pnb * ps
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
|
nodes[-1]['prob'] = ps
|
||||||
|
nodes[-1]['frame'] = t
|
||||||
|
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
|
||||||
|
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||||
|
# Update *s-s -> *ss, - is for blank
|
||||||
|
n_prefix = prefix + (s, )
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
|
n_pnb = n_pnb + pb * ps
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
nodes.append(dict(token=s, frame=t,
|
||||||
|
prob=ps)) # to record token prob
|
||||||
|
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
else:
|
||||||
|
n_prefix = prefix + (s, )
|
||||||
|
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||||
|
if nodes:
|
||||||
|
if ps > nodes[-1]['prob']: # update frame and prob
|
||||||
|
nodes[-1]['prob'] = ps
|
||||||
|
nodes[-1]['frame'] = t
|
||||||
|
else:
|
||||||
|
nodes = cur_nodes.copy()
|
||||||
|
nodes.append(dict(token=s, frame=t,
|
||||||
|
prob=ps)) # to record token prob
|
||||||
|
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||||
|
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||||
|
|
||||||
|
# 2.2 Second beam prune
|
||||||
|
next_hyps = sorted(
|
||||||
|
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
|
||||||
|
|
||||||
|
cur_hyps = next_hyps[:path_beam_size]
|
||||||
|
|
||||||
|
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
class Calculator:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.data = {}
|
||||||
|
self.space = []
|
||||||
|
self.cost = {}
|
||||||
|
self.cost['cor'] = 0
|
||||||
|
self.cost['sub'] = 1
|
||||||
|
self.cost['del'] = 1
|
||||||
|
self.cost['ins'] = 1
|
||||||
|
|
||||||
|
def calculate(self, lab, rec):
|
||||||
|
# Initialization
|
||||||
|
lab.insert(0, '')
|
||||||
|
rec.insert(0, '')
|
||||||
|
while len(self.space) < len(lab):
|
||||||
|
self.space.append([])
|
||||||
|
for row in self.space:
|
||||||
|
for element in row:
|
||||||
|
element['dist'] = 0
|
||||||
|
element['error'] = 'non'
|
||||||
|
while len(row) < len(rec):
|
||||||
|
row.append({'dist': 0, 'error': 'non'})
|
||||||
|
for i in range(len(lab)):
|
||||||
|
self.space[i][0]['dist'] = i
|
||||||
|
self.space[i][0]['error'] = 'del'
|
||||||
|
for j in range(len(rec)):
|
||||||
|
self.space[0][j]['dist'] = j
|
||||||
|
self.space[0][j]['error'] = 'ins'
|
||||||
|
self.space[0][0]['error'] = 'non'
|
||||||
|
for token in lab:
|
||||||
|
if token not in self.data and len(token) > 0:
|
||||||
|
self.data[token] = {
|
||||||
|
'all': 0,
|
||||||
|
'cor': 0,
|
||||||
|
'sub': 0,
|
||||||
|
'ins': 0,
|
||||||
|
'del': 0
|
||||||
|
}
|
||||||
|
for token in rec:
|
||||||
|
if token not in self.data and len(token) > 0:
|
||||||
|
self.data[token] = {
|
||||||
|
'all': 0,
|
||||||
|
'cor': 0,
|
||||||
|
'sub': 0,
|
||||||
|
'ins': 0,
|
||||||
|
'del': 0
|
||||||
|
}
|
||||||
|
# Computing edit distance
|
||||||
|
for i, lab_token in enumerate(lab):
|
||||||
|
for j, rec_token in enumerate(rec):
|
||||||
|
if i == 0 or j == 0:
|
||||||
|
continue
|
||||||
|
min_dist = sys.maxsize
|
||||||
|
min_error = 'none'
|
||||||
|
dist = self.space[i - 1][j]['dist'] + self.cost['del']
|
||||||
|
error = 'del'
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
min_error = error
|
||||||
|
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
|
||||||
|
error = 'ins'
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
min_error = error
|
||||||
|
if lab_token == rec_token:
|
||||||
|
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
|
||||||
|
error = 'cor'
|
||||||
|
else:
|
||||||
|
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
|
||||||
|
error = 'sub'
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
min_error = error
|
||||||
|
self.space[i][j]['dist'] = min_dist
|
||||||
|
self.space[i][j]['error'] = min_error
|
||||||
|
# Tracing back
|
||||||
|
result = {
|
||||||
|
'lab': [],
|
||||||
|
'rec': [],
|
||||||
|
'all': 0,
|
||||||
|
'cor': 0,
|
||||||
|
'sub': 0,
|
||||||
|
'ins': 0,
|
||||||
|
'del': 0
|
||||||
|
}
|
||||||
|
i = len(lab) - 1
|
||||||
|
j = len(rec) - 1
|
||||||
|
while True:
|
||||||
|
if self.space[i][j]['error'] == 'cor': # correct
|
||||||
|
if len(lab[i]) > 0:
|
||||||
|
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||||
|
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
|
||||||
|
result['all'] = result['all'] + 1
|
||||||
|
result['cor'] = result['cor'] + 1
|
||||||
|
result['lab'].insert(0, lab[i])
|
||||||
|
result['rec'].insert(0, rec[j])
|
||||||
|
i = i - 1
|
||||||
|
j = j - 1
|
||||||
|
elif self.space[i][j]['error'] == 'sub': # substitution
|
||||||
|
if len(lab[i]) > 0:
|
||||||
|
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||||
|
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
|
||||||
|
result['all'] = result['all'] + 1
|
||||||
|
result['sub'] = result['sub'] + 1
|
||||||
|
result['lab'].insert(0, lab[i])
|
||||||
|
result['rec'].insert(0, rec[j])
|
||||||
|
i = i - 1
|
||||||
|
j = j - 1
|
||||||
|
elif self.space[i][j]['error'] == 'del': # deletion
|
||||||
|
if len(lab[i]) > 0:
|
||||||
|
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||||
|
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
|
||||||
|
result['all'] = result['all'] + 1
|
||||||
|
result['del'] = result['del'] + 1
|
||||||
|
result['lab'].insert(0, lab[i])
|
||||||
|
result['rec'].insert(0, '')
|
||||||
|
i = i - 1
|
||||||
|
elif self.space[i][j]['error'] == 'ins': # insertion
|
||||||
|
if len(rec[j]) > 0:
|
||||||
|
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
|
||||||
|
result['ins'] = result['ins'] + 1
|
||||||
|
result['lab'].insert(0, '')
|
||||||
|
result['rec'].insert(0, rec[j])
|
||||||
|
j = j - 1
|
||||||
|
elif self.space[i][j]['error'] == 'non': # starting point
|
||||||
|
break
|
||||||
|
else: # shouldn't reach here
|
||||||
|
print(
|
||||||
|
'this should not happen , i = {i} , j = {j} , error = {error}'
|
||||||
|
.format(i=i, j=j, error=self.space[i][j]['error']))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def overall(self):
|
||||||
|
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
||||||
|
for token in self.data:
|
||||||
|
result['all'] = result['all'] + self.data[token]['all']
|
||||||
|
result['cor'] = result['cor'] + self.data[token]['cor']
|
||||||
|
result['sub'] = result['sub'] + self.data[token]['sub']
|
||||||
|
result['ins'] = result['ins'] + self.data[token]['ins']
|
||||||
|
result['del'] = result['del'] + self.data[token]['del']
|
||||||
|
return result
|
||||||
|
|
||||||
|
def cluster(self, data):
|
||||||
|
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
||||||
|
for token in data:
|
||||||
|
if token in self.data:
|
||||||
|
result['all'] = result['all'] + self.data[token]['all']
|
||||||
|
result['cor'] = result['cor'] + self.data[token]['cor']
|
||||||
|
result['sub'] = result['sub'] + self.data[token]['sub']
|
||||||
|
result['ins'] = result['ins'] + self.data[token]['ins']
|
||||||
|
result['del'] = result['del'] + self.data[token]['del']
|
||||||
|
return result
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return list(self.data.keys())
|
||||||
@ -34,17 +34,20 @@ class Executor:
|
|||||||
min_duration = args.get('min_duration', 0)
|
min_duration = args.get('min_duration', 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(data_loader):
|
for batch_idx, batch in enumerate(data_loader):
|
||||||
key, feats, target, feats_lengths = batch
|
key, feats, target, feats_lengths, label_lengths = batch
|
||||||
feats = feats.to(device)
|
feats = feats.to(device)
|
||||||
target = target.to(device)
|
target = target.to(device)
|
||||||
feats_lengths = feats_lengths.to(device)
|
feats_lengths = feats_lengths.to(device)
|
||||||
|
label_lengths = label_lengths.to(device)
|
||||||
num_utts = feats_lengths.size(0)
|
num_utts = feats_lengths.size(0)
|
||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits, _ = model(feats)
|
logits, _ = model(feats)
|
||||||
loss_type = args.get('criterion', 'max_pooling')
|
loss_type = args.get('criterion', 'max_pooling')
|
||||||
loss, acc = criterion(loss_type, logits, target, feats_lengths,
|
loss, acc = criterion(loss_type, logits, target, feats_lengths,
|
||||||
min_duration)
|
target_lengths=label_lengths,
|
||||||
|
min_duration=min_duration,
|
||||||
|
validation=False)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||||
@ -67,16 +70,20 @@ class Executor:
|
|||||||
total_acc = 0.0
|
total_acc = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(data_loader):
|
for batch_idx, batch in enumerate(data_loader):
|
||||||
key, feats, target, feats_lengths = batch
|
key, feats, target, feats_lengths, label_lengths = batch
|
||||||
feats = feats.to(device)
|
feats = feats.to(device)
|
||||||
target = target.to(device)
|
target = target.to(device)
|
||||||
feats_lengths = feats_lengths.to(device)
|
feats_lengths = feats_lengths.to(device)
|
||||||
|
label_lengths = label_lengths.to(device)
|
||||||
num_utts = feats_lengths.size(0)
|
num_utts = feats_lengths.size(0)
|
||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits, _ = model(feats)
|
logits, _ = model(feats)
|
||||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
||||||
logits, target, feats_lengths)
|
logits, target, feats_lengths,
|
||||||
|
target_lengths=label_lengths,
|
||||||
|
min_duration=0,
|
||||||
|
validation=True)
|
||||||
if torch.isfinite(loss):
|
if torch.isfinite(loss):
|
||||||
num_seen_utts += num_utts
|
num_seen_utts += num_utts
|
||||||
total_loss += loss.item() * num_utts
|
total_loss += loss.item() * num_utts
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user