Merge pull request #17 from wenet-e2e/dev_jingyonghou
[network] update mdtc.py to prevent possible errors and remove useless functions
This commit is contained in:
commit
1909fcd360
@ -186,6 +186,7 @@ class MDTC(nn.Module):
|
|||||||
causal: bool,
|
causal: bool,
|
||||||
):
|
):
|
||||||
super(MDTC, self).__init__()
|
super(MDTC, self).__init__()
|
||||||
|
assert kernel_size % 2 == 0
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.preprocessor = TCNBlock(in_channels,
|
self.preprocessor = TCNBlock(in_channels,
|
||||||
@ -204,29 +205,6 @@ class MDTC(nn.Module):
|
|||||||
self.half_receptive_fields = self.receptive_fields // 2
|
self.half_receptive_fields = self.receptive_fields // 2
|
||||||
print('Receptive Fields: %d' % self.receptive_fields)
|
print('Receptive Fields: %d' % self.receptive_fields)
|
||||||
|
|
||||||
def normalize_length_causal(self, skip_connections: list):
|
|
||||||
output_size = skip_connections[-1].shape[-1]
|
|
||||||
normalized_outputs = []
|
|
||||||
for x in skip_connections:
|
|
||||||
remove_length = x.shape[-1] - output_size
|
|
||||||
if remove_length != 0:
|
|
||||||
normalized_outputs.append(x[:, :, remove_length:])
|
|
||||||
else:
|
|
||||||
normalized_outputs.append(x)
|
|
||||||
return normalized_outputs
|
|
||||||
|
|
||||||
def normalize_length(self, skip_connections: list):
|
|
||||||
output_size = skip_connections[-1].shape[-1]
|
|
||||||
normalized_outputs = []
|
|
||||||
for x in skip_connections:
|
|
||||||
remove_length = (x.shape[-1] - output_size) // 2
|
|
||||||
if remove_length != 0:
|
|
||||||
normalized_outputs.append(x[:, :,
|
|
||||||
remove_length:-remove_length])
|
|
||||||
else:
|
|
||||||
normalized_outputs.append(x)
|
|
||||||
return normalized_outputs
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.causal:
|
if self.causal:
|
||||||
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||||
@ -249,13 +227,13 @@ class MDTC(nn.Module):
|
|||||||
output_size = outputs_list[-1].shape[-1]
|
output_size = outputs_list[-1].shape[-1]
|
||||||
for x in outputs_list:
|
for x in outputs_list:
|
||||||
remove_length = x.shape[-1] - output_size
|
remove_length = x.shape[-1] - output_size
|
||||||
if remove_length != 0:
|
if self.causal and remove_length > 0:
|
||||||
if self.causal:
|
normalized_outputs.append(x[:, :, remove_length:])
|
||||||
normalized_outputs.append(x[:, :, remove_length:])
|
elif not self.causal and remove_length > 1:
|
||||||
else:
|
half_remove_length = remove_length // 2
|
||||||
remove_length = remove_length // 2
|
normalized_outputs.append(
|
||||||
normalized_outputs.append(x[:, :,
|
x[:, :, half_remove_length:-half_remove_length]
|
||||||
remove_length:-remove_length])
|
)
|
||||||
else:
|
else:
|
||||||
normalized_outputs.append(x)
|
normalized_outputs.append(x)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user