diff --git a/kws/bin/score.py b/kws/bin/score.py index 0e8c18f..4c80816 100644 --- a/kws/bin/score.py +++ b/kws/bin/score.py @@ -43,6 +43,18 @@ def get_args(): default=16, type=int, help='batch size for inference') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') parser.add_argument('--score_file', required=True, help='output score file') @@ -65,11 +77,15 @@ def main(): test_conf['speed_perturb'] = False test_conf['spec_aug'] = False test_conf['shuffle'] = False - test_conf['fbank_conf']['dither'] = 0.0 + test_conf['feature_extraction_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_size'] = args.batch_size test_dataset = Dataset(args.test_data, test_conf) - test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + test_data_loader = DataLoader(test_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) # Init asr model from configs model = init_model(configs['model'])