[kws] update cross_entropy loss (#62)
* [kws] update cross_entropy loss replace nn.CrossEntropyLoss() with F.cross_entropy() * format
This commit is contained in:
parent
66fcfa2ce5
commit
ff4b47f94d
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from kws.utils.mask import padding_mask
|
from kws.utils.mask import padding_mask
|
||||||
|
|
||||||
@ -105,8 +105,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
"""
|
"""
|
||||||
cross_entropy = nn.CrossEntropyLoss()
|
loss = F.cross_entropy(logits, target)
|
||||||
loss = cross_entropy(logits, target)
|
|
||||||
acc = acc_frame(logits, target)
|
acc = acc_frame(logits, target)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user