[exampels] add speechcommand train (#30)

* [example] added code for training speech command dataset

* update kes_model.py

* update kes_model.py

* format

* format

* add more comments to explain the new classifier designed for speech command classification task

* add copyrigh info

* update copyrigh info of classifier.py
This commit is contained in:
xiaohou 2021-12-06 17:14:33 +08:00 committed by GitHub
parent 8be4bef405
commit 37f56db5af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 234 additions and 15 deletions

View File

@ -0,0 +1,52 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'mfcc'
num_ceps: 80
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
feature_dither: 0.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 20
max_f: 40
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 100
model:
hidden_dim: 64
preprocessing:
type: none
backbone:
type: mdtc
num_stack: 4
stack_size: 4
kernel_size: 5
hidden_dim: 64
causal: False
classifier:
type: global
dropout: 0.5
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.00005
training_config:
grad_clip: 50
max_epoch: 100
log_interval: 10
criterion: ce

View File

@ -7,7 +7,19 @@
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
stage=-1 stage=-1
stop_stage=0 stop_stage=2
num_keywords=11
config=conf/mdtc.yaml
norm_mean=false
norm_var=false
gpu_id=4
checkpoint=
dir=exp/mdtc
num_average=10
score_checkpoint=$dir/avg_${num_average}.pt
# your data dir # your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data download_dir=/mnt/mnt-data-3/jingyong.hou/data
@ -35,3 +47,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
done done
fi fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train valid test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
python kws/bin/train.py --gpu $gpu_id \
--config $config \
--train_data data/train/data.list \
--cv_data data/valid/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi

View File

@ -221,8 +221,9 @@ def main():
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer, executor.train(model, optimizer, train_data_loader, device, writer,
training_config) training_config)
cv_loss = executor.cv(model, cv_data_loader, device, training_config) cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc))
if args.rank == 0: if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
@ -232,6 +233,7 @@ def main():
'cv_loss': cv_loss, 'cv_loss': cv_loss,
}) })
writer.add_scalar('epoch/cv_loss', cv_loss, epoch) writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
writer.add_scalar('epoch/cv_acc', cv_acc, epoch)
writer.add_scalar('epoch/lr', lr, epoch) writer.add_scalar('epoch/lr', lr, epoch)
final_epoch = epoch final_epoch = epoch
scheduler.step(cv_loss) scheduler.step(cv_loss)

47
kws/model/classifier.py Normal file
View File

@ -0,0 +1,47 @@
# Copyright (c) 2021 Jingyong Hou
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class GlobalClassifier(nn.Module):
"""Add a global average pooling before the classifier"""
def __init__(self, classifier: nn.Module):
super(GlobalClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
x = torch.mean(x, dim=1)
return self.classifier(x)
class LastClassifier(nn.Module):
"""Select last frame to do the classification"""
def __init__(self, classifier: nn.Module):
super(LastClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
x = x[:, -1, :]
return self.classifier(x)
class ElementClassifier(nn.Module):
"""Classify all the frames in an utterance"""
def __init__(self, classifier: nn.Module):
super(ElementClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
return self.classifier(x)

View File

@ -16,8 +16,10 @@ import sys
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn
from kws.model.cmvn import GlobalCMVN from kws.model.cmvn import GlobalCMVN
from kws.model.classifier import GlobalClassifier, LastClassifier
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC from kws.model.mdtc import MDTC
@ -39,6 +41,7 @@ class KWSModel(torch.nn.Module):
global_cmvn: Optional[torch.nn.Module], global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module], preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module, backbone: torch.nn.Module,
classifier: torch.nn.Module
): ):
super().__init__() super().__init__()
self.idim = idim self.idim = idim
@ -47,7 +50,7 @@ class KWSModel(torch.nn.Module):
self.global_cmvn = global_cmvn self.global_cmvn = global_cmvn
self.preprocessing = preprocessing self.preprocessing = preprocessing
self.backbone = backbone self.backbone = backbone
self.classifier = torch.nn.Linear(hdim, odim) self.classifier = classifier
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None: if self.global_cmvn is not None:
@ -55,7 +58,6 @@ class KWSModel(torch.nn.Module):
x = self.preprocessing(x) x = self.preprocessing(x)
x, _ = self.backbone(x) x, _ = self.backbone(x)
x = self.classifier(x) x = self.classifier(x)
x = torch.sigmoid(x)
return x return x
@ -110,17 +112,39 @@ def init_model(configs):
num_stack = configs['backbone']['num_stack'] num_stack = configs['backbone']['num_stack']
kernel_size = configs['backbone']['kernel_size'] kernel_size = configs['backbone']['kernel_size']
hidden_dim = configs['backbone']['hidden_dim'] hidden_dim = configs['backbone']['hidden_dim']
causal = configs['backbone']['causal']
backbone = MDTC(num_stack, backbone = MDTC(num_stack,
stack_size, stack_size,
input_dim, input_dim,
hidden_dim, hidden_dim,
kernel_size, kernel_size,
causal=True) causal=causal)
else: else:
print('Unknown body type {}'.format(backbone_type)) print('Unknown body type {}'.format(backbone_type))
sys.exit(1) sys.exit(1)
if 'classifier' in configs:
# For speech command dataset, we use 2 FC layer as classifier,
# we add dropout after first FC layer to prevent overfitting
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim))
if classifier_type == 'global':
# global means we add a global average pooling before classifier
classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last':
# last means we use last frame to do backpropagation, so the model
# can be infered streamingly
classifier = LastClassifier(classifier_base)
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
else:
classifier = torch.nn.Linear(hidden_dim, output_dim)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone) preprocessing, backbone, classifier)
return kws_model return kws_model

