fix typo.
This commit is contained in:
parent
9f3167632a
commit
c4b2ddbd11
@ -44,7 +44,7 @@ optim_conf:
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 30
|
||||
max_epoch: 50
|
||||
log_interval: 100
|
||||
criterion: ctc
|
||||
|
||||
|
||||
@ -11,15 +11,19 @@ num_keywords=2599
|
||||
config=conf/ds_tcn_ctc.yaml
|
||||
norm_mean=true
|
||||
norm_var=true
|
||||
gpus="4,5,6,7"
|
||||
gpus="0,1,2,3"
|
||||
|
||||
checkpoint=
|
||||
dir=exp/ds_tcn_ctc_ft
|
||||
|
||||
average_model=true
|
||||
num_average=30
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
if $average_model ;then
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
else
|
||||
score_checkpoint=$dir/final.pt
|
||||
fi
|
||||
|
||||
download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir
|
||||
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
window_shift=50
|
||||
@ -100,7 +104,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
||||
exit
|
||||
fi
|
||||
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
|
||||
tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
|
||||
tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
|
||||
fi
|
||||
|
||||
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||
@ -124,12 +128,13 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
||||
--train_data $trainbase_dir/train/data.list \
|
||||
--cv_data $trainbase_dir/dev/data.list \
|
||||
--model_dir $trainbase_exp \
|
||||
--num_workers 8 \
|
||||
--num_workers 2 \
|
||||
--ddp.dist_backend nccl \
|
||||
--num_keywords $num_keywords \
|
||||
--min_duration 50 \
|
||||
--seed 666 \
|
||||
$cmvn_opts \
|
||||
--checkpoint $trainbase_exp/23.pt
|
||||
$cmvn_opts # \
|
||||
#--checkpoint $trainbase_exp/23.pt
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
@ -167,16 +172,19 @@ fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "Do model average, Compute FRR/FAR ..."
|
||||
python wekws/bin/average_model.py \
|
||||
--dst_model $score_checkpoint \
|
||||
--src_path $dir \
|
||||
--num ${num_average} \
|
||||
--val_best
|
||||
if $average_model; then
|
||||
python wekws/bin/average_model.py \
|
||||
--dst_model $score_checkpoint \
|
||||
--src_path $dir \
|
||||
--num ${num_average} \
|
||||
--val_best
|
||||
fi
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
mkdir -p $result_dir
|
||||
python wekws/bin/score_ctc.py \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/test/data.list \
|
||||
--gpu 0 \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import argparse, logging, glob
|
||||
import json, re, os, numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pypinyin
|
||||
import pypinyin # for Chinese Character
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
@ -165,13 +165,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--xlim',
|
||||
type=int,
|
||||
default=5,
|
||||
default=10,
|
||||
help='xlim:range of x-axis, x is false alarm per hour')
|
||||
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
|
||||
parser.add_argument(
|
||||
'--ylim',
|
||||
type=int,
|
||||
default=75,
|
||||
default=100,
|
||||
help='ylim:range of y-axis, y is false rejection rate')
|
||||
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
|
||||
|
||||
@ -228,7 +228,7 @@ if __name__ == '__main__':
|
||||
false_alarm_rate = num_false_alarm / filler_num
|
||||
|
||||
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
|
||||
threshold, false_alarm_per_hour, threshold))
|
||||
threshold, false_alarm_per_hour, false_reject_rate))
|
||||
threshold += args.step
|
||||
if args.det_curve_path :
|
||||
det_curve_path = args.det_curve_path
|
||||
|
||||
@ -192,10 +192,10 @@ def main():
|
||||
if hit_keyword is not None:
|
||||
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
|
||||
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
|
||||
fout.write('{} {} {:.3f}\n'.format(
|
||||
fout.write('{} detected {} {:.3f}\n'.format(
|
||||
key, hit_keyword, hit_score))
|
||||
else:
|
||||
fout.write('{} -1 -1\n'.format(key))
|
||||
fout.write('{} rejected\n'.format(key))
|
||||
|
||||
if batch_idx % 10 == 0:
|
||||
print('Progress batch {}'.format(batch_idx))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user