fix export in export_onnx (#71)
This commit is contained in:
parent
663a31d9ea
commit
41a3432198
@ -156,6 +156,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
--jit_model $dir/$jit_model
|
||||
python kws/bin/export_onnx.py \
|
||||
--config $dir/config.yaml \
|
||||
--jit_model $dir/$jit_model \
|
||||
--checkpoint $score_checkpoint \
|
||||
--onnx_model $dir/$onnx_model
|
||||
fi
|
||||
@ -18,6 +18,9 @@ import torch
|
||||
import yaml
|
||||
import onnxruntime as ort
|
||||
|
||||
from kws.model.kws_model import init_model
|
||||
from kws.utils.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export to onnx model')
|
||||
@ -28,6 +31,7 @@ def get_args():
|
||||
parser.add_argument('--onnx_model',
|
||||
required=True,
|
||||
help='output onnx model')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -37,8 +41,11 @@ def main():
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feature_dim = configs['model']['input_dim']
|
||||
model = torch.jit.load(args.jit_model)
|
||||
model = init_model(configs['model'])
|
||||
print(model)
|
||||
|
||||
load_checkpoint(model, args.checkpoint)
|
||||
model.eval()
|
||||
# dummy_input: (batch, time, feature_dim)
|
||||
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
||||
torch.onnx.export(model,
|
||||
@ -48,7 +55,8 @@ def main():
|
||||
output_names=['output'],
|
||||
dynamic_axes={'input': {
|
||||
1: 'T'
|
||||
}})
|
||||
}},
|
||||
opset_version=10)
|
||||
|
||||
torch_output = model(dummy_input)
|
||||
ort_sess = ort.InferenceSession(args.onnx_model)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user