format
This commit is contained in:
parent
215cff687d
commit
95396a3885
@ -127,10 +127,10 @@ def init_model(configs):
|
|||||||
dropout = configs['classifier']['dropout']
|
dropout = configs['classifier']['dropout']
|
||||||
|
|
||||||
classifier_base = nn.Sequential(
|
classifier_base = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, 64),
|
nn.Linear(hidden_dim, 64),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(64, output_dim))
|
nn.Linear(64, output_dim))
|
||||||
|
|
||||||
if classifier_type == 'global':
|
if classifier_type == 'global':
|
||||||
classifier = GlobalClassifier(classifier_base)
|
classifier = GlobalClassifier(classifier_base)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user