This commit is contained in:
jingyong hou 2021-12-03 20:23:32 +08:00
parent 6f91335452
commit 8cdf83a692
11 changed files with 20 additions and 21 deletions

View File

@ -42,7 +42,7 @@ model:
optim: adam
optim_conf:
lr: 0.001
lr: 0.0002
weight_decay: 0.00005
training_config:

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python
import os
import sys
import argparse
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
@ -9,14 +8,11 @@ CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=
'prepare kaldi format file for google speech command dataset ')
description='prepare kaldi format file for google speech command')
parser.add_argument(
'--wav_list',
required=True,
help=
'wave list is a file containts full path of a wav file in google speech command dataset'
)
help='full path of a wav file in google speech command dataset')
parser.add_argument('--data_dir',
required=True,
help='folder to write kaldi format files')

View File

@ -1,4 +1,5 @@
"""Splits the google speech commands into train, validation and test set """
#!/usr/bin/env python
'''Splits the google speech commands into train, validation and test set'''
import os
import shutil

View File

@ -6,7 +6,7 @@
export CUDA_VISIBLE_DEVICES="0"
stage=3
stage=2
stop_stage=4
num_keywords=11
@ -21,7 +21,8 @@ dir=exp/mdtc_debug
num_average=1
score_checkpoint=$dir/avg_${num_average}.pt
download_dir=/mnt/mnt-data-3/jingyong.hou/data # your data dir
# your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data
speech_command_dir=$download_dir/speech_commands_v1
. tools/parse_options.sh || exit 1;

View File

@ -27,7 +27,6 @@ from kws.dataset.dataset import Dataset
from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint
from kws.utils.executor import Executor
from kws.utils.mask import padding_mask
def get_args():

View File

@ -220,10 +220,11 @@ def main():
training_config['epoch'] = epoch
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer,
training_config)
#executor.train(model, optimizer, train_data_loader, device, writer,
# training_config)
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'.format(epoch, cv_loss, cv_acc))
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc))
if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))

View File

@ -31,3 +31,4 @@ class ElementClassifier(nn.Module):
def forward(self, x: torch.Tensor):
return self.classifier(x)

View File

@ -41,7 +41,7 @@ class KWSModel(torch.nn.Module):
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
classifier: torch.nn.Module
):
super().__init__()
self.idim = idim

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import log
import torch
import torch.nn as nn

View File

@ -44,8 +44,9 @@ class Executor:
if num_utts == 0:
continue
logits = model(feats)
loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
loss, acc = criterion(
args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):