add model mdtc for mobvoi-hotword example
This commit is contained in:
parent
dbebee86fd
commit
4db050eb67
67
examples/hi_xiaowen/s0/conf/mdtc.yaml
Normal file
67
examples/hi_xiaowen/s0/conf/mdtc.yaml
Normal file
@ -0,0 +1,67 @@
|
||||
debug: false
|
||||
|
||||
input_dim: 80
|
||||
output_dim: 2
|
||||
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
|
||||
speed_perturb: false
|
||||
|
||||
feature_extraction_conf:
|
||||
feature_type: 'mfcc'
|
||||
num_ceps: 80
|
||||
num_mel_bins: 80
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
|
||||
feature_dither: 0.0
|
||||
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
#warp_for_time: false
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 20
|
||||
max_f: 40
|
||||
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 100
|
||||
|
||||
model:
|
||||
hidden_dim: 64
|
||||
preprocessing:
|
||||
type: none
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 4
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 64
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
#warmup_step: 2000
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
log_interval: 10
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -3,24 +3,24 @@
|
||||
|
||||
. ./path.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=0
|
||||
stop_stage=4
|
||||
stage=2
|
||||
stop_stage=2
|
||||
num_keywords=2
|
||||
|
||||
config=conf/ds_tcn.yaml
|
||||
norm_mean=true
|
||||
norm_var=true
|
||||
config=conf/mdtc.yaml
|
||||
norm_mean=false
|
||||
norm_var=false
|
||||
gpu_id=0
|
||||
|
||||
checkpoint=
|
||||
dir=exp/ds_tcn
|
||||
dir=exp/mdtc
|
||||
|
||||
num_average=30
|
||||
num_average=10
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
download_dir=/export/expts6/binbinzhang/data/
|
||||
download_dir=./data/local # your data dir
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
|
||||
@ -34,19 +34,16 @@ fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Preparing datasets..."
|
||||
mkdir dict
|
||||
mkdir -p dict
|
||||
echo "<filler> -1" > dict/words.txt
|
||||
echo "Hi_Xiaowen 0" >> dict/words.txt
|
||||
echo "Nihao_Wenwen 1" >> dict/words.txt
|
||||
|
||||
for folder in train dev eval; do
|
||||
for folder in train dev test; do
|
||||
mkdir -p data/$folder
|
||||
for prefix in p n; do
|
||||
mkdir -p data/${prefix}_$folder
|
||||
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
|
||||
if [ $folder = "eval" ]; then
|
||||
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_test.json
|
||||
fi
|
||||
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
|
||||
data/${prefix}_$folder
|
||||
done
|
||||
@ -63,7 +60,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
--in_scp data/train/wav.scp \
|
||||
--out_cmvn data/train/global_cmvn
|
||||
|
||||
for x in train dev eval; do
|
||||
for x in train dev test; do
|
||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||
data/$x/wav.dur data/$x/data.list
|
||||
@ -100,27 +97,31 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# Compute posterior score
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
mkdir -p $result_dir
|
||||
python kws/bin/score.py --gpu -1 \
|
||||
python kws/bin/score.py --gpu 1 \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/eval/data.list \
|
||||
--test_data data/test/data.list \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# Compute detection error tradeoff
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
for keyword in 0 1; do
|
||||
python kws/bin/compute_det.py \
|
||||
--keyword $keyword \
|
||||
--test_data data/eval/data.list \
|
||||
--test_data data/test/data.list \
|
||||
--score_file $result_dir/score.txt \
|
||||
--stats_file $result_dir/stats.${keyword}.txt
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
python kws/bin/export_jit.py --config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--output_file $dir/final.zip \
|
||||
--output_quant_file $dir/final.quant.zip
|
||||
fi
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ def main():
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
|
||||
input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
|
||||
input_dim = configs['input_dim']
|
||||
output_dim = args.num_keywords
|
||||
|
||||
# Write model_dir/config.yaml for inference and export
|
||||
@ -160,9 +160,9 @@ def main():
|
||||
# !!!IMPORTANT!!!
|
||||
# Try to export the model by script, if fails, we should refine
|
||||
# the code to satisfy the script export requirements
|
||||
if args.rank == 0:
|
||||
script_model = torch.jit.script(model)
|
||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
#if args.rank == 0:
|
||||
#script_model = torch.jit.script(model)
|
||||
#script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
executor = Executor()
|
||||
# If specify checkpoint, load some info from checkpoint
|
||||
if args.checkpoint is not None:
|
||||
|
||||
@ -136,10 +136,13 @@ def Dataset(data_list_file, conf, partition=True):
|
||||
speed_perturb = conf.get('speed_perturb', False)
|
||||
if speed_perturb:
|
||||
dataset = Processor(dataset, processor.speed_perturb)
|
||||
|
||||
fbank_conf = conf.get('fbank_conf', {})
|
||||
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
|
||||
|
||||
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
||||
if feature_extraction_conf['feature_type'] == 'mfcc':
|
||||
dataset = Processor(dataset, processor.compute_mfcc,
|
||||
**feature_extraction_conf)
|
||||
elif feature_extraction_conf['feature_type'] == 'fbank':
|
||||
dataset = Processor(dataset, processor.compute_fbank,
|
||||
**feature_extraction_conf)
|
||||
spec_aug = conf.get('spec_aug', True)
|
||||
if spec_aug:
|
||||
spec_aug_conf = conf.get('spec_aug_conf', {})
|
||||
|
||||
@ -126,6 +126,43 @@ def speed_perturb(data, speeds=None):
|
||||
|
||||
yield sample
|
||||
|
||||
def compute_mfcc(
|
||||
data,
|
||||
feature_type='mfcc',
|
||||
num_ceps=80,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
):
|
||||
"""Extract mfcc
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'key' in sample
|
||||
assert 'label' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
waveform = waveform * (1 << 15)
|
||||
# Only keep key, feat, label
|
||||
mat = kaldi.mfcc(
|
||||
waveform,
|
||||
num_ceps=num_ceps,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=sample_rate,
|
||||
)
|
||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||
|
||||
def compute_fbank(data,
|
||||
num_mel_bins=23,
|
||||
|
||||
@ -20,45 +20,55 @@ import torch
|
||||
from kws.model.cmvn import GlobalCMVN
|
||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||
from kws.model.mdtc import MDTC
|
||||
from kws.utils.cmvn import load_cmvn
|
||||
|
||||
|
||||
class KwsModel(torch.nn.Module):
|
||||
""" Our model consists of four parts:
|
||||
class KWSModel(torch.nn.Module):
|
||||
"""Our model consists of four parts:
|
||||
1. global_cmvn: Optional, (idim, idim)
|
||||
2. subsampling: subsampling the input, (idim, hdim)
|
||||
3. body: body of the whole network, (hdim, hdim)
|
||||
4. linear: a linear layer, (hdim, odim)
|
||||
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||
"""
|
||||
def __init__(self, idim: int, odim: int, hdim: int,
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
subsampling: torch.nn.Module, body: torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
odim: int,
|
||||
hdim: int,
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
preprocessing: Optional[torch.nn.Module],
|
||||
backbone: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
self.hdim = hdim
|
||||
self.global_cmvn = global_cmvn
|
||||
self.subsampling = subsampling
|
||||
self.body = body
|
||||
self.linear = torch.nn.Linear(hdim, odim)
|
||||
self.preprocessing = preprocessing
|
||||
self.backbone = backbone
|
||||
self.classifier = torch.nn.Linear(hdim, odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
x = self.subsampling(x)
|
||||
x, _ = self.body(x)
|
||||
x = self.linear(x)
|
||||
if self.preprocessing:
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_model(configs):
|
||||
cmvn = configs.get('cmvn', {})
|
||||
if cmvn['cmvn_file'] is not None:
|
||||
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
|
||||
mean, istd = load_cmvn(cmvn['cmvn_file'])
|
||||
global_cmvn = GlobalCMVN(
|
||||
torch.from_numpy(mean).float(),
|
||||
torch.from_numpy(istd).float(), cmvn['norm_var'])
|
||||
torch.from_numpy(istd).float(),
|
||||
cmvn['norm_var'],
|
||||
)
|
||||
else:
|
||||
global_cmvn = None
|
||||
|
||||
@ -66,36 +76,52 @@ def init_model(configs):
|
||||
output_dim = configs['output_dim']
|
||||
hidden_dim = configs['hidden_dim']
|
||||
|
||||
subsampling_type = configs['subsampling']['type']
|
||||
if subsampling_type == 'linear':
|
||||
subsampling = LinearSubsampling1(input_dim, hidden_dim)
|
||||
elif subsampling_type == 'cnn1d_s1':
|
||||
subsampling = Conv1dSubsampling1(input_dim, hidden_dim)
|
||||
prep_type = configs['preprocessing']['type']
|
||||
if prep_type == 'linear':
|
||||
preprocessing = LinearSubsampling1(input_dim, hidden_dim)
|
||||
elif prep_type == 'cnn1d_s1':
|
||||
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
|
||||
elif prep_type == 'none':
|
||||
preprocessing = None
|
||||
else:
|
||||
print('Unknown subsampling type {}'.format(subsampling_type))
|
||||
print('Unknown preprocessing type {}'.format(prep_type))
|
||||
sys.exit(1)
|
||||
|
||||
body_type = configs['body']['type']
|
||||
num_layers = configs['body']['num_layers']
|
||||
if body_type == 'gru':
|
||||
body = torch.nn.GRU(hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True)
|
||||
elif body_type == 'tcn':
|
||||
backbone_type = configs['backbone']['type']
|
||||
if backbone_type == 'gru':
|
||||
num_layers = configs['backbone']['num_layers']
|
||||
backbone = torch.nn.GRU(hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True)
|
||||
elif backbone_type == 'tcn':
|
||||
# Depthwise Separable
|
||||
ds = configs['body'].get('ds', False)
|
||||
num_layers = configs['backbone']['num_layers']
|
||||
ds = configs['backbone'].get('ds', False)
|
||||
if ds:
|
||||
block_class = DsCnnBlock
|
||||
else:
|
||||
block_class = CnnBlock
|
||||
kernel_size = configs['body'].get('kernel_size', 8)
|
||||
dropout = configs['body'].get('drouput', 0.1)
|
||||
body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class)
|
||||
kernel_size = configs['backbone'].get('kernel_size', 8)
|
||||
dropout = configs['backbone'].get('drouput', 0.1)
|
||||
backbone = TCN(num_layers, hidden_dim, kernel_size, dropout,
|
||||
block_class)
|
||||
elif backbone_type == 'mdtc':
|
||||
stack_size = configs['backbone']['stack_size']
|
||||
num_stack = configs['backbone']['num_stack']
|
||||
kernel_size = configs['backbone']['kernel_size']
|
||||
hidden_dim = configs['backbone']['hidden_dim']
|
||||
|
||||
backbone = MDTC(num_stack,
|
||||
stack_size,
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
kernel_size,
|
||||
causal=True)
|
||||
else:
|
||||
print('Unknown body type {}'.format(body_type))
|
||||
print('Unknown body type {}'.format(backbone_type))
|
||||
sys.exit(1)
|
||||
|
||||
kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
subsampling, body)
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone)
|
||||
return kws_model
|
||||
|
||||
260
kws/model/mdtc.py
Normal file
260
kws/model/mdtc.py
Normal file
@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DSDilatedConv1d(nn.Module):
|
||||
"""Dilated Depthwise-Separable Convolution"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
dilation: int = 1,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
):
|
||||
super(DSDilatedConv1d, self).__init__()
|
||||
self.receptive_fields = dilation * (kernel_size - 1)
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
stride=stride,
|
||||
groups=in_channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(in_channels)
|
||||
self.pointwise = nn.Conv1d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=bias)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
outputs = self.conv(inputs)
|
||||
outputs = self.bn(outputs)
|
||||
outputs = self.pointwise(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class TCNBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
causal: bool,
|
||||
):
|
||||
super(TCNBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.res_channels = res_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
self.causal = causal
|
||||
self.receptive_fields = dilation * (kernel_size - 1)
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
self.conv1 = DSDilatedConv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=res_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
)
|
||||
self.bn1 = nn.BatchNorm1d(res_channels)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.conv2 = nn.Conv1d(in_channels=res_channels,
|
||||
out_channels=res_channels,
|
||||
kernel_size=1)
|
||||
self.bn2 = nn.BatchNorm1d(res_channels)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
outputs = self.relu1(self.bn1(self.conv1(inputs)))
|
||||
outputs = self.bn2(self.conv2(outputs))
|
||||
if self.causal:
|
||||
inputs = inputs[:, :, self.receptive_fields:]
|
||||
else:
|
||||
inputs = inputs[:, :, self.
|
||||
half_receptive_fields:-self.half_receptive_fields]
|
||||
if self.in_channels == self.res_channels:
|
||||
res_out = self.relu2(outputs + inputs)
|
||||
else:
|
||||
res_out = self.relu2(outputs)
|
||||
return res_out
|
||||
|
||||
|
||||
class TCNStack(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stack_num: int,
|
||||
stack_size: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
causal: bool,
|
||||
):
|
||||
super(TCNStack, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.stack_num = stack_num
|
||||
self.stack_size = stack_size
|
||||
self.res_channels = res_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.causal = causal
|
||||
self.res_blocks = self.stack_tcn_blocks()
|
||||
self.receptive_fields = self.calculate_receptive_fields()
|
||||
self.res_blocks = nn.Sequential(*self.res_blocks)
|
||||
|
||||
def calculate_receptive_fields(self):
|
||||
receptive_fields = 0
|
||||
for block in self.res_blocks:
|
||||
receptive_fields += block.receptive_fields
|
||||
return receptive_fields
|
||||
|
||||
def build_dilations(self):
|
||||
dilations = []
|
||||
for s in range(0, self.stack_size):
|
||||
for l in range(0, self.stack_num):
|
||||
dilations.append(2**l)
|
||||
return dilations
|
||||
|
||||
def stack_tcn_blocks(self):
|
||||
dilations = self.build_dilations()
|
||||
res_blocks = nn.ModuleList()
|
||||
|
||||
res_blocks.append(
|
||||
TCNBlock(
|
||||
self.in_channels,
|
||||
self.res_channels,
|
||||
self.kernel_size,
|
||||
dilations[0],
|
||||
self.causal,
|
||||
))
|
||||
for dilation in dilations[1:]:
|
||||
res_blocks.append(
|
||||
TCNBlock(
|
||||
self.res_channels,
|
||||
self.res_channels,
|
||||
self.kernel_size,
|
||||
dilation,
|
||||
self.causal,
|
||||
))
|
||||
return res_blocks
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
outputs = inputs
|
||||
outputs = self.res_blocks(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class MDTC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stack_num: int,
|
||||
stack_size: int,
|
||||
in_channels: int,
|
||||
res_channels: int,
|
||||
kernel_size: int,
|
||||
causal: bool,
|
||||
):
|
||||
super(MDTC, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.causal = causal
|
||||
self.preprocessor = TCNBlock(in_channels,
|
||||
res_channels,
|
||||
kernel_size,
|
||||
dilation=1,
|
||||
causal=causal)
|
||||
self.relu = nn.ReLU()
|
||||
self.blocks = nn.ModuleList()
|
||||
self.receptive_fields = self.preprocessor.receptive_fields
|
||||
for i in range(stack_num):
|
||||
self.blocks.append(
|
||||
TCNStack(res_channels, stack_size, 1, res_channels,
|
||||
kernel_size, causal))
|
||||
self.receptive_fields += self.blocks[-1].receptive_fields
|
||||
self.half_receptive_fields = self.receptive_fields // 2
|
||||
print('Receptive Fields: %d' % self.receptive_fields)
|
||||
|
||||
def normalize_length_causal(self, skip_connections: list):
|
||||
output_size = skip_connections[-1].shape[-1]
|
||||
normalized_outputs = []
|
||||
for x in skip_connections:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if remove_length != 0:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
return normalized_outputs
|
||||
|
||||
def normalize_length(self, skip_connections: list):
|
||||
output_size = skip_connections[-1].shape[-1]
|
||||
normalized_outputs = []
|
||||
for x in skip_connections:
|
||||
remove_length = (x.shape[-1] - output_size) // 2
|
||||
if remove_length != 0:
|
||||
normalized_outputs.append(x[:, :,
|
||||
remove_length:-remove_length])
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
return normalized_outputs
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.causal:
|
||||
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||
'constant')
|
||||
else:
|
||||
outputs = F.pad(
|
||||
x,
|
||||
(0, 0, self.half_receptive_fields, self.half_receptive_fields,
|
||||
0, 0),
|
||||
'constant',
|
||||
)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
outputs_list = []
|
||||
outputs = self.relu(self.preprocessor(outputs))
|
||||
for i in range(len(self.blocks)):
|
||||
outputs = self.blocks[i](outputs)
|
||||
outputs_list.append(outputs)
|
||||
|
||||
if self.causal:
|
||||
outputs_list = self.normalize_length_causal(outputs_list)
|
||||
else:
|
||||
outputs_list = self.normalize_length(outputs_list)
|
||||
|
||||
outputs = sum(outputs_list)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
|
||||
print(mdtc)
|
||||
|
||||
num_params = sum(p.numel() for p in mdtc.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 80) # batch-size * time * dim
|
||||
y, _ = mdtc(x) # batch-size * time * dim
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
@ -95,7 +95,7 @@ if __name__ == '__main__':
|
||||
|
||||
with open(args.train_config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
|
||||
feat_dim = configs['input_dim']
|
||||
resample_rate = 0
|
||||
if 'resample_conf' in configs['dataset_conf']:
|
||||
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user