diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 94a7270..d30fb0a 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -46,11 +46,11 @@ class Executor: logits = model(feats) loss_type = args.get('criterion', 'max_pooling') loss, acc = criterion(loss_type, logits, target, feats_lengths) + optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm): optimizer.step() - optimizer.zero_grad() if batch_idx % log_interval == 0: logging.debug( 'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(