From 9f3167632afdcec4fe2181451c2d4e4c89029252 Mon Sep 17 00:00:00 2001 From: dujing Date: Mon, 22 May 2023 11:10:25 +0800 Subject: [PATCH] update compute_det_ctc --- examples/hi_xiaowen/s0/run_ctc.sh | 22 ++++++++---- wekws/bin/compute_det_ctc.py | 56 ++++++++++++++----------------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh index af848bb..a0147b8 100644 --- a/examples/hi_xiaowen/s0/run_ctc.sh +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -1,6 +1,6 @@ #!/bin/bash # Copyright 2021 Binbin Zhang(binbzha@qq.com) -# 2023 dujing(thuduj12@163.com) +# 2023 Jing Du(thuduj12@163.com) . ./path.sh @@ -29,9 +29,6 @@ trainbase=true trainbase_dir=data/base trainbase_config=conf/ds_tcn_ctc_base.yaml trainbase_exp=exp/base -if $trainbase; then - checkpoint=$trainbase_exp/final.pt -fi if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then echo "Download and extracte all datasets" @@ -131,7 +128,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then --num_keywords $num_keywords \ --min_duration 50 \ --seed 666 \ - $cmvn_opts + $cmvn_opts \ + --checkpoint $trainbase_exp/23.pt fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then @@ -141,6 +139,18 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" $norm_var && cmvn_opts="$cmvn_opts --norm_var" num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + + if $trainbase; then + echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt" + checkpoint=$trainbase_exp/final.pt + else + echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/final.pt" + if [ ! -d mobvoi_kws_transcription ] ;then + git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git + fi + checkpoint=mobvoi_kws_transcription/final.pt + fi + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ wekws/bin/train.py --gpus $gpus \ --config $config \ @@ -176,7 +186,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --lexicon_file data/lexicon.txt python wekws/bin/compute_det_ctc.py \ - --keyword 嗨小问,你好问问 \ + --keywords 嗨小问,你好问问 \ --test_data data/test/data.list \ --window_shift $window_shift \ --step 0.001 \ diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index 94db816..ff3fe6d 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -16,10 +16,8 @@ import argparse, logging, glob import json, re, os, numpy as np -import matplotlib.font_manager as fm import matplotlib.pyplot as plt - -font = fm.FontProperties(size=15) +import pypinyin def split_mixed_label(input_str): tokens = [] @@ -41,7 +39,6 @@ def space_mixed_label(input_str): return space_str.strip() def load_label_and_score(keywords_list, label_file, score_file): - # score_table: {uttid: [keywordlist]} score_table = {} with open(score_file, 'r', encoding='utf8') as fin: # read score file and store in table @@ -84,7 +81,6 @@ def load_label_and_score(keywords_list, label_file, score_file): assert 'duration' in obj key = obj['key'] - # wav_file = obj['wav'] txt = "".join(obj['tok']) txt = space_mixed_label(txt) txt_regstr_lrblk = ' ' + txt + ' ' @@ -125,12 +121,8 @@ def load_stats_file(stats_file): values.reverse() return np.array(values) -def plot_det(dets_dir, figure_file, det_title="DetCurve"): - xlim = '[0,2]' - # xstep = kwargs.get('xstep', '1') - ylim = '[15,30]' - # ystep = kwargs.get('ystep', '5') - +def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5): + det_title = "DetCurve" plt.figure(dpi=200) plt.rcParams['xtick.direction'] = 'in' plt.rcParams['ytick.direction'] = 'in' @@ -138,33 +130,25 @@ def plot_det(dets_dir, figure_file, det_title="DetCurve"): for file in glob.glob(f'{dets_dir}/*stats*.txt'): logging.info(f'reading det data from {file}') - label = os.path.basename(file).split('.')[0] + label = os.path.basename(file).split('.')[1] + label = "".join(pypinyin.lazy_pinyin(label)) values = load_stats_file(file) plt.plot(values[:, 0], values[:, 1], label=label) - xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',') - assert len(xlim_splits) == 2 - ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',') - assert len(ylim_splits) == 2 - - plt.xlim(float(xlim_splits[0]), float(xlim_splits[1])) - plt.ylim(float(ylim_splits[0]), float(ylim_splits[1])) - - # plt.xticks(range(0, xlim + x_step, x_step)) - # plt.yticks(range(0, ylim + y_step, y_step)) + plt.xlim([0, xlim]) + plt.ylim([0, ylim]) + plt.xticks(range(0, xlim + x_step, x_step)) + plt.yticks(range(0, ylim + y_step, y_step)) plt.xlabel('False Alarm Per Hour') - plt.ylabel('False Rejection Rate (\\%)') - plt.title(det_title, fontproperties=font) + plt.ylabel('False Rejection Rate (%)') plt.grid(linestyle='--') - # plt.legend(loc='best', fontsize=6) - plt.legend(loc='upper right', fontsize=5) - # plt.show() + plt.legend(loc='best', fontsize=6) plt.savefig(figure_file) if __name__ == '__main__': parser = argparse.ArgumentParser(description='compute det curve') parser.add_argument('--test_data', required=True, help='label file') - parser.add_argument('--keyword', type=str, default=None, help='keyword label') + parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)') parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--step', type=float, default=0.01, help='threshold step') @@ -178,10 +162,22 @@ if __name__ == '__main__': required=False, default=None, help='det curve path, default is stats_dir/det.png') + parser.add_argument( + '--xlim', + type=int, + default=5, + help='xlim:range of x-axis, x is false alarm per hour') + parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') + parser.add_argument( + '--ylim', + type=int, + default=75, + help='ylim:range of y-axis, y is false rejection rate') + parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') args = parser.parse_args() window_shift = args.window_shift - keywords_list = args.keyword.strip().split(',') + keywords_list = args.keywords.strip().split(',') keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file) for keyword in keywords_list: @@ -238,4 +234,4 @@ if __name__ == '__main__': det_curve_path = args.det_curve_path else: det_curve_path = os.path.join(stats_dir, 'det.png') - plot_det(stats_dir, det_curve_path) \ No newline at end of file + plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step) \ No newline at end of file