From 00c5acf31a67f861b6a847f3da48f200bd6ee35b Mon Sep 17 00:00:00 2001 From: blessyyyu <954793264@qq.com> Date: Thu, 24 Mar 2022 10:23:50 +0800 Subject: [PATCH] override the score and compute_det file --- examples/hi_xiaowen/s0/run.sh | 10 +-- kws/bin/compute_det.py | 43 +++++++---- kws/bin/compute_det_longwav.py | 95 ------------------------ kws/bin/score.py | 21 +++--- kws/bin/score_longwav.py | 129 --------------------------------- 5 files changed, 45 insertions(+), 253 deletions(-) delete mode 100644 kws/bin/compute_det_longwav.py delete mode 100644 kws/bin/score_longwav.py diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 0848225..3c964c5 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -95,21 +95,21 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --val_best result_dir=$dir/test_$(basename $score_checkpoint) mkdir -p $result_dir - python kws/bin/score_longwav.py \ + python kws/bin/score.py \ --config $dir/config.yaml \ --test_data data/test/data.list \ --batch_size 256 \ --checkpoint $score_checkpoint \ - --score_file $result_dir/score_longwav.txt \ + --score_file $result_dir/score.txt \ --num_workers 8 for keyword in 0 1; do - python kws/bin/compute_det_longwav.py \ + python kws/bin/compute_det.py \ --keyword $keyword \ --test_data data/test/data.list \ --window_shift $window_shift \ - --score_file $result_dir/score_longwav.txt \ - --stats_file $result_dir/stats_longwav.${keyword}.txt + --score_file $result_dir/score.txt \ + --stats_file $result_dir/stats.${keyword}.txt done fi diff --git a/kws/bin/compute_det.py b/kws/bin/compute_det.py index 53e3079..4c632f5 100644 --- a/kws/bin/compute_det.py +++ b/kws/bin/compute_det.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +15,17 @@ import argparse import json - +from collections import defaultdict def load_label_and_score(keyword, label_file, score_file): - score_table = {} + score_table = defaultdict(list) with open(score_file, 'r', encoding='utf8') as fin: for line in fin: arr = line.strip().split() key = arr[0] - score = float(arr[keyword + 1]) - score_table[key] = score + str_list = arr[1:] + scores = list(map(float, str_list)) + score_table[key].append(scores) keyword_table = {} filler_table = {} filler_duration = 0.0 @@ -44,36 +46,49 @@ def load_label_and_score(keyword, label_file, score_file): filler_duration += duration return keyword_table, filler_table, filler_duration - if __name__ == '__main__': parser = argparse.ArgumentParser(description='compute det curve') parser.add_argument('--test_data', required=True, help='label file') parser.add_argument('--keyword', type=int, default=0, help='score file') parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--step', type=float, default=0.01, help='score file') + parser.add_argument('--window_shift', type=int, default=50, + help='window_shift is used to skip the frames after triggered') parser.add_argument('--stats_file', required=True, help='false reject/alarm stats file') args = parser.parse_args() - + window_shift = args.window_shift keyword_table, filler_table, filler_duration = load_label_and_score( args.keyword, args.test_data, args.score_file) print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) - with open(args.stats_file, 'w', encoding='utf8') as fout: + keyword_index = int(args.stats_file.split('/')[-1].split('.')[1]) threshold = 0.0 while threshold <= 1.0: num_false_reject = 0 - for key, score in keyword_table.items(): - if score < threshold: + # transverse the all keyword_table + for key, scores_list in keyword_table.items(): + # computer positive test sample, use the max score of list. + score = max(scores_list[keyword_index]) + if float(score) < threshold: num_false_reject += 1 num_false_alarm = 0 - for key, score in filler_table.items(): - if score >= threshold: - num_false_alarm += 1 - false_reject_rate = num_false_reject / len(keyword_table) + # transverse the all filler_table + for key, scores_list in filler_table.items(): + i = 0 + score_list = scores_list[keyword_index] + while i < len(score_list): + if score_list[i] >= threshold: + num_false_alarm += 1 + i += window_shift + else: + i += 1 + if len(keyword_table) != 0 : + false_reject_rate = num_false_reject / len(keyword_table) num_false_alarm = max(num_false_alarm, 1e-6) - false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0) + if filler_duration != 0: + false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0) fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold, false_alarm_per_hour, false_reject_rate)) diff --git a/kws/bin/compute_det_longwav.py b/kws/bin/compute_det_longwav.py deleted file mode 100644 index 9728cf5..0000000 --- a/kws/bin/compute_det_longwav.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) -# 2022 Shaoqing Yu(yu954793264@163.com) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -from collections import defaultdict - -def load_label_and_score(keyword, label_file, score_file): - score_table = defaultdict(list) - with open(score_file, 'r', encoding='utf8') as fin: - for line in fin: - arr = line.strip().split() - key = arr[0] - str_list = arr[1:] - scores = list(map(float, str_list)) - score_table[key].append(scores) - keyword_table = {} - filler_table = {} - filler_duration = 0.0 - with open(label_file, 'r', encoding='utf8') as fin: - for line in fin: - obj = json.loads(line.strip()) - assert 'key' in obj - assert 'txt' in obj - assert 'duration' in obj - key = obj['key'] - index = obj['txt'] - duration = obj['duration'] - assert key in score_table - if index == keyword: - keyword_table[key] = score_table[key] - else: - filler_table[key] = score_table[key] - filler_duration += duration - return keyword_table, filler_table, filler_duration - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='compute det curve') - parser.add_argument('--test_data', required=True, help='label file') - parser.add_argument('--keyword', type=int, default=0, help='score file') - parser.add_argument('--score_file', required=True, help='score file') - parser.add_argument('--step', type=float, default=0.01, help='score file') - parser.add_argument('--window_shift', type=int, default=50, - help='window_shift is used to skip the frames after triggered') - parser.add_argument('--stats_file', - required=True, - help='false reject/alarm stats file') - args = parser.parse_args() - window_shift = args.window_shift - keyword_table, filler_table, filler_duration = load_label_and_score( - args.keyword, args.test_data, args.score_file) - print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) - with open(args.stats_file, 'w', encoding='utf8') as fout: - keyword_index = int(args.stats_file.split('/')[-1].split('.')[1]) - threshold = 0.0 - while threshold <= 1.0: - num_false_reject = 0 - # transverse the all keyword_table - for key, scores_list in keyword_table.items(): - # computer positive test sample, use the max score of list. - score = max(scores_list[keyword_index]) - if float(score) < threshold: - num_false_reject += 1 - num_false_alarm = 0 - # transverse the all filler_table - for key, scores_list in filler_table.items(): - i = 0 - score_list = scores_list[keyword_index] - while i < len(score_list): - if score_list[i] >= threshold: - num_false_alarm += 1 - i += window_shift - else: - i += 1 - if len(keyword_table) != 0 : - false_reject_rate = num_false_reject / len(keyword_table) - num_false_alarm = max(num_false_alarm, 1e-6) - if filler_duration != 0: - false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0) - fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold, - false_alarm_per_hour, - false_reject_rate)) - threshold += args.step diff --git a/kws/bin/score.py b/kws/bin/score.py index b8e7c5c..19ccca8 100644 --- a/kws/bin/score.py +++ b/kws/bin/score.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,6 @@ from torch.utils.data import DataLoader from kws.dataset.dataset import Dataset from kws.model.kws_model import init_model from kws.utils.checkpoint import load_checkpoint -from kws.utils.mask import padding_mask def get_args(): @@ -102,23 +102,24 @@ def main(): use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) - model.eval() - with torch.no_grad(), open(args.score_file, 'w', encoding='utf8') as fout: + score_abs_path = os.path.abspath(args.score_file) + with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: for batch_idx, batch in enumerate(test_data_loader): keys, feats, target, lengths = batch feats = feats.to(device) lengths = lengths.to(device) - mask = padding_mask(lengths).unsqueeze(2) logits = model(feats) - logits = logits.masked_fill(mask, 0.0) - max_logits, _ = logits.max(dim=1) - max_logits = max_logits.cpu() + num_keywords = logits.shape[2] + logits = logits.cpu() for i in range(len(keys)): key = keys[i] - score = max_logits[i] - score = ' '.join([str(x) for x in score.tolist()]) - fout.write('{} {}\n'.format(key, score)) + score = logits[i][:lengths[i]] + for keyword_i in range(num_keywords): + keyword_scores = score[:, keyword_i] + score_frames = ' '.join(['{:.6f}'.format(x) + for x in keyword_scores.tolist()]) + fout.write('{} {}\n'.format(key, score_frames)) if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx)) sys.stdout.flush() diff --git a/kws/bin/score_longwav.py b/kws/bin/score_longwav.py deleted file mode 100644 index 1d86289..0000000 --- a/kws/bin/score_longwav.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) -# 2022 Shaoqing Yu(yu954793264@163.com) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import argparse -import copy -import logging -import os -import sys - -import torch -import yaml -from torch.utils.data import DataLoader - -from kws.dataset.dataset import Dataset -from kws.model.kws_model import init_model -from kws.utils.checkpoint import load_checkpoint - - -def get_args(): - parser = argparse.ArgumentParser(description='recognize with your model') - parser.add_argument('--config', required=True, help='config file') - parser.add_argument('--test_data', required=True, help='test data file') - parser.add_argument('--gpu', - type=int, - default=-1, - help='gpu id for this rank, -1 for cpu') - parser.add_argument('--checkpoint', required=True, help='checkpoint model') - parser.add_argument('--batch_size', - default=16, - type=int, - help='batch size for inference') - parser.add_argument('--num_workers', - default=0, - type=int, - help='num of subprocess workers for reading') - parser.add_argument('--pin_memory', - action='store_true', - default=False, - help='Use pinned memory buffers used for reading') - parser.add_argument('--prefetch', - default=100, - type=int, - help='prefetch number') - parser.add_argument('--score_file', - required=True, - help='output score file') - parser.add_argument('--jit_model', - action='store_true', - default=False, - help='Use pinned memory buffers used for reading') - args = parser.parse_args() - return args - - -def main(): - args = get_args() - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(levelname)s %(message)s') - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) - - with open(args.config, 'r') as fin: - configs = yaml.load(fin, Loader=yaml.FullLoader) - - test_conf = copy.deepcopy(configs['dataset_conf']) - test_conf['filter_conf']['max_length'] = 102400 - test_conf['filter_conf']['min_length'] = 0 - test_conf['speed_perturb'] = False - test_conf['spec_aug'] = False - test_conf['shuffle'] = False - test_conf['feature_extraction_conf']['dither'] = 0.0 - test_conf['batch_conf']['batch_size'] = args.batch_size - - test_dataset = Dataset(args.test_data, test_conf) - test_data_loader = DataLoader(test_dataset, - batch_size=None, - pin_memory=args.pin_memory, - num_workers=args.num_workers, - prefetch_factor=args.prefetch) - - if args.jit_model: - model = torch.jit.load(args.checkpoint) - # For script model, only cpu is supported. - device = torch.device('cpu') - else: - # Init asr model from configs - model = init_model(configs['model']) - load_checkpoint(model, args.checkpoint) - use_cuda = args.gpu >= 0 and torch.cuda.is_available() - device = torch.device('cuda' if use_cuda else 'cpu') - model = model.to(device) - model.eval() - score_abs_path = os.path.abspath(args.score_file) - with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: - for batch_idx, batch in enumerate(test_data_loader): - keys, feats, target, lengths = batch - feats = feats.to(device) - lengths = lengths.to(device) - logits = model(feats) - num_keywords = logits.shape[2] - logits = logits.cpu() - for i in range(len(keys)): - key = keys[i] - score = logits[i][:lengths[i]] - for keyword_i in range(num_keywords): - keyword_scores = score[:, keyword_i] - score_frames = ' '.join(['{:.6f}'.format(x) - for x in keyword_scores.tolist()]) - fout.write('{} {}\n'.format(key, score_frames)) - if batch_idx % 10 == 0: - print('Progress batch {}'.format(batch_idx)) - sys.stdout.flush() - - -if __name__ == '__main__': - main()