This commit is contained in:
jingyong hou 2021-12-03 20:23:32 +08:00
parent 6f91335452
commit 8cdf83a692
11 changed files with 20 additions and 21 deletions

View File

@ -42,7 +42,7 @@ model:
optim: adam optim: adam
optim_conf: optim_conf:
lr: 0.001 lr: 0.0002
weight_decay: 0.00005 weight_decay: 0.00005
training_config: training_config:

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import sys
import argparse import argparse
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split( CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
@ -9,14 +8,11 @@ CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description='prepare kaldi format file for google speech command')
'prepare kaldi format file for google speech command dataset ')
parser.add_argument( parser.add_argument(
'--wav_list', '--wav_list',
required=True, required=True,
help= help='full path of a wav file in google speech command dataset')
'wave list is a file containts full path of a wav file in google speech command dataset'
)
parser.add_argument('--data_dir', parser.add_argument('--data_dir',
required=True, required=True,
help='folder to write kaldi format files') help='folder to write kaldi format files')

View File

@ -1,4 +1,5 @@
"""Splits the google speech commands into train, validation and test set """ #!/usr/bin/env python
'''Splits the google speech commands into train, validation and test set'''
import os import os
import shutil import shutil

View File

@ -6,7 +6,7 @@
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
stage=3 stage=2
stop_stage=4 stop_stage=4
num_keywords=11 num_keywords=11
@ -21,7 +21,8 @@ dir=exp/mdtc_debug
num_average=1 num_average=1
score_checkpoint=$dir/avg_${num_average}.pt score_checkpoint=$dir/avg_${num_average}.pt
download_dir=/mnt/mnt-data-3/jingyong.hou/data # your data dir # your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data
speech_command_dir=$download_dir/speech_commands_v1 speech_command_dir=$download_dir/speech_commands_v1
. tools/parse_options.sh || exit 1; . tools/parse_options.sh || exit 1;

View File

@ -27,7 +27,6 @@ from kws.dataset.dataset import Dataset
from kws.model.kws_model import init_model from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint from kws.utils.checkpoint import load_checkpoint
from kws.utils.executor import Executor from kws.utils.executor import Executor
from kws.utils.mask import padding_mask
def get_args(): def get_args():

View File

@ -220,10 +220,11 @@ def main():
training_config['epoch'] = epoch training_config['epoch'] = epoch
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer, #executor.train(model, optimizer, train_data_loader, device, writer,
training_config) # training_config)
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'.format(epoch, cv_loss, cv_acc)) logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc))
if args.rank == 0: if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))

View File

@ -31,3 +31,4 @@ class ElementClassifier(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return self.classifier(x) return self.classifier(x)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from logging import log
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -44,7 +44,8 @@ class Executor:
if num_utts == 0: if num_utts == 0:
continue continue
logits = model(feats) logits = model(feats)
loss, acc = criterion(args.get('criterion', 'max_pooling'), loss, acc = criterion(
args.get('criterion', 'max_pooling'),
logits, target, feats_lengths) logits, target, feats_lengths)
loss.backward() loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip) grad_norm = clip_grad_norm_(model.parameters(), clip)