diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 51fc529..94a7270 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -50,6 +50,7 @@ class Executor: grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm): optimizer.step() + optimizer.zero_grad() if batch_idx % log_interval == 0: logging.debug( 'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(