add fsmn model, can use pretrained kws model from modelscope.

This commit is contained in:
dujing 2023-05-30 17:12:52 +08:00
parent c4b2ddbd11
commit 6d7e7784b5
10 changed files with 889 additions and 11 deletions

View File

@ -0,0 +1,64 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.
context_expansion: true
context_expansion_conf:
left: 2
right: 2
frame_skip: 3
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256
model:
input_dim: 400
preprocessing:
type: none
hidden_dim: 128
backbone:
type: fsmn
input_affine_dim: 140
num_layers: 4
linear_dim: 250
proj_dim: 128
left_order: 10
right_order: 2
left_stride: 1
right_stride: 1
output_affine_dim: 140
classifier:
type: identity
dropout: 0.1
activation:
type: identity
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001
training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

View File

@ -11,10 +11,10 @@ num_keywords=2599
config=conf/ds_tcn_ctc.yaml config=conf/ds_tcn_ctc.yaml
norm_mean=true norm_mean=true
norm_var=true norm_var=true
gpus="0,1,2,3" gpus="0"
checkpoint= checkpoint=
dir=exp/ds_tcn_ctc_ft dir=exp/ds_tcn_ctc
average_model=true average_model=true
num_average=30 num_average=30
if $average_model ;then if $average_model ;then
@ -29,7 +29,7 @@ download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
window_shift=50 window_shift=50
#Whether to train base model. If set true, must put train+dev data in trainbase_dir #Whether to train base model. If set true, must put train+dev data in trainbase_dir
trainbase=true trainbase=false
trainbase_dir=data/base trainbase_dir=data/base
trainbase_config=conf/ds_tcn_ctc_base.yaml trainbase_config=conf/ds_tcn_ctc_base.yaml
trainbase_exp=exp/base trainbase_exp=exp/base
@ -149,11 +149,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt" echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt"
checkpoint=$trainbase_exp/final.pt checkpoint=$trainbase_exp/final.pt
else else
echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/final.pt" echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/23.pt"
if [ ! -d mobvoi_kws_transcription ] ;then if [ ! -d mobvoi_kws_transcription ] ;then
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
fi fi
checkpoint=mobvoi_kws_transcription/final.pt checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not be the best.
fi fi
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \

View File

@ -0,0 +1,167 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
# 2023 Jing Du(thuduj12@163.com)
. ./path.sh
stage=$1
stop_stage=$2
num_keywords=2599
config=conf/fsmn_ctc.yaml
norm_mean=true
norm_var=true
gpus="0"
checkpoint=
dir=exp/fsmn_ctc
average_model=true
num_average=30
if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi
download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir
. tools/parse_options.sh || exit 1;
window_shift=50
if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
echo "Download and extracte all datasets"
local/mobvoi_data_download.sh --dl_dir $download_dir
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Preparing datasets..."
mkdir -p dict
echo "<filler> -1" > dict/words.txt
echo "Hi_Xiaowen 0" >> dict/words.txt
echo "Nihao_Wenwen 1" >> dict/words.txt
for folder in train dev test; do
mkdir -p data/$folder
for prefix in p n; do
mkdir -p data/${prefix}_$folder
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
data/${prefix}_$folder
done
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
rm -rf data/p_$folder data/n_$folder
done
fi
if [ ${stage} -le -0 ] && [ ${stop_stage} -ge -0 ]; then
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
# to transcribe the negative wavs, and upload the transcription to modelscope.
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
for folder in train dev test; do
if [ -f data/$folder/text ];then
mv data/$folder/text data/$folder/text.label
fi
cp mobvoi_kws_transcription/$folder.text data/$folder/text
done
# and we also copy the tokens and lexicon that used in
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
# Here we use tokens.txt and lexicon.txt to convert txt into index
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Use the base model from modelscope"
if [ ! -d speech_charctc_kws_phone-xiaoyun ] ;then
git lfs install
git clone https://www.modelscope.cn/damo/speech_charctc_kws_phone-xiaoyun.git
fi
checkpoint=speech_charctc_kws_phone-xiaoyun/train/base.pt
cp speech_charctc_kws_phone-xiaoyun/train/feature_transform.txt.80dim-l2r2 data/global_cmvn.kaldi
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/global_cmvn.kaldi"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
--train_data data/train/data.list \
--cv_data data/dev/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
if $average_model; then
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
python wekws/bin/score_ctc.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords 嗨小问,你好问问 \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
python wekws/bin/compute_det_ctc.py \
--keywords 嗨小问,你好问问 \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \
--score_file $result_dir/score.txt
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_jit.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--jit_model $dir/$jit_model
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--onnx_model $dir/$onnx_model
fi

