This commit is contained in:
jingyong hou 2021-12-06 15:11:23 +08:00
parent 215cff687d
commit 95396a3885

View File

@ -127,10 +127,10 @@ def init_model(configs):
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim))
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim))
if classifier_type == 'global':
classifier = GlobalClassifier(classifier_base)