From 28652a766fa3b0ecdd482c4e0e0acb2c274f5ba7 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Mon, 13 Dec 2021 19:16:50 +0800 Subject: [PATCH] [kws] support fuse and quantization (#41) --- kws/model/classifier.py | 14 ++++++++++++++ kws/model/kws_model.py | 8 ++++++-- kws/model/subsampling.py | 11 ++++++++++- kws/model/tcn.py | 20 ++++++++++++++++++++ 4 files changed, 50 insertions(+), 3 deletions(-) diff --git a/kws/model/classifier.py b/kws/model/classifier.py index 42af97c..190b30e 100644 --- a/kws/model/classifier.py +++ b/kws/model/classifier.py @@ -45,3 +45,17 @@ class ElementClassifier(nn.Module): def forward(self, x: torch.Tensor): return self.classifier(x) + +class LinearClassifier(nn.Module): + """ Wrapper of Linear """ + def __init__(self, input_dim, output_dim): + super().__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x: torch.Tensor): + x = self.quant(x) + x = self.linear(x) + x = self.dequant(x) + return x diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index effc889..f99f0ab 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn from kws.model.cmvn import GlobalCMVN -from kws.model.classifier import GlobalClassifier, LastClassifier +from kws.model.classifier import GlobalClassifier, LastClassifier, LinearClassifier from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.mdtc import MDTC @@ -60,6 +60,10 @@ class KWSModel(torch.nn.Module): x = self.classifier(x) return x + def fuse_modules(self): + self.preprocessing.fuse_modules() + self.backbone.fuse_modules() + def init_model(configs): cmvn = configs.get('cmvn', {}) @@ -143,7 +147,7 @@ def init_model(configs): print('Unknown classifier type {}'.format(classifier_type)) sys.exit(1) else: - classifier = torch.nn.Linear(hidden_dim, output_dim) + classifier = LinearClassifier(hidden_dim, output_dim) kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, preprocessing, backbone, classifier) diff --git a/kws/model/subsampling.py b/kws/model/subsampling.py index 1aba989..06d6311 100644 --- a/kws/model/subsampling.py +++ b/kws/model/subsampling.py @@ -23,6 +23,7 @@ class SubsamplingBase(torch.nn.Module): super().__init__() self.subsampling_rate = 1 + class NoSubsampling(SubsamplingBase): """No subsampling in accordance to the 'none' preprocessing """ @@ -32,6 +33,7 @@ class NoSubsampling(SubsamplingBase): def forward(self, x: torch.Tensor) -> torch.Tensor: return x + class LinearSubsampling1(SubsamplingBase): """Linear transform the input without subsampling """ @@ -39,15 +41,22 @@ class LinearSubsampling1(SubsamplingBase): super().__init__() self.out = torch.nn.Sequential( torch.nn.Linear(idim, odim), - # torch.nn.BatchNorm1d(odim), torch.nn.ReLU(), ) self.subsampling_rate = 1 + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.quant(x) x = self.out(x) + x = self.dequant(x) return x + def fuse_modules(self): + torch.quantization.fuse_modules(self, [['out.0', 'out.1']], + inplace=True) + class Conv1dSubsampling1(SubsamplingBase): """Conv1d transform without subsampling diff --git a/kws/model/tcn.py b/kws/model/tcn.py index 958a951..84366b3 100644 --- a/kws/model/tcn.py +++ b/kws/model/tcn.py @@ -28,6 +28,8 @@ class Block(nn.Module): dropout: float = 0.1): super().__init__() self.padding = (kernel_size - 1) * dilation + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): """ @@ -44,11 +46,16 @@ class Block(nn.Module): assert y.size(2) > self.padding new_cache = y[:, :, -self.padding:] + y = self.quant(y) # self.cnn is defined in the subclass of Block y = self.cnn(y) + y = self.dequant(y) y = y + x # residual connection return y, new_cache + def fuse_modules(self): + self.cnn.fuse_modules() + class CnnBlock(Block): def __init__(self, @@ -68,6 +75,10 @@ class CnnBlock(Block): nn.Dropout(dropout), ) + def fuse_modules(self): + torch.quantization.fuse_modules(self, [['cnn.0', 'cnn.1', 'cnn.2']], + inplace=True) + class DsCnnBlock(Block): """ Depthwise Separable Convolution @@ -93,6 +104,11 @@ class DsCnnBlock(Block): nn.Dropout(dropout), ) + def fuse_modules(self): + torch.quantization.fuse_modules( + self, [['cnn.0', 'cnn.1', 'cnn.2'], ['cnn.3', 'cnn.4', 'cnn.5']], + inplace=True) + class TCN(nn.Module): def __init__(self, @@ -129,6 +145,10 @@ class TCN(nn.Module): new_cache = torch.cat(out_caches, dim=2) return x, new_cache + def fuse_modules(self): + for m in self.network: + m.fuse_modules() + if __name__ == '__main__': tcn = TCN(4, 64, 8, block_class=CnnBlock)