[kws] refine tcn and ds_tcp, add batchnorm (#31)

* [kws] fix seed type

* [kws] refine tcn and ds_tcn, add batch norm
This commit is contained in:
Binbin Zhang 2021-12-06 17:24:48 +08:00 committed by GitHub
parent 37f56db5af
commit c7c5bd3edc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,12 +29,16 @@ class CnnBlock(nn.Module):
super().__init__() super().__init__()
# The CNN used here is causal convolution # The CNN used here is causal convolution
self.padding = (kernel_size - 1) * dilation self.padding = (kernel_size - 1) * dilation
self.cnn = nn.Conv1d(channel, self.cnn = nn.Sequential(
channel, nn.Conv1d(channel,
kernel_size, channel,
stride=1, kernel_size,
dilation=dilation) stride=1,
self.dropout = nn.Dropout(dropout) dilation=dilation),
nn.BatchNorm1d(channel),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
""" """
@ -51,8 +55,6 @@ class CnnBlock(nn.Module):
new_cache = y[:, :, -self.padding:] new_cache = y[:, :, -self.padding:]
y = self.cnn(y) y = self.cnn(y)
y = F.relu(y)
y = self.dropout(y)
y = y + x # residual connection y = y + x # residual connection
return y, new_cache return y, new_cache
@ -68,17 +70,20 @@ class DsCnnBlock(nn.Module):
super().__init__() super().__init__()
# The CNN used here is causal convolution # The CNN used here is causal convolution
self.padding = (kernel_size - 1) * dilation self.padding = (kernel_size - 1) * dilation
self.depthwise_cnn = nn.Conv1d(channel, self.cnn = nn.Sequential(
channel, nn.Conv1d(channel,
kernel_size, channel,
stride=1, kernel_size,
dilation=dilation, stride=1,
groups=channel) dilation=dilation,
self.pointwise_cnn = nn.Conv1d(channel, groups=channel),
channel, nn.BatchNorm1d(channel),
kernel_size=1, nn.ReLU(),
stride=1) nn.Conv1d(channel, channel, kernel_size=1, stride=1),
self.dropout = nn.Dropout(dropout) nn.BatchNorm1d(channel),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
""" """
@ -94,10 +99,7 @@ class DsCnnBlock(nn.Module):
assert y.size(2) > self.padding assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:] new_cache = y[:, :, -self.padding:]
y = self.depthwise_cnn(y) y = self.cnn(y)
y = self.pointwise_cnn(y)
y = F.relu(y)
y = self.dropout(y)
y = y + x # residual connection y = y + x # residual connection
return y, new_cache return y, new_cache