format
This commit is contained in:
parent
95396a3885
commit
fb2777177f
@ -126,12 +126,10 @@ def init_model(configs):
|
||||
classifier_type = configs['classifier']['type']
|
||||
dropout = configs['classifier']['dropout']
|
||||
|
||||
classifier_base = nn.Sequential(
|
||||
nn.Linear(hidden_dim, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(64, output_dim))
|
||||
|
||||
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(64, output_dim))
|
||||
if classifier_type == 'global':
|
||||
classifier = GlobalClassifier(classifier_base)
|
||||
elif classifier_type == 'last':
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user