diff --git a/kws/bin/train.py b/kws/bin/train.py index 34acc80..51e2766 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -220,8 +220,8 @@ def main(): training_config['epoch'] = epoch lr = optimizer.param_groups[0]['lr'] logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) - #executor.train(model, optimizer, train_data_loader, device, writer, - # training_config) + executor.train(model, optimizer, train_data_loader, device, writer, + training_config) cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) logging.info('Epoch {} CV info cv_loss {} cv_acc {}' .format(epoch, cv_loss, cv_acc)) diff --git a/kws/utils/executor.py b/kws/utils/executor.py index a5c3d74..51fc529 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -44,9 +44,8 @@ class Executor: if num_utts == 0: continue logits = model(feats) - loss, acc = criterion( - args.get('criterion', 'max_pooling'), - logits, target, feats_lengths) + loss_type = args.get('criterion', 'max_pooling') + loss, acc = criterion(loss_type, logits, target, feats_lengths) loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm):