update loss.py and kws_model.py
This commit is contained in:
parent
9aaa4fc26c
commit
8cd12edfed
@ -56,6 +56,7 @@ class KWSModel(torch.nn.Module):
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -37,7 +37,6 @@ def max_polling_loss(logits: torch.Tensor,
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
'''
|
||||
logits = torch.sigmoid(logits)
|
||||
mask = padding_mask(lengths)
|
||||
num_utts = logits.size(0)
|
||||
num_keywords = logits.size(2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user