From 9b2d56e2d999adec1b4a01ba002b554ee2bd3b34 Mon Sep 17 00:00:00 2001 From: Menglong Xu <32296227+mlxu995@users.noreply.github.com> Date: Tue, 15 Mar 2022 18:36:42 +0800 Subject: [PATCH] [kws] update cross_entropy loss replace nn.CrossEntropyLoss() with F.cross_entropy() --- kws/model/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kws/model/loss.py b/kws/model/loss.py index f862928..450cd95 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from kws.utils.mask import padding_mask @@ -105,8 +106,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): (float): loss of current batch (float): accuracy of current batch """ - cross_entropy = nn.CrossEntropyLoss() - loss = cross_entropy(logits, target) + loss = F.cross_entropy(logits, target) acc = acc_frame(logits, target) return loss, acc