diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 5254fa8..5a5f074 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -113,9 +113,34 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - python kws/bin/export_jit.py --config $dir/config.yaml \ + echo "Static quantization, compute FRR/FAR..." + # Apply static quantization + quantize_score_checkpoint=$(basename $score_checkpoint | sed -e 's:.pt$:.quant.zip:g') + cat data/train/data.list | python tools/shuffle_list.py --seed 777 | \ + head -n 10000 > $dir/calibration.list + python kws/bin/static_quantize.py \ + --config $dir/config.yaml \ + --test_data $dir/calibration.list \ --checkpoint $score_checkpoint \ - --output_file $dir/final.zip \ - --output_quant_file $dir/final.quant.zip + --num_workers 8 \ + --script_model $dir/$quantize_score_checkpoint + + result_dir=$dir/test_$(basename $quantize_score_checkpoint) + mkdir -p $result_dir + python kws/bin/score.py \ + --config $dir/config.yaml \ + --test_data data/test/data.list \ + --batch_size 256 \ + --jit_model \ + --checkpoint $dir/$quantize_score_checkpoint \ + --score_file $result_dir/score.txt \ + --num_workers 8 + for keyword in 0 1; do + python kws/bin/compute_det.py \ + --keyword $keyword \ + --test_data data/test/data.list \ + --score_file $result_dir/score.txt \ + --stats_file $result_dir/stats.${keyword}.txt + done fi diff --git a/kws/bin/score.py b/kws/bin/score.py index 40986d2..5d945f2 100644 --- a/kws/bin/score.py +++ b/kws/bin/score.py @@ -58,6 +58,10 @@ def get_args(): parser.add_argument('--score_file', required=True, help='output score file') + parser.add_argument('--jit_model', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') args = parser.parse_args() return args @@ -87,12 +91,16 @@ def main(): num_workers=args.num_workers, prefetch_factor=args.prefetch) - # Init asr model from configs - model = init_model(configs['model']) - - load_checkpoint(model, args.checkpoint) - use_cuda = args.gpu >= 0 and torch.cuda.is_available() - device = torch.device('cuda' if use_cuda else 'cpu') + if args.jit_model: + model = torch.jit.load(args.checkpoint) + # For script model, only cpu is supported. + device = torch.device('cpu') + else: + # Init asr model from configs + model = init_model(configs['model']) + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) model.eval() diff --git a/kws/bin/static_quantize.py b/kws/bin/static_quantize.py new file mode 100644 index 0000000..f92e66f --- /dev/null +++ b/kws/bin/static_quantize.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from kws.dataset.dataset import Dataset +from kws.model.kws_model import init_model +from kws.utils.checkpoint import load_checkpoint + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--script_model', + required=True, + help='output script model') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str("-1") + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['feature_extraction_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_size'] = 1 + + test_dataset = Dataset(args.test_data, test_conf) + test_data_loader = DataLoader(test_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + + # Init asr model from configs + model_fp32 = init_model(configs['model']) + load_checkpoint(model_fp32, args.checkpoint) + # model must be set to eval mode for static quantization logic to work + model_fp32.eval() + + # Fuse the activations to preceding layers, where applicable. + # This needs to be done manually depending on the model architecture. + # Common fusions include `conv + relu` and `conv + batchnorm + relu` + print('================ Float 32 ======================') + print(model_fp32) + print('================ Float 32(fused) ===============') + model_fp32.fuse_modules() + print(model_fp32) + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'fbgemm' for server inference and + # 'qnnpack' for mobile inference. Other quantization configurations such + # as selecting symmetric or assymetric quantization and MinMax or L2Norm + # calibration techniques can be specified here. + model_fp32.qconfig = torch.quantization.get_default_qconfig('qnnpack') + + # Prepare the model for static quantization. This inserts observers in + # the model that will observe activation tensors during calibration. + model_fp32_prepared = torch.quantization.prepare(model_fp32) + + # calibrate the prepared model to determine quantization parameters for + # activations in a real world setting, the calibration would be done with + # a representative dataset + with torch.no_grad(): + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, lengths = batch + logits = model_fp32_prepared(feats) + if batch_idx % 100 == 0: + print('Progress utts {}'.format(batch_idx)) + sys.stdout.flush() + + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, and replaces key operators with + # quantized implementations. + print('=================== int8 ======================') + model_int8 = torch.quantization.convert(model_fp32_prepared) + print(model_int8) + + print('================ int8(script) ==================') + script_model = torch.jit.script(model_int8) + script_model.save(args.script_model) + print(script_model) + + +if __name__ == '__main__': + main() diff --git a/tools/shuffle_list.py b/tools/shuffle_list.py new file mode 100644 index 0000000..8d38c27 --- /dev/null +++ b/tools/shuffle_list.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import sys + +parser = argparse.ArgumentParser(description='shuffle input file by line') +parser.add_argument('--seed', default=None, type=int, help='random seed') +parser.add_argument('--input', help='input file') +parser.add_argument('--output', help='output file') +args = parser.parse_args() + +random.seed(args.seed) + +if args.input is not None: + fin = open(args.input, 'r', encoding='utf8') +else: + fin = sys.stdin + +lines = fin.readlines() +random.shuffle(lines) + +if args.output is not None: + fout = open(args.output, 'w', encoding='utf8') +else: + fout = sys.stdout + +try: + fout.writelines(lines) +except Exception: + pass +if args.input is not None: + fin.close() +if args.output is not None: + fout.close()