diff --git a/kws/bin/train.py b/kws/bin/train.py index 8053b49..9e6117c 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -164,9 +164,9 @@ def main(): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - # if args.rank == 0: - # script_model = torch.jit.script(model) - # script_model.save(os.path.join(args.model_dir, 'init.zip')) + if args.rank == 0: + script_model = torch.jit.script(model) + script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: