[fix bug] add sigmoid() in score.py (#37)

The sigmoid() in kws/model/kws_model.py:KWSModel() was moved into kws/model/loss.py:max_pooling_loss()
To compute the posterior score correctly, the sigmoid() should also be added to kws/bin/score.py:main()
This commit is contained in:
Menglong Xu 2021-12-08 13:41:20 +08:00 committed by GitHub
parent bd504c3cee
commit 1eda27647b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -102,7 +102,7 @@ def main():
feats = feats.to(device)
lengths = lengths.to(device)
mask = padding_mask(lengths).unsqueeze(2)
logits = model(feats)
logits = torch.sigmoid(model(feats))
logits = logits.masked_fill(mask, 0.0)
max_logits, _ = logits.max(dim=1)
max_logits = max_logits.cpu()