From c270cbe38f75a358556b4250d64b974f61a639e8 Mon Sep 17 00:00:00 2001 From: blessyyyu <954793264@qq.com> Date: Tue, 22 Mar 2022 17:14:48 +0800 Subject: [PATCH] add long wav --- examples/hi_xiaowen/s0/run.sh | 27 +++--- kws/bin/compute_det_longwav.py | 102 +++++++++++++++++++++++ kws/bin/score_longwav.py | 148 +++++++++++++++++++++++++++++++++ 3 files changed, 263 insertions(+), 14 deletions(-) create mode 100644 kws/bin/compute_det_longwav.py create mode 100644 kws/bin/score_longwav.py diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index a735251..a26d16e 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -3,22 +3,19 @@ . ./path.sh -stage=0 -stop_stage=4 +stage=3 +stop_stage=3 num_keywords=2 -config=conf/ds_tcn.yaml norm_mean=true norm_var=true gpus="0,1" -checkpoint= -dir=exp/ds_tcn num_average=30 +checkpoint=$dir/avg_${num_average}.pt score_checkpoint=$dir/avg_${num_average}.pt -download_dir=./data/local # your data dir . tools/parse_options.sh || exit 1; @@ -95,19 +92,21 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --val_best result_dir=$dir/test_$(basename $score_checkpoint) mkdir -p $result_dir - python kws/bin/score.py \ + python kws/bin/score_longwav.py \ --config $dir/config.yaml \ - --test_data data/test/data.list \ - --batch_size 256 \ + --test_data data/test/test_data.list \ + --batch_size 5 \ --checkpoint $score_checkpoint \ - --score_file $result_dir/score.txt \ + --score_file_dir $result_dir \ + --num_keywords $num_keywords \ --num_workers 8 + for keyword in 0 1; do - python kws/bin/compute_det.py \ + python kws/bin/compute_det_longwav.py \ --keyword $keyword \ - --test_data data/test/data.list \ - --score_file $result_dir/score.txt \ - --stats_file $result_dir/stats.${keyword}.txt + --test_data data/test/test_data.list \ + --score_file $result_dir/score_longwav.${keyword}.txt \ + --stats_file $result_dir/stats_longwav.${keyword}.txt done fi diff --git a/kws/bin/compute_det_longwav.py b/kws/bin/compute_det_longwav.py new file mode 100644 index 0000000..5d26afb --- /dev/null +++ b/kws/bin/compute_det_longwav.py @@ -0,0 +1,102 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + + +def load_label_and_score(keyword, label_file, score_file): + # utt_id : score list + score_table = {} + with open(score_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + # key = utt_id + key = arr[0] + # scores is a list + str_list = arr[1: ] + scores = list(map(float, str_list)) + score_table[key] = scores + keyword_table = {} + filler_table = {} + filler_duration = 0.0 + # label_file = data.list + with open(label_file, 'r', encoding='utf8') as fin: + for line in fin: + obj = json.loads(line.strip()) + assert 'key' in obj + assert 'txt' in obj + assert 'duration' in obj + key = obj['key'] + # txt is label + index = obj['txt'] + duration = obj['duration'] + assert key in score_table + # txt == keyword , correct + if index == keyword: + keyword_table[key] = score_table[key] + else: + # false + filler_table[key] = score_table[key] + filler_duration += duration + return keyword_table, filler_table, filler_duration + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='compute det curve') + parser.add_argument('--test_data', required=True, help='label file') + parser.add_argument('--keyword', type=int, default=0, help='score file') + parser.add_argument('--score_file', required=True, help='score file') + parser.add_argument('--step', type=float, default=0.01, help='score file') + parser.add_argument('--stats_file', + required=True, + help='false reject/alarm stats file') + args = parser.parse_args() + # 'window_shift' is used to skip the frames after triggered + window_shift = 50 + keyword_table, filler_table, filler_duration = load_label_and_score( + args.keyword, args.test_data, args.score_file) + print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) + # print('keyword_table.size = ', len(keyword_table)) + # print('filler_table.size = ', len(filler_table)) + # print('filler_duration = ', filler_duration) + with open(args.stats_file, 'w', encoding='utf8') as fout: + threshold = 0.0 + while threshold <= 1.0: + num_false_reject = 0 + # transverse the all keyword_table + for key, score_list in keyword_table.items(): + # computer positive test sample, use the max score of list. + score = max(score_list) + if float(score) < threshold: + num_false_reject += 1 + num_false_alarm = 0 + # transverse the all filler_table + for key, score_list in filler_table.items(): + i = 0 + while i < len(score_list): + if score_list[i] >= threshold: + num_false_alarm += 1 + i += 1 + else: + i += window_shift + if len(keyword_table) != 0 : + false_reject_rate = num_false_reject / len(keyword_table) + num_false_alarm = max(num_false_alarm, 1e-6) + if filler_duration != 0: + false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0) + fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold, + false_alarm_per_hour, + false_reject_rate)) + threshold += args.step diff --git a/kws/bin/score_longwav.py b/kws/bin/score_longwav.py new file mode 100644 index 0000000..cfb6ae9 --- /dev/null +++ b/kws/bin/score_longwav.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from kws.dataset.dataset import Dataset +from kws.model.kws_model import init_model +from kws.utils.checkpoint import load_checkpoint +from kws.utils.mask import padding_mask + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--batch_size', + default=16, + type=int, + help='batch size for inference') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--score_file_dir', + required=True, + help='output score file') + parser.add_argument('--num_keywords', + required=True, + help='the number of keywords') + parser.add_argument('--jit_model', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['feature_extraction_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_size'] = args.batch_size + + test_dataset = Dataset(args.test_data, test_conf) + test_data_loader = DataLoader(test_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + + if args.jit_model: + model = torch.jit.load(args.checkpoint) + # For script model, only cpu is supported. + device = torch.device('cpu') + else: + # Init asr model from configs + model = init_model(configs['model']) + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + model.eval() + # add to write different keyword score file + num_keywords = int(args.num_keywords) + score_file_list = [] + dir_abs_path = os.path.abspath(args.score_file_dir) + for i in range(num_keywords): + temp_list = ['score_longwav', 'txt'] + temp_list.insert(1, str(i)) + suffix = '.'.join(temp_list) + # print('suffix = ', suffix) + score_abs_path = os.path.join(dir_abs_path, suffix) + score_file_list.append(score_abs_path) + + for abs_path in score_file_list: + with torch.no_grad(), open(abs_path, 'w', encoding='utf8') as fout: + keyword_label = abs_path.split('/')[-1].split('.')[1] + # print('keyword_label = ', keyword_label) + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, lengths = batch + feats = feats.to(device) + lengths = lengths.to(device) + # mask = padding_mask(lengths).unsqueeze(2) + logits = model(feats) + # mask对应的true的部分用0填充 + # Getting every frames desn't need to mask + # logits = logits.masked_fill(mask, 0.0) + logits = logits.cpu() + for i in range(len(keys)): + key = keys[i] + score = logits[i][:lengths[i]] + score = score[:, int(keyword_label)] + # keep 2 significant digits + score = ' '.join([str("%.2g" % x) for x in score.tolist()]) + fout.write('{} {}\n'.format(key, score)) + if batch_idx % 10 == 0: + print('Progress batch {}'.format(batch_idx)) + sys.stdout.flush() + + +if __name__ == '__main__': + main()