[exampels] add speechcommand train (#30)
* [example] added code for training speech command dataset * update kes_model.py * update kes_model.py * format * format * add more comments to explain the new classifier designed for speech command classification task * add copyrigh info * update copyrigh info of classifier.py
This commit is contained in:
parent
8be4bef405
commit
37f56db5af
52
examples/speechcommand_v1/s0/conf/mdtc.yaml
Normal file
52
examples/speechcommand_v1/s0/conf/mdtc.yaml
Normal file
@ -0,0 +1,52 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'mfcc'
|
||||
num_ceps: 80
|
||||
num_mel_bins: 80
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
feature_dither: 0.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 2
|
||||
num_f_mask: 2
|
||||
max_t: 20
|
||||
max_f: 40
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 100
|
||||
|
||||
model:
|
||||
hidden_dim: 64
|
||||
preprocessing:
|
||||
type: none
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 4
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 64
|
||||
causal: False
|
||||
classifier:
|
||||
type: global
|
||||
dropout: 0.5
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 0.00005
|
||||
|
||||
training_config:
|
||||
grad_clip: 50
|
||||
max_epoch: 100
|
||||
log_interval: 10
|
||||
criterion: ce
|
||||
@ -7,7 +7,19 @@
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=-1
|
||||
stop_stage=0
|
||||
stop_stage=2
|
||||
num_keywords=11
|
||||
|
||||
config=conf/mdtc.yaml
|
||||
norm_mean=false
|
||||
norm_var=false
|
||||
gpu_id=4
|
||||
|
||||
checkpoint=
|
||||
dir=exp/mdtc
|
||||
|
||||
num_average=10
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
# your data dir
|
||||
download_dir=/mnt/mnt-data-3/jingyong.hou/data
|
||||
@ -35,3 +47,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Compute CMVN and Format datasets"
|
||||
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||
--in_scp data/train/wav.scp \
|
||||
--out_cmvn data/train/global_cmvn
|
||||
|
||||
for x in train valid test; do
|
||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||
data/$x/wav.dur data/$x/data.list
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Start training ..."
|
||||
mkdir -p $dir
|
||||
cmvn_opts=
|
||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||
python kws/bin/train.py --gpu $gpu_id \
|
||||
--config $config \
|
||||
--train_data data/train/data.list \
|
||||
--cv_data data/valid/data.list \
|
||||
--model_dir $dir \
|
||||
--num_workers 8 \
|
||||
--num_keywords $num_keywords \
|
||||
--min_duration 50 \
|
||||
$cmvn_opts \
|
||||
${checkpoint:+--checkpoint $checkpoint}
|
||||
fi
|
||||
|
||||
@ -221,8 +221,9 @@ def main():
|
||||
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||
executor.train(model, optimizer, train_data_loader, device, writer,
|
||||
training_config)
|
||||
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
|
||||
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
|
||||
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
||||
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
|
||||
.format(epoch, cv_loss, cv_acc))
|
||||
|
||||
if args.rank == 0:
|
||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
||||
@ -232,6 +233,7 @@ def main():
|
||||
'cv_loss': cv_loss,
|
||||
})
|
||||
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
||||
writer.add_scalar('epoch/cv_acc', cv_acc, epoch)
|
||||
writer.add_scalar('epoch/lr', lr, epoch)
|
||||
final_epoch = epoch
|
||||
scheduler.step(cv_loss)
|
||||
|
||||
47
kws/model/classifier.py
Normal file
47
kws/model/classifier.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2021 Jingyong Hou
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GlobalClassifier(nn.Module):
|
||||
"""Add a global average pooling before the classifier"""
|
||||
def __init__(self, classifier: nn.Module):
|
||||
super(GlobalClassifier, self).__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = torch.mean(x, dim=1)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
class LastClassifier(nn.Module):
|
||||
"""Select last frame to do the classification"""
|
||||
def __init__(self, classifier: nn.Module):
|
||||
super(LastClassifier, self).__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x[:, -1, :]
|
||||
return self.classifier(x)
|
||||
|
||||
class ElementClassifier(nn.Module):
|
||||
"""Classify all the frames in an utterance"""
|
||||
def __init__(self, classifier: nn.Module):
|
||||
super(ElementClassifier, self).__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.classifier(x)
|
||||
@ -16,8 +16,10 @@ import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from kws.model.cmvn import GlobalCMVN
|
||||
from kws.model.classifier import GlobalClassifier, LastClassifier
|
||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||
from kws.model.mdtc import MDTC
|
||||
@ -39,6 +41,7 @@ class KWSModel(torch.nn.Module):
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
preprocessing: Optional[torch.nn.Module],
|
||||
backbone: torch.nn.Module,
|
||||
classifier: torch.nn.Module
|
||||
):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
@ -47,7 +50,7 @@ class KWSModel(torch.nn.Module):
|
||||
self.global_cmvn = global_cmvn
|
||||
self.preprocessing = preprocessing
|
||||
self.backbone = backbone
|
||||
self.classifier = torch.nn.Linear(hdim, odim)
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_cmvn is not None:
|
||||
@ -55,7 +58,6 @@ class KWSModel(torch.nn.Module):
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -110,17 +112,39 @@ def init_model(configs):
|
||||
num_stack = configs['backbone']['num_stack']
|
||||
kernel_size = configs['backbone']['kernel_size']
|
||||
hidden_dim = configs['backbone']['hidden_dim']
|
||||
|
||||
causal = configs['backbone']['causal']
|
||||
backbone = MDTC(num_stack,
|
||||
stack_size,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
kernel_size,
|
||||
causal=True)
|
||||
causal=causal)
|
||||
else:
|
||||
print('Unknown body type {}'.format(backbone_type))
|
||||
sys.exit(1)
|
||||
if 'classifier' in configs:
|
||||
# For speech command dataset, we use 2 FC layer as classifier,
|
||||
# we add dropout after first FC layer to prevent overfitting
|
||||
classifier_type = configs['classifier']['type']
|
||||
dropout = configs['classifier']['dropout']
|
||||
|
||||
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(64, output_dim))
|
||||
if classifier_type == 'global':
|
||||
# global means we add a global average pooling before classifier
|
||||
classifier = GlobalClassifier(classifier_base)
|
||||
elif classifier_type == 'last':
|
||||
# last means we use last frame to do backpropagation, so the model
|
||||
# can be infered streamingly
|
||||
classifier = LastClassifier(classifier_base)
|
||||
else:
|
||||
print('Unknown classifier type {}'.format(classifier_type))
|
||||
sys.exit(1)
|
||||
else:
|
||||
classifier = torch.nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone)
|
||||
preprocessing, backbone, classifier)
|
||||
return kws_model
|
||||
|
||||
@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from kws.utils.mask import padding_mask
|
||||
|
||||
|
||||
def max_polling_loss(logits: torch.Tensor,
|
||||
def max_pooling_loss(logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
min_duration: int = 0):
|
||||
@ -37,6 +38,7 @@ def max_polling_loss(logits: torch.Tensor,
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
'''
|
||||
logits = torch.sigmoid(logits)
|
||||
mask = padding_mask(lengths)
|
||||
num_utts = logits.size(0)
|
||||
num_keywords = logits.size(2)
|
||||
@ -80,3 +82,46 @@ def max_polling_loss(logits: torch.Tensor,
|
||||
acc = num_correct / num_utts
|
||||
# acc = 0.0
|
||||
return loss, acc
|
||||
|
||||
|
||||
def acc_frame(
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
):
|
||||
if logits is None:
|
||||
return 0
|
||||
pred = logits.max(1, keepdim=True)[1]
|
||||
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
||||
return correct * 100.0 / logits.size(0)
|
||||
|
||||
|
||||
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
||||
""" Cross Entropy Loss
|
||||
Attributes:
|
||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||
target: (B)
|
||||
lengths: (B)
|
||||
min_duration: min duration of the keyword
|
||||
Returns:
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
"""
|
||||
cross_entropy = nn.CrossEntropyLoss()
|
||||
loss = cross_entropy(logits, target)
|
||||
acc = acc_frame(logits, target)
|
||||
return loss, acc
|
||||
|
||||
|
||||
def criterion(type: str,
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
min_duration: int = 0):
|
||||
if type == 'ce':
|
||||
loss, acc = cross_entropy(logits, target)
|
||||
return loss, acc
|
||||
elif type == 'max_pooling':
|
||||
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
||||
return loss, acc
|
||||
else:
|
||||
exit(1)
|
||||
|
||||
@ -17,7 +17,7 @@ import logging
|
||||
import torch
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from kws.model.loss import max_polling_loss
|
||||
from kws.model.loss import criterion
|
||||
|
||||
|
||||
class Executor:
|
||||
@ -44,8 +44,8 @@ class Executor:
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits = model(feats)
|
||||
loss, acc = max_polling_loss(logits, target, feats_lengths,
|
||||
min_duration)
|
||||
loss_type = args.get('criterion', 'max_pooling')
|
||||
loss, acc = criterion(loss_type, logits, target, feats_lengths)
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
@ -64,6 +64,7 @@ class Executor:
|
||||
# in order to avoid division by 0
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
total_acc = 0.0
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths = batch
|
||||
@ -73,15 +74,19 @@ class Executor:
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
num_seen_utts += num_utts
|
||||
logits = model(feats)
|
||||
loss, acc = max_polling_loss(logits, target, feats_lengths)
|
||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
||||
logits, target, feats_lengths)
|
||||
if torch.isfinite(loss):
|
||||
num_seen_utts += num_utts
|
||||
total_loss += loss.item() * num_utts
|
||||
total_acc += acc * num_utts
|
||||
if batch_idx % log_interval == 0:
|
||||
logging.debug(
|
||||
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
|
||||
.format(epoch, batch_idx, loss.item(), acc,
|
||||
total_loss / num_seen_utts))
|
||||
return total_loss / num_seen_utts
|
||||
return total_loss / num_seen_utts, total_acc / num_seen_utts
|
||||
|
||||
def test(self, model, data_loader, device, args):
|
||||
return self.cv(model, data_loader, device, args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user