diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 51fc529..d30fb0a 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -46,6 +46,7 @@ class Executor: logits = model(feats) loss_type = args.get('criterion', 'max_pooling') loss, acc = criterion(loss_type, logits, target, feats_lengths) + optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm):