Merge dc1799100dd9e0451f0d5e8b58c6aa10f282a11e into 059fd87a8fd151da8adf9f80c2d4e33052e57753
This commit is contained in:
commit
537f96ad15
@ -123,7 +123,6 @@ class TCNStack(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
stack_num: int,
|
|
||||||
stack_size: int,
|
stack_size: int,
|
||||||
res_channels: int,
|
res_channels: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
@ -131,7 +130,6 @@ class TCNStack(nn.Module):
|
|||||||
):
|
):
|
||||||
super(TCNStack, self).__init__()
|
super(TCNStack, self).__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.stack_num = stack_num
|
|
||||||
self.stack_size = stack_size
|
self.stack_size = stack_size
|
||||||
self.res_channels = res_channels
|
self.res_channels = res_channels
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
@ -148,8 +146,7 @@ class TCNStack(nn.Module):
|
|||||||
def build_dilations(self):
|
def build_dilations(self):
|
||||||
dilations = []
|
dilations = []
|
||||||
for s in range(0, self.stack_size):
|
for s in range(0, self.stack_size):
|
||||||
for l in range(0, self.stack_num):
|
dilations.append(2**s)
|
||||||
dilations.append(2**l)
|
|
||||||
return dilations
|
return dilations
|
||||||
|
|
||||||
def stack_tcn_blocks(self):
|
def stack_tcn_blocks(self):
|
||||||
@ -229,7 +226,7 @@ class MDTC(nn.Module):
|
|||||||
self.padding = self.preprocessor.padding
|
self.padding = self.preprocessor.padding
|
||||||
for i in range(stack_num):
|
for i in range(stack_num):
|
||||||
self.blocks.append(
|
self.blocks.append(
|
||||||
TCNStack(res_channels, stack_size, 1, res_channels,
|
TCNStack(res_channels, stack_size, res_channels,
|
||||||
kernel_size, causal))
|
kernel_size, causal))
|
||||||
self.padding += self.blocks[-1].padding
|
self.padding += self.blocks[-1].padding
|
||||||
self.half_padding = self.padding // 2
|
self.half_padding = self.padding // 2
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user