From 41a3432198f600bbbb049de91f6766e7f63df06f Mon Sep 17 00:00:00 2001 From: ryoha <48485650+ryoha000@users.noreply.github.com> Date: Sun, 29 May 2022 10:30:53 +0900 Subject: [PATCH] fix export in export_onnx (#71) --- examples/hi_xiaowen/s0/run.sh | 2 +- kws/bin/export_onnx.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 3c964c5..14f1bed 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -156,6 +156,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then --jit_model $dir/$jit_model python kws/bin/export_onnx.py \ --config $dir/config.yaml \ - --jit_model $dir/$jit_model \ + --checkpoint $score_checkpoint \ --onnx_model $dir/$onnx_model fi \ No newline at end of file diff --git a/kws/bin/export_onnx.py b/kws/bin/export_onnx.py index e261e4b..bb9a29d 100644 --- a/kws/bin/export_onnx.py +++ b/kws/bin/export_onnx.py @@ -18,6 +18,9 @@ import torch import yaml import onnxruntime as ort +from kws.model.kws_model import init_model +from kws.utils.checkpoint import load_checkpoint + def get_args(): parser = argparse.ArgumentParser(description='export to onnx model') @@ -28,6 +31,7 @@ def get_args(): parser.add_argument('--onnx_model', required=True, help='output onnx model') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') args = parser.parse_args() return args @@ -37,8 +41,11 @@ def main(): with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) feature_dim = configs['model']['input_dim'] - model = torch.jit.load(args.jit_model) + model = init_model(configs['model']) print(model) + + load_checkpoint(model, args.checkpoint) + model.eval() # dummy_input: (batch, time, feature_dim) dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float) torch.onnx.export(model, @@ -48,7 +55,8 @@ def main(): output_names=['output'], dynamic_axes={'input': { 1: 'T' - }}) + }}, + opset_version=10) torch_output = model(dummy_input) ort_sess = ort.InferenceSession(args.onnx_model)