244 lines
9.3 KiB
Python
244 lines
9.3 KiB
Python
# Copyright (c) 2020 Binbin Zhang
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import copy
|
|
import logging
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.optim as optim
|
|
import yaml
|
|
from tensorboardX import SummaryWriter
|
|
from torch.utils.data import DataLoader
|
|
|
|
from kws.dataset.dataset import Dataset
|
|
from kws.utils.checkpoint import load_checkpoint, save_checkpoint
|
|
from kws.model.kws_model import init_model
|
|
from kws.utils.executor import Executor
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(description='training your network')
|
|
parser.add_argument('--config', required=True, help='config file')
|
|
parser.add_argument('--train_data', required=True, help='train data file')
|
|
parser.add_argument('--cv_data', required=True, help='cv data file')
|
|
parser.add_argument('--gpu',
|
|
type=int,
|
|
default=-1,
|
|
help='gpu id for this local rank, -1 for cpu')
|
|
parser.add_argument('--model_dir', required=True, help='save model dir')
|
|
parser.add_argument('--checkpoint', help='checkpoint model')
|
|
parser.add_argument('--tensorboard_dir',
|
|
default='tensorboard',
|
|
help='tensorboard log dir')
|
|
parser.add_argument('--ddp.rank',
|
|
dest='rank',
|
|
default=0,
|
|
type=int,
|
|
help='global rank for distributed training')
|
|
parser.add_argument('--ddp.world_size',
|
|
dest='world_size',
|
|
default=-1,
|
|
type=int,
|
|
help='''number of total processes/gpus for
|
|
distributed training''')
|
|
parser.add_argument('--ddp.dist_backend',
|
|
dest='dist_backend',
|
|
default='nccl',
|
|
choices=['nccl', 'gloo'],
|
|
help='distributed backend')
|
|
parser.add_argument('--ddp.init_method',
|
|
dest='init_method',
|
|
default=None,
|
|
help='ddp init method')
|
|
parser.add_argument('--num_workers',
|
|
default=0,
|
|
type=int,
|
|
help='num of subprocess workers for reading')
|
|
parser.add_argument('--pin_memory',
|
|
action='store_true',
|
|
default=False,
|
|
help='Use pinned memory buffers used for reading')
|
|
parser.add_argument('--cmvn_file', default=None, help='global cmvn file')
|
|
parser.add_argument('--norm_var',
|
|
action='store_true',
|
|
default=False,
|
|
help='norm var option')
|
|
parser.add_argument('--num_keywords',
|
|
default=1,
|
|
type=int,
|
|
help='number of keywords')
|
|
parser.add_argument('--min_duration',
|
|
default=50,
|
|
type=int,
|
|
help='min duration frames of the keyword')
|
|
parser.add_argument('--prefetch',
|
|
default=100,
|
|
type=int,
|
|
help='prefetch number')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format='%(asctime)s %(levelname)s %(message)s')
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
# Set random seed
|
|
torch.manual_seed(777)
|
|
print(args)
|
|
with open(args.config, 'r') as fin:
|
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
|
|
distributed = args.world_size > 1
|
|
if distributed:
|
|
logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
|
|
dist.init_process_group(args.dist_backend,
|
|
init_method=args.init_method,
|
|
world_size=args.world_size,
|
|
rank=args.rank)
|
|
|
|
train_conf = configs['dataset_conf']
|
|
cv_conf = copy.deepcopy(train_conf)
|
|
cv_conf['speed_perturb'] = False
|
|
cv_conf['spec_aug'] = False
|
|
cv_conf['shuffle'] = False
|
|
|
|
train_dataset = Dataset(args.train_data, train_conf)
|
|
cv_dataset = Dataset(args.cv_data, cv_conf)
|
|
|
|
train_data_loader = DataLoader(train_dataset,
|
|
batch_size=None,
|
|
pin_memory=args.pin_memory,
|
|
num_workers=args.num_workers,
|
|
prefetch_factor=args.prefetch)
|
|
cv_data_loader = DataLoader(cv_dataset,
|
|
batch_size=None,
|
|
pin_memory=args.pin_memory,
|
|
num_workers=args.num_workers,
|
|
prefetch_factor=args.prefetch)
|
|
|
|
input_dim = configs['dataset_conf']['feature_extraction_conf'][
|
|
'num_mel_bins']
|
|
output_dim = args.num_keywords
|
|
|
|
# Write model_dir/config.yaml for inference and export
|
|
configs['model']['input_dim'] = input_dim
|
|
configs['model']['output_dim'] = output_dim
|
|
if args.cmvn_file is not None:
|
|
configs['model']['cmvn'] = {}
|
|
configs['model']['cmvn']['norm_var'] = args.norm_var
|
|
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
|
|
if args.rank == 0:
|
|
saved_config_path = os.path.join(args.model_dir, 'config.yaml')
|
|
with open(saved_config_path, 'w') as fout:
|
|
data = yaml.dump(configs)
|
|
fout.write(data)
|
|
|
|
# Init asr model from configs
|
|
model = init_model(configs['model'])
|
|
print(model)
|
|
num_params = sum(p.numel() for p in model.parameters())
|
|
print('the number of model params: {}'.format(num_params))
|
|
|
|
# !!!IMPORTANT!!!
|
|
# Try to export the model by script, if fails, we should refine
|
|
# the code to satisfy the script export requirements
|
|
#if args.rank == 0:
|
|
#script_model = torch.jit.script(model)
|
|
#script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
|
executor = Executor()
|
|
# If specify checkpoint, load some info from checkpoint
|
|
if args.checkpoint is not None:
|
|
infos = load_checkpoint(model, args.checkpoint)
|
|
else:
|
|
infos = {}
|
|
start_epoch = infos.get('epoch', -1) + 1
|
|
cv_loss = infos.get('cv_loss', 0.0)
|
|
|
|
model_dir = args.model_dir
|
|
writer = None
|
|
if args.rank == 0:
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
exp_id = os.path.basename(model_dir)
|
|
writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
|
|
|
|
if distributed:
|
|
assert (torch.cuda.is_available())
|
|
# cuda model is required for nn.parallel.DistributedDataParallel
|
|
model.cuda()
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model, find_unused_parameters=True)
|
|
device = torch.device("cuda")
|
|
else:
|
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
model = model.to(device)
|
|
|
|
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer,
|
|
mode='min',
|
|
factor=0.5,
|
|
patience=3,
|
|
min_lr=1e-6,
|
|
threshold=0.01,
|
|
)
|
|
|
|
training_config = configs['training_config']
|
|
training_config['min_duration'] = args.min_duration
|
|
num_epochs = training_config.get('max_epoch', 100)
|
|
final_epoch = None
|
|
if start_epoch == 0 and args.rank == 0:
|
|
save_model_path = os.path.join(model_dir, 'init.pt')
|
|
save_checkpoint(model, save_model_path)
|
|
|
|
# Start training loop
|
|
for epoch in range(start_epoch, num_epochs):
|
|
train_dataset.set_epoch(epoch)
|
|
training_config['epoch'] = epoch
|
|
lr = optimizer.param_groups[0]['lr']
|
|
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
|
executor.train(model, optimizer, train_data_loader, device, writer,
|
|
training_config)
|
|
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
|
|
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
|
|
|
|
if args.rank == 0:
|
|
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
|
save_checkpoint(model, save_model_path, {
|
|
'epoch': epoch,
|
|
'lr': lr,
|
|
'cv_loss': cv_loss,
|
|
})
|
|
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
|
writer.add_scalar('epoch/lr', lr, epoch)
|
|
final_epoch = epoch
|
|
scheduler.step(cv_loss)
|
|
|
|
if final_epoch is not None and args.rank == 0:
|
|
final_model_path = os.path.join(model_dir, 'final.pt')
|
|
os.symlink('{}.pt'.format(final_epoch), final_model_path)
|
|
writer.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|