diff --git a/kws/model/mdtc.py b/kws/model/mdtc.py index a50d401..36e46f1 100644 --- a/kws/model/mdtc.py +++ b/kws/model/mdtc.py @@ -249,11 +249,12 @@ class MDTC(nn.Module): output_size = outputs_list[-1].shape[-1] for x in outputs_list: remove_length = x.shape[-1] - output_size + if not self.causal: + remove_length = remove_length // 2 if remove_length != 0: if self.causal: normalized_outputs.append(x[:, :, remove_length:]) else: - remove_length = remove_length // 2 normalized_outputs.append(x[:, :, remove_length:-remove_length]) else: