diff --git a/kws/bin/export_jit.py b/kws/bin/export_jit.py index 7c1846e..81f48f9 100644 --- a/kws/bin/export_jit.py +++ b/kws/bin/export_jit.py @@ -56,8 +56,7 @@ def main(): # Export quantized jit torch script model if args.output_quant_file: quantized_model = torch.quantization.quantize_dynamic( - model, {torch.nn.Linear}, dtype=torch.qint8 - ) + model, {torch.nn.Linear}, dtype=torch.qint8) print(quantized_model) script_quant_model = torch.jit.script(quantized_model) script_quant_model.save(args.output_quant_file) diff --git a/kws/bin/train.py b/kws/bin/train.py index 55c87ec..00c367b 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -135,7 +135,8 @@ def main(): num_workers=args.num_workers, prefetch_factor=args.prefetch) - input_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins'] + input_dim = configs['dataset_conf']['feature_extraction_conf'][ + 'num_mel_bins'] output_dim = args.num_keywords # Write model_dir/config.yaml for inference and export @@ -161,8 +162,8 @@ def main(): # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements #if args.rank == 0: - #script_model = torch.jit.script(model) - #script_model.save(os.path.join(args.model_dir, 'init.zip')) + #script_model = torch.jit.script(model) + #script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: diff --git a/kws/dataset/processor.py b/kws/dataset/processor.py index 9108b23..0fd1d84 100644 --- a/kws/dataset/processor.py +++ b/kws/dataset/processor.py @@ -126,6 +126,7 @@ def speed_perturb(data, speeds=None): yield sample + def compute_mfcc( data, feature_type='mfcc', @@ -164,6 +165,7 @@ def compute_mfcc( ) yield dict(key=sample['key'], label=sample['label'], feat=mat) + def compute_fbank(data, feature_type='fbank', num_mel_bins=23, diff --git a/tools/compute_cmvn_stats.py b/tools/compute_cmvn_stats.py index be8df99..95b7e0c 100755 --- a/tools/compute_cmvn_stats.py +++ b/tools/compute_cmvn_stats.py @@ -32,7 +32,8 @@ class CollateFunc(object): value = item[1].strip().split(",") assert len(value) == 3 or len(value) == 1 wav_path = value[0] - sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_path).sample_rate resample_rate = sample_rate # len(value) == 3 means segmented wav.scp, # len(value) == 1 means original wav.scp @@ -53,18 +54,19 @@ class CollateFunc(object): orig_freq=sample_rate, new_freq=resample_rate)(waveform) if self.feat_type == 'fbank': mat = kaldi.fbank(waveform, - num_mel_bins=self.feat_dim, - dither=0.0, - energy_floor=0.0, - sample_frequency=resample_rate) + num_mel_bins=self.feat_dim, + dither=0.0, + energy_floor=0.0, + sample_frequency=resample_rate) elif self.feat_type == 'mfcc': - mat = kaldi.mfcc(waveform, - num_ceps=self.feat_dim, - num_mel_bins=self.feat_dim, - dither=0.0, - energy_floor=0.0, - sample_frequency=resample_rate, - ) + mat = kaldi.mfcc( + waveform, + num_ceps=self.feat_dim, + num_mel_bins=self.feat_dim, + dither=0.0, + energy_floor=0.0, + sample_frequency=resample_rate, + ) mean_stat += torch.sum(mat, axis=0) var_stat += torch.sum(torch.square(mat), axis=0) number += mat.shape[0] @@ -104,11 +106,14 @@ if __name__ == '__main__': with open(args.train_config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) - feat_dim = configs['dataset_conf']['feature_extraction_conf']['num_mel_bins'] - feat_type = configs['dataset_conf']['feature_extraction_conf']['feature_type'] + feat_dim = configs['dataset_conf']['feature_extraction_conf'][ + 'num_mel_bins'] + feat_type = configs['dataset_conf']['feature_extraction_conf'][ + 'feature_type'] resample_rate = 0 if 'resample_conf' in configs['dataset_conf']: - resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] + resample_rate = configs['dataset_conf']['resample_conf'][ + 'resample_rate'] print('using resample and new sample rate is {}'.format(resample_rate)) collate_func = CollateFunc(feat_dim, feat_type, resample_rate) diff --git a/tools/wav2dur.py b/tools/wav2dur.py index 1bcc1b6..b53a7fe 100755 --- a/tools/wav2dur.py +++ b/tools/wav2dur.py @@ -4,6 +4,7 @@ import sys import torchaudio + torchaudio.set_audio_backend("sox_io") scp = sys.argv[1]