From c48c959807e7e80cdd514be9bd019b16e3b816eb Mon Sep 17 00:00:00 2001 From: xiaohou Date: Fri, 3 Dec 2021 21:07:42 +0800 Subject: [PATCH] [recipe] suport speech command dataset (#21) * [recipe] suport speech command dataset * format * format * format * update run.sh --- examples/speechcommand_v1/s0/conf/mdtc.yaml | 52 +++++++++ examples/speechcommand_v1/s0/kws | 1 + .../s0/local/data_download.sh | 30 +++++ .../s0/local/prepare_speech_command.py | 35 ++++++ .../s0/local/split_dataset.py | 42 +++++++ examples/speechcommand_v1/s0/path.sh | 1 + examples/speechcommand_v1/s0/run.sh | 108 ++++++++++++++++++ examples/speechcommand_v1/s0/tools | 1 + kws/bin/test.py | 102 +++++++++++++++++ kws/bin/train.py | 15 +-- kws/model/classifier.py | 33 ++++++ kws/model/kws_model.py | 29 ++++- kws/model/loss.py | 47 +++++++- kws/utils/executor.py | 17 ++- 14 files changed, 494 insertions(+), 19 deletions(-) create mode 100644 examples/speechcommand_v1/s0/conf/mdtc.yaml create mode 120000 examples/speechcommand_v1/s0/kws create mode 100755 examples/speechcommand_v1/s0/local/data_download.sh create mode 100644 examples/speechcommand_v1/s0/local/prepare_speech_command.py create mode 100755 examples/speechcommand_v1/s0/local/split_dataset.py create mode 120000 examples/speechcommand_v1/s0/path.sh create mode 100644 examples/speechcommand_v1/s0/run.sh create mode 120000 examples/speechcommand_v1/s0/tools create mode 100644 kws/bin/test.py create mode 100644 kws/model/classifier.py diff --git a/examples/speechcommand_v1/s0/conf/mdtc.yaml b/examples/speechcommand_v1/s0/conf/mdtc.yaml new file mode 100644 index 0000000..6b5646f --- /dev/null +++ b/examples/speechcommand_v1/s0/conf/mdtc.yaml @@ -0,0 +1,52 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'mfcc' + num_ceps: 80 + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + feature_dither: 0.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 20 + max_f: 40 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 100 + +model: + hidden_dim: 64 + preprocessing: + type: none + backbone: + type: mdtc + num_stack: 4 + stack_size: 4 + kernel_size: 5 + hidden_dim: 64 + causal: False + classifier: + type: global + dropout: 0.5 + +optim: adam +optim_conf: + lr: 0.0002 + weight_decay: 0.00005 + +training_config: + grad_clip: 50 + max_epoch: 100 + log_interval: 10 + criterion: ce diff --git a/examples/speechcommand_v1/s0/kws b/examples/speechcommand_v1/s0/kws new file mode 120000 index 0000000..7a3e8e1 --- /dev/null +++ b/examples/speechcommand_v1/s0/kws @@ -0,0 +1 @@ +../../../kws \ No newline at end of file diff --git a/examples/speechcommand_v1/s0/local/data_download.sh b/examples/speechcommand_v1/s0/local/data_download.sh new file mode 100755 index 0000000..4b1000a --- /dev/null +++ b/examples/speechcommand_v1/s0/local/data_download.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# Copyright 2021 Jingyong Hou (houjingyong@gmail.com) +[ -f ./path.sh ] && . ./path.sh + +dl_dir=./data/local + +. tools/parse_options.sh || exit 1; +data_dir=$dl_dir +file_name=speech_commands_v0.01.tar.gz +speech_command_dir=$data_dir/speech_commands_v1 +audio_dir=$data_dir/speech_commands_v1/audio +url=http://download.tensorflow.org/data/$file_name +mkdir -p $data_dir +if [ ! -f $data_dir/$file_name ]; then + echo "downloading $url..." + wget -O $data_dir/$file_name $url +else + echo "$file_name exist in $data_dir, skip download it" +fi + +if [ ! -f $speech_command_dir/.extracted ]; then + mkdir -p $audio_dir + tar -xzvf $data_dir/$file_name -C $audio_dir + touch $speech_command_dir/.extracted +else + echo "$speech_command_dir/.exatracted exist in $speech_command_dir, skip exatraction" +fi + +exit 0 diff --git a/examples/speechcommand_v1/s0/local/prepare_speech_command.py b/examples/speechcommand_v1/s0/local/prepare_speech_command.py new file mode 100644 index 0000000..a9d6982 --- /dev/null +++ b/examples/speechcommand_v1/s0/local/prepare_speech_command.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +import os +import argparse + +CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split( + ', ') +CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))} + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='prepare kaldi format file for google speech command') + parser.add_argument( + '--wav_list', + required=True, + help='full path of a wav file in google speech command dataset') + parser.add_argument('--data_dir', + required=True, + help='folder to write kaldi format files') + args = parser.parse_args() + + data_dir = args.data_dir + f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w') + f_text = open(os.path.join(data_dir, 'text'), 'w') + with open(args.wav_list) as f: + for line in f.readlines(): + keyword, file_name = line.strip().split('/')[-2:] + file_name_new = file_name.split('.')[0] + wav_id = '_'.join([keyword, file_name_new]) + file_dir = line.strip() + f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n') + label = CLASS_TO_IDX[ + keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"] + f_text.writelines(wav_id + ' ' + str(label) + '\n') + f_wav_scp.close() + f_text.close() diff --git a/examples/speechcommand_v1/s0/local/split_dataset.py b/examples/speechcommand_v1/s0/local/split_dataset.py new file mode 100755 index 0000000..b7b5326 --- /dev/null +++ b/examples/speechcommand_v1/s0/local/split_dataset.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +'''Splits the google speech commands into train, validation and test set''' + +import os +import shutil +import argparse + + +def move_files(src_folder, to_folder, list_file): + with open(list_file) as f: + for line in f.readlines(): + line = line.rstrip() + dirname = os.path.dirname(line) + dest = os.path.join(to_folder, dirname) + if not os.path.exists(dest): + os.mkdir(dest) + shutil.move(os.path.join(src_folder, line), dest) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Split google command dataset.') + parser.add_argument( + 'root', + type=str, + help='the path to the root folder of the google commands dataset') + args = parser.parse_args() + + audio_folder = os.path.join(args.root, 'audio') + validation_path = os.path.join(audio_folder, 'validation_list.txt') + test_path = os.path.join(audio_folder, 'testing_list.txt') + + valid_folder = os.path.join(args.root, 'valid') + test_folder = os.path.join(args.root, 'test') + train_folder = os.path.join(args.root, 'train') + + os.mkdir(valid_folder) + os.mkdir(test_folder) + + move_files(audio_folder, test_folder, test_path) + move_files(audio_folder, valid_folder, validation_path) + os.rename(audio_folder, train_folder) diff --git a/examples/speechcommand_v1/s0/path.sh b/examples/speechcommand_v1/s0/path.sh new file mode 120000 index 0000000..5587d29 --- /dev/null +++ b/examples/speechcommand_v1/s0/path.sh @@ -0,0 +1 @@ +../../hi_xiaowen/s0/path.sh \ No newline at end of file diff --git a/examples/speechcommand_v1/s0/run.sh b/examples/speechcommand_v1/s0/run.sh new file mode 100644 index 0000000..413e824 --- /dev/null +++ b/examples/speechcommand_v1/s0/run.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Copyright 2021 Binbin Zhang +# Jingyong Hou + +. ./path.sh + +export CUDA_VISIBLE_DEVICES="0" + +stage=-1 +stop_stage=4 +num_keywords=11 + +config=conf/mdtc.yaml +norm_mean=false +norm_var=false +gpu_id=4 + +checkpoint= +dir=exp/mdtc_debug + +num_average=10 +score_checkpoint=$dir/avg_${num_average}.pt + +# your data dir +download_dir=/mnt/mnt-data-3/jingyong.hou/data +speech_command_dir=$download_dir/speech_commands_v1 +. tools/parse_options.sh || exit 1; + +set -euo pipefail + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "Download and extract all datasets" + local/data_download.sh --dl_dir $download_dir + python local/split_dataset.py $download_dir/speech_commands_v1 +fi + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Start preparing Kaldi format files" + for x in train test valid; + do + data=data/$x + mkdir -p $data + # make wav.scp utt2spk text file + find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list + python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Compute CMVN and Format datasets" + tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \ + --in_scp data/train/wav.scp \ + --out_cmvn data/train/global_cmvn + + for x in train valid test; do + tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur + tools/make_list.py data/$x/wav.scp data/$x/text \ + data/$x/wav.dur data/$x/data.list + done +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Start training ..." + mkdir -p $dir + cmvn_opts= + $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" + $norm_var && cmvn_opts="$cmvn_opts --norm_var" + python kws/bin/train.py --gpu $gpu_id \ + --config $config \ + --train_data data/train/data.list \ + --cv_data data/valid/data.list \ + --model_dir $dir \ + --num_workers 8 \ + --num_keywords $num_keywords \ + --min_duration 50 \ + $cmvn_opts \ + ${checkpoint:+--checkpoint $checkpoint} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Do model average + python kws/bin/average_model.py \ + --dst_model $score_checkpoint \ + --src_path $dir \ + --num ${num_average} \ + --val_best + + # Testing + result_dir=$dir/test_$(basename $score_checkpoint) + mkdir -p $result_dir + python kws/bin/test.py --gpu 3 \ + --config $dir/config.yaml \ + --test_data data/test/data.list \ + --batch_size 256 \ + --num_workers 8 \ + --checkpoint $score_checkpoint +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + python kws/bin/export_jit.py --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --output_file $dir/final.zip \ + --output_quant_file $dir/final.quant.zip +fi diff --git a/examples/speechcommand_v1/s0/tools b/examples/speechcommand_v1/s0/tools new file mode 120000 index 0000000..c92f417 --- /dev/null +++ b/examples/speechcommand_v1/s0/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file diff --git a/kws/bin/test.py b/kws/bin/test.py new file mode 100644 index 0000000..a6cde27 --- /dev/null +++ b/kws/bin/test.py @@ -0,0 +1,102 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import yaml +from torch.utils.data import DataLoader + +from kws.dataset.dataset import Dataset +from kws.model.kws_model import init_model +from kws.utils.checkpoint import load_checkpoint +from kws.utils.executor import Executor + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--batch_size', + default=16, + type=int, + help='batch size for inference') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['feature_extraction_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_size'] = args.batch_size + + test_dataset = Dataset(args.test_data, test_conf) + test_data_loader = DataLoader(test_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers) + + # Init asr model from configs + model = init_model(configs['model']) + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + executor = Executor() + model.eval() + training_config = configs['training_config'] + with torch.no_grad(): + test_loss, test_acc = executor.test(model, test_data_loader, device, + training_config) + logging.info('Test Loss {} Acc {}'.format(test_loss, test_acc)) + + +if __name__ == '__main__': + main() diff --git a/kws/bin/train.py b/kws/bin/train.py index 3497322..51e2766 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -164,9 +164,9 @@ def main(): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - # if args.rank == 0: - # script_model = torch.jit.script(model) - # script_model.save(os.path.join(args.model_dir, 'init.zip')) + if args.rank == 0: + script_model = torch.jit.script(model) + script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: @@ -196,8 +196,7 @@ def main(): model = model.to(device) optimizer = optim.Adam(model.parameters(), - lr=configs['optim_conf']['lr'], - weight_decay=configs['optim_conf']['weight_decay']) + lr=configs['optim_conf']['lr']) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', @@ -223,8 +222,9 @@ def main(): logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) executor.train(model, optimizer, train_data_loader, device, writer, training_config) - cv_loss = executor.cv(model, cv_data_loader, device, training_config) - logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) + cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) + logging.info('Epoch {} CV info cv_loss {} cv_acc {}' + .format(epoch, cv_loss, cv_acc)) if args.rank == 0: save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) @@ -234,6 +234,7 @@ def main(): 'cv_loss': cv_loss, }) writer.add_scalar('epoch/cv_loss', cv_loss, epoch) + writer.add_scalar('epoch/cv_acc', cv_acc, epoch) writer.add_scalar('epoch/lr', lr, epoch) final_epoch = epoch scheduler.step(cv_loss) diff --git a/kws/model/classifier.py b/kws/model/classifier.py new file mode 100644 index 0000000..c742b7b --- /dev/null +++ b/kws/model/classifier.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + + +class GlobalClassifier(nn.Module): + """Add a global average pooling before the classifier""" + def __init__(self, classifier: nn.Module): + super(GlobalClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + x = torch.mean(x, dim=1) + return self.classifier(x) + + +class LastClassifier(nn.Module): + """Select last frame to do the classification""" + def __init__(self, classifier: nn.Module): + super(LastClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + x = x[:, -1, :] + return self.classifier(x) + +class ElementClassifier(nn.Module): + """Classify all the frames in an utterance""" + def __init__(self, classifier: nn.Module): + super(ElementClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + return self.classifier(x) diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index dd6142a..5c56aef 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -16,8 +16,10 @@ import sys from typing import Optional import torch +import torch.nn as nn from kws.model.cmvn import GlobalCMVN +from kws.model.classifier import GlobalClassifier, LastClassifier from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.mdtc import MDTC @@ -39,6 +41,7 @@ class KWSModel(torch.nn.Module): global_cmvn: Optional[torch.nn.Module], preprocessing: Optional[torch.nn.Module], backbone: torch.nn.Module, + classifier: torch.nn.Module ): super().__init__() self.idim = idim @@ -47,7 +50,7 @@ class KWSModel(torch.nn.Module): self.global_cmvn = global_cmvn self.preprocessing = preprocessing self.backbone = backbone - self.classifier = torch.nn.Linear(hdim, odim) + self.classifier = classifier def forward(self, x: torch.Tensor) -> torch.Tensor: if self.global_cmvn is not None: @@ -55,7 +58,6 @@ class KWSModel(torch.nn.Module): x = self.preprocessing(x) x, _ = self.backbone(x) x = self.classifier(x) - x = torch.sigmoid(x) return x @@ -110,17 +112,34 @@ def init_model(configs): num_stack = configs['backbone']['num_stack'] kernel_size = configs['backbone']['kernel_size'] hidden_dim = configs['backbone']['hidden_dim'] - + causal = configs['backbone']['causal'] backbone = MDTC(num_stack, stack_size, input_dim, hidden_dim, kernel_size, - causal=True) + causal=causal) else: print('Unknown body type {}'.format(backbone_type)) sys.exit(1) + classifier_type = configs['classifier']['type'] + dropout = configs['classifier']['dropout'] + classifier_base = nn.Sequential( + nn.Linear(hidden_dim, 64), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(64, output_dim), + ) + if classifier_type == 'linear': + classifier = classifier_base + elif classifier_type == 'global': + classifier = GlobalClassifier(classifier_base) + elif classifier_type == 'last': + classifier = LastClassifier(classifier_base) + else: + print('Unknown classifier type {}'.format(classifier_type)) + sys.exit(1) kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, - preprocessing, backbone) + preprocessing, backbone, classifier) return kws_model diff --git a/kws/model/loss.py b/kws/model/loss.py index d73722c..80a7fec 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -13,11 +13,12 @@ # limitations under the License. import torch +import torch.nn as nn from kws.utils.mask import padding_mask -def max_polling_loss(logits: torch.Tensor, +def max_pooling_loss(logits: torch.Tensor, target: torch.Tensor, lengths: torch.Tensor, min_duration: int = 0): @@ -37,6 +38,7 @@ def max_polling_loss(logits: torch.Tensor, (float): loss of current batch (float): accuracy of current batch ''' + logits = torch.sigmoid(logits) mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2) @@ -80,3 +82,46 @@ def max_polling_loss(logits: torch.Tensor, acc = num_correct / num_utts # acc = 0.0 return loss, acc + + +def acc_frame( + logits: torch.Tensor, + target: torch.Tensor, +): + if logits is None: + return 0 + pred = logits.max(1, keepdim=True)[1] + correct = pred.eq(target.long().view_as(pred)).sum().item() + return correct * 100.0 / logits.size(0) + + +def cross_entropy(logits: torch.Tensor, target: torch.Tensor): + """ Cross Entropy Loss + Attributes: + logits: (B, D), D is the number of keywords plus 1 (non-keyword) + target: (B) + lengths: (B) + min_duration: min duration of the keyword + Returns: + (float): loss of current batch + (float): accuracy of current batch + """ + cross_entropy = nn.CrossEntropyLoss() + loss = cross_entropy(logits, target) + acc = acc_frame(logits, target) + return loss, acc + + +def criterion(type: str, + logits: torch.Tensor, + target: torch.Tensor, + lengths: torch.Tensor, + min_duration: int = 0): + if type == 'ce': + loss, acc = cross_entropy(logits, target) + return loss, acc + elif type == 'max_pooling': + loss, acc = max_pooling_loss(logits, target, lengths, min_duration) + return loss, acc + else: + exit(1) diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 2f8fe47..51fc529 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -17,7 +17,7 @@ import logging import torch from torch.nn.utils import clip_grad_norm_ -from kws.model.loss import max_polling_loss +from kws.model.loss import criterion class Executor: @@ -44,8 +44,8 @@ class Executor: if num_utts == 0: continue logits = model(feats) - loss, acc = max_polling_loss(logits, target, feats_lengths, - min_duration) + loss_type = args.get('criterion', 'max_pooling') + loss, acc = criterion(loss_type, logits, target, feats_lengths) loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm): @@ -64,6 +64,7 @@ class Executor: # in order to avoid division by 0 num_seen_utts = 1 total_loss = 0.0 + total_acc = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): key, feats, target, feats_lengths = batch @@ -73,15 +74,19 @@ class Executor: num_utts = feats_lengths.size(0) if num_utts == 0: continue - num_seen_utts += num_utts logits = model(feats) - loss, acc = max_polling_loss(logits, target, feats_lengths) + loss, acc = criterion(args.get('criterion', 'max_pooling'), + logits, target, feats_lengths) if torch.isfinite(loss): num_seen_utts += num_utts total_loss += loss.item() * num_utts + total_acc += acc * num_utts if batch_idx % log_interval == 0: logging.debug( 'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}' .format(epoch, batch_idx, loss.item(), acc, total_loss / num_seen_utts)) - return total_loss / num_seen_utts + return total_loss / num_seen_utts, total_acc / num_seen_utts + + def test(self, model, data_loader, device, args): + return self.cv(model, data_loader, device, args)