format the code

This commit is contained in:
Jingyong Hou 2021-11-10 22:49:53 +08:00
parent 7df9ced666
commit 0942092426
5 changed files with 28 additions and 20 deletions

View File

@ -56,8 +56,7 @@ def main():
# Export quantized jit torch script model # Export quantized jit torch script model
if args.output_quant_file: if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic( quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8 model, {torch.nn.Linear}, dtype=torch.qint8)
)
print(quantized_model) print(quantized_model)
script_quant_model = torch.jit.script(quantized_model) script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file) script_quant_model.save(args.output_quant_file)

View File

@ -135,7 +135,8 @@ def main():
num_workers=args.num_workers, num_workers=args.num_workers,
prefetch_factor=args.prefetch) prefetch_factor=args.prefetch)
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins'] input_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
output_dim = args.num_keywords output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export # Write model_dir/config.yaml for inference and export
@ -161,8 +162,8 @@ def main():
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements # the code to satisfy the script export requirements
#if args.rank == 0: #if args.rank == 0:
#script_model = torch.jit.script(model) #script_model = torch.jit.script(model)
#script_model.save(os.path.join(args.model_dir, 'init.zip')) #script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor() executor = Executor()
# If specify checkpoint, load some info from checkpoint # If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None: if args.checkpoint is not None:

View File

@ -126,6 +126,7 @@ def speed_perturb(data, speeds=None):
yield sample yield sample
def compute_mfcc( def compute_mfcc(
data, data,
feature_type='mfcc', feature_type='mfcc',
@ -164,6 +165,7 @@ def compute_mfcc(
) )
yield dict(key=sample['key'], label=sample['label'], feat=mat) yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_fbank(data, def compute_fbank(data,
feature_type='fbank', feature_type='fbank',
num_mel_bins=23, num_mel_bins=23,

View File

@ -32,7 +32,8 @@ class CollateFunc(object):
value = item[1].strip().split(",") value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1 assert len(value) == 3 or len(value) == 1
wav_path = value[0] wav_path = value[0]
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate sample_rate = torchaudio.backend.sox_io_backend.info(
wav_path).sample_rate
resample_rate = sample_rate resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp, # len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp # len(value) == 1 means original wav.scp
@ -53,18 +54,19 @@ class CollateFunc(object):
orig_freq=sample_rate, new_freq=resample_rate)(waveform) orig_freq=sample_rate, new_freq=resample_rate)(waveform)
if self.feat_type == 'fbank': if self.feat_type == 'fbank':
mat = kaldi.fbank(waveform, mat = kaldi.fbank(waveform,
num_mel_bins=self.feat_dim, num_mel_bins=self.feat_dim,
dither=0.0, dither=0.0,
energy_floor=0.0, energy_floor=0.0,
sample_frequency=resample_rate) sample_frequency=resample_rate)
elif self.feat_type == 'mfcc': elif self.feat_type == 'mfcc':
mat = kaldi.mfcc(waveform, mat = kaldi.mfcc(
num_ceps=self.feat_dim, waveform,
num_mel_bins=self.feat_dim, num_ceps=self.feat_dim,
dither=0.0, num_mel_bins=self.feat_dim,
energy_floor=0.0, dither=0.0,
sample_frequency=resample_rate, energy_floor=0.0,
) sample_frequency=resample_rate,
)
mean_stat += torch.sum(mat, axis=0) mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0) var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0] number += mat.shape[0]
@ -104,11 +106,14 @@ if __name__ == '__main__':
with open(args.train_config, 'r') as fin: with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader) configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins'] feat_dim = configs['dataset_conf']['feature_extraction_conf'][
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type'] 'num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf'][
'feature_type']
resample_rate = 0 resample_rate = 0
if 'resample_conf' in configs['dataset_conf']: if 'resample_conf' in configs['dataset_conf']:
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] resample_rate = configs['dataset_conf']['resample_conf'][
'resample_rate']
print('using resample and new sample rate is {}'.format(resample_rate)) print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, feat_type, resample_rate) collate_func = CollateFunc(feat_dim, feat_type, resample_rate)

View File

@ -4,6 +4,7 @@
import sys import sys
import torchaudio import torchaudio
torchaudio.set_audio_backend("sox_io") torchaudio.set_audio_backend("sox_io")
scp = sys.argv[1] scp = sys.argv[1]