Merge pull request #3 from jingyonghou/master

add MDTC model support
This commit is contained in:
xiaohou 2021-11-11 10:01:03 +08:00 committed by GitHub
commit 531d795bce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 490 additions and 89 deletions

View File

@ -5,11 +5,12 @@ dataset_conf:
resample_conf: resample_conf:
resample_rate: 16000 resample_rate: 16000
speed_perturb: false speed_perturb: false
fbank_conf: feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40 num_mel_bins: 40
frame_shift: 10 frame_shift: 10
frame_length: 25 frame_length: 25
dither: 0.1 dither: 1.0
spec_aug: true spec_aug: true
spec_aug_conf: spec_aug_conf:
num_t_mask: 1 num_t_mask: 1
@ -24,9 +25,9 @@ dataset_conf:
model: model:
hidden_dim: 64 hidden_dim: 64
subsampling: preprocessing:
type: linear type: linear
body: backbone:
type: tcn type: tcn
ds: true ds: true
num_layers: 4 num_layers: 4

View File

@ -5,11 +5,12 @@ dataset_conf:
resample_conf: resample_conf:
resample_rate: 16000 resample_rate: 16000
speed_perturb: false speed_perturb: false
fbank_conf: feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40 num_mel_bins: 40
frame_shift: 10 frame_shift: 10
frame_length: 25 frame_length: 25
dither: 0.1 dither: 1.0
spec_aug: false spec_aug: false
spec_aug_conf: spec_aug_conf:
num_t_mask: 2 num_t_mask: 2
@ -24,9 +25,9 @@ dataset_conf:
model: model:
hidden_dim: 128 hidden_dim: 128
subsampling: preprocessing:
type: linear type: linear
body: backbone:
type: gru type: gru
num_layers: 2 num_layers: 2

View File

@ -0,0 +1,46 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'mfcc'
num_ceps: 80
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
feature_dither: 0.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 40
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 100
model:
hidden_dim: 64
preprocessing:
type: none
backbone:
type: mdtc
num_stack: 4
stack_size: 4
kernel_size: 5
hidden_dim: 64
optim: adam
optim_conf:
lr: 0.001
training_config:
grad_clip: 5
max_epoch: 100
log_interval: 10

View File

@ -5,11 +5,12 @@ dataset_conf:
resample_conf: resample_conf:
resample_rate: 16000 resample_rate: 16000
speed_perturb: false speed_perturb: false
fbank_conf: feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40 num_mel_bins: 40
frame_shift: 10 frame_shift: 10
frame_length: 25 frame_length: 25
dither: 0.1 dither: 1.0
spec_aug: false spec_aug: false
spec_aug_conf: spec_aug_conf:
num_t_mask: 2 num_t_mask: 2
@ -24,9 +25,9 @@ dataset_conf:
model: model:
hidden_dim: 64 hidden_dim: 64
subsampling: preprocessing:
type: linear type: linear
body: backbone:
type: tcn type: tcn
ds: false ds: false
num_layers: 4 num_layers: 4

View File

