[kws] put activation in model, so the activation could be exported in script model

This commit is contained in:
Binbin Zhang 2021-12-15 21:10:27 +08:00
parent 566baca343
commit 8943acb51f
3 changed files with 16 additions and 10 deletions

View File

@ -110,7 +110,7 @@ def main():
feats = feats.to(device)
lengths = lengths.to(device)
mask = padding_mask(lengths).unsqueeze(2)
logits = torch.sigmoid(model(feats))
logits = model(feats)
logits = logits.masked_fill(mask, 0.0)
max_logits, _ = logits.max(dim=1)
max_logits = max_logits.cpu()

View File

@ -26,22 +26,26 @@ from kws.model.mdtc import MDTC
from kws.utils.cmvn import load_cmvn
class KWSModel(torch.nn.Module):
class KWSModel(nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation:
nn.Sigmoid for wakeup word
nn.Identity for speech command dataset
"""
def __init__(
self,
idim: int,
odim: int,
hdim: int,
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
global_cmvn: Optional[nn.Module],
preprocessing: Optional[nn.Module],
backbone: nn.Module,
classifier: nn.Module,
activation: nn.Module,
):
super().__init__()
self.idim = idim
@ -51,6 +55,7 @@ class KWSModel(torch.nn.Module):
self.preprocessing = preprocessing
self.backbone = backbone
self.classifier = classifier
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
@ -58,6 +63,7 @@ class KWSModel(torch.nn.Module):
x = self.preprocessing(x)
x, _ = self.backbone(x)
x = self.classifier(x)
x = self.activation(x)
return x
def fuse_modules(self):
@ -132,8 +138,7 @@ def init_model(configs):
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
nn.ReLU(),
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim))
if classifier_type == 'global':
@ -146,9 +151,11 @@ def init_model(configs):
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
activation = nn.Identity()
else:
classifier = LinearClassifier(hidden_dim, output_dim)
activation = nn.Sigmoid()
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier)
preprocessing, backbone, classifier, activation)
return kws_model

View File

@ -38,7 +38,6 @@ def max_pooling_loss(logits: torch.Tensor,
(float): loss of current batch
(float): accuracy of current batch
'''
logits = torch.sigmoid(logits)
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)