update kes_model.py
This commit is contained in:
parent
3b36da11c9
commit
fb10439f2f
@ -123,23 +123,26 @@ def init_model(configs):
|
|||||||
print('Unknown body type {}'.format(backbone_type))
|
print('Unknown body type {}'.format(backbone_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
if 'classifier' in configs:
|
||||||
classifier_type = configs['classifier']['type']
|
classifier_type = configs['classifier']['type']
|
||||||
dropout = configs['classifier']['dropout']
|
dropout = configs['classifier']['dropout']
|
||||||
|
|
||||||
classifier_base = nn.Sequential(
|
classifier_base = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, 64),
|
nn.Linear(hidden_dim, 64),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(64, output_dim),
|
nn.Linear(64, output_dim))
|
||||||
)
|
|
||||||
if classifier_type == 'linear':
|
if classifier_type == 'global':
|
||||||
classifier = classifier_base
|
|
||||||
elif classifier_type == 'global':
|
|
||||||
classifier = GlobalClassifier(classifier_base)
|
classifier = GlobalClassifier(classifier_base)
|
||||||
elif classifier_type == 'last':
|
elif classifier_type == 'last':
|
||||||
classifier = LastClassifier(classifier_base)
|
classifier = LastClassifier(classifier_base)
|
||||||
else:
|
else:
|
||||||
print('Unknown classifier type {}'.format(classifier_type))
|
print('Unknown classifier type {}'.format(classifier_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
classifier = torch.nn.Linear(hidden_dim, output_dim)
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier)
|
preprocessing, backbone, classifier)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user