Update mdtc.py

This commit is contained in:
xiaohou 2021-11-30 17:14:30 +08:00 committed by GitHub
parent b9551bb716
commit 5236d42800
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -227,9 +227,9 @@ class MDTC(nn.Module):
output_size = outputs_list[-1].shape[-1]
for x in outputs_list:
remove_length = x.shape[-1] - output_size
if self.causal and (remove_length > 0):
if self.causal and remove_length > 0:
normalized_outputs.append(x[:, :, remove_length:])
elif (not self.causal) and (remove_length > 1):
elif not self.causal and remove_length > 1:
half_remove_length = remove_length // 2
normalized_outputs.append(
x[:, :, half_remove_length:-half_remove_length]
@ -245,7 +245,7 @@ class MDTC(nn.Module):
if __name__ == '__main__':
mdtc = MDTC(3, 4, 80, 64, 3, causal=True)
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
print(mdtc)
num_params = sum(p.numel() for p in mdtc.parameters())