format the code

This commit is contained in:
Jingyong Hou 2021-11-10 22:49:53 +08:00
parent 7df9ced666
commit 0942092426
5 changed files with 28 additions and 20 deletions

View File

@ -56,8 +56,7 @@ def main():
# Export quantized jit torch script model
if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
model, {torch.nn.Linear}, dtype=torch.qint8)
print(quantized_model)
script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file)

View File

@ -135,7 +135,8 @@ def main():
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
input_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export
@ -161,8 +162,8 @@ def main():
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
#if args.rank == 0:
#script_model = torch.jit.script(model)
#script_model.save(os.path.join(args.model_dir, 'init.zip'))
#script_model = torch.jit.script(model)
#script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:

View File

@ -126,6 +126,7 @@ def speed_perturb(data, speeds=None):
yield sample
def compute_mfcc(
data,
feature_type='mfcc',
@ -164,6 +165,7 @@ def compute_mfcc(
)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_fbank(data,
feature_type='fbank',
num_mel_bins=23,

View File

@ -32,7 +32,8 @@ class CollateFunc(object):
value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1
wav_path = value[0]
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
sample_rate = torchaudio.backend.sox_io_backend.info(
wav_path).sample_rate
resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp
@ -53,18 +54,19 @@ class CollateFunc(object):
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
if self.feat_type == 'fbank':
mat = kaldi.fbank(waveform,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate)
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate)
elif self.feat_type == 'mfcc':
mat = kaldi.mfcc(waveform,
num_ceps=self.feat_dim,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate,
)
mat = kaldi.mfcc(
waveform,
num_ceps=self.feat_dim,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate,
)
mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0]
@ -104,11 +106,14 @@ if __name__ == '__main__':
with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
feat_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf'][
'feature_type']
resample_rate = 0
if 'resample_conf' in configs['dataset_conf']:
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
resample_rate = configs['dataset_conf']['resample_conf'][
'resample_rate']
print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)

View File

@ -4,6 +4,7 @@
import sys
import torchaudio
torchaudio.set_audio_backend("sox_io")
scp = sys.argv[1]