fixed bug of compute_cmvn_stats.py
This commit is contained in:
parent
4db050eb67
commit
7df9ced666
@ -5,11 +5,12 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
fbank_conf:
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
frame_shift: 10
|
frame_shift: 10
|
||||||
frame_length: 25
|
frame_length: 25
|
||||||
dither: 0.1
|
dither: 1.0
|
||||||
spec_aug: true
|
spec_aug: true
|
||||||
spec_aug_conf:
|
spec_aug_conf:
|
||||||
num_t_mask: 1
|
num_t_mask: 1
|
||||||
@ -24,9 +25,9 @@ dataset_conf:
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
subsampling:
|
preprocessing:
|
||||||
type: linear
|
type: linear
|
||||||
body:
|
backbone:
|
||||||
type: tcn
|
type: tcn
|
||||||
ds: true
|
ds: true
|
||||||
num_layers: 4
|
num_layers: 4
|
||||||
|
|||||||
@ -5,11 +5,12 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
fbank_conf:
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
frame_shift: 10
|
frame_shift: 10
|
||||||
frame_length: 25
|
frame_length: 25
|
||||||
dither: 0.1
|
dither: 1.0
|
||||||
spec_aug: false
|
spec_aug: false
|
||||||
spec_aug_conf:
|
spec_aug_conf:
|
||||||
num_t_mask: 2
|
num_t_mask: 2
|
||||||
@ -24,9 +25,9 @@ dataset_conf:
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
hidden_dim: 128
|
hidden_dim: 128
|
||||||
subsampling:
|
preprocessing:
|
||||||
type: linear
|
type: linear
|
||||||
body:
|
backbone:
|
||||||
type: gru
|
type: gru
|
||||||
num_layers: 2
|
num_layers: 2
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
debug: false
|
|
||||||
|
|
||||||
input_dim: 80
|
|
||||||
output_dim: 2
|
|
||||||
|
|
||||||
dataset_conf:
|
dataset_conf:
|
||||||
filter_conf:
|
filter_conf:
|
||||||
@ -38,6 +35,8 @@ dataset_conf:
|
|||||||
batch_size: 100
|
batch_size: 100
|
||||||
|
|
||||||
model:
|
model:
|
||||||
|
input_dim: 80
|
||||||
|
output_dim: 2
|
||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
preprocessing:
|
preprocessing:
|
||||||
type: none
|
type: none
|
||||||
|
|||||||
@ -5,11 +5,12 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
fbank_conf:
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
frame_shift: 10
|
frame_shift: 10
|
||||||
frame_length: 25
|
frame_length: 25
|
||||||
dither: 0.1
|
dither: 1.0
|
||||||
spec_aug: false
|
spec_aug: false
|
||||||
spec_aug_conf:
|
spec_aug_conf:
|
||||||
num_t_mask: 2
|
num_t_mask: 2
|
||||||
@ -24,9 +25,9 @@ dataset_conf:
|
|||||||
|
|
||||||
model:
|
model:
|
||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
subsampling:
|
preprocessing:
|
||||||
type: linear
|
type: linear
|
||||||
body:
|
backbone:
|
||||||
type: tcn
|
type: tcn
|
||||||
ds: false
|
ds: false
|
||||||
num_layers: 4
|
num_layers: 4
|
||||||
|
|||||||
@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
stage=2
|
stage=0
|
||||||
stop_stage=2
|
stop_stage=4
|
||||||
num_keywords=2
|
num_keywords=2
|
||||||
|
|
||||||
config=conf/mdtc.yaml
|
config=conf/mdtc.yaml
|
||||||
|
|||||||
@ -135,7 +135,7 @@ def main():
|
|||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
prefetch_factor=args.prefetch)
|
prefetch_factor=args.prefetch)
|
||||||
|
|
||||||
input_dim = configs['input_dim']
|
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
||||||
output_dim = args.num_keywords
|
output_dim = args.num_keywords
|
||||||
|
|
||||||
# Write model_dir/config.yaml for inference and export
|
# Write model_dir/config.yaml for inference and export
|
||||||
|
|||||||
@ -165,6 +165,7 @@ def compute_mfcc(
|
|||||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||||
|
|
||||||
def compute_fbank(data,
|
def compute_fbank(data,
|
||||||
|
feature_type='fbank',
|
||||||
num_mel_bins=23,
|
num_mel_bins=23,
|
||||||
frame_length=25,
|
frame_length=25,
|
||||||
frame_shift=10,
|
frame_shift=10,
|
||||||
|
|||||||
@ -18,9 +18,10 @@ torchaudio.set_audio_backend("sox_io")
|
|||||||
class CollateFunc(object):
|
class CollateFunc(object):
|
||||||
''' Collate function for AudioDataset
|
''' Collate function for AudioDataset
|
||||||
'''
|
'''
|
||||||
def __init__(self, feat_dim, resample_rate):
|
def __init__(self, feat_dim, feat_type, resample_rate):
|
||||||
self.feat_dim = feat_dim
|
self.feat_dim = feat_dim
|
||||||
self.resample_rate = resample_rate
|
self.resample_rate = resample_rate
|
||||||
|
self.feat_type = feat_type
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
@ -50,12 +51,20 @@ class CollateFunc(object):
|
|||||||
resample_rate = self.resample_rate
|
resample_rate = self.resample_rate
|
||||||
waveform = torchaudio.transforms.Resample(
|
waveform = torchaudio.transforms.Resample(
|
||||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||||
|
if self.feat_type == 'fbank':
|
||||||
mat = kaldi.fbank(waveform,
|
mat = kaldi.fbank(waveform,
|
||||||
num_mel_bins=self.feat_dim,
|
num_mel_bins=self.feat_dim,
|
||||||
dither=0.0,
|
dither=0.0,
|
||||||
energy_floor=0.0,
|
energy_floor=0.0,
|
||||||
sample_frequency=resample_rate)
|
sample_frequency=resample_rate)
|
||||||
|
elif self.feat_type == 'mfcc':
|
||||||
|
mat = kaldi.mfcc(waveform,
|
||||||
|
num_ceps=self.feat_dim,
|
||||||
|
num_mel_bins=self.feat_dim,
|
||||||
|
dither=0.0,
|
||||||
|
energy_floor=0.0,
|
||||||
|
sample_frequency=resample_rate,
|
||||||
|
)
|
||||||
mean_stat += torch.sum(mat, axis=0)
|
mean_stat += torch.sum(mat, axis=0)
|
||||||
var_stat += torch.sum(torch.square(mat), axis=0)
|
var_stat += torch.sum(torch.square(mat), axis=0)
|
||||||
number += mat.shape[0]
|
number += mat.shape[0]
|
||||||
@ -95,13 +104,14 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
with open(args.train_config, 'r') as fin:
|
with open(args.train_config, 'r') as fin:
|
||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
feat_dim = configs['input_dim']
|
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
||||||
|
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
|
||||||
resample_rate = 0
|
resample_rate = 0
|
||||||
if 'resample_conf' in configs['dataset_conf']:
|
if 'resample_conf' in configs['dataset_conf']:
|
||||||
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
||||||
print('using resample and new sample rate is {}'.format(resample_rate))
|
print('using resample and new sample rate is {}'.format(resample_rate))
|
||||||
|
|
||||||
collate_func = CollateFunc(feat_dim, resample_rate)
|
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
|
||||||
dataset = AudioDataset(args.in_scp)
|
dataset = AudioDataset(args.in_scp)
|
||||||
batch_size = 20
|
batch_size = 20
|
||||||
data_loader = DataLoader(dataset,
|
data_loader = DataLoader(dataset,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user