[kws] update cross_entropy loss (#62)
* [kws] update cross_entropy loss replace nn.CrossEntropyLoss() with F.cross_entropy() * format
This commit is contained in:
parent
66fcfa2ce5
commit
ff4b47f94d
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from kws.utils.mask import padding_mask
|
||||
|
||||
@ -105,8 +105,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
"""
|
||||
cross_entropy = nn.CrossEntropyLoss()
|
||||
loss = cross_entropy(logits, target)
|
||||
loss = F.cross_entropy(logits, target)
|
||||
acc = acc_frame(logits, target)
|
||||
return loss, acc
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user