format the code
This commit is contained in:
parent
7df9ced666
commit
0942092426
@ -56,8 +56,7 @@ def main():
|
||||
# Export quantized jit torch script model
|
||||
if args.output_quant_file:
|
||||
quantized_model = torch.quantization.quantize_dynamic(
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||
print(quantized_model)
|
||||
script_quant_model = torch.jit.script(quantized_model)
|
||||
script_quant_model.save(args.output_quant_file)
|
||||
|
||||
@ -135,7 +135,8 @@ def main():
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
|
||||
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
||||
input_dim = configs['dataset_conf']['feature_extraction_conf'][
|
||||
'num_mel_bins']
|
||||
output_dim = args.num_keywords
|
||||
|
||||
# Write model_dir/config.yaml for inference and export
|
||||
|
||||
@ -126,6 +126,7 @@ def speed_perturb(data, speeds=None):
|
||||
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_mfcc(
|
||||
data,
|
||||
feature_type='mfcc',
|
||||
@ -164,6 +165,7 @@ def compute_mfcc(
|
||||
)
|
||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||
|
||||
|
||||
def compute_fbank(data,
|
||||
feature_type='fbank',
|
||||
num_mel_bins=23,
|
||||
|
||||
@ -32,7 +32,8 @@ class CollateFunc(object):
|
||||
value = item[1].strip().split(",")
|
||||
assert len(value) == 3 or len(value) == 1
|
||||
wav_path = value[0]
|
||||
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
|
||||
sample_rate = torchaudio.backend.sox_io_backend.info(
|
||||
wav_path).sample_rate
|
||||
resample_rate = sample_rate
|
||||
# len(value) == 3 means segmented wav.scp,
|
||||
# len(value) == 1 means original wav.scp
|
||||
@ -58,7 +59,8 @@ class CollateFunc(object):
|
||||
energy_floor=0.0,
|
||||
sample_frequency=resample_rate)
|
||||
elif self.feat_type == 'mfcc':
|
||||
mat = kaldi.mfcc(waveform,
|
||||
mat = kaldi.mfcc(
|
||||
waveform,
|
||||
num_ceps=self.feat_dim,
|
||||
num_mel_bins=self.feat_dim,
|
||||
dither=0.0,
|
||||
@ -104,11 +106,14 @@ if __name__ == '__main__':
|
||||
|
||||
with open(args.train_config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
||||
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
|
||||
feat_dim = configs['dataset_conf']['feature_extraction_conf'][
|
||||
'num_mel_bins']
|
||||
feat_type = configs['dataset_conf']['feature_extraction_conf'][
|
||||
'feature_type']
|
||||
resample_rate = 0
|
||||
if 'resample_conf' in configs['dataset_conf']:
|
||||
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
||||
resample_rate = configs['dataset_conf']['resample_conf'][
|
||||
'resample_rate']
|
||||
print('using resample and new sample rate is {}'.format(resample_rate))
|
||||
|
||||
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import sys
|
||||
|
||||
import torchaudio
|
||||
|
||||
torchaudio.set_audio_backend("sox_io")
|
||||
|
||||
scp = sys.argv[1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user