diff --git a/examples/hi_xiaowen/s0/conf/ds_tcn.yaml b/examples/hi_xiaowen/s0/conf/ds_tcn.yaml index 7da0181..88b0c78 100644 --- a/examples/hi_xiaowen/s0/conf/ds_tcn.yaml +++ b/examples/hi_xiaowen/s0/conf/ds_tcn.yaml @@ -5,11 +5,12 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false - fbank_conf: + feature_extraction_conf: + feature_type: 'fbank' num_mel_bins: 40 frame_shift: 10 frame_length: 25 - dither: 0.1 + dither: 1.0 spec_aug: true spec_aug_conf: num_t_mask: 1 @@ -24,9 +25,9 @@ dataset_conf: model: hidden_dim: 64 - subsampling: + preprocessing: type: linear - body: + backbone: type: tcn ds: true num_layers: 4 diff --git a/examples/hi_xiaowen/s0/conf/gru.yaml b/examples/hi_xiaowen/s0/conf/gru.yaml index e664319..8e63900 100644 --- a/examples/hi_xiaowen/s0/conf/gru.yaml +++ b/examples/hi_xiaowen/s0/conf/gru.yaml @@ -5,11 +5,12 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false - fbank_conf: + feature_extraction_conf: + feature_type: 'fbank' num_mel_bins: 40 frame_shift: 10 frame_length: 25 - dither: 0.1 + dither: 1.0 spec_aug: false spec_aug_conf: num_t_mask: 2 @@ -24,9 +25,9 @@ dataset_conf: model: hidden_dim: 128 - subsampling: + preprocessing: type: linear - body: + backbone: type: gru num_layers: 2 diff --git a/examples/hi_xiaowen/s0/conf/mdtc.yaml b/examples/hi_xiaowen/s0/conf/mdtc.yaml index 4549c6e..8f0bcc1 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc.yaml @@ -1,7 +1,4 @@ -debug: false -input_dim: 80 -output_dim: 2 dataset_conf: filter_conf: @@ -38,6 +35,8 @@ dataset_conf: batch_size: 100 model: + input_dim: 80 + output_dim: 2 hidden_dim: 64 preprocessing: type: none diff --git a/examples/hi_xiaowen/s0/conf/tcn.yaml b/examples/hi_xiaowen/s0/conf/tcn.yaml index 0612634..517b94e 100644 --- a/examples/hi_xiaowen/s0/conf/tcn.yaml +++ b/examples/hi_xiaowen/s0/conf/tcn.yaml @@ -5,11 +5,12 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false - fbank_conf: + feature_extraction_conf: + feature_type: 'fbank' num_mel_bins: 40 frame_shift: 10 frame_length: 25 - dither: 0.1 + dither: 1.0 spec_aug: false spec_aug_conf: num_t_mask: 2 @@ -24,9 +25,9 @@ dataset_conf: model: hidden_dim: 64 - subsampling: + preprocessing: type: linear - body: + backbone: type: tcn ds: false num_layers: 4 diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 3d587fd..15af9d3 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -5,8 +5,8 @@ export CUDA_VISIBLE_DEVICES="0" -stage=2 -stop_stage=2 +stage=0 +stop_stage=4 num_keywords=2 config=conf/mdtc.yaml diff --git a/kws/bin/train.py b/kws/bin/train.py index ddb6dfd..55c87ec 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -135,7 +135,7 @@ def main(): num_workers=args.num_workers, prefetch_factor=args.prefetch) - input_dim = configs['input_dim'] + input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins'] output_dim = args.num_keywords # Write model_dir/config.yaml for inference and export diff --git a/kws/dataset/processor.py b/kws/dataset/processor.py index cfe5df3..9108b23 100644 --- a/kws/dataset/processor.py +++ b/kws/dataset/processor.py @@ -165,6 +165,7 @@ def compute_mfcc( yield dict(key=sample['key'], label=sample['label'], feat=mat) def compute_fbank(data, + feature_type='fbank', num_mel_bins=23, frame_length=25, frame_shift=10, diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index 301a53c..be8df99 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io") class CollateFunc(object): ''' Collate function for AudioDataset ''' - def __init__(self, feat_dim, resample_rate): + def __init__(self, feat_dim, feat_type, resample_rate): self.feat_dim = feat_dim self.resample_rate = resample_rate + self.feat_type = feat_type pass def __call__(self, batch): @@ -50,12 +51,20 @@ class CollateFunc(object): resample_rate = self.resample_rate waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) - - mat = kaldi.fbank(waveform, + if self.feat_type == 'fbank': + mat = kaldi.fbank(waveform, num_mel_bins=self.feat_dim, dither=0.0, energy_floor=0.0, sample_frequency=resample_rate) + elif self.feat_type == 'mfcc': + mat = kaldi.mfcc(waveform, + num_ceps=self.feat_dim, + num_mel_bins=self.feat_dim, + dither=0.0, + energy_floor=0.0, + sample_frequency=resample_rate, + ) mean_stat += torch.sum(mat, axis=0) var_stat += torch.sum(torch.square(mat), axis=0) number += mat.shape[0] @@ -95,13 +104,14 @@ if __name__ == '__main__': with open(args.train_config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) - feat_dim = configs['input_dim'] + feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins'] + feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type'] resample_rate = 0 if 'resample_conf' in configs['dataset_conf']: resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] print('using resample and new sample rate is {}'.format(resample_rate)) - collate_func = CollateFunc(feat_dim, resample_rate) + collate_func = CollateFunc(feat_dim, feat_type, resample_rate) dataset = AudioDataset(args.in_scp) batch_size = 20 data_loader = DataLoader(dataset,