From dc1799100dd9e0451f0d5e8b58c6aa10f282a11e Mon Sep 17 00:00:00 2001 From: William RUAN <56035086+William1617@users.noreply.github.com> Date: Mon, 10 Apr 2023 10:02:17 +0800 Subject: [PATCH] fix mdtc --- wekws/model/mdtc.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/wekws/model/mdtc.py b/wekws/model/mdtc.py index 090dde5..34f0e25 100644 --- a/wekws/model/mdtc.py +++ b/wekws/model/mdtc.py @@ -123,7 +123,6 @@ class TCNStack(nn.Module): def __init__( self, in_channels: int, - stack_num: int, stack_size: int, res_channels: int, kernel_size: int, @@ -131,7 +130,6 @@ class TCNStack(nn.Module): ): super(TCNStack, self).__init__() self.in_channels = in_channels - self.stack_num = stack_num self.stack_size = stack_size self.res_channels = res_channels self.kernel_size = kernel_size @@ -148,8 +146,7 @@ class TCNStack(nn.Module): def build_dilations(self): dilations = [] for s in range(0, self.stack_size): - for l in range(0, self.stack_num): - dilations.append(2**l) + dilations.append(2**s) return dilations def stack_tcn_blocks(self): @@ -229,7 +226,7 @@ class MDTC(nn.Module): self.padding = self.preprocessor.padding for i in range(stack_num): self.blocks.append( - TCNStack(res_channels, stack_size, 1, res_channels, + TCNStack(res_channels, stack_size, res_channels, kernel_size, causal)) self.padding += self.blocks[-1].padding self.half_padding = self.padding // 2