From b55ae111ae31d824e0f74d1883d2ff82aeb927d3 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Tue, 7 Dec 2021 10:52:14 +0800 Subject: [PATCH] [model] refactor tcn and ds_tcn share the same base class (#35) --- kws/model/tcn.py | 74 +++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 42 deletions(-) diff --git a/kws/model/tcn.py b/kws/model/tcn.py index 255ea7d..958a951 100644 --- a/kws/model/tcn.py +++ b/kws/model/tcn.py @@ -20,15 +20,43 @@ import torch.nn as nn import torch.nn.functional as F -class CnnBlock(nn.Module): +class Block(nn.Module): def __init__(self, channel: int, kernel_size: int, dilation: int, dropout: float = 0.1): super().__init__() - # The CNN used here is causal convolution self.padding = (kernel_size - 1) * dilation + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + """ + Args: + x(torch.Tensor): Input tensor (B, D, T) + Returns: + torch.Tensor(B, D, T) + """ + # The CNN used here is causal convolution + if cache is None: + y = F.pad(x, (self.padding, 0), value=0.0) + else: + y = torch.cat((cache, x), dim=2) + assert y.size(2) > self.padding + new_cache = y[:, :, -self.padding:] + + # self.cnn is defined in the subclass of Block + y = self.cnn(y) + y = y + x # residual connection + return y, new_cache + + +class CnnBlock(Block): + def __init__(self, + channel: int, + kernel_size: int, + dilation: int, + dropout: float = 0.1): + super().__init__(channel, kernel_size, dilation, dropout) self.cnn = nn.Sequential( nn.Conv1d(channel, channel, @@ -40,26 +68,8 @@ class CnnBlock(nn.Module): nn.Dropout(dropout), ) - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): - """ - Args: - x(torch.Tensor): Input tensor (B, D, T) - Returns: - torch.Tensor(B, D, T) - """ - if cache is None: - y = F.pad(x, (self.padding, 0), value=0.0) - else: - y = torch.cat((cache, x), dim=2) - assert y.size(2) > self.padding - new_cache = y[:, :, -self.padding:] - y = self.cnn(y) - y = y + x # residual connection - return y, new_cache - - -class DsCnnBlock(nn.Module): +class DsCnnBlock(Block): """ Depthwise Separable Convolution """ def __init__(self, @@ -67,9 +77,7 @@ class DsCnnBlock(nn.Module): kernel_size: int, dilation: int, dropout: float = 0.1): - super().__init__() - # The CNN used here is causal convolution - self.padding = (kernel_size - 1) * dilation + super().__init__(channel, kernel_size, dilation, dropout) self.cnn = nn.Sequential( nn.Conv1d(channel, channel, @@ -85,24 +93,6 @@ class DsCnnBlock(nn.Module): nn.Dropout(dropout), ) - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): - """ - Args: - x(torch.Tensor): Input tensor (B, D, T) - Returns: - torch.Tensor(B, D, T) - """ - if cache is None: - y = F.pad(x, (self.padding, 0), value=0.0) - else: - y = torch.cat((cache, x), dim=2) - assert y.size(2) > self.padding - new_cache = y[:, :, -self.padding:] - - y = self.cnn(y) - y = y + x # residual connection - return y, new_cache - class TCN(nn.Module): def __init__(self,