updated lint error

This commit is contained in:
blessyyyu 2022-03-23 14:04:00 +08:00
parent 074a501a82
commit 17a67fe579
4 changed files with 11 additions and 10 deletions

1
.gitattributes vendored
View File

@ -1 +0,0 @@
./examples/hi_xiaowen/s0/run.sh filter=gitignore

View File

@ -3,21 +3,22 @@
. ./path.sh . ./path.sh
stage=3 stage=0
stop_stage=3 stop_stage=4
num_keywords=2 num_keywords=2
config=conf/tcn.yaml config=conf/ds_tcn.yaml
norm_mean=true norm_mean=true
norm_var=true norm_var=true
gpus="0,1" gpus="0,1"
checkpoint=
dir=exp/ds_tcn dir=exp/ds_tcn
num_average=30 num_average=30
checkpoint=
score_checkpoint=$dir/avg_${num_average}.pt score_checkpoint=$dir/avg_${num_average}.pt
download_dir=./mobvoi_data
. tools/parse_options.sh || exit 1; . tools/parse_options.sh || exit 1;

View File

@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(yu954793264@163.com)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -24,8 +25,8 @@ def load_label_and_score(keyword, label_file, score_file):
arr = line.strip().split() arr = line.strip().split()
# key = utt_id # key = utt_id
key = arr[0] key = arr[0]
# scores is a list # scores is a list
str_list = arr[1: ] str_list = arr[1:]
scores = list(map(float, str_list)) scores = list(map(float, str_list))
score_table[key] = scores score_table[key] = scores
keyword_table = {} keyword_table = {}
@ -43,7 +44,7 @@ def load_label_and_score(keyword, label_file, score_file):
index = obj['txt'] index = obj['txt']
duration = obj['duration'] duration = obj['duration']
assert key in score_table assert key in score_table
# txt == keyword , correct # txt == keyword , correct
if index == keyword: if index == keyword:
keyword_table[key] = score_table[key] keyword_table[key] = score_table[key]
else: else:

View File

@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) # Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(yu954793264@163.com)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -27,7 +28,6 @@ from torch.utils.data import DataLoader
from kws.dataset.dataset import Dataset from kws.dataset.dataset import Dataset
from kws.model.kws_model import init_model from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint from kws.utils.checkpoint import load_checkpoint
from kws.utils.mask import padding_mask
def get_args(): def get_args():
@ -60,7 +60,7 @@ def get_args():
help='output score file') help='output score file')
parser.add_argument('--num_keywords', parser.add_argument('--num_keywords',
required=True, required=True,
help='the number of keywords') help='the number of keywords')
parser.add_argument('--jit_model', parser.add_argument('--jit_model',
action='store_true', action='store_true',
default=False, default=False,