wekws/learnkws/learn_mask.py
2022-03-23 18:28:06 +08:00

31 lines
839 B
Python

'''
Date: 2022-03-04 18:10:52
LastEditors: Cyan
LastEditTime: 2022-03-07 10:21:34
'''
import torch
def padding_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Examples:
>>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
>>> mask = padding_mask(lengths)
>>> print(mask)
tensor([[False, False, True],
[False, False, True],
[False, False, False]])
"""
batch_size = lengths.size(0)
max_len = int(lengths.max().item())
seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device)
seq = seq.expand(batch_size, max_len)
return seq >= lengths.unsqueeze(1)
if __name__ == '__main__':
lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
print(lengths.numel())
mask = padding_mask(lengths)
print(mask, mask.size())