update compute_det_ctc

This commit is contained in:
dujing 2023-05-22 11:10:25 +08:00
parent a2f8d0e39e
commit 9f3167632a
2 changed files with 42 additions and 36 deletions

View File

@ -1,6 +1,6 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
# 2023 dujing(thuduj12@163.com)
# 2023 Jing Du(thuduj12@163.com)
. ./path.sh
@ -29,9 +29,6 @@ trainbase=true
trainbase_dir=data/base
trainbase_config=conf/ds_tcn_ctc_base.yaml
trainbase_exp=exp/base
if $trainbase; then
checkpoint=$trainbase_exp/final.pt
fi
if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
echo "Download and extracte all datasets"
@ -131,7 +128,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts
$cmvn_opts \
--checkpoint $trainbase_exp/23.pt
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
@ -141,6 +139,18 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
if $trainbase; then
echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt"
checkpoint=$trainbase_exp/final.pt
else
echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/final.pt"
if [ ! -d mobvoi_kws_transcription ] ;then
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
fi
checkpoint=mobvoi_kws_transcription/final.pt
fi
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
@ -176,7 +186,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \
--keyword 嗨小问,你好问问 \
--keywords 嗨小问,你好问问 \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \

View File

@ -16,10 +16,8 @@
import argparse, logging, glob
import json, re, os, numpy as np
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
font = fm.FontProperties(size=15)
import pypinyin
def split_mixed_label(input_str):
tokens = []
@ -41,7 +39,6 @@ def space_mixed_label(input_str):
return space_str.strip()
def load_label_and_score(keywords_list, label_file, score_file):
# score_table: {uttid: [keywordlist]}
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table
@ -84,7 +81,6 @@ def load_label_and_score(keywords_list, label_file, score_file):
assert 'duration' in obj
key = obj['key']
# wav_file = obj['wav']
txt = "".join(obj['tok'])
txt = space_mixed_label(txt)
txt_regstr_lrblk = ' ' + txt + ' '
@ -125,12 +121,8 @@ def load_stats_file(stats_file):
values.reverse()
return np.array(values)
def plot_det(dets_dir, figure_file, det_title="DetCurve"):
xlim = '[0,2]'
# xstep = kwargs.get('xstep', '1')
ylim = '[15,30]'
# ystep = kwargs.get('ystep', '5')
def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
det_title = "DetCurve"
plt.figure(dpi=200)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
@ -138,33 +130,25 @@ def plot_det(dets_dir, figure_file, det_title="DetCurve"):
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
logging.info(f'reading det data from {file}')
label = os.path.basename(file).split('.')[0]
label = os.path.basename(file).split('.')[1]
label = "".join(pypinyin.lazy_pinyin(label))
values = load_stats_file(file)
plt.plot(values[:, 0], values[:, 1], label=label)
xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',')
assert len(xlim_splits) == 2
ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',')
assert len(ylim_splits) == 2
plt.xlim(float(xlim_splits[0]), float(xlim_splits[1]))
plt.ylim(float(ylim_splits[0]), float(ylim_splits[1]))
# plt.xticks(range(0, xlim + x_step, x_step))
# plt.yticks(range(0, ylim + y_step, y_step))
plt.xlim([0, xlim])
plt.ylim([0, ylim])
plt.xticks(range(0, xlim + x_step, x_step))
plt.yticks(range(0, ylim + y_step, y_step))
plt.xlabel('False Alarm Per Hour')
plt.ylabel('False Rejection Rate (\\%)')
plt.title(det_title, fontproperties=font)
plt.ylabel('False Rejection Rate (%)')
plt.grid(linestyle='--')
# plt.legend(loc='best', fontsize=6)
plt.legend(loc='upper right', fontsize=5)
# plt.show()
plt.legend(loc='best', fontsize=6)
plt.savefig(figure_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keyword', type=str, default=None, help='keyword label')
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01,
help='threshold step')
@ -178,10 +162,22 @@ if __name__ == '__main__':
required=False,
default=None,
help='det curve path, default is stats_dir/det.png')
parser.add_argument(
'--xlim',
type=int,
default=5,
help='xlimrange of x-axis, x is false alarm per hour')
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
parser.add_argument(
'--ylim',
type=int,
default=75,
help='ylimrange of y-axis, y is false rejection rate')
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
args = parser.parse_args()
window_shift = args.window_shift
keywords_list = args.keyword.strip().split(',')
keywords_list = args.keywords.strip().split(',')
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
for keyword in keywords_list:
@ -238,4 +234,4 @@ if __name__ == '__main__':
det_curve_path = args.det_curve_path
else:
det_curve_path = os.path.join(stats_dir, 'det.png')
plot_det(stats_dir, det_curve_path)
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)