update mdtc.py to prevent possible errors and remove useless functions
This commit is contained in:
parent
ba6919baaf
commit
f642c0952a
@ -187,6 +187,9 @@ class MDTC(nn.Module):
|
||||
):
|
||||
super(MDTC, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
if kernel_size % 2 == 0:
|
||||
print("The kernel size of MDTC must be an odd number")
|
||||
exit(1)
|
||||
self.causal = causal
|
||||
self.preprocessor = TCNBlock(in_channels,
|
||||
res_channels,
|
||||
@ -204,29 +207,6 @@ class MDTC(nn.Module):
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
print('Receptive Fields: %d' % self.receptive_fields)
|
||||
|
||||
def normalize_length_causal(self, skip_connections: list):
|
||||
output_size = skip_connections[-1].shape[-1]
|
||||
normalized_outputs = []
|
||||
for x in skip_connections:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if remove_length != 0:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
return normalized_outputs
|
||||
|
||||
def normalize_length(self, skip_connections: list):
|
||||
output_size = skip_connections[-1].shape[-1]
|
||||
normalized_outputs = []
|
||||
for x in skip_connections:
|
||||
remove_length = (x.shape[-1] - output_size) // 2
|
||||
if remove_length != 0:
|
||||
normalized_outputs.append(x[:, :,
|
||||
remove_length:-remove_length])
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
return normalized_outputs
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.causal:
|
||||
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||
@ -249,13 +229,13 @@ class MDTC(nn.Module):
|
||||
output_size = outputs_list[-1].shape[-1]
|
||||
for x in outputs_list:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if remove_length != 0:
|
||||
if self.causal:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
else:
|
||||
remove_length = remove_length // 2
|
||||
normalized_outputs.append(x[:, :,
|
||||
remove_length:-remove_length])
|
||||
if self.causal and (remove_length > 0):
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
elif (not self.causal) and (remove_length > 1):
|
||||
half_remove_length = remove_length // 2
|
||||
normalized_outputs.append(
|
||||
x[:, :, half_remove_length:-half_remove_length]
|
||||
)
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
|
||||
@ -267,7 +247,7 @@ class MDTC(nn.Module):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
|
||||
mdtc = MDTC(3, 4, 80, 64, 3, causal=True)
|
||||
print(mdtc)
|
||||
|
||||
num_params = sum(p.numel() for p in mdtc.parameters())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user