update loss.py and kws_model.py
This commit is contained in:
parent
9aaa4fc26c
commit
8cd12edfed
@ -56,6 +56,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
x = self.preprocessing(x)
|
x = self.preprocessing(x)
|
||||||
x, _ = self.backbone(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
x = torch.sigmoid(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,6 @@ def max_polling_loss(logits: torch.Tensor,
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
'''
|
'''
|
||||||
logits = torch.sigmoid(logits)
|
|
||||||
mask = padding_mask(lengths)
|
mask = padding_mask(lengths)
|
||||||
num_utts = logits.size(0)
|
num_utts = logits.size(0)
|
||||||
num_keywords = logits.size(2)
|
num_keywords = logits.size(2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user