Merge pull request #51 from wenet-e2e/binbin-activation
[kws] put activation in model, so the activation could be exported in…
This commit is contained in:
commit
20891f90e6
@ -110,7 +110,7 @@ def main():
|
|||||||
feats = feats.to(device)
|
feats = feats.to(device)
|
||||||
lengths = lengths.to(device)
|
lengths = lengths.to(device)
|
||||||
mask = padding_mask(lengths).unsqueeze(2)
|
mask = padding_mask(lengths).unsqueeze(2)
|
||||||
logits = torch.sigmoid(model(feats))
|
logits = model(feats)
|
||||||
logits = logits.masked_fill(mask, 0.0)
|
logits = logits.masked_fill(mask, 0.0)
|
||||||
max_logits, _ = logits.max(dim=1)
|
max_logits, _ = logits.max(dim=1)
|
||||||
max_logits = max_logits.cpu()
|
max_logits = max_logits.cpu()
|
||||||
|
|||||||
@ -26,22 +26,26 @@ from kws.model.mdtc import MDTC
|
|||||||
from kws.utils.cmvn import load_cmvn
|
from kws.utils.cmvn import load_cmvn
|
||||||
|
|
||||||
|
|
||||||
class KWSModel(torch.nn.Module):
|
class KWSModel(nn.Module):
|
||||||
"""Our model consists of four parts:
|
"""Our model consists of four parts:
|
||||||
1. global_cmvn: Optional, (idim, idim)
|
1. global_cmvn: Optional, (idim, idim)
|
||||||
2. preprocessing: feature dimention projection, (idim, hdim)
|
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||||
|
5. activation:
|
||||||
|
nn.Sigmoid for wakeup word
|
||||||
|
nn.Identity for speech command dataset
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
idim: int,
|
idim: int,
|
||||||
odim: int,
|
odim: int,
|
||||||
hdim: int,
|
hdim: int,
|
||||||
global_cmvn: Optional[torch.nn.Module],
|
global_cmvn: Optional[nn.Module],
|
||||||
preprocessing: Optional[torch.nn.Module],
|
preprocessing: Optional[nn.Module],
|
||||||
backbone: torch.nn.Module,
|
backbone: nn.Module,
|
||||||
classifier: torch.nn.Module
|
classifier: nn.Module,
|
||||||
|
activation: nn.Module,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
@ -51,6 +55,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
self.preprocessing = preprocessing
|
self.preprocessing = preprocessing
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
@ -58,6 +63,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
x = self.preprocessing(x)
|
x = self.preprocessing(x)
|
||||||
x, _ = self.backbone(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
x = self.activation(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def fuse_modules(self):
|
def fuse_modules(self):
|
||||||
@ -132,8 +138,7 @@ def init_model(configs):
|
|||||||
classifier_type = configs['classifier']['type']
|
classifier_type = configs['classifier']['type']
|
||||||
dropout = configs['classifier']['dropout']
|
dropout = configs['classifier']['dropout']
|
||||||
|
|
||||||
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64),
|
classifier_base = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(),
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(64, output_dim))
|
nn.Linear(64, output_dim))
|
||||||
if classifier_type == 'global':
|
if classifier_type == 'global':
|
||||||
@ -146,9 +151,11 @@ def init_model(configs):
|
|||||||
else:
|
else:
|
||||||
print('Unknown classifier type {}'.format(classifier_type))
|
print('Unknown classifier type {}'.format(classifier_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
activation = nn.Identity()
|
||||||
else:
|
else:
|
||||||
classifier = LinearClassifier(hidden_dim, output_dim)
|
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||||
|
activation = nn.Sigmoid()
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier)
|
preprocessing, backbone, classifier, activation)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -38,7 +38,6 @@ def max_pooling_loss(logits: torch.Tensor,
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
'''
|
'''
|
||||||
logits = torch.sigmoid(logits)
|
|
||||||
mask = padding_mask(lengths)
|
mask = padding_mask(lengths)
|
||||||
num_utts = logits.size(0)
|
num_utts = logits.size(0)
|
||||||
num_keywords = logits.size(2)
|
num_keywords = logits.size(2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user