[examples] added training code of speech command dataset

This commit is contained in:
jingyong hou 2021-12-06 12:03:12 +08:00
parent 1240136eba
commit dc596b41a1
7 changed files with 221 additions and 21 deletions

View File

@ -0,0 +1,52 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'mfcc'
num_ceps: 80
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
feature_dither: 0.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 20
max_f: 40
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 100
model:
hidden_dim: 64
preprocessing:
type: none
backbone:
type: mdtc
num_stack: 4
stack_size: 4
kernel_size: 5
hidden_dim: 64
causal: False
classifier:
type: global
dropout: 0.5
optim: adam
optim_conf:
lr: 0.0002
weight_decay: 0.00005
training_config:
grad_clip: 50
max_epoch: 100
log_interval: 10
criterion: ce

View File

@ -6,8 +6,20 @@
export CUDA_VISIBLE_DEVICES="0"
stage=-1
stop_stage=0
stage=1
stop_stage=4
num_keywords=11
config=conf/mdtc.yaml
norm_mean=false
norm_var=false
gpu_id=4
checkpoint=
dir=exp/mdtc
num_average=10
score_checkpoint=$dir/avg_${num_average}.pt
# your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data
@ -35,3 +47,36 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train valid test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
python kws/bin/train.py --gpu $gpu_id \
--config $config \
--train_data data/train/data.list \
--cv_data data/valid/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi

View File

@ -164,9 +164,9 @@ def main():
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# if args.rank == 0:
# script_model = torch.jit.script(model)
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
if args.rank == 0:
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
@ -196,8 +196,7 @@ def main():
model = model.to(device)
optimizer = optim.Adam(model.parameters(),
lr=configs['optim_conf']['lr'],
weight_decay=configs['optim_conf'].get('weight_decay', 0))
lr=configs['optim_conf']['lr'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
@ -223,8 +222,9 @@ def main():
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer,
training_config)
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc))
if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
@ -234,6 +234,7 @@ def main():
'cv_loss': cv_loss,
})
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
writer.add_scalar('epoch/cv_acc', cv_acc, epoch)
writer.add_scalar('epoch/lr', lr, epoch)
final_epoch = epoch
scheduler.step(cv_loss)

33
kws/model/classifier.py Normal file
View File

@ -0,0 +1,33 @@
import torch
import torch.nn as nn
class GlobalClassifier(nn.Module):
"""Add a global average pooling before the classifier"""
def __init__(self, classifier: nn.Module):
super(GlobalClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
x = torch.mean(x, dim=1)
return self.classifier(x)
class LastClassifier(nn.Module):
"""Select last frame to do the classification"""
def __init__(self, classifier: nn.Module):
super(LastClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
x = x[:, -1, :]
return self.classifier(x)
class ElementClassifier(nn.Module):
"""Classify all the frames in an utterance"""
def __init__(self, classifier: nn.Module):
super(ElementClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
return self.classifier(x)

View File

@ -16,8 +16,10 @@ import sys
from typing import Optional
import torch
import torch.nn as nn
from kws.model.cmvn import GlobalCMVN
from kws.model.classifier import GlobalClassifier, LastClassifier
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC
@ -39,6 +41,7 @@ class KWSModel(torch.nn.Module):
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
):
super().__init__()
self.idim = idim
@ -47,7 +50,7 @@ class KWSModel(torch.nn.Module):
self.global_cmvn = global_cmvn
self.preprocessing = preprocessing
self.backbone = backbone
self.classifier = torch.nn.Linear(hdim, odim)
self.classifier = classifier
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
@ -55,7 +58,6 @@ class KWSModel(torch.nn.Module):
x = self.preprocessing(x)
x, _ = self.backbone(x)
x = self.classifier(x)
x = torch.sigmoid(x)
return x
@ -110,17 +112,34 @@ def init_model(configs):
num_stack = configs['backbone']['num_stack']
kernel_size = configs['backbone']['kernel_size']
hidden_dim = configs['backbone']['hidden_dim']
causal = configs['backbone']['causal']
backbone = MDTC(num_stack,
stack_size,
input_dim,
hidden_dim,
kernel_size,
causal=True)
causal=causal)
else:
print('Unknown body type {}'.format(backbone_type))
sys.exit(1)
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim),
)
if classifier_type == 'linear':
classifier = classifier_base
elif classifier_type == 'global':
classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last':
classifier = LastClassifier(classifier_base)
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone)
preprocessing, backbone, classifier)
return kws_model

View File

@ -13,11 +13,12 @@
# limitations under the License.
import torch
import torch.nn as nn
from kws.utils.mask import padding_mask
def max_polling_loss(logits: torch.Tensor,
def max_pooling_loss(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
@ -37,6 +38,7 @@ def max_polling_loss(logits: torch.Tensor,
(float): loss of current batch
(float): accuracy of current batch
'''
logits = torch.sigmoid(logits)
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)
@ -80,3 +82,46 @@ def max_polling_loss(logits: torch.Tensor,
acc = num_correct / num_utts
# acc = 0.0
return loss, acc
def acc_frame(
logits: torch.Tensor,
target: torch.Tensor,
):
if logits is None:
return 0
pred = logits.max(1, keepdim=True)[1]
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct * 100.0 / logits.size(0)
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
""" Cross Entropy Loss
Attributes:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
"""
cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(logits, target)
acc = acc_frame(logits, target)
return loss, acc
def criterion(type: str,
logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
if type == 'ce':
loss, acc = cross_entropy(logits, target)
return loss, acc
elif type == 'max_pooling':
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
else:
exit(1)

View File

@ -17,7 +17,7 @@ import logging
import torch
from torch.nn.utils import clip_grad_norm_
from kws.model.loss import max_polling_loss
from kws.model.loss import criterion
class Executor:
@ -44,8 +44,8 @@ class Executor:
if num_utts == 0:
continue
logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths,
min_duration)
loss_type = args.get('criterion', 'max_pooling')
loss, acc = criterion(loss_type, logits, target, feats_lengths)
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
@ -64,6 +64,7 @@ class Executor:
# in order to avoid division by 0
num_seen_utts = 1
total_loss = 0.0
total_acc = 0.0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
@ -73,15 +74,19 @@ class Executor:
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
num_seen_utts += num_utts
logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths)
loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts
total_acc += acc * num_utts
if batch_idx % log_interval == 0:
logging.debug(
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
.format(epoch, batch_idx, loss.item(), acc,
total_loss / num_seen_utts))
return total_loss / num_seen_utts
return total_loss / num_seen_utts, total_acc / num_seen_utts
def test(self, model, data_loader, device, args):
return self.cv(model, data_loader, device, args)