This commit is contained in:
William RUAN 2023-04-10 10:02:17 +08:00 committed by GitHub
parent 85350c38a8
commit dc1799100d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -123,7 +123,6 @@ class TCNStack(nn.Module):
def __init__(
self,
in_channels: int,
stack_num: int,
stack_size: int,
res_channels: int,
kernel_size: int,
@ -131,7 +130,6 @@ class TCNStack(nn.Module):
):
super(TCNStack, self).__init__()
self.in_channels = in_channels
self.stack_num = stack_num
self.stack_size = stack_size
self.res_channels = res_channels
self.kernel_size = kernel_size
@ -148,8 +146,7 @@ class TCNStack(nn.Module):
def build_dilations(self):
dilations = []
for s in range(0, self.stack_size):
for l in range(0, self.stack_num):
dilations.append(2**l)
dilations.append(2**s)
return dilations
def stack_tcn_blocks(self):
@ -229,7 +226,7 @@ class MDTC(nn.Module):
self.padding = self.preprocessor.padding
for i in range(stack_num):
self.blocks.append(
TCNStack(res_channels, stack_size, 1, res_channels,
TCNStack(res_channels, stack_size, res_channels,
kernel_size, causal))
self.padding += self.blocks[-1].padding
self.half_padding = self.padding // 2