This commit is contained in:
William RUAN 2023-04-10 10:02:17 +08:00 committed by GitHub
parent 85350c38a8
commit dc1799100d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -123,7 +123,6 @@ class TCNStack(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
stack_num: int,
stack_size: int, stack_size: int,
res_channels: int, res_channels: int,
kernel_size: int, kernel_size: int,
@ -131,7 +130,6 @@ class TCNStack(nn.Module):
): ):
super(TCNStack, self).__init__() super(TCNStack, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.stack_num = stack_num
self.stack_size = stack_size self.stack_size = stack_size
self.res_channels = res_channels self.res_channels = res_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -148,8 +146,7 @@ class TCNStack(nn.Module):
def build_dilations(self): def build_dilations(self):
dilations = [] dilations = []
for s in range(0, self.stack_size): for s in range(0, self.stack_size):
for l in range(0, self.stack_num): dilations.append(2**s)
dilations.append(2**l)
return dilations return dilations
def stack_tcn_blocks(self): def stack_tcn_blocks(self):
@ -229,7 +226,7 @@ class MDTC(nn.Module):
self.padding = self.preprocessor.padding self.padding = self.preprocessor.padding
for i in range(stack_num): for i in range(stack_num):
self.blocks.append( self.blocks.append(
TCNStack(res_channels, stack_size, 1, res_channels, TCNStack(res_channels, stack_size, res_channels,
kernel_size, causal)) kernel_size, causal))
self.padding += self.blocks[-1].padding self.padding += self.blocks[-1].padding
self.half_padding = self.padding // 2 self.half_padding = self.padding // 2