Merge pull request #17 from wenet-e2e/dev_jingyonghou

[network] update mdtc.py to prevent possible errors and remove useless functions
This commit is contained in:
Binbin Zhang 2021-11-30 17:20:09 +08:00 committed by GitHub
commit 1909fcd360
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -186,6 +186,7 @@ class MDTC(nn.Module):
causal: bool, causal: bool,
): ):
super(MDTC, self).__init__() super(MDTC, self).__init__()
assert kernel_size % 2 == 0
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.causal = causal self.causal = causal
self.preprocessor = TCNBlock(in_channels, self.preprocessor = TCNBlock(in_channels,
@ -204,29 +205,6 @@ class MDTC(nn.Module):
self.half_receptive_fields = self.receptive_fields // 2 self.half_receptive_fields = self.receptive_fields // 2
print('Receptive Fields: %d' % self.receptive_fields) print('Receptive Fields: %d' % self.receptive_fields)
def normalize_length_causal(self, skip_connections: list):
output_size = skip_connections[-1].shape[-1]
normalized_outputs = []
for x in skip_connections:
remove_length = x.shape[-1] - output_size
if remove_length != 0:
normalized_outputs.append(x[:, :, remove_length:])
else:
normalized_outputs.append(x)
return normalized_outputs
def normalize_length(self, skip_connections: list):
output_size = skip_connections[-1].shape[-1]
normalized_outputs = []
for x in skip_connections:
remove_length = (x.shape[-1] - output_size) // 2
if remove_length != 0:
normalized_outputs.append(x[:, :,
remove_length:-remove_length])
else:
normalized_outputs.append(x)
return normalized_outputs
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.causal: if self.causal:
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
@ -249,13 +227,13 @@ class MDTC(nn.Module):
output_size = outputs_list[-1].shape[-1] output_size = outputs_list[-1].shape[-1]
for x in outputs_list: for x in outputs_list:
remove_length = x.shape[-1] - output_size remove_length = x.shape[-1] - output_size
if remove_length != 0: if self.causal and remove_length > 0:
if self.causal:
normalized_outputs.append(x[:, :, remove_length:]) normalized_outputs.append(x[:, :, remove_length:])
else: elif not self.causal and remove_length > 1:
remove_length = remove_length // 2 half_remove_length = remove_length // 2
normalized_outputs.append(x[:, :, normalized_outputs.append(
remove_length:-remove_length]) x[:, :, half_remove_length:-half_remove_length]
)
else: else:
normalized_outputs.append(x) normalized_outputs.append(x)