Merge pull request #51 from wenet-e2e/binbin-activation

[kws] put activation in model, so the activation could be exported in…
This commit is contained in:
Menglong Xu 2021-12-15 21:32:52 +08:00 committed by GitHub
commit 20891f90e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 10 deletions

View File

@ -110,7 +110,7 @@ def main():
feats = feats.to(device) feats = feats.to(device)
lengths = lengths.to(device) lengths = lengths.to(device)
mask = padding_mask(lengths).unsqueeze(2) mask = padding_mask(lengths).unsqueeze(2)
logits = torch.sigmoid(model(feats)) logits = model(feats)
logits = logits.masked_fill(mask, 0.0) logits = logits.masked_fill(mask, 0.0)
max_logits, _ = logits.max(dim=1) max_logits, _ = logits.max(dim=1)
max_logits = max_logits.cpu() max_logits = max_logits.cpu()

View File

@ -26,22 +26,26 @@ from kws.model.mdtc import MDTC
from kws.utils.cmvn import load_cmvn from kws.utils.cmvn import load_cmvn
class KWSModel(torch.nn.Module): class KWSModel(nn.Module):
"""Our model consists of four parts: """Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim) 1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim) 2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim) 4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation:
nn.Sigmoid for wakeup word
nn.Identity for speech command dataset
""" """
def __init__( def __init__(
self, self,
idim: int, idim: int,
odim: int, odim: int,
hdim: int, hdim: int,
global_cmvn: Optional[torch.nn.Module], global_cmvn: Optional[nn.Module],
preprocessing: Optional[torch.nn.Module], preprocessing: Optional[nn.Module],
backbone: torch.nn.Module, backbone: nn.Module,
classifier: torch.nn.Module classifier: nn.Module,
activation: nn.Module,
): ):
super().__init__() super().__init__()
self.idim = idim self.idim = idim
@ -51,6 +55,7 @@ class KWSModel(torch.nn.Module):
self.preprocessing = preprocessing self.preprocessing = preprocessing
self.backbone = backbone self.backbone = backbone
self.classifier = classifier self.classifier = classifier
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None: if self.global_cmvn is not None:
@ -58,6 +63,7 @@ class KWSModel(torch.nn.Module):
x = self.preprocessing(x) x = self.preprocessing(x)
x, _ = self.backbone(x) x, _ = self.backbone(x)
x = self.classifier(x) x = self.classifier(x)
x = self.activation(x)
return x return x
def fuse_modules(self): def fuse_modules(self):
@ -132,8 +138,7 @@ def init_model(configs):
classifier_type = configs['classifier']['type'] classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout'] dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(),
nn.ReLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(64, output_dim)) nn.Linear(64, output_dim))
if classifier_type == 'global': if classifier_type == 'global':
@ -146,9 +151,11 @@ def init_model(configs):
else: else:
print('Unknown classifier type {}'.format(classifier_type)) print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1) sys.exit(1)
activation = nn.Identity()
else: else:
classifier = LinearClassifier(hidden_dim, output_dim) classifier = LinearClassifier(hidden_dim, output_dim)
activation = nn.Sigmoid()
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier) preprocessing, backbone, classifier, activation)
return kws_model return kws_model

View File

@ -38,7 +38,6 @@ def max_pooling_loss(logits: torch.Tensor,
(float): loss of current batch (float): loss of current batch
(float): accuracy of current batch (float): accuracy of current batch
''' '''
logits = torch.sigmoid(logits)
mask = padding_mask(lengths) mask = padding_mask(lengths)
num_utts = logits.size(0) num_utts = logits.size(0)
num_keywords = logits.size(2) num_keywords = logits.size(2)