update kes_model.py

This commit is contained in:
jingyong hou 2021-12-06 14:54:47 +08:00
parent 3b36da11c9
commit fb10439f2f

View File

@ -122,24 +122,27 @@ def init_model(configs):
else:
print('Unknown body type {}'.format(backbone_type))
sys.exit(1)
if 'classifier' in configs:
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(
classifier_base = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim),
)
if classifier_type == 'linear':
classifier = classifier_base
elif classifier_type == 'global':
classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last':
classifier = LastClassifier(classifier_base)
nn.Linear(64, output_dim))
if classifier_type == 'global':
classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last':
classifier = LastClassifier(classifier_base)
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
classifier = torch.nn.Linear(hidden_dim, output_dim)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier)
return kws_model