From 5037d51ed99dd25c97c24cded99683519a81e7a9 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Sat, 27 Aug 2022 16:44:22 +0800 Subject: [PATCH] [wekws] add cache support (#78) --- wekws/bin/export_onnx.py | 52 ++++++++++++++++++++++++++++------------ wekws/bin/score.py | 2 +- wekws/model/kws_model.py | 10 +++++--- wekws/model/tcn.py | 31 ++++++++++++++++++------ wekws/utils/executor.py | 4 ++-- 5 files changed, 71 insertions(+), 28 deletions(-) diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index e752909..89bf002 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -16,6 +16,7 @@ import argparse import torch import yaml + import onnxruntime as ort from wekws.model.kws_model import init_model @@ -25,9 +26,6 @@ from wekws.utils.checkpoint import load_checkpoint def get_args(): parser = argparse.ArgumentParser(description='export to onnx model') parser.add_argument('--config', required=True, help='config file') - parser.add_argument('--jit_model', - required=True, - help='pytorch jit script model') parser.add_argument('--onnx_model', required=True, help='output onnx model') @@ -48,21 +46,45 @@ def main(): model.eval() # dummy_input: (batch, time, feature_dim) dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float) - torch.onnx.export(model, - dummy_input, + cache = torch.zeros(1, + model.hdim, + model.backbone.padding, + dtype=torch.float) + torch.onnx.export(model, (dummy_input, cache), args.onnx_model, - input_names=['input'], - output_names=['output'], - dynamic_axes={'input': { - 1: 'T' - }}, - opset_version=10) + input_names=['input', 'cache'], + output_names=['output', 'r_cache'], + dynamic_axes={ + 'input': { + 1: 'T' + }, + 'output': { + 1: 'T' + }}, + opset_version=13, + verbose=False, + do_constant_folding=True) - torch_output = model(dummy_input) + # Add hidden dim and cache size + onnx_model = onnx.load(args.onnx_model) + meta = onnx_model.metadata_props.add() + meta.key, meta.value = 'cache_dim', str(model.hdim) + meta = onnx_model.metadata_props.add() + meta.key, meta.value = 'cache_len', str(model.backbone.padding) + onnx.save(onnx_model, args.onnx_model) + + # Verify onnx precision + torch_output = model(dummy_input, cache) ort_sess = ort.InferenceSession(args.onnx_model) - onnx_input = dummy_input.numpy() - onnx_output = ort_sess.run(None, {'input': onnx_input}) - if torch.allclose(torch_output, torch.tensor(onnx_output[0])): + onnx_output = ort_sess.run(None, { + 'input': dummy_input.numpy(), + 'cache': cache.numpy() + }) + + if torch.allclose(torch_output[0], + torch.tensor(onnx_output[0]), atol=1e-6) and \ + torch.allclose(torch_output[1], + torch.tensor(onnx_output[1]), atol=1e-6): print('Export to onnx succeed!') else: print('''Export to onnx succeed, but pytorch/onnx have different diff --git a/wekws/bin/score.py b/wekws/bin/score.py index f5d554f..b2c0483 100644 --- a/wekws/bin/score.py +++ b/wekws/bin/score.py @@ -109,7 +109,7 @@ def main(): keys, feats, target, lengths = batch feats = feats.to(device) lengths = lengths.to(device) - logits = model(feats) + logits, _ = model(feats) num_keywords = logits.shape[2] logits = logits.cpu() for i in range(len(keys)): diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index 3ca5b39..7e44bb4 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -59,14 +59,18 @@ class KWSModel(nn.Module): self.classifier = classifier self.activation = activation - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_cmvn is not None: x = self.global_cmvn(x) x = self.preprocessing(x) - x, _ = self.backbone(x) + x, out_cache = self.backbone(x, in_cache) x = self.classifier(x) x = self.activation(x) - return x + return x, out_cache def fuse_modules(self): self.preprocessing.fuse_modules() diff --git a/wekws/model/tcn.py b/wekws/model/tcn.py index 84366b3..2c76ed4 100644 --- a/wekws/model/tcn.py +++ b/wekws/model/tcn.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Tuple import torch import torch.nn as nn @@ -31,15 +31,21 @@ class Block(nn.Module): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + def forward( + self, + x: torch.Tensor, + cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x(torch.Tensor): Input tensor (B, D, T) + cache(torch.Tensor): Input cache(B, D, self.padding) Returns: - torch.Tensor(B, D, T) + torch.Tensor(B, D, T): output + torch.Tensor(B, D, self.padding): new cache """ # The CNN used here is causal convolution - if cache is None: + if cache.size(0) == 0: y = F.pad(x, (self.padding, 0), value=0.0) else: y = torch.cat((cache, x), dim=2) @@ -127,10 +133,15 @@ class TCN(nn.Module): self.network.append( block_class(channel, kernel_size, dilation, dropout)) - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + def forward( + self, + x: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x (torch.Tensor): Input tensor (B, T, D) + in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size Returns: torch.Tensor(B, T, D) @@ -138,9 +149,15 @@ class TCN(nn.Module): """ x = x.transpose(1, 2) # (B, D, T) out_caches = [] + offset = 0 for block in self.network: - x, c = block(x) - out_caches.append(c) + if in_cache.size(0) > 0: + c_in = in_cache[:, :, offset:offset + block.padding] + else: + c_in = torch.zeros(0, 0, 0) + x, c_out = block(x, c_in) + out_caches.append(c_out) + offset += block.padding x = x.transpose(1, 2) # (B, T, D) new_cache = torch.cat(out_caches, dim=2) return x, new_cache diff --git a/wekws/utils/executor.py b/wekws/utils/executor.py index 6740269..646884e 100644 --- a/wekws/utils/executor.py +++ b/wekws/utils/executor.py @@ -43,7 +43,7 @@ class Executor: num_utts = feats_lengths.size(0) if num_utts == 0: continue - logits = model(feats) + logits, _ = model(feats) loss_type = args.get('criterion', 'max_pooling') loss, acc = criterion(loss_type, logits, target, feats_lengths, min_duration) @@ -76,7 +76,7 @@ class Executor: num_utts = feats_lengths.size(0) if num_utts == 0: continue - logits = model(feats) + logits, _ = model(feats) loss, acc = criterion(args.get('criterion', 'max_pooling'), logits, target, feats_lengths) if torch.isfinite(loss):