From c7c5bd3edc41e11112b03b727e3961615bc91c39 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Mon, 6 Dec 2021 17:24:48 +0800 Subject: [PATCH] [kws] refine tcn and ds_tcp, add batchnorm (#31) * [kws] fix seed type * [kws] refine tcn and ds_tcn, add batch norm --- kws/model/tcn.py | 48 +++++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/kws/model/tcn.py b/kws/model/tcn.py index 6a002b1..255ea7d 100644 --- a/kws/model/tcn.py +++ b/kws/model/tcn.py @@ -29,12 +29,16 @@ class CnnBlock(nn.Module): super().__init__() # The CNN used here is causal convolution self.padding = (kernel_size - 1) * dilation - self.cnn = nn.Conv1d(channel, - channel, - kernel_size, - stride=1, - dilation=dilation) - self.dropout = nn.Dropout(dropout) + self.cnn = nn.Sequential( + nn.Conv1d(channel, + channel, + kernel_size, + stride=1, + dilation=dilation), + nn.BatchNorm1d(channel), + nn.ReLU(), + nn.Dropout(dropout), + ) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): """ @@ -51,8 +55,6 @@ class CnnBlock(nn.Module): new_cache = y[:, :, -self.padding:] y = self.cnn(y) - y = F.relu(y) - y = self.dropout(y) y = y + x # residual connection return y, new_cache @@ -68,17 +70,20 @@ class DsCnnBlock(nn.Module): super().__init__() # The CNN used here is causal convolution self.padding = (kernel_size - 1) * dilation - self.depthwise_cnn = nn.Conv1d(channel, - channel, - kernel_size, - stride=1, - dilation=dilation, - groups=channel) - self.pointwise_cnn = nn.Conv1d(channel, - channel, - kernel_size=1, - stride=1) - self.dropout = nn.Dropout(dropout) + self.cnn = nn.Sequential( + nn.Conv1d(channel, + channel, + kernel_size, + stride=1, + dilation=dilation, + groups=channel), + nn.BatchNorm1d(channel), + nn.ReLU(), + nn.Conv1d(channel, channel, kernel_size=1, stride=1), + nn.BatchNorm1d(channel), + nn.ReLU(), + nn.Dropout(dropout), + ) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): """ @@ -94,10 +99,7 @@ class DsCnnBlock(nn.Module): assert y.size(2) > self.padding new_cache = y[:, :, -self.padding:] - y = self.depthwise_cnn(y) - y = self.pointwise_cnn(y) - y = F.relu(y) - y = self.dropout(y) + y = self.cnn(y) y = y + x # residual connection return y, new_cache