[kws] support fuse and quantization (#41)
This commit is contained in:
parent
fd255fd7c6
commit
28652a766f
@ -45,3 +45,17 @@ class ElementClassifier(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return self.classifier(x)
|
return self.classifier(x)
|
||||||
|
|
||||||
|
class LinearClassifier(nn.Module):
|
||||||
|
""" Wrapper of Linear """
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(input_dim, output_dim)
|
||||||
|
self.quant = torch.quantization.QuantStub()
|
||||||
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.quant(x)
|
||||||
|
x = self.linear(x)
|
||||||
|
x = self.dequant(x)
|
||||||
|
return x
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from kws.model.cmvn import GlobalCMVN
|
from kws.model.cmvn import GlobalCMVN
|
||||||
from kws.model.classifier import GlobalClassifier, LastClassifier
|
from kws.model.classifier import GlobalClassifier, LastClassifier, LinearClassifier
|
||||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
||||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||||
from kws.model.mdtc import MDTC
|
from kws.model.mdtc import MDTC
|
||||||
@ -60,6 +60,10 @@ class KWSModel(torch.nn.Module):
|
|||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def fuse_modules(self):
|
||||||
|
self.preprocessing.fuse_modules()
|
||||||
|
self.backbone.fuse_modules()
|
||||||
|
|
||||||
|
|
||||||
def init_model(configs):
|
def init_model(configs):
|
||||||
cmvn = configs.get('cmvn', {})
|
cmvn = configs.get('cmvn', {})
|
||||||
@ -143,7 +147,7 @@ def init_model(configs):
|
|||||||
print('Unknown classifier type {}'.format(classifier_type))
|
print('Unknown classifier type {}'.format(classifier_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
classifier = torch.nn.Linear(hidden_dim, output_dim)
|
classifier = LinearClassifier(hidden_dim, output_dim)
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier)
|
preprocessing, backbone, classifier)
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class SubsamplingBase(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.subsampling_rate = 1
|
self.subsampling_rate = 1
|
||||||
|
|
||||||
|
|
||||||
class NoSubsampling(SubsamplingBase):
|
class NoSubsampling(SubsamplingBase):
|
||||||
"""No subsampling in accordance to the 'none' preprocessing
|
"""No subsampling in accordance to the 'none' preprocessing
|
||||||
"""
|
"""
|
||||||
@ -32,6 +33,7 @@ class NoSubsampling(SubsamplingBase):
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LinearSubsampling1(SubsamplingBase):
|
class LinearSubsampling1(SubsamplingBase):
|
||||||
"""Linear transform the input without subsampling
|
"""Linear transform the input without subsampling
|
||||||
"""
|
"""
|
||||||
@ -39,15 +41,22 @@ class LinearSubsampling1(SubsamplingBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.out = torch.nn.Sequential(
|
self.out = torch.nn.Sequential(
|
||||||
torch.nn.Linear(idim, odim),
|
torch.nn.Linear(idim, odim),
|
||||||
# torch.nn.BatchNorm1d(odim),
|
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
)
|
)
|
||||||
self.subsampling_rate = 1
|
self.subsampling_rate = 1
|
||||||
|
self.quant = torch.quantization.QuantStub()
|
||||||
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.quant(x)
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
|
x = self.dequant(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def fuse_modules(self):
|
||||||
|
torch.quantization.fuse_modules(self, [['out.0', 'out.1']],
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
|
|
||||||
class Conv1dSubsampling1(SubsamplingBase):
|
class Conv1dSubsampling1(SubsamplingBase):
|
||||||
"""Conv1d transform without subsampling
|
"""Conv1d transform without subsampling
|
||||||
|
|||||||
@ -28,6 +28,8 @@ class Block(nn.Module):
|
|||||||
dropout: float = 0.1):
|
dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding = (kernel_size - 1) * dilation
|
self.padding = (kernel_size - 1) * dilation
|
||||||
|
self.quant = torch.quantization.QuantStub()
|
||||||
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
"""
|
"""
|
||||||
@ -44,11 +46,16 @@ class Block(nn.Module):
|
|||||||
assert y.size(2) > self.padding
|
assert y.size(2) > self.padding
|
||||||
new_cache = y[:, :, -self.padding:]
|
new_cache = y[:, :, -self.padding:]
|
||||||
|
|
||||||
|
y = self.quant(y)
|
||||||
# self.cnn is defined in the subclass of Block
|
# self.cnn is defined in the subclass of Block
|
||||||
y = self.cnn(y)
|
y = self.cnn(y)
|
||||||
|
y = self.dequant(y)
|
||||||
y = y + x # residual connection
|
y = y + x # residual connection
|
||||||
return y, new_cache
|
return y, new_cache
|
||||||
|
|
||||||
|
def fuse_modules(self):
|
||||||
|
self.cnn.fuse_modules()
|
||||||
|
|
||||||
|
|
||||||
class CnnBlock(Block):
|
class CnnBlock(Block):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -68,6 +75,10 @@ class CnnBlock(Block):
|
|||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fuse_modules(self):
|
||||||
|
torch.quantization.fuse_modules(self, [['cnn.0', 'cnn.1', 'cnn.2']],
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
|
|
||||||
class DsCnnBlock(Block):
|
class DsCnnBlock(Block):
|
||||||
""" Depthwise Separable Convolution
|
""" Depthwise Separable Convolution
|
||||||
@ -93,6 +104,11 @@ class DsCnnBlock(Block):
|
|||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fuse_modules(self):
|
||||||
|
torch.quantization.fuse_modules(
|
||||||
|
self, [['cnn.0', 'cnn.1', 'cnn.2'], ['cnn.3', 'cnn.4', 'cnn.5']],
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
|
|
||||||
class TCN(nn.Module):
|
class TCN(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -129,6 +145,10 @@ class TCN(nn.Module):
|
|||||||
new_cache = torch.cat(out_caches, dim=2)
|
new_cache = torch.cat(out_caches, dim=2)
|
||||||
return x, new_cache
|
return x, new_cache
|
||||||
|
|
||||||
|
def fuse_modules(self):
|
||||||
|
for m in self.network:
|
||||||
|
m.fuse_modules()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tcn = TCN(4, 64, 8, block_class=CnnBlock)
|
tcn = TCN(4, 64, 8, block_class=CnnBlock)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user