diff --git a/kws/model/mdtc.py b/kws/model/mdtc.py index a50d401..6bdf639 100644 --- a/kws/model/mdtc.py +++ b/kws/model/mdtc.py @@ -187,6 +187,9 @@ class MDTC(nn.Module): ): super(MDTC, self).__init__() self.kernel_size = kernel_size + if kernel_size % 2 == 0: + print("The kernel size of MDTC must be an odd number") + exit(1) self.causal = causal self.preprocessor = TCNBlock(in_channels, res_channels, @@ -204,29 +207,6 @@ class MDTC(nn.Module): self.half_receptive_fields = self.receptive_fields // 2 print('Receptive Fields: %d' % self.receptive_fields) - def normalize_length_causal(self, skip_connections: list): - output_size = skip_connections[-1].shape[-1] - normalized_outputs = [] - for x in skip_connections: - remove_length = x.shape[-1] - output_size - if remove_length != 0: - normalized_outputs.append(x[:, :, remove_length:]) - else: - normalized_outputs.append(x) - return normalized_outputs - - def normalize_length(self, skip_connections: list): - output_size = skip_connections[-1].shape[-1] - normalized_outputs = [] - for x in skip_connections: - remove_length = (x.shape[-1] - output_size) // 2 - if remove_length != 0: - normalized_outputs.append(x[:, :, - remove_length:-remove_length]) - else: - normalized_outputs.append(x) - return normalized_outputs - def forward(self, x: torch.Tensor): if self.causal: outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), @@ -249,13 +229,13 @@ class MDTC(nn.Module): output_size = outputs_list[-1].shape[-1] for x in outputs_list: remove_length = x.shape[-1] - output_size - if remove_length != 0: - if self.causal: - normalized_outputs.append(x[:, :, remove_length:]) - else: - remove_length = remove_length // 2 - normalized_outputs.append(x[:, :, - remove_length:-remove_length]) + if self.causal and (remove_length > 0): + normalized_outputs.append(x[:, :, remove_length:]) + elif (not self.causal) and (remove_length > 1): + half_remove_length = remove_length // 2 + normalized_outputs.append( + x[:, :, half_remove_length:-half_remove_length] + ) else: normalized_outputs.append(x) @@ -267,7 +247,7 @@ class MDTC(nn.Module): if __name__ == '__main__': - mdtc = MDTC(3, 4, 80, 64, 5, causal=True) + mdtc = MDTC(3, 4, 80, 64, 3, causal=True) print(mdtc) num_params = sum(p.numel() for p in mdtc.parameters())