Update mdtc.py
This commit is contained in:
parent
b9551bb716
commit
5236d42800
@ -227,9 +227,9 @@ class MDTC(nn.Module):
|
|||||||
output_size = outputs_list[-1].shape[-1]
|
output_size = outputs_list[-1].shape[-1]
|
||||||
for x in outputs_list:
|
for x in outputs_list:
|
||||||
remove_length = x.shape[-1] - output_size
|
remove_length = x.shape[-1] - output_size
|
||||||
if self.causal and (remove_length > 0):
|
if self.causal and remove_length > 0:
|
||||||
normalized_outputs.append(x[:, :, remove_length:])
|
normalized_outputs.append(x[:, :, remove_length:])
|
||||||
elif (not self.causal) and (remove_length > 1):
|
elif not self.causal and remove_length > 1:
|
||||||
half_remove_length = remove_length // 2
|
half_remove_length = remove_length // 2
|
||||||
normalized_outputs.append(
|
normalized_outputs.append(
|
||||||
x[:, :, half_remove_length:-half_remove_length]
|
x[:, :, half_remove_length:-half_remove_length]
|
||||||
@ -245,7 +245,7 @@ class MDTC(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
mdtc = MDTC(3, 4, 80, 64, 3, causal=True)
|
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
|
||||||
print(mdtc)
|
print(mdtc)
|
||||||
|
|
||||||
num_params = sum(p.numel() for p in mdtc.parameters())
|
num_params = sum(p.numel() for p in mdtc.parameters())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user