[fix] Convert target to torch.int64 for cross_entropy (#141)
On my machine the original code threw an error RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int' I followed https://github.com/wenet-e2e/wekws#installation to setup the environment so I'm curious if this error has ever occured to other people.
This commit is contained in:
parent
2c3c9ce383
commit
6ae98ef111
@ -169,7 +169,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
"""
|
"""
|
||||||
loss = F.cross_entropy(logits, target)
|
loss = F.cross_entropy(logits, target.type(torch.int64))
|
||||||
acc = acc_frame(logits, target)
|
acc = acc_frame(logits, target)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user