diff --git a/kws/model/loss.py b/kws/model/loss.py index f862928..eac9597 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -import torch.nn as nn +import torch.nn.functional as F from kws.utils.mask import padding_mask @@ -105,8 +105,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): (float): loss of current batch (float): accuracy of current batch """ - cross_entropy = nn.CrossEntropyLoss() - loss = cross_entropy(logits, target) + loss = F.cross_entropy(logits, target) acc = acc_frame(logits, target) return loss, acc