[kws] update cross_entropy loss
replace nn.CrossEntropyLoss() with F.cross_entropy()
This commit is contained in:
parent
66fcfa2ce5
commit
9b2d56e2d9
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from kws.utils.mask import padding_mask
|
from kws.utils.mask import padding_mask
|
||||||
|
|
||||||
@ -105,8 +106,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
"""
|
"""
|
||||||
cross_entropy = nn.CrossEntropyLoss()
|
loss = F.cross_entropy(logits, target)
|
||||||
loss = cross_entropy(logits, target)
|
|
||||||
acc = acc_frame(logits, target)
|
acc = acc_frame(logits, target)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user