fix typo.

This commit is contained in:
dujing 2023-05-25 13:35:00 +08:00
parent 9f3167632a
commit c4b2ddbd11
4 changed files with 28 additions and 20 deletions

View File

@ -44,7 +44,7 @@ optim_conf:
training_config: training_config:
grad_clip: 5 grad_clip: 5
max_epoch: 30 max_epoch: 50
log_interval: 100 log_interval: 100
criterion: ctc criterion: ctc

View File

@ -11,15 +11,19 @@ num_keywords=2599
config=conf/ds_tcn_ctc.yaml config=conf/ds_tcn_ctc.yaml
norm_mean=true norm_mean=true
norm_var=true norm_var=true
gpus="4,5,6,7" gpus="0,1,2,3"
checkpoint= checkpoint=
dir=exp/ds_tcn_ctc_ft dir=exp/ds_tcn_ctc_ft
average_model=true
num_average=30 num_average=30
score_checkpoint=$dir/avg_${num_average}.pt if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi
download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
. tools/parse_options.sh || exit 1; . tools/parse_options.sh || exit 1;
window_shift=50 window_shift=50
@ -100,7 +104,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
exit exit
fi fi
if [ ! -f $trainbase_dir/$x/wav.dur ]; then if [ ! -f $trainbase_dir/$x/wav.dur ]; then
tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
fi fi
# Here we use tokens.txt and lexicon.txt to convert txt into index # Here we use tokens.txt and lexicon.txt to convert txt into index
@ -124,12 +128,13 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
--train_data $trainbase_dir/train/data.list \ --train_data $trainbase_dir/train/data.list \
--cv_data $trainbase_dir/dev/data.list \ --cv_data $trainbase_dir/dev/data.list \
--model_dir $trainbase_exp \ --model_dir $trainbase_exp \
--num_workers 8 \ --num_workers 2 \
--ddp.dist_backend nccl \
--num_keywords $num_keywords \ --num_keywords $num_keywords \
--min_duration 50 \ --min_duration 50 \
--seed 666 \ --seed 666 \
$cmvn_opts \ $cmvn_opts # \
--checkpoint $trainbase_exp/23.pt #--checkpoint $trainbase_exp/23.pt
fi fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -167,16 +172,19 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..." echo "Do model average, Compute FRR/FAR ..."
python wekws/bin/average_model.py \ if $average_model; then
--dst_model $score_checkpoint \ python wekws/bin/average_model.py \
--src_path $dir \ --dst_model $score_checkpoint \
--num ${num_average} \ --src_path $dir \
--val_best --num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint) result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir mkdir -p $result_dir
python wekws/bin/score_ctc.py \ python wekws/bin/score_ctc.py \
--config $dir/config.yaml \ --config $dir/config.yaml \
--test_data data/test/data.list \ --test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \ --batch_size 256 \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \

View File

@ -17,7 +17,7 @@
import argparse, logging, glob import argparse, logging, glob
import json, re, os, numpy as np import json, re, os, numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pypinyin import pypinyin # for Chinese Character
def split_mixed_label(input_str): def split_mixed_label(input_str):
tokens = [] tokens = []
@ -165,13 +165,13 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--xlim', '--xlim',
type=int, type=int,
default=5, default=10,
help='xlimrange of x-axis, x is false alarm per hour') help='xlimrange of x-axis, x is false alarm per hour')
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
parser.add_argument( parser.add_argument(
'--ylim', '--ylim',
type=int, type=int,
default=75, default=100,
help='ylimrange of y-axis, y is false rejection rate') help='ylimrange of y-axis, y is false rejection rate')
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
@ -228,7 +228,7 @@ if __name__ == '__main__':
false_alarm_rate = num_false_alarm / filler_num false_alarm_rate = num_false_alarm / filler_num
fout.write('{:.3f} {:.6f} {:.6f}\n'.format( fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
threshold, false_alarm_per_hour, threshold)) threshold, false_alarm_per_hour, false_reject_rate))
threshold += args.step threshold += args.step
if args.det_curve_path : if args.det_curve_path :
det_curve_path = args.det_curve_path det_curve_path = args.det_curve_path

View File

@ -192,10 +192,10 @@ def main():
if hit_keyword is not None: if hit_keyword is not None:
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\ # fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score)) # .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
fout.write('{} {} {:.3f}\n'.format( fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score)) key, hit_keyword, hit_score))
else: else:
fout.write('{} -1 -1\n'.format(key)) fout.write('{} rejected\n'.format(key))
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
print('Progress batch {}'.format(batch_idx)) print('Progress batch {}'.format(batch_idx))