diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index 75eb5b0..8121830 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -49,8 +49,9 @@ def main(): dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float) cache = torch.zeros(1, model.hdim, - model.backbone.padding, + 0, dtype=torch.float) + torch.onnx.export(model, (dummy_input, cache), args.onnx_model, input_names=['input', 'cache'], @@ -71,7 +72,7 @@ def main(): meta = onnx_model.metadata_props.add() meta.key, meta.value = 'cache_dim', str(model.hdim) meta = onnx_model.metadata_props.add() - meta.key, meta.value = 'cache_len', str(model.backbone.padding) + meta.key, meta.value = 'cache_len', str(0) onnx.save(onnx_model, args.onnx_model) # Verify onnx precision