This commit is contained in:
jingyong hou 2021-12-03 20:39:03 +08:00
parent 8cdf83a692
commit 48856977b7
2 changed files with 4 additions and 5 deletions

View File

@ -220,8 +220,8 @@ def main():
training_config['epoch'] = epoch training_config['epoch'] = epoch
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
#executor.train(model, optimizer, train_data_loader, device, writer, executor.train(model, optimizer, train_data_loader, device, writer,
# training_config) training_config)
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config) cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {} cv_acc {}' logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
.format(epoch, cv_loss, cv_acc)) .format(epoch, cv_loss, cv_acc))

View File

@ -44,9 +44,8 @@ class Executor:
if num_utts == 0: if num_utts == 0:
continue continue
logits = model(feats) logits = model(feats)
loss, acc = criterion( loss_type = args.get('criterion', 'max_pooling')
args.get('criterion', 'max_pooling'), loss, acc = criterion(loss_type, logits, target, feats_lengths)
logits, target, feats_lengths)
loss.backward() loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip) grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm): if torch.isfinite(grad_norm):