handle the exception that remove_length==1 and self.causal is false

This commit is contained in:
hp 2021-11-29 13:45:56 +08:00
parent ba6919baaf
commit 262ca57133

View File

@ -249,11 +249,12 @@ class MDTC(nn.Module):
output_size = outputs_list[-1].shape[-1] output_size = outputs_list[-1].shape[-1]
for x in outputs_list: for x in outputs_list:
remove_length = x.shape[-1] - output_size remove_length = x.shape[-1] - output_size
if not self.causal:
remove_length = remove_length // 2
if remove_length != 0: if remove_length != 0:
if self.causal: if self.causal:
normalized_outputs.append(x[:, :, remove_length:]) normalized_outputs.append(x[:, :, remove_length:])
else: else:
remove_length = remove_length // 2
normalized_outputs.append(x[:, :, normalized_outputs.append(x[:, :,
remove_length:-remove_length]) remove_length:-remove_length])
else: else: