update compute_det_ctc
This commit is contained in:
parent
a2f8d0e39e
commit
9f3167632a
@ -1,6 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
|
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
# 2023 dujing(thuduj12@163.com)
|
# 2023 Jing Du(thuduj12@163.com)
|
||||||
|
|
||||||
. ./path.sh
|
. ./path.sh
|
||||||
|
|
||||||
@ -29,9 +29,6 @@ trainbase=true
|
|||||||
trainbase_dir=data/base
|
trainbase_dir=data/base
|
||||||
trainbase_config=conf/ds_tcn_ctc_base.yaml
|
trainbase_config=conf/ds_tcn_ctc_base.yaml
|
||||||
trainbase_exp=exp/base
|
trainbase_exp=exp/base
|
||||||
if $trainbase; then
|
|
||||||
checkpoint=$trainbase_exp/final.pt
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
|
if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
|
||||||
echo "Download and extracte all datasets"
|
echo "Download and extracte all datasets"
|
||||||
@ -131,7 +128,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
|
|||||||
--num_keywords $num_keywords \
|
--num_keywords $num_keywords \
|
||||||
--min_duration 50 \
|
--min_duration 50 \
|
||||||
--seed 666 \
|
--seed 666 \
|
||||||
$cmvn_opts
|
$cmvn_opts \
|
||||||
|
--checkpoint $trainbase_exp/23.pt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
@ -141,6 +139,18 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|||||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||||
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||||
|
|
||||||
|
if $trainbase; then
|
||||||
|
echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt"
|
||||||
|
checkpoint=$trainbase_exp/final.pt
|
||||||
|
else
|
||||||
|
echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/final.pt"
|
||||||
|
if [ ! -d mobvoi_kws_transcription ] ;then
|
||||||
|
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
|
||||||
|
fi
|
||||||
|
checkpoint=mobvoi_kws_transcription/final.pt
|
||||||
|
fi
|
||||||
|
|
||||||
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||||
wekws/bin/train.py --gpus $gpus \
|
wekws/bin/train.py --gpus $gpus \
|
||||||
--config $config \
|
--config $config \
|
||||||
@ -176,7 +186,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--lexicon_file data/lexicon.txt
|
--lexicon_file data/lexicon.txt
|
||||||
|
|
||||||
python wekws/bin/compute_det_ctc.py \
|
python wekws/bin/compute_det_ctc.py \
|
||||||
--keyword 嗨小问,你好问问 \
|
--keywords 嗨小问,你好问问 \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--window_shift $window_shift \
|
--window_shift $window_shift \
|
||||||
--step 0.001 \
|
--step 0.001 \
|
||||||
|
|||||||
@ -16,10 +16,8 @@
|
|||||||
|
|
||||||
import argparse, logging, glob
|
import argparse, logging, glob
|
||||||
import json, re, os, numpy as np
|
import json, re, os, numpy as np
|
||||||
import matplotlib.font_manager as fm
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import pypinyin
|
||||||
font = fm.FontProperties(size=15)
|
|
||||||
|
|
||||||
def split_mixed_label(input_str):
|
def split_mixed_label(input_str):
|
||||||
tokens = []
|
tokens = []
|
||||||
@ -41,7 +39,6 @@ def space_mixed_label(input_str):
|
|||||||
return space_str.strip()
|
return space_str.strip()
|
||||||
|
|
||||||
def load_label_and_score(keywords_list, label_file, score_file):
|
def load_label_and_score(keywords_list, label_file, score_file):
|
||||||
# score_table: {uttid: [keywordlist]}
|
|
||||||
score_table = {}
|
score_table = {}
|
||||||
with open(score_file, 'r', encoding='utf8') as fin:
|
with open(score_file, 'r', encoding='utf8') as fin:
|
||||||
# read score file and store in table
|
# read score file and store in table
|
||||||
@ -84,7 +81,6 @@ def load_label_and_score(keywords_list, label_file, score_file):
|
|||||||
assert 'duration' in obj
|
assert 'duration' in obj
|
||||||
|
|
||||||
key = obj['key']
|
key = obj['key']
|
||||||
# wav_file = obj['wav']
|
|
||||||
txt = "".join(obj['tok'])
|
txt = "".join(obj['tok'])
|
||||||
txt = space_mixed_label(txt)
|
txt = space_mixed_label(txt)
|
||||||
txt_regstr_lrblk = ' ' + txt + ' '
|
txt_regstr_lrblk = ' ' + txt + ' '
|
||||||
@ -125,12 +121,8 @@ def load_stats_file(stats_file):
|
|||||||
values.reverse()
|
values.reverse()
|
||||||
return np.array(values)
|
return np.array(values)
|
||||||
|
|
||||||
def plot_det(dets_dir, figure_file, det_title="DetCurve"):
|
def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
|
||||||
xlim = '[0,2]'
|
det_title = "DetCurve"
|
||||||
# xstep = kwargs.get('xstep', '1')
|
|
||||||
ylim = '[15,30]'
|
|
||||||
# ystep = kwargs.get('ystep', '5')
|
|
||||||
|
|
||||||
plt.figure(dpi=200)
|
plt.figure(dpi=200)
|
||||||
plt.rcParams['xtick.direction'] = 'in'
|
plt.rcParams['xtick.direction'] = 'in'
|
||||||
plt.rcParams['ytick.direction'] = 'in'
|
plt.rcParams['ytick.direction'] = 'in'
|
||||||
@ -138,33 +130,25 @@ def plot_det(dets_dir, figure_file, det_title="DetCurve"):
|
|||||||
|
|
||||||
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
|
for file in glob.glob(f'{dets_dir}/*stats*.txt'):
|
||||||
logging.info(f'reading det data from {file}')
|
logging.info(f'reading det data from {file}')
|
||||||
label = os.path.basename(file).split('.')[0]
|
label = os.path.basename(file).split('.')[1]
|
||||||
|
label = "".join(pypinyin.lazy_pinyin(label))
|
||||||
values = load_stats_file(file)
|
values = load_stats_file(file)
|
||||||
plt.plot(values[:, 0], values[:, 1], label=label)
|
plt.plot(values[:, 0], values[:, 1], label=label)
|
||||||
|
|
||||||
xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',')
|
plt.xlim([0, xlim])
|
||||||
assert len(xlim_splits) == 2
|
plt.ylim([0, ylim])
|
||||||
ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',')
|
plt.xticks(range(0, xlim + x_step, x_step))
|
||||||
assert len(ylim_splits) == 2
|
plt.yticks(range(0, ylim + y_step, y_step))
|
||||||
|
|
||||||
plt.xlim(float(xlim_splits[0]), float(xlim_splits[1]))
|
|
||||||
plt.ylim(float(ylim_splits[0]), float(ylim_splits[1]))
|
|
||||||
|
|
||||||
# plt.xticks(range(0, xlim + x_step, x_step))
|
|
||||||
# plt.yticks(range(0, ylim + y_step, y_step))
|
|
||||||
plt.xlabel('False Alarm Per Hour')
|
plt.xlabel('False Alarm Per Hour')
|
||||||
plt.ylabel('False Rejection Rate (\\%)')
|
plt.ylabel('False Rejection Rate (%)')
|
||||||
plt.title(det_title, fontproperties=font)
|
|
||||||
plt.grid(linestyle='--')
|
plt.grid(linestyle='--')
|
||||||
# plt.legend(loc='best', fontsize=6)
|
plt.legend(loc='best', fontsize=6)
|
||||||
plt.legend(loc='upper right', fontsize=5)
|
|
||||||
# plt.show()
|
|
||||||
plt.savefig(figure_file)
|
plt.savefig(figure_file)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='compute det curve')
|
parser = argparse.ArgumentParser(description='compute det curve')
|
||||||
parser.add_argument('--test_data', required=True, help='label file')
|
parser.add_argument('--test_data', required=True, help='label file')
|
||||||
parser.add_argument('--keyword', type=str, default=None, help='keyword label')
|
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
|
||||||
parser.add_argument('--score_file', required=True, help='score file')
|
parser.add_argument('--score_file', required=True, help='score file')
|
||||||
parser.add_argument('--step', type=float, default=0.01,
|
parser.add_argument('--step', type=float, default=0.01,
|
||||||
help='threshold step')
|
help='threshold step')
|
||||||
@ -178,10 +162,22 @@ if __name__ == '__main__':
|
|||||||
required=False,
|
required=False,
|
||||||
default=None,
|
default=None,
|
||||||
help='det curve path, default is stats_dir/det.png')
|
help='det curve path, default is stats_dir/det.png')
|
||||||
|
parser.add_argument(
|
||||||
|
'--xlim',
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help='xlim:range of x-axis, x is false alarm per hour')
|
||||||
|
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ylim',
|
||||||
|
type=int,
|
||||||
|
default=75,
|
||||||
|
help='ylim:range of y-axis, y is false rejection rate')
|
||||||
|
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
window_shift = args.window_shift
|
window_shift = args.window_shift
|
||||||
keywords_list = args.keyword.strip().split(',')
|
keywords_list = args.keywords.strip().split(',')
|
||||||
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
|
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
|
||||||
|
|
||||||
for keyword in keywords_list:
|
for keyword in keywords_list:
|
||||||
@ -238,4 +234,4 @@ if __name__ == '__main__':
|
|||||||
det_curve_path = args.det_curve_path
|
det_curve_path = args.det_curve_path
|
||||||
else:
|
else:
|
||||||
det_curve_path = os.path.join(stats_dir, 'det.png')
|
det_curve_path = os.path.join(stats_dir, 'det.png')
|
||||||
plot_det(stats_dir, det_curve_path)
|
plot_det(stats_dir, det_curve_path, args.xlim, args.x_step, args.ylim, args.y_step)
|
||||||
Loading…
x
Reference in New Issue
Block a user