View File

@ -13,11 +13,12 @@
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn
from kws.utils.mask import padding_mask from kws.utils.mask import padding_mask
def max_polling_loss(logits: torch.Tensor, def max_pooling_loss(logits: torch.Tensor,
target: torch.Tensor, target: torch.Tensor,
lengths: torch.Tensor, lengths: torch.Tensor,
min_duration: int = 0): min_duration: int = 0):
@ -37,6 +38,7 @@ def max_polling_loss(logits: torch.Tensor,
(float): loss of current batch (float): loss of current batch
(float): accuracy of current batch (float): accuracy of current batch
''' '''
logits = torch.sigmoid(logits)
mask = padding_mask(lengths) mask = padding_mask(lengths)
num_utts = logits.size(0) num_utts = logits.size(0)
num_keywords = logits.size(2) num_keywords = logits.size(2)
@ -80,3 +82,46 @@ def max_polling_loss(logits: torch.Tensor,
acc = num_correct / num_utts acc = num_correct / num_utts
# acc = 0.0 # acc = 0.0
return loss, acc return loss, acc
def acc_frame(
logits: torch.Tensor,
target: torch.Tensor,
):
if logits is None:
return 0
pred = logits.max(1, keepdim=True)[1]
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct * 100.0 / logits.size(0)
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
""" Cross Entropy Loss
Attributes:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
"""
cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(logits, target)
acc = acc_frame(logits, target)
return loss, acc
def criterion(type: str,
logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
if type == 'ce':
loss, acc = cross_entropy(logits, target)
return loss, acc
elif type == 'max_pooling':
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
else:
exit(1)

View File

@ -17,7 +17,7 @@ import logging
import torch import torch
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from kws.model.loss import max_polling_loss from kws.model.loss import criterion
class Executor: class Executor:
@ -44,8 +44,8 @@ class Executor:
if num_utts == 0: if num_utts == 0:
continue continue
logits = model(feats) logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths, loss_type = args.get('criterion', 'max_pooling')
min_duration) loss, acc = criterion(loss_type, logits, target, feats_lengths)
loss.backward() loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip) grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm): if torch.isfinite(grad_norm):
@ -64,6 +64,7 @@ class Executor:
# in order to avoid division by 0 # in order to avoid division by 0
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
total_acc = 0.0
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(data_loader): for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch key, feats, target, feats_lengths = batch
@ -73,15 +74,19 @@ class Executor:
num_utts = feats_lengths.size(0) num_utts = feats_lengths.size(0)
if num_utts == 0: if num_utts == 0:
continue continue
num_seen_utts += num_utts
logits = model(feats) logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths) loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
if torch.isfinite(loss): if torch.isfinite(loss):
num_seen_utts += num_utts num_seen_utts += num_utts
total_loss += loss.item() * num_utts total_loss += loss.item() * num_utts
total_acc += acc * num_utts
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
logging.debug( logging.debug(
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}' 'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
.format(epoch, batch_idx, loss.item(), acc, .format(epoch, batch_idx, loss.item(), acc,
total_loss / num_seen_utts)) total_loss / num_seen_utts))
return total_loss / num_seen_utts return total_loss / num_seen_utts, total_acc / num_seen_utts
def test(self, model, data_loader, device, args):
return self.cv(model, data_loader, device, args)