[kws] update cross_entropy loss (#62)

* [kws] update cross_entropy loss

replace nn.CrossEntropyLoss() with F.cross_entropy()

* format
This commit is contained in:
Menglong Xu 2022-03-15 19:34:28 +08:00 committed by GitHub
parent 66fcfa2ce5
commit ff4b47f94d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn import torch.nn.functional as F
from kws.utils.mask import padding_mask from kws.utils.mask import padding_mask
@ -105,8 +105,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
(float): loss of current batch (float): loss of current batch
(float): accuracy of current batch (float): accuracy of current batch
""" """
cross_entropy = nn.CrossEntropyLoss() loss = F.cross_entropy(logits, target)
loss = cross_entropy(logits, target)
acc = acc_frame(logits, target) acc = acc_frame(logits, target)
return loss, acc return loss, acc