From 37f56db5af921cc81430ea2e2ca676f716d8b6e5 Mon Sep 17 00:00:00 2001 From: xiaohou Date: Mon, 6 Dec 2021 17:14:33 +0800 Subject: [PATCH] [exampels] add speechcommand train (#30) * [example] added code for training speech command dataset * update kes_model.py * update kes_model.py * format * format * add more comments to explain the new classifier designed for speech command classification task * add copyrigh info * update copyrigh info of classifier.py --- examples/speechcommand_v1/s0/conf/mdtc.yaml | 52 +++++++++++++++++++++ examples/speechcommand_v1/s0/run.sh | 46 +++++++++++++++++- kws/bin/train.py | 6 ++- kws/model/classifier.py | 47 +++++++++++++++++++ kws/model/kws_model.py | 34 ++++++++++++-- kws/model/loss.py | 47 ++++++++++++++++++- kws/utils/executor.py | 17 ++++--- 7 files changed, 234 insertions(+), 15 deletions(-) create mode 100644 examples/speechcommand_v1/s0/conf/mdtc.yaml create mode 100644 kws/model/classifier.py diff --git a/examples/speechcommand_v1/s0/conf/mdtc.yaml b/examples/speechcommand_v1/s0/conf/mdtc.yaml new file mode 100644 index 0000000..80a9c6d --- /dev/null +++ b/examples/speechcommand_v1/s0/conf/mdtc.yaml @@ -0,0 +1,52 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'mfcc' + num_ceps: 80 + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + feature_dither: 0.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 20 + max_f: 40 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 100 + +model: + hidden_dim: 64 + preprocessing: + type: none + backbone: + type: mdtc + num_stack: 4 + stack_size: 4 + kernel_size: 5 + hidden_dim: 64 + causal: False + classifier: + type: global + dropout: 0.5 + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.00005 + +training_config: + grad_clip: 50 + max_epoch: 100 + log_interval: 10 + criterion: ce diff --git a/examples/speechcommand_v1/s0/run.sh b/examples/speechcommand_v1/s0/run.sh index 0149e54..44bf258 100755 --- a/examples/speechcommand_v1/s0/run.sh +++ b/examples/speechcommand_v1/s0/run.sh @@ -7,7 +7,19 @@ export CUDA_VISIBLE_DEVICES="0" stage=-1 -stop_stage=0 +stop_stage=2 +num_keywords=11 + +config=conf/mdtc.yaml +norm_mean=false +norm_var=false +gpu_id=4 + +checkpoint= +dir=exp/mdtc + +num_average=10 +score_checkpoint=$dir/avg_${num_average}.pt # your data dir download_dir=/mnt/mnt-data-3/jingyong.hou/data @@ -35,3 +47,35 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then done fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Compute CMVN and Format datasets" + tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \ + --in_scp data/train/wav.scp \ + --out_cmvn data/train/global_cmvn + + for x in train valid test; do + tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur + tools/make_list.py data/$x/wav.scp data/$x/text \ + data/$x/wav.dur data/$x/data.list + done +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Start training ..." + mkdir -p $dir + cmvn_opts= + $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" + $norm_var && cmvn_opts="$cmvn_opts --norm_var" + python kws/bin/train.py --gpu $gpu_id \ + --config $config \ + --train_data data/train/data.list \ + --cv_data data/valid/data.list \ + --model_dir $dir \ + --num_workers 8 \ + --num_keywords $num_keywords \ + --min_duration 50 \ + $cmvn_opts \ + ${checkpoint:+--checkpoint $checkpoint} +fi diff --git a/kws/bin/train.py b/kws/bin/train.py index f6fe33d..8053b49 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -221,8 +221,9 @@ def main(): logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) executor.train(model, optimizer, train_data_loader, device, writer, training_config) - cv_loss = executor.cv(model, cv_data_loader, device, training_config) - logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) + cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) + logging.info('Epoch {} CV info cv_loss {} cv_acc {}' + .format(epoch, cv_loss, cv_acc)) if args.rank == 0: save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) @@ -232,6 +233,7 @@ def main(): 'cv_loss': cv_loss, }) writer.add_scalar('epoch/cv_loss', cv_loss, epoch) + writer.add_scalar('epoch/cv_acc', cv_acc, epoch) writer.add_scalar('epoch/lr', lr, epoch) final_epoch = epoch scheduler.step(cv_loss) diff --git a/kws/model/classifier.py b/kws/model/classifier.py new file mode 100644 index 0000000..42af97c --- /dev/null +++ b/kws/model/classifier.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 Jingyong Hou +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class GlobalClassifier(nn.Module): + """Add a global average pooling before the classifier""" + def __init__(self, classifier: nn.Module): + super(GlobalClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + x = torch.mean(x, dim=1) + return self.classifier(x) + + +class LastClassifier(nn.Module): + """Select last frame to do the classification""" + def __init__(self, classifier: nn.Module): + super(LastClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + x = x[:, -1, :] + return self.classifier(x) + +class ElementClassifier(nn.Module): + """Classify all the frames in an utterance""" + def __init__(self, classifier: nn.Module): + super(ElementClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + return self.classifier(x) diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index dd6142a..effc889 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -16,8 +16,10 @@ import sys from typing import Optional import torch +import torch.nn as nn from kws.model.cmvn import GlobalCMVN +from kws.model.classifier import GlobalClassifier, LastClassifier from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.mdtc import MDTC @@ -39,6 +41,7 @@ class KWSModel(torch.nn.Module): global_cmvn: Optional[torch.nn.Module], preprocessing: Optional[torch.nn.Module], backbone: torch.nn.Module, + classifier: torch.nn.Module ): super().__init__() self.idim = idim @@ -47,7 +50,7 @@ class KWSModel(torch.nn.Module): self.global_cmvn = global_cmvn self.preprocessing = preprocessing self.backbone = backbone - self.classifier = torch.nn.Linear(hdim, odim) + self.classifier = classifier def forward(self, x: torch.Tensor) -> torch.Tensor: if self.global_cmvn is not None: @@ -55,7 +58,6 @@ class KWSModel(torch.nn.Module): x = self.preprocessing(x) x, _ = self.backbone(x) x = self.classifier(x) - x = torch.sigmoid(x) return x @@ -110,17 +112,39 @@ def init_model(configs): num_stack = configs['backbone']['num_stack'] kernel_size = configs['backbone']['kernel_size'] hidden_dim = configs['backbone']['hidden_dim'] - + causal = configs['backbone']['causal'] backbone = MDTC(num_stack, stack_size, input_dim, hidden_dim, kernel_size, - causal=True) + causal=causal) else: print('Unknown body type {}'.format(backbone_type)) sys.exit(1) + if 'classifier' in configs: + # For speech command dataset, we use 2 FC layer as classifier, + # we add dropout after first FC layer to prevent overfitting + classifier_type = configs['classifier']['type'] + dropout = configs['classifier']['dropout'] + + classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(64, output_dim)) + if classifier_type == 'global': + # global means we add a global average pooling before classifier + classifier = GlobalClassifier(classifier_base) + elif classifier_type == 'last': + # last means we use last frame to do backpropagation, so the model + # can be infered streamingly + classifier = LastClassifier(classifier_base) + else: + print('Unknown classifier type {}'.format(classifier_type)) + sys.exit(1) + else: + classifier = torch.nn.Linear(hidden_dim, output_dim) kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, - preprocessing, backbone) + preprocessing, backbone, classifier) return kws_model diff --git a/kws/model/loss.py b/kws/model/loss.py index d73722c..80a7fec 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -13,11 +13,12 @@ # limitations under the License. import torch +import torch.nn as nn from kws.utils.mask import padding_mask -def max_polling_loss(logits: torch.Tensor, +def max_pooling_loss(logits: torch.Tensor, target: torch.Tensor, lengths: torch.Tensor, min_duration: int = 0): @@ -37,6 +38,7 @@ def max_polling_loss(logits: torch.Tensor, (float): loss of current batch (float): accuracy of current batch ''' + logits = torch.sigmoid(logits) mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2) @@ -80,3 +82,46 @@ def max_polling_loss(logits: torch.Tensor, acc = num_correct / num_utts # acc = 0.0 return loss, acc + + +def acc_frame( + logits: torch.Tensor, + target: torch.Tensor, +): + if logits is None: + return 0 + pred = logits.max(1, keepdim=True)[1] + correct = pred.eq(target.long().view_as(pred)).sum().item() + return correct * 100.0 / logits.size(0) + + +def cross_entropy(logits: torch.Tensor, target: torch.Tensor): + """ Cross Entropy Loss + Attributes: + logits: (B, D), D is the number of keywords plus 1 (non-keyword) + target: (B) + lengths: (B) + min_duration: min duration of the keyword + Returns: + (float): loss of current batch + (float): accuracy of current batch + """ + cross_entropy = nn.CrossEntropyLoss() + loss = cross_entropy(logits, target) + acc = acc_frame(logits, target) + return loss, acc + + +def criterion(type: str, + logits: torch.Tensor, + target: torch.Tensor, + lengths: torch.Tensor, + min_duration: int = 0): + if type == 'ce': + loss, acc = cross_entropy(logits, target) + return loss, acc + elif type == 'max_pooling': + loss, acc = max_pooling_loss(logits, target, lengths, min_duration) + return loss, acc + else: + exit(1) diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 2f8fe47..51fc529 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -17,7 +17,7 @@ import logging import torch from torch.nn.utils import clip_grad_norm_ -from kws.model.loss import max_polling_loss +from kws.model.loss import criterion class Executor: @@ -44,8 +44,8 @@ class Executor: if num_utts == 0: continue logits = model(feats) - loss, acc = max_polling_loss(logits, target, feats_lengths, - min_duration) + loss_type = args.get('criterion', 'max_pooling') + loss, acc = criterion(loss_type, logits, target, feats_lengths) loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm): @@ -64,6 +64,7 @@ class Executor: # in order to avoid division by 0 num_seen_utts = 1 total_loss = 0.0 + total_acc = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): key, feats, target, feats_lengths = batch @@ -73,15 +74,19 @@ class Executor: num_utts = feats_lengths.size(0) if num_utts == 0: continue - num_seen_utts += num_utts logits = model(feats) - loss, acc = max_polling_loss(logits, target, feats_lengths) + loss, acc = criterion(args.get('criterion', 'max_pooling'), + logits, target, feats_lengths) if torch.isfinite(loss): num_seen_utts += num_utts total_loss += loss.item() * num_utts + total_acc += acc * num_utts if batch_idx % log_interval == 0: logging.debug( 'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}' .format(epoch, batch_idx, loss.item(), acc, total_loss / num_seen_utts)) - return total_loss / num_seen_utts + return total_loss / num_seen_utts, total_acc / num_seen_utts + + def test(self, model, data_loader, device, args): + return self.cv(model, data_loader, device, args)