diff --git a/kws/bin/plot_det_curve.py b/kws/bin/plot_det_curve.py index 231a9e1..64feebb 100644 --- a/kws/bin/plot_det_curve.py +++ b/kws/bin/plot_det_curve.py @@ -1,5 +1,5 @@ # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) -# Menglong Xu +# Menglong Xu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,18 +36,19 @@ def plot_det_curve(keywords, stats_dir, figure_file): plt.rcParams['ytick.direction'] = 'in' plt.rcParams['text.usetex'] = True plt.rcParams['font.size'] = 12 - + for index, keyword in enumerate(keywords): - stats_file = os.path.join(stats_dir, 'stats.'+str(index)+'.txt') + stats_file = os.path.join(stats_dir, 'stats.' + str(index) + '.txt') values = load_stats_file(stats_file) plt.plot(values[:, 0], values[:, 1], label=keyword) plt.xlim([0, 5]) plt.ylim([0, 0.35]) - plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1','2','3','4','5']) - plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], ['0', '5','10','15','20','25','30','35']) + plt.xticks([0, 1, 2, 3, 4, 5], ['0', '1', '2', '3', '4', '5']) + plt.yticks([0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35], + ['0', '5', '10', '15', '20', '25', '30', '35']) plt.xlabel('False Alarm Per Hour') - plt.ylabel('False Rejection Rate (\%)') + plt.ylabel('False Rejection Rate (\\%)') plt.grid(linestyle='--') plt.legend(loc='best', fontsize=16) plt.savefig(figure_file) @@ -55,10 +56,11 @@ def plot_det_curve(keywords, stats_dir, figure_file): if __name__ == '__main__': parser = argparse.ArgumentParser(description='plot det curve') - parser.add_argument('--keywords', required=True, help='keywords, must in the same order as in "dict/words.txt" separated by ", "') + parser.add_argument('--keywords', required=True, + help='keywords, must in the same order as in "dict/words.txt" separated by ", "') parser.add_argument('--stats_dir', required=True, help='dir of stats files') parser.add_argument('--figure_file', required=True, help='path to save det curve') - + args = parser.parse_args() keywords = args.keywords.strip().split(', ')