[exampels] add speechcommand train (#30)

* [example] added code for training speech command dataset

* update kes_model.py

* update kes_model.py

* format

* format

* add more comments to explain the new classifier designed for speech command classification task

* add copyrigh info

* update copyrigh info of classifier.py
This commit is contained in:
xiaohou 2021-12-06 17:14:33 +08:00 committed by GitHub
parent 8be4bef405
commit 37f56db5af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 234 additions and 15 deletions

View File

@ -0,0 +1,52 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'mfcc'
num_ceps: 80
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
feature_dither: 0.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 20
max_f: 40
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 100
model:
hidden_dim: 64
preprocessing:
type: none
backbone:
type: mdtc
num_stack: 4
stack_size: 4
kernel_size: 5
hidden_dim: 64
causal: False
classifier:
type: global
dropout: 0.5
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.00005
training_config:
grad_clip: 50
max_epoch: 100
log_interval: 10
criterion: ce

View File

@ -7,7 +7,19 @@
export CUDA_VISIBLE_DEVICES="0"
stage=-1
stop_stage=0
stop_stage=2
num_keywords=11
config=conf/mdtc.yaml
norm_mean=false
norm_var=false
gpu_id=4
checkpoint=
dir=exp/mdtc
num_average=10
score_checkpoint=$dir/avg_${num_average}.pt
# your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data
@ -35,3 +47,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train valid test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
python kws/bin/train.py --gpu $gpu_id \
--config $config \
--train_data data/train/data.list \
--cv_data data/valid/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi

View File

@ -221,8 +221,9 @@ def main():
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer,
training_config)
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc))
if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
@ -232,6 +233,7 @@ def main():
'cv_loss': cv_loss,
})
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
writer.add_scalar('epoch/cv_acc', cv_acc, epoch)
writer.add_scalar('epoch/lr', lr, epoch)
final_epoch = epoch
scheduler.step(cv_loss)

47
kws/model/classifier.py Normal file
View File

@ -0,0 +1,47 @@
# Copyright (c) 2021 Jingyong Hou
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class GlobalClassifier(nn.Module):
"""Add a global average pooling before the classifier"""
def __init__(self, classifier: nn.Module):
super(GlobalClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
x = torch.mean(x, dim=1)
return self.classifier(x)
class LastClassifier(nn.Module):
"""Select last frame to do the classification"""
def __init__(self, classifier: nn.Module):
super(LastClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
x = x[:, -1, :]
return self.classifier(x)
class ElementClassifier(nn.Module):
"""Classify all the frames in an utterance"""
def __init__(self, classifier: nn.Module):
super(ElementClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
return self.classifier(x)

View File

@ -16,8 +16,10 @@ import sys
from typing import Optional
import torch
import torch.nn as nn
from kws.model.cmvn import GlobalCMVN
from kws.model.classifier import GlobalClassifier, LastClassifier
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC
@ -39,6 +41,7 @@ class KWSModel(torch.nn.Module):
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
):
super().__init__()
self.idim = idim
@ -47,7 +50,7 @@ class KWSModel(torch.nn.Module):
self.global_cmvn = global_cmvn
self.preprocessing = preprocessing
self.backbone = backbone
self.classifier = torch.nn.Linear(hdim, odim)
self.classifier = classifier
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
@ -55,7 +58,6 @@ class KWSModel(torch.nn.Module):
x = self.preprocessing(x)
x, _ = self.backbone(x)
x = self.classifier(x)
x = torch.sigmoid(x)
return x
@ -110,17 +112,39 @@ def init_model(configs):
num_stack = configs['backbone']['num_stack']
kernel_size = configs['backbone']['kernel_size']
hidden_dim = configs['backbone']['hidden_dim']
causal = configs['backbone']['causal']
backbone = MDTC(num_stack,
stack_size,
input_dim,
hidden_dim,
kernel_size,
causal=True)
causal=causal)
else:
print('Unknown body type {}'.format(backbone_type))
sys.exit(1)
if 'classifier' in configs:
# For speech command dataset, we use 2 FC layer as classifier,
# we add dropout after first FC layer to prevent overfitting
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim))
if classifier_type == 'global':
# global means we add a global average pooling before classifier
classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last':
# last means we use last frame to do backpropagation, so the model
# can be infered streamingly
classifier = LastClassifier(classifier_base)
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
else:
classifier = torch.nn.Linear(hidden_dim, output_dim)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone)
preprocessing, backbone, classifier)
return kws_model

View File

@ -13,11 +13,12 @@
# limitations under the License.
import torch
import torch.nn as nn
from kws.utils.mask import padding_mask
def max_polling_loss(logits: torch.Tensor,
def max_pooling_loss(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
@ -37,6 +38,7 @@ def max_polling_loss(logits: torch.Tensor,
(float): loss of current batch
(float): accuracy of current batch
'''
logits = torch.sigmoid(logits)
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)
@ -80,3 +82,46 @@ def max_polling_loss(logits: torch.Tensor,
acc = num_correct / num_utts
# acc = 0.0
return loss, acc
def acc_frame(
logits: torch.Tensor,
target: torch.Tensor,
):
if logits is None:
return 0
pred = logits.max(1, keepdim=True)[1]
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct * 100.0 / logits.size(0)
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
""" Cross Entropy Loss
Attributes:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
"""
cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(logits, target)
acc = acc_frame(logits, target)
return loss, acc
def criterion(type: str,
logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
if type == 'ce':
loss, acc = cross_entropy(logits, target)
return loss, acc
elif type == 'max_pooling':
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
else:
exit(1)

View File

@ -17,7 +17,7 @@ import logging
import torch
from torch.nn.utils import clip_grad_norm_
from kws.model.loss import max_polling_loss
from kws.model.loss import criterion
class Executor:
@ -44,8 +44,8 @@ class Executor:
if num_utts == 0:
continue
logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths,
min_duration)
loss_type = args.get('criterion', 'max_pooling')
loss, acc = criterion(loss_type, logits, target, feats_lengths)
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
@ -64,6 +64,7 @@ class Executor:
# in order to avoid division by 0
num_seen_utts = 1
total_loss = 0.0
total_acc = 0.0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
@ -73,15 +74,19 @@ class Executor:
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
num_seen_utts += num_utts
logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths)
loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts
total_acc += acc * num_utts
if batch_idx % log_interval == 0:
logging.debug(
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
.format(epoch, batch_idx, loss.item(), acc,
total_loss / num_seen_utts))
return total_loss / num_seen_utts
return total_loss / num_seen_utts, total_acc / num_seen_utts
def test(self, model, data_loader, device, args):
return self.cv(model, data_loader, device, args)