commit
531d795bce
@ -5,11 +5,12 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
fbank_conf:
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
frame_shift: 10
|
frame_shift: 10
|
||||||
frame_length: 25
|
frame_length: 25
|
||||||
dither: 0.1
|
dither: 1.0
|
||||||
spec_aug: true
|
spec_aug: true
|
||||||
spec_aug_conf:
|
spec_aug_conf:
|
||||||
num_t_mask: 1
|
num_t_mask: 1
|
||||||
@ -24,9 +25,9 @@ dataset_conf:
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
subsampling:
|
preprocessing:
|
||||||
type: linear
|
type: linear
|
||||||
body:
|
backbone:
|
||||||
type: tcn
|
type: tcn
|
||||||
ds: true
|
ds: true
|
||||||
num_layers: 4
|
num_layers: 4
|
||||||
|
|||||||
@ -5,11 +5,12 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
fbank_conf:
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
frame_shift: 10
|
frame_shift: 10
|
||||||
frame_length: 25
|
frame_length: 25
|
||||||
dither: 0.1
|
dither: 1.0
|
||||||
spec_aug: false
|
spec_aug: false
|
||||||
spec_aug_conf:
|
spec_aug_conf:
|
||||||
num_t_mask: 2
|
num_t_mask: 2
|
||||||
@ -24,9 +25,9 @@ dataset_conf:
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
hidden_dim: 128
|
hidden_dim: 128
|
||||||
subsampling:
|
preprocessing:
|
||||||
type: linear
|
type: linear
|
||||||
body:
|
backbone:
|
||||||
type: gru
|
type: gru
|
||||||
num_layers: 2
|
num_layers: 2
|
||||||
|
|
||||||
|
|||||||
46
examples/hi_xiaowen/s0/conf/mdtc.yaml
Normal file
46
examples/hi_xiaowen/s0/conf/mdtc.yaml
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'mfcc'
|
||||||
|
num_ceps: 80
|
||||||
|
num_mel_bins: 80
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
feature_dither: 0.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 20
|
||||||
|
max_f: 40
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 100
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 64
|
||||||
|
preprocessing:
|
||||||
|
type: none
|
||||||
|
backbone:
|
||||||
|
type: mdtc
|
||||||
|
num_stack: 4
|
||||||
|
stack_size: 4
|
||||||
|
kernel_size: 5
|
||||||
|
hidden_dim: 64
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 100
|
||||||
|
log_interval: 10
|
||||||
@ -5,11 +5,12 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
fbank_conf:
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
frame_shift: 10
|
frame_shift: 10
|
||||||
frame_length: 25
|
frame_length: 25
|
||||||
dither: 0.1
|
dither: 1.0
|
||||||
spec_aug: false
|
spec_aug: false
|
||||||
spec_aug_conf:
|
spec_aug_conf:
|
||||||
num_t_mask: 2
|
num_t_mask: 2
|
||||||
@ -24,9 +25,9 @@ dataset_conf:
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
subsampling:
|
preprocessing:
|
||||||
type: linear
|
type: linear
|
||||||
body:
|
backbone:
|
||||||
type: tcn
|
type: tcn
|
||||||
ds: false
|
ds: false
|
||||||
num_layers: 4
|
num_layers: 4
|
||||||
|
|||||||
@ -3,24 +3,24 @@
|
|||||||
|
|
||||||
. ./path.sh
|
. ./path.sh
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
stage=0
|
stage=0
|
||||||
stop_stage=4
|
stop_stage=4
|
||||||
num_keywords=2
|
num_keywords=2
|
||||||
|
|
||||||
config=conf/ds_tcn.yaml
|
config=conf/mdtc.yaml
|
||||||
norm_mean=true
|
norm_mean=false
|
||||||
norm_var=true
|
norm_var=false
|
||||||
gpu_id=0
|
gpu_id=0
|
||||||
|
|
||||||
checkpoint=
|
checkpoint=
|
||||||
dir=exp/ds_tcn
|
dir=exp/mdtc
|
||||||
|
|
||||||
num_average=30
|
num_average=10
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
download_dir=/export/expts6/binbinzhang/data/
|
download_dir=./data/local # your data dir
|
||||||
|
|
||||||
. tools/parse_options.sh || exit 1;
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
@ -34,19 +34,16 @@ fi
|
|||||||
|
|
||||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
echo "Preparing datasets..."
|
echo "Preparing datasets..."
|
||||||
mkdir dict
|
mkdir -p dict
|
||||||
echo "<filler> -1" > dict/words.txt
|
echo "<filler> -1" > dict/words.txt
|
||||||
echo "Hi_Xiaowen 0" >> dict/words.txt
|
echo "Hi_Xiaowen 0" >> dict/words.txt
|
||||||
echo "Nihao_Wenwen 1" >> dict/words.txt
|
echo "Nihao_Wenwen 1" >> dict/words.txt
|
||||||
|
|
||||||
for folder in train dev eval; do
|
for folder in train dev test; do
|
||||||
mkdir -p data/$folder
|
mkdir -p data/$folder
|
||||||
for prefix in p n; do
|
for prefix in p n; do
|
||||||
mkdir -p data/${prefix}_$folder
|
mkdir -p data/${prefix}_$folder
|
||||||
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
|
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
|
||||||
if [ $folder = "eval" ]; then
|
|
||||||
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_test.json
|
|
||||||
fi
|
|
||||||
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
|
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
|
||||||
data/${prefix}_$folder
|
data/${prefix}_$folder
|
||||||
done
|
done
|
||||||
@ -63,7 +60,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|||||||
--in_scp data/train/wav.scp \
|
--in_scp data/train/wav.scp \
|
||||||
--out_cmvn data/train/global_cmvn
|
--out_cmvn data/train/global_cmvn
|
||||||
|
|
||||||
for x in train dev eval; do
|
for x in train dev test; do
|
||||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||||
data/$x/wav.dur data/$x/data.list
|
data/$x/wav.dur data/$x/data.list
|
||||||
@ -100,27 +97,31 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
# Compute posterior score
|
# Compute posterior score
|
||||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
mkdir -p $result_dir
|
mkdir -p $result_dir
|
||||||
python kws/bin/score.py --gpu -1 \
|
python kws/bin/score.py --gpu 1 \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--test_data data/eval/data.list \
|
--test_data data/test/data.list \
|
||||||
--batch_size 256 \
|
--batch_size 256 \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt
|
--score_file $result_dir/score.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
# Compute detection error tradeoff
|
# Compute detection error tradeoff
|
||||||
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
for keyword in 0 1; do
|
for keyword in 0 1; do
|
||||||
python kws/bin/compute_det.py \
|
python kws/bin/compute_det.py \
|
||||||
--keyword $keyword \
|
--keyword $keyword \
|
||||||
--test_data data/eval/data.list \
|
--test_data data/test/data.list \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
--stats_file $result_dir/stats.${keyword}.txt
|
--stats_file $result_dir/stats.${keyword}.txt
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
python kws/bin/export_jit.py --config $dir/config.yaml \
|
python kws/bin/export_jit.py --config $dir/config.yaml \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--output_file $dir/final.zip \
|
--output_file $dir/final.zip \
|
||||||
--output_quant_file $dir/final.quant.zip
|
--output_quant_file $dir/final.quant.zip
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -56,8 +56,7 @@ def main():
|
|||||||
# Export quantized jit torch script model
|
# Export quantized jit torch script model
|
||||||
if args.output_quant_file:
|
if args.output_quant_file:
|
||||||
quantized_model = torch.quantization.quantize_dynamic(
|
quantized_model = torch.quantization.quantize_dynamic(
|
||||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||||
)
|
|
||||||
print(quantized_model)
|
print(quantized_model)
|
||||||
script_quant_model = torch.jit.script(quantized_model)
|
script_quant_model = torch.jit.script(quantized_model)
|
||||||
script_quant_model.save(args.output_quant_file)
|
script_quant_model.save(args.output_quant_file)
|
||||||
|
|||||||
@ -135,7 +135,8 @@ def main():
|
|||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
prefetch_factor=args.prefetch)
|
prefetch_factor=args.prefetch)
|
||||||
|
|
||||||
input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
|
input_dim = configs['dataset_conf']['feature_extraction_conf'][
|
||||||
|
'num_mel_bins']
|
||||||
output_dim = args.num_keywords
|
output_dim = args.num_keywords
|
||||||
|
|
||||||
# Write model_dir/config.yaml for inference and export
|
# Write model_dir/config.yaml for inference and export
|
||||||
@ -160,9 +161,9 @@ def main():
|
|||||||
# !!!IMPORTANT!!!
|
# !!!IMPORTANT!!!
|
||||||
# Try to export the model by script, if fails, we should refine
|
# Try to export the model by script, if fails, we should refine
|
||||||
# the code to satisfy the script export requirements
|
# the code to satisfy the script export requirements
|
||||||
if args.rank == 0:
|
# if args.rank == 0:
|
||||||
script_model = torch.jit.script(model)
|
# script_model = torch.jit.script(model)
|
||||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||||
executor = Executor()
|
executor = Executor()
|
||||||
# If specify checkpoint, load some info from checkpoint
|
# If specify checkpoint, load some info from checkpoint
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
|
|||||||
@ -136,10 +136,13 @@ def Dataset(data_list_file, conf, partition=True):
|
|||||||
speed_perturb = conf.get('speed_perturb', False)
|
speed_perturb = conf.get('speed_perturb', False)
|
||||||
if speed_perturb:
|
if speed_perturb:
|
||||||
dataset = Processor(dataset, processor.speed_perturb)
|
dataset = Processor(dataset, processor.speed_perturb)
|
||||||
|
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
||||||
fbank_conf = conf.get('fbank_conf', {})
|
if feature_extraction_conf['feature_type'] == 'mfcc':
|
||||||
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
|
dataset = Processor(dataset, processor.compute_mfcc,
|
||||||
|
**feature_extraction_conf)
|
||||||
|
elif feature_extraction_conf['feature_type'] == 'fbank':
|
||||||
|
dataset = Processor(dataset, processor.compute_fbank,
|
||||||
|
**feature_extraction_conf)
|
||||||
spec_aug = conf.get('spec_aug', True)
|
spec_aug = conf.get('spec_aug', True)
|
||||||
if spec_aug:
|
if spec_aug:
|
||||||
spec_aug_conf = conf.get('spec_aug_conf', {})
|
spec_aug_conf = conf.get('spec_aug_conf', {})
|
||||||
|
|||||||
@ -127,7 +127,47 @@ def speed_perturb(data, speeds=None):
|
|||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mfcc(
|
||||||
|
data,
|
||||||
|
feature_type='mfcc',
|
||||||
|
num_ceps=80,
|
||||||
|
num_mel_bins=80,
|
||||||
|
frame_length=25,
|
||||||
|
frame_shift=10,
|
||||||
|
dither=0.0,
|
||||||
|
):
|
||||||
|
"""Extract mfcc
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Iterable[{key, wav, label, sample_rate}]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterable[{key, feat, label}]
|
||||||
|
"""
|
||||||
|
for sample in data:
|
||||||
|
assert 'sample_rate' in sample
|
||||||
|
assert 'wav' in sample
|
||||||
|
assert 'key' in sample
|
||||||
|
assert 'label' in sample
|
||||||
|
sample_rate = sample['sample_rate']
|
||||||
|
waveform = sample['wav']
|
||||||
|
waveform = waveform * (1 << 15)
|
||||||
|
# Only keep key, feat, label
|
||||||
|
mat = kaldi.mfcc(
|
||||||
|
waveform,
|
||||||
|
num_ceps=num_ceps,
|
||||||
|
num_mel_bins=num_mel_bins,
|
||||||
|
frame_length=frame_length,
|
||||||
|
frame_shift=frame_shift,
|
||||||
|
dither=dither,
|
||||||
|
energy_floor=0.0,
|
||||||
|
sample_frequency=sample_rate,
|
||||||
|
)
|
||||||
|
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank(data,
|
def compute_fbank(data,
|
||||||
|
feature_type='fbank',
|
||||||
num_mel_bins=23,
|
num_mel_bins=23,
|
||||||
frame_length=25,
|
frame_length=25,
|
||||||
frame_shift=10,
|
frame_shift=10,
|
||||||
|
|||||||
@ -20,45 +20,55 @@ import torch
|
|||||||
from kws.model.cmvn import GlobalCMVN
|
from kws.model.cmvn import GlobalCMVN
|
||||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
||||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||||
|
from kws.model.mdtc import MDTC
|
||||||
from kws.utils.cmvn import load_cmvn
|
from kws.utils.cmvn import load_cmvn
|
||||||
|
|
||||||
|
|
||||||
class KwsModel(torch.nn.Module):
|
class KWSModel(torch.nn.Module):
|
||||||
""" Our model consists of four parts:
|
"""Our model consists of four parts:
|
||||||
1. global_cmvn: Optional, (idim, idim)
|
1. global_cmvn: Optional, (idim, idim)
|
||||||
2. subsampling: subsampling the input, (idim, hdim)
|
2. preprocessing: feature dimention projection, (idim, hdim)
|
||||||
3. body: body of the whole network, (hdim, hdim)
|
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||||
4. linear: a linear layer, (hdim, odim)
|
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||||
"""
|
"""
|
||||||
def __init__(self, idim: int, odim: int, hdim: int,
|
def __init__(
|
||||||
global_cmvn: Optional[torch.nn.Module],
|
self,
|
||||||
subsampling: torch.nn.Module, body: torch.nn.Module):
|
idim: int,
|
||||||
|
odim: int,
|
||||||
|
hdim: int,
|
||||||
|
global_cmvn: Optional[torch.nn.Module],
|
||||||
|
preprocessing: Optional[torch.nn.Module],
|
||||||
|
backbone: torch.nn.Module,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
self.odim = odim
|
self.odim = odim
|
||||||
self.hdim = hdim
|
self.hdim = hdim
|
||||||
self.global_cmvn = global_cmvn
|
self.global_cmvn = global_cmvn
|
||||||
self.subsampling = subsampling
|
self.preprocessing = preprocessing
|
||||||
self.body = body
|
self.backbone = backbone
|
||||||
self.linear = torch.nn.Linear(hdim, odim)
|
self.classifier = torch.nn.Linear(hdim, odim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
x = self.subsampling(x)
|
if self.preprocessing:
|
||||||
x, _ = self.body(x)
|
x = self.preprocessing(x)
|
||||||
x = self.linear(x)
|
x, _ = self.backbone(x)
|
||||||
|
x = self.classifier(x)
|
||||||
x = torch.sigmoid(x)
|
x = torch.sigmoid(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def init_model(configs):
|
def init_model(configs):
|
||||||
cmvn = configs.get('cmvn', {})
|
cmvn = configs.get('cmvn', {})
|
||||||
if cmvn['cmvn_file'] is not None:
|
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
|
||||||
mean, istd = load_cmvn(cmvn['cmvn_file'])
|
mean, istd = load_cmvn(cmvn['cmvn_file'])
|
||||||
global_cmvn = GlobalCMVN(
|
global_cmvn = GlobalCMVN(
|
||||||
torch.from_numpy(mean).float(),
|
torch.from_numpy(mean).float(),
|
||||||
torch.from_numpy(istd).float(), cmvn['norm_var'])
|
torch.from_numpy(istd).float(),
|
||||||
|
cmvn['norm_var'],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
global_cmvn = None
|
global_cmvn = None
|
||||||
|
|
||||||
@ -66,36 +76,52 @@ def init_model(configs):
|
|||||||
output_dim = configs['output_dim']
|
output_dim = configs['output_dim']
|
||||||
hidden_dim = configs['hidden_dim']
|
hidden_dim = configs['hidden_dim']
|
||||||
|
|
||||||
subsampling_type = configs['subsampling']['type']
|
prep_type = configs['preprocessing']['type']
|
||||||
if subsampling_type == 'linear':
|
if prep_type == 'linear':
|
||||||
subsampling = LinearSubsampling1(input_dim, hidden_dim)
|
preprocessing = LinearSubsampling1(input_dim, hidden_dim)
|
||||||
elif subsampling_type == 'cnn1d_s1':
|
elif prep_type == 'cnn1d_s1':
|
||||||
subsampling = Conv1dSubsampling1(input_dim, hidden_dim)
|
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
|
||||||
|
elif prep_type == 'none':
|
||||||
|
preprocessing = None
|
||||||
else:
|
else:
|
||||||
print('Unknown subsampling type {}'.format(subsampling_type))
|
print('Unknown preprocessing type {}'.format(prep_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
body_type = configs['body']['type']
|
backbone_type = configs['backbone']['type']
|
||||||
num_layers = configs['body']['num_layers']
|
if backbone_type == 'gru':
|
||||||
if body_type == 'gru':
|
num_layers = configs['backbone']['num_layers']
|
||||||
body = torch.nn.GRU(hidden_dim,
|
backbone = torch.nn.GRU(hidden_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
batch_first=True)
|
batch_first=True)
|
||||||
elif body_type == 'tcn':
|
elif backbone_type == 'tcn':
|
||||||
# Depthwise Separable
|
# Depthwise Separable
|
||||||
ds = configs['body'].get('ds', False)
|
num_layers = configs['backbone']['num_layers']
|
||||||
|
ds = configs['backbone'].get('ds', False)
|
||||||
if ds:
|
if ds:
|
||||||
block_class = DsCnnBlock
|
block_class = DsCnnBlock
|
||||||
else:
|
else:
|
||||||
block_class = CnnBlock
|
block_class = CnnBlock
|
||||||
kernel_size = configs['body'].get('kernel_size', 8)
|
kernel_size = configs['backbone'].get('kernel_size', 8)
|
||||||
dropout = configs['body'].get('drouput', 0.1)
|
dropout = configs['backbone'].get('drouput', 0.1)
|
||||||
body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class)
|
backbone = TCN(num_layers, hidden_dim, kernel_size, dropout,
|
||||||
|
block_class)
|
||||||
|
elif backbone_type == 'mdtc':
|
||||||
|
stack_size = configs['backbone']['stack_size']
|
||||||
|
num_stack = configs['backbone']['num_stack']
|
||||||
|
kernel_size = configs['backbone']['kernel_size']
|
||||||
|
hidden_dim = configs['backbone']['hidden_dim']
|
||||||
|
|
||||||
|
backbone = MDTC(num_stack,
|
||||||
|
stack_size,
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
kernel_size,
|
||||||
|
causal=True)
|
||||||
else:
|
else:
|
||||||
print('Unknown body type {}'.format(body_type))
|
print('Unknown body type {}'.format(backbone_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
subsampling, body)
|
preprocessing, backbone)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from kws.utils.mask import padding_mask
|
from kws.utils.mask import padding_mask
|
||||||
|
|
||||||
|
|||||||
267
kws/model/mdtc.py
Normal file
267
kws/model/mdtc.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DSDilatedConv1d(nn.Module):
|
||||||
|
"""Dilated Depthwise-Separable Convolution"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation: int = 1,
|
||||||
|
stride: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super(DSDilatedConv1d, self).__init__()
|
||||||
|
self.receptive_fields = dilation * (kernel_size - 1)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=0,
|
||||||
|
dilation=dilation,
|
||||||
|
stride=stride,
|
||||||
|
groups=in_channels,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm1d(in_channels)
|
||||||
|
self.pointwise = nn.Conv1d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor):
|
||||||
|
outputs = self.conv(inputs)
|
||||||
|
outputs = self.bn(outputs)
|
||||||
|
outputs = self.pointwise(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TCNBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
res_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dilation: int,
|
||||||
|
causal: bool,
|
||||||
|
):
|
||||||
|
super(TCNBlock, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.res_channels = res_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation = dilation
|
||||||
|
self.causal = causal
|
||||||
|
self.receptive_fields = dilation * (kernel_size - 1)
|
||||||
|
self.half_receptive_fields = self.receptive_fields // 2
|
||||||
|
self.conv1 = DSDilatedConv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=res_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
self.bn1 = nn.BatchNorm1d(res_channels)
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv1d(in_channels=res_channels,
|
||||||
|
out_channels=res_channels,
|
||||||
|
kernel_size=1)
|
||||||
|
self.bn2 = nn.BatchNorm1d(res_channels)
|
||||||
|
self.relu2 = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor):
|
||||||
|
outputs = self.relu1(self.bn1(self.conv1(inputs)))
|
||||||
|
outputs = self.bn2(self.conv2(outputs))
|
||||||
|
if self.causal:
|
||||||
|
inputs = inputs[:, :, self.receptive_fields:]
|
||||||
|
else:
|
||||||
|
inputs = inputs[:, :, self.
|
||||||
|
half_receptive_fields:-self.half_receptive_fields]
|
||||||
|
if self.in_channels == self.res_channels:
|
||||||
|
res_out = self.relu2(outputs + inputs)
|
||||||
|
else:
|
||||||
|
res_out = self.relu2(outputs)
|
||||||
|
return res_out
|
||||||
|
|
||||||
|
|
||||||
|
class TCNStack(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
stack_num: int,
|
||||||
|
stack_size: int,
|
||||||
|
res_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
causal: bool,
|
||||||
|
):
|
||||||
|
super(TCNStack, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.stack_num = stack_num
|
||||||
|
self.stack_size = stack_size
|
||||||
|
self.res_channels = res_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.causal = causal
|
||||||
|
self.res_blocks = self.stack_tcn_blocks()
|
||||||
|
self.receptive_fields = self.calculate_receptive_fields()
|
||||||
|
self.res_blocks = nn.Sequential(*self.res_blocks)
|
||||||
|
|
||||||
|
def calculate_receptive_fields(self):
|
||||||
|
receptive_fields = 0
|
||||||
|
for block in self.res_blocks:
|
||||||
|
receptive_fields += block.receptive_fields
|
||||||
|
return receptive_fields
|
||||||
|
|
||||||
|
def build_dilations(self):
|
||||||
|
dilations = []
|
||||||
|
for s in range(0, self.stack_size):
|
||||||
|
for l in range(0, self.stack_num):
|
||||||
|
dilations.append(2**l)
|
||||||
|
return dilations
|
||||||
|
|
||||||
|
def stack_tcn_blocks(self):
|
||||||
|
dilations = self.build_dilations()
|
||||||
|
res_blocks = nn.ModuleList()
|
||||||
|
|
||||||
|
res_blocks.append(
|
||||||
|
TCNBlock(
|
||||||
|
self.in_channels,
|
||||||
|
self.res_channels,
|
||||||
|
self.kernel_size,
|
||||||
|
dilations[0],
|
||||||
|
self.causal,
|
||||||
|
))
|
||||||
|
for dilation in dilations[1:]:
|
||||||
|
res_blocks.append(
|
||||||
|
TCNBlock(
|
||||||
|
self.res_channels,
|
||||||
|
self.res_channels,
|
||||||
|
self.kernel_size,
|
||||||
|
dilation,
|
||||||
|
self.causal,
|
||||||
|
))
|
||||||
|
return res_blocks
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor):
|
||||||
|
outputs = inputs
|
||||||
|
outputs = self.res_blocks(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MDTC(nn.Module):
|
||||||
|
"""Multi-scale Depthwise Temporal Convolution (MDTC).
|
||||||
|
In MDTC, stacked depthwise one-dimensional (1-D) convolution with
|
||||||
|
dilated connections is adopted to efficiently model long-range
|
||||||
|
dependency of speech. With a large receptive field while
|
||||||
|
keeping a small number of model parameters, the structure
|
||||||
|
can model temporal context of speech effectively. It aslo
|
||||||
|
extracts multi-scale features from different hidden layers
|
||||||
|
of MDTC with different receptive fields.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stack_num: int,
|
||||||
|
stack_size: int,
|
||||||
|
in_channels: int,
|
||||||
|
res_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
causal: bool,
|
||||||
|
):
|
||||||
|
super(MDTC, self).__init__()
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.causal = causal
|
||||||
|
self.preprocessor = TCNBlock(in_channels,
|
||||||
|
res_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation=1,
|
||||||
|
causal=causal)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
self.receptive_fields = self.preprocessor.receptive_fields
|
||||||
|
for i in range(stack_num):
|
||||||
|
self.blocks.append(
|
||||||
|
TCNStack(res_channels, stack_size, 1, res_channels,
|
||||||
|
kernel_size, causal))
|
||||||
|
self.receptive_fields += self.blocks[-1].receptive_fields
|
||||||
|
self.half_receptive_fields = self.receptive_fields // 2
|
||||||
|
print('Receptive Fields: %d' % self.receptive_fields)
|
||||||
|
|
||||||
|
def normalize_length_causal(self, skip_connections: list):
|
||||||
|
output_size = skip_connections[-1].shape[-1]
|
||||||
|
normalized_outputs = []
|
||||||
|
for x in skip_connections:
|
||||||
|
remove_length = x.shape[-1] - output_size
|
||||||
|
if remove_length != 0:
|
||||||
|
normalized_outputs.append(x[:, :, remove_length:])
|
||||||
|
else:
|
||||||
|
normalized_outputs.append(x)
|
||||||
|
return normalized_outputs
|
||||||
|
|
||||||
|
def normalize_length(self, skip_connections: list):
|
||||||
|
output_size = skip_connections[-1].shape[-1]
|
||||||
|
normalized_outputs = []
|
||||||
|
for x in skip_connections:
|
||||||
|
remove_length = (x.shape[-1] - output_size) // 2
|
||||||
|
if remove_length != 0:
|
||||||
|
normalized_outputs.append(x[:, :,
|
||||||
|
remove_length:-remove_length])
|
||||||
|
else:
|
||||||
|
normalized_outputs.append(x)
|
||||||
|
return normalized_outputs
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.causal:
|
||||||
|
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||||
|
'constant')
|
||||||
|
else:
|
||||||
|
outputs = F.pad(
|
||||||
|
x,
|
||||||
|
(0, 0, self.half_receptive_fields, self.half_receptive_fields,
|
||||||
|
0, 0),
|
||||||
|
'constant',
|
||||||
|
)
|
||||||
|
outputs = outputs.transpose(1, 2)
|
||||||
|
outputs_list = []
|
||||||
|
outputs = self.relu(self.preprocessor(outputs))
|
||||||
|
for i in range(len(self.blocks)):
|
||||||
|
outputs = self.blocks[i](outputs)
|
||||||
|
outputs_list.append(outputs)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
outputs_list = self.normalize_length_causal(outputs_list)
|
||||||
|
else:
|
||||||
|
outputs_list = self.normalize_length(outputs_list)
|
||||||
|
|
||||||
|
outputs = sum(outputs_list)
|
||||||
|
outputs = outputs.transpose(1, 2)
|
||||||
|
return outputs, None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
mdtc = MDTC(3, 4, 80, 64, 5, causal=True)
|
||||||
|
print(mdtc)
|
||||||
|
|
||||||
|
num_params = sum(p.numel() for p in mdtc.parameters())
|
||||||
|
print('the number of model params: {}'.format(num_params))
|
||||||
|
x = torch.zeros(128, 200, 80) # batch-size * time * dim
|
||||||
|
y, _ = mdtc(x) # batch-size * time * dim
|
||||||
|
print('input shape: {}'.format(x.shape))
|
||||||
|
print('output shape: {}'.format(y.shape))
|
||||||
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|||||||
@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io")
|
|||||||
class CollateFunc(object):
|
class CollateFunc(object):
|
||||||
''' Collate function for AudioDataset
|
''' Collate function for AudioDataset
|
||||||
'''
|
'''
|
||||||
def __init__(self, feat_dim, resample_rate):
|
def __init__(self, feat_dim, feat_type, resample_rate):
|
||||||
self.feat_dim = feat_dim
|
self.feat_dim = feat_dim
|
||||||
self.resample_rate = resample_rate
|
self.resample_rate = resample_rate
|
||||||
|
self.feat_type = feat_type
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
@ -31,7 +32,8 @@ class CollateFunc(object):
|
|||||||
value = item[1].strip().split(",")
|
value = item[1].strip().split(",")
|
||||||
assert len(value) == 3 or len(value) == 1
|
assert len(value) == 3 or len(value) == 1
|
||||||
wav_path = value[0]
|
wav_path = value[0]
|
||||||
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
|
sample_rate = torchaudio.backend.sox_io_backend.info(
|
||||||
|
wav_path).sample_rate
|
||||||
resample_rate = sample_rate
|
resample_rate = sample_rate
|
||||||
# len(value) == 3 means segmented wav.scp,
|
# len(value) == 3 means segmented wav.scp,
|
||||||
# len(value) == 1 means original wav.scp
|
# len(value) == 1 means original wav.scp
|
||||||
@ -50,12 +52,21 @@ class CollateFunc(object):
|
|||||||
resample_rate = self.resample_rate
|
resample_rate = self.resample_rate
|
||||||
waveform = torchaudio.transforms.Resample(
|
waveform = torchaudio.transforms.Resample(
|
||||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||||
|
if self.feat_type == 'fbank':
|
||||||
mat = kaldi.fbank(waveform,
|
mat = kaldi.fbank(waveform,
|
||||||
num_mel_bins=self.feat_dim,
|
num_mel_bins=self.feat_dim,
|
||||||
dither=0.0,
|
dither=0.0,
|
||||||
energy_floor=0.0,
|
energy_floor=0.0,
|
||||||
sample_frequency=resample_rate)
|
sample_frequency=resample_rate)
|
||||||
|
elif self.feat_type == 'mfcc':
|
||||||
|
mat = kaldi.mfcc(
|
||||||
|
waveform,
|
||||||
|
num_ceps=self.feat_dim,
|
||||||
|
num_mel_bins=self.feat_dim,
|
||||||
|
dither=0.0,
|
||||||
|
energy_floor=0.0,
|
||||||
|
sample_frequency=resample_rate,
|
||||||
|
)
|
||||||
mean_stat += torch.sum(mat, axis=0)
|
mean_stat += torch.sum(mat, axis=0)
|
||||||
var_stat += torch.sum(torch.square(mat), axis=0)
|
var_stat += torch.sum(torch.square(mat), axis=0)
|
||||||
number += mat.shape[0]
|
number += mat.shape[0]
|
||||||
@ -95,13 +106,17 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
with open(args.train_config, 'r') as fin:
|
with open(args.train_config, 'r') as fin:
|
||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
|
feat_dim = configs['dataset_conf']['feature_extraction_conf'][
|
||||||
|
'num_mel_bins']
|
||||||
|
feat_type = configs['dataset_conf']['feature_extraction_conf'][
|
||||||
|
'feature_type']
|
||||||
resample_rate = 0
|
resample_rate = 0
|
||||||
if 'resample_conf' in configs['dataset_conf']:
|
if 'resample_conf' in configs['dataset_conf']:
|
||||||
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
resample_rate = configs['dataset_conf']['resample_conf'][
|
||||||
|
'resample_rate']
|
||||||
print('using resample and new sample rate is {}'.format(resample_rate))
|
print('using resample and new sample rate is {}'.format(resample_rate))
|
||||||
|
|
||||||
collate_func = CollateFunc(feat_dim, resample_rate)
|
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
|
||||||
dataset = AudioDataset(args.in_scp)
|
dataset = AudioDataset(args.in_scp)
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
data_loader = DataLoader(dataset,
|
data_loader = DataLoader(dataset,
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
torchaudio.set_audio_backend("sox_io")
|
torchaudio.set_audio_backend("sox_io")
|
||||||
|
|
||||||
scp = sys.argv[1]
|
scp = sys.argv[1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user