diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 56cbf76..2f2b2e2 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -126,12 +126,10 @@ def init_model(configs): classifier_type = configs['classifier']['type'] dropout = configs['classifier']['dropout'] - classifier_base = nn.Sequential( - nn.Linear(hidden_dim, 64), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(64, output_dim)) - + classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(64, output_dim)) if classifier_type == 'global': classifier = GlobalClassifier(classifier_base) elif classifier_type == 'last':