From 515222c83e05c0b147ca09eee74e31392135e7ae Mon Sep 17 00:00:00 2001 From: lvanchao1 Date: Fri, 1 Apr 2022 19:28:02 +0800 Subject: [PATCH] fix lint --- examples/hi_xiaowen/s0/conf/mdtc.yaml | 1 + examples/hi_xiaowen/s0/conf/mdtc_small.yaml | 1 + examples/hi_xiaowen/s0/conf/tcn.yaml | 2 +- kws/model/kws_model.py | 5 +++-- kws/model/tcn.py | 8 ++++---- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/hi_xiaowen/s0/conf/mdtc.yaml b/examples/hi_xiaowen/s0/conf/mdtc.yaml index 4ec5f7e..37f20da 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc.yaml @@ -35,6 +35,7 @@ model: stack_size: 4 kernel_size: 5 hidden_dim: 64 + causal: true optim: adam optim_conf: diff --git a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml index a602a41..6bcffff 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml @@ -30,6 +30,7 @@ model: preprocessing: type: none backbone: + causal: true type: mdtc num_stack: 3 stack_size: 4 diff --git a/examples/hi_xiaowen/s0/conf/tcn.yaml b/examples/hi_xiaowen/s0/conf/tcn.yaml index 35ff1c0..2cbc046 100644 --- a/examples/hi_xiaowen/s0/conf/tcn.yaml +++ b/examples/hi_xiaowen/s0/conf/tcn.yaml @@ -41,6 +41,6 @@ optim_conf: training_config: grad_clip: 5 - max_epoch: 80 + max_epoch: 2 log_interval: 10 diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 749e181..61f9e70 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -70,7 +70,7 @@ class KWSModel(nn.Module): return x else: stride = self.streaming_chunk - num_frames = x.size(1) #(B, T, D) + num_frames = x.size(1) cache: Optional[torch.Tensor] = None out = [] for cur in range(0, num_frames, stride): @@ -180,5 +180,6 @@ def init_model(configs): activation = nn.Sigmoid() kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, - preprocessing, backbone, classifier, activation, streaming_chunk) + preprocessing, backbone, classifier, activation, + streaming_chunk) return kws_model diff --git a/kws/model/tcn.py b/kws/model/tcn.py index a3017cf..d2a2a89 100644 --- a/kws/model/tcn.py +++ b/kws/model/tcn.py @@ -143,7 +143,7 @@ class TCN(nn.Module): out_caches.append(c) x = x.transpose(1, 2) # (B, T, D) new_cache = torch.cat(out_caches, dim=2) - return x, new_cache + return x, new_cache def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): """ @@ -160,11 +160,11 @@ class TCN(nn.Module): if cache is None: cache = torch.zeros(x.size(0), x.size(1), self.padding) for block in self.network: - cur_padding += block.padding + cur_padding += block.padding x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding]) - out_caches.append(c) # c (B, D, T) + out_caches.append(c) x = x.transpose(1, 2) # (B, T, D) - new_cache = torch.cat(out_caches, dim=2) + new_cache = torch.cat(out_caches, dim=2) return x, new_cache def fuse_modules(self):