From f2cade7684b6fdf9ad8d281104af0ace3d540889 Mon Sep 17 00:00:00 2001 From: chmod740 Date: Mon, 30 May 2022 20:26:38 +0800 Subject: [PATCH] [fix bug] add zero_grad() above backward() in kws/utils/executor.py (#72) --- kws/utils/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 94a7270..d30fb0a 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -46,11 +46,11 @@ class Executor: logits = model(feats) loss_type = args.get('criterion', 'max_pooling') loss, acc = criterion(loss_type, logits, target, feats_lengths) + optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm): optimizer.step() - optimizer.zero_grad() if batch_idx % log_interval == 0: logging.debug( 'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(