''' FSMN implementation. Copyright: 2022-03-09 yueyue.nyy 2023 Jing Du ''' from typing import Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def toKaldiMatrix(np_mat): np.set_printoptions(threshold=np.inf, linewidth=np.nan) out_str = str(np_mat) out_str = out_str.replace('[', '') out_str = out_str.replace(']', '') return '[ %s ]\n' % out_str def printTensor(torch_tensor): re_str = '' x = torch_tensor.detach().squeeze().numpy() re_str += toKaldiMatrix(x) # re_str += '\n' print(re_str) class LinearTransform(nn.Module): def __init__(self, input_dim, output_dim): super(LinearTransform, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.linear = nn.Linear(input_dim, output_dim, bias=False) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() def forward(self, input): if isinstance(input, tuple): input, in_cache = input else: in_cache = None output = self.quant(input) output = self.linear(output) output = self.dequant(output) return (output, in_cache) def to_kaldi_net(self): re_str = '' re_str += ' %d %d\n' % (self.output_dim, self.input_dim) re_str += ' 1\n' linear_weights = self.state_dict()['linear.weight'] x = linear_weights.squeeze().numpy() re_str += toKaldiMatrix(x) # re_str += '\n' return re_str def to_pytorch_net(self, fread): linear_line = fread.readline() linear_split = linear_line.strip().split() assert len(linear_split) == 3 assert linear_split[0] == '' self.output_dim = int(linear_split[1]) self.input_dim = int(linear_split[2]) learn_rate_line = fread.readline() assert learn_rate_line.find('LearnRateCoef') != -1 self.linear.reset_parameters() # linear_weights = self.state_dict()['linear.weight'] # print(linear_weights.shape) new_weights = torch.zeros((self.output_dim, self.input_dim), dtype=torch.float32) for i in range(self.output_dim): line = fread.readline() splits = line.strip().strip('[]').strip().split() assert len(splits) == self.input_dim cols = torch.tensor([float(item) for item in splits], dtype=torch.float32) new_weights[i, :] = cols self.linear.weight.data = new_weights class AffineTransform(nn.Module): def __init__(self, input_dim, output_dim): super(AffineTransform, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.linear = nn.Linear(input_dim, output_dim) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() def forward(self, input): if isinstance(input, tuple): input, in_cache = input else: in_cache = None output = self.quant(input) output = self.linear(output) output = self.dequant(output) return (output, in_cache) def to_kaldi_net(self): re_str = '' re_str += ' %d %d\n' % (self.output_dim, self.input_dim) re_str += ' 1 1 0\n' linear_weights = self.state_dict()['linear.weight'] x = linear_weights.squeeze().numpy() re_str += toKaldiMatrix(x) linear_bias = self.state_dict()['linear.bias'] x = linear_bias.squeeze().numpy() re_str += toKaldiMatrix(x) # re_str += '\n' return re_str def to_pytorch_net(self, fread): affine_line = fread.readline() affine_split = affine_line.strip().split() assert len(affine_split) == 3 assert affine_split[0] == '' self.output_dim = int(affine_split[1]) self.input_dim = int(affine_split[2]) print('AffineTransform output/input dim: %d %d' % (self.output_dim, self.input_dim)) learn_rate_line = fread.readline() assert learn_rate_line.find('LearnRateCoef') != -1 # linear_weights = self.state_dict()['linear.weight'] # print(linear_weights.shape) self.linear.reset_parameters() new_weights = torch.zeros((self.output_dim, self.input_dim), dtype=torch.float32) for i in range(self.output_dim): line = fread.readline() splits = line.strip().strip('[]').strip().split() assert len(splits) == self.input_dim cols = torch.tensor([float(item) for item in splits], dtype=torch.float32) new_weights[i, :] = cols self.linear.weight.data = new_weights # linear_bias = self.state_dict()['linear.bias'] # print(linear_bias.shape) bias_line = fread.readline() splits = bias_line.strip().strip('[]').strip().split() assert len(splits) == self.output_dim new_bias = torch.tensor([float(item) for item in splits], dtype=torch.float32) self.linear.bias.data = new_bias class FSMNBlock(nn.Module): def __init__( self, input_dim: int, output_dim: int, lorder=None, rorder=None, lstride=1, rstride=1, ): super(FSMNBlock, self).__init__() self.dim = input_dim if lorder is None: return self.lorder = lorder self.rorder = rorder self.lstride = lstride self.rstride = rstride self.conv_left = nn.Conv2d( self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False) if rorder > 0: self.conv_right = nn.Conv2d( self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) else: self.conv_right = None self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() def forward(self, input): if isinstance(input, tuple): input, in_cache = input else : in_cache = None x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) if in_cache is None or len(in_cache)==0 or in_cache[0]==None: x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride+self.rorder*self.rstride, 0]) else : in_cache = in_cache.to(x_per.device) x_pad = torch.cat((in_cache, x_per), dim=2) in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride+self.rorder*self.rstride):, :] y_left = x_pad[:, :, :-self.rorder * self.rstride, :] y_left = self.quant(y_left) y_left = self.conv_left(y_left) y_left = self.dequant(y_left) out = x_pad[:, :, (self.lorder - 1) * self.lstride:-self.rorder * self.rstride, :] + y_left #out = out[:, :, :-self.rorder*self.rstride, :] if self.conv_right is not None: # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] y_right = y_right[:, :, self.rstride:, :] y_right = self.quant(y_right) y_right = self.conv_right(y_right) y_right = self.dequant(y_right) out += y_right out_per = out.permute(0, 3, 2, 1) output = out_per.squeeze(1) return (output, in_cache) def to_kaldi_net(self): re_str = '' re_str += ' %d %d\n' % (self.dim, self.dim) re_str += ' %d %d %d %d %d 0\n' % ( 1, self.lorder, self.rorder, self.lstride, self.rstride) # print(self.conv_left.weight,self.conv_right.weight) lfiters = self.state_dict()['conv_left.weight'] x = np.flipud(lfiters.squeeze().numpy().T) re_str += toKaldiMatrix(x) if self.conv_right is not None: rfiters = self.state_dict()['conv_right.weight'] x = (rfiters.squeeze().numpy().T) re_str += toKaldiMatrix(x) # re_str += '\n' return re_str def to_pytorch_net(self, fread): fsmn_line = fread.readline() fsmn_split = fsmn_line.strip().split() assert len(fsmn_split) == 3 assert fsmn_split[0] == '' self.dim = int(fsmn_split[1]) params_line = fread.readline() params_split = params_line.strip().strip('[]').strip().split() assert len(params_split) == 12 assert params_split[0] == '' assert params_split[2] == '' self.lorder = int(params_split[3]) assert params_split[4] == '' self.rorder = int(params_split[5]) assert params_split[6] == '' self.lstride = int(params_split[7]) assert params_split[8] == '' self.rstride = int(params_split[9]) assert params_split[10] == '' # lfilters = self.state_dict()['conv_left.weight'] # print(lfilters.shape) print('read conv_left weight') new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1), dtype=torch.float32) for i in range(self.lorder): print('read conv_left weight -- %d' % i) line = fread.readline() splits = line.strip().strip('[]').strip().split() assert len(splits) == self.dim cols = torch.tensor([float(item) for item in splits], dtype=torch.float32) new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols new_lfilters = torch.transpose(new_lfilters, 0, 2) # print(new_lfilters.shape) self.conv_left.reset_parameters() self.conv_left.weight.data = new_lfilters # print(self.conv_left.weight.shape) if self.rorder > 0: # rfilters = self.state_dict()['conv_right.weight'] # print(rfilters.shape) print('read conv_right weight') new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1), dtype=torch.float32) line = fread.readline() for i in range(self.rorder): print('read conv_right weight -- %d' % i) line = fread.readline() splits = line.strip().strip('[]').strip().split() assert len(splits) == self.dim cols = torch.tensor([float(item) for item in splits], dtype=torch.float32) new_rfilters[i, 0, :, 0] = cols new_rfilters = torch.transpose(new_rfilters, 0, 2) # print(new_rfilters.shape) self.conv_right.reset_parameters() self.conv_right.weight.data = new_rfilters # print(self.conv_right.weight.shape) class RectifiedLinear(nn.Module): def __init__(self, input_dim, output_dim): super(RectifiedLinear, self).__init__() self.dim = input_dim self.relu = nn.ReLU() self.dropout = nn.Dropout(0.1) def forward(self, input): if isinstance(input, tuple): input, in_cache = input else : in_cache = None out = self.relu(input) # out = self.dropout(out) return (out, in_cache) def to_kaldi_net(self): re_str = '' re_str += ' %d %d\n' % (self.dim, self.dim) # re_str += '\n' return re_str # re_str = '' # re_str += ' %d %d\n' % (self.dim, self.dim) # re_str += ' 0 0\n' # re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32')) # re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32')) # re_str += '\n' # return re_str def to_pytorch_net(self, fread): line = fread.readline() splits = line.strip().split() assert len(splits) == 3 assert splits[0] == '' assert int(splits[1]) == int(splits[2]) assert int(splits[1]) == self.dim self.dim = int(splits[1]) def _build_repeats( fsmn_layers: int, linear_dim: int, proj_dim: int, lorder: int, rorder: int, lstride=1, rstride=1, ): repeats = [ nn.Sequential( LinearTransform(linear_dim, proj_dim), FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), AffineTransform(proj_dim, linear_dim), RectifiedLinear(linear_dim, linear_dim)) for i in range(fsmn_layers) ] return nn.Sequential(*repeats) class FSMN(nn.Module): def __init__( self, input_dim: int, input_affine_dim: int, fsmn_layers: int, linear_dim: int, proj_dim: int, lorder: int, rorder: int, lstride: int, rstride: int, output_affine_dim: int, output_dim: int, ): """ Args: input_dim: input dimension input_affine_dim: input affine layer dimension fsmn_layers: no. of fsmn units linear_dim: fsmn input dimension proj_dim: fsmn projection dimension lorder: fsmn left order rorder: fsmn right order lstride: fsmn left stride rstride: fsmn right stride output_affine_dim: output affine layer dimension output_dim: output dimension """ super(FSMN, self).__init__() self.input_dim = input_dim self.input_affine_dim = input_affine_dim self.fsmn_layers = fsmn_layers self.linear_dim = linear_dim self.proj_dim = proj_dim self.lorder = lorder self.rorder = rorder self.lstride = lstride self.rstride = rstride self.output_affine_dim = output_affine_dim self.output_dim = output_dim self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim) self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder, rorder, lstride, rstride) self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) self.out_linear2 = AffineTransform(output_affine_dim, output_dim) # self.softmax = nn.Softmax(dim = -1) def fuse_modules(self): pass def forward( self, input: torch.Tensor, in_cache #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input (torch.Tensor): Input tensor (B, T, D) in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size """ if in_cache is None or len(in_cache) == 0 or in_cache[0] == None: in_cache = [None for _ in range(len(self.fsmn))] input = (input, in_cache) x1 = self.in_linear1(input) x2 = self.in_linear2(x1) x3 = self.relu(x2) # x4 = self.fsmn(x3) x4, _ = x3 for layer, module in enumerate(self.fsmn): x4, in_cache[layer] = module((x4, in_cache[layer])) x5 = self.out_linear1(x4) x6 = self.out_linear2(x5) # x7 = self.softmax(x6) x7, _ = x6 # return x7, None return x7, in_cache def to_kaldi_net(self): re_str = '' re_str += '\n' re_str += self.in_linear1.to_kaldi_net() re_str += self.in_linear2.to_kaldi_net() re_str += self.relu.to_kaldi_net() for fsmn in self.fsmn: re_str += fsmn[0].to_kaldi_net() re_str += fsmn[1].to_kaldi_net() re_str += fsmn[2].to_kaldi_net() re_str += fsmn[3].to_kaldi_net() re_str += self.out_linear1.to_kaldi_net() re_str += self.out_linear2.to_kaldi_net() re_str += ' %d %d\n' % (self.output_dim, self.output_dim) # re_str += '\n' re_str += '\n' return re_str def to_pytorch_net(self, kaldi_file): with open(kaldi_file, 'r', encoding='utf8') as fread: fread = open(kaldi_file, 'r') nnet_start_line = fread.readline() assert nnet_start_line.strip() == '' self.in_linear1.to_pytorch_net(fread) self.in_linear2.to_pytorch_net(fread) self.relu.to_pytorch_net(fread) for fsmn in self.fsmn: fsmn[0].to_pytorch_net(fread) fsmn[1].to_pytorch_net(fread) fsmn[2].to_pytorch_net(fread) fsmn[3].to_pytorch_net(fread) self.out_linear1.to_pytorch_net(fread) self.out_linear2.to_pytorch_net(fread) softmax_line = fread.readline() softmax_split = softmax_line.strip().split() assert softmax_split[0].strip() == '' assert int(softmax_split[1]) == self.output_dim assert int(softmax_split[2]) == self.output_dim # '\n' nnet_end_line = fread.readline() assert nnet_end_line.strip() == '' fread.close() if __name__ == '__main__': fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) print(fsmn) num_params = sum(p.numel() for p in fsmn.parameters()) print('the number of model params: {}'.format(num_params)) x = torch.zeros(128, 200, 400) # batch-size * time * dim y, _ = fsmn(x) # batch-size * time * dim print('input shape: {}'.format(x.shape)) print('output shape: {}'.format(y.shape)) print(fsmn.to_kaldi_net())