[model] refactor tcn and ds_tcn share the same base class (#35)
This commit is contained in:
parent
92a4c19ffe
commit
b55ae111ae
@ -20,15 +20,43 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class CnnBlock(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
channel: int,
|
channel: int,
|
||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
dilation: int,
|
dilation: int,
|
||||||
dropout: float = 0.1):
|
dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# The CNN used here is causal convolution
|
|
||||||
self.padding = (kernel_size - 1) * dilation
|
self.padding = (kernel_size - 1) * dilation
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x(torch.Tensor): Input tensor (B, D, T)
|
||||||
|
Returns:
|
||||||
|
torch.Tensor(B, D, T)
|
||||||
|
"""
|
||||||
|
# The CNN used here is causal convolution
|
||||||
|
if cache is None:
|
||||||
|
y = F.pad(x, (self.padding, 0), value=0.0)
|
||||||
|
else:
|
||||||
|
y = torch.cat((cache, x), dim=2)
|
||||||
|
assert y.size(2) > self.padding
|
||||||
|
new_cache = y[:, :, -self.padding:]
|
||||||
|
|
||||||
|
# self.cnn is defined in the subclass of Block
|
||||||
|
y = self.cnn(y)
|
||||||
|
y = y + x # residual connection
|
||||||
|
return y, new_cache
|
||||||
|
|
||||||
|
|
||||||
|
class CnnBlock(Block):
|
||||||
|
def __init__(self,
|
||||||
|
channel: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation: int,
|
||||||
|
dropout: float = 0.1):
|
||||||
|
super().__init__(channel, kernel_size, dilation, dropout)
|
||||||
self.cnn = nn.Sequential(
|
self.cnn = nn.Sequential(
|
||||||
nn.Conv1d(channel,
|
nn.Conv1d(channel,
|
||||||
channel,
|
channel,
|
||||||
@ -40,26 +68,8 @@ class CnnBlock(nn.Module):
|
|||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x(torch.Tensor): Input tensor (B, D, T)
|
|
||||||
Returns:
|
|
||||||
torch.Tensor(B, D, T)
|
|
||||||
"""
|
|
||||||
if cache is None:
|
|
||||||
y = F.pad(x, (self.padding, 0), value=0.0)
|
|
||||||
else:
|
|
||||||
y = torch.cat((cache, x), dim=2)
|
|
||||||
assert y.size(2) > self.padding
|
|
||||||
new_cache = y[:, :, -self.padding:]
|
|
||||||
|
|
||||||
y = self.cnn(y)
|
class DsCnnBlock(Block):
|
||||||
y = y + x # residual connection
|
|
||||||
return y, new_cache
|
|
||||||
|
|
||||||
|
|
||||||
class DsCnnBlock(nn.Module):
|
|
||||||
""" Depthwise Separable Convolution
|
""" Depthwise Separable Convolution
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -67,9 +77,7 @@ class DsCnnBlock(nn.Module):
|
|||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
dilation: int,
|
dilation: int,
|
||||||
dropout: float = 0.1):
|
dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__(channel, kernel_size, dilation, dropout)
|
||||||
# The CNN used here is causal convolution
|
|
||||||
self.padding = (kernel_size - 1) * dilation
|
|
||||||
self.cnn = nn.Sequential(
|
self.cnn = nn.Sequential(
|
||||||
nn.Conv1d(channel,
|
nn.Conv1d(channel,
|
||||||
channel,
|
channel,
|
||||||
@ -85,24 +93,6 @@ class DsCnnBlock(nn.Module):
|
|||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x(torch.Tensor): Input tensor (B, D, T)
|
|
||||||
Returns:
|
|
||||||
torch.Tensor(B, D, T)
|
|
||||||
"""
|
|
||||||
if cache is None:
|
|
||||||
y = F.pad(x, (self.padding, 0), value=0.0)
|
|
||||||
else:
|
|
||||||
y = torch.cat((cache, x), dim=2)
|
|
||||||
assert y.size(2) > self.padding
|
|
||||||
new_cache = y[:, :, -self.padding:]
|
|
||||||
|
|
||||||
y = self.cnn(y)
|
|
||||||
y = y + x # residual connection
|
|
||||||
return y, new_cache
|
|
||||||
|
|
||||||
|
|
||||||
class TCN(nn.Module):
|
class TCN(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user