fix export in export_onnx (#71)
This commit is contained in:
parent
663a31d9ea
commit
41a3432198
@ -156,6 +156,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|||||||
--jit_model $dir/$jit_model
|
--jit_model $dir/$jit_model
|
||||||
python kws/bin/export_onnx.py \
|
python kws/bin/export_onnx.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--jit_model $dir/$jit_model \
|
--checkpoint $score_checkpoint \
|
||||||
--onnx_model $dir/$onnx_model
|
--onnx_model $dir/$onnx_model
|
||||||
fi
|
fi
|
||||||
@ -18,6 +18,9 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
from kws.model.kws_model import init_model
|
||||||
|
from kws.utils.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(description='export to onnx model')
|
parser = argparse.ArgumentParser(description='export to onnx model')
|
||||||
@ -28,6 +31,7 @@ def get_args():
|
|||||||
parser.add_argument('--onnx_model',
|
parser.add_argument('--onnx_model',
|
||||||
required=True,
|
required=True,
|
||||||
help='output onnx model')
|
help='output onnx model')
|
||||||
|
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -37,8 +41,11 @@ def main():
|
|||||||
with open(args.config, 'r') as fin:
|
with open(args.config, 'r') as fin:
|
||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
feature_dim = configs['model']['input_dim']
|
feature_dim = configs['model']['input_dim']
|
||||||
model = torch.jit.load(args.jit_model)
|
model = init_model(configs['model'])
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
load_checkpoint(model, args.checkpoint)
|
||||||
|
model.eval()
|
||||||
# dummy_input: (batch, time, feature_dim)
|
# dummy_input: (batch, time, feature_dim)
|
||||||
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
||||||
torch.onnx.export(model,
|
torch.onnx.export(model,
|
||||||
@ -48,7 +55,8 @@ def main():
|
|||||||
output_names=['output'],
|
output_names=['output'],
|
||||||
dynamic_axes={'input': {
|
dynamic_axes={'input': {
|
||||||
1: 'T'
|
1: 'T'
|
||||||
}})
|
}},
|
||||||
|
opset_version=10)
|
||||||
|
|
||||||
torch_output = model(dummy_input)
|
torch_output = model(dummy_input)
|
||||||
ort_sess = ort.InferenceSession(args.onnx_model)
|
ort_sess = ort.InferenceSession(args.onnx_model)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user