From 768900307ae8720cc9be444e5143b1b41c0aeddb Mon Sep 17 00:00:00 2001 From: Menglong Xu <32296227+mlxu995@users.noreply.github.com> Date: Thu, 16 Dec 2021 18:21:04 +0800 Subject: [PATCH] [kws] add code for plotting det curve (#52) * [kws] add code for plotting det curve * format * format * format * format * [kws] add code for plotting det curve format format format format * set xlim and ylim by parameter * set xlim and ylim optional * update help information * update parser type * Update run.sh --- examples/hey_snips/s0/run.sh | 8 +++ kws/bin/plot_det_curve.py | 98 ++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 107 insertions(+) create mode 100644 kws/bin/plot_det_curve.py diff --git a/examples/hey_snips/s0/run.sh b/examples/hey_snips/s0/run.sh index fc5801f..58f300a 100755 --- a/examples/hey_snips/s0/run.sh +++ b/examples/hey_snips/s0/run.sh @@ -106,6 +106,14 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --score_file $result_dir/score.txt \ --stats_file $result_dir/stats.${keyword}.txt done + python kws/bin/plot_det_curve.py \ + --keywords 'Hey_Snips' \ + --stats_dir $result_dir \ + --figure_file $result_dir/det.png \ + --xlim 10 \ + --x_step 2 \ + --ylim 10 \ + --y_step 2 fi diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py new file mode 100644 index 0000000..fc935e3 --- /dev/null +++ b/kws/bin/plot_det_curve.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# Menglong Xu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import numpy as np +import matplotlib.pyplot as plt + + +def load_stats_file(stats_file): + values = [] + with open(stats_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + threshold, fa_per_hour, frr = arr + values.append([float(fa_per_hour), float(frr) * 100]) + values.reverse() + return np.array(values) + + +def plot_det_curve( + keywords, + stats_dir, + figure_file, + xlim, + x_step, + ylim, + y_step): + plt.figure(dpi=200) + plt.rcParams['xtick.direction'] = 'in' + plt.rcParams['ytick.direction'] = 'in' + plt.rcParams['font.size'] = 12 + + for index, keyword in enumerate(keywords): + stats_file = os.path.join(stats_dir, 'stats.' + str(index) + '.txt') + values = load_stats_file(stats_file) + plt.plot(values[:, 0], values[:, 1], label=keyword) + + plt.xlim([0, xlim]) + plt.ylim([0, ylim]) + plt.xticks(range(0, xlim + x_step, x_step)) + plt.yticks(range(0, ylim + y_step, y_step)) + plt.xlabel('False Alarm Per Hour') + plt.ylabel('False Rejection Rate (\\%)') + plt.grid(linestyle='--') + plt.legend(loc='best', fontsize=16) + plt.savefig(figure_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='plot det curve') + parser.add_argument( + '--keywords', + required=True, + help=('keywords, must in the same order as in "dict/words.txt", ' + + 'separated by ", "') + ) + parser.add_argument('--stats_dir', required=True, help='dir of stats files') + parser.add_argument( + '--figure_file', + required=True, + help='path to save det curve') + parser.add_argument( + '--xlim', + type=int, + default=5, + help='xlim:range of x-axis, x is false alarm per hour') + parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') + parser.add_argument( + '--ylim', + type=int, + default=35, + help='ylim:range of y-axis, y is false rejection rate') + parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') + + args = parser.parse_args() + + keywords = args.keywords.strip().split(', ') + plot_det_curve( + keywords, + args.stats_dir, + args.figure_file, + args.xlim, + args.x_step, + args.ylim, + args.y_step) diff --git a/requirements.txt b/requirements.txt index b010971..cc3f45f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ flake8==3.8.2 pyyaml>=5.1 tensorboard tensorboardX +matplotlib \ No newline at end of file