format
This commit is contained in:
parent
6f91335452
commit
8cdf83a692
@ -42,7 +42,7 @@ model:
|
|||||||
|
|
||||||
optim: adam
|
optim: adam
|
||||||
optim_conf:
|
optim_conf:
|
||||||
lr: 0.001
|
lr: 0.0002
|
||||||
weight_decay: 0.00005
|
weight_decay: 0.00005
|
||||||
|
|
||||||
training_config:
|
training_config:
|
||||||
|
|||||||
@ -14,7 +14,7 @@ url=http://download.tensorflow.org/data/$file_name
|
|||||||
mkdir -p $data_dir
|
mkdir -p $data_dir
|
||||||
if [ ! -f $data_dir/$file_name ]; then
|
if [ ! -f $data_dir/$file_name ]; then
|
||||||
echo "downloading $url..."
|
echo "downloading $url..."
|
||||||
wget -O $data_dir/$file_name $url
|
wget -O $data_dir/$file_name $url
|
||||||
else
|
else
|
||||||
echo "$file_name exist in $data_dir, skip download it"
|
echo "$file_name exist in $data_dir, skip download it"
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
|
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
|
||||||
@ -9,14 +8,11 @@ CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=
|
description='prepare kaldi format file for google speech command')
|
||||||
'prepare kaldi format file for google speech command dataset ')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--wav_list',
|
'--wav_list',
|
||||||
required=True,
|
required=True,
|
||||||
help=
|
help='full path of a wav file in google speech command dataset')
|
||||||
'wave list is a file containts full path of a wav file in google speech command dataset'
|
|
||||||
)
|
|
||||||
parser.add_argument('--data_dir',
|
parser.add_argument('--data_dir',
|
||||||
required=True,
|
required=True,
|
||||||
help='folder to write kaldi format files')
|
help='folder to write kaldi format files')
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
"""Splits the google speech commands into train, validation and test set """
|
#!/usr/bin/env python
|
||||||
|
'''Splits the google speech commands into train, validation and test set'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
stage=3
|
stage=2
|
||||||
stop_stage=4
|
stop_stage=4
|
||||||
num_keywords=11
|
num_keywords=11
|
||||||
|
|
||||||
@ -21,7 +21,8 @@ dir=exp/mdtc_debug
|
|||||||
num_average=1
|
num_average=1
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
download_dir=/mnt/mnt-data-3/jingyong.hou/data # your data dir
|
# your data dir
|
||||||
|
download_dir=/mnt/mnt-data-3/jingyong.hou/data
|
||||||
speech_command_dir=$download_dir/speech_commands_v1
|
speech_command_dir=$download_dir/speech_commands_v1
|
||||||
. tools/parse_options.sh || exit 1;
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
|||||||
data=data/$x
|
data=data/$x
|
||||||
mkdir -p $data
|
mkdir -p $data
|
||||||
# make wav.scp utt2spk text file
|
# make wav.scp utt2spk text file
|
||||||
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
|
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
|
||||||
python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data
|
python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -27,7 +27,6 @@ from kws.dataset.dataset import Dataset
|
|||||||
from kws.model.kws_model import init_model
|
from kws.model.kws_model import init_model
|
||||||
from kws.utils.checkpoint import load_checkpoint
|
from kws.utils.checkpoint import load_checkpoint
|
||||||
from kws.utils.executor import Executor
|
from kws.utils.executor import Executor
|
||||||
from kws.utils.mask import padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
|||||||
@ -220,10 +220,11 @@ def main():
|
|||||||
training_config['epoch'] = epoch
|
training_config['epoch'] = epoch
|
||||||
lr = optimizer.param_groups[0]['lr']
|
lr = optimizer.param_groups[0]['lr']
|
||||||
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||||
executor.train(model, optimizer, train_data_loader, device, writer,
|
#executor.train(model, optimizer, train_data_loader, device, writer,
|
||||||
training_config)
|
# training_config)
|
||||||
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
||||||
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'.format(epoch, cv_loss, cv_acc))
|
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
|
||||||
|
.format(epoch, cv_loss, cv_acc))
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
||||||
|
|||||||
@ -30,4 +30,5 @@ class ElementClassifier(nn.Module):
|
|||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return self.classifier(x)
|
return self.classifier(x)
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
global_cmvn: Optional[torch.nn.Module],
|
global_cmvn: Optional[torch.nn.Module],
|
||||||
preprocessing: Optional[torch.nn.Module],
|
preprocessing: Optional[torch.nn.Module],
|
||||||
backbone: torch.nn.Module,
|
backbone: torch.nn.Module,
|
||||||
classifier: torch.nn.Module
|
classifier: torch.nn.Module
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
|
|||||||
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from logging import log
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|||||||
@ -44,8 +44,9 @@ class Executor:
|
|||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
loss, acc = criterion(
|
||||||
logits, target, feats_lengths)
|
args.get('criterion', 'max_pooling'),
|
||||||
|
logits, target, feats_lengths)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||||
if torch.isfinite(grad_norm):
|
if torch.isfinite(grad_norm):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user