fix export onnx model, backbone has no attribute 'padding'
This commit is contained in:
parent
f3d6a0a40e
commit
29d43985da
@ -49,8 +49,9 @@ def main():
|
|||||||
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
||||||
cache = torch.zeros(1,
|
cache = torch.zeros(1,
|
||||||
model.hdim,
|
model.hdim,
|
||||||
model.backbone.padding,
|
0,
|
||||||
dtype=torch.float)
|
dtype=torch.float)
|
||||||
|
|
||||||
torch.onnx.export(model, (dummy_input, cache),
|
torch.onnx.export(model, (dummy_input, cache),
|
||||||
args.onnx_model,
|
args.onnx_model,
|
||||||
input_names=['input', 'cache'],
|
input_names=['input', 'cache'],
|
||||||
@ -71,7 +72,7 @@ def main():
|
|||||||
meta = onnx_model.metadata_props.add()
|
meta = onnx_model.metadata_props.add()
|
||||||
meta.key, meta.value = 'cache_dim', str(model.hdim)
|
meta.key, meta.value = 'cache_dim', str(model.hdim)
|
||||||
meta = onnx_model.metadata_props.add()
|
meta = onnx_model.metadata_props.add()
|
||||||
meta.key, meta.value = 'cache_len', str(model.backbone.padding)
|
meta.key, meta.value = 'cache_len', str(0)
|
||||||
onnx.save(onnx_model, args.onnx_model)
|
onnx.save(onnx_model, args.onnx_model)
|
||||||
|
|
||||||
# Verify onnx precision
|
# Verify onnx precision
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user