fixed bug of compute_cmvn_stats.py

This commit is contained in:
Jingyong Hou 2021-11-10 22:40:21 +08:00
parent 4db050eb67
commit 7df9ced666
8 changed files with 36 additions and 23 deletions

View File

@ -5,11 +5,12 @@ dataset_conf:
resample_conf: resample_conf:
resample_rate: 16000 resample_rate: 16000
speed_perturb: false speed_perturb: false
fbank_conf: feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40 num_mel_bins: 40
frame_shift: 10 frame_shift: 10
frame_length: 25 frame_length: 25
dither: 0.1 dither: 1.0
spec_aug: true spec_aug: true
spec_aug_conf: spec_aug_conf:
num_t_mask: 1 num_t_mask: 1
@ -24,9 +25,9 @@ dataset_conf:
model: model:
hidden_dim: 64 hidden_dim: 64
subsampling: preprocessing:
type: linear type: linear
body: backbone:
type: tcn type: tcn
ds: true ds: true
num_layers: 4 num_layers: 4

View File

@ -5,11 +5,12 @@ dataset_conf:
resample_conf: resample_conf:
resample_rate: 16000 resample_rate: 16000
speed_perturb: false speed_perturb: false
fbank_conf: feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40 num_mel_bins: 40
frame_shift: 10 frame_shift: 10
frame_length: 25 frame_length: 25
dither: 0.1 dither: 1.0
spec_aug: false spec_aug: false
spec_aug_conf: spec_aug_conf:
num_t_mask: 2 num_t_mask: 2
@ -24,9 +25,9 @@ dataset_conf:
model: model:
hidden_dim: 128 hidden_dim: 128
subsampling: preprocessing:
type: linear type: linear
body: backbone:
type: gru type: gru
num_layers: 2 num_layers: 2

View File

@ -1,7 +1,4 @@
debug: false
input_dim: 80
output_dim: 2
dataset_conf: dataset_conf:
filter_conf: filter_conf:
@ -38,6 +35,8 @@ dataset_conf:
batch_size: 100 batch_size: 100
model: model:
input_dim: 80
output_dim: 2
hidden_dim: 64 hidden_dim: 64
preprocessing: preprocessing:
type: none type: none

View File

@ -5,11 +5,12 @@ dataset_conf:
resample_conf: resample_conf:
resample_rate: 16000 resample_rate: 16000
speed_perturb: false speed_perturb: false
fbank_conf: feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40 num_mel_bins: 40
frame_shift: 10 frame_shift: 10
frame_length: 25 frame_length: 25
dither: 0.1 dither: 1.0
spec_aug: false spec_aug: false
spec_aug_conf: spec_aug_conf:
num_t_mask: 2 num_t_mask: 2
@ -24,9 +25,9 @@ dataset_conf:
model: model:
hidden_dim: 64 hidden_dim: 64
subsampling: preprocessing:
type: linear type: linear
body: backbone:
type: tcn type: tcn
ds: false ds: false
num_layers: 4 num_layers: 4

View File

@ -5,8 +5,8 @@
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
stage=2 stage=0
stop_stage=2 stop_stage=4
num_keywords=2 num_keywords=2
config=conf/mdtc.yaml config=conf/mdtc.yaml

View File

@ -135,7 +135,7 @@ def main():
num_workers=args.num_workers, num_workers=args.num_workers,
prefetch_factor=args.prefetch) prefetch_factor=args.prefetch)
input_dim = configs['input_dim'] input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
output_dim = args.num_keywords output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export # Write model_dir/config.yaml for inference and export

View File

@ -165,6 +165,7 @@ def compute_mfcc(
yield dict(key=sample['key'], label=sample['label'], feat=mat) yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_fbank(data, def compute_fbank(data,
feature_type='fbank',
num_mel_bins=23, num_mel_bins=23,
frame_length=25, frame_length=25,
frame_shift=10, frame_shift=10,

View File

@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io")
class CollateFunc(object): class CollateFunc(object):
''' Collate function for AudioDataset ''' Collate function for AudioDataset
''' '''
def __init__(self, feat_dim, resample_rate): def __init__(self, feat_dim, feat_type, resample_rate):
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.resample_rate = resample_rate self.resample_rate = resample_rate
self.feat_type = feat_type
pass pass
def __call__(self, batch): def __call__(self, batch):
@ -50,12 +51,20 @@ class CollateFunc(object):
resample_rate = self.resample_rate resample_rate = self.resample_rate
waveform = torchaudio.transforms.Resample( waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform) orig_freq=sample_rate, new_freq=resample_rate)(waveform)
if self.feat_type == 'fbank':
mat = kaldi.fbank(waveform, mat = kaldi.fbank(waveform,
num_mel_bins=self.feat_dim, num_mel_bins=self.feat_dim,
dither=0.0, dither=0.0,
energy_floor=0.0, energy_floor=0.0,
sample_frequency=resample_rate) sample_frequency=resample_rate)
elif self.feat_type == 'mfcc':
mat = kaldi.mfcc(waveform,
num_ceps=self.feat_dim,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate,
)
mean_stat += torch.sum(mat, axis=0) mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0) var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0] number += mat.shape[0]
@ -95,13 +104,14 @@ if __name__ == '__main__':
with open(args.train_config, 'r') as fin: with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader) configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['input_dim'] feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
resample_rate = 0 resample_rate = 0
if 'resample_conf' in configs['dataset_conf']: if 'resample_conf' in configs['dataset_conf']:
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
print('using resample and new sample rate is {}'.format(resample_rate)) print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, resample_rate) collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
dataset = AudioDataset(args.in_scp) dataset = AudioDataset(args.in_scp)
batch_size = 20 batch_size = 20
data_loader = DataLoader(dataset, data_loader = DataLoader(dataset,