[kws] support fuse and quantization (#41)

This commit is contained in:
Binbin Zhang 2021-12-13 19:16:50 +08:00 committed by GitHub
parent fd255fd7c6
commit 28652a766f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 50 additions and 3 deletions

View File

@ -45,3 +45,17 @@ class ElementClassifier(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return self.classifier(x) return self.classifier(x)
class LinearClassifier(nn.Module):
""" Wrapper of Linear """
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x: torch.Tensor):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x

View File

@ -19,7 +19,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from kws.model.cmvn import GlobalCMVN from kws.model.cmvn import GlobalCMVN
from kws.model.classifier import GlobalClassifier, LastClassifier from kws.model.classifier import GlobalClassifier, LastClassifier, LinearClassifier
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC from kws.model.mdtc import MDTC
@ -60,6 +60,10 @@ class KWSModel(torch.nn.Module):
x = self.classifier(x) x = self.classifier(x)
return x return x
def fuse_modules(self):
self.preprocessing.fuse_modules()
self.backbone.fuse_modules()
def init_model(configs): def init_model(configs):
cmvn = configs.get('cmvn', {}) cmvn = configs.get('cmvn', {})
@ -143,7 +147,7 @@ def init_model(configs):
print('Unknown classifier type {}'.format(classifier_type)) print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1) sys.exit(1)
else: else:
classifier = torch.nn.Linear(hidden_dim, output_dim) classifier = LinearClassifier(hidden_dim, output_dim)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier) preprocessing, backbone, classifier)

View File

@ -23,6 +23,7 @@ class SubsamplingBase(torch.nn.Module):
super().__init__() super().__init__()
self.subsampling_rate = 1 self.subsampling_rate = 1
class NoSubsampling(SubsamplingBase): class NoSubsampling(SubsamplingBase):
"""No subsampling in accordance to the 'none' preprocessing """No subsampling in accordance to the 'none' preprocessing
""" """
@ -32,6 +33,7 @@ class NoSubsampling(SubsamplingBase):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return x return x
class LinearSubsampling1(SubsamplingBase): class LinearSubsampling1(SubsamplingBase):
"""Linear transform the input without subsampling """Linear transform the input without subsampling
""" """
@ -39,15 +41,22 @@ class LinearSubsampling1(SubsamplingBase):
super().__init__() super().__init__()
self.out = torch.nn.Sequential( self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim), torch.nn.Linear(idim, odim),
# torch.nn.BatchNorm1d(odim),
torch.nn.ReLU(), torch.nn.ReLU(),
) )
self.subsampling_rate = 1 self.subsampling_rate = 1
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.quant(x)
x = self.out(x) x = self.out(x)
x = self.dequant(x)
return x return x
def fuse_modules(self):
torch.quantization.fuse_modules(self, [['out.0', 'out.1']],
inplace=True)
class Conv1dSubsampling1(SubsamplingBase): class Conv1dSubsampling1(SubsamplingBase):
"""Conv1d transform without subsampling """Conv1d transform without subsampling

View File

@ -28,6 +28,8 @@ class Block(nn.Module):
dropout: float = 0.1): dropout: float = 0.1):
super().__init__() super().__init__()
self.padding = (kernel_size - 1) * dilation self.padding = (kernel_size - 1) * dilation
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
""" """
@ -44,11 +46,16 @@ class Block(nn.Module):
assert y.size(2) > self.padding assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:] new_cache = y[:, :, -self.padding:]
y = self.quant(y)
# self.cnn is defined in the subclass of Block # self.cnn is defined in the subclass of Block
y = self.cnn(y) y = self.cnn(y)
y = self.dequant(y)
y = y + x # residual connection y = y + x # residual connection
return y, new_cache return y, new_cache
def fuse_modules(self):
self.cnn.fuse_modules()
class CnnBlock(Block): class CnnBlock(Block):
def __init__(self, def __init__(self,
@ -68,6 +75,10 @@ class CnnBlock(Block):
nn.Dropout(dropout), nn.Dropout(dropout),
) )
def fuse_modules(self):
torch.quantization.fuse_modules(self, [['cnn.0', 'cnn.1', 'cnn.2']],
inplace=True)
class DsCnnBlock(Block): class DsCnnBlock(Block):
""" Depthwise Separable Convolution """ Depthwise Separable Convolution
@ -93,6 +104,11 @@ class DsCnnBlock(Block):
nn.Dropout(dropout), nn.Dropout(dropout),
) )
def fuse_modules(self):
torch.quantization.fuse_modules(
self, [['cnn.0', 'cnn.1', 'cnn.2'], ['cnn.3', 'cnn.4', 'cnn.5']],
inplace=True)
class TCN(nn.Module): class TCN(nn.Module):
def __init__(self, def __init__(self,
@ -129,6 +145,10 @@ class TCN(nn.Module):
new_cache = torch.cat(out_caches, dim=2) new_cache = torch.cat(out_caches, dim=2)
return x, new_cache return x, new_cache
def fuse_modules(self):
for m in self.network:
m.fuse_modules()
if __name__ == '__main__': if __name__ == '__main__':
tcn = TCN(4, 64, 8, block_class=CnnBlock) tcn = TCN(4, 64, 8, block_class=CnnBlock)