format
This commit is contained in:
parent
8cdf83a692
commit
48856977b7
@ -220,8 +220,8 @@ def main():
|
|||||||
training_config['epoch'] = epoch
|
training_config['epoch'] = epoch
|
||||||
lr = optimizer.param_groups[0]['lr']
|
lr = optimizer.param_groups[0]['lr']
|
||||||
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||||
#executor.train(model, optimizer, train_data_loader, device, writer,
|
executor.train(model, optimizer, train_data_loader, device, writer,
|
||||||
# training_config)
|
training_config)
|
||||||
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
||||||
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
|
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
|
||||||
.format(epoch, cv_loss, cv_acc))
|
.format(epoch, cv_loss, cv_acc))
|
||||||
|
|||||||
@ -44,9 +44,8 @@ class Executor:
|
|||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss, acc = criterion(
|
loss_type = args.get('criterion', 'max_pooling')
|
||||||
args.get('criterion', 'max_pooling'),
|
loss, acc = criterion(loss_type, logits, target, feats_lengths)
|
||||||
logits, target, feats_lengths)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||||
if torch.isfinite(grad_norm):
|
if torch.isfinite(grad_norm):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user