View File

@ -165,13 +165,13 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--xlim', '--xlim',
type=int, type=int,
default=10, default=5,
help='xlimrange of x-axis, x is false alarm per hour') help='xlimrange of x-axis, x is false alarm per hour')
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
parser.add_argument( parser.add_argument(
'--ylim', '--ylim',
type=int, type=int,
default=100, default=35,
help='ylimrange of y-axis, y is false rejection rate') help='ylimrange of y-axis, y is false rejection rate')
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')

View File

@ -134,7 +134,8 @@ def main():
output_dim = args.num_keywords output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export # Write model_dir/config.yaml for inference and export
configs['model']['input_dim'] = input_dim if 'input_dim' not in configs['model']:
configs['model']['input_dim'] = input_dim
configs['model']['output_dim'] = output_dim configs['model']['output_dim'] = output_dim
if args.cmvn_file is not None: if args.cmvn_file is not None:
configs['model']['cmvn'] = {} configs['model']['cmvn'] = {}

View File

@ -162,6 +162,16 @@ def Dataset(data_list_file, conf,
spec_aug_conf = conf.get('spec_aug_conf', {}) spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
context_expansion = conf.get('context_expansion', False)
if context_expansion:
context_expansion_conf = conf.get('context_expansion_conf', {})
dataset = Processor(dataset, processor.context_expansion,
**context_expansion_conf)
frame_skip = conf.get('frame_skip', 1)
if frame_skip > 1:
dataset = Processor(dataset, processor.frame_skip, frame_skip)
if shuffle: if shuffle:
shuffle_conf = conf.get('shuffle_conf', {}) shuffle_conf = conf.get('shuffle_conf', {})
dataset = Processor(dataset, processor.shuffle, **shuffle_conf) dataset = Processor(dataset, processor.shuffle, **shuffle_conf)

View File

@ -263,6 +263,51 @@ def shuffle(data, shuffle_size=1000):
for x in buf: for x in buf:
yield x yield x
def context_expansion(data, left=1, right=1):
""" expand left and right frames
Args:
data: Iterable[{key, feat, label}]
left (int): feature left context frames
right (int): feature right context frames
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
index = 0
feats = sample['feat']
ctx_dim = feats.shape[0]
ctx_frm = feats.shape[1] * (left + right + 1)
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
for lag in range(-left, right + 1):
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
feats, -lag, 0)
index = index + feats.shape[1]
# replication pad left margin
for idx in range(left):
for cpx in range(left - idx):
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]
feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
sample['feat'] = feats_ctx
yield sample
def frame_skip(data, skip_rate=1):
""" skip frame
Args:
data: Iterable[{key, feat, label}]
skip_rate (int): take every N-frames for model input
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
feats_skip = sample['feat'][::skip_rate, :]
sample['feat'] = feats_skip
yield sample
def batch(data, batch_size=16): def batch(data, batch_size=16):
""" Static batch the data by `batch_size` """ Static batch the data by `batch_size`

523
wekws/model/fsmn.py Normal file
View File

@ -0,0 +1,523 @@
'''
FSMN implementation.
Copyright: 2022-03-09 yueyue.nyy
'''
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def toKaldiMatrix(np_mat):
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
out_str = str(np_mat)
out_str = out_str.replace('[', '')
out_str = out_str.replace(']', '')
return '[ %s ]\n' % out_str
def printTensor(torch_tensor):
re_str = ''
x = torch_tensor.detach().squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
print(re_str)
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return output
def to_kaldi_net(self):
re_str = ''
re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
linear_line = fread.readline()
linear_split = linear_line.strip().split()
assert len(linear_split) == 3
assert linear_split[0] == '<LinearTransform>'
self.output_dim = int(linear_split[1])
self.input_dim = int(linear_split[2])
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
self.linear.reset_parameters()
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
output = self.quant(input)
output = self.linear(output)
output = self.dequant(output)
return output
def to_kaldi_net(self):
re_str = ''
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
self.input_dim)
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
linear_weights = self.state_dict()['linear.weight']
x = linear_weights.squeeze().numpy()
re_str += toKaldiMatrix(x)
linear_bias = self.state_dict()['linear.bias']
x = linear_bias.squeeze().numpy()
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
affine_line = fread.readline()
affine_split = affine_line.strip().split()
assert len(affine_split) == 3
assert affine_split[0] == '<AffineTransform>'
self.output_dim = int(affine_split[1])
self.input_dim = int(affine_split[2])
print('AffineTransform output/input dim: %d %d' %
(self.output_dim, self.input_dim))
learn_rate_line = fread.readline()
assert learn_rate_line.find('LearnRateCoef') != -1
# linear_weights = self.state_dict()['linear.weight']
# print(linear_weights.shape)
self.linear.reset_parameters()
new_weights = torch.zeros((self.output_dim, self.input_dim),
dtype=torch.float32)
for i in range(self.output_dim):
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.input_dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_weights[i, :] = cols
self.linear.weight.data = new_weights
# linear_bias = self.state_dict()['linear.bias']
# print(linear_bias.shape)
bias_line = fread.readline()
splits = bias_line.strip().strip('[]').strip().split()
assert len(splits) == self.output_dim
new_bias = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
self.linear.bias.data = new_bias
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim,
self.dim, [lorder, 1],
dilation=[lstride, 1],
groups=self.dim,
bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
self.dim,
self.dim, [rorder, 1],
dilation=[rstride, 1],
groups=self.dim,
bias=False)
else:
self.conv_right = None
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_per + y_left
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right)
y_right = self.conv_right(y_right)
y_right = self.dequant(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output
def to_kaldi_net(self):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight']
x = np.flipud(lfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
if self.conv_right is not None:
rfiters = self.state_dict()['conv_right.weight']
x = (rfiters.squeeze().numpy().T)
re_str += toKaldiMatrix(x)
# re_str += '<!EndOfComponent>\n'
return re_str
def to_pytorch_net(self, fread):
fsmn_line = fread.readline()
fsmn_split = fsmn_line.strip().split()
assert len(fsmn_split) == 3
assert fsmn_split[0] == '<Fsmn>'
self.dim = int(fsmn_split[1])
params_line = fread.readline()
params_split = params_line.strip().strip('[]').strip().split()
assert len(params_split) == 12
assert params_split[0] == '<LearnRateCoef>'
assert params_split[2] == '<LOrder>'
self.lorder = int(params_split[3])
assert params_split[4] == '<ROrder>'
self.rorder = int(params_split[5])
assert params_split[6] == '<LStride>'
self.lstride = int(params_split[7])
assert params_split[8] == '<RStride>'
self.rstride = int(params_split[9])
assert params_split[10] == '<MaxNorm>'
# lfilters = self.state_dict()['conv_left.weight']
# print(lfilters.shape)
print('read conv_left weight')
new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
dtype=torch.float32)
for i in range(self.lorder):
print('read conv_left weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
new_lfilters = torch.transpose(new_lfilters, 0, 2)
# print(new_lfilters.shape)
self.conv_left.reset_parameters()
self.conv_left.weight.data = new_lfilters
# print(self.conv_left.weight.shape)
if self.rorder > 0:
# rfilters = self.state_dict()['conv_right.weight']
# print(rfilters.shape)
print('read conv_right weight')
new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
dtype=torch.float32)
line = fread.readline()
for i in range(self.rorder):
print('read conv_right weight -- %d' % i)
line = fread.readline()
splits = line.strip().strip('[]').strip().split()
assert len(splits) == self.dim
cols = torch.tensor([float(item) for item in splits],
dtype=torch.float32)
new_rfilters[i, 0, :, 0] = cols
new_rfilters = torch.transpose(new_rfilters, 0, 2)
# print(new_rfilters.shape)
self.conv_right.reset_parameters()
self.conv_right.weight.data = new_rfilters
# print(self.conv_right.weight.shape)
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
# out = self.dropout(out)
return out
def to_kaldi_net(self):
re_str = ''
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
# re_str += '<!EndOfComponent>\n'
return re_str
# re_str = ''
# re_str += '<ParametricRelu> %d %d\n' % (self.dim, self.dim)
# re_str += '<AlphaLearnRateCoef> 0 <BetaLearnRateCoef> 0\n'
# re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32'))
# re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32'))
# re_str += '<!EndOfComponent>\n'
# return re_str
def to_pytorch_net(self, fread):
line = fread.readline()
splits = line.strip().split()
assert len(splits) == 3
assert splits[0] == '<RectifiedLinear>'
assert int(splits[1]) == int(splits[2])
assert int(splits[1]) == self.dim
self.dim = int(splits[1])
def _build_repeats(
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride=1,
rstride=1,
):
repeats = [
nn.Sequential(
LinearTransform(linear_dim, proj_dim),
FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
AffineTransform(proj_dim, linear_dim),
RectifiedLinear(linear_dim, linear_dim))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
class FSMN(nn.Module):
def __init__(
self,
input_dim: int,
input_affine_dim: int,
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int,
):
"""
Args:
input_dim: input dimension
input_affine_dim: input affine layer dimension
fsmn_layers: no. of fsmn units
linear_dim: fsmn input dimension
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
lstride: fsmn left stride
rstride: fsmn right stride
output_affine_dim: output affine layer dimension
output_dim: output dimension
"""
super(FSMN, self).__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder,
rorder, lstride, rstride)
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim = -1)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
"""
# print("FSMN forward!!!!")
# print(input.shape)
# print(input)
# print(self.in_linear1.input_dim)
# print(self.in_linear1.output_dim)
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
x4 = self.fsmn(x3)
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
# x7 = self.softmax(x6)
# return x7, None
return x6, in_cache
def to_kaldi_net(self):
re_str = ''
re_str += '<Nnet>\n'
re_str += self.in_linear1.to_kaldi_net()
re_str += self.in_linear2.to_kaldi_net()
re_str += self.relu.to_kaldi_net()
for fsmn in self.fsmn:
re_str += fsmn[0].to_kaldi_net()
re_str += fsmn[1].to_kaldi_net()
re_str += fsmn[2].to_kaldi_net()
re_str += fsmn[3].to_kaldi_net()
re_str += self.out_linear1.to_kaldi_net()
re_str += self.out_linear2.to_kaldi_net()
re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
# re_str += '<!EndOfComponent>\n'
re_str += '</Nnet>\n'
return re_str
def to_pytorch_net(self, kaldi_file):
with open(kaldi_file, 'r', encoding='utf8') as fread:
fread = open(kaldi_file, 'r')
nnet_start_line = fread.readline()
assert nnet_start_line.strip() == '<Nnet>'
self.in_linear1.to_pytorch_net(fread)
self.in_linear2.to_pytorch_net(fread)
self.relu.to_pytorch_net(fread)
for fsmn in self.fsmn:
fsmn[0].to_pytorch_net(fread)
fsmn[1].to_pytorch_net(fread)
fsmn[2].to_pytorch_net(fread)
fsmn[3].to_pytorch_net(fread)
self.out_linear1.to_pytorch_net(fread)
self.out_linear2.to_pytorch_net(fread)
softmax_line = fread.readline()
softmax_split = softmax_line.strip().split()
assert softmax_split[0].strip() == '<Softmax>'
assert int(softmax_split[1]) == self.output_dim
assert int(softmax_split[2]) == self.output_dim
# '<!EndOfComponent>\n'
nnet_end_line = fread.readline()
assert nnet_end_line.strip() == '</Nnet>'
fread.close()
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang # Copyright (c) 2021 Binbin Zhang
# 2023 Jing Du
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -25,7 +26,8 @@ from wekws.model.subsampling import (LinearSubsampling1, Conv1dSubsampling1,
NoSubsampling) NoSubsampling)
from wekws.model.tcn import TCN, CnnBlock, DsCnnBlock from wekws.model.tcn import TCN, CnnBlock, DsCnnBlock
from wekws.model.mdtc import MDTC from wekws.model.mdtc import MDTC
from wekws.utils.cmvn import load_cmvn from wekws.utils.cmvn import load_cmvn, load_kaldi_cmvn
from wekws.model.fsmn import FSMN
class KWSModel(nn.Module): class KWSModel(nn.Module):
@ -80,7 +82,10 @@ class KWSModel(nn.Module):
def init_model(configs): def init_model(configs):
cmvn = configs.get('cmvn', {}) cmvn = configs.get('cmvn', {})
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None: if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
mean, istd = load_cmvn(cmvn['cmvn_file']) if "kaldi" in cmvn['cmvn_file']:
mean, istd = load_kaldi_cmvn(cmvn['cmvn_file'])
else:
mean, istd = load_cmvn(cmvn['cmvn_file'])
global_cmvn = GlobalCMVN( global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(), torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(), torch.from_numpy(istd).float(),
@ -135,6 +140,20 @@ def init_model(configs):
hidden_dim, hidden_dim,
kernel_size, kernel_size,
causal=causal) causal=causal)
elif backbone_type == 'fsmn':
input_affine_dim = configs['backbone']['input_affine_dim']
num_layers = configs['backbone']['num_layers']
linear_dim = configs['backbone']['linear_dim']
proj_dim = configs['backbone']['proj_dim']
left_order = configs['backbone']['left_order']
right_order = configs['backbone']['right_order']
left_stride = configs['backbone']['left_stride']
right_stride = configs['backbone']['right_stride']
output_affine_dim = configs['backbone']['output_affine_dim']
backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
proj_dim, left_order, right_order, left_stride,
right_stride, output_affine_dim, output_dim)
else: else:
print('Unknown body type {}'.format(backbone_type)) print('Unknown body type {}'.format(backbone_type))
sys.exit(1) sys.exit(1)
@ -154,6 +173,8 @@ def init_model(configs):
# last means we use last frame to do backpropagation, so the model # last means we use last frame to do backpropagation, so the model
# can be infered streamingly # can be infered streamingly
classifier = LastClassifier(classifier_base) classifier = LastClassifier(classifier_base)
elif classifier_type == 'identity':
classifier = nn.Identity()
else: else:
print('Unknown classifier type {}'.format(classifier_type)) print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1) sys.exit(1)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import json import json
import math import math,re
import numpy as np import numpy as np
@ -42,3 +42,50 @@ def load_cmvn(json_cmvn_file):
variance[i] = 1.0 / math.sqrt(variance[i]) variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance]) cmvn = np.array([means, variance])
return cmvn return cmvn
def load_kaldi_cmvn(cmvn_file):
""" Load the kaldi format cmvn stats file and no need to calculate
Args:
cmvn_file: cmvn stats file in kaldi format
Returns:
a numpy array of [means, vars]
"""
means = None
variance = None
with open(cmvn_file) as f:
all_lines = f.readlines()
for idx, line in enumerate(all_lines):
if line.find('AddShift') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
means_list = means_str.strip().split(' ')
means = [0 - float(s) for s in means_list]
assert len(means) == int(segs[1])
elif line.find('Rescale') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
vars_list = vars_str.strip().split(' ')
variance = [float(s) for s in vars_list]
assert len(variance) == int(segs[1])
elif line.find('Splice') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
splice_list = splice_str.strip().split(' ')
assert len(splice_list) * int(segs[2]) == int(segs[1])
copy_times = len(splice_list)
else:
continue
cmvn = np.array([means, variance])
cmvn = np.tile(cmvn, (1, copy_times))
return cmvn