face01lib.damo_yolo.damo_internal.utils.checkpoint のソースコード

#!/usr/bin/env python3
# Copyright (c) Megvii Inc. All rights reserved.
import os
import shutil

import torch
from loguru import logger


[ドキュメント] def load_ckpt(model, ckpt): model_state_dict = model.state_dict() load_dict = {} for key_model, v in model_state_dict.items(): if key_model not in ckpt: logger.warning('{} is not in the ckpt. \ Please double check and see if this is desired.'.format( key_model)) continue v_ckpt = ckpt[key_model] if v.shape != v_ckpt.shape: logger.warning('Shape of {} in checkpoint is {}, \ while shape of {} in model is {}.'.format( key_model, v_ckpt.shape, key_model, v.shape)) continue load_dict[key_model] = v_ckpt model.load_state_dict(load_dict, strict=False) return model
[ドキュメント] def save_checkpoint(state, is_best, save_dir, model_name=''): if not os.path.exists(save_dir): os.makedirs(save_dir) filename = os.path.join(save_dir, model_name + '_ckpt.pth') torch.save(state, filename) if is_best: best_filename = os.path.join(save_dir, 'best_ckpt.pth') shutil.copyfile(filename, best_filename)