update loss.py and kws_model.py

This commit is contained in:
jingyong hou 2021-11-19 17:14:11 +08:00
parent 9aaa4fc26c
commit 8cd12edfed
2 changed files with 1 additions and 1 deletions

View File

@ -56,6 +56,7 @@ class KWSModel(torch.nn.Module):
x = self.preprocessing(x) x = self.preprocessing(x)
x, _ = self.backbone(x) x, _ = self.backbone(x)
x = self.classifier(x) x = self.classifier(x)
x = torch.sigmoid(x)
return x return x

View File

@ -37,7 +37,6 @@ def max_polling_loss(logits: torch.Tensor,
(float): loss of current batch (float): loss of current batch
(float): accuracy of current batch (float): accuracy of current batch
''' '''
logits = torch.sigmoid(logits)
mask = padding_mask(lengths) mask = padding_mask(lengths)
num_utts = logits.size(0) num_utts = logits.size(0)
num_keywords = logits.size(2) num_keywords = logits.size(2)