add mannul random seed so we can reproduce the experimental results
This commit is contained in:
parent
edfc6de743
commit
9aaa4fc26c
@ -41,8 +41,7 @@ model:
|
|||||||
optim: adam
|
optim: adam
|
||||||
optim_conf:
|
optim_conf:
|
||||||
lr: 0.001
|
lr: 0.001
|
||||||
weight_decay: 5e-5
|
weight_decay: 0.00005
|
||||||
warm_up_step: 2500
|
|
||||||
|
|
||||||
training_config:
|
training_config:
|
||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
|
|||||||
@ -2,4 +2,4 @@ export PATH=$PWD:$PATH
|
|||||||
|
|
||||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
export PYTHONIOENCODING=UTF-8
|
export PYTHONIOENCODING=UTF-8
|
||||||
export PYTHONPATH=../../:$PYTHONPATH
|
export PYTHONPATH=../../../:$PYTHONPATH
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from kws.dataset.dataset import Dataset
|
|||||||
from kws.utils.checkpoint import load_checkpoint, save_checkpoint
|
from kws.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||||
from kws.model.kws_model import init_model
|
from kws.model.kws_model import init_model
|
||||||
from kws.utils.executor import Executor
|
from kws.utils.executor import Executor
|
||||||
|
from kws.utils.train_utils import count_parameters, set_mannul_seed
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -42,6 +43,7 @@ def get_args():
|
|||||||
default=-1,
|
default=-1,
|
||||||
help='gpu id for this local rank, -1 for cpu')
|
help='gpu id for this local rank, -1 for cpu')
|
||||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||||
|
parser.add_argument('--seed', dest='seed', default=777, help='random seed')
|
||||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||||
parser.add_argument('--tensorboard_dir',
|
parser.add_argument('--tensorboard_dir',
|
||||||
default='tensorboard',
|
default='tensorboard',
|
||||||
@ -101,6 +103,7 @@ def main():
|
|||||||
logging.basicConfig(level=logging.DEBUG,
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||||
|
set_mannul_seed(args.gpu)
|
||||||
# Set random seed
|
# Set random seed
|
||||||
torch.manual_seed(777)
|
torch.manual_seed(777)
|
||||||
print(args)
|
print(args)
|
||||||
@ -155,7 +158,7 @@ def main():
|
|||||||
# Init asr model from configs
|
# Init asr model from configs
|
||||||
model = init_model(configs['model'])
|
model = init_model(configs['model'])
|
||||||
print(model)
|
print(model)
|
||||||
num_params = sum(p.numel() for p in model.parameters())
|
num_params = count_parameters(model)
|
||||||
print('the number of model params: {}'.format(num_params))
|
print('the number of model params: {}'.format(num_params))
|
||||||
|
|
||||||
# !!!IMPORTANT!!!
|
# !!!IMPORTANT!!!
|
||||||
@ -192,7 +195,9 @@ def main():
|
|||||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
|
optimizer = optim.Adam(model.parameters(),
|
||||||
|
lr=configs['optim_conf']['lr'],
|
||||||
|
weight_decay=configs['optim_conf']['weight_decay'])
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer,
|
optimizer,
|
||||||
mode='min',
|
mode='min',
|
||||||
|
|||||||
@ -56,7 +56,6 @@ class KWSModel(torch.nn.Module):
|
|||||||
x = self.preprocessing(x)
|
x = self.preprocessing(x)
|
||||||
x, _ = self.backbone(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
x = torch.sigmoid(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,7 @@ def max_polling_loss(logits: torch.Tensor,
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
'''
|
'''
|
||||||
|
logits = torch.sigmoid(logits)
|
||||||
mask = padding_mask(lengths)
|
mask = padding_mask(lengths)
|
||||||
num_utts = logits.size(0)
|
num_utts = logits.size(0)
|
||||||
num_keywords = logits.size(2)
|
num_keywords = logits.size(2)
|
||||||
|
|||||||
30
kws/utils/train_utils.py
Normal file
30
kws/utils/train_utils.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def set_mannul_seed(seed):
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
Loading…
x
Reference in New Issue
Block a user