From 6d7e7784b5b9cae9ae36cd227311109b3519e4b0 Mon Sep 17 00:00:00 2001 From: dujing Date: Tue, 30 May 2023 17:12:52 +0800 Subject: [PATCH] add fsmn model, can use pretrained kws model from modelscope. --- examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml | 64 +++ examples/hi_xiaowen/s0/run_ctc.sh | 10 +- examples/hi_xiaowen/s0/run_fsmn_ctc.sh | 167 +++++++ wekws/bin/compute_det_ctc.py | 4 +- wekws/bin/train.py | 3 +- wekws/dataset/dataset.py | 10 + wekws/dataset/processor.py | 45 ++ wekws/model/fsmn.py | 523 ++++++++++++++++++++++ wekws/model/kws_model.py | 25 +- wekws/utils/cmvn.py | 49 +- 10 files changed, 889 insertions(+), 11 deletions(-) create mode 100644 examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml create mode 100644 examples/hi_xiaowen/s0/run_fsmn_ctc.sh create mode 100644 wekws/model/fsmn.py diff --git a/examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml b/examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml new file mode 100644 index 0000000..5ce9596 --- /dev/null +++ b/examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml @@ -0,0 +1,64 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'fbank' + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1. + context_expansion: true + context_expansion_conf: + left: 2 + right: 2 + frame_skip: 3 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 256 + +model: + input_dim: 400 + preprocessing: + type: none + hidden_dim: 128 + backbone: + type: fsmn + input_affine_dim: 140 + num_layers: 4 + linear_dim: 250 + proj_dim: 128 + left_order: 10 + right_order: 2 + left_stride: 1 + right_stride: 1 + output_affine_dim: 140 + classifier: + type: identity + dropout: 0.1 + activation: + type: identity + + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.0001 + +training_config: + grad_clip: 5 + max_epoch: 80 + log_interval: 10 + criterion: ctc + diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh index 3085ff1..ee1f34d 100644 --- a/examples/hi_xiaowen/s0/run_ctc.sh +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -11,10 +11,10 @@ num_keywords=2599 config=conf/ds_tcn_ctc.yaml norm_mean=true norm_var=true -gpus="0,1,2,3" +gpus="0" checkpoint= -dir=exp/ds_tcn_ctc_ft +dir=exp/ds_tcn_ctc average_model=true num_average=30 if $average_model ;then @@ -29,7 +29,7 @@ download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir window_shift=50 #Whether to train base model. If set true, must put train+dev data in trainbase_dir -trainbase=true +trainbase=false trainbase_dir=data/base trainbase_config=conf/ds_tcn_ctc_base.yaml trainbase_exp=exp/base @@ -149,11 +149,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt" checkpoint=$trainbase_exp/final.pt else - echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/final.pt" + echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/23.pt" if [ ! -d mobvoi_kws_transcription ] ;then git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git fi - checkpoint=mobvoi_kws_transcription/final.pt + checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not be the best. fi torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ diff --git a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh new file mode 100644 index 0000000..39c6bff --- /dev/null +++ b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh @@ -0,0 +1,167 @@ +#!/bin/bash +# Copyright 2021 Binbin Zhang(binbzha@qq.com) +# 2023 Jing Du(thuduj12@163.com) + +. ./path.sh + +stage=$1 +stop_stage=$2 +num_keywords=2599 + +config=conf/fsmn_ctc.yaml +norm_mean=true +norm_var=true +gpus="0" + +checkpoint= +dir=exp/fsmn_ctc +average_model=true +num_average=30 +if $average_model ;then + score_checkpoint=$dir/avg_${num_average}.pt +else + score_checkpoint=$dir/final.pt +fi + +download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir + +. tools/parse_options.sh || exit 1; +window_shift=50 + +if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then + echo "Download and extracte all datasets" + local/mobvoi_data_download.sh --dl_dir $download_dir +fi + + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "Preparing datasets..." + mkdir -p dict + echo " -1" > dict/words.txt + echo "Hi_Xiaowen 0" >> dict/words.txt + echo "Nihao_Wenwen 1" >> dict/words.txt + + for folder in train dev test; do + mkdir -p data/$folder + for prefix in p n; do + mkdir -p data/${prefix}_$folder + json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json + local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \ + data/${prefix}_$folder + done + cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp + cat data/p_$folder/text data/n_$folder/text > data/$folder/text + rm -rf data/p_$folder data/n_$folder + done +fi + +if [ ${stage} -le -0 ] && [ ${stop_stage} -ge -0 ]; then +# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) +# to transcribe the negative wavs, and upload the transcription to modelscope. + git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git + for folder in train dev test; do + if [ -f data/$folder/text ];then + mv data/$folder/text data/$folder/text.label + fi + cp mobvoi_kws_transcription/$folder.text data/$folder/text + done + + # and we also copy the tokens and lexicon that used in + # https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary + cp mobvoi_kws_transcription/tokens.txt data/tokens.txt + cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt + +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Compute CMVN and Format datasets" + tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \ + --in_scp data/train/wav.scp \ + --out_cmvn data/train/global_cmvn + + for x in train dev test; do + tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur + + # Here we use tokens.txt and lexicon.txt to convert txt into index + tools/make_list.py data/$x/wav.scp data/$x/text \ + data/$x/wav.dur data/$x/data.list \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt + done +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + + echo "Use the base model from modelscope" + if [ ! -d speech_charctc_kws_phone-xiaoyun ] ;then + git lfs install + git clone https://www.modelscope.cn/damo/speech_charctc_kws_phone-xiaoyun.git + fi + checkpoint=speech_charctc_kws_phone-xiaoyun/train/base.pt + cp speech_charctc_kws_phone-xiaoyun/train/feature_transform.txt.80dim-l2r2 data/global_cmvn.kaldi + + echo "Start training ..." + mkdir -p $dir + cmvn_opts= + $norm_mean && cmvn_opts="--cmvn_file data/global_cmvn.kaldi" + $norm_var && cmvn_opts="$cmvn_opts --norm_var" + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + wekws/bin/train.py --gpus $gpus \ + --config $config \ + --train_data data/train/data.list \ + --cv_data data/dev/data.list \ + --model_dir $dir \ + --num_workers 8 \ + --num_keywords $num_keywords \ + --min_duration 50 \ + --seed 666 \ + $cmvn_opts \ + ${checkpoint:+--checkpoint $checkpoint} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Do model average, Compute FRR/FAR ..." + if $average_model; then + python wekws/bin/average_model.py \ + --dst_model $score_checkpoint \ + --src_path $dir \ + --num ${num_average} \ + --val_best + fi + result_dir=$dir/test_$(basename $score_checkpoint) + mkdir -p $result_dir + python wekws/bin/score_ctc.py \ + --config $dir/config.yaml \ + --test_data data/test/data.list \ + --gpu 0 \ + --batch_size 256 \ + --checkpoint $score_checkpoint \ + --score_file $result_dir/score.txt \ + --num_workers 8 \ + --keywords 嗨小问,你好问问 \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt + + python wekws/bin/compute_det_ctc.py \ + --keywords 嗨小问,你好问问 \ + --test_data data/test/data.list \ + --window_shift $window_shift \ + --step 0.001 \ + --score_file $result_dir/score.txt +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') + onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') + python wekws/bin/export_jit.py \ + --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --jit_model $dir/$jit_model + python wekws/bin/export_onnx.py \ + --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --onnx_model $dir/$onnx_model +fi diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index f5221bb..fcc82bc 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -165,13 +165,13 @@ if __name__ == '__main__': parser.add_argument( '--xlim', type=int, - default=10, + default=5, help='xlim:range of x-axis, x is false alarm per hour') parser.add_argument('--x_step', type=int, default=1, help='step on x-axis') parser.add_argument( '--ylim', type=int, - default=100, + default=35, help='ylim:range of y-axis, y is false rejection rate') parser.add_argument('--y_step', type=int, default=5, help='step on y-axis') diff --git a/wekws/bin/train.py b/wekws/bin/train.py index 632a240..5c5c933 100644 --- a/wekws/bin/train.py +++ b/wekws/bin/train.py @@ -134,7 +134,8 @@ def main(): output_dim = args.num_keywords # Write model_dir/config.yaml for inference and export - configs['model']['input_dim'] = input_dim + if 'input_dim' not in configs['model']: + configs['model']['input_dim'] = input_dim configs['model']['output_dim'] = output_dim if args.cmvn_file is not None: configs['model']['cmvn'] = {} diff --git a/wekws/dataset/dataset.py b/wekws/dataset/dataset.py index 74aaf4d..897b87c 100644 --- a/wekws/dataset/dataset.py +++ b/wekws/dataset/dataset.py @@ -162,6 +162,16 @@ def Dataset(data_list_file, conf, spec_aug_conf = conf.get('spec_aug_conf', {}) dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + context_expansion = conf.get('context_expansion', False) + if context_expansion: + context_expansion_conf = conf.get('context_expansion_conf', {}) + dataset = Processor(dataset, processor.context_expansion, + **context_expansion_conf) + + frame_skip = conf.get('frame_skip', 1) + if frame_skip > 1: + dataset = Processor(dataset, processor.frame_skip, frame_skip) + if shuffle: shuffle_conf = conf.get('shuffle_conf', {}) dataset = Processor(dataset, processor.shuffle, **shuffle_conf) diff --git a/wekws/dataset/processor.py b/wekws/dataset/processor.py index c21de63..f1d9073 100644 --- a/wekws/dataset/processor.py +++ b/wekws/dataset/processor.py @@ -263,6 +263,51 @@ def shuffle(data, shuffle_size=1000): for x in buf: yield x +def context_expansion(data, left=1, right=1): + """ expand left and right frames + Args: + data: Iterable[{key, feat, label}] + left (int): feature left context frames + right (int): feature right context frames + + Returns: + data: Iterable[{key, feat, label}] + """ + for sample in data: + index = 0 + feats = sample['feat'] + ctx_dim = feats.shape[0] + ctx_frm = feats.shape[1] * (left + right + 1) + feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32) + for lag in range(-left, right + 1): + feats_ctx[:, index:index + feats.shape[1]] = torch.roll( + feats, -lag, 0) + index = index + feats.shape[1] + + # replication pad left margin + for idx in range(left): + for cpx in range(left - idx): + feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1) + * feats.shape[1]] = feats_ctx[left, :feats.shape[1]] + + feats_ctx = feats_ctx[:feats_ctx.shape[0] - right] + sample['feat'] = feats_ctx + yield sample + + +def frame_skip(data, skip_rate=1): + """ skip frame + Args: + data: Iterable[{key, feat, label}] + skip_rate (int): take every N-frames for model input + + Returns: + data: Iterable[{key, feat, label}] + """ + for sample in data: + feats_skip = sample['feat'][::skip_rate, :] + sample['feat'] = feats_skip + yield sample def batch(data, batch_size=16): """ Static batch the data by `batch_size` diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py new file mode 100644 index 0000000..3db5aac --- /dev/null +++ b/wekws/model/fsmn.py @@ -0,0 +1,523 @@ +''' +FSMN implementation. + +Copyright: 2022-03-09 yueyue.nyy +''' + +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def toKaldiMatrix(np_mat): + np.set_printoptions(threshold=np.inf, linewidth=np.nan) + out_str = str(np_mat) + out_str = out_str.replace('[', '') + out_str = out_str.replace(']', '') + return '[ %s ]\n' % out_str + + +def printTensor(torch_tensor): + re_str = '' + x = torch_tensor.detach().squeeze().numpy() + re_str += toKaldiMatrix(x) + # re_str += '\n' + print(re_str) + + +class LinearTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(LinearTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim, bias=False) + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, input): + output = self.quant(input) + output = self.linear(output) + output = self.dequant(output) + + return output + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += toKaldiMatrix(x) + # re_str += '\n' + + return re_str + + def to_pytorch_net(self, fread): + linear_line = fread.readline() + linear_split = linear_line.strip().split() + assert len(linear_split) == 3 + assert linear_split[0] == '' + self.output_dim = int(linear_split[1]) + self.input_dim = int(linear_split[2]) + + learn_rate_line = fread.readline() + assert learn_rate_line.find('LearnRateCoef') != -1 + + self.linear.reset_parameters() + + # linear_weights = self.state_dict()['linear.weight'] + # print(linear_weights.shape) + new_weights = torch.zeros((self.output_dim, self.input_dim), + dtype=torch.float32) + for i in range(self.output_dim): + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.input_dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_weights[i, :] = cols + + self.linear.weight.data = new_weights + + +class AffineTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear = nn.Linear(input_dim, output_dim) + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, input): + output = self.quant(input) + output = self.linear(output) + output = self.dequant(output) + + return output + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.output_dim, + self.input_dim) + re_str += ' 1 1 0\n' + + linear_weights = self.state_dict()['linear.weight'] + x = linear_weights.squeeze().numpy() + re_str += toKaldiMatrix(x) + + linear_bias = self.state_dict()['linear.bias'] + x = linear_bias.squeeze().numpy() + re_str += toKaldiMatrix(x) + # re_str += '\n' + + return re_str + + def to_pytorch_net(self, fread): + affine_line = fread.readline() + affine_split = affine_line.strip().split() + assert len(affine_split) == 3 + assert affine_split[0] == '' + self.output_dim = int(affine_split[1]) + self.input_dim = int(affine_split[2]) + print('AffineTransform output/input dim: %d %d' % + (self.output_dim, self.input_dim)) + + learn_rate_line = fread.readline() + assert learn_rate_line.find('LearnRateCoef') != -1 + + # linear_weights = self.state_dict()['linear.weight'] + # print(linear_weights.shape) + self.linear.reset_parameters() + + new_weights = torch.zeros((self.output_dim, self.input_dim), + dtype=torch.float32) + for i in range(self.output_dim): + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.input_dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_weights[i, :] = cols + + self.linear.weight.data = new_weights + + # linear_bias = self.state_dict()['linear.bias'] + # print(linear_bias.shape) + bias_line = fread.readline() + splits = bias_line.strip().strip('[]').strip().split() + assert len(splits) == self.output_dim + new_bias = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + + self.linear.bias.data = new_bias + + +class FSMNBlock(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + lorder=None, + rorder=None, + lstride=1, + rstride=1, + ): + super(FSMNBlock, self).__init__() + + self.dim = input_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.conv_left = nn.Conv2d( + self.dim, + self.dim, [lorder, 1], + dilation=[lstride, 1], + groups=self.dim, + bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + self.dim, + self.dim, [rorder, 1], + dilation=[rstride, 1], + groups=self.dim, + bias=False) + else: + self.conv_right = None + + self.quant = torch.quantization.QuantStub() + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, input): + x = torch.unsqueeze(input, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + y_left = self.quant(y_left) + y_left = self.conv_left(y_left) + y_left = self.dequant(y_left) + out = x_per + y_left + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + y_right = self.quant(y_right) + y_right = self.conv_right(y_right) + y_right = self.dequant(y_right) + out += y_right + + out_per = out.permute(0, 3, 2, 1) + output = out_per.squeeze(1) + + return output + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + re_str += ' %d %d %d %d %d 0\n' % ( + 1, self.lorder, self.rorder, self.lstride, self.rstride) + + # print(self.conv_left.weight,self.conv_right.weight) + lfiters = self.state_dict()['conv_left.weight'] + x = np.flipud(lfiters.squeeze().numpy().T) + re_str += toKaldiMatrix(x) + + if self.conv_right is not None: + rfiters = self.state_dict()['conv_right.weight'] + x = (rfiters.squeeze().numpy().T) + re_str += toKaldiMatrix(x) + # re_str += '\n' + + return re_str + + def to_pytorch_net(self, fread): + fsmn_line = fread.readline() + fsmn_split = fsmn_line.strip().split() + assert len(fsmn_split) == 3 + assert fsmn_split[0] == '' + self.dim = int(fsmn_split[1]) + + params_line = fread.readline() + params_split = params_line.strip().strip('[]').strip().split() + assert len(params_split) == 12 + assert params_split[0] == '' + assert params_split[2] == '' + self.lorder = int(params_split[3]) + assert params_split[4] == '' + self.rorder = int(params_split[5]) + assert params_split[6] == '' + self.lstride = int(params_split[7]) + assert params_split[8] == '' + self.rstride = int(params_split[9]) + assert params_split[10] == '' + + # lfilters = self.state_dict()['conv_left.weight'] + # print(lfilters.shape) + print('read conv_left weight') + new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1), + dtype=torch.float32) + for i in range(self.lorder): + print('read conv_left weight -- %d' % i) + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols + + new_lfilters = torch.transpose(new_lfilters, 0, 2) + # print(new_lfilters.shape) + + self.conv_left.reset_parameters() + self.conv_left.weight.data = new_lfilters + # print(self.conv_left.weight.shape) + + if self.rorder > 0: + # rfilters = self.state_dict()['conv_right.weight'] + # print(rfilters.shape) + print('read conv_right weight') + new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1), + dtype=torch.float32) + line = fread.readline() + for i in range(self.rorder): + print('read conv_right weight -- %d' % i) + line = fread.readline() + splits = line.strip().strip('[]').strip().split() + assert len(splits) == self.dim + cols = torch.tensor([float(item) for item in splits], + dtype=torch.float32) + new_rfilters[i, 0, :, 0] = cols + + new_rfilters = torch.transpose(new_rfilters, 0, 2) + # print(new_rfilters.shape) + self.conv_right.reset_parameters() + self.conv_right.weight.data = new_rfilters + # print(self.conv_right.weight.shape) + + +class RectifiedLinear(nn.Module): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + + def forward(self, input): + out = self.relu(input) + # out = self.dropout(out) + return out + + def to_kaldi_net(self): + re_str = '' + re_str += ' %d %d\n' % (self.dim, self.dim) + # re_str += '\n' + return re_str + + # re_str = '' + # re_str += ' %d %d\n' % (self.dim, self.dim) + # re_str += ' 0 0\n' + # re_str += toKaldiMatrix(np.ones((self.dim), dtype = 'int32')) + # re_str += toKaldiMatrix(np.zeros((self.dim), dtype = 'int32')) + # re_str += '\n' + # return re_str + + def to_pytorch_net(self, fread): + line = fread.readline() + splits = line.strip().split() + assert len(splits) == 3 + assert splits[0] == '' + assert int(splits[1]) == int(splits[2]) + assert int(splits[1]) == self.dim + self.dim = int(splits[1]) + + +def _build_repeats( + fsmn_layers: int, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride=1, + rstride=1, +): + repeats = [ + nn.Sequential( + LinearTransform(linear_dim, proj_dim), + FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), + AffineTransform(proj_dim, linear_dim), + RectifiedLinear(linear_dim, linear_dim)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) + + +class FSMN(nn.Module): + + def __init__( + self, + input_dim: int, + input_affine_dim: int, + fsmn_layers: int, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride: int, + rstride: int, + output_affine_dim: int, + output_dim: int, + ): + """ + Args: + input_dim: input dimension + input_affine_dim: input affine layer dimension + fsmn_layers: no. of fsmn units + linear_dim: fsmn input dimension + proj_dim: fsmn projection dimension + lorder: fsmn left order + rorder: fsmn right order + lstride: fsmn left stride + rstride: fsmn right stride + output_affine_dim: output affine layer dimension + output_dim: output dimension + """ + super(FSMN, self).__init__() + + self.input_dim = input_dim + self.input_affine_dim = input_affine_dim + self.fsmn_layers = fsmn_layers + self.linear_dim = linear_dim + self.proj_dim = proj_dim + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + self.output_affine_dim = output_affine_dim + self.output_dim = output_dim + + self.padding = (self.lorder-1) * self.lstride + self.rorder * self.rstride + + self.in_linear1 = AffineTransform(input_dim, input_affine_dim) + self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) + + self.fsmn = _build_repeats(fsmn_layers, linear_dim, proj_dim, lorder, + rorder, lstride, rstride) + + self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) + self.out_linear2 = AffineTransform(output_affine_dim, output_dim) + # self.softmax = nn.Softmax(dim = -1) + + def fuse_modules(self): + pass + + def forward( + self, + input: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input (torch.Tensor): Input tensor (B, T, D) + in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size + """ + + # print("FSMN forward!!!!") + # print(input.shape) + # print(input) + # print(self.in_linear1.input_dim) + # print(self.in_linear1.output_dim) + + x1 = self.in_linear1(input) + x2 = self.in_linear2(x1) + x3 = self.relu(x2) + x4 = self.fsmn(x3) + x5 = self.out_linear1(x4) + x6 = self.out_linear2(x5) + # x7 = self.softmax(x6) + + # return x7, None + return x6, in_cache + + def to_kaldi_net(self): + re_str = '' + re_str += '\n' + re_str += self.in_linear1.to_kaldi_net() + re_str += self.in_linear2.to_kaldi_net() + re_str += self.relu.to_kaldi_net() + + for fsmn in self.fsmn: + re_str += fsmn[0].to_kaldi_net() + re_str += fsmn[1].to_kaldi_net() + re_str += fsmn[2].to_kaldi_net() + re_str += fsmn[3].to_kaldi_net() + + re_str += self.out_linear1.to_kaldi_net() + re_str += self.out_linear2.to_kaldi_net() + re_str += ' %d %d\n' % (self.output_dim, self.output_dim) + # re_str += '\n' + re_str += '\n' + + return re_str + + def to_pytorch_net(self, kaldi_file): + with open(kaldi_file, 'r', encoding='utf8') as fread: + fread = open(kaldi_file, 'r') + nnet_start_line = fread.readline() + assert nnet_start_line.strip() == '' + + self.in_linear1.to_pytorch_net(fread) + self.in_linear2.to_pytorch_net(fread) + self.relu.to_pytorch_net(fread) + + for fsmn in self.fsmn: + fsmn[0].to_pytorch_net(fread) + fsmn[1].to_pytorch_net(fread) + fsmn[2].to_pytorch_net(fread) + fsmn[3].to_pytorch_net(fread) + + self.out_linear1.to_pytorch_net(fread) + self.out_linear2.to_pytorch_net(fread) + + softmax_line = fread.readline() + softmax_split = softmax_line.strip().split() + assert softmax_split[0].strip() == '' + assert int(softmax_split[1]) == self.output_dim + assert int(softmax_split[2]) == self.output_dim + # '\n' + + nnet_end_line = fread.readline() + assert nnet_end_line.strip() == '' + fread.close() + + +if __name__ == '__main__': + fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) + print(fsmn) + + num_params = sum(p.numel() for p in fsmn.parameters()) + print('the number of model params: {}'.format(num_params)) + x = torch.zeros(128, 200, 400) # batch-size * time * dim + y, _ = fsmn(x) # batch-size * time * dim + print('input shape: {}'.format(x.shape)) + print('output shape: {}'.format(y.shape)) + + print(fsmn.to_kaldi_net()) diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index 3df0849..d4d7b5c 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -1,4 +1,5 @@ # Copyright (c) 2021 Binbin Zhang +# 2023 Jing Du # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +26,8 @@ from wekws.model.subsampling import (LinearSubsampling1, Conv1dSubsampling1, NoSubsampling) from wekws.model.tcn import TCN, CnnBlock, DsCnnBlock from wekws.model.mdtc import MDTC -from wekws.utils.cmvn import load_cmvn +from wekws.utils.cmvn import load_cmvn, load_kaldi_cmvn +from wekws.model.fsmn import FSMN class KWSModel(nn.Module): @@ -80,7 +82,10 @@ class KWSModel(nn.Module): def init_model(configs): cmvn = configs.get('cmvn', {}) if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None: - mean, istd = load_cmvn(cmvn['cmvn_file']) + if "kaldi" in cmvn['cmvn_file']: + mean, istd = load_kaldi_cmvn(cmvn['cmvn_file']) + else: + mean, istd = load_cmvn(cmvn['cmvn_file']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), torch.from_numpy(istd).float(), @@ -135,6 +140,20 @@ def init_model(configs): hidden_dim, kernel_size, causal=causal) + elif backbone_type == 'fsmn': + input_affine_dim = configs['backbone']['input_affine_dim'] + num_layers = configs['backbone']['num_layers'] + linear_dim = configs['backbone']['linear_dim'] + proj_dim = configs['backbone']['proj_dim'] + left_order = configs['backbone']['left_order'] + right_order = configs['backbone']['right_order'] + left_stride = configs['backbone']['left_stride'] + right_stride = configs['backbone']['right_stride'] + output_affine_dim = configs['backbone']['output_affine_dim'] + backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim, + proj_dim, left_order, right_order, left_stride, + right_stride, output_affine_dim, output_dim) + else: print('Unknown body type {}'.format(backbone_type)) sys.exit(1) @@ -154,6 +173,8 @@ def init_model(configs): # last means we use last frame to do backpropagation, so the model # can be infered streamingly classifier = LastClassifier(classifier_base) + elif classifier_type == 'identity': + classifier = nn.Identity() else: print('Unknown classifier type {}'.format(classifier_type)) sys.exit(1) diff --git a/wekws/utils/cmvn.py b/wekws/utils/cmvn.py index 1d2ecbd..4280679 100644 --- a/wekws/utils/cmvn.py +++ b/wekws/utils/cmvn.py @@ -14,7 +14,7 @@ # limitations under the License. import json -import math +import math,re import numpy as np @@ -42,3 +42,50 @@ def load_cmvn(json_cmvn_file): variance[i] = 1.0 / math.sqrt(variance[i]) cmvn = np.array([means, variance]) return cmvn + +def load_kaldi_cmvn(cmvn_file): + """ Load the kaldi format cmvn stats file and no need to calculate + + Args: + cmvn_file: cmvn stats file in kaldi format + + Returns: + a numpy array of [means, vars] + """ + + means = None + variance = None + with open(cmvn_file) as f: + all_lines = f.readlines() + for idx, line in enumerate(all_lines): + if line.find('AddShift') != -1: + segs = line.strip().split(' ') + assert len(segs) == 3 + next_line = all_lines[idx + 1] + means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] + means_list = means_str.strip().split(' ') + means = [0 - float(s) for s in means_list] + assert len(means) == int(segs[1]) + elif line.find('Rescale') != -1: + segs = line.strip().split(' ') + assert len(segs) == 3 + next_line = all_lines[idx + 1] + vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] + vars_list = vars_str.strip().split(' ') + variance = [float(s) for s in vars_list] + assert len(variance) == int(segs[1]) + elif line.find('Splice') != -1: + segs = line.strip().split(' ') + assert len(segs) == 3 + next_line = all_lines[idx + 1] + splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0] + splice_list = splice_str.strip().split(' ') + assert len(splice_list) * int(segs[2]) == int(segs[1]) + copy_times = len(splice_list) + else: + continue + + cmvn = np.array([means, variance]) + cmvn = np.tile(cmvn, (1, copy_times)) + + return cmvn \ No newline at end of file