diff --git a/examples/hi_xiaowen/s0/conf/ds_tcn.yaml b/examples/hi_xiaowen/s0/conf/ds_tcn.yaml index 7da0181..88b0c78 100644 --- a/examples/hi_xiaowen/s0/conf/ds_tcn.yaml +++ b/examples/hi_xiaowen/s0/conf/ds_tcn.yaml @@ -5,11 +5,12 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false - fbank_conf: + feature_extraction_conf: + feature_type: 'fbank' num_mel_bins: 40 frame_shift: 10 frame_length: 25 - dither: 0.1 + dither: 1.0 spec_aug: true spec_aug_conf: num_t_mask: 1 @@ -24,9 +25,9 @@ dataset_conf: model: hidden_dim: 64 - subsampling: + preprocessing: type: linear - body: + backbone: type: tcn ds: true num_layers: 4 diff --git a/examples/hi_xiaowen/s0/conf/gru.yaml b/examples/hi_xiaowen/s0/conf/gru.yaml index e664319..8e63900 100644 --- a/examples/hi_xiaowen/s0/conf/gru.yaml +++ b/examples/hi_xiaowen/s0/conf/gru.yaml @@ -5,11 +5,12 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false - fbank_conf: + feature_extraction_conf: + feature_type: 'fbank' num_mel_bins: 40 frame_shift: 10 frame_length: 25 - dither: 0.1 + dither: 1.0 spec_aug: false spec_aug_conf: num_t_mask: 2 @@ -24,9 +25,9 @@ dataset_conf: model: hidden_dim: 128 - subsampling: + preprocessing: type: linear - body: + backbone: type: gru num_layers: 2 diff --git a/examples/hi_xiaowen/s0/conf/mdtc.yaml b/examples/hi_xiaowen/s0/conf/mdtc.yaml new file mode 100644 index 0000000..4ec5f7e --- /dev/null +++ b/examples/hi_xiaowen/s0/conf/mdtc.yaml @@ -0,0 +1,46 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'mfcc' + num_ceps: 80 + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + feature_dither: 0.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 40 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 100 + +model: + hidden_dim: 64 + preprocessing: + type: none + backbone: + type: mdtc + num_stack: 4 + stack_size: 4 + kernel_size: 5 + hidden_dim: 64 + +optim: adam +optim_conf: + lr: 0.001 + +training_config: + grad_clip: 5 + max_epoch: 100 + log_interval: 10 diff --git a/examples/hi_xiaowen/s0/conf/tcn.yaml b/examples/hi_xiaowen/s0/conf/tcn.yaml index 0612634..517b94e 100644 --- a/examples/hi_xiaowen/s0/conf/tcn.yaml +++ b/examples/hi_xiaowen/s0/conf/tcn.yaml @@ -5,11 +5,12 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false - fbank_conf: + feature_extraction_conf: + feature_type: 'fbank' num_mel_bins: 40 frame_shift: 10 frame_length: 25 - dither: 0.1 + dither: 1.0 spec_aug: false spec_aug_conf: num_t_mask: 2 @@ -24,9 +25,9 @@ dataset_conf: model: hidden_dim: 64 - subsampling: + preprocessing: type: linear - body: + backbone: type: tcn ds: false num_layers: 4 diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 650532f..27987f8 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -3,24 +3,24 @@ . ./path.sh -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export CUDA_VISIBLE_DEVICES="0" stage=0 stop_stage=4 num_keywords=2 -config=conf/ds_tcn.yaml -norm_mean=true -norm_var=true +config=conf/mdtc.yaml +norm_mean=false +norm_var=false gpu_id=0 checkpoint= -dir=exp/ds_tcn +dir=exp/mdtc -num_average=30 +num_average=10 score_checkpoint=$dir/avg_${num_average}.pt -download_dir=/export/expts6/binbinzhang/data/ +download_dir=./data/local # your data dir . tools/parse_options.sh || exit 1; @@ -34,19 +34,16 @@ fi if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then echo "Preparing datasets..." - mkdir dict + mkdir -p dict echo " -1" > dict/words.txt echo "Hi_Xiaowen 0" >> dict/words.txt echo "Nihao_Wenwen 1" >> dict/words.txt - for folder in train dev eval; do + for folder in train dev test; do mkdir -p data/$folder for prefix in p n; do mkdir -p data/${prefix}_$folder json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json - if [ $folder = "eval" ]; then - json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_test.json - fi local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \ data/${prefix}_$folder done @@ -63,7 +60,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --in_scp data/train/wav.scp \ --out_cmvn data/train/global_cmvn - for x in train dev eval; do + for x in train dev test; do tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur tools/make_list.py data/$x/wav.scp data/$x/text \ data/$x/wav.dur data/$x/data.list @@ -100,27 +97,31 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # Compute posterior score result_dir=$dir/test_$(basename $score_checkpoint) mkdir -p $result_dir - python kws/bin/score.py --gpu -1 \ + python kws/bin/score.py --gpu 1 \ --config $dir/config.yaml \ - --test_data data/eval/data.list \ + --test_data data/test/data.list \ --batch_size 256 \ --checkpoint $score_checkpoint \ --score_file $result_dir/score.txt +fi +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # Compute detection error tradeoff + result_dir=$dir/test_$(basename $score_checkpoint) for keyword in 0 1; do python kws/bin/compute_det.py \ --keyword $keyword \ - --test_data data/eval/data.list \ + --test_data data/test/data.list \ --score_file $result_dir/score.txt \ --stats_file $result_dir/stats.${keyword}.txt done fi -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then python kws/bin/export_jit.py --config $dir/config.yaml \ --checkpoint $score_checkpoint \ --output_file $dir/final.zip \ --output_quant_file $dir/final.quant.zip fi + diff --git a/kws/bin/export_jit.py b/kws/bin/export_jit.py index 7c1846e..81f48f9 100644 --- a/kws/bin/export_jit.py +++ b/kws/bin/export_jit.py @@ -56,8 +56,7 @@ def main(): # Export quantized jit torch script model if args.output_quant_file: quantized_model = torch.quantization.quantize_dynamic( - model, {torch.nn.Linear}, dtype=torch.qint8 - ) + model, {torch.nn.Linear}, dtype=torch.qint8) print(quantized_model) script_quant_model = torch.jit.script(quantized_model) script_quant_model.save(args.output_quant_file) diff --git a/kws/bin/train.py b/kws/bin/train.py index 150d333..915f176 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -135,7 +135,8 @@ def main(): num_workers=args.num_workers, prefetch_factor=args.prefetch) - input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] + input_dim = configs['dataset_conf']['feature_extraction_conf'][ + 'num_mel_bins'] output_dim = args.num_keywords # Write model_dir/config.yaml for inference and export @@ -160,9 +161,9 @@ def main(): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - if args.rank == 0: - script_model = torch.jit.script(model) - script_model.save(os.path.join(args.model_dir, 'init.zip')) + # if args.rank == 0: + # script_model = torch.jit.script(model) + # script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: diff --git a/kws/dataset/dataset.py b/kws/dataset/dataset.py index ffff668..f301b75 100644 --- a/kws/dataset/dataset.py +++ b/kws/dataset/dataset.py @@ -136,10 +136,13 @@ def Dataset(data_list_file, conf, partition=True): speed_perturb = conf.get('speed_perturb', False) if speed_perturb: dataset = Processor(dataset, processor.speed_perturb) - - fbank_conf = conf.get('fbank_conf', {}) - dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) - + feature_extraction_conf = conf.get('feature_extraction_conf', {}) + if feature_extraction_conf['feature_type'] == 'mfcc': + dataset = Processor(dataset, processor.compute_mfcc, + **feature_extraction_conf) + elif feature_extraction_conf['feature_type'] == 'fbank': + dataset = Processor(dataset, processor.compute_fbank, + **feature_extraction_conf) spec_aug = conf.get('spec_aug', True) if spec_aug: spec_aug_conf = conf.get('spec_aug_conf', {}) diff --git a/kws/dataset/processor.py b/kws/dataset/processor.py index 65aba99..0fd1d84 100644 --- a/kws/dataset/processor.py +++ b/kws/dataset/processor.py @@ -127,7 +127,47 @@ def speed_perturb(data, speeds=None): yield sample +def compute_mfcc( + data, + feature_type='mfcc', + num_ceps=80, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, +): + """Extract mfcc + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.mfcc( + waveform, + num_ceps=num_ceps, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate, + ) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + def compute_fbank(data, + feature_type='fbank', num_mel_bins=23, frame_length=25, frame_shift=10, diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index 314822c..ab543dd 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -20,45 +20,55 @@ import torch from kws.model.cmvn import GlobalCMVN from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1 from kws.model.tcn import TCN, CnnBlock, DsCnnBlock +from kws.model.mdtc import MDTC from kws.utils.cmvn import load_cmvn -class KwsModel(torch.nn.Module): - """ Our model consists of four parts: +class KWSModel(torch.nn.Module): + """Our model consists of four parts: 1. global_cmvn: Optional, (idim, idim) - 2. subsampling: subsampling the input, (idim, hdim) - 3. body: body of the whole network, (hdim, hdim) - 4. linear: a linear layer, (hdim, odim) + 2. preprocessing: feature dimention projection, (idim, hdim) + 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) + 4. classifier: output layer or classifier of KWS model, (hdim, odim) """ - def __init__(self, idim: int, odim: int, hdim: int, - global_cmvn: Optional[torch.nn.Module], - subsampling: torch.nn.Module, body: torch.nn.Module): + def __init__( + self, + idim: int, + odim: int, + hdim: int, + global_cmvn: Optional[torch.nn.Module], + preprocessing: Optional[torch.nn.Module], + backbone: torch.nn.Module, + ): super().__init__() self.idim = idim self.odim = odim self.hdim = hdim self.global_cmvn = global_cmvn - self.subsampling = subsampling - self.body = body - self.linear = torch.nn.Linear(hdim, odim) + self.preprocessing = preprocessing + self.backbone = backbone + self.classifier = torch.nn.Linear(hdim, odim) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.global_cmvn is not None: x = self.global_cmvn(x) - x = self.subsampling(x) - x, _ = self.body(x) - x = self.linear(x) + if self.preprocessing: + x = self.preprocessing(x) + x, _ = self.backbone(x) + x = self.classifier(x) x = torch.sigmoid(x) return x def init_model(configs): cmvn = configs.get('cmvn', {}) - if cmvn['cmvn_file'] is not None: + if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None: mean, istd = load_cmvn(cmvn['cmvn_file']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), - torch.from_numpy(istd).float(), cmvn['norm_var']) + torch.from_numpy(istd).float(), + cmvn['norm_var'], + ) else: global_cmvn = None @@ -66,36 +76,52 @@ def init_model(configs): output_dim = configs['output_dim'] hidden_dim = configs['hidden_dim'] - subsampling_type = configs['subsampling']['type'] - if subsampling_type == 'linear': - subsampling = LinearSubsampling1(input_dim, hidden_dim) - elif subsampling_type == 'cnn1d_s1': - subsampling = Conv1dSubsampling1(input_dim, hidden_dim) + prep_type = configs['preprocessing']['type'] + if prep_type == 'linear': + preprocessing = LinearSubsampling1(input_dim, hidden_dim) + elif prep_type == 'cnn1d_s1': + preprocessing = Conv1dSubsampling1(input_dim, hidden_dim) + elif prep_type == 'none': + preprocessing = None else: - print('Unknown subsampling type {}'.format(subsampling_type)) + print('Unknown preprocessing type {}'.format(prep_type)) sys.exit(1) - body_type = configs['body']['type'] - num_layers = configs['body']['num_layers'] - if body_type == 'gru': - body = torch.nn.GRU(hidden_dim, - hidden_dim, - num_layers=num_layers, - batch_first=True) - elif body_type == 'tcn': + backbone_type = configs['backbone']['type'] + if backbone_type == 'gru': + num_layers = configs['backbone']['num_layers'] + backbone = torch.nn.GRU(hidden_dim, + hidden_dim, + num_layers=num_layers, + batch_first=True) + elif backbone_type == 'tcn': # Depthwise Separable - ds = configs['body'].get('ds', False) + num_layers = configs['backbone']['num_layers'] + ds = configs['backbone'].get('ds', False) if ds: block_class = DsCnnBlock else: block_class = CnnBlock - kernel_size = configs['body'].get('kernel_size', 8) - dropout = configs['body'].get('drouput', 0.1) - body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class) + kernel_size = configs['backbone'].get('kernel_size', 8) + dropout = configs['backbone'].get('drouput', 0.1) + backbone = TCN(num_layers, hidden_dim, kernel_size, dropout, + block_class) + elif backbone_type == 'mdtc': + stack_size = configs['backbone']['stack_size'] + num_stack = configs['backbone']['num_stack'] + kernel_size = configs['backbone']['kernel_size'] + hidden_dim = configs['backbone']['hidden_dim'] + + backbone = MDTC(num_stack, + stack_size, + input_dim, + hidden_dim, + kernel_size, + causal=True) else: - print('Unknown body type {}'.format(body_type)) + print('Unknown body type {}'.format(backbone_type)) sys.exit(1) - kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn, - subsampling, body) + kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, + preprocessing, backbone) return kws_model diff --git a/kws/model/loss.py b/kws/model/loss.py index e871ef1..bd34951 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -13,7 +13,6 @@ # limitations under the License. import torch -import torch.nn.functional as F from kws.utils.mask import padding_mask diff --git a/kws/model/mdtc.py b/kws/model/mdtc.py new file mode 100644 index 0000000..a5ab929 --- /dev/null +++ b/kws/model/mdtc.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DSDilatedConv1d(nn.Module): + """Dilated Depthwise-Separable Convolution""" + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + dilation: int = 1, + stride: int = 1, + bias: bool = True, + ): + super(DSDilatedConv1d, self).__init__() + self.receptive_fields = dilation * (kernel_size - 1) + self.conv = nn.Conv1d( + in_channels, + in_channels, + kernel_size, + padding=0, + dilation=dilation, + stride=stride, + groups=in_channels, + bias=bias, + ) + self.bn = nn.BatchNorm1d(in_channels) + self.pointwise = nn.Conv1d(in_channels, + out_channels, + kernel_size=1, + padding=0, + dilation=1, + bias=bias) + + def forward(self, inputs: torch.Tensor): + outputs = self.conv(inputs) + outputs = self.bn(outputs) + outputs = self.pointwise(outputs) + return outputs + + +class TCNBlock(nn.Module): + def __init__( + self, + in_channels: int, + res_channels: int, + kernel_size: int, + dilation: int, + causal: bool, + ): + super(TCNBlock, self).__init__() + self.in_channels = in_channels + self.res_channels = res_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.causal = causal + self.receptive_fields = dilation * (kernel_size - 1) + self.half_receptive_fields = self.receptive_fields // 2 + self.conv1 = DSDilatedConv1d( + in_channels=in_channels, + out_channels=res_channels, + kernel_size=kernel_size, + dilation=dilation, + ) + self.bn1 = nn.BatchNorm1d(res_channels) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv1d(in_channels=res_channels, + out_channels=res_channels, + kernel_size=1) + self.bn2 = nn.BatchNorm1d(res_channels) + self.relu2 = nn.ReLU() + + def forward(self, inputs: torch.Tensor): + outputs = self.relu1(self.bn1(self.conv1(inputs))) + outputs = self.bn2(self.conv2(outputs)) + if self.causal: + inputs = inputs[:, :, self.receptive_fields:] + else: + inputs = inputs[:, :, self. + half_receptive_fields:-self.half_receptive_fields] + if self.in_channels == self.res_channels: + res_out = self.relu2(outputs + inputs) + else: + res_out = self.relu2(outputs) + return res_out + + +class TCNStack(nn.Module): + def __init__( + self, + in_channels: int, + stack_num: int, + stack_size: int, + res_channels: int, + kernel_size: int, + causal: bool, + ): + super(TCNStack, self).__init__() + self.in_channels = in_channels + self.stack_num = stack_num + self.stack_size = stack_size + self.res_channels = res_channels + self.kernel_size = kernel_size + self.causal = causal + self.res_blocks = self.stack_tcn_blocks() + self.receptive_fields = self.calculate_receptive_fields() + self.res_blocks = nn.Sequential(*self.res_blocks) + + def calculate_receptive_fields(self): + receptive_fields = 0 + for block in self.res_blocks: + receptive_fields += block.receptive_fields + return receptive_fields + + def build_dilations(self): + dilations = [] + for s in range(0, self.stack_size): + for l in range(0, self.stack_num): + dilations.append(2**l) + return dilations + + def stack_tcn_blocks(self): + dilations = self.build_dilations() + res_blocks = nn.ModuleList() + + res_blocks.append( + TCNBlock( + self.in_channels, + self.res_channels, + self.kernel_size, + dilations[0], + self.causal, + )) + for dilation in dilations[1:]: + res_blocks.append( + TCNBlock( + self.res_channels, + self.res_channels, + self.kernel_size, + dilation, + self.causal, + )) + return res_blocks + + def forward(self, inputs: torch.Tensor): + outputs = inputs + outputs = self.res_blocks(outputs) + return outputs + + +class MDTC(nn.Module): + """Multi-scale Depthwise Temporal Convolution (MDTC). + In MDTC, stacked depthwise one-dimensional (1-D) convolution with + dilated connections is adopted to efficiently model long-range + dependency of speech. With a large receptive field while + keeping a small number of model parameters, the structure + can model temporal context of speech effectively. It aslo + extracts multi-scale features from different hidden layers + of MDTC with different receptive fields. + """ + def __init__( + self, + stack_num: int, + stack_size: int, + in_channels: int, + res_channels: int, + kernel_size: int, + causal: bool, + ): + super(MDTC, self).__init__() + self.kernel_size = kernel_size + self.causal = causal + self.preprocessor = TCNBlock(in_channels, + res_channels, + kernel_size, + dilation=1, + causal=causal) + self.relu = nn.ReLU() + self.blocks = nn.ModuleList() + self.receptive_fields = self.preprocessor.receptive_fields + for i in range(stack_num): + self.blocks.append( + TCNStack(res_channels, stack_size, 1, res_channels, + kernel_size, causal)) + self.receptive_fields += self.blocks[-1].receptive_fields + self.half_receptive_fields = self.receptive_fields // 2 + print('Receptive Fields: %d' % self.receptive_fields) + + def normalize_length_causal(self, skip_connections: list): + output_size = skip_connections[-1].shape[-1] + normalized_outputs = [] + for x in skip_connections: + remove_length = x.shape[-1] - output_size + if remove_length != 0: + normalized_outputs.append(x[:, :, remove_length:]) + else: + normalized_outputs.append(x) + return normalized_outputs + + def normalize_length(self, skip_connections: list): + output_size = skip_connections[-1].shape[-1] + normalized_outputs = [] + for x in skip_connections: + remove_length = (x.shape[-1] - output_size) // 2 + if remove_length != 0: + normalized_outputs.append(x[:, :, + remove_length:-remove_length]) + else: + normalized_outputs.append(x) + return normalized_outputs + + def forward(self, x: torch.Tensor): + if self.causal: + outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), + 'constant') + else: + outputs = F.pad( + x, + (0, 0, self.half_receptive_fields, self.half_receptive_fields, + 0, 0), + 'constant', + ) + outputs = outputs.transpose(1, 2) + outputs_list = [] + outputs = self.relu(self.preprocessor(outputs)) + for i in range(len(self.blocks)): + outputs = self.blocks[i](outputs) + outputs_list.append(outputs) + + if self.causal: + outputs_list = self.normalize_length_causal(outputs_list) + else: + outputs_list = self.normalize_length(outputs_list) + + outputs = sum(outputs_list) + outputs = outputs.transpose(1, 2) + return outputs, None + + +if __name__ == '__main__': + mdtc = MDTC(3, 4, 80, 64, 5, causal=True) + print(mdtc) + + num_params = sum(p.numel() for p in mdtc.parameters()) + print('the number of model params: {}'.format(num_params)) + x = torch.zeros(128, 200, 80) # batch-size * time * dim + y, _ = mdtc(x) # batch-size * time * dim + print('input shape: {}'.format(x.shape)) + print('output shape: {}'.format(y.shape)) diff --git a/kws/model/tcn.py b/kws/model/tcn.py index 7dffa5b..6a002b1 100644 --- a/kws/model/tcn.py +++ b/kws/model/tcn.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Optional import torch import torch.nn as nn diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index 25ae546..95b7e0c 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io") class CollateFunc(object): ''' Collate function for AudioDataset ''' - def __init__(self, feat_dim, resample_rate): + def __init__(self, feat_dim, feat_type, resample_rate): self.feat_dim = feat_dim self.resample_rate = resample_rate + self.feat_type = feat_type pass def __call__(self, batch): @@ -31,7 +32,8 @@ class CollateFunc(object): value = item[1].strip().split(",") assert len(value) == 3 or len(value) == 1 wav_path = value[0] - sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_path).sample_rate resample_rate = sample_rate # len(value) == 3 means segmented wav.scp, # len(value) == 1 means original wav.scp @@ -50,12 +52,21 @@ class CollateFunc(object): resample_rate = self.resample_rate waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) - - mat = kaldi.fbank(waveform, - num_mel_bins=self.feat_dim, - dither=0.0, - energy_floor=0.0, - sample_frequency=resample_rate) + if self.feat_type == 'fbank': + mat = kaldi.fbank(waveform, + num_mel_bins=self.feat_dim, + dither=0.0, + energy_floor=0.0, + sample_frequency=resample_rate) + elif self.feat_type == 'mfcc': + mat = kaldi.mfcc( + waveform, + num_ceps=self.feat_dim, + num_mel_bins=self.feat_dim, + dither=0.0, + energy_floor=0.0, + sample_frequency=resample_rate, + ) mean_stat += torch.sum(mat, axis=0) var_stat += torch.sum(torch.square(mat), axis=0) number += mat.shape[0] @@ -95,13 +106,17 @@ if __name__ == '__main__': with open(args.train_config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) - feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] + feat_dim = configs['dataset_conf']['feature_extraction_conf'][ + 'num_mel_bins'] + feat_type = configs['dataset_conf']['feature_extraction_conf'][ + 'feature_type'] resample_rate = 0 if 'resample_conf' in configs['dataset_conf']: - resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] + resample_rate = configs['dataset_conf']['resample_conf'][ + 'resample_rate'] print('using resample and new sample rate is {}'.format(resample_rate)) - collate_func = CollateFunc(feat_dim, resample_rate) + collate_func = CollateFunc(feat_dim, feat_type, resample_rate) dataset = AudioDataset(args.in_scp) batch_size = 20 data_loader = DataLoader(dataset, diff --git a/tools/wav2dur.py b/tools/wav2dur.py index 1bcc1b6..b53a7fe 100755 --- a/tools/wav2dur.py +++ b/tools/wav2dur.py @@ -4,6 +4,7 @@ import sys import torchaudio + torchaudio.set_audio_backend("sox_io") scp = sys.argv[1]