From fb10439f2fb72386210c8311ba3026d875429536 Mon Sep 17 00:00:00 2001 From: jingyong hou Date: Mon, 6 Dec 2021 14:54:47 +0800 Subject: [PATCH] update kes_model.py --- kws/model/kws_model.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 5c56aef..0f1bc8d 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -122,24 +122,27 @@ def init_model(configs): else: print('Unknown body type {}'.format(backbone_type)) sys.exit(1) + + if 'classifier' in configs: + classifier_type = configs['classifier']['type'] + dropout = configs['classifier']['dropout'] - classifier_type = configs['classifier']['type'] - dropout = configs['classifier']['dropout'] - classifier_base = nn.Sequential( + classifier_base = nn.Sequential( nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Dropout(dropout), - nn.Linear(64, output_dim), - ) - if classifier_type == 'linear': - classifier = classifier_base - elif classifier_type == 'global': - classifier = GlobalClassifier(classifier_base) - elif classifier_type == 'last': - classifier = LastClassifier(classifier_base) + nn.Linear(64, output_dim)) + + if classifier_type == 'global': + classifier = GlobalClassifier(classifier_base) + elif classifier_type == 'last': + classifier = LastClassifier(classifier_base) + else: + print('Unknown classifier type {}'.format(classifier_type)) + sys.exit(1) else: - print('Unknown classifier type {}'.format(classifier_type)) - sys.exit(1) + classifier = torch.nn.Linear(hidden_dim, output_dim) + kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, preprocessing, backbone, classifier) return kws_model