fix typo.
This commit is contained in:
parent
9f3167632a
commit
c4b2ddbd11
@ -44,7 +44,7 @@ optim_conf:
|
|||||||
|
|
||||||
training_config:
|
training_config:
|
||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
max_epoch: 30
|
max_epoch: 50
|
||||||
log_interval: 100
|
log_interval: 100
|
||||||
criterion: ctc
|
criterion: ctc
|
||||||
|
|
||||||
|
|||||||
@ -11,15 +11,19 @@ num_keywords=2599
|
|||||||
config=conf/ds_tcn_ctc.yaml
|
config=conf/ds_tcn_ctc.yaml
|
||||||
norm_mean=true
|
norm_mean=true
|
||||||
norm_var=true
|
norm_var=true
|
||||||
gpus="4,5,6,7"
|
gpus="0,1,2,3"
|
||||||
|
|
||||||
checkpoint=
|
checkpoint=
|
||||||
dir=exp/ds_tcn_ctc_ft
|
dir=exp/ds_tcn_ctc_ft
|
||||||
|
average_model=true
|
||||||
num_average=30
|
num_average=30
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
if $average_model ;then
|
||||||
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
else
|
||||||
|
score_checkpoint=$dir/final.pt
|
||||||
|
fi
|
||||||
|
|
||||||
download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir
|
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
|
||||||
|
|
||||||
. tools/parse_options.sh || exit 1;
|
. tools/parse_options.sh || exit 1;
|
||||||
window_shift=50
|
window_shift=50
|
||||||
@ -100,7 +104,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
|||||||
exit
|
exit
|
||||||
fi
|
fi
|
||||||
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
|
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
|
||||||
tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
|
tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
# Here we use tokens.txt and lexicon.txt to convert txt into index
|
||||||
@ -124,12 +128,13 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
|||||||
--train_data $trainbase_dir/train/data.list \
|
--train_data $trainbase_dir/train/data.list \
|
||||||
--cv_data $trainbase_dir/dev/data.list \
|
--cv_data $trainbase_dir/dev/data.list \
|
||||||
--model_dir $trainbase_exp \
|
--model_dir $trainbase_exp \
|
||||||
--num_workers 8 \
|
--num_workers 2 \
|
||||||
|
--ddp.dist_backend nccl \
|
||||||
--num_keywords $num_keywords \
|
--num_keywords $num_keywords \
|
||||||
--min_duration 50 \
|
--min_duration 50 \
|
||||||
--seed 666 \
|
--seed 666 \
|
||||||
$cmvn_opts \
|
$cmvn_opts # \
|
||||||
--checkpoint $trainbase_exp/23.pt
|
#--checkpoint $trainbase_exp/23.pt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
@ -167,16 +172,19 @@ fi
|
|||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
echo "Do model average, Compute FRR/FAR ..."
|
echo "Do model average, Compute FRR/FAR ..."
|
||||||
python wekws/bin/average_model.py \
|
if $average_model; then
|
||||||
--dst_model $score_checkpoint \
|
python wekws/bin/average_model.py \
|
||||||
--src_path $dir \
|
--dst_model $score_checkpoint \
|
||||||
--num ${num_average} \
|
--src_path $dir \
|
||||||
--val_best
|
--num ${num_average} \
|
||||||
|
--val_best
|
||||||
|
fi
|
||||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
mkdir -p $result_dir
|
mkdir -p $result_dir
|
||||||
python wekws/bin/score_ctc.py \
|
python wekws/bin/score_ctc.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
|
--gpu 0 \
|
||||||
--batch_size 256 \
|
--batch_size 256 \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
import argparse, logging, glob
|
import argparse, logging, glob
|
||||||
import json, re, os, numpy as np
|
import json, re, os, numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pypinyin
|
import pypinyin # for Chinese Character
|
||||||
|
|
||||||
def split_mixed_label(input_str):
|
def split_mixed_label(input_str):
|
||||||
tokens = []
|
tokens = []
|
||||||
@ -165,13 +165,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--xlim',
|
'--xlim',
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=10,
|
||||||
help='xlim:range of x-axis, x is false alarm per hour')
|
help='xlim:range of x-axis, x is false alarm per hour')
|
||||||
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
|
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--ylim',
|
'--ylim',
|
||||||
type=int,
|
type=int,
|
||||||
default=75,
|
default=100,
|
||||||
help='ylim:range of y-axis, y is false rejection rate')
|
help='ylim:range of y-axis, y is false rejection rate')
|
||||||
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
|
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ if __name__ == '__main__':
|
|||||||
false_alarm_rate = num_false_alarm / filler_num
|
false_alarm_rate = num_false_alarm / filler_num
|
||||||
|
|
||||||
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
|
fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
|
||||||
threshold, false_alarm_per_hour, threshold))
|
threshold, false_alarm_per_hour, false_reject_rate))
|
||||||
threshold += args.step
|
threshold += args.step
|
||||||
if args.det_curve_path :
|
if args.det_curve_path :
|
||||||
det_curve_path = args.det_curve_path
|
det_curve_path = args.det_curve_path
|
||||||
|
|||||||
@ -192,10 +192,10 @@ def main():
|
|||||||
if hit_keyword is not None:
|
if hit_keyword is not None:
|
||||||
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
|
# fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\
|
||||||
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
|
# .format(key, start*0.03, end*0.03, hit_keyword, hit_score))
|
||||||
fout.write('{} {} {:.3f}\n'.format(
|
fout.write('{} detected {} {:.3f}\n'.format(
|
||||||
key, hit_keyword, hit_score))
|
key, hit_keyword, hit_score))
|
||||||
else:
|
else:
|
||||||
fout.write('{} -1 -1\n'.format(key))
|
fout.write('{} rejected\n'.format(key))
|
||||||
|
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
print('Progress batch {}'.format(batch_idx))
|
print('Progress batch {}'.format(batch_idx))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user