From 1628f8c2219258380f0809dc00b410f154198017 Mon Sep 17 00:00:00 2001 From: lvanchao1 Date: Fri, 1 Apr 2022 18:12:29 +0800 Subject: [PATCH] add streaming inference at tcn and ds_tcn model --- examples/hi_xiaowen/s0/run.sh | 8 +++++--- kws/bin/score.py | 6 ++++++ kws/model/kws_model.py | 35 +++++++++++++++++++++++++++++------ kws/model/tcn.py | 22 ++++++++++++++++++++++ 4 files changed, 62 insertions(+), 9 deletions(-) diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 3c964c5..d5d8438 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -101,7 +101,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --batch_size 256 \ --checkpoint $score_checkpoint \ --score_file $result_dir/score.txt \ - --num_workers 8 + --num_workers 8 \ + --streaming_chunk 20 for keyword in 0 1; do python kws/bin/compute_det.py \ @@ -136,7 +137,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --jit_model \ --checkpoint $dir/$quantize_score_checkpoint \ --score_file $result_dir/score.txt \ - --num_workers 8 + --num_workers 8 \ + --streaming_chunk 20 for keyword in 0 1; do python kws/bin/compute_det.py \ --keyword $keyword \ @@ -158,4 +160,4 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --config $dir/config.yaml \ --jit_model $dir/$jit_model \ --onnx_model $dir/$onnx_model -fi \ No newline at end of file +fi diff --git a/kws/bin/score.py b/kws/bin/score.py index a894704..bb12455 100644 --- a/kws/bin/score.py +++ b/kws/bin/score.py @@ -55,6 +55,10 @@ def get_args(): default=100, type=int, help='prefetch number') + parser.add_argument('--streaming_chunk', + default=-1, + type=int, + help='streaming chunk size for inference') parser.add_argument('--score_file', required=True, help='output score file') @@ -84,6 +88,8 @@ def main(): test_conf['feature_extraction_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_size'] = args.batch_size + configs['model']['streaming_chunk'] = args.streaming_chunk + test_dataset = Dataset(args.test_data, test_conf) test_data_loader = DataLoader(test_dataset, batch_size=None, diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 0744524..749e181 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -46,6 +46,7 @@ class KWSModel(nn.Module): backbone: nn.Module, classifier: nn.Module, activation: nn.Module, + streaming_chunk: int, ): super().__init__() self.idim = idim @@ -56,15 +57,33 @@ class KWSModel(nn.Module): self.backbone = backbone self.classifier = classifier self.activation = activation + self.streaming_chunk = streaming_chunk def forward(self, x: torch.Tensor) -> torch.Tensor: if self.global_cmvn is not None: x = self.global_cmvn(x) - x = self.preprocessing(x) - x, _ = self.backbone(x) - x = self.classifier(x) - x = self.activation(x) - return x + if self.streaming_chunk <= 0: + x = self.preprocessing(x) + x, _ = self.backbone(x) + x = self.classifier(x) + x = self.activation(x) + return x + else: + stride = self.streaming_chunk + num_frames = x.size(1) #(B, T, D) + cache: Optional[torch.Tensor] = None + out = [] + for cur in range(0, num_frames, stride): + end = min(cur + stride, num_frames) + x_s = x[:, cur:end, :] + x_s = self.preprocessing(x_s) + x_s, cache = self.backbone.forward_chunk(x_s, cache) + x_s = self.classifier(x_s) + x_s = self.activation(x_s) + out.append(x_s) + out_ = torch.cat(out, dim=1) + # wait all num_frames out, then return + return out_ def fuse_modules(self): self.preprocessing.fuse_modules() @@ -87,6 +106,10 @@ def init_model(configs): output_dim = configs['output_dim'] hidden_dim = configs['hidden_dim'] + streaming_chunk = -1 + if ('streaming_chunk' in configs and configs['streaming_chunk'] > 0): + streaming_chunk = configs['streaming_chunk'] + prep_type = configs['preprocessing']['type'] if prep_type == 'linear': preprocessing = LinearSubsampling1(input_dim, hidden_dim) @@ -157,5 +180,5 @@ def init_model(configs): activation = nn.Sigmoid() kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, - preprocessing, backbone, classifier, activation) + preprocessing, backbone, classifier, activation, streaming_chunk) return kws_model diff --git a/kws/model/tcn.py b/kws/model/tcn.py index 84366b3..a3017cf 100644 --- a/kws/model/tcn.py +++ b/kws/model/tcn.py @@ -143,6 +143,28 @@ class TCN(nn.Module): out_caches.append(c) x = x.transpose(1, 2) # (B, T, D) new_cache = torch.cat(out_caches, dim=2) + return x, new_cache + + def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + """ + Args: + x (torch.Tensor): Input tensor (B, T, D) + + Returns: + torch.Tensor(B, T, D) + torch.Tensor(B, D, C): C is the accumulated cache size + """ + x = x.transpose(1, 2) # (B, D, T) + out_caches = [] + cur_padding = 0 + if cache is None: + cache = torch.zeros(x.size(0), x.size(1), self.padding) + for block in self.network: + cur_padding += block.padding + x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding]) + out_caches.append(c) # c (B, D, T) + x = x.transpose(1, 2) # (B, T, D) + new_cache = torch.cat(out_caches, dim=2) return x, new_cache def fuse_modules(self):