[fix bug] add zero_grad() above backward() in kws/utils/executor.py (#72)
This commit is contained in:
parent
cff8a5fe26
commit
f2cade7684
@ -46,11 +46,11 @@ class Executor:
|
|||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss_type = args.get('criterion', 'max_pooling')
|
loss_type = args.get('criterion', 'max_pooling')
|
||||||
loss, acc = criterion(loss_type, logits, target, feats_lengths)
|
loss, acc = criterion(loss_type, logits, target, feats_lengths)
|
||||||
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||||
if torch.isfinite(grad_norm):
|
if torch.isfinite(grad_norm):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(
|
'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user