wekws/wekws/model/fsmn.py
Jean Du b233d46552
[ctc] KWS with CTCloss training and CTC prefix beam search detection. (#135)
* add ctcloss training scripts.

* update compute_det_ctc

* fix typo.

* add fsmn model, can use pretrained kws model from modelscope.

* Add streaming detection of CTC model. Add CTC model onnx export. Add CTC model's result in README; For now CTC model runtime is not supported yet.

* QA run.sh, maxpooling training scripts is compatible. Ready to PR.

* Add a streaming kws demo, support fsmn online forward

* fix typo.

* Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward.

* fix repeat activation, add a interval restrict.

* fix timestamp when subsampling!=1.

* fix flake8, update training script and README, give pretrained ckpt.

* fix quickcheck and flake8

* Add realtime CTC-KWS demo in README.

---------

Co-authored-by: dujing <dujing@xmov.ai>
2023-08-16 10:07:04 +08:00

559 lines
19 KiB
Python

'''
FSMN implementation.
Copyright: 2022-03-09 yueyue.nyy
2023 Jing Du
'''
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def toKaldiMatrix(np_mat):
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
out_str = str(np_mat)
out_str = out_str.replace('[', '')
out_str = out_str.replace(']', '')
return '[ %s ]\n' % out_str
def printTensor(torch_tensor):
re_str = ''
x = torch_tensor.detach().squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
print(re_str)
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else:
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return (output, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
linear_line = fread.readline()
linear_split = linear_line.strip().split()
assert len(linear_split) == 3
assert linear_split[0] == '<LinearTransform>'
self.output_dim = int(linear_split[1])
self.input_dim = int(linear_split[2])
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
self.linear.reset_parameters()
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else:
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return (output, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
affine_line = fread.readline()
affine_split = affine_line.strip().split()
assert len(affine_split) == 3
assert affine_split[0] == '<AffineTransform>'
self.output_dim = int(affine_split[1])
self.input_dim = int(affine_split[2])
print('AffineTransform output/input dim: %d %d' %
(self.output_dim, self.input_dim))
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
self.linear.reset_parameters()
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
# linear_bias = self.state_dict()['linear.bias']
# print(linear_bias.shape)
bias_line = fread.readline()
splits = bias_line.strip().strip('[]').strip().split()
assert len(splits) == self.output_dim
new_bias = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
self.linear.bias.data = new_bias
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim,
self.dim, [lorder, 1],
dilation=[lstride, 1],
groups=self.dim,
bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
self.dim,
self.dim, [rorder, 1],
dilation=[rstride, 1],
groups=self.dim,
bias=False)
else:
self.conv_right = None
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else :
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache) == 0 :
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride
+ self.rorder * self.rstride, 0])
else:
in_cache = in_cache.to(x_per.device)
x_pad = torch.cat((in_cache, x_per), dim=2)
in_cache = x_pad[:, :, -((self.lorder - 1) * self.lstride
+ self.rorder * self.rstride):, :]
y_left = x_pad[:, :, :-self.rorder * self.rstride, :]
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
self.rstride, :] + y_left
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(
x_per.size(2) + self.rorder * self.rstride):, :]
y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right)
y_right = self.conv_right(y_right)
y_right = self.dequant(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return (output, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight']
x = np.flipud(lfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
if self.conv_right is not None:
rfiters = self.state_dict()['conv_right.weight']
x = (rfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
fsmn_line = fread.readline()
fsmn_split = fsmn_line.strip().split()
assert len(fsmn_split) == 3
assert fsmn_split[0] == '<Fsmn>'
self.dim = int(fsmn_split[1])
params_line = fread.readline()
params_split = params_line.strip().strip('[]').strip().split()
assert len(params_split) == 12
assert params_split[0] == '<LearnRateCoef>'
assert params_split[2] == '<LOrder>'
self.lorder = int(params_split[3])
assert params_split[4] == '<ROrder>'
self.rorder = int(params_split[5])
assert params_split[6] == '<LStride>'
self.lstride = int(params_split[7])
assert params_split[8] == '<RStride>'
self.rstride = int(params_split[9])
assert params_split[10] == '<MaxNorm>'
# lfilters = self.state_dict()['conv_left.weight']
# print(lfilters.shape)
print('read conv_left weight')
new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
dtype=torch.float32)
for i in range(self.lorder):
print('read conv_left weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
new_lfilters = torch.transpose(new_lfilters, 0, 2)
# print(new_lfilters.shape)
self.conv_left.reset_parameters()
self.conv_left.weight.data = new_lfilters
# print(self.conv_left.weight.shape)
if self.rorder > 0:
# rfilters = self.state_dict()['conv_right.weight']
# print(rfilters.shape)
print('read conv_right weight')
new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
dtype=torch.float32)
line = fread.readline()
for i in range(self.rorder):
print('read conv_right weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_rfilters[i, 0, :, 0] = cols
new_rfilters = torch.transpose(new_rfilters, 0, 2)
# print(new_rfilters.shape)
self.conv_right.reset_parameters()
self.conv_right.weight.data = new_rfilters
# print(self.conv_right.weight.shape)
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else :
in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
out = self.relu(input)
# out = self.dropout(out)
return (out, in_cache)
def to_kaldi_net(self):
re_str = ''
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
# re_str += '<!EndOfComponent>\n'
return re_str
# re_str = ''
# re_str += '<ParametricRelu> %d %d\n' % (self.dim, self.dim)
# re_str += '<AlphaLearnRateCoef> 0 <BetaLearnRateCoef> 0\n'
# re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32'))
# re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32'))
# re_str += '<!EndOfComponent>\n'
# return re_str
def to_pytorch_net(self, fread):
line = fread.readline()
splits = line.strip().split()
assert len(splits) == 3
assert splits[0] == '<RectifiedLinear>'
assert int(splits[1]) == int(splits[2])
assert int(splits[1]) == self.dim
self.dim = int(splits[1])
def _build_repeats(
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride=1,
rstride=1,
):
repeats = [
nn.Sequential(
LinearTransform(linear_dim, proj_dim),
FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
AffineTransform(proj_dim, linear_dim),
RectifiedLinear(linear_dim, linear_dim))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
class FSMN(nn.Module):
def __init__(
self,
input_dim: int,
input_affine_dim: int,
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int,
):
"""
Args:
input_dim: input dimension
input_affine_dim: input affine layer dimension
fsmn_layers: no. of fsmn units
linear_dim: fsmn input dimension
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
lstride: fsmn left stride
rstride: fsmn right stride
output_affine_dim: output affine layer dimension
output_dim: output dimension
"""
super(FSMN, self).__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.padding = (self.lorder - 1) * self.lstride \
+ self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder,
rorder, lstride, rstride)
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim = -1)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
"""
if in_cache is None or len(in_cache) == 0 :
in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float)
for _ in range(len(self.fsmn))]
input = (input, in_cache)
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
# x4 = self.fsmn(x3)
x4, _ = x3
for layer, module in enumerate(self.fsmn):
x4, in_cache[layer] = module((x4, in_cache[layer]))
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
# x7 = self.softmax(x6)
x7, _ = x6
# return x7, None
return x7, in_cache
def to_kaldi_net(self):
re_str = ''
re_str += '<Nnet>\n'
re_str += self.in_linear1.to_kaldi_net()
re_str += self.in_linear2.to_kaldi_net()
re_str += self.relu.to_kaldi_net()
for fsmn in self.fsmn:
re_str += fsmn[0].to_kaldi_net()
re_str += fsmn[1].to_kaldi_net()
re_str += fsmn[2].to_kaldi_net()
re_str += fsmn[3].to_kaldi_net()
re_str += self.out_linear1.to_kaldi_net()
re_str += self.out_linear2.to_kaldi_net()
re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
# re_str += '<!EndOfComponent>\n'
re_str += '</Nnet>\n'
return re_str
def to_pytorch_net(self, kaldi_file):
with open(kaldi_file, 'r', encoding='utf8') as fread:
fread = open(kaldi_file, 'r')
nnet_start_line = fread.readline()
assert nnet_start_line.strip() == '<Nnet>'
self.in_linear1.to_pytorch_net(fread)
self.in_linear2.to_pytorch_net(fread)
self.relu.to_pytorch_net(fread)
for fsmn in self.fsmn:
fsmn[0].to_pytorch_net(fread)
fsmn[1].to_pytorch_net(fread)
fsmn[2].to_pytorch_net(fread)
fsmn[3].to_pytorch_net(fread)
self.out_linear1.to_pytorch_net(fread)
self.out_linear2.to_pytorch_net(fread)
softmax_line = fread.readline()
softmax_split = softmax_line.strip().split()
assert softmax_split[0].strip() == '<Softmax>'
assert int(softmax_split[1]) == self.output_dim
assert int(softmax_split[2]) == self.output_dim
# '<!EndOfComponent>\n'
nnet_end_line = fread.readline()
assert nnet_end_line.strip() == '</Nnet>'
fread.close()
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())