[fix bug] add sigmoid() in score.py (#37)
The sigmoid() in kws/model/kws_model.py:KWSModel() was moved into kws/model/loss.py:max_pooling_loss() To compute the posterior score correctly, the sigmoid() should also be added to kws/bin/score.py:main()
This commit is contained in:
parent
bd504c3cee
commit
1eda27647b
@ -102,7 +102,7 @@ def main():
|
|||||||
feats = feats.to(device)
|
feats = feats.to(device)
|
||||||
lengths = lengths.to(device)
|
lengths = lengths.to(device)
|
||||||
mask = padding_mask(lengths).unsqueeze(2)
|
mask = padding_mask(lengths).unsqueeze(2)
|
||||||
logits = model(feats)
|
logits = torch.sigmoid(model(feats))
|
||||||
logits = logits.masked_fill(mask, 0.0)
|
logits = logits.masked_fill(mask, 0.0)
|
||||||
max_logits, _ = logits.max(dim=1)
|
max_logits, _ = logits.max(dim=1)
|
||||||
max_logits = max_logits.cpu()
|
max_logits = max_logits.cpu()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user