diff --git a/kws/bin/train.py b/kws/bin/train.py index c7b092a..4d178c9 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -157,7 +157,9 @@ def main(): infos = {} start_epoch = infos.get('epoch', -1) + 1 cv_loss = infos.get('cv_loss', 0.0) - + # get the last epoch lr + lr_last_epoch = infos.get('lr', configs['optim_conf']['lr']) + configs['optim_conf']['lr'] = lr_last_epoch model_dir = args.model_dir writer = None if rank == 0: