Update mdtc.py
This commit is contained in:
parent
b9551bb716
commit
5236d42800
@ -227,9 +227,9 @@ class MDTC(nn.Module):
|
||||
output_size = outputs_list[-1].shape[-1]
|
||||
for x in outputs_list:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if self.causal and (remove_length > 0):
|
||||
if self.causal and remove_length > 0:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
elif (not self.causal) and (remove_length > 1):
|
||||
elif not self.causal and remove_length > 1:
|
||||
half_remove_length = remove_length // 2
|
||||
normalized_outputs.append(
|
||||
x[:, :, half_remove_length:-half_remove_length]
|
||||
@ -245,7 +245,7 @@ class MDTC(nn.Module):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mdtc = MDTC(3, 4, 80, 64, 3, causal=True)
|
||||
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
|
||||
print(mdtc)
|
||||
|
||||
num_params = sum(p.numel() for p in mdtc.parameters())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user