From 831ba946d63f18ac33d8918969bfdfe8e18b5302 Mon Sep 17 00:00:00 2001 From: mlxu995 <228311995@qq.com> Date: Thu, 16 Dec 2021 14:57:41 +0800 Subject: [PATCH] set xlim and ylim by parameter --- examples/hey_snips/s0/run.sh | 10 +++++++--- kws/bin/plot_det_curve.py | 36 +++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/examples/hey_snips/s0/run.sh b/examples/hey_snips/s0/run.sh index 5e135ae..bed5d23 100755 --- a/examples/hey_snips/s0/run.sh +++ b/examples/hey_snips/s0/run.sh @@ -4,7 +4,7 @@ . ./path.sh -stage=0 +stage=3 stop_stage=4 num_keywords=1 @@ -14,7 +14,7 @@ norm_var=true gpus="0" checkpoint= -dir=exp/ds_tcn +dir=exp/ds_tcn_specargument num_average=30 score_checkpoint=$dir/avg_${num_average}.pt @@ -109,7 +109,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then python kws/bin/plot_det_curve.py \ --keywords 'Hey_Snips' \ --stats_dir $result_dir \ - --figure_file $result_dir/det.png + --figure_file $result_dir/det.png \ + --xlim 10 \ + --x_step 2 \ + --ylim 10 \ + --y_step 2 fi diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py index ce1a980..c5f6ecb 100644 --- a/kws/bin/plot_det_curve.py +++ b/kws/bin/plot_det_curve.py @@ -25,16 +25,22 @@ def load_stats_file(stats_file): for line in fin: arr = line.strip().split() threshold, fa_per_hour, frr = arr - values.append([float(fa_per_hour), float(frr)]) + values.append([float(fa_per_hour), float(frr) * 100]) values.reverse() return np.array(values) -def plot_det_curve(keywords, stats_dir, figure_file): +def plot_det_curve( + keywords, + stats_dir, + figure_file, + xlim, + x_step, + ylim, + y_step): plt.figure(dpi=200) plt.rcParams['xtick.direction'] = 'in' plt.rcParams['ytick.direction'] = 'in' - plt.rcParams['text.usetex'] = True plt.rcParams['font.size'] = 12 for index, keyword in enumerate(keywords): @@ -42,11 +48,10 @@ def plot_det_curve(keywords, stats_dir, figure_file): values = load_stats_file(stats_file) plt.plot(values[:, 0], values[:, 1], label=keyword) - plt.xlim([0, 5]) - plt.ylim([0, 0.35]) - plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1', '2', '3', '4', '5']) - plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], - ['0', '5', '10', '15', '20', '25', '30', '35']) + plt.xlim([0, xlim]) + plt.ylim([0, ylim]) + plt.xticks(range(0, xlim + x_step, x_step)) + plt.yticks(range(0, ylim + y_step, y_step)) plt.xlabel('False Alarm Per Hour') plt.ylabel('False Rejection Rate (\\%)') plt.grid(linestyle='--') @@ -67,8 +72,21 @@ if __name__ == '__main__': '--figure_file', required=True, help='path to save det curve') + parser.add_argument('--xlim', required=True, help='range of X-axis') + parser.add_argument('--x_step', required=True, help='step of X-axis') + parser.add_argument('--ylim', required=True, help='range of Y-axis') + parser.add_argument('--y_step', required=True, help='step of Y-axis') args = parser.parse_args() keywords = args.keywords.strip().split(', ') - plot_det_curve(keywords, args.stats_dir, args.figure_file) + xlim, x_step, ylim, y_step = map( + int, [args.xlim, args.x_step, args.ylim, args.y_step]) + plot_det_curve( + keywords, + args.stats_dir, + args.figure_file, + xlim, + x_step, + ylim, + y_step)