This commit is contained in:
blessyyyu 2022-03-23 14:32:04 +08:00
commit c2572d9abf

View File

@ -13,7 +13,7 @@
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from kws.utils.mask import padding_mask
@ -105,8 +105,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
(float): loss of current batch
(float): accuracy of current batch
"""
cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(logits, target)
loss = F.cross_entropy(logits, target)
acc = acc_frame(logits, target)
return loss, acc