handle the exception that remove_length==1 and self.causal is false
This commit is contained in:
parent
ba6919baaf
commit
262ca57133
@ -249,11 +249,12 @@ class MDTC(nn.Module):
|
||||
output_size = outputs_list[-1].shape[-1]
|
||||
for x in outputs_list:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if not self.causal:
|
||||
remove_length = remove_length // 2
|
||||
if remove_length != 0:
|
||||
if self.causal:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
else:
|
||||
remove_length = remove_length // 2
|
||||
normalized_outputs.append(x[:, :,
|
||||
remove_length:-remove_length])
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user