From c4b2ddbd1134b5b5c8982b92369a0167e5cb05d3 Mon Sep 17 00:00:00 2001 From: dujing Date: Thu, 25 May 2023 13:35:00 +0800 Subject: [PATCH] fix typo. --- .../hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml | 2 +- examples/hi_xiaowen/s0/run_ctc.sh | 34 ++++++++++++------- wekws/bin/compute_det_ctc.py | 8 ++--- wekws/bin/score_ctc.py | 4 +-- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml b/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml index 6eb753f..a97e9a4 100644 --- a/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml +++ b/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml @@ -44,7 +44,7 @@ optim_conf: training_config: grad_clip: 5 - max_epoch: 30 + max_epoch: 50 log_interval: 100 criterion: ctc diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh index a0147b8..3085ff1 100644 --- a/examples/hi_xiaowen/s0/run_ctc.sh +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -11,15 +11,19 @@ num_keywords=2599 config=conf/ds_tcn_ctc.yaml norm_mean=true norm_var=true -gpus="4,5,6,7" +gpus="0,1,2,3" checkpoint= dir=exp/ds_tcn_ctc_ft - +average_model=true num_average=30 -score_checkpoint=$dir/avg_${num_average}.pt +if $average_model ;then + score_checkpoint=$dir/avg_${num_average}.pt +else + score_checkpoint=$dir/final.pt +fi -download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir +download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir . tools/parse_options.sh || exit 1; window_shift=50 @@ -100,7 +104,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then exit fi if [ ! -f $trainbase_dir/$x/wav.dur ]; then - tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur + tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur fi # Here we use tokens.txt and lexicon.txt to convert txt into index @@ -124,12 +128,13 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then --train_data $trainbase_dir/train/data.list \ --cv_data $trainbase_dir/dev/data.list \ --model_dir $trainbase_exp \ - --num_workers 8 \ + --num_workers 2 \ + --ddp.dist_backend nccl \ --num_keywords $num_keywords \ --min_duration 50 \ --seed 666 \ - $cmvn_opts \ - --checkpoint $trainbase_exp/23.pt + $cmvn_opts # \ + #--checkpoint $trainbase_exp/23.pt fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -167,16 +172,19 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "Do model average, Compute FRR/FAR ..." - python wekws/bin/average_model.py \ - --dst_model $score_checkpoint \ - --src_path $dir \ - --num ${num_average} \ - --val_best + if $average_model; then + python wekws/bin/average_model.py \ + --dst_model $score_checkpoint \ + --src_path $dir \ + --num ${num_average} \ + --val_best + fi result_dir=$dir/test_$(basename $score_checkpoint) mkdir -p $result_dir python wekws/bin/score_ctc.py \ --config $dir/config.yaml \ --test_data data/test/data.list \ + --gpu 0 \ --batch_size 256 \ --checkpoint $score_checkpoint \ --score_file $result_dir/score.txt \ diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index ff3fe6d..f5221bb 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -17,7 +17,7 @@ import argparse, logging, glob import json, re, os, numpy as np import matplotlib.pyplot as plt -import pypinyin +import pypinyin # for Chinese Character def split_mixed_label(input_str): tokens = [] @@ -165,13 +165,13 @@ if __name__ == '__main__': parser.add_argument( '--xlim', type=int, - default=5, + default=10, help='xlim:range of x-axis, x is false alarm per hour') parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') parser.add_argument( '--ylim', type=int, - default=75, + default=100, help='ylim:range of y-axis, y is false rejection rate') parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') @@ -228,7 +228,7 @@ if __name__ == '__main__': false_alarm_rate = num_false_alarm / filler_num fout.write('{:.3f} {:.6f} {:.6f}\n'.format( - threshold, false_alarm_per_hour, threshold)) + threshold, false_alarm_per_hour, false_reject_rate)) threshold += args.step if args.det_curve_path : det_curve_path = args.det_curve_path diff --git a/wekws/bin/score_ctc.py b/wekws/bin/score_ctc.py index a2221cd..161f388 100644 --- a/wekws/bin/score_ctc.py +++ b/wekws/bin/score_ctc.py @@ -192,10 +192,10 @@ def main(): if hit_keyword is not None: # fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\ # .format(key, start*0.03, end*0.03, hit_keyword, hit_score)) - fout.write('{} {} {:.3f}\n'.format( + fout.write('{} detected {} {:.3f}\n'.format( key, hit_keyword, hit_score)) else: - fout.write('{} -1 -1\n'.format(key)) + fout.write('{} rejected\n'.format(key)) if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx))