diff --git a/kws/bin/compute_det_longwav.py b/kws/bin/compute_det_longwav.py index 3710cd2..9728cf5 100644 --- a/kws/bin/compute_det_longwav.py +++ b/kws/bin/compute_det_longwav.py @@ -23,7 +23,7 @@ def load_label_and_score(keyword, label_file, score_file): for line in fin: arr = line.strip().split() key = arr[0] - str_list = arr[1: ] + str_list = arr[1:] scores = list(map(float, str_list)) score_table[key].append(scores) keyword_table = {} @@ -93,4 +93,3 @@ if __name__ == '__main__': false_alarm_per_hour, false_reject_rate)) threshold += args.step - \ No newline at end of file diff --git a/kws/bin/score_longwav.py b/kws/bin/score_longwav.py index dfed853..1d86289 100644 --- a/kws/bin/score_longwav.py +++ b/kws/bin/score_longwav.py @@ -57,7 +57,7 @@ def get_args(): help='prefetch number') parser.add_argument('--score_file', required=True, - help='output score file') + help='output score file') parser.add_argument('--jit_model', action='store_true', default=False, @@ -117,7 +117,7 @@ def main(): score = logits[i][:lengths[i]] for keyword_i in range(num_keywords): keyword_scores = score[:, keyword_i] - score_frames = ' '.join(['{:.6f}'.format(x) + score_frames = ' '.join(['{:.6f}'.format(x) for x in keyword_scores.tolist()]) fout.write('{} {}\n'.format(key, score_frames)) if batch_idx % 10 == 0: