[wekws] add cache support for mdtc (#105)
* [wekws] add cache support for mdtc * format Co-authored-by: 02Bigboy <570843154@qq.com>
This commit is contained in:
parent
80285fa696
commit
64ccd5bb86
@ -28,7 +28,7 @@ dataset_conf:
|
||||
model:
|
||||
hidden_dim: 32
|
||||
preprocessing:
|
||||
type: none
|
||||
type: linear
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 3
|
||||
|
||||
@ -28,13 +28,14 @@ dataset_conf:
|
||||
model:
|
||||
hidden_dim: 64
|
||||
preprocessing:
|
||||
type: none
|
||||
type: linear
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 4
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 64
|
||||
causal: True
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
|
||||
@ -28,13 +28,14 @@ dataset_conf:
|
||||
model:
|
||||
hidden_dim: 32
|
||||
preprocessing:
|
||||
type: none
|
||||
type: linear
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 3
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 32
|
||||
causal: True
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
|
||||
@ -28,14 +28,14 @@ dataset_conf:
|
||||
model:
|
||||
hidden_dim: 64
|
||||
preprocessing:
|
||||
type: none
|
||||
type: linear
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 4
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 64
|
||||
causal: False
|
||||
causal: True
|
||||
classifier:
|
||||
type: global
|
||||
dropout: 0.5
|
||||
|
||||
@ -131,7 +131,7 @@ def init_model(configs):
|
||||
causal = configs['backbone']['causal']
|
||||
backbone = MDTC(num_stack,
|
||||
stack_size,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size,
|
||||
causal=causal)
|
||||
|
||||
@ -32,7 +32,7 @@ class DSDilatedConv1d(nn.Module):
|
||||
bias: bool = True,
|
||||
):
|
||||
super(DSDilatedConv1d, self).__init__()
|
||||
self.receptive_fields = dilation * (kernel_size - 1)
|
||||
self.padding = dilation * (kernel_size - 1)
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
@ -73,8 +73,8 @@ class TCNBlock(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
self.causal = causal
|
||||
self.receptive_fields = dilation * (kernel_size - 1)
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
self.padding = dilation * (kernel_size - 1)
|
||||
self.half_padding = self.padding // 2
|
||||
self.conv1 = DSDilatedConv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=res_channels,
|
||||
@ -90,19 +90,33 @@ class TCNBlock(nn.Module):
|
||||
self.bn2 = nn.BatchNorm1d(res_channels)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
outputs = self.relu1(self.bn1(self.conv1(inputs)))
|
||||
outputs = self.bn2(self.conv2(outputs))
|
||||
if self.causal:
|
||||
inputs = inputs[:, :, self.receptive_fields:]
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
inputs(torch.Tensor): Input tensor (B, D, T)
|
||||
cache(torch.Tensor): Input cache(B, D, self.padding)
|
||||
Returns:
|
||||
torch.Tensor(B, D, T): outputs
|
||||
torch.Tensor(B, D, self.padding): new cache
|
||||
"""
|
||||
if cache.size(0) == 0:
|
||||
outputs = F.pad(inputs, (self.padding, 0), value=0.0)
|
||||
else:
|
||||
inputs = inputs[:, :, self.
|
||||
half_receptive_fields:-self.half_receptive_fields]
|
||||
outputs = torch.cat((cache, inputs), dim=2)
|
||||
assert outputs.size(2) > self.padding
|
||||
new_cache = outputs[:, :, -self.padding:]
|
||||
|
||||
outputs = self.relu1(self.bn1(self.conv1(outputs)))
|
||||
outputs = self.bn2(self.conv2(outputs))
|
||||
if self.in_channels == self.res_channels:
|
||||
res_out = self.relu2(outputs + inputs)
|
||||
else:
|
||||
res_out = self.relu2(outputs)
|
||||
return res_out
|
||||
return res_out, new_cache
|
||||
|
||||
|
||||
class TCNStack(nn.Module):
|
||||
@ -123,14 +137,13 @@ class TCNStack(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.causal = causal
|
||||
self.res_blocks = self.stack_tcn_blocks()
|
||||
self.receptive_fields = self.calculate_receptive_fields()
|
||||
self.res_blocks = nn.Sequential(*self.res_blocks)
|
||||
self.padding = self.calculate_padding()
|
||||
|
||||
def calculate_receptive_fields(self):
|
||||
receptive_fields = 0
|
||||
def calculate_padding(self):
|
||||
padding = 0
|
||||
for block in self.res_blocks:
|
||||
receptive_fields += block.receptive_fields
|
||||
return receptive_fields
|
||||
padding += block.padding
|
||||
return padding
|
||||
|
||||
def build_dilations(self):
|
||||
dilations = []
|
||||
@ -162,10 +175,24 @@ class TCNStack(nn.Module):
|
||||
))
|
||||
return res_blocks
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
outputs = inputs
|
||||
outputs = self.res_blocks(outputs)
|
||||
return outputs
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
outputs = inputs # (B, D, T)
|
||||
out_caches = []
|
||||
offset = 0
|
||||
for block in self.res_blocks:
|
||||
if in_cache.size(0) > 0:
|
||||
c_in = in_cache[:, :, offset:offset + block.padding]
|
||||
else:
|
||||
c_in = torch.zeros(0, 0, 0)
|
||||
outputs, c_out = block(outputs, c_in)
|
||||
out_caches.append(c_out)
|
||||
offset += block.padding
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return outputs, new_cache
|
||||
|
||||
|
||||
class MDTC(nn.Module):
|
||||
@ -190,6 +217,7 @@ class MDTC(nn.Module):
|
||||
super(MDTC, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
self.kernel_size = kernel_size
|
||||
assert causal is True, "we now only support causal mdtc"
|
||||
self.causal = causal
|
||||
self.preprocessor = TCNBlock(in_channels,
|
||||
res_channels,
|
||||
@ -198,66 +226,65 @@ class MDTC(nn.Module):
|
||||
causal=causal)
|
||||
self.relu = nn.ReLU()
|
||||
self.blocks = nn.ModuleList()
|
||||
self.receptive_fields = self.preprocessor.receptive_fields
|
||||
self.padding = self.preprocessor.padding
|
||||
for i in range(stack_num):
|
||||
self.blocks.append(
|
||||
TCNStack(res_channels, stack_size, 1, res_channels,
|
||||
kernel_size, causal))
|
||||
self.receptive_fields += self.blocks[-1].receptive_fields
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
print('Receptive Fields: %d' % self.receptive_fields)
|
||||
self.padding += self.blocks[-1].padding
|
||||
self.half_padding = self.padding // 2
|
||||
print('Receptive Fields: %d' % self.padding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.causal:
|
||||
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||
'constant')
|
||||
else:
|
||||
outputs = F.pad(
|
||||
x,
|
||||
(0, 0, self.half_receptive_fields, self.half_receptive_fields,
|
||||
0, 0),
|
||||
'constant',
|
||||
)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
outputs = x.transpose(1, 2) # (B, D, T)
|
||||
outputs_list = []
|
||||
outputs = self.relu(self.preprocessor(outputs))
|
||||
for block in self.blocks:
|
||||
outputs = block(outputs)
|
||||
outputs_list.append(outputs)
|
||||
out_caches = []
|
||||
offset = 0
|
||||
if in_cache.size(0) > 0:
|
||||
c_in = in_cache[:, :, offset:offset + self.preprocessor.padding]
|
||||
else:
|
||||
c_in = torch.zeros(0, 0, 0)
|
||||
|
||||
normalized_outputs = []
|
||||
output_size = outputs_list[-1].shape[-1]
|
||||
for x in outputs_list:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if self.causal and remove_length > 0:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
elif not self.causal and remove_length > 1:
|
||||
half_remove_length = remove_length // 2
|
||||
normalized_outputs.append(
|
||||
x[:, :, half_remove_length:-half_remove_length]
|
||||
)
|
||||
outputs, c_out = self.preprocessor(outputs, c_in)
|
||||
outputs = self.relu(outputs)
|
||||
out_caches.append(c_out)
|
||||
offset += self.preprocessor.padding
|
||||
for block in self.blocks:
|
||||
if in_cache.size(0) > 0:
|
||||
c_in = in_cache[:, :, offset:offset + block.padding]
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
c_in = torch.zeros(0, 0, 0)
|
||||
outputs, c_out = block(outputs, c_in)
|
||||
outputs_list.append(outputs)
|
||||
out_caches.append(c_out)
|
||||
offset += block.padding
|
||||
|
||||
outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype)
|
||||
for x in normalized_outputs:
|
||||
for x in outputs_list:
|
||||
outputs += x
|
||||
outputs = outputs.transpose(1, 2)
|
||||
# TODO(Binbin Zhang): Fix cache
|
||||
return outputs, in_cache
|
||||
outputs = outputs.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return outputs, new_cache
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
|
||||
mdtc = MDTC(3, 4, 64, 64, 5, causal=True)
|
||||
print(mdtc)
|
||||
|
||||
num_params = sum(p.numel() for p in mdtc.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 80) # batch-size * time * dim
|
||||
y, _ = mdtc(x) # batch-size * time * dim
|
||||
x = torch.randn(128, 200, 64) # batch-size * time * dim
|
||||
y, c = mdtc(x)
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
print('cache shape: {}'.format(c.shape))
|
||||
|
||||
print('########################################')
|
||||
for _ in range(10):
|
||||
y, c = mdtc(y, c)
|
||||
print('output shape: {}'.format(y.shape))
|
||||
print('cache shape: {}'.format(c.shape))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user