updated lint error
This commit is contained in:
parent
074a501a82
commit
17a67fe579
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1 +0,0 @@
|
|||||||
./examples/hi_xiaowen/s0/run.sh filter=gitignore
|
|
||||||
@ -3,21 +3,22 @@
|
|||||||
|
|
||||||
. ./path.sh
|
. ./path.sh
|
||||||
|
|
||||||
stage=3
|
stage=0
|
||||||
stop_stage=3
|
stop_stage=4
|
||||||
num_keywords=2
|
num_keywords=2
|
||||||
|
|
||||||
config=conf/tcn.yaml
|
config=conf/ds_tcn.yaml
|
||||||
norm_mean=true
|
norm_mean=true
|
||||||
norm_var=true
|
norm_var=true
|
||||||
gpus="0,1"
|
gpus="0,1"
|
||||||
|
|
||||||
|
checkpoint=
|
||||||
dir=exp/ds_tcn
|
dir=exp/ds_tcn
|
||||||
|
|
||||||
num_average=30
|
num_average=30
|
||||||
checkpoint=
|
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
|
download_dir=./mobvoi_data
|
||||||
|
|
||||||
. tools/parse_options.sh || exit 1;
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
|
# 2022 Shaoqing Yu(yu954793264@163.com)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -24,8 +25,8 @@ def load_label_and_score(keyword, label_file, score_file):
|
|||||||
arr = line.strip().split()
|
arr = line.strip().split()
|
||||||
# key = utt_id
|
# key = utt_id
|
||||||
key = arr[0]
|
key = arr[0]
|
||||||
# scores is a list
|
# scores is a list
|
||||||
str_list = arr[1: ]
|
str_list = arr[1:]
|
||||||
scores = list(map(float, str_list))
|
scores = list(map(float, str_list))
|
||||||
score_table[key] = scores
|
score_table[key] = scores
|
||||||
keyword_table = {}
|
keyword_table = {}
|
||||||
@ -43,7 +44,7 @@ def load_label_and_score(keyword, label_file, score_file):
|
|||||||
index = obj['txt']
|
index = obj['txt']
|
||||||
duration = obj['duration']
|
duration = obj['duration']
|
||||||
assert key in score_table
|
assert key in score_table
|
||||||
# txt == keyword , correct
|
# txt == keyword , correct
|
||||||
if index == keyword:
|
if index == keyword:
|
||||||
keyword_table[key] = score_table[key]
|
keyword_table[key] = score_table[key]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||||
|
# 2022 Shaoqing Yu(yu954793264@163.com)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -27,7 +28,6 @@ from torch.utils.data import DataLoader
|
|||||||
from kws.dataset.dataset import Dataset
|
from kws.dataset.dataset import Dataset
|
||||||
from kws.model.kws_model import init_model
|
from kws.model.kws_model import init_model
|
||||||
from kws.utils.checkpoint import load_checkpoint
|
from kws.utils.checkpoint import load_checkpoint
|
||||||
from kws.utils.mask import padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -60,7 +60,7 @@ def get_args():
|
|||||||
help='output score file')
|
help='output score file')
|
||||||
parser.add_argument('--num_keywords',
|
parser.add_argument('--num_keywords',
|
||||||
required=True,
|
required=True,
|
||||||
help='the number of keywords')
|
help='the number of keywords')
|
||||||
parser.add_argument('--jit_model',
|
parser.add_argument('--jit_model',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user