face01lib.damo_yolo.damo_internal.utils.model_utils のソースコード

#!/usr/bin/env python3
# Copyright (c) Megvii Inc. All rights reserved.
# Copyright (C) Alibaba Group Holding Limited. All rights reserved.

import time
from copy import deepcopy

import torch
import torch.nn as nn
from thop import profile

__all__ = [
    'fuse_conv_and_bn',
    'fuse_model',
    'get_model_info',
    'replace_module',
    'make_divisible'
]

[ドキュメント] def make_divisible(v, divisor=8, min_value=None): if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) if new_v < 0.9 * v: new_v += divisor return new_v
def get_latency(model, inp, iters=500, warmup=2): start = time.time() for i in range(iters): out = model(inp) if torch.cuda.is_available(): torch.cuda.synchronize() if i <= warmup: start = time.time() latency = (time.time() - start) / (iters - warmup) return out, latency
[ドキュメント] def get_model_info(model, tsize): stride = 640 model = model.eval() backbone = model.backbone neck = model.neck head = model.head h, w = tsize img = torch.randn((1, 3, stride, stride), device=next(model.parameters()).device) bf, bp = profile(deepcopy(backbone), inputs=(img, ), verbose=False) bo, bl = get_latency(backbone, img, iters=10) nf, np = profile(deepcopy(neck), inputs=(bo, ), verbose=False) no, nl = get_latency(neck, bo, iters=10) hf, hp = profile(deepcopy(head), inputs=(no, ), verbose=False) ho, hl = get_latency(head, no, iters=10) _, total_latency = get_latency(model, img) total_flops = 0 total_params = 0 info = '' for name, flops, params, latency in zip(('backbone', 'neck', 'head'), (bf, nf, hf), (bp, np, hp), (bl, nl, hl)): params /= 1e6 flops /= 1e9 flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops total_flops += flops total_params += params info += f"{name}'s params(M): {params:.2f}, " + \ f'flops(G): {flops:.2f}, latency(ms): {latency*1000:.3f}\n' info += f'total latency(ms): {total_latency*1000:.3f}, ' + \ f'total flops(G): {total_flops:.2f}, ' + f'total params(M): {total_params:.2f}\n' return info
[ドキュメント] def fuse_conv_and_bn(conv, bn): # Fuse convolution and batchnorm layers # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = (nn.Conv2d( conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=True, ).requires_grad_(False).to(conv.weight.device)) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # prepare spatial bias b_conv = (torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias) b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_( torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) return fusedconv
[ドキュメント] def fuse_model(model): from .base_models.core.ops import ConvBNAct from .base_models.backbones.tinynas_res import ConvKXBN for m in model.modules(): if type(m) is ConvBNAct and hasattr(m, 'bn'): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.fuseforward # update forward elif type(m) is ConvKXBN and hasattr(m, 'bn1'): m.conv1 = fuse_conv_and_bn(m.conv1, m.bn1) # update conv delattr(m, 'bn1') # remove batchnorm m.forward = m.fuseforward # update forward return model
[ドキュメント] def replace_module(module, replaced_module_type, new_module_type, replace_func=None): """ Replace given type in module to a new type. mostly used in deploy. Args: module (nn.Module): model to apply replace operation. replaced_module_type (Type): module type to be replaced. new_module_type (Type) replace_func (function): python function to describe replace logic. Defalut value None. Returns: model (nn.Module): module that already been replaced. """ def default_replace_func(replaced_module_type, new_module_type): return new_module_type() if replace_func is None: replace_func = default_replace_func model = module if isinstance(module, replaced_module_type): model = replace_func(replaced_module_type, new_module_type) else: # recurrsively replace for name, child in module.named_children(): new_child = replace_module(child, replaced_module_type, new_module_type) if new_child is not child: # child is already replaced model.add_module(name, new_child) return model