From 8924db129e5d6bd11b87f6fc09ee3158f75a008d Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Tue, 22 Aug 2023 14:34:37 +0800 Subject: [PATCH] Convert target to torch.int64 for cross_entropy On my machine the original code threw an error RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Int' I followed https://github.com/wenet-e2e/wekws#installation to setup the environment so I'm curious if this error has ever occured to other people. --- wekws/model/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wekws/model/loss.py b/wekws/model/loss.py index fef17d6..42045a0 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -169,7 +169,7 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor): (float): loss of current batch (float): accuracy of current batch """ - loss = F.cross_entropy(logits, target) + loss = F.cross_entropy(logits, target.type(torch.int64)) acc = acc_frame(logits, target) return loss, acc