[wekws] fix log (only one process print model)

This commit is contained in:
di.wu 2023-10-07 15:46:22 +08:00
parent 6ae98ef111
commit d618951e87

View File

@ -141,15 +141,14 @@ def main():
configs['model']['cmvn'] = {} configs['model']['cmvn'] = {}
configs['model']['cmvn']['norm_var'] = args.norm_var configs['model']['cmvn']['norm_var'] = args.norm_var
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
# Init asr model from configs
model = init_model(configs['model'])
if rank == 0: if rank == 0:
saved_config_path = os.path.join(args.model_dir, 'config.yaml') saved_config_path = os.path.join(args.model_dir, 'config.yaml')
with open(saved_config_path, 'w') as fout: with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs) data = yaml.dump(configs)
fout.write(data) fout.write(data)
print(model)
# Init asr model from configs
model = init_model(configs['model'])
print(model)
num_params = count_parameters(model) num_params = count_parameters(model)
print('the number of model params: {}'.format(num_params)) print('the number of model params: {}'.format(num_params))