From 57021924cb0897e74d0ce8f46874e02dcc01c3c9 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Sat, 15 Jan 2022 13:50:34 +0800 Subject: [PATCH] [kws] support onnx export (#53) --- examples/hi_xiaowen/s0/run.sh | 14 ++++++++ kws/bin/export_jit.py | 19 ++-------- kws/bin/export_onnx.py | 65 +++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 4 files changed, 84 insertions(+), 17 deletions(-) create mode 100644 kws/bin/export_onnx.py diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 5a5f074..a735251 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -144,3 +144,17 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then done fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') + onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') + python kws/bin/export_jit.py \ + --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --jit_model $dir/$jit_model + python kws/bin/export_onnx.py \ + --config $dir/config.yaml \ + --jit_model $dir/$jit_model \ + --onnx_model $dir/$onnx_model +fi + diff --git a/kws/bin/export_jit.py b/kws/bin/export_jit.py index 81f48f9..f127377 100644 --- a/kws/bin/export_jit.py +++ b/kws/bin/export_jit.py @@ -28,10 +28,7 @@ def get_args(): parser = argparse.ArgumentParser(description='export your script model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--checkpoint', required=True, help='checkpoint model') - parser.add_argument('--output_file', required=True, help='output file') - parser.add_argument('--output_quant_file', - default=None, - help='output quantized model file') + parser.add_argument('--jit_model', required=True, help='output jit model') args = parser.parse_args() return args @@ -50,18 +47,8 @@ def main(): # Export jit torch script model script_model = torch.jit.script(model) - script_model.save(args.output_file) - print('Export model successfully, see {}'.format(args.output_file)) - - # Export quantized jit torch script model - if args.output_quant_file: - quantized_model = torch.quantization.quantize_dynamic( - model, {torch.nn.Linear}, dtype=torch.qint8) - print(quantized_model) - script_quant_model = torch.jit.script(quantized_model) - script_quant_model.save(args.output_quant_file) - print('Export quantized model successfully, ' - 'see {}'.format(args.output_quant_file)) + script_model.save(args.jit_model) + print('Export model successfully, see {}'.format(args.jit_model)) if __name__ == '__main__': diff --git a/kws/bin/export_onnx.py b/kws/bin/export_onnx.py new file mode 100644 index 0000000..e261e4b --- /dev/null +++ b/kws/bin/export_onnx.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import yaml +import onnxruntime as ort + + +def get_args(): + parser = argparse.ArgumentParser(description='export to onnx model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--jit_model', + required=True, + help='pytorch jit script model') + parser.add_argument('--onnx_model', + required=True, + help='output onnx model') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + feature_dim = configs['model']['input_dim'] + model = torch.jit.load(args.jit_model) + print(model) + # dummy_input: (batch, time, feature_dim) + dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float) + torch.onnx.export(model, + dummy_input, + args.onnx_model, + input_names=['input'], + output_names=['output'], + dynamic_axes={'input': { + 1: 'T' + }}) + + torch_output = model(dummy_input) + ort_sess = ort.InferenceSession(args.onnx_model) + onnx_input = dummy_input.numpy() + onnx_output = ort_sess.run(None, {'input': onnx_input}) + if torch.allclose(torch_output, torch.tensor(onnx_output[0])): + print('Export to onnx succeed!') + else: + print('''Export to onnx succeed, but pytorch/onnx have different + outputs when given the same input, please check!!!''') + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index cc3f45f..c9e45a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ flake8==3.8.2 pyyaml>=5.1 tensorboard tensorboardX -matplotlib \ No newline at end of file +matplotlib +onnxruntime