remove files
This commit is contained in:
parent
db9fc7a738
commit
07b7beabad
@ -1,30 +0,0 @@
|
|||||||
'''
|
|
||||||
Date: 2022-03-04 18:10:52
|
|
||||||
LastEditors: Cyan
|
|
||||||
LastEditTime: 2022-03-07 10:21:34
|
|
||||||
'''
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Examples:
|
|
||||||
>>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
|
|
||||||
>>> mask = padding_mask(lengths)
|
|
||||||
>>> print(mask)
|
|
||||||
tensor([[False, False, True],
|
|
||||||
[False, False, True],
|
|
||||||
[False, False, False]])
|
|
||||||
"""
|
|
||||||
batch_size = lengths.size(0)
|
|
||||||
max_len = int(lengths.max().item())
|
|
||||||
seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device)
|
|
||||||
seq = seq.expand(batch_size, max_len)
|
|
||||||
return seq >= lengths.unsqueeze(1)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
|
|
||||||
print(lengths.numel())
|
|
||||||
mask = padding_mask(lengths)
|
|
||||||
print(mask, mask.size())
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
'''
|
|
||||||
Date: 2022-03-04 18:10:52
|
|
||||||
LastEditors: Cyan
|
|
||||||
LastEditTime: 2022-03-07 10:21:34
|
|
||||||
'''
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
a = [1,2,3,4,5,6,7]
|
|
||||||
for i in range(len(a)):
|
|
||||||
print('i = ', i)
|
|
||||||
if a[i] >= 3:
|
|
||||||
i += 2
|
|
||||||
# print('a[i] = ' , a[i])
|
|
||||||
Loading…
x
Reference in New Issue
Block a user