handle the exception that remove_length==1 and self.causal is false
This commit is contained in:
parent
ba6919baaf
commit
262ca57133
@ -249,11 +249,12 @@ class MDTC(nn.Module):
|
|||||||
output_size = outputs_list[-1].shape[-1]
|
output_size = outputs_list[-1].shape[-1]
|
||||||
for x in outputs_list:
|
for x in outputs_list:
|
||||||
remove_length = x.shape[-1] - output_size
|
remove_length = x.shape[-1] - output_size
|
||||||
|
if not self.causal:
|
||||||
|
remove_length = remove_length // 2
|
||||||
if remove_length != 0:
|
if remove_length != 0:
|
||||||
if self.causal:
|
if self.causal:
|
||||||
normalized_outputs.append(x[:, :, remove_length:])
|
normalized_outputs.append(x[:, :, remove_length:])
|
||||||
else:
|
else:
|
||||||
remove_length = remove_length // 2
|
|
||||||
normalized_outputs.append(x[:, :,
|
normalized_outputs.append(x[:, :,
|
||||||
remove_length:-remove_length])
|
remove_length:-remove_length])
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user