[kws] update parameter for plotting det curve (#54)

This commit is contained in:
Menglong Xu 2021-12-17 20:52:45 +08:00 committed by GitHub
parent 768900307a
commit f622c55b04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 6 deletions

View File

@ -107,7 +107,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--stats_file $result_dir/stats.${keyword}.txt
done
python kws/bin/plot_det_curve.py \
--keywords 'Hey_Snips' \
--keywords_dict dict/words.txt \
--stats_dir $result_dir \
--figure_file $result_dir/det.png \
--xlim 10 \

View File

@ -62,11 +62,9 @@ def plot_det_curve(
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='plot det curve')
parser.add_argument(
'--keywords',
'--keywords_dict',
required=True,
help=('keywords, must in the same order as in "dict/words.txt", ' +
'separated by ", "')
)
help='path to the dictionary of keywords')
parser.add_argument('--stats_dir', required=True, help='dir of stats files')
parser.add_argument(
'--figure_file',
@ -87,7 +85,13 @@ if __name__ == '__main__':
args = parser.parse_args()
keywords = args.keywords.strip().split(', ')
keywords = []
with open(args.keywords_dict, 'r', encoding='utf8') as fin:
for line in fin:
keyword, index = line.strip().split()
if int(index) > -1:
keywords.append(keyword)
plot_det_curve(
keywords,
args.stats_dir,