From 5236d42800b5af8273aac08eea05b04948c3139a Mon Sep 17 00:00:00 2001 From: xiaohou Date: Tue, 30 Nov 2021 17:14:30 +0800 Subject: [PATCH] Update mdtc.py --- kws/model/mdtc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kws/model/mdtc.py b/kws/model/mdtc.py index 74d2535..e0c3778 100644 --- a/kws/model/mdtc.py +++ b/kws/model/mdtc.py @@ -227,9 +227,9 @@ class MDTC(nn.Module): output_size = outputs_list[-1].shape[-1] for x in outputs_list: remove_length = x.shape[-1] - output_size - if self.causal and (remove_length > 0): + if self.causal and remove_length > 0: normalized_outputs.append(x[:, :, remove_length:]) - elif (not self.causal) and (remove_length > 1): + elif not self.causal and remove_length > 1: half_remove_length = remove_length // 2 normalized_outputs.append( x[:, :, half_remove_length:-half_remove_length] @@ -245,7 +245,7 @@ class MDTC(nn.Module): if __name__ == '__main__': - mdtc = MDTC(3, 4, 80, 64, 3, causal=True) + mdtc = MDTC(3, 4, 80, 64, 5, causal=True) print(mdtc) num_params = sum(p.numel() for p in mdtc.parameters())