diff --git a/kws/bin/train.py b/kws/bin/train.py index 3497322..a6c9472 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -197,7 +197,7 @@ def main(): optimizer = optim.Adam(model.parameters(), lr=configs['optim_conf']['lr'], - weight_decay=configs['optim_conf']['weight_decay']) + weight_decay=configs['optim_conf'].get('weight_decay', 0)) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min',