remove files
This commit is contained in:
parent
db9fc7a738
commit
07b7beabad
@ -1,30 +0,0 @@
|
||||
'''
|
||||
Date: 2022-03-04 18:10:52
|
||||
LastEditors: Cyan
|
||||
LastEditTime: 2022-03-07 10:21:34
|
||||
'''
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Examples:
|
||||
>>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
|
||||
>>> mask = padding_mask(lengths)
|
||||
>>> print(mask)
|
||||
tensor([[False, False, True],
|
||||
[False, False, True],
|
||||
[False, False, False]])
|
||||
"""
|
||||
batch_size = lengths.size(0)
|
||||
max_len = int(lengths.max().item())
|
||||
seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device)
|
||||
seq = seq.expand(batch_size, max_len)
|
||||
return seq >= lengths.unsqueeze(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
|
||||
print(lengths.numel())
|
||||
mask = padding_mask(lengths)
|
||||
print(mask, mask.size())
|
||||
@ -1,13 +0,0 @@
|
||||
'''
|
||||
Date: 2022-03-04 18:10:52
|
||||
LastEditors: Cyan
|
||||
LastEditTime: 2022-03-07 10:21:34
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = [1,2,3,4,5,6,7]
|
||||
for i in range(len(a)):
|
||||
print('i = ', i)
|
||||
if a[i] >= 3:
|
||||
i += 2
|
||||
# print('a[i] = ' , a[i])
|
||||
Loading…
x
Reference in New Issue
Block a user