fix export in export_onnx (#71)

This commit is contained in:
ryoha 2022-05-29 10:30:53 +09:00 committed by GitHub
parent 663a31d9ea
commit 41a3432198
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View File

@ -156,6 +156,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--jit_model $dir/$jit_model --jit_model $dir/$jit_model
python kws/bin/export_onnx.py \ python kws/bin/export_onnx.py \
--config $dir/config.yaml \ --config $dir/config.yaml \
--jit_model $dir/$jit_model \ --checkpoint $score_checkpoint \
--onnx_model $dir/$onnx_model --onnx_model $dir/$onnx_model
fi fi

View File

@ -18,6 +18,9 @@ import torch
import yaml import yaml
import onnxruntime as ort import onnxruntime as ort
from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint
def get_args(): def get_args():
parser = argparse.ArgumentParser(description='export to onnx model') parser = argparse.ArgumentParser(description='export to onnx model')
@ -28,6 +31,7 @@ def get_args():
parser.add_argument('--onnx_model', parser.add_argument('--onnx_model',
required=True, required=True,
help='output onnx model') help='output onnx model')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -37,8 +41,11 @@ def main():
with open(args.config, 'r') as fin: with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader) configs = yaml.load(fin, Loader=yaml.FullLoader)
feature_dim = configs['model']['input_dim'] feature_dim = configs['model']['input_dim']
model = torch.jit.load(args.jit_model) model = init_model(configs['model'])
print(model) print(model)
load_checkpoint(model, args.checkpoint)
model.eval()
# dummy_input: (batch, time, feature_dim) # dummy_input: (batch, time, feature_dim)
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float) dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
torch.onnx.export(model, torch.onnx.export(model,
@ -48,7 +55,8 @@ def main():
output_names=['output'], output_names=['output'],
dynamic_axes={'input': { dynamic_axes={'input': {
1: 'T' 1: 'T'
}}) }},
opset_version=10)
torch_output = model(dummy_input) torch_output = model(dummy_input)
ort_sess = ort.InferenceSession(args.onnx_model) ort_sess = ort.InferenceSession(args.onnx_model)