From db9fc7a7384405d47f24f0d1ea160b32396881e8 Mon Sep 17 00:00:00 2001 From: blessyyyu <954793264@qq.com> Date: Wed, 23 Mar 2022 18:28:06 +0800 Subject: [PATCH] remove 'num_keyword' parameter --- examples/hi_xiaowen/s0/run.sh | 1 - kws/bin/compute_det_longwav.py | 11 ++++++----- kws/bin/score_longwav.py | 11 ++++------- learnkws/learn_mask.py | 30 ++++++++++++++++++++++++++++++ learnkws/test_learn.py | 13 +++++++++++++ 5 files changed, 53 insertions(+), 13 deletions(-) create mode 100644 learnkws/learn_mask.py create mode 100644 learnkws/test_learn.py diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 1f00ff0..0848225 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -101,7 +101,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --batch_size 256 \ --checkpoint $score_checkpoint \ --score_file $result_dir/score_longwav.txt \ - --num_keywords $num_keywords \ --num_workers 8 for keyword in 0 1; do diff --git a/kws/bin/compute_det_longwav.py b/kws/bin/compute_det_longwav.py index ec20222..3710cd2 100644 --- a/kws/bin/compute_det_longwav.py +++ b/kws/bin/compute_det_longwav.py @@ -23,7 +23,7 @@ def load_label_and_score(keyword, label_file, score_file): for line in fin: arr = line.strip().split() key = arr[0] - str_list = arr[1:] + str_list = arr[1: ] scores = list(map(float, str_list)) score_table[key].append(scores) keyword_table = {} @@ -52,8 +52,8 @@ if __name__ == '__main__': parser.add_argument('--keyword', type=int, default=0, help='score file') parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--step', type=float, default=0.01, help='score file') - parser.add_argument('--window_shift', type=int, default=50, - help='window_shift is used to skip the frames after triggered') + parser.add_argument('--window_shift', type=int, default=50, + help='window_shift is used to skip the frames after triggered') parser.add_argument('--stats_file', required=True, help='false reject/alarm stats file') @@ -61,7 +61,7 @@ if __name__ == '__main__': window_shift = args.window_shift keyword_table, filler_table, filler_duration = load_label_and_score( args.keyword, args.test_data, args.score_file) - print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) + print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) with open(args.stats_file, 'w', encoding='utf8') as fout: keyword_index = int(args.stats_file.split('/')[-1].split('.')[1]) threshold = 0.0 @@ -92,4 +92,5 @@ if __name__ == '__main__': fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold, false_alarm_per_hour, false_reject_rate)) - threshold += args.step \ No newline at end of file + threshold += args.step + \ No newline at end of file diff --git a/kws/bin/score_longwav.py b/kws/bin/score_longwav.py index 782763a..dfed853 100644 --- a/kws/bin/score_longwav.py +++ b/kws/bin/score_longwav.py @@ -57,10 +57,7 @@ def get_args(): help='prefetch number') parser.add_argument('--score_file', required=True, - help='output score file') - parser.add_argument('--num_keywords', - required=True, - help='the number of keywords') + help='output score file') parser.add_argument('--jit_model', action='store_true', default=False, @@ -106,22 +103,22 @@ def main(): device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) model.eval() - score_abs_path = os.path.abspath(args.score_file) - num_keywords = int(args.num_keywords) with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: for batch_idx, batch in enumerate(test_data_loader): keys, feats, target, lengths = batch feats = feats.to(device) lengths = lengths.to(device) logits = model(feats) + num_keywords = logits.shape[2] logits = logits.cpu() for i in range(len(keys)): key = keys[i] score = logits[i][:lengths[i]] for keyword_i in range(num_keywords): keyword_scores = score[:, keyword_i] - score_frames = ' '.join(['{:.3g}'.format(x) for x in keyword_scores.tolist()]) + score_frames = ' '.join(['{:.6f}'.format(x) + for x in keyword_scores.tolist()]) fout.write('{} {}\n'.format(key, score_frames)) if batch_idx % 10 == 0: print('Progress batch {}'.format(batch_idx)) diff --git a/learnkws/learn_mask.py b/learnkws/learn_mask.py new file mode 100644 index 0000000..4627ebd --- /dev/null +++ b/learnkws/learn_mask.py @@ -0,0 +1,30 @@ +''' +Date: 2022-03-04 18:10:52 +LastEditors: Cyan +LastEditTime: 2022-03-07 10:21:34 +''' + +import torch + + +def padding_mask(lengths: torch.Tensor) -> torch.Tensor: + """ + Examples: + >>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32) + >>> mask = padding_mask(lengths) + >>> print(mask) + tensor([[False, False, True], + [False, False, True], + [False, False, False]]) + """ + batch_size = lengths.size(0) + max_len = int(lengths.max().item()) + seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device) + seq = seq.expand(batch_size, max_len) + return seq >= lengths.unsqueeze(1) + +if __name__ == '__main__': + lengths = torch.tensor([2, 2, 3], dtype=torch.int32) + print(lengths.numel()) + mask = padding_mask(lengths) + print(mask, mask.size()) diff --git a/learnkws/test_learn.py b/learnkws/test_learn.py new file mode 100644 index 0000000..df1f7c7 --- /dev/null +++ b/learnkws/test_learn.py @@ -0,0 +1,13 @@ +''' +Date: 2022-03-04 18:10:52 +LastEditors: Cyan +LastEditTime: 2022-03-07 10:21:34 +''' + +if __name__ == '__main__': + a = [1,2,3,4,5,6,7] + for i in range(len(a)): + print('i = ', i) + if a[i] >= 3: + i += 2 + # print('a[i] = ' , a[i])