From ff4b47f94ddfde532de12c9730111a913d5dff92 Mon Sep 17 00:00:00 2001 From: Menglong Xu <32296227+mlxu995@users.noreply.github.com> Date: Tue, 15 Mar 2022 19:34:28 +0800 Subject: [PATCH] [kws] update cross_entropy loss (#62) * [kws] update cross_entropy loss replace nn.CrossEntropyLoss() with F.cross_entropy() * format --- kws/model/loss.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kws/model/loss.py b/kws/model/loss.py index f862928..eac9597 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -import torch.nn as nn +import torch.nn.functional as F from kws.utils.mask import padding_mask @@ -105,8 +105,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): (float): loss of current batch (float): accuracy of current batch """ - cross_entropy = nn.CrossEntropyLoss() - loss = cross_entropy(logits, target) + loss = F.cross_entropy(logits, target) acc = acc_frame(logits, target) return loss, acc