From 64ccd5bb860cc6d2eaa839d1286acc7ac3547d44 Mon Sep 17 00:00:00 2001 From: Menglong Xu <32296227+mlxu995@users.noreply.github.com> Date: Tue, 8 Nov 2022 09:18:18 +0800 Subject: [PATCH] [wekws] add cache support for mdtc (#105) * [wekws] add cache support for mdtc * format Co-authored-by: 02Bigboy <570843154@qq.com> --- examples/hey_snips/s0/conf/mdtc_small.yaml | 2 +- examples/hi_xiaowen/s0/conf/mdtc.yaml | 3 +- examples/hi_xiaowen/s0/conf/mdtc_small.yaml | 3 +- examples/speechcommand_v1/s0/conf/mdtc.yaml | 4 +- wekws/model/kws_model.py | 2 +- wekws/model/mdtc.py | 145 ++++++++++++-------- 6 files changed, 94 insertions(+), 65 deletions(-) diff --git a/examples/hey_snips/s0/conf/mdtc_small.yaml b/examples/hey_snips/s0/conf/mdtc_small.yaml index 06932b9..c178eb1 100644 --- a/examples/hey_snips/s0/conf/mdtc_small.yaml +++ b/examples/hey_snips/s0/conf/mdtc_small.yaml @@ -28,7 +28,7 @@ dataset_conf: model: hidden_dim: 32 preprocessing: - type: none + type: linear backbone: type: mdtc num_stack: 3 diff --git a/examples/hi_xiaowen/s0/conf/mdtc.yaml b/examples/hi_xiaowen/s0/conf/mdtc.yaml index 4ec5f7e..0e1b44a 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc.yaml @@ -28,13 +28,14 @@ dataset_conf: model: hidden_dim: 64 preprocessing: - type: none + type: linear backbone: type: mdtc num_stack: 4 stack_size: 4 kernel_size: 5 hidden_dim: 64 + causal: True optim: adam optim_conf: diff --git a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml index a602a41..936a5d5 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml @@ -28,13 +28,14 @@ dataset_conf: model: hidden_dim: 32 preprocessing: - type: none + type: linear backbone: type: mdtc num_stack: 3 stack_size: 4 kernel_size: 5 hidden_dim: 32 + causal: True optim: adam optim_conf: diff --git a/examples/speechcommand_v1/s0/conf/mdtc.yaml b/examples/speechcommand_v1/s0/conf/mdtc.yaml index d9046cd..768bf19 100644 --- a/examples/speechcommand_v1/s0/conf/mdtc.yaml +++ b/examples/speechcommand_v1/s0/conf/mdtc.yaml @@ -28,14 +28,14 @@ dataset_conf: model: hidden_dim: 64 preprocessing: - type: none + type: linear backbone: type: mdtc num_stack: 4 stack_size: 4 kernel_size: 5 hidden_dim: 64 - causal: False + causal: True classifier: type: global dropout: 0.5 diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index 92b76ab..993cbd4 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -131,7 +131,7 @@ def init_model(configs): causal = configs['backbone']['causal'] backbone = MDTC(num_stack, stack_size, - input_dim, + hidden_dim, hidden_dim, kernel_size, causal=causal) diff --git a/wekws/model/mdtc.py b/wekws/model/mdtc.py index 385cd63..090dde5 100644 --- a/wekws/model/mdtc.py +++ b/wekws/model/mdtc.py @@ -32,7 +32,7 @@ class DSDilatedConv1d(nn.Module): bias: bool = True, ): super(DSDilatedConv1d, self).__init__() - self.receptive_fields = dilation * (kernel_size - 1) + self.padding = dilation * (kernel_size - 1) self.conv = nn.Conv1d( in_channels, in_channels, @@ -73,8 +73,8 @@ class TCNBlock(nn.Module): self.kernel_size = kernel_size self.dilation = dilation self.causal = causal - self.receptive_fields = dilation * (kernel_size - 1) - self.half_receptive_fields = self.receptive_fields // 2 + self.padding = dilation * (kernel_size - 1) + self.half_padding = self.padding // 2 self.conv1 = DSDilatedConv1d( in_channels=in_channels, out_channels=res_channels, @@ -90,19 +90,33 @@ class TCNBlock(nn.Module): self.bn2 = nn.BatchNorm1d(res_channels) self.relu2 = nn.ReLU() - def forward(self, inputs: torch.Tensor): - outputs = self.relu1(self.bn1(self.conv1(inputs))) - outputs = self.bn2(self.conv2(outputs)) - if self.causal: - inputs = inputs[:, :, self.receptive_fields:] + def forward( + self, + inputs: torch.Tensor, + cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + inputs(torch.Tensor): Input tensor (B, D, T) + cache(torch.Tensor): Input cache(B, D, self.padding) + Returns: + torch.Tensor(B, D, T): outputs + torch.Tensor(B, D, self.padding): new cache + """ + if cache.size(0) == 0: + outputs = F.pad(inputs, (self.padding, 0), value=0.0) else: - inputs = inputs[:, :, self. - half_receptive_fields:-self.half_receptive_fields] + outputs = torch.cat((cache, inputs), dim=2) + assert outputs.size(2) > self.padding + new_cache = outputs[:, :, -self.padding:] + + outputs = self.relu1(self.bn1(self.conv1(outputs))) + outputs = self.bn2(self.conv2(outputs)) if self.in_channels == self.res_channels: res_out = self.relu2(outputs + inputs) else: res_out = self.relu2(outputs) - return res_out + return res_out, new_cache class TCNStack(nn.Module): @@ -123,14 +137,13 @@ class TCNStack(nn.Module): self.kernel_size = kernel_size self.causal = causal self.res_blocks = self.stack_tcn_blocks() - self.receptive_fields = self.calculate_receptive_fields() - self.res_blocks = nn.Sequential(*self.res_blocks) + self.padding = self.calculate_padding() - def calculate_receptive_fields(self): - receptive_fields = 0 + def calculate_padding(self): + padding = 0 for block in self.res_blocks: - receptive_fields += block.receptive_fields - return receptive_fields + padding += block.padding + return padding def build_dilations(self): dilations = [] @@ -162,10 +175,24 @@ class TCNStack(nn.Module): )) return res_blocks - def forward(self, inputs: torch.Tensor): - outputs = inputs - outputs = self.res_blocks(outputs) - return outputs + def forward( + self, + inputs: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: + outputs = inputs # (B, D, T) + out_caches = [] + offset = 0 + for block in self.res_blocks: + if in_cache.size(0) > 0: + c_in = in_cache[:, :, offset:offset + block.padding] + else: + c_in = torch.zeros(0, 0, 0) + outputs, c_out = block(outputs, c_in) + out_caches.append(c_out) + offset += block.padding + new_cache = torch.cat(out_caches, dim=2) + return outputs, new_cache class MDTC(nn.Module): @@ -190,6 +217,7 @@ class MDTC(nn.Module): super(MDTC, self).__init__() assert kernel_size % 2 == 1 self.kernel_size = kernel_size + assert causal is True, "we now only support causal mdtc" self.causal = causal self.preprocessor = TCNBlock(in_channels, res_channels, @@ -198,66 +226,65 @@ class MDTC(nn.Module): causal=causal) self.relu = nn.ReLU() self.blocks = nn.ModuleList() - self.receptive_fields = self.preprocessor.receptive_fields + self.padding = self.preprocessor.padding for i in range(stack_num): self.blocks.append( TCNStack(res_channels, stack_size, 1, res_channels, kernel_size, causal)) - self.receptive_fields += self.blocks[-1].receptive_fields - self.half_receptive_fields = self.receptive_fields // 2 - print('Receptive Fields: %d' % self.receptive_fields) + self.padding += self.blocks[-1].padding + self.half_padding = self.padding // 2 + print('Receptive Fields: %d' % self.padding) def forward( self, x: torch.Tensor, in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) ) -> Tuple[torch.Tensor, torch.Tensor]: - if self.causal: - outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), - 'constant') - else: - outputs = F.pad( - x, - (0, 0, self.half_receptive_fields, self.half_receptive_fields, - 0, 0), - 'constant', - ) - outputs = outputs.transpose(1, 2) + outputs = x.transpose(1, 2) # (B, D, T) outputs_list = [] - outputs = self.relu(self.preprocessor(outputs)) - for block in self.blocks: - outputs = block(outputs) - outputs_list.append(outputs) + out_caches = [] + offset = 0 + if in_cache.size(0) > 0: + c_in = in_cache[:, :, offset:offset + self.preprocessor.padding] + else: + c_in = torch.zeros(0, 0, 0) - normalized_outputs = [] - output_size = outputs_list[-1].shape[-1] - for x in outputs_list: - remove_length = x.shape[-1] - output_size - if self.causal and remove_length > 0: - normalized_outputs.append(x[:, :, remove_length:]) - elif not self.causal and remove_length > 1: - half_remove_length = remove_length // 2 - normalized_outputs.append( - x[:, :, half_remove_length:-half_remove_length] - ) + outputs, c_out = self.preprocessor(outputs, c_in) + outputs = self.relu(outputs) + out_caches.append(c_out) + offset += self.preprocessor.padding + for block in self.blocks: + if in_cache.size(0) > 0: + c_in = in_cache[:, :, offset:offset + block.padding] else: - normalized_outputs.append(x) + c_in = torch.zeros(0, 0, 0) + outputs, c_out = block(outputs, c_in) + outputs_list.append(outputs) + out_caches.append(c_out) + offset += block.padding outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype) - for x in normalized_outputs: + for x in outputs_list: outputs += x - outputs = outputs.transpose(1, 2) - # TODO(Binbin Zhang): Fix cache - return outputs, in_cache + outputs = outputs.transpose(1, 2) # (B, T, D) + new_cache = torch.cat(out_caches, dim=2) + return outputs, new_cache if __name__ == '__main__': - mdtc = MDTC(3, 4, 80, 64, 5, causal=True) + mdtc = MDTC(3, 4, 64, 64, 5, causal=True) print(mdtc) num_params = sum(p.numel() for p in mdtc.parameters()) print('the number of model params: {}'.format(num_params)) - x = torch.zeros(128, 200, 80) # batch-size * time * dim - y, _ = mdtc(x) # batch-size * time * dim + x = torch.randn(128, 200, 64) # batch-size * time * dim + y, c = mdtc(x) print('input shape: {}'.format(x.shape)) print('output shape: {}'.format(y.shape)) + print('cache shape: {}'.format(c.shape)) + + print('########################################') + for _ in range(10): + y, c = mdtc(y, c) + print('output shape: {}'.format(y.shape)) + print('cache shape: {}'.format(c.shape))