fix some bugs

This commit is contained in:
blessyyyu 2022-03-23 11:14:52 +08:00
parent c270cbe38f
commit 074a501a82
2 changed files with 8 additions and 9 deletions

View File

@ -7,13 +7,15 @@ stage=3
stop_stage=3
num_keywords=2
config=conf/tcn.yaml
norm_mean=true
norm_var=true
gpus="0,1"
dir=exp/ds_tcn
num_average=30
checkpoint=$dir/avg_${num_average}.pt
checkpoint=
score_checkpoint=$dir/avg_${num_average}.pt
@ -94,8 +96,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
mkdir -p $result_dir
python kws/bin/score_longwav.py \
--config $dir/config.yaml \
--test_data data/test/test_data.list \
--batch_size 5 \
--test_data data/test/data.list \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file_dir $result_dir \
--num_keywords $num_keywords \
@ -104,7 +106,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
for keyword in 0 1; do
python kws/bin/compute_det_longwav.py \
--keyword $keyword \
--test_data data/test/test_data.list \
--test_data data/test/data.list \
--score_file $result_dir/score_longwav.${keyword}.txt \
--stats_file $result_dir/stats_longwav.${keyword}.txt
done

View File

@ -68,9 +68,6 @@ if __name__ == '__main__':
keyword_table, filler_table, filler_duration = load_label_and_score(
args.keyword, args.test_data, args.score_file)
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
# print('keyword_table.size = ', len(keyword_table))
# print('filler_table.size = ', len(filler_table))
# print('filler_duration = ', filler_duration)
with open(args.stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
@ -88,9 +85,9 @@ if __name__ == '__main__':
while i < len(score_list):
if score_list[i] >= threshold:
num_false_alarm += 1
i += 1
else:
i += window_shift
else:
i += 1
if len(keyword_table) != 0 :
false_reject_rate = num_false_reject / len(keyword_table)
num_false_alarm = max(num_false_alarm, 1e-6)