format
This commit is contained in:
parent
6f91335452
commit
8cdf83a692
@ -42,7 +42,7 @@ model:
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
lr: 0.0002
|
||||
weight_decay: 0.00005
|
||||
|
||||
training_config:
|
||||
|
||||
@ -14,7 +14,7 @@ url=http://download.tensorflow.org/data/$file_name
|
||||
mkdir -p $data_dir
|
||||
if [ ! -f $data_dir/$file_name ]; then
|
||||
echo "downloading $url..."
|
||||
wget -O $data_dir/$file_name $url
|
||||
wget -O $data_dir/$file_name $url
|
||||
else
|
||||
echo "$file_name exist in $data_dir, skip download it"
|
||||
fi
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
|
||||
@ -9,14 +8,11 @@ CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
'prepare kaldi format file for google speech command dataset ')
|
||||
description='prepare kaldi format file for google speech command')
|
||||
parser.add_argument(
|
||||
'--wav_list',
|
||||
required=True,
|
||||
help=
|
||||
'wave list is a file containts full path of a wav file in google speech command dataset'
|
||||
)
|
||||
help='full path of a wav file in google speech command dataset')
|
||||
parser.add_argument('--data_dir',
|
||||
required=True,
|
||||
help='folder to write kaldi format files')
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Splits the google speech commands into train, validation and test set """
|
||||
#!/usr/bin/env python
|
||||
'''Splits the google speech commands into train, validation and test set'''
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=3
|
||||
stage=2
|
||||
stop_stage=4
|
||||
num_keywords=11
|
||||
|
||||
@ -21,7 +21,8 @@ dir=exp/mdtc_debug
|
||||
num_average=1
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
download_dir=/mnt/mnt-data-3/jingyong.hou/data # your data dir
|
||||
# your data dir
|
||||
download_dir=/mnt/mnt-data-3/jingyong.hou/data
|
||||
speech_command_dir=$download_dir/speech_commands_v1
|
||||
. tools/parse_options.sh || exit 1;
|
||||
|
||||
@ -41,7 +42,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
data=data/$x
|
||||
mkdir -p $data
|
||||
# make wav.scp utt2spk text file
|
||||
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
|
||||
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
|
||||
python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data
|
||||
done
|
||||
fi
|
||||
|
||||
@ -27,7 +27,6 @@ from kws.dataset.dataset import Dataset
|
||||
from kws.model.kws_model import init_model
|
||||
from kws.utils.checkpoint import load_checkpoint
|
||||
from kws.utils.executor import Executor
|
||||
from kws.utils.mask import padding_mask
|
||||
|
||||
|
||||
def get_args():
|
||||
|
||||
@ -220,10 +220,11 @@ def main():
|
||||
training_config['epoch'] = epoch
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||
executor.train(model, optimizer, train_data_loader, device, writer,
|
||||
training_config)
|
||||
#executor.train(model, optimizer, train_data_loader, device, writer,
|
||||
# training_config)
|
||||
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
||||
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'.format(epoch, cv_loss, cv_acc))
|
||||
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
|
||||
.format(epoch, cv_loss, cv_acc))
|
||||
|
||||
if args.rank == 0:
|
||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
||||
|
||||
@ -30,4 +30,5 @@ class ElementClassifier(nn.Module):
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.classifier(x)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class KWSModel(torch.nn.Module):
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
preprocessing: Optional[torch.nn.Module],
|
||||
backbone: torch.nn.Module,
|
||||
classifier: torch.nn.Module
|
||||
classifier: torch.nn.Module
|
||||
):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from logging import log
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@ -44,8 +44,9 @@ class Executor:
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits = model(feats)
|
||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
||||
logits, target, feats_lengths)
|
||||
loss, acc = criterion(
|
||||
args.get('criterion', 'max_pooling'),
|
||||
logits, target, feats_lengths)
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user