@ -3,24 +3,24 @@
. ./path.sh . ./path.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" export CUDA_VISIBLE_DEVICES="0"
stage=0 stage=0
stop_stage=4 stop_stage=4
num_keywords=2 num_keywords=2
config=conf/ds_tcn.yaml config=conf/mdtc.yaml
norm_mean=true norm_mean=false
norm_var=true norm_var=false
gpu_id=0 gpu_id=0
checkpoint= checkpoint=
dir=exp/ds_tcn dir=exp/mdtc
num_average=30 num_average=10
score_checkpoint=$dir/avg_${num_average}.pt score_checkpoint=$dir/avg_${num_average}.pt
download_dir=/export/expts6/binbinzhang/data/ download_dir=./data/local # your data dir
. tools/parse_options.sh || exit 1; . tools/parse_options.sh || exit 1;
@ -34,19 +34,16 @@ fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Preparing datasets..." echo "Preparing datasets..."
mkdir dict mkdir -p dict
echo "<filler> -1" > dict/words.txt echo "<filler> -1" > dict/words.txt
echo "Hi_Xiaowen 0" >> dict/words.txt echo "Hi_Xiaowen 0" >> dict/words.txt
echo "Nihao_Wenwen 1" >> dict/words.txt echo "Nihao_Wenwen 1" >> dict/words.txt
for folder in train dev eval; do for folder in train dev test; do
mkdir -p data/$folder mkdir -p data/$folder
for prefix in p n; do for prefix in p n; do
mkdir -p data/${prefix}_$folder mkdir -p data/${prefix}_$folder
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
if [ $folder = "eval" ]; then
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_test.json
fi
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \ local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
data/${prefix}_$folder data/${prefix}_$folder
done done
@ -63,7 +60,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--in_scp data/train/wav.scp \ --in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn --out_cmvn data/train/global_cmvn
for x in train dev eval; do for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
tools/make_list.py data/$x/wav.scp data/$x/text \ tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list data/$x/wav.dur data/$x/data.list
@ -100,27 +97,31 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# Compute posterior score # Compute posterior score
result_dir=$dir/test_$(basename $score_checkpoint) result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir mkdir -p $result_dir
python kws/bin/score.py --gpu -1 \ python kws/bin/score.py --gpu 1 \
--config $dir/config.yaml \ --config $dir/config.yaml \
--test_data data/eval/data.list \ --test_data data/test/data.list \
--batch_size 256 \ --batch_size 256 \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt --score_file $result_dir/score.txt
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# Compute detection error tradeoff # Compute detection error tradeoff
result_dir=$dir/test_$(basename $score_checkpoint)
for keyword in 0 1; do for keyword in 0 1; do
python kws/bin/compute_det.py \ python kws/bin/compute_det.py \
--keyword $keyword \ --keyword $keyword \
--test_data data/eval/data.list \ --test_data data/test/data.list \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt --stats_file $result_dir/stats.${keyword}.txt
done done
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
python kws/bin/export_jit.py --config $dir/config.yaml \ python kws/bin/export_jit.py --config $dir/config.yaml \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--output_file $dir/final.zip \ --output_file $dir/final.zip \
--output_quant_file $dir/final.quant.zip --output_quant_file $dir/final.quant.zip
fi fi

View File

@ -56,8 +56,7 @@ def main():
# Export quantized jit torch script model # Export quantized jit torch script model
if args.output_quant_file: if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic( quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8 model, {torch.nn.Linear}, dtype=torch.qint8)
)
print(quantized_model) print(quantized_model)
script_quant_model = torch.jit.script(quantized_model) script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file) script_quant_model.save(args.output_quant_file)

View File

@ -135,7 +135,8 @@ def main():
num_workers=args.num_workers, num_workers=args.num_workers,
prefetch_factor=args.prefetch) prefetch_factor=args.prefetch)
input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] input_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
output_dim = args.num_keywords output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export # Write model_dir/config.yaml for inference and export
@ -160,9 +161,9 @@ def main():
# !!!IMPORTANT!!! # !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements # the code to satisfy the script export requirements
if args.rank == 0: # if args.rank == 0:
script_model = torch.jit.script(model) # script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip')) # script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor() executor = Executor()
# If specify checkpoint, load some info from checkpoint # If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None: if args.checkpoint is not None:

View File

@ -136,10 +136,13 @@ def Dataset(data_list_file, conf, partition=True):
speed_perturb = conf.get('speed_perturb', False) speed_perturb = conf.get('speed_perturb', False)
if speed_perturb: if speed_perturb:
dataset = Processor(dataset, processor.speed_perturb) dataset = Processor(dataset, processor.speed_perturb)
feature_extraction_conf = conf.get('feature_extraction_conf', {})
fbank_conf = conf.get('fbank_conf', {}) if feature_extraction_conf['feature_type'] == 'mfcc':
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) dataset = Processor(dataset, processor.compute_mfcc,
**feature_extraction_conf)
elif feature_extraction_conf['feature_type'] == 'fbank':
dataset = Processor(dataset, processor.compute_fbank,
**feature_extraction_conf)
spec_aug = conf.get('spec_aug', True) spec_aug = conf.get('spec_aug', True)
if spec_aug: if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {}) spec_aug_conf = conf.get('spec_aug_conf', {})

View File

@ -127,7 +127,47 @@ def speed_perturb(data, speeds=None):
yield sample yield sample
def compute_mfcc(
data,
feature_type='mfcc',
num_ceps=80,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
):
"""Extract mfcc
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.mfcc(
waveform,
num_ceps=num_ceps,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate,
)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_fbank(data, def compute_fbank(data,
feature_type='fbank',
num_mel_bins=23, num_mel_bins=23,
frame_length=25, frame_length=25,
frame_shift=10, frame_shift=10,

View File

@ -20,45 +20,55 @@ import torch
from kws.model.cmvn import GlobalCMVN from kws.model.cmvn import GlobalCMVN
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1 from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC
from kws.utils.cmvn import load_cmvn from kws.utils.cmvn import load_cmvn
class KwsModel(torch.nn.Module): class KWSModel(torch.nn.Module):
""" Our model consists of four parts: """Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim) 1. global_cmvn: Optional, (idim, idim)
2. subsampling: subsampling the input, (idim, hdim) 2. preprocessing: feature dimention projection, (idim, hdim)
3. body: body of the whole network, (hdim, hdim) 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. linear: a linear layer, (hdim, odim) 4. classifier: output layer or classifier of KWS model, (hdim, odim)
""" """
def __init__(self, idim: int, odim: int, hdim: int, def __init__(
global_cmvn: Optional[torch.nn.Module], self,
subsampling: torch.nn.Module, body: torch.nn.Module): idim: int,
odim: int,
hdim: int,
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
):
super().__init__() super().__init__()
self.idim = idim self.idim = idim
self.odim = odim self.odim = odim
self.hdim = hdim self.hdim = hdim
self.global_cmvn = global_cmvn self.global_cmvn = global_cmvn
self.subsampling = subsampling self.preprocessing = preprocessing
self.body = body self.backbone = backbone
self.linear = torch.nn.Linear(hdim, odim) self.classifier = torch.nn.Linear(hdim, odim)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None: if self.global_cmvn is not None:
x = self.global_cmvn(x) x = self.global_cmvn(x)
x = self.subsampling(x) if self.preprocessing:
x, _ = self.body(x) x = self.preprocessing(x)
x = self.linear(x) x, _ = self.backbone(x)
x = self.classifier(x)
x = torch.sigmoid(x) x = torch.sigmoid(x)
return x return x
def init_model(configs): def init_model(configs):
cmvn = configs.get('cmvn', {}) cmvn = configs.get('cmvn', {})
if cmvn['cmvn_file'] is not None: if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
mean, istd = load_cmvn(cmvn['cmvn_file']) mean, istd = load_cmvn(cmvn['cmvn_file'])
global_cmvn = GlobalCMVN( global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(), torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(), cmvn['norm_var']) torch.from_numpy(istd).float(),
cmvn['norm_var'],
)
else: else:
global_cmvn = None global_cmvn = None
@ -66,36 +76,52 @@ def init_model(configs):
output_dim = configs['output_dim'] output_dim = configs['output_dim']
hidden_dim = configs['hidden_dim'] hidden_dim = configs['hidden_dim']
subsampling_type = configs['subsampling']['type'] prep_type = configs['preprocessing']['type']
if subsampling_type == 'linear': if prep_type == 'linear':
subsampling = LinearSubsampling1(input_dim, hidden_dim) preprocessing = LinearSubsampling1(input_dim, hidden_dim)
elif subsampling_type == 'cnn1d_s1': elif prep_type == 'cnn1d_s1':
subsampling = Conv1dSubsampling1(input_dim, hidden_dim) preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
elif prep_type == 'none':
preprocessing = None
else: else:
print('Unknown subsampling type {}'.format(subsampling_type)) print('Unknown preprocessing type {}'.format(prep_type))
sys.exit(1) sys.exit(1)
body_type = configs['body']['type'] backbone_type = configs['backbone']['type']
num_layers = configs['body']['num_layers'] if backbone_type == 'gru':
if body_type == 'gru': num_layers = configs['backbone']['num_layers']
body = torch.nn.GRU(hidden_dim, backbone = torch.nn.GRU(hidden_dim,
hidden_dim, hidden_dim,
num_layers=num_layers, num_layers=num_layers,
batch_first=True) batch_first=True)
elif body_type == 'tcn': elif backbone_type == 'tcn':
# Depthwise Separable # Depthwise Separable
ds = configs['body'].get('ds', False) num_layers = configs['backbone']['num_layers']
ds = configs['backbone'].get('ds', False)
if ds: if ds:
block_class = DsCnnBlock block_class = DsCnnBlock
else: else:
block_class = CnnBlock block_class = CnnBlock
kernel_size = configs['body'].get('kernel_size', 8) kernel_size = configs['backbone'].get('kernel_size', 8)
dropout = configs['body'].get('drouput', 0.1) dropout = configs['backbone'].get('drouput', 0.1)
body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class) backbone = TCN(num_layers, hidden_dim, kernel_size, dropout,
block_class)
elif backbone_type == 'mdtc':
stack_size = configs['backbone']['stack_size']
num_stack = configs['backbone']['num_stack']
kernel_size = configs['backbone']['kernel_size']
hidden_dim = configs['backbone']['hidden_dim']
backbone = MDTC(num_stack,
stack_size,
input_dim,
hidden_dim,
kernel_size,
causal=True)
else: else:
print('Unknown body type {}'.format(body_type)) print('Unknown body type {}'.format(backbone_type))
sys.exit(1) sys.exit(1)
kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn, kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
subsampling, body) preprocessing, backbone)
return kws_model return kws_model

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn.functional as F
from kws.utils.mask import padding_mask from kws.utils.mask import padding_mask

267
kws/model/mdtc.py Normal file
View File

@ -0,0 +1,267 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DSDilatedConv1d(nn.Module):
"""Dilated Depthwise-Separable Convolution"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
dilation: int = 1,
stride: int = 1,
bias: bool = True,
):
super(DSDilatedConv1d, self).__init__()
self.receptive_fields = dilation * (kernel_size - 1)
self.conv = nn.Conv1d(
in_channels,
in_channels,
kernel_size,
padding=0,
dilation=dilation,
stride=stride,
groups=in_channels,
bias=bias,
)
self.bn = nn.BatchNorm1d(in_channels)
self.pointwise = nn.Conv1d(in_channels,
out_channels,
kernel_size=1,
padding=0,
dilation=1,
bias=bias)
def forward(self, inputs: torch.Tensor):
outputs = self.conv(inputs)
outputs = self.bn(outputs)
outputs = self.pointwise(outputs)
return outputs
class TCNBlock(nn.Module):
def __init__(
self,
in_channels: int,
res_channels: int,
kernel_size: int,
dilation: int,
causal: bool,
):
super(TCNBlock, self).__init__()
self.in_channels = in_channels
self.res_channels = res_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.causal = causal
self.receptive_fields = dilation * (kernel_size - 1)
self.half_receptive_fields = self.receptive_fields // 2
self.conv1 = DSDilatedConv1d(
in_channels=in_channels,
out_channels=res_channels,
kernel_size=kernel_size,
dilation=dilation,
)
self.bn1 = nn.BatchNorm1d(res_channels)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv1d(in_channels=res_channels,
out_channels=res_channels,
kernel_size=1)
self.bn2 = nn.BatchNorm1d(res_channels)
self.relu2 = nn.ReLU()
def forward(self, inputs: torch.Tensor):
outputs = self.relu1(self.bn1(self.conv1(inputs)))
outputs = self.bn2(self.conv2(outputs))
if self.causal:
inputs = inputs[:, :, self.receptive_fields:]
else:
inputs = inputs[:, :, self.
half_receptive_fields:-self.half_receptive_fields]
if self.in_channels == self.res_channels:
res_out = self.relu2(outputs + inputs)
else:
res_out = self.relu2(outputs)
return res_out
class TCNStack(nn.Module):
def __init__(
self,
in_channels: int,
stack_num: int,
stack_size: int,
res_channels: int,
kernel_size: int,
causal: bool,
):
super(TCNStack, self).__init__()
self.in_channels = in_channels
self.stack_num = stack_num
self.stack_size = stack_size
self.res_channels = res_channels
self.kernel_size = kernel_size
self.causal = causal
self.res_blocks = self.stack_tcn_blocks()
self.receptive_fields = self.calculate_receptive_fields()
self.res_blocks = nn.Sequential(*self.res_blocks)
def calculate_receptive_fields(self):
receptive_fields = 0
for block in self.res_blocks:
receptive_fields += block.receptive_fields
return receptive_fields
def build_dilations(self):
dilations = []
for s in range(0, self.stack_size):
for l in range(0, self.stack_num):
dilations.append(2**l)
return dilations
def stack_tcn_blocks(self):
dilations = self.build_dilations()
res_blocks = nn.ModuleList()
res_blocks.append(
TCNBlock(
self.in_channels,
self.res_channels,
self.kernel_size,
dilations[0],
self.causal,
))
for dilation in dilations[1:]:
res_blocks.append(
TCNBlock(
self.res_channels,
self.res_channels,
self.kernel_size,
dilation,
self.causal,
))
return res_blocks
def forward(self, inputs: torch.Tensor):
outputs = inputs
outputs = self.res_blocks(outputs)
return outputs
class MDTC(nn.Module):
"""Multi-scale Depthwise Temporal Convolution (MDTC).
In MDTC, stacked depthwise one-dimensional (1-D) convolution with
dilated connections is adopted to efficiently model long-range
dependency of speech. With a large receptive field while
keeping a small number of model parameters, the structure
can model temporal context of speech effectively. It aslo
extracts multi-scale features from different hidden layers
of MDTC with different receptive fields.
"""
def __init__(
self,
stack_num: int,
stack_size: int,
in_channels: int,
res_channels: int,
kernel_size: int,
causal: bool,
):
super(MDTC, self).__init__()
self.kernel_size = kernel_size
self.causal = causal
self.preprocessor = TCNBlock(in_channels,
res_channels,
kernel_size,
dilation=1,
causal=causal)
self.relu = nn.ReLU()
self.blocks = nn.ModuleList()
self.receptive_fields = self.preprocessor.receptive_fields
for i in range(stack_num):
self.blocks.append(
TCNStack(res_channels, stack_size, 1, res_channels,
kernel_size, causal))
self.receptive_fields += self.blocks[-1].receptive_fields
self.half_receptive_fields = self.receptive_fields // 2
print('Receptive Fields: %d' % self.receptive_fields)
def normalize_length_causal(self, skip_connections: list):
output_size = skip_connections[-1].shape[-1]
normalized_outputs = []
for x in skip_connections:
remove_length = x.shape[-1] - output_size
if remove_length != 0:
normalized_outputs.append(x[:, :, remove_length:])
else:
normalized_outputs.append(x)
return normalized_outputs
def normalize_length(self, skip_connections: list):
output_size = skip_connections[-1].shape[-1]
normalized_outputs = []
for x in skip_connections:
remove_length = (x.shape[-1] - output_size) // 2
if remove_length != 0:
normalized_outputs.append(x[:, :,
remove_length:-remove_length])
else:
normalized_outputs.append(x)
return normalized_outputs
def forward(self, x: torch.Tensor):
if self.causal:
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
'constant')
else:
outputs = F.pad(
x,
(0, 0, self.half_receptive_fields, self.half_receptive_fields,
0, 0),
'constant',
)
outputs = outputs.transpose(1, 2)
outputs_list = []
outputs = self.relu(self.preprocessor(outputs))
for i in range(len(self.blocks)):
outputs = self.blocks[i](outputs)
outputs_list.append(outputs)
if self.causal:
outputs_list = self.normalize_length_causal(outputs_list)
else:
outputs_list = self.normalize_length(outputs_list)
outputs = sum(outputs_list)
outputs = outputs.transpose(1, 2)
return outputs, None
if __name__ == '__main__':
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
print(mdtc)
num_params = sum(p.numel() for p in mdtc.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 80) # batch-size * time * dim
y, _ = mdtc(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io")
class CollateFunc(object): class CollateFunc(object):
''' Collate function for AudioDataset ''' Collate function for AudioDataset
''' '''
def __init__(self, feat_dim, resample_rate): def __init__(self, feat_dim, feat_type, resample_rate):
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.resample_rate = resample_rate self.resample_rate = resample_rate
self.feat_type = feat_type
pass pass
def __call__(self, batch): def __call__(self, batch):
@ -31,7 +32,8 @@ class CollateFunc(object):
value = item[1].strip().split(",") value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1 assert len(value) == 3 or len(value) == 1
wav_path = value[0] wav_path = value[0]
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate sample_rate = torchaudio.backend.sox_io_backend.info(
wav_path).sample_rate
resample_rate = sample_rate resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp, # len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp # len(value) == 1 means original wav.scp
@ -50,12 +52,21 @@ class CollateFunc(object):
resample_rate = self.resample_rate resample_rate = self.resample_rate
waveform = torchaudio.transforms.Resample( waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform) orig_freq=sample_rate, new_freq=resample_rate)(waveform)
if self.feat_type == 'fbank':
mat = kaldi.fbank(waveform, mat = kaldi.fbank(waveform,
num_mel_bins=self.feat_dim, num_mel_bins=self.feat_dim,
dither=0.0, dither=0.0,
energy_floor=0.0, energy_floor=0.0,
sample_frequency=resample_rate) sample_frequency=resample_rate)
elif self.feat_type == 'mfcc':
mat = kaldi.mfcc(
waveform,
num_ceps=self.feat_dim,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate,
)
mean_stat += torch.sum(mat, axis=0) mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0) var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0] number += mat.shape[0]
@ -95,13 +106,17 @@ if __name__ == '__main__':
with open(args.train_config, 'r') as fin: with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader) configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] feat_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf'][
'feature_type']
resample_rate = 0 resample_rate = 0
if 'resample_conf' in configs['dataset_conf']: if 'resample_conf' in configs['dataset_conf']:
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] resample_rate = configs['dataset_conf']['resample_conf'][
'resample_rate']
print('using resample and new sample rate is {}'.format(resample_rate)) print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, resample_rate) collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
dataset = AudioDataset(args.in_scp) dataset = AudioDataset(args.in_scp)
batch_size = 20 batch_size = 20
data_loader = DataLoader(dataset, data_loader = DataLoader(dataset,

View File

@ -4,6 +4,7 @@
import sys import sys
import torchaudio import torchaudio
torchaudio.set_audio_backend("sox_io") torchaudio.set_audio_backend("sox_io")
scp = sys.argv[1] scp = sys.argv[1]