diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 454aff6..ab543dd 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -56,6 +56,7 @@ class KWSModel(torch.nn.Module): x = self.preprocessing(x) x, _ = self.backbone(x) x = self.classifier(x) + x = torch.sigmoid(x) return x diff --git a/kws/model/loss.py b/kws/model/loss.py index 57f2670..d73722c 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -37,7 +37,6 @@ def max_polling_loss(logits: torch.Tensor, (float): loss of current batch (float): accuracy of current batch ''' - logits = torch.sigmoid(logits) mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2)