diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index a26d16e..32b3257 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -7,13 +7,15 @@ stage=3 stop_stage=3 num_keywords=2 +config=conf/tcn.yaml norm_mean=true norm_var=true gpus="0,1" +dir=exp/ds_tcn num_average=30 -checkpoint=$dir/avg_${num_average}.pt +checkpoint= score_checkpoint=$dir/avg_${num_average}.pt @@ -94,8 +96,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then mkdir -p $result_dir python kws/bin/score_longwav.py \ --config $dir/config.yaml \ - --test_data data/test/test_data.list \ - --batch_size 5 \ + --test_data data/test/data.list \ + --batch_size 256 \ --checkpoint $score_checkpoint \ --score_file_dir $result_dir \ --num_keywords $num_keywords \ @@ -104,7 +106,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then for keyword in 0 1; do python kws/bin/compute_det_longwav.py \ --keyword $keyword \ - --test_data data/test/test_data.list \ + --test_data data/test/data.list \ --score_file $result_dir/score_longwav.${keyword}.txt \ --stats_file $result_dir/stats_longwav.${keyword}.txt done diff --git a/kws/bin/compute_det_longwav.py b/kws/bin/compute_det_longwav.py index 5d26afb..02c827e 100644 --- a/kws/bin/compute_det_longwav.py +++ b/kws/bin/compute_det_longwav.py @@ -68,9 +68,6 @@ if __name__ == '__main__': keyword_table, filler_table, filler_duration = load_label_and_score( args.keyword, args.test_data, args.score_file) print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) - # print('keyword_table.size = ', len(keyword_table)) - # print('filler_table.size = ', len(filler_table)) - # print('filler_duration = ', filler_duration) with open(args.stats_file, 'w', encoding='utf8') as fout: threshold = 0.0 while threshold <= 1.0: @@ -88,9 +85,9 @@ if __name__ == '__main__': while i < len(score_list): if score_list[i] >= threshold: num_false_alarm += 1 - i += 1 - else: i += window_shift + else: + i += 1 if len(keyword_table) != 0 : false_reject_rate = num_false_reject / len(keyword_table) num_false_alarm = max(num_false_alarm, 1e-6)