From 262ca571333b82689dfee39c7381bef23d678488 Mon Sep 17 00:00:00 2001 From: hp Date: Mon, 29 Nov 2021 13:45:56 +0800 Subject: [PATCH] handle the exception that remove_length==1 and self.causal is false --- kws/model/mdtc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kws/model/mdtc.py b/kws/model/mdtc.py index a50d401..36e46f1 100644 --- a/kws/model/mdtc.py +++ b/kws/model/mdtc.py @@ -249,11 +249,12 @@ class MDTC(nn.Module): output_size = outputs_list[-1].shape[-1] for x in outputs_list: remove_length = x.shape[-1] - output_size + if not self.causal: + remove_length = remove_length // 2 if remove_length != 0: if self.causal: normalized_outputs.append(x[:, :, remove_length:]) else: - remove_length = remove_length // 2 normalized_outputs.append(x[:, :, remove_length:-remove_length]) else: