[kws] refine tcn and ds_tcp, add batchnorm (#31)
* [kws] fix seed type * [kws] refine tcn and ds_tcn, add batch norm
This commit is contained in:
parent
37f56db5af
commit
c7c5bd3edc
@ -29,12 +29,16 @@ class CnnBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# The CNN used here is causal convolution
|
# The CNN used here is causal convolution
|
||||||
self.padding = (kernel_size - 1) * dilation
|
self.padding = (kernel_size - 1) * dilation
|
||||||
self.cnn = nn.Conv1d(channel,
|
self.cnn = nn.Sequential(
|
||||||
channel,
|
nn.Conv1d(channel,
|
||||||
kernel_size,
|
channel,
|
||||||
stride=1,
|
kernel_size,
|
||||||
dilation=dilation)
|
stride=1,
|
||||||
self.dropout = nn.Dropout(dropout)
|
dilation=dilation),
|
||||||
|
nn.BatchNorm1d(channel),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
"""
|
"""
|
||||||
@ -51,8 +55,6 @@ class CnnBlock(nn.Module):
|
|||||||
new_cache = y[:, :, -self.padding:]
|
new_cache = y[:, :, -self.padding:]
|
||||||
|
|
||||||
y = self.cnn(y)
|
y = self.cnn(y)
|
||||||
y = F.relu(y)
|
|
||||||
y = self.dropout(y)
|
|
||||||
y = y + x # residual connection
|
y = y + x # residual connection
|
||||||
return y, new_cache
|
return y, new_cache
|
||||||
|
|
||||||
@ -68,17 +70,20 @@ class DsCnnBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# The CNN used here is causal convolution
|
# The CNN used here is causal convolution
|
||||||
self.padding = (kernel_size - 1) * dilation
|
self.padding = (kernel_size - 1) * dilation
|
||||||
self.depthwise_cnn = nn.Conv1d(channel,
|
self.cnn = nn.Sequential(
|
||||||
channel,
|
nn.Conv1d(channel,
|
||||||
kernel_size,
|
channel,
|
||||||
stride=1,
|
kernel_size,
|
||||||
dilation=dilation,
|
stride=1,
|
||||||
groups=channel)
|
dilation=dilation,
|
||||||
self.pointwise_cnn = nn.Conv1d(channel,
|
groups=channel),
|
||||||
channel,
|
nn.BatchNorm1d(channel),
|
||||||
kernel_size=1,
|
nn.ReLU(),
|
||||||
stride=1)
|
nn.Conv1d(channel, channel, kernel_size=1, stride=1),
|
||||||
self.dropout = nn.Dropout(dropout)
|
nn.BatchNorm1d(channel),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
"""
|
"""
|
||||||
@ -94,10 +99,7 @@ class DsCnnBlock(nn.Module):
|
|||||||
assert y.size(2) > self.padding
|
assert y.size(2) > self.padding
|
||||||
new_cache = y[:, :, -self.padding:]
|
new_cache = y[:, :, -self.padding:]
|
||||||
|
|
||||||
y = self.depthwise_cnn(y)
|
y = self.cnn(y)
|
||||||
y = self.pointwise_cnn(y)
|
|
||||||
y = F.relu(y)
|
|
||||||
y = self.dropout(y)
|
|
||||||
y = y + x # residual connection
|
y = y + x # residual connection
|
||||||
return y, new_cache
|
return y, new_cache
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user