59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
# Copyright (c) 2021 Binbin Zhang
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
|
|
import yaml
|
|
import torch
|
|
|
|
|
|
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
|
|
if torch.cuda.is_available():
|
|
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
|
|
checkpoint = torch.load(path)
|
|
else:
|
|
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
|
|
checkpoint = torch.load(path, map_location='cpu')
|
|
model.load_state_dict(checkpoint)
|
|
info_path = re.sub('.pt$', '.yaml', path)
|
|
configs = {}
|
|
if os.path.exists(info_path):
|
|
with open(info_path, 'r') as fin:
|
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
return configs
|
|
|
|
|
|
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
|
|
'''
|
|
Args:
|
|
infos (dict or None): any info you want to save.
|
|
'''
|
|
logging.info('Checkpoint: save to checkpoint %s' % path)
|
|
if isinstance(model, torch.nn.DataParallel):
|
|
state_dict = model.module.state_dict()
|
|
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
|
state_dict = model.module.state_dict()
|
|
else:
|
|
state_dict = model.state_dict()
|
|
torch.save(state_dict, path)
|
|
info_path = re.sub('.pt$', '.yaml', path)
|
|
if infos is None:
|
|
infos = {}
|
|
with open(info_path, 'w') as fout:
|
|
data = yaml.dump(infos)
|
|
fout.write(data)
|