From 4dd49e0f5c6641b798dbb440cbdf69a82148421b Mon Sep 17 00:00:00 2001 From: mlxu995 <228311995@qq.com> Date: Thu, 16 Dec 2021 12:29:30 +0800 Subject: [PATCH 1/5] [kws] add code for plotting det curve --- examples/hey_snips/s0/run.sh | 4 +++ kws/bin/plot_det_curve.py | 65 ++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 70 insertions(+) create mode 100644 kws/bin/plot_det_curve.py diff --git a/examples/hey_snips/s0/run.sh b/examples/hey_snips/s0/run.sh index fc5801f..5e135ae 100755 --- a/examples/hey_snips/s0/run.sh +++ b/examples/hey_snips/s0/run.sh @@ -106,6 +106,10 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --score_file $result_dir/score.txt \ --stats_file $result_dir/stats.${keyword}.txt done + python kws/bin/plot_det_curve.py \ + --keywords 'Hey_Snips' \ + --stats_dir $result_dir \ + --figure_file $result_dir/det.png fi diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py new file mode 100644 index 0000000..231a9e1 --- /dev/null +++ b/kws/bin/plot_det_curve.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# Menglong Xu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import numpy as np +import matplotlib.pyplot as plt + + +def load_stats_file(stats_file): + values = [] + with open(stats_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + threshold, fa_per_hour, frr = arr + values.append([float(fa_per_hour), float(frr)]) + values.reverse() + return np.array(values) + + +def plot_det_curve(keywords, stats_dir, figure_file): + plt.figure(dpi=200) + plt.rcParams['xtick.direction'] = 'in' + plt.rcParams['ytick.direction'] = 'in' + plt.rcParams['text.usetex'] = True + plt.rcParams['font.size'] = 12 + + for index, keyword in enumerate(keywords): + stats_file = os.path.join(stats_dir, 'stats.'+str(index)+'.txt') + values = load_stats_file(stats_file) + plt.plot(values[:, 0], values[:, 1], label=keyword) + + plt.xlim([0, 5]) + plt.ylim([0, 0.35]) + plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1','2','3','4','5']) + plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], ['0', '5','10','15','20','25','30','35']) + plt.xlabel('False Alarm Per Hour') + plt.ylabel('False Rejection Rate (\%)') + plt.grid(linestyle='--') + plt.legend(loc='best', fontsize=16) + plt.savefig(figure_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='plot det curve') + parser.add_argument('--keywords', required=True, help='keywords, must in the same order as in "dict/words.txt" separated by ", "') + parser.add_argument('--stats_dir', required=True, help='dir of stats files') + parser.add_argument('--figure_file', required=True, help='path to save det curve') + + args = parser.parse_args() + + keywords = args.keywords.strip().split(', ') + plot_det_curve(keywords, args.stats_dir, args.figure_file) diff --git a/requirements.txt b/requirements.txt index b010971..cc3f45f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ flake8==3.8.2 pyyaml>=5.1 tensorboard tensorboardX +matplotlib \ No newline at end of file From 8383ae2f9090a29258970464913526a31f368754 Mon Sep 17 00:00:00 2001 From: mlxu995 <228311995@qq.com> Date: Thu, 16 Dec 2021 12:50:10 +0800 Subject: [PATCH 2/5] format --- kws/bin/plot_det_curve.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py index 231a9e1..64feebb 100644 --- a/kws/bin/plot_det_curve.py +++ b/kws/bin/plot_det_curve.py @@ -1,5 +1,5 @@ # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) -# Menglong Xu +# Menglong Xu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,18 +36,19 @@ def plot_det_curve(keywords, stats_dir, figure_file): plt.rcParams['ytick.direction'] = 'in' plt.rcParams['text.usetex'] = True plt.rcParams['font.size'] = 12 - + for index, keyword in enumerate(keywords): - stats_file = os.path.join(stats_dir, 'stats.'+str(index)+'.txt') + stats_file = os.path.join(stats_dir, 'stats.' + str(index) + '.txt') values = load_stats_file(stats_file) plt.plot(values[:, 0], values[:, 1], label=keyword) plt.xlim([0, 5]) plt.ylim([0, 0.35]) - plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1','2','3','4','5']) - plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], ['0', '5','10','15','20','25','30','35']) + plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1', '2', '3', '4', '5']) + plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], + ['0', '5', '10', '15', '20', '25', '30', '35']) plt.xlabel('False Alarm Per Hour') - plt.ylabel('False Rejection Rate (\%)') + plt.ylabel('False Rejection Rate (\\%)') plt.grid(linestyle='--') plt.legend(loc='best', fontsize=16) plt.savefig(figure_file) @@ -55,10 +56,11 @@ def plot_det_curve(keywords, stats_dir, figure_file): if __name__ == '__main__': parser = argparse.ArgumentParser(description='plot det curve') - parser.add_argument('--keywords', required=True, help='keywords, must in the same order as in "dict/words.txt" separated by ", "') + parser.add_argument('--keywords', required=True, + help='keywords, must in the same order as in "dict/words.txt" separated by ", "') parser.add_argument('--stats_dir', required=True, help='dir of stats files') parser.add_argument('--figure_file', required=True, help='path to save det curve') - + args = parser.parse_args() keywords = args.keywords.strip().split(', ') From 5f50d01ae32f57bd8774ac228f15b255f38b2c03 Mon Sep 17 00:00:00 2001 From: mlxu995 <228311995@qq.com> Date: Thu, 16 Dec 2021 13:08:23 +0800 Subject: [PATCH 3/5] format --- kws/bin/plot_det_curve.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py index 64feebb..71e1472 100644 --- a/kws/bin/plot_det_curve.py +++ b/kws/bin/plot_det_curve.py @@ -46,7 +46,7 @@ def plot_det_curve(keywords, stats_dir, figure_file): plt.ylim([0, 0.35]) plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1', '2', '3', '4', '5']) plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], - ['0', '5', '10', '15', '20', '25', '30', '35']) + ['0', '5', '10', '15', '20', '25', '30', '35']) plt.xlabel('False Alarm Per Hour') plt.ylabel('False Rejection Rate (\\%)') plt.grid(linestyle='--') @@ -56,10 +56,15 @@ def plot_det_curve(keywords, stats_dir, figure_file): if __name__ == '__main__': parser = argparse.ArgumentParser(description='plot det curve') - parser.add_argument('--keywords', required=True, + parser.add_argument( + '--keywords', + required=True, help='keywords, must in the same order as in "dict/words.txt" separated by ", "') parser.add_argument('--stats_dir', required=True, help='dir of stats files') - parser.add_argument('--figure_file', required=True, help='path to save det curve') + parser.add_argument( + '--figure_file', + required=True, + help='path to save det curve') args = parser.parse_args() From abdc9f125f289ee7a5a4cb26653dd6333782899e Mon Sep 17 00:00:00 2001 From: mlxu995 <228311995@qq.com> Date: Thu, 16 Dec 2021 13:49:21 +0800 Subject: [PATCH 4/5] format --- kws/bin/plot_det_curve.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py index 71e1472..c9036e8 100644 --- a/kws/bin/plot_det_curve.py +++ b/kws/bin/plot_det_curve.py @@ -59,7 +59,9 @@ if __name__ == '__main__': parser.add_argument( '--keywords', required=True, - help='keywords, must in the same order as in "dict/words.txt" separated by ", "') + help=('keywords, must in the same order as in "dict/words.txt", ' + + 'separated by ", "') + ) parser.add_argument('--stats_dir', required=True, help='dir of stats files') parser.add_argument( '--figure_file', From 92c948b720516946c6974aa74b5da5bc05157d30 Mon Sep 17 00:00:00 2001 From: mlxu995 <228311995@qq.com> Date: Thu, 16 Dec 2021 13:54:26 +0800 Subject: [PATCH 5/5] format --- kws/bin/plot_det_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py index c9036e8..ce1a980 100644 --- a/kws/bin/plot_det_curve.py +++ b/kws/bin/plot_det_curve.py @@ -59,7 +59,7 @@ if __name__ == '__main__': parser.add_argument( '--keywords', required=True, - help=('keywords, must in the same order as in "dict/words.txt", ' + + help=('keywords, must in the same order as in "dict/words.txt", ' + 'separated by ", "') ) parser.add_argument('--stats_dir', required=True, help='dir of stats files')