fix some bugs
This commit is contained in:
parent
c270cbe38f
commit
074a501a82
@ -7,13 +7,15 @@ stage=3
|
|||||||
stop_stage=3
|
stop_stage=3
|
||||||
num_keywords=2
|
num_keywords=2
|
||||||
|
|
||||||
|
config=conf/tcn.yaml
|
||||||
norm_mean=true
|
norm_mean=true
|
||||||
norm_var=true
|
norm_var=true
|
||||||
gpus="0,1"
|
gpus="0,1"
|
||||||
|
|
||||||
|
dir=exp/ds_tcn
|
||||||
|
|
||||||
num_average=30
|
num_average=30
|
||||||
checkpoint=$dir/avg_${num_average}.pt
|
checkpoint=
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
|
|
||||||
@ -94,8 +96,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
mkdir -p $result_dir
|
mkdir -p $result_dir
|
||||||
python kws/bin/score_longwav.py \
|
python kws/bin/score_longwav.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--test_data data/test/test_data.list \
|
--test_data data/test/data.list \
|
||||||
--batch_size 5 \
|
--batch_size 256 \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file_dir $result_dir \
|
--score_file_dir $result_dir \
|
||||||
--num_keywords $num_keywords \
|
--num_keywords $num_keywords \
|
||||||
@ -104,7 +106,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
for keyword in 0 1; do
|
for keyword in 0 1; do
|
||||||
python kws/bin/compute_det_longwav.py \
|
python kws/bin/compute_det_longwav.py \
|
||||||
--keyword $keyword \
|
--keyword $keyword \
|
||||||
--test_data data/test/test_data.list \
|
--test_data data/test/data.list \
|
||||||
--score_file $result_dir/score_longwav.${keyword}.txt \
|
--score_file $result_dir/score_longwav.${keyword}.txt \
|
||||||
--stats_file $result_dir/stats_longwav.${keyword}.txt
|
--stats_file $result_dir/stats_longwav.${keyword}.txt
|
||||||
done
|
done
|
||||||
|
|||||||
@ -68,9 +68,6 @@ if __name__ == '__main__':
|
|||||||
keyword_table, filler_table, filler_duration = load_label_and_score(
|
keyword_table, filler_table, filler_duration = load_label_and_score(
|
||||||
args.keyword, args.test_data, args.score_file)
|
args.keyword, args.test_data, args.score_file)
|
||||||
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
|
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
|
||||||
# print('keyword_table.size = ', len(keyword_table))
|
|
||||||
# print('filler_table.size = ', len(filler_table))
|
|
||||||
# print('filler_duration = ', filler_duration)
|
|
||||||
with open(args.stats_file, 'w', encoding='utf8') as fout:
|
with open(args.stats_file, 'w', encoding='utf8') as fout:
|
||||||
threshold = 0.0
|
threshold = 0.0
|
||||||
while threshold <= 1.0:
|
while threshold <= 1.0:
|
||||||
@ -88,9 +85,9 @@ if __name__ == '__main__':
|
|||||||
while i < len(score_list):
|
while i < len(score_list):
|
||||||
if score_list[i] >= threshold:
|
if score_list[i] >= threshold:
|
||||||
num_false_alarm += 1
|
num_false_alarm += 1
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
i += window_shift
|
i += window_shift
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
if len(keyword_table) != 0 :
|
if len(keyword_table) != 0 :
|
||||||
false_reject_rate = num_false_reject / len(keyword_table)
|
false_reject_rate = num_false_reject / len(keyword_table)
|
||||||
num_false_alarm = max(num_false_alarm, 1e-6)
|
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user