Merge pull request #11 from wenet-e2e/jingyong-random-seed
add mannul random seed so we can reproduce the experimental results
This commit is contained in:
commit
87c42add2a
@ -41,8 +41,7 @@ model:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 5e-5
|
||||
warm_up_step: 2500
|
||||
weight_decay: 0.00005
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
|
||||
@ -2,4 +2,4 @@ export PATH=$PWD:$PATH
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
export PYTHONPATH=../../../:$PYTHONPATH
|
||||
|
||||
@ -30,6 +30,7 @@ from kws.dataset.dataset import Dataset
|
||||
from kws.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from kws.model.kws_model import init_model
|
||||
from kws.utils.executor import Executor
|
||||
from kws.utils.train_utils import count_parameters, set_mannul_seed
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -42,6 +43,7 @@ def get_args():
|
||||
default=-1,
|
||||
help='gpu id for this local rank, -1 for cpu')
|
||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||
parser.add_argument('--seed', dest='seed', default=777, help='random seed')
|
||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||
parser.add_argument('--tensorboard_dir',
|
||||
default='tensorboard',
|
||||
@ -101,8 +103,9 @@ def main():
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
# Set random seed
|
||||
torch.manual_seed(777)
|
||||
set_mannul_seed(args.seed)
|
||||
print(args)
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
@ -155,7 +158,7 @@ def main():
|
||||
# Init asr model from configs
|
||||
model = init_model(configs['model'])
|
||||
print(model)
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
num_params = count_parameters(model)
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
|
||||
# !!!IMPORTANT!!!
|
||||
@ -192,7 +195,9 @@ def main():
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
|
||||
optimizer = optim.Adam(model.parameters(),
|
||||
lr=configs['optim_conf']['lr'],
|
||||
weight_decay=configs['optim_conf']['weight_decay'])
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
|
||||
30
kws/utils/train_utils.py
Normal file
30
kws/utils/train_utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
def set_mannul_seed(seed):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
Loading…
x
Reference in New Issue
Block a user