face01lib.damo_yolo.damo_internal.structures.boxlist_ops のソースコード

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch

from .bounding_box import BoxList


[ドキュメント] def remove_small_boxes(boxlist, min_size): """ Only keep boxes with both sides >= min_size Arguments: boxlist (Boxlist) min_size (int) """ xywh_boxes = boxlist.convert('xywh').bbox _, _, ws, hs = xywh_boxes.unbind(dim=1) keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1) return boxlist[keep]
[ドキュメント] def boxlist_iou(boxlist1, boxlist2): """Compute the intersection over union of two set of boxes. The box order must be (xmin, ymin, xmax, ymax). Arguments: box1: (BoxList) bounding boxes, sized [N,4]. box2: (BoxList) bounding boxes, sized [M,4]. Returns: (tensor) iou, sized [N,M]. Reference: https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py """ if boxlist1.size != boxlist2.size: raise RuntimeError( 'boxlists should have same image size, got {}, {}'.format( boxlist1, boxlist2)) area1 = boxlist1.area() area2 = boxlist2.area() box1, box2 = boxlist1.bbox, boxlist2.bbox lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] TO_REMOVE = 1 wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] iou = inter / (area1[:, None] + area2 - inter) return iou
def _cat(tensors, dim=0): """ Efficient version of torch.cat that avoids a copy if there is only a single element in a list """ assert isinstance(tensors, (list, tuple)) if len(tensors) == 1: return tensors[0] return torch.cat(tensors, dim)
[ドキュメント] def cat_boxlist(bboxes): """ Concatenates a list of BoxList (having the same image size) into a single BoxList Arguments: bboxes (list[BoxList]) """ assert isinstance(bboxes, (list, tuple)) assert all(isinstance(bbox, BoxList) for bbox in bboxes) size = bboxes[0].size assert all(bbox.size == size for bbox in bboxes) mode = bboxes[0].mode assert all(bbox.mode == mode for bbox in bboxes) fields = set(bboxes[0].fields()) assert all(set(bbox.fields()) == fields for bbox in bboxes) cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode) for field in fields: data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0) cat_boxes.add_field(field, data) return cat_boxes