diff --git a/kws/bin/score.py b/kws/bin/score.py index 4c80816..40986d2 100644 --- a/kws/bin/score.py +++ b/kws/bin/score.py @@ -102,7 +102,7 @@ def main(): feats = feats.to(device) lengths = lengths.to(device) mask = padding_mask(lengths).unsqueeze(2) - logits = model(feats) + logits = torch.sigmoid(model(feats)) logits = logits.masked_fill(mask, 0.0) max_logits, _ = logits.max(dim=1) max_logits = max_logits.cpu()