From 8cdf83a692aeb379869ad5a2f3ecdc1a76534e47 Mon Sep 17 00:00:00 2001 From: jingyong hou Date: Fri, 3 Dec 2021 20:23:32 +0800 Subject: [PATCH] format --- examples/speechcommand_v1/s0/conf/mdtc.yaml | 2 +- examples/speechcommand_v1/s0/local/data_download.sh | 2 +- .../speechcommand_v1/s0/local/prepare_speech_command.py | 8 ++------ examples/speechcommand_v1/s0/local/split_dataset.py | 3 ++- examples/speechcommand_v1/s0/run.sh | 7 ++++--- kws/bin/test.py | 1 - kws/bin/train.py | 7 ++++--- kws/model/classifier.py | 3 ++- kws/model/kws_model.py | 2 +- kws/model/loss.py | 1 - kws/utils/executor.py | 5 +++-- 11 files changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/speechcommand_v1/s0/conf/mdtc.yaml b/examples/speechcommand_v1/s0/conf/mdtc.yaml index 80a9c6d..6b5646f 100644 --- a/examples/speechcommand_v1/s0/conf/mdtc.yaml +++ b/examples/speechcommand_v1/s0/conf/mdtc.yaml @@ -42,7 +42,7 @@ model: optim: adam optim_conf: - lr: 0.001 + lr: 0.0002 weight_decay: 0.00005 training_config: diff --git a/examples/speechcommand_v1/s0/local/data_download.sh b/examples/speechcommand_v1/s0/local/data_download.sh index 0fd3a64..4b1000a 100755 --- a/examples/speechcommand_v1/s0/local/data_download.sh +++ b/examples/speechcommand_v1/s0/local/data_download.sh @@ -14,7 +14,7 @@ url=http://download.tensorflow.org/data/$file_name mkdir -p $data_dir if [ ! -f $data_dir/$file_name ]; then echo "downloading $url..." - wget -O $data_dir/$file_name $url + wget -O $data_dir/$file_name $url else echo "$file_name exist in $data_dir, skip download it" fi diff --git a/examples/speechcommand_v1/s0/local/prepare_speech_command.py b/examples/speechcommand_v1/s0/local/prepare_speech_command.py index f7898f3..a9d6982 100644 --- a/examples/speechcommand_v1/s0/local/prepare_speech_command.py +++ b/examples/speechcommand_v1/s0/local/prepare_speech_command.py @@ -1,6 +1,5 @@ #!/usr/bin/env python import os -import sys import argparse CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split( @@ -9,14 +8,11 @@ CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))} if __name__ == '__main__': parser = argparse.ArgumentParser( - description= - 'prepare kaldi format file for google speech command dataset ') + description='prepare kaldi format file for google speech command') parser.add_argument( '--wav_list', required=True, - help= - 'wave list is a file containts full path of a wav file in google speech command dataset' - ) + help='full path of a wav file in google speech command dataset') parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files') diff --git a/examples/speechcommand_v1/s0/local/split_dataset.py b/examples/speechcommand_v1/s0/local/split_dataset.py index 8d625db..b7b5326 100755 --- a/examples/speechcommand_v1/s0/local/split_dataset.py +++ b/examples/speechcommand_v1/s0/local/split_dataset.py @@ -1,4 +1,5 @@ -"""Splits the google speech commands into train, validation and test set """ +#!/usr/bin/env python +'''Splits the google speech commands into train, validation and test set''' import os import shutil diff --git a/examples/speechcommand_v1/s0/run.sh b/examples/speechcommand_v1/s0/run.sh index 2441059..bb9f8ac 100644 --- a/examples/speechcommand_v1/s0/run.sh +++ b/examples/speechcommand_v1/s0/run.sh @@ -6,7 +6,7 @@ export CUDA_VISIBLE_DEVICES="0" -stage=3 +stage=2 stop_stage=4 num_keywords=11 @@ -21,7 +21,8 @@ dir=exp/mdtc_debug num_average=1 score_checkpoint=$dir/avg_${num_average}.pt -download_dir=/mnt/mnt-data-3/jingyong.hou/data # your data dir +# your data dir +download_dir=/mnt/mnt-data-3/jingyong.hou/data speech_command_dir=$download_dir/speech_commands_v1 . tools/parse_options.sh || exit 1; @@ -41,7 +42,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then data=data/$x mkdir -p $data # make wav.scp utt2spk text file - find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list + find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data done fi diff --git a/kws/bin/test.py b/kws/bin/test.py index ca6d7ef..a6cde27 100644 --- a/kws/bin/test.py +++ b/kws/bin/test.py @@ -27,7 +27,6 @@ from kws.dataset.dataset import Dataset from kws.model.kws_model import init_model from kws.utils.checkpoint import load_checkpoint from kws.utils.executor import Executor -from kws.utils.mask import padding_mask def get_args(): diff --git a/kws/bin/train.py b/kws/bin/train.py index f9f05b7..34acc80 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -220,10 +220,11 @@ def main(): training_config['epoch'] = epoch lr = optimizer.param_groups[0]['lr'] logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) - executor.train(model, optimizer, train_data_loader, device, writer, - training_config) + #executor.train(model, optimizer, train_data_loader, device, writer, + # training_config) cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) - logging.info('Epoch {} CV info cv_loss {} cv_acc {}'.format(epoch, cv_loss, cv_acc)) + logging.info('Epoch {} CV info cv_loss {} cv_acc {}' + .format(epoch, cv_loss, cv_acc)) if args.rank == 0: save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) diff --git a/kws/model/classifier.py b/kws/model/classifier.py index b4a0983..3bbcb85 100644 --- a/kws/model/classifier.py +++ b/kws/model/classifier.py @@ -30,4 +30,5 @@ class ElementClassifier(nn.Module): self.classifier = classifier def forward(self, x: torch.Tensor): - return self.classifier(x) \ No newline at end of file + return self.classifier(x) + diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 004cd79..5c56aef 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -41,7 +41,7 @@ class KWSModel(torch.nn.Module): global_cmvn: Optional[torch.nn.Module], preprocessing: Optional[torch.nn.Module], backbone: torch.nn.Module, - classifier: torch.nn.Module + classifier: torch.nn.Module ): super().__init__() self.idim = idim diff --git a/kws/model/loss.py b/kws/model/loss.py index 1785c2b..80a7fec 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from logging import log import torch import torch.nn as nn diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 4bf0303..a5c3d74 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -44,8 +44,9 @@ class Executor: if num_utts == 0: continue logits = model(feats) - loss, acc = criterion(args.get('criterion', 'max_pooling'), - logits, target, feats_lengths) + loss, acc = criterion( + args.get('criterion', 'max_pooling'), + logits, target, feats_lengths) loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm):