Convert target to torch.int64 for cross_entropy

On my machine the original code threw an error 
RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int'

I followed https://github.com/wenet-e2e/wekws#installation to setup the environment so I'm curious if this error has ever occured to other people.
This commit is contained in:
Tiance Wang 2023-08-22 14:34:37 +08:00 committed by GitHub
parent b233d46552
commit 8924db129e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -169,7 +169,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
(float): loss of current batch (float): loss of current batch
(float): accuracy of current batch (float): accuracy of current batch
""" """
loss = F.cross_entropy(logits, target) loss = F.cross_entropy(logits, target.type(torch.int64))
acc = acc_frame(logits, target) acc = acc_frame(logits, target)
return loss, acc return loss, acc