add more comments to explain the new classifier designed for speech command classification task

This commit is contained in:
jingyong hou 2021-12-06 17:03:41 +08:00
parent fb2777177f
commit ac36f6e9dd
2 changed files with 6 additions and 1 deletions

View File

@ -6,7 +6,7 @@
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
stage=2 stage=-1
stop_stage=2 stop_stage=2
num_keywords=11 num_keywords=11

View File

@ -123,6 +123,8 @@ def init_model(configs):
print('Unknown body type {}'.format(backbone_type)) print('Unknown body type {}'.format(backbone_type))
sys.exit(1) sys.exit(1)
if 'classifier' in configs: if 'classifier' in configs:
# For speech command dataset, we use 2 FC layer as classifier,
# we add dropout after first FC layer to prevent overfitting
classifier_type = configs['classifier']['type'] classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout'] dropout = configs['classifier']['dropout']
@ -131,8 +133,11 @@ def init_model(configs):
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(64, output_dim)) nn.Linear(64, output_dim))
if classifier_type == 'global': if classifier_type == 'global':
# global means we add a global average pooling before classifier
classifier = GlobalClassifier(classifier_base) classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last': elif classifier_type == 'last':
# last means we use last frame to do backpropagation, so the model
# can be infered streamingly
classifier = LastClassifier(classifier_base) classifier = LastClassifier(classifier_base)
else: else:
print('Unknown classifier type {}'.format(classifier_type)) print('Unknown classifier type {}'.format(classifier_type))