diff --git a/kws/model/loss.py b/kws/model/loss.py index f862928..450cd95 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from kws.utils.mask import padding_mask @@ -105,8 +106,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): (float): loss of current batch (float): accuracy of current batch """ - cross_entropy = nn.CrossEntropyLoss() - loss = cross_entropy(logits, target) + loss = F.cross_entropy(logits, target) acc = acc_frame(logits, target) return loss, acc