diff --git a/kws/bin/train.py b/kws/bin/train.py index a6c9472..6ba902b 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -195,9 +195,7 @@ def main(): device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) - optimizer = optim.Adam(model.parameters(), - lr=configs['optim_conf']['lr'], - weight_decay=configs['optim_conf'].get('weight_decay', 0)) + optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min',