fix typo.

This commit is contained in:
dujing 2023-05-25 13:35:00 +08:00
parent 9f3167632a
commit c4b2ddbd11
4 changed files with 28 additions and 20 deletions

View File

@ -44,7 +44,7 @@ optim_conf:
training_config:
grad_clip: 5
max_epoch: 30
max_epoch: 50
log_interval: 100
criterion: ctc

View File

@ -11,15 +11,19 @@ num_keywords=2599
config=conf/ds_tcn_ctc.yaml
norm_mean=true
norm_var=true
gpus="4,5,6,7"
gpus="0,1,2,3"
checkpoint=
dir=exp/ds_tcn_ctc_ft
average_model=true
num_average=30
score_checkpoint=$dir/avg_${num_average}.pt
if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi
download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
. tools/parse_options.sh || exit 1;
window_shift=50
@ -100,7 +104,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
exit
fi
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
fi
# Here we use tokens.txt and lexicon.txt to convert txt into index
@ -124,12 +128,13 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
--train_data $trainbase_dir/train/data.list \
--cv_data $trainbase_dir/dev/data.list \
--model_dir $trainbase_exp \
--num_workers 8 \
--num_workers 2 \
--ddp.dist_backend nccl \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts \
--checkpoint $trainbase_exp/23.pt
$cmvn_opts # \
#--checkpoint $trainbase_exp/23.pt
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -167,16 +172,19 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
if $average_model; then
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
python wekws/bin/score_ctc.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \

View File

@ -17,7 +17,7 @@
import argparse, logging, glob
import json, re, os, numpy as np
import matplotlib.pyplot as plt
import pypinyin
import pypinyin # for Chinese Character
def split_mixed_label(input_str):
tokens = []
@ -165,13 +165,13 @@ if __name__ == '__main__':
parser.add_argument(
'--xlim',
type=int,
default=5,
default=10,
help='xlimrange of x-axis, x is false alarm per hour')
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
parser.add_argument(
'--ylim',
type=int,
default=75,
default=100,
help='ylimrange of y-axis, y is false rejection rate')
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
@ -228,7 +228,7 @@ if __name__ == '__main__':
false_alarm_rate = num_false_alarm / filler_num
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
threshold, false_alarm_per_hour, threshold))
threshold, false_alarm_per_hour, false_reject_rate))
threshold += args.step
if args.det_curve_path :
det_curve_path = args.det_curve_path

View File

@ -192,10 +192,10 @@ def main():
if hit_keyword is not None:
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
fout.write('{} {} {:.3f}\n'.format(
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
else:
fout.write('{} -1 -1\n'.format(key))
fout.write('{} rejected\n'.format(key))
if batch_idx % 10 == 0:
print('Progress batch {}'.format(batch_idx))