format the code

This commit is contained in:
Jingyong Hou 2021-11-10 22:49:53 +08:00
parent 7df9ced666
commit 0942092426
5 changed files with 28 additions and 20 deletions

View File

@ -56,8 +56,7 @@ def main():
# Export quantized jit torch script model
if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
model, {torch.nn.Linear}, dtype=torch.qint8)
print(quantized_model)
script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file)

View File

@ -135,7 +135,8 @@ def main():
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
input_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export

View File

@ -126,6 +126,7 @@ def speed_perturb(data, speeds=None):
yield sample
def compute_mfcc(
data,
feature_type='mfcc',
@ -164,6 +165,7 @@ def compute_mfcc(
)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def compute_fbank(data,
feature_type='fbank',
num_mel_bins=23,

View File

@ -32,7 +32,8 @@ class CollateFunc(object):
value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1
wav_path = value[0]
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
sample_rate = torchaudio.backend.sox_io_backend.info(
wav_path).sample_rate
resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp
@ -58,7 +59,8 @@ class CollateFunc(object):
energy_floor=0.0,
sample_frequency=resample_rate)
elif self.feat_type == 'mfcc':
mat = kaldi.mfcc(waveform,
mat = kaldi.mfcc(
waveform,
num_ceps=self.feat_dim,
num_mel_bins=self.feat_dim,
dither=0.0,
@ -104,11 +106,14 @@ if __name__ == '__main__':
with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type']
feat_dim = configs['dataset_conf']['feature_extraction_conf'][
'num_mel_bins']
feat_type = configs['dataset_conf']['feature_extraction_conf'][
'feature_type']
resample_rate = 0
if 'resample_conf' in configs['dataset_conf']:
resample_rate = configs['dataset_conf']['resample_conf']['resample_rate']
resample_rate = configs['dataset_conf']['resample_conf'][
'resample_rate']
print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, feat_type, resample_rate)

View File

@ -4,6 +4,7 @@
import sys
import torchaudio
torchaudio.set_audio_backend("sox_io")
scp = sys.argv[1]