add more comments to explain the new classifier designed for speech command classification task
This commit is contained in:
parent
fb2777177f
commit
ac36f6e9dd
@ -6,7 +6,7 @@
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=2
|
||||
stage=-1
|
||||
stop_stage=2
|
||||
num_keywords=11
|
||||
|
||||
|
||||
@ -123,6 +123,8 @@ def init_model(configs):
|
||||
print('Unknown body type {}'.format(backbone_type))
|
||||
sys.exit(1)
|
||||
if 'classifier' in configs:
|
||||
# For speech command dataset, we use 2 FC layer as classifier,
|
||||
# we add dropout after first FC layer to prevent overfitting
|
||||
classifier_type = configs['classifier']['type']
|
||||
dropout = configs['classifier']['dropout']
|
||||
|
||||
@ -131,8 +133,11 @@ def init_model(configs):
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(64, output_dim))
|
||||
if classifier_type == 'global':
|
||||
# global means we add a global average pooling before classifier
|
||||
classifier = GlobalClassifier(classifier_base)
|
||||
elif classifier_type == 'last':
|
||||
# last means we use last frame to do backpropagation, so the model
|
||||
# can be infered streamingly
|
||||
classifier = LastClassifier(classifier_base)
|
||||
else:
|
||||
print('Unknown classifier type {}'.format(classifier_type))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user