diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index a3c1d0a..56cbf76 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -127,10 +127,10 @@ def init_model(configs): dropout = configs['classifier']['dropout'] classifier_base = nn.Sequential( - nn.Linear(hidden_dim, 64), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(64, output_dim)) + nn.Linear(hidden_dim, 64), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(64, output_dim)) if classifier_type == 'global': classifier = GlobalClassifier(classifier_base)