diff --git a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml index 937f1ce..3977f90 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml @@ -41,8 +41,7 @@ model: optim: adam optim_conf: lr: 0.001 - weight_decay: 5e-5 - warm_up_step: 2500 + weight_decay: 0.00005 training_config: grad_clip: 5 diff --git a/examples/hi_xiaowen/s0/path.sh b/examples/hi_xiaowen/s0/path.sh index cf09584..b90a515 100755 --- a/examples/hi_xiaowen/s0/path.sh +++ b/examples/hi_xiaowen/s0/path.sh @@ -2,4 +2,4 @@ export PATH=$PWD:$PATH # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C export PYTHONIOENCODING=UTF-8 -export PYTHONPATH=../../:$PYTHONPATH +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/kws/bin/train.py b/kws/bin/train.py index 915f176..df5e0fb 100644 --- a/kws/bin/train.py +++ b/kws/bin/train.py @@ -30,6 +30,7 @@ from kws.dataset.dataset import Dataset from kws.utils.checkpoint import load_checkpoint, save_checkpoint from kws.model.kws_model import init_model from kws.utils.executor import Executor +from kws.utils.train_utils import count_parameters, set_mannul_seed def get_args(): @@ -42,6 +43,7 @@ def get_args(): default=-1, help='gpu id for this local rank, -1 for cpu') parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--seed', dest='seed', default=777, help='random seed') parser.add_argument('--checkpoint', help='checkpoint model') parser.add_argument('--tensorboard_dir', default='tensorboard', @@ -101,6 +103,7 @@ def main(): logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + set_mannul_seed(args.gpu) # Set random seed torch.manual_seed(777) print(args) @@ -155,7 +158,7 @@ def main(): # Init asr model from configs model = init_model(configs['model']) print(model) - num_params = sum(p.numel() for p in model.parameters()) + num_params = count_parameters(model) print('the number of model params: {}'.format(num_params)) # !!!IMPORTANT!!! @@ -192,7 +195,9 @@ def main(): device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) - optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + optimizer = optim.Adam(model.parameters(), + lr=configs['optim_conf']['lr'], + weight_decay=configs['optim_conf']['weight_decay']) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index ab543dd..454aff6 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -56,7 +56,6 @@ class KWSModel(torch.nn.Module): x = self.preprocessing(x) x, _ = self.backbone(x) x = self.classifier(x) - x = torch.sigmoid(x) return x diff --git a/kws/model/loss.py b/kws/model/loss.py index d73722c..57f2670 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -37,6 +37,7 @@ def max_polling_loss(logits: torch.Tensor, (float): loss of current batch (float): accuracy of current batch ''' + logits = torch.sigmoid(logits) mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2) diff --git a/kws/utils/train_utils.py b/kws/utils/train_utils.py new file mode 100644 index 0000000..9641d24 --- /dev/null +++ b/kws/utils/train_utils.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +import random + + +def set_mannul_seed(seed): + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad)