[kws] put activation in model, so the activation could be exported in script model
This commit is contained in:
parent
566baca343
commit
8943acb51f
@ -110,7 +110,7 @@ def main():
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
mask = padding_mask(lengths).unsqueeze(2)
|
||||
logits = torch.sigmoid(model(feats))
|
||||
logits = model(feats)
|
||||
logits = logits.masked_fill(mask, 0.0)
|
||||
max_logits, _ = logits.max(dim=1)
|
||||
max_logits = max_logits.cpu()
|
||||
|
||||
@ -26,22 +26,26 @@ from kws.model.mdtc import MDTC
|
||||
from kws.utils.cmvn import load_cmvn
|
||||
|
||||
|
||||
class KWSModel(torch.nn.Module):
|
||||
class KWSModel(nn.Module):
|
||||
"""Our model consists of four parts:
|
||||
1. global_cmvn: Optional, (idim, idim)
|
||||
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||
5. activation:
|
||||
nn.Sigmoid for wakeup word
|
||||
nn.Identity for speech command dataset
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
odim: int,
|
||||
hdim: int,
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
preprocessing: Optional[torch.nn.Module],
|
||||
backbone: torch.nn.Module,
|
||||
classifier: torch.nn.Module
|
||||
global_cmvn: Optional[nn.Module],
|
||||
preprocessing: Optional[nn.Module],
|
||||
backbone: nn.Module,
|
||||
classifier: nn.Module,
|
||||
activation: nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
@ -51,6 +55,7 @@ class KWSModel(torch.nn.Module):
|
||||
self.preprocessing = preprocessing
|
||||
self.backbone = backbone
|
||||
self.classifier = classifier
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_cmvn is not None:
|
||||
@ -58,6 +63,7 @@ class KWSModel(torch.nn.Module):
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
def fuse_modules(self):
|
||||
@ -132,8 +138,7 @@ def init_model(configs):
|
||||
classifier_type = configs['classifier']['type']
|
||||
dropout = configs['classifier']['dropout']
|
||||
|
||||
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
|
||||
nn.ReLU(),
|
||||
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(64, output_dim))
|
||||
if classifier_type == 'global':
|
||||
@ -146,9 +151,11 @@ def init_model(configs):
|
||||
else:
|
||||
print('Unknown classifier type {}'.format(classifier_type))
|
||||
sys.exit(1)
|
||||
activation = nn.Identity()
|
||||
else:
|
||||
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||
activation = nn.Sigmoid()
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone, classifier)
|
||||
preprocessing, backbone, classifier, activation)
|
||||
return kws_model
|
||||
|
||||
@ -38,7 +38,6 @@ def max_pooling_loss(logits: torch.Tensor,
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
'''
|
||||
logits = torch.sigmoid(logits)
|
||||
mask = padding_mask(lengths)
|
||||
num_utts = logits.size(0)
|
||||
num_keywords = logits.size(2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user