[fix bug] add sigmoid() in score.py (#37)
The sigmoid() in kws/model/kws_model.py:KWSModel() was moved into kws/model/loss.py:max_pooling_loss() To compute the posterior score correctly, the sigmoid() should also be added to kws/bin/score.py:main()
This commit is contained in:
parent
bd504c3cee
commit
1eda27647b
@ -102,7 +102,7 @@ def main():
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
mask = padding_mask(lengths).unsqueeze(2)
|
||||
logits = model(feats)
|
||||
logits = torch.sigmoid(model(feats))
|
||||
logits = logits.masked_fill(mask, 0.0)
|
||||
max_logits, _ = logits.max(dim=1)
|
||||
max_logits = max_logits.cpu()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user