[kws] support onnx export (#53)
This commit is contained in:
parent
665df6113e
commit
57021924cb
@ -144,3 +144,17 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
|
||||
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
|
||||
python kws/bin/export_jit.py \
|
||||
--config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--jit_model $dir/$jit_model
|
||||
python kws/bin/export_onnx.py \
|
||||
--config $dir/config.yaml \
|
||||
--jit_model $dir/$jit_model \
|
||||
--onnx_model $dir/$onnx_model
|
||||
fi
|
||||
|
||||
|
||||
@ -28,10 +28,7 @@ def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your script model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--output_file', required=True, help='output file')
|
||||
parser.add_argument('--output_quant_file',
|
||||
default=None,
|
||||
help='output quantized model file')
|
||||
parser.add_argument('--jit_model', required=True, help='output jit model')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -50,18 +47,8 @@ def main():
|
||||
# Export jit torch script model
|
||||
|
||||
script_model = torch.jit.script(model)
|
||||
script_model.save(args.output_file)
|
||||
print('Export model successfully, see {}'.format(args.output_file))
|
||||
|
||||
# Export quantized jit torch script model
|
||||
if args.output_quant_file:
|
||||
quantized_model = torch.quantization.quantize_dynamic(
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||
print(quantized_model)
|
||||
script_quant_model = torch.jit.script(quantized_model)
|
||||
script_quant_model.save(args.output_quant_file)
|
||||
print('Export quantized model successfully, '
|
||||
'see {}'.format(args.output_quant_file))
|
||||
script_model.save(args.jit_model)
|
||||
print('Export model successfully, see {}'.format(args.jit_model))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
65
kws/bin/export_onnx.py
Normal file
65
kws/bin/export_onnx.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export to onnx model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--jit_model',
|
||||
required=True,
|
||||
help='pytorch jit script model')
|
||||
parser.add_argument('--onnx_model',
|
||||
required=True,
|
||||
help='output onnx model')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feature_dim = configs['model']['input_dim']
|
||||
model = torch.jit.load(args.jit_model)
|
||||
print(model)
|
||||
# dummy_input: (batch, time, feature_dim)
|
||||
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
||||
torch.onnx.export(model,
|
||||
dummy_input,
|
||||
args.onnx_model,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={'input': {
|
||||
1: 'T'
|
||||
}})
|
||||
|
||||
torch_output = model(dummy_input)
|
||||
ort_sess = ort.InferenceSession(args.onnx_model)
|
||||
onnx_input = dummy_input.numpy()
|
||||
onnx_output = ort_sess.run(None, {'input': onnx_input})
|
||||
if torch.allclose(torch_output, torch.tensor(onnx_output[0])):
|
||||
print('Export to onnx succeed!')
|
||||
else:
|
||||
print('''Export to onnx succeed, but pytorch/onnx have different
|
||||
outputs when given the same input, please check!!!''')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -2,4 +2,5 @@ flake8==3.8.2
|
||||
pyyaml>=5.1
|
||||
tensorboard
|
||||
tensorboardX
|
||||
matplotlib
|
||||
matplotlib
|
||||
onnxruntime
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user