This commit is contained in:
jingyong hou 2021-12-03 20:23:32 +08:00
parent 6f91335452
commit 8cdf83a692
11 changed files with 20 additions and 21 deletions

View File

@ -42,7 +42,7 @@ model:
optim: adam optim: adam
optim_conf: optim_conf:
lr: 0.001 lr: 0.0002
weight_decay: 0.00005 weight_decay: 0.00005
training_config: training_config:

View File

@ -14,7 +14,7 @@ url=http://download.tensorflow.org/data/$file_name
mkdir -p $data_dir mkdir -p $data_dir
if [ ! -f $data_dir/$file_name ]; then if [ ! -f $data_dir/$file_name ]; then
echo "downloading $url..." echo "downloading $url..."
wget -O $data_dir/$file_name $url wget -O $data_dir/$file_name $url
else else
echo "$file_name exist in $data_dir, skip download it" echo "$file_name exist in $data_dir, skip download it"
fi fi

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import sys
import argparse import argparse
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split( CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
@ -9,14 +8,11 @@ CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description='prepare kaldi format file for google speech command')
'prepare kaldi format file for google speech command dataset ')
parser.add_argument( parser.add_argument(
'--wav_list', '--wav_list',
required=True, required=True,
help= help='full path of a wav file in google speech command dataset')
'wave list is a file containts full path of a wav file in google speech command dataset'
)
parser.add_argument('--data_dir', parser.add_argument('--data_dir',
required=True, required=True,
help='folder to write kaldi format files') help='folder to write kaldi format files')

View File

@ -1,4 +1,5 @@
"""Splits the google speech commands into train, validation and test set """ #!/usr/bin/env python
'''Splits the google speech commands into train, validation and test set'''
import os import os
import shutil import shutil

View File

@ -6,7 +6,7 @@
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
stage=3 stage=2
stop_stage=4 stop_stage=4
num_keywords=11 num_keywords=11
@ -21,7 +21,8 @@ dir=exp/mdtc_debug
num_average=1 num_average=1
score_checkpoint=$dir/avg_${num_average}.pt score_checkpoint=$dir/avg_${num_average}.pt
download_dir=/mnt/mnt-data-3/jingyong.hou/data # your data dir # your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data
speech_command_dir=$download_dir/speech_commands_v1 speech_command_dir=$download_dir/speech_commands_v1
. tools/parse_options.sh || exit 1; . tools/parse_options.sh || exit 1;
@ -41,7 +42,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
data=data/$x data=data/$x
mkdir -p $data mkdir -p $data
# make wav.scp utt2spk text file # make wav.scp utt2spk text file
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data
done done
fi fi

View File

@ -27,7 +27,6 @@ from kws.dataset.dataset import Dataset
from kws.model.kws_model import init_model from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint from kws.utils.checkpoint import load_checkpoint
from kws.utils.executor import Executor from kws.utils.executor import Executor
from kws.utils.mask import padding_mask
def get_args(): def get_args():

View File

@ -220,10 +220,11 @@ def main():
training_config['epoch'] = epoch training_config['epoch'] = epoch
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer, #executor.train(model, optimizer, train_data_loader, device, writer,
training_config) # training_config)
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'.format(epoch, cv_loss, cv_acc)) logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc))
if args.rank == 0: if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))

View File

@ -30,4 +30,5 @@ class ElementClassifier(nn.Module):
self.classifier = classifier self.classifier = classifier
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return self.classifier(x) return self.classifier(x)

View File

@ -41,7 +41,7 @@ class KWSModel(torch.nn.Module):
global_cmvn: Optional[torch.nn.Module], global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module], preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module, backbone: torch.nn.Module,
classifier: torch.nn.Module classifier: torch.nn.Module
): ):
super().__init__() super().__init__()
self.idim = idim self.idim = idim

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from logging import log
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -44,8 +44,9 @@ class Executor:
if num_utts == 0: if num_utts == 0:
continue continue
logits = model(feats) logits = model(feats)
loss, acc = criterion(args.get('criterion', 'max_pooling'), loss, acc = criterion(
logits, target, feats_lengths) args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
loss.backward() loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip) grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm): if torch.isfinite(grad_norm):