fix mdtc
This commit is contained in:
parent
85350c38a8
commit
dc1799100d
@ -123,7 +123,6 @@ class TCNStack(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stack_num: int,
|
||||
stack_size: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
@ -131,7 +130,6 @@ class TCNStack(nn.Module):
|
||||
):
|
||||
super(TCNStack, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.stack_num = stack_num
|
||||
self.stack_size = stack_size
|
||||
self.res_channels = res_channels
|
||||
self.kernel_size = kernel_size
|
||||
@ -148,8 +146,7 @@ class TCNStack(nn.Module):
|
||||
def build_dilations(self):
|
||||
dilations = []
|
||||
for s in range(0, self.stack_size):
|
||||
for l in range(0, self.stack_num):
|
||||
dilations.append(2**l)
|
||||
dilations.append(2**s)
|
||||
return dilations
|
||||
|
||||
def stack_tcn_blocks(self):
|
||||
@ -229,7 +226,7 @@ class MDTC(nn.Module):
|
||||
self.padding = self.preprocessor.padding
|
||||
for i in range(stack_num):
|
||||
self.blocks.append(
|
||||
TCNStack(res_channels, stack_size, 1, res_channels,
|
||||
TCNStack(res_channels, stack_size, res_channels,
|
||||
kernel_size, causal))
|
||||
self.padding += self.blocks[-1].padding
|
||||
self.half_padding = self.padding // 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user