format the code
This commit is contained in:
parent
7df9ced666
commit
0942092426
@ -56,8 +56,7 @@ def main():
|
|||||||
# Export quantized jit torch script model
|
# Export quantized jit torch script model
|
||||||
if args.output_quant_file:
|
if args.output_quant_file:
|
||||||
quantized_model = torch.quantization.quantize_dynamic(
|
quantized_model = torch.quantization.quantize_dynamic(
|
||||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
model, {torch.nn.Linear}, dtype=torch.qint8)
|
||||||
)
|
|
||||||
print(quantized_model)
|
print(quantized_model)
|
||||||
script_quant_model = torch.jit.script(quantized_model)
|
script_quant_model = torch.jit.script(quantized_model)
|
||||||
script_quant_model.save(args.output_quant_file)
|
script_quant_model.save(args.output_quant_file)
|
||||||
|
|||||||
@ -135,7 +135,8 @@ def main():
|
|||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
prefetch_factor=args.prefetch)
|
prefetch_factor=args.prefetch)
|
||||||
|
|
||||||
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
input_dim = configs['dataset_conf']['feature_extraction_conf'][
|
||||||
|
'num_mel_bins']
|
||||||
output_dim = args.num_keywords
|
output_dim = args.num_keywords
|
||||||
|
|
||||||
# Write model_dir/config.yaml for inference and export
|
# Write model_dir/config.yaml for inference and export
|
||||||
@ -161,8 +162,8 @@ def main():
|
|||||||
# Try to export the model by script, if fails, we should refine
|
# Try to export the model by script, if fails, we should refine
|
||||||
# the code to satisfy the script export requirements
|
# the code to satisfy the script export requirements
|
||||||
#if args.rank == 0:
|
#if args.rank == 0:
|
||||||
#script_model = torch.jit.script(model)
|
#script_model = torch.jit.script(model)
|
||||||
#script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
#script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||||
executor = Executor()
|
executor = Executor()
|
||||||
# If specify checkpoint, load some info from checkpoint
|
# If specify checkpoint, load some info from checkpoint
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
|
|||||||
@ -126,6 +126,7 @@ def speed_perturb(data, speeds=None):
|
|||||||
|
|
||||||
yield sample
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
def compute_mfcc(
|
def compute_mfcc(
|
||||||
data,
|
data,
|
||||||
feature_type='mfcc',
|
feature_type='mfcc',
|
||||||
@ -164,6 +165,7 @@ def compute_mfcc(
|
|||||||
)
|
)
|
||||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank(data,
|
def compute_fbank(data,
|
||||||
feature_type='fbank',
|
feature_type='fbank',
|
||||||
num_mel_bins=23,
|
num_mel_bins=23,
|
||||||
|
|||||||
@ -32,7 +32,8 @@ class CollateFunc(object):
|
|||||||
value = item[1].strip().split(",")
|
value = item[1].strip().split(",")
|
||||||
assert len(value) == 3 or len(value) == 1
|
assert len(value) == 3 or len(value) == 1
|
||||||
wav_path = value[0]
|
wav_path = value[0]
|
||||||
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
|
sample_rate = torchaudio.backend.sox_io_backend.info(
|
||||||
|
wav_path).sample_rate
|
||||||
resample_rate = sample_rate
|
resample_rate = sample_rate
|
||||||
# len(value) == 3 means segmented wav.scp,
|
# len(value) == 3 means segmented wav.scp,
|
||||||
# len(value) == 1 means original wav.scp
|
# len(value) == 1 means original wav.scp
|
||||||
@ -53,18 +54,19 @@ class CollateFunc(object):
|
|||||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||||
if self.feat_type == 'fbank':
|
if self.feat_type == 'fbank':
|
||||||
mat = kaldi.fbank(waveform,
|
mat = kaldi.fbank(waveform,
|
||||||
num_mel_bins=self.feat_dim,
|
num_mel_bins=self.feat_dim,
|
||||||
dither=0.0,
|
dither=0.0,
|
||||||
energy_floor=0.0,
|
energy_floor=0.0,
|
||||||
sample_frequency=resample_rate)
|
sample_frequency=resample_rate)
|
||||||
elif self.feat_type == 'mfcc':
|
elif self.feat_type == 'mfcc':
|
||||||
mat = kaldi.mfcc(waveform,
|
mat = kaldi.mfcc(
|
||||||
num_ceps=self.feat_dim,
|
waveform,
|
||||||
num_mel_bins=self.feat_dim,
|
num_ceps=self.feat_dim,
|
||||||
dither=0.0,
|
num_mel_bins=self.feat_dim,
|
||||||
energy_floor=0.0,
|
dither=0.0,
|
||||||
sample_frequency=resample_rate,
|
energy_floor=0.0,
|
||||||
)
|
sample_frequency=resample_rate,
|
||||||
|
)
|
||||||
mean_stat += torch.sum(mat, axis=0)
|
mean_stat += torch.sum(mat, axis=0)
|
||||||
var_stat += torch.sum(torch.square(mat), axis=0)
|
var_stat += torch.sum(torch.square(mat), axis=0)
|
||||||
number += mat.shape[0]
|
number += mat.shape[0]
|
||||||
@ -104,11 +106,14 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
with open(args.train_config, 'r') as fin:
|
with open(args.train_config, 'r') as fin:
|
||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
|
feat_dim = configs['dataset_conf']['feature_extraction_conf'][
|
||||||
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
|
'num_mel_bins']
|
||||||
|
feat_type = configs['dataset_conf']['feature_extraction_conf'][
|
||||||
|
'feature_type']
|
||||||
resample_rate = 0
|
resample_rate = 0
|
||||||
if 'resample_conf' in configs['dataset_conf']:
|
if 'resample_conf' in configs['dataset_conf']:
|
||||||
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
|
resample_rate = configs['dataset_conf']['resample_conf'][
|
||||||
|
'resample_rate']
|
||||||
print('using resample and new sample rate is {}'.format(resample_rate))
|
print('using resample and new sample rate is {}'.format(resample_rate))
|
||||||
|
|
||||||
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
|
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
torchaudio.set_audio_backend("sox_io")
|
torchaudio.set_audio_backend("sox_io")
|
||||||
|
|
||||||
scp = sys.argv[1]
|
scp = sys.argv[1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user