fixed bug of compute_cmvn_stats.py
This commit is contained in:
parent
4db050eb67
commit
7df9ced666
@ -5,11 +5,12 @@ dataset_conf:
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
fbank_conf:
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 0.1
|
||||
dither: 1.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
@ -24,9 +25,9 @@ dataset_conf:
|
||||
|
||||
model:
|
||||
hidden_dim: 64
|
||||
subsampling:
|
||||
preprocessing:
|
||||
type: linear
|
||||
body:
|
||||
backbone:
|
||||
type: tcn
|
||||
ds: true
|
||||
num_layers: 4
|
||||
|
||||
@ -5,11 +5,12 @@ dataset_conf:
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
fbank_conf:
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 0.1
|
||||
dither: 1.0
|
||||
spec_aug: false
|
||||
spec_aug_conf:
|
||||
num_t_mask: 2
|
||||
@ -24,9 +25,9 @@ dataset_conf:
|
||||
|
||||
model:
|
||||
hidden_dim: 128
|
||||
subsampling:
|
||||
preprocessing:
|
||||
type: linear
|
||||
body:
|
||||
backbone:
|
||||
type: gru
|
||||
num_layers: 2
|
||||
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
debug: false
|
||||
|
||||
input_dim: 80
|
||||
output_dim: 2
|
||||
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
@ -38,6 +35,8 @@ dataset_conf:
|
||||
batch_size: 100
|
||||
|
||||
model:
|
||||
input_dim: 80
|
||||
output_dim: 2
|
||||
hidden_dim: 64
|
||||
preprocessing:
|
||||
type: none
|
||||
|
||||
@ -5,11 +5,12 @@ dataset_conf:
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
fbank_conf:
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 0.1
|
||||
dither: 1.0
|
||||
spec_aug: false
|
||||
spec_aug_conf:
|
||||
num_t_mask: 2
|
||||
@ -24,9 +25,9 @@ dataset_conf:
|
||||
|
||||
model:
|
||||
hidden_dim: 64
|
||||
subsampling:
|
||||
preprocessing:
|
||||
type: linear
|
||||
body:
|
||||
backbone:
|
||||
type: tcn
|
||||
ds: false
|
||||
num_layers: 4
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=2
|
||||
stop_stage=2
|
||||
stage=0
|
||||
stop_stage=4
|
||||
num_keywords=2
|
||||
|
||||
config=conf/mdtc.yaml
|
||||
|
||||
@ -135,7 +135,7 @@ def main():
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
|
||||
input_dim = configs['input_dim']
|
||||
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
||||
output_dim = args.num_keywords
|
||||
|
||||
# Write model_dir/config.yaml for inference and export
|
||||
|
||||
@ -165,6 +165,7 @@ def compute_mfcc(
|
||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||
|
||||
def compute_fbank(data,
|
||||
feature_type='fbank',
|
||||
num_mel_bins=23,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
|
||||
@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io")
|
||||
class CollateFunc(object):
|
||||
''' Collate function for AudioDataset
|
||||
'''
|
||||
def __init__(self, feat_dim, resample_rate):
|
||||
def __init__(self, feat_dim, feat_type, resample_rate):
|
||||
self.feat_dim = feat_dim
|
||||
self.resample_rate = resample_rate
|
||||
self.feat_type = feat_type
|
||||
pass
|
||||
|
||||
def __call__(self, batch):
|
||||
@ -50,12 +51,20 @@ class CollateFunc(object):
|
||||
resample_rate = self.resample_rate
|
||||
waveform = torchaudio.transforms.Resample(
|
||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||
|
||||
mat = kaldi.fbank(waveform,
|
||||
if self.feat_type == 'fbank':
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.feat_dim,
|
||||
dither=0.0,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=resample_rate)
|
||||
elif self.feat_type == 'mfcc':
|
||||
mat = kaldi.mfcc(waveform,
|
||||
num_ceps=self.feat_dim,
|
||||
num_mel_bins=self.feat_dim,
|
||||
dither=0.0,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=resample_rate,
|
||||
)
|
||||
mean_stat += torch.sum(mat, axis=0)
|
||||
var_stat += torch.sum(torch.square(mat), axis=0)
|
||||
number += mat.shape[0]
|
||||
@ -95,13 +104,14 @@ if __name__ == '__main__':
|
||||
|
||||
with open(args.train_config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feat_dim = configs['input_dim']
|
||||
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
||||
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
|
||||
resample_rate = 0
|
||||
if 'resample_conf' in configs['dataset_conf']:
|
||||
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
||||
print('using resample and new sample rate is {}'.format(resample_rate))
|
||||
|
||||
collate_func = CollateFunc(feat_dim, resample_rate)
|
||||
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
|
||||
dataset = AudioDataset(args.in_scp)
|
||||
batch_size = 20
|
||||
data_loader = DataLoader(dataset,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user