From 8943acb51f7768de0417683fac6db5e207fd5d33 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Wed, 15 Dec 2021 21:10:27 +0800 Subject: [PATCH] [kws] put activation in model, so the activation could be exported in script model --- kws/bin/score.py | 2 +- kws/model/kws_model.py | 23 +++++++++++++++-------- kws/model/loss.py | 1 - 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/kws/bin/score.py b/kws/bin/score.py index 5d945f2..b8e7c5c 100644 --- a/kws/bin/score.py +++ b/kws/bin/score.py @@ -110,7 +110,7 @@ def main(): feats = feats.to(device) lengths = lengths.to(device) mask = padding_mask(lengths).unsqueeze(2) - logits = torch.sigmoid(model(feats)) + logits = model(feats) logits = logits.masked_fill(mask, 0.0) max_logits, _ = logits.max(dim=1) max_logits = max_logits.cpu() diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index f99f0ab..0744524 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -26,22 +26,26 @@ from kws.model.mdtc import MDTC from kws.utils.cmvn import load_cmvn -class KWSModel(torch.nn.Module): +class KWSModel(nn.Module): """Our model consists of four parts: 1. global_cmvn: Optional, (idim, idim) 2. preprocessing: feature dimention projection, (idim, hdim) 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) 4. classifier: output layer or classifier of KWS model, (hdim, odim) + 5. activation: + nn.Sigmoid for wakeup word + nn.Identity for speech command dataset """ def __init__( self, idim: int, odim: int, hdim: int, - global_cmvn: Optional[torch.nn.Module], - preprocessing: Optional[torch.nn.Module], - backbone: torch.nn.Module, - classifier: torch.nn.Module + global_cmvn: Optional[nn.Module], + preprocessing: Optional[nn.Module], + backbone: nn.Module, + classifier: nn.Module, + activation: nn.Module, ): super().__init__() self.idim = idim @@ -51,6 +55,7 @@ class KWSModel(torch.nn.Module): self.preprocessing = preprocessing self.backbone = backbone self.classifier = classifier + self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: if self.global_cmvn is not None: @@ -58,6 +63,7 @@ class KWSModel(torch.nn.Module): x = self.preprocessing(x) x, _ = self.backbone(x) x = self.classifier(x) + x = self.activation(x) return x def fuse_modules(self): @@ -132,8 +138,7 @@ def init_model(configs): classifier_type = configs['classifier']['type'] dropout = configs['classifier']['dropout'] - classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), - nn.ReLU(), + classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, output_dim)) if classifier_type == 'global': @@ -146,9 +151,11 @@ def init_model(configs): else: print('Unknown classifier type {}'.format(classifier_type)) sys.exit(1) + activation = nn.Identity() else: classifier = LinearClassifier(hidden_dim, output_dim) + activation = nn.Sigmoid() kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, - preprocessing, backbone, classifier) + preprocessing, backbone, classifier, activation) return kws_model diff --git a/kws/model/loss.py b/kws/model/loss.py index 80a7fec..f862928 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -38,7 +38,6 @@ def max_pooling_loss(logits: torch.Tensor, (float): loss of current batch (float): accuracy of current batch ''' - logits = torch.sigmoid(logits) mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2)