[model] refactor tcn and ds_tcn share the same base class (#35)

This commit is contained in:
Binbin Zhang 2021-12-07 10:52:14 +08:00 committed by GitHub
parent 92a4c19ffe
commit b55ae111ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,15 +20,43 @@ import torch.nn as nn
import torch.nn.functional as F
class CnnBlock(nn.Module):
class Block(nn.Module):
def __init__(self,
channel: int,
kernel_size: int,
dilation: int,
dropout: float = 0.1):
super().__init__()
# The CNN used here is causal convolution
self.padding = (kernel_size - 1) * dilation
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x(torch.Tensor): Input tensor (B, D, T)
Returns:
torch.Tensor(B, D, T)
"""
# The CNN used here is causal convolution
if cache is None:
y = F.pad(x, (self.padding, 0), value=0.0)
else:
y = torch.cat((cache, x), dim=2)
assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:]
# self.cnn is defined in the subclass of Block
y = self.cnn(y)
y = y + x # residual connection
return y, new_cache
class CnnBlock(Block):
def __init__(self,
channel: int,
kernel_size: int,
dilation: int,
dropout: float = 0.1):
super().__init__(channel, kernel_size, dilation, dropout)
self.cnn = nn.Sequential(
nn.Conv1d(channel,
channel,
@ -40,26 +68,8 @@ class CnnBlock(nn.Module):
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x(torch.Tensor): Input tensor (B, D, T)
Returns:
torch.Tensor(B, D, T)
"""
if cache is None:
y = F.pad(x, (self.padding, 0), value=0.0)
else:
y = torch.cat((cache, x), dim=2)
assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:]
y = self.cnn(y)
y = y + x # residual connection
return y, new_cache
class DsCnnBlock(nn.Module):
class DsCnnBlock(Block):
""" Depthwise Separable Convolution
"""
def __init__(self,
@ -67,9 +77,7 @@ class DsCnnBlock(nn.Module):
kernel_size: int,
dilation: int,
dropout: float = 0.1):
super().__init__()
# The CNN used here is causal convolution
self.padding = (kernel_size - 1) * dilation
super().__init__(channel, kernel_size, dilation, dropout)
self.cnn = nn.Sequential(
nn.Conv1d(channel,
channel,
@ -85,24 +93,6 @@ class DsCnnBlock(nn.Module):
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x(torch.Tensor): Input tensor (B, D, T)
Returns:
torch.Tensor(B, D, T)
"""
if cache is None:
y = F.pad(x, (self.padding, 0), value=0.0)
else:
y = torch.cat((cache, x), dim=2)
assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:]
y = self.cnn(y)
y = y + x # residual connection
return y, new_cache
class TCN(nn.Module):
def __init__(self,