[kws] support fuse and quantization (#41)
This commit is contained in:
parent
fd255fd7c6
commit
28652a766f
@ -45,3 +45,17 @@ class ElementClassifier(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.classifier(x)
|
||||
|
||||
class LinearClassifier(nn.Module):
|
||||
""" Wrapper of Linear """
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.quant(x)
|
||||
x = self.linear(x)
|
||||
x = self.dequant(x)
|
||||
return x
|
||||
|
||||
@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from kws.model.cmvn import GlobalCMVN
|
||||
from kws.model.classifier import GlobalClassifier, LastClassifier
|
||||
from kws.model.classifier import GlobalClassifier, LastClassifier, LinearClassifier
|
||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||
from kws.model.mdtc import MDTC
|
||||
@ -60,6 +60,10 @@ class KWSModel(torch.nn.Module):
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def fuse_modules(self):
|
||||
self.preprocessing.fuse_modules()
|
||||
self.backbone.fuse_modules()
|
||||
|
||||
|
||||
def init_model(configs):
|
||||
cmvn = configs.get('cmvn', {})
|
||||
@ -143,7 +147,7 @@ def init_model(configs):
|
||||
print('Unknown classifier type {}'.format(classifier_type))
|
||||
sys.exit(1)
|
||||
else:
|
||||
classifier = torch.nn.Linear(hidden_dim, output_dim)
|
||||
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone, classifier)
|
||||
|
||||
@ -23,6 +23,7 @@ class SubsamplingBase(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.subsampling_rate = 1
|
||||
|
||||
|
||||
class NoSubsampling(SubsamplingBase):
|
||||
"""No subsampling in accordance to the 'none' preprocessing
|
||||
"""
|
||||
@ -32,6 +33,7 @@ class NoSubsampling(SubsamplingBase):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
class LinearSubsampling1(SubsamplingBase):
|
||||
"""Linear transform the input without subsampling
|
||||
"""
|
||||
@ -39,15 +41,22 @@ class LinearSubsampling1(SubsamplingBase):
|
||||
super().__init__()
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, odim),
|
||||
# torch.nn.BatchNorm1d(odim),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.subsampling_rate = 1
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.quant(x)
|
||||
x = self.out(x)
|
||||
x = self.dequant(x)
|
||||
return x
|
||||
|
||||
def fuse_modules(self):
|
||||
torch.quantization.fuse_modules(self, [['out.0', 'out.1']],
|
||||
inplace=True)
|
||||
|
||||
|
||||
class Conv1dSubsampling1(SubsamplingBase):
|
||||
"""Conv1d transform without subsampling
|
||||
|
||||
@ -28,6 +28,8 @@ class Block(nn.Module):
|
||||
dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.padding = (kernel_size - 1) * dilation
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
@ -44,11 +46,16 @@ class Block(nn.Module):
|
||||
assert y.size(2) > self.padding
|
||||
new_cache = y[:, :, -self.padding:]
|
||||
|
||||
y = self.quant(y)
|
||||
# self.cnn is defined in the subclass of Block
|
||||
y = self.cnn(y)
|
||||
y = self.dequant(y)
|
||||
y = y + x # residual connection
|
||||
return y, new_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
self.cnn.fuse_modules()
|
||||
|
||||
|
||||
class CnnBlock(Block):
|
||||
def __init__(self,
|
||||
@ -68,6 +75,10 @@ class CnnBlock(Block):
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def fuse_modules(self):
|
||||
torch.quantization.fuse_modules(self, [['cnn.0', 'cnn.1', 'cnn.2']],
|
||||
inplace=True)
|
||||
|
||||
|
||||
class DsCnnBlock(Block):
|
||||
""" Depthwise Separable Convolution
|
||||
@ -93,6 +104,11 @@ class DsCnnBlock(Block):
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def fuse_modules(self):
|
||||
torch.quantization.fuse_modules(
|
||||
self, [['cnn.0', 'cnn.1', 'cnn.2'], ['cnn.3', 'cnn.4', 'cnn.5']],
|
||||
inplace=True)
|
||||
|
||||
|
||||
class TCN(nn.Module):
|
||||
def __init__(self,
|
||||
@ -129,6 +145,10 @@ class TCN(nn.Module):
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
for m in self.network:
|
||||
m.fuse_modules()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tcn = TCN(4, 64, 8, block_class=CnnBlock)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user