This commit is contained in:
jingyong hou 2021-12-06 15:14:21 +08:00
parent 95396a3885
commit fb2777177f

View File

@ -126,12 +126,10 @@ def init_model(configs):
classifier_type = configs['classifier']['type'] classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout'] dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential( classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
nn.Linear(hidden_dim, 64), nn.ReLU(),
nn.ReLU(), nn.Dropout(dropout),
nn.Dropout(dropout), nn.Linear(64, output_dim))
nn.Linear(64, output_dim))
if classifier_type == 'global': if classifier_type == 'global':
classifier = GlobalClassifier(classifier_base) classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last': elif classifier_type == 'last':