diff --git a/.circleci/test.yml b/.circleci/test.yml index 809c1f311f1..a4333a7bc09 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -68,6 +68,7 @@ jobs: # force reinstall pycocotools to ensure pycocotools being built under the currenct numpy command: | python -m pip install git+ssh://git@github.com/open-mmlab/mmengine.git@main + python -m pip install git+ssh://git@github.com/open-mmlab/mmeval.git@main pip install -U openmim mim install 'mmcv >= 2.0.0rc0' pip install -r requirements/tests.txt -r requirements/optional.txt @@ -106,16 +107,18 @@ jobs: name: Clone Repos command: | git clone -b main --depth 1 ssh://git@github.com/open-mmlab/mmengine.git /home/circleci/mmengine + git clone -b main --depth 1 ssh://git@github.com/open-mmlab/mmeval.git /home/circleci/mmeval - run: name: Build Docker image command: | docker build .circleci/docker -t mmdetection:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >> - docker run --gpus all -t -d -v /home/circleci/project:/mmdetection -v /home/circleci/mmengine:/mmengine -w /mmdetection --name mmdetection mmdetection:gpu + docker run --gpus all -t -d -v /home/circleci/project:/mmdetection -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmeval:/mmeval -w /mmdetection --name mmdetection mmdetection:gpu docker exec mmdetection apt-get install -y git - run: name: Install mmdet dependencies command: | docker exec mmdetection pip install -e /mmengine + docker exec mmdetection pip install -e /mmeval docker exec mmdetection pip install -U openmim docker exec mmdetection mim install 'mmcv >= 2.0.0rc0' docker exec mmdetection pip install -r requirements/tests.txt -r requirements/optional.txt diff --git a/configs/_base_/datasets/coco_detection.py b/configs/_base_/datasets/coco_detection.py index fcd9859f135..97004f32255 100644 --- a/configs/_base_/datasets/coco_detection.py +++ b/configs/_base_/datasets/coco_detection.py @@ -59,6 +59,7 @@ type='CocoMetric', ann_file=data_root + 'annotations/instances_val2017.json', metric='bbox', + classwise=False, format_only=False) test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/coco_instance.py b/configs/_base_/datasets/coco_instance.py index 878d8b4915e..79276b7728a 100644 --- a/configs/_base_/datasets/coco_instance.py +++ b/configs/_base_/datasets/coco_instance.py @@ -59,6 +59,7 @@ type='CocoMetric', ann_file=data_root + 'annotations/instances_val2017.json', metric=['bbox', 'segm'], + classwise=False, format_only=False) test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/coco_instance_semantic.py b/configs/_base_/datasets/coco_instance_semantic.py index 12652d02c6b..44b7b2da46b 100644 --- a/configs/_base_/datasets/coco_instance_semantic.py +++ b/configs/_base_/datasets/coco_instance_semantic.py @@ -64,5 +64,6 @@ type='CocoMetric', ann_file=data_root + 'annotations/instances_val2017.json', metric=['bbox', 'segm'], + classwise=False, format_only=False) test_evaluator = val_evaluator diff --git a/configs/rpn/rpn_r50_fpn_1x_coco.py b/configs/rpn/rpn_r50_fpn_1x_coco.py index 692ff9e6650..6504d90453a 100644 --- a/configs/rpn/rpn_r50_fpn_1x_coco.py +++ b/configs/rpn/rpn_r50_fpn_1x_coco.py @@ -3,7 +3,13 @@ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' ] -val_evaluator = dict(metric='proposal_fast') +val_evaluator = dict( + _delete_=True, + type='ProposalRecallMetric', + proposal_nums=(1, 10, 100, 1000), + prefix='coco', + use_legacy_coordinate=False, # VOCDataset should set True, else False +) test_evaluator = val_evaluator # inference on val dataset and dump the proposals with evaluate metric @@ -14,11 +20,9 @@ # output_dir=data_root + 'proposals/', # proposals_file='rpn_r50_fpn_1x_val2017.pkl'), # dict( -# type='CocoMetric', -# ann_file=data_root + 'annotations/instances_val2017.json', -# metric='proposal_fast', -# file_client_args={{_base_.file_client_args}}, -# format_only=False) +# type='ProposalRecallMetric', +# proposal_nums=(1, 10, 100, 1000), +# use_legacy_coordinate=False, # ] # inference on training dataset and dump the proposals without evaluate metric diff --git a/mmdet/datasets/api_wrappers/coco_api.py b/mmdet/datasets/api_wrappers/coco_api.py index 40f7f2c9b93..63d79555c88 100644 --- a/mmdet/datasets/api_wrappers/coco_api.py +++ b/mmdet/datasets/api_wrappers/coco_api.py @@ -15,33 +15,102 @@ class COCO(_COCO): It implements some snake case function aliases. So that the COCO class has the same interface as LVIS class. + + Args: + annotation_file (str, optional): Path of annotation file. + Defaults to None. """ - def __init__(self, annotation_file=None): + def __init__(self, annotation_file: Optional[str] = None) -> None: if getattr(pycocotools, '__version__', '0') >= '12.0.2': warnings.warn( - 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501 - UserWarning) + 'mmpycocotools is deprecated. ' + 'Please install official pycocotools by ' + '"pip install pycocotools"', UserWarning) super().__init__(annotation_file=annotation_file) self.img_ann_map = self.imgToAnns self.cat_img_map = self.catToImgs - def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None): + def get_ann_ids(self, + img_ids: Union[list, int] = [], + cat_ids: Union[list, int] = [], + area_rng: Union[list, int] = [], + iscrowd: Optional[bool] = None) -> list: + """Get annotation ids that satisfy given filter conditions. + + Args: + img_ids (list | int): Get annotations for given images. + cat_ids (list | int): Get categories for given images. + area_rng (list | int): Get annotations for given area range. + iscrowd (bool, optional): Get annotations for given crowd label. + + Returns: + List: Integer array of annotation ids. + """ return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd) - def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]): + def get_cat_ids(self, + cat_names: Union[list, int] = [], + sup_names: Union[list, int] = [], + cat_ids: Union[list, int] = []) -> list: + """Get category ids that satisfy given filter conditions. + + Args: + cat_names (list | int): Get categories for given category names. + sup_names (list | int): Get categories for given supercategory + names. + cat_ids (list | int): Get categories for given category ids. + + Returns: + List: Integer array of category ids. + """ return self.getCatIds(cat_names, sup_names, cat_ids) - def get_img_ids(self, img_ids=[], cat_ids=[]): + def get_img_ids(self, + img_ids: Union[list, int] = [], + cat_ids: Union[list, int] = []) -> list: + """Get image ids that satisfy given filter conditions. + + Args: + img_ids (list | int): Get images for given ids + cat_ids (list | int): Get images with all given cats + + Returns: + List: Integer array of image ids. + """ return self.getImgIds(img_ids, cat_ids) - def load_anns(self, ids): + def load_anns(self, ids: Union[list, int] = []) -> list: + """Load annotations with the specified ids. + + Args: + ids (list | int): Integer ids specifying annotations. + + Returns: + List[dict]: Loaded annotation objects. + """ return self.loadAnns(ids) - def load_cats(self, ids): + def load_cats(self, ids: Union[list, int] = []) -> list: + """Load categories with the specified ids. + + Args: + ids (list | int): Integer ids specifying categories. + + Returns: + List[dict]: loaded category objects. + """ return self.loadCats(ids) - def load_imgs(self, ids): + def load_imgs(self, ids: Union[list, int] = []) -> list: + """Load annotations with the specified ids. + + Args: + ids (list): integer ids specifying image. + + Returns: + List[dict]: Loaded image objects. + """ return self.loadImgs(ids) @@ -53,15 +122,8 @@ class COCOPanoptic(COCO): """This wrapper is for loading the panoptic style annotation file. The format is shown in the CocoPanopticDataset class. - - Args: - annotation_file (str, optional): Path of annotation file. - Defaults to None. """ - def __init__(self, annotation_file: Optional[str] = None) -> None: - super(COCOPanoptic, self).__init__(annotation_file) - def createIndex(self) -> None: """Create index.""" # create index @@ -114,16 +176,16 @@ def createIndex(self) -> None: def load_anns(self, ids: Union[List[int], int] = []) -> Optional[List[dict]]: - """Load anns with the specified ids. + """Load annotations with the specified ids. ``self.anns`` is a list of annotation lists instead of a list of annotations. Args: - ids (Union[List[int], int]): Integer ids specifying anns. + ids (Union[List[int], int]): Integer ids specifying annotations. Returns: - anns (List[dict], optional): Loaded ann objects. + anns (List[dict], optional): Loaded annotation objects. """ anns = [] diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py index da000e0d535..4799e2f38c3 100644 --- a/mmdet/evaluation/metrics/__init__.py +++ b/mmdet/evaluation/metrics/__init__.py @@ -8,10 +8,11 @@ from .dump_proposals_metric import DumpProposals from .lvis_metric import LVISMetric from .openimages_metric import OpenImagesMetric +from .proposal_recall_metric import ProposalRecallMetric from .voc_metric import VOCMetric __all__ = [ 'CityScapesMetric', 'CocoMetric', 'CocoPanopticMetric', 'OpenImagesMetric', 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals', - 'CocoOccludedSeparatedMetric', 'DumpDetResults' + 'CocoOccludedSeparatedMetric', 'DumpDetResults', 'ProposalRecallMetric' ] diff --git a/mmdet/evaluation/metrics/coco_metric.py b/mmdet/evaluation/metrics/coco_metric.py index bd56803da3d..2439ad48e44 100644 --- a/mmdet/evaluation/metrics/coco_metric.py +++ b/mmdet/evaluation/metrics/coco_metric.py @@ -1,27 +1,114 @@ # Copyright (c) OpenMMLab. All rights reserved. -import datetime -import itertools -import os.path as osp -import tempfile -from collections import OrderedDict -from typing import Dict, List, Optional, Sequence, Union +import warnings +from typing import List, Optional, Sequence, Union import numpy as np -import torch -from mmengine.evaluator import BaseMetric -from mmengine.fileio import FileClient, dump, load from mmengine.logging import MMLogger -from terminaltables import AsciiTable +from mmeval import COCODetection +from torch import Tensor -from mmdet.datasets.api_wrappers import COCO, COCOeval from mmdet.registry import METRICS -from mmdet.structures.mask import encode_mask_results -from ..functional import eval_recalls +from mmdet.structures.mask import (BitmapMasks, PolygonMasks, + encode_mask_results) + + +def parse_coco_groundtruth(data_sample: dict) -> dict: + """Parse coco groundtruth if :obj:`COCODetection._coco_api` is None. + + Args: + data_sample (dict): Data samples that contain annotations + and predictions. + + Returns: + dict: Represents a groundtruths for an image, with the following + keys: + + - img_id (int): Image id. + - width (int): The width of the image. + - height (int): The height of the image. + - bboxes (numpy.ndarray): Shape (K, 4), the ground truth + bounding bboxes of this image, in 'xyxy' foramrt. + - labels (numpy.ndarray): Shape (K, ), the ground truth + labels of bounding boxes. + - masks (list[RLE], optional): The predicted masks. + - ignore_flags (numpy.ndarray, optional): Shape (K, ), + the ignore flags. + """ + ann = dict() + ann['width'] = data_sample['ori_shape'][1] + ann['height'] = data_sample['ori_shape'][0] + ann['img_id'] = data_sample['img_id'] + + gt_instances = data_sample['gt_instances'] + ignored_instances = data_sample['ignored_instances'] + + ann['bboxes'] = np.concatenate((gt_instances['bboxes'].cpu().numpy(), + ignored_instances['bboxes'].cpu().numpy()), + axis=0) + ann['labels'] = np.concatenate((gt_instances['labels'].cpu().numpy(), + ignored_instances['labels'].cpu().numpy()), + axis=0) + ann['ignore_flags'] = np.concatenate( + (np.zeros(len(gt_instances['labels'])), + np.ones(len(ignored_instances['labels']))), + axis=0) + assert len(ann['bboxes']) == len(ann['labels']) + if 'masks' in gt_instances: + assert isinstance(gt_instances['masks'], + (PolygonMasks, BitmapMasks)) and \ + isinstance(ignored_instances['masks'], + (PolygonMasks, BitmapMasks)) + ann['masks']: list = [] + ann['masks'].extend( + encode_mask_results(gt_instances['masks'].to_ndarray())) + ann['masks'].extend( + encode_mask_results(ignored_instances['masks'].to_ndarray())) + assert len(ann['bboxes']) == len(ann['masks']) + return ann + + +def parse_coco_prediction(data_sample: dict) -> dict: + """Parse coco prediction. + + Args: + data_sample (dict): Data samples that contain annotations + and predictions. + + Returns: + dict: Represents a detection result for an image, with the + following keys: + + - img_id (int): Image id. + - bboxes (numpy.ndarray): Shape (N, 4), the predicted + bounding bboxes of this image, in 'xyxy' foramrt. + - scores (numpy.ndarray): Shape (N, ), the predicted scores + of bounding boxes. + - labels (numpy.ndarray): Shape (N, ), the predicted labels + of bounding boxes. + - masks (list[RLE], optional): The predicted masks. + - mask_scores (np.array, optional): Shape (N, ), the predicted + scores of masks. + """ + pred = dict() + pred_instances = data_sample['pred_instances'] + pred['img_id'] = data_sample['img_id'] + pred['bboxes'] = pred_instances['bboxes'].cpu().numpy() + pred['scores'] = pred_instances['scores'].cpu().numpy() + pred['labels'] = pred_instances['labels'].cpu().numpy() + if 'masks' in pred_instances: + pred['masks'] = encode_mask_results( + pred_instances['masks'].detach().cpu().numpy()) if isinstance( + pred_instances['masks'], Tensor) else pred_instances['masks'] + # some detectors use different scores for bbox and mask + if 'mask_scores' in pred_instances: + pred['mask_scores'] = \ + pred_instances['mask_scores'].cpu().numpy() + return pred @METRICS.register_module() -class CocoMetric(BaseMetric): - """COCO evaluation metric. +class CocoMetric(COCODetection): + """A wrapper of :class:`mmeval.COCODetection`. Evaluate AR, AP, and mAP for detection tasks including proposal/box detection and instance segmentation. Please refer to @@ -32,15 +119,15 @@ class CocoMetric(BaseMetric): If not specified, ground truth annotations from the dataset will be converted to coco format. Defaults to None. metric (str | List[str]): Metrics to be evaluated. Valid metrics - include 'bbox', 'segm', 'proposal', and 'proposal_fast'. - Defaults to 'bbox'. - classwise (bool): Whether to evaluate the metric class-wise. - Defaults to False. - proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. - Defaults to (100, 300, 1000). + include 'bbox', 'segm', and 'proposal'. Defaults to 'bbox'. iou_thrs (float | List[float], optional): IoU threshold to compute AP and AR. If not specified, IoUs from 0.5 to 0.95 will be used. Defaults to None. + classwise (bool):Whether to return the computed + results of each class. Defaults to False. + proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. + Defaults to (1, 10, 100). + Note: it defaults to (100, 300, 1000) in MMDet 2.x. metric_items (List[str], optional): Metric result names to be recorded in the evaluation result. Defaults to None. format_only (bool): Format the output results without perform @@ -50,284 +137,61 @@ class CocoMetric(BaseMetric): outfile_prefix (str, optional): The prefix of json files. It includes the file path and the prefix of filename, e.g., "a/b/prefix". If not specified, a temp file will be created. Defaults to None. - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmengine.fileio.FileClient` for details. - Defaults to ``dict(backend='disk')``. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + gt_mask_area (bool): Whether calculate GT mask area when not loading + ann_file. If True, the GT instance area will be the mask area, + else the bounding box area. It will not be used when loading + ann_file. Defaults to True. prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None. - sort_categories (bool): Whether sort categories in annotations. Only - used for `Objects365V1Dataset`. Defaults to False. + dist_backend (str | None): The name of the distributed communication + backend. Refer to :class:`mmeval.BaseMetric`. + Defaults to 'torch_cuda'. + **kwargs: Keyword parameters passed to :class:`COCODetection`. """ default_prefix: Optional[str] = 'coco' def __init__(self, ann_file: Optional[str] = None, metric: Union[str, List[str]] = 'bbox', - classwise: bool = False, - proposal_nums: Sequence[int] = (100, 300, 1000), iou_thrs: Optional[Union[float, Sequence[float]]] = None, + classwise: bool = False, + proposal_nums: Sequence[int] = (1, 10, 100), metric_items: Optional[Sequence[str]] = None, format_only: bool = False, outfile_prefix: Optional[str] = None, - file_client_args: dict = dict(backend='disk'), - collect_device: str = 'cpu', + backend_args: Optional[dict] = None, + gt_mask_area: bool = True, prefix: Optional[str] = None, - sort_categories: bool = False) -> None: - super().__init__(collect_device=collect_device, prefix=prefix) - # coco evaluation metrics - self.metrics = metric if isinstance(metric, list) else [metric] - allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] - for metric in self.metrics: - if metric not in allowed_metrics: - raise KeyError( - "metric should be one of 'bbox', 'segm', 'proposal', " - f"'proposal_fast', but got {metric}.") - - # do class wise evaluation, default False - self.classwise = classwise - - # proposal_nums used to compute recall or precision. - self.proposal_nums = list(proposal_nums) - - # iou_thrs used to compute recall or precision. - if iou_thrs is None: - iou_thrs = np.linspace( - .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) - self.iou_thrs = iou_thrs - self.metric_items = metric_items - self.format_only = format_only - if self.format_only: - assert outfile_prefix is not None, 'outfile_prefix must be not' - 'None when format_only is True, otherwise the result files will' - 'be saved to a temp directory which will be cleaned up at the end.' - - self.outfile_prefix = outfile_prefix - - self.file_client_args = file_client_args - self.file_client = FileClient(**file_client_args) - - # if ann_file is not specified, - # initialize coco api with the converted dataset - if ann_file is not None: - with self.file_client.get_local_path(ann_file) as local_path: - self._coco_api = COCO(local_path) - if sort_categories: - # 'categories' list in objects365_train.json and - # objects365_val.json is inconsistent, need sort - # list(or dict) before get cat_ids. - cats = self._coco_api.cats - sorted_cats = {i: cats[i] for i in sorted(cats)} - self._coco_api.cats = sorted_cats - categories = self._coco_api.dataset['categories'] - sorted_categories = sorted( - categories, key=lambda i: i['id']) - self._coco_api.dataset['categories'] = sorted_categories - else: - self._coco_api = None - - # handle dataset lazy init - self.cat_ids = None - self.img_ids = None - - def fast_eval_recall(self, - results: List[dict], - proposal_nums: Sequence[int], - iou_thrs: Sequence[float], - logger: Optional[MMLogger] = None) -> np.ndarray: - """Evaluate proposal recall with COCO's fast_eval_recall. - - Args: - results (List[dict]): Results of the dataset. - proposal_nums (Sequence[int]): Proposal numbers used for - evaluation. - iou_thrs (Sequence[float]): IoU thresholds used for evaluation. - logger (MMLogger, optional): Logger used for logging the recall - summary. - Returns: - np.ndarray: Averaged recall results. - """ - gt_bboxes = [] - pred_bboxes = [result['bboxes'] for result in results] - for i in range(len(self.img_ids)): - ann_ids = self._coco_api.get_ann_ids(img_ids=self.img_ids[i]) - ann_info = self._coco_api.load_anns(ann_ids) - if len(ann_info) == 0: - gt_bboxes.append(np.zeros((0, 4))) - continue - bboxes = [] - for ann in ann_info: - if ann.get('ignore', False) or ann['iscrowd']: - continue - x1, y1, w, h = ann['bbox'] - bboxes.append([x1, y1, x1 + w, y1 + h]) - bboxes = np.array(bboxes, dtype=np.float32) - if bboxes.shape[0] == 0: - bboxes = np.zeros((0, 4)) - gt_bboxes.append(bboxes) - - recalls = eval_recalls( - gt_bboxes, pred_bboxes, proposal_nums, iou_thrs, logger=logger) - ar = recalls.mean(axis=1) - return ar - - def xyxy2xywh(self, bbox: np.ndarray) -> list: - """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO - evaluation. - - Args: - bbox (numpy.ndarray): The bounding boxes, shape (4, ), in - ``xyxy`` order. - - Returns: - list[float]: The converted bounding boxes, in ``xywh`` order. - """ - - _bbox: List = bbox.tolist() - return [ - _bbox[0], - _bbox[1], - _bbox[2] - _bbox[0], - _bbox[3] - _bbox[1], - ] - - def results2json(self, results: Sequence[dict], - outfile_prefix: str) -> dict: - """Dump the detection results to a COCO style json file. - - There are 3 types of results: proposals, bbox predictions, mask - predictions, and they have different data types. This method will - automatically recognize the type, and dump them to json files. - - Args: - results (Sequence[dict]): Testing results of the - dataset. - outfile_prefix (str): The filename prefix of the json files. If the - prefix is "somepath/xxx", the json files will be named - "somepath/xxx.bbox.json", "somepath/xxx.segm.json", - "somepath/xxx.proposal.json". - - Returns: - dict: Possible keys are "bbox", "segm", "proposal", and - values are corresponding filenames. - """ - bbox_json_results = [] - segm_json_results = [] if 'masks' in results[0] else None - for idx, result in enumerate(results): - image_id = result.get('img_id', idx) - labels = result['labels'] - bboxes = result['bboxes'] - scores = result['scores'] - # bbox results - for i, label in enumerate(labels): - data = dict() - data['image_id'] = image_id - data['bbox'] = self.xyxy2xywh(bboxes[i]) - data['score'] = float(scores[i]) - data['category_id'] = self.cat_ids[label] - bbox_json_results.append(data) - - if segm_json_results is None: - continue - - # segm results - masks = result['masks'] - mask_scores = result.get('mask_scores', scores) - for i, label in enumerate(labels): - data = dict() - data['image_id'] = image_id - data['bbox'] = self.xyxy2xywh(bboxes[i]) - data['score'] = float(mask_scores[i]) - data['category_id'] = self.cat_ids[label] - if isinstance(masks[i]['counts'], bytes): - masks[i]['counts'] = masks[i]['counts'].decode() - data['segmentation'] = masks[i] - segm_json_results.append(data) - - result_files = dict() - result_files['bbox'] = f'{outfile_prefix}.bbox.json' - result_files['proposal'] = f'{outfile_prefix}.bbox.json' - dump(bbox_json_results, result_files['bbox']) - - if segm_json_results is not None: - result_files['segm'] = f'{outfile_prefix}.segm.json' - dump(segm_json_results, result_files['segm']) - - return result_files - - def gt_to_coco_json(self, gt_dicts: Sequence[dict], - outfile_prefix: str) -> str: - """Convert ground truth to coco format json file. - - Args: - gt_dicts (Sequence[dict]): Ground truth of the dataset. - outfile_prefix (str): The filename prefix of the json files. If the - prefix is "somepath/xxx", the json file will be named - "somepath/xxx.gt.json". - Returns: - str: The filename of the json file. - """ - categories = [ - dict(id=id, name=name) - for id, name in enumerate(self.dataset_meta['classes']) - ] - image_infos = [] - annotations = [] - - for idx, gt_dict in enumerate(gt_dicts): - img_id = gt_dict.get('img_id', idx) - image_info = dict( - id=img_id, - width=gt_dict['width'], - height=gt_dict['height'], - file_name='') - image_infos.append(image_info) - for ann in gt_dict['anns']: - label = ann['bbox_label'] - bbox = ann['bbox'] - coco_bbox = [ - bbox[0], - bbox[1], - bbox[2] - bbox[0], - bbox[3] - bbox[1], - ] - - annotation = dict( - id=len(annotations) + - 1, # coco api requires id starts with 1 - image_id=img_id, - bbox=coco_bbox, - iscrowd=ann.get('ignore_flag', 0), - category_id=int(label), - area=coco_bbox[2] * coco_bbox[3]) - if ann.get('mask', None): - mask = ann['mask'] - # area = mask_util.area(mask) - if isinstance(mask, dict) and isinstance( - mask['counts'], bytes): - mask['counts'] = mask['counts'].decode() - annotation['segmentation'] = mask - # annotation['area'] = float(area) - annotations.append(annotation) - - info = dict( - date_created=str(datetime.datetime.now()), - description='Coco json file converted by mmdet CocoMetric.') - coco_json = dict( - info=info, - images=image_infos, - categories=categories, - licenses=None, - ) - if len(annotations) > 0: - coco_json['annotations'] = annotations - converted_json_path = f'{outfile_prefix}.gt.json' - dump(coco_json, converted_json_path) - return converted_json_path + dist_backend: str = 'torch_cuda', + **kwargs) -> None: + + collect_device = kwargs.pop('collect_device', None) + if collect_device is not None: + warnings.warn( + 'DeprecationWarning: The `collect_device` parameter of ' + '`CocoMetric` is deprecated, use `dist_backend` instead.') + + logger = MMLogger.get_current_instance() + super().__init__( + ann_file=ann_file, + metric=metric, + iou_thrs=iou_thrs, + classwise=classwise, + proposal_nums=proposal_nums, + metric_items=metric_items, + format_only=format_only, + outfile_prefix=outfile_prefix, + backend_args=backend_args, + gt_mask_area=gt_mask_area, + dist_backend=dist_backend, + logger=logger, + **kwargs) + + self.prefix = prefix or self.default_prefix # TODO: data_batch is no longer needed, consider adjusting the # parameter position @@ -341,217 +205,34 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: data_samples (Sequence[dict]): A batch of data samples that contain annotations and predictions. """ + predictions, groundtruths = [], [] for data_sample in data_samples: - result = dict() - pred = data_sample['pred_instances'] - result['img_id'] = data_sample['img_id'] - result['bboxes'] = pred['bboxes'].cpu().numpy() - result['scores'] = pred['scores'].cpu().numpy() - result['labels'] = pred['labels'].cpu().numpy() - # encode mask to RLE - if 'masks' in pred: - result['masks'] = encode_mask_results( - pred['masks'].detach().cpu().numpy()) if isinstance( - pred['masks'], torch.Tensor) else pred['masks'] - # some detectors use different scores for bbox and mask - if 'mask_scores' in pred: - result['mask_scores'] = pred['mask_scores'].cpu().numpy() + # parse prediction + pred = parse_coco_prediction(data_sample) + predictions.append(pred) - # parse gt - gt = dict() - gt['width'] = data_sample['ori_shape'][1] - gt['height'] = data_sample['ori_shape'][0] - gt['img_id'] = data_sample['img_id'] + # parse groundtruth if self._coco_api is None: - # TODO: Need to refactor to support LoadAnnotations - assert 'instances' in data_sample, \ - 'ground truth is required for evaluation when ' \ - '`ann_file` is not provided' - gt['anns'] = data_sample['instances'] - # add converted result to the results list - self.results.append((gt, result)) + ann = parse_coco_groundtruth(data_sample) + else: + ann = dict() + groundtruths.append(ann) - def compute_metrics(self, results: list) -> Dict[str, float]: - """Compute the metrics from processed results. + self.add(predictions, groundtruths) - Args: - results (list): The processed results of each batch. + def evaluate(self, *args, **kwargs) -> dict: + """Returns metric results and print pretty table of metrics per class. - Returns: - Dict[str, float]: The computed metrics. The keys are the names of - the metrics, and the values are corresponding results. + This method would be invoked by ``mmengine.Evaluator``. """ - logger: MMLogger = MMLogger.get_current_instance() - - # split gt and prediction list - gts, preds = zip(*results) - - tmp_dir = None - if self.outfile_prefix is None: - tmp_dir = tempfile.TemporaryDirectory() - outfile_prefix = osp.join(tmp_dir.name, 'results') - else: - outfile_prefix = self.outfile_prefix + metric_results = self.compute(*args, **kwargs) + self.reset() - if self._coco_api is None: - # use converted gt json file to initialize coco api - logger.info('Converting ground truth to coco format...') - coco_json_path = self.gt_to_coco_json( - gt_dicts=gts, outfile_prefix=outfile_prefix) - self._coco_api = COCO(coco_json_path) - - # handle lazy init - if self.cat_ids is None: - self.cat_ids = self._coco_api.get_cat_ids( - cat_names=self.dataset_meta['classes']) - if self.img_ids is None: - self.img_ids = self._coco_api.get_img_ids() - - # convert predictions to coco format and dump to json file - result_files = self.results2json(preds, outfile_prefix) - - eval_results = OrderedDict() if self.format_only: - logger.info('results are saved in ' - f'{osp.dirname(outfile_prefix)}') - return eval_results - - for metric in self.metrics: - logger.info(f'Evaluating {metric}...') - - # TODO: May refactor fast_eval_recall to an independent metric? - # fast eval recall - if metric == 'proposal_fast': - ar = self.fast_eval_recall( - preds, self.proposal_nums, self.iou_thrs, logger=logger) - log_msg = [] - for i, num in enumerate(self.proposal_nums): - eval_results[f'AR@{num}'] = ar[i] - log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') - log_msg = ''.join(log_msg) - logger.info(log_msg) - continue - - # evaluate proposal, bbox and segm - iou_type = 'bbox' if metric == 'proposal' else metric - if metric not in result_files: - raise KeyError(f'{metric} is not in results') - try: - predictions = load(result_files[metric]) - if iou_type == 'segm': - # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa - # When evaluating mask AP, if the results contain bbox, - # cocoapi will use the box area instead of the mask area - # for calculating the instance area. Though the overall AP - # is not affected, this leads to different - # small/medium/large mask AP results. - for x in predictions: - x.pop('bbox') - coco_dt = self._coco_api.loadRes(predictions) - - except IndexError: - logger.error( - 'The testing results of the whole dataset is empty.') - break - - coco_eval = COCOeval(self._coco_api, coco_dt, iou_type) - - coco_eval.params.catIds = self.cat_ids - coco_eval.params.imgIds = self.img_ids - coco_eval.params.maxDets = list(self.proposal_nums) - coco_eval.params.iouThrs = self.iou_thrs - - # mapping of cocoEval.stats - coco_metric_names = { - 'mAP': 0, - 'mAP_50': 1, - 'mAP_75': 2, - 'mAP_s': 3, - 'mAP_m': 4, - 'mAP_l': 5, - 'AR@100': 6, - 'AR@300': 7, - 'AR@1000': 8, - 'AR_s@1000': 9, - 'AR_m@1000': 10, - 'AR_l@1000': 11 - } - metric_items = self.metric_items - if metric_items is not None: - for metric_item in metric_items: - if metric_item not in coco_metric_names: - raise KeyError( - f'metric item "{metric_item}" is not supported') - - if metric == 'proposal': - coco_eval.params.useCats = 0 - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - if metric_items is None: - metric_items = [ - 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', - 'AR_m@1000', 'AR_l@1000' - ] - - for item in metric_items: - val = float( - f'{coco_eval.stats[coco_metric_names[item]]:.3f}') - eval_results[item] = val - else: - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - if self.classwise: # Compute per-category AP - # Compute per-category AP - # from https://github.com/facebookresearch/detectron2/ - precisions = coco_eval.eval['precision'] - # precision: (iou, recall, cls, area range, max dets) - assert len(self.cat_ids) == precisions.shape[2] - - results_per_category = [] - for idx, cat_id in enumerate(self.cat_ids): - # area range index 0: all area ranges - # max dets index -1: typically 100 per image - nm = self._coco_api.loadCats(cat_id)[0] - precision = precisions[:, :, idx, 0, -1] - precision = precision[precision > -1] - if precision.size: - ap = np.mean(precision) - else: - ap = float('nan') - results_per_category.append( - (f'{nm["name"]}', f'{round(ap, 3)}')) - eval_results[f'{nm["name"]}_precision'] = round(ap, 3) - - num_columns = min(6, len(results_per_category) * 2) - results_flatten = list( - itertools.chain(*results_per_category)) - headers = ['category', 'AP'] * (num_columns // 2) - results_2d = itertools.zip_longest(*[ - results_flatten[i::num_columns] - for i in range(num_columns) - ]) - table_data = [headers] - table_data += [result for result in results_2d] - table = AsciiTable(table_data) - logger.info('\n' + table.table) - - if metric_items is None: - metric_items = [ - 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' - ] - - for metric_item in metric_items: - key = f'{metric}_{metric_item}' - val = coco_eval.stats[coco_metric_names[metric_item]] - eval_results[key] = float(f'{round(val, 3)}') - - ap = coco_eval.stats[:6] - logger.info(f'{metric}_mAP_copypaste: {ap[0]:.3f} ' - f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' - f'{ap[4]:.3f} {ap[5]:.3f}') + return metric_results - if tmp_dir is not None: - tmp_dir.cleanup() - return eval_results + evaluate_results = { + f'{self.prefix}/{k}(%)': round(float(v) * 100, 4) + for k, v in metric_results.items() + } + return evaluate_results diff --git a/mmdet/evaluation/metrics/coco_occluded_metric.py b/mmdet/evaluation/metrics/coco_occluded_metric.py index 544ff4426ba..ea27e4af3bb 100644 --- a/mmdet/evaluation/metrics/coco_occluded_metric.py +++ b/mmdet/evaluation/metrics/coco_occluded_metric.py @@ -1,24 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import List, Optional, Sequence, Union -import os.path as osp -from typing import Dict, List, Optional, Union - -import mmengine -import numpy as np -from mmengine.fileio import load -from mmengine.logging import print_log -from pycocotools import mask as coco_mask -from terminaltables import AsciiTable +from mmengine.logging import MMLogger +from mmeval import CocoOccludedSeparated from mmdet.registry import METRICS -from .coco_metric import CocoMetric +from .coco_metric import parse_coco_groundtruth, parse_coco_prediction @METRICS.register_module() -class CocoOccludedSeparatedMetric(CocoMetric): - """Metric of separated and occluded masks which presented in paper `A Tri- - Layer Plugin to Improve Occluded Detection. +class CocoOccludedSeparatedMetric(CocoOccludedSeparated): + """A wrapper of :class:`mmeval.CocoOccludedSeparated`. + Metric of separated and occluded masks which presented in paper `A Tri- + Layer Plugin to Improve Occluded Detection. `_. Separated COCO and Occluded COCO are automatically generated subsets of @@ -44,168 +40,140 @@ class CocoOccludedSeparatedMetric(CocoMetric): } Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', and 'proposal'. + Defaults to ['bbox', 'segm']. + iou_thrs (float | List[float], optional): IoU threshold to compute AP + and AR. If not specified, IoUs from 0.5 to 0.95 will be used. + Defaults to None. + classwise (bool):Whether to return the computed + results of each class. Defaults to False. + proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. + Defaults to (1, 10, 100). + Note: it defaults to (100, 300, 1000) in MMDet 2.x. + metric_items (List[str], optional): Metric result names to be + recorded in the evaluation result. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + gt_mask_area (bool): Whether calculate GT mask area when not loading + ann_file. If True, the GT instance area will be the mask area, + else the bounding box area. It will not be used when loading + ann_file. Defaults to True. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + dist_backend (str | None): The name of the distributed communication + backend. Refer to :class:`mmeval.BaseMetric`. + Defaults to 'torch_cuda'. occluded_ann (str): Path to the occluded coco annotation file. separated_ann (str): Path to the separated coco annotation file. score_thr (float): Score threshold of the detection masks. Defaults to 0.3. iou_thr (float): IoU threshold for the recall calculation. Defaults to 0.75. - metric (str | List[str]): Metrics to be evaluated. Valid metrics - include 'bbox', 'segm', 'proposal', and 'proposal_fast'. - Defaults to 'bbox'. + **kwargs: Keyword parameters passed to :class:`CocoOccludedSeparated`. """ default_prefix: Optional[str] = 'coco' def __init__( self, - *args, + ann_file: Optional[str] = None, + metric: Union[str, List[str]] = ['bbox', 'segm'], + iou_thrs: Optional[Union[float, Sequence[float]]] = None, + classwise: bool = False, + proposal_nums: Sequence[int] = (1, 10, 100), + metric_items: Optional[Sequence[str]] = None, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + backend_args: Optional[dict] = None, + gt_mask_area: bool = True, occluded_ann: str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/occluded_coco.pkl', # noqa separated_ann: str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/separated_coco.pkl', # noqa score_thr: float = 0.3, iou_thr: float = 0.75, - metric: Union[str, List[str]] = ['bbox', 'segm'], + prefix: Optional[str] = None, + dist_backend: str = 'torch_cuda', **kwargs) -> None: - super().__init__(*args, metric=metric, **kwargs) - # load from local file - if osp.isfile(occluded_ann) and not osp.isabs(occluded_ann): - occluded_ann = osp.join(self.data_root, occluded_ann) - if osp.isfile(separated_ann) and not osp.isabs(separated_ann): - separated_ann = osp.join(self.data_root, separated_ann) - self.occluded_ann = load(occluded_ann) - self.separated_ann = load(separated_ann) - self.score_thr = score_thr - self.iou_thr = iou_thr - - def compute_metrics(self, results: list) -> Dict[str, float]: - """Compute the metrics from processed results. - Args: - results (list): The processed results of each batch. + collect_device = kwargs.pop('collect_device', None) + if collect_device is not None: + warnings.warn( + 'DeprecationWarning: The `collect_device` parameter of ' + '`CocoOccludedSeparatedMetric` is deprecated, ' + 'use `dist_backend` instead.') + + logger = MMLogger.get_current_instance() + super().__init__( + ann_file=ann_file, + metric=metric, + iou_thrs=iou_thrs, + classwise=classwise, + proposal_nums=proposal_nums, + metric_items=metric_items, + format_only=format_only, + outfile_prefix=outfile_prefix, + backend_args=backend_args, + gt_mask_area=gt_mask_area, + occluded_ann=occluded_ann, + separated_ann=separated_ann, + score_thr=score_thr, + iou_thr=iou_thr, + dist_backend=dist_backend, + logger=logger, + **kwargs) + + self.prefix = prefix or self.default_prefix + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. - Returns: - Dict[str, float]: The computed metrics. The keys are the names of - the metrics, and the values are corresponding results. + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. """ - coco_metric_res = super().compute_metrics(results) - eval_res = self.evaluate_occluded_separated(results) - coco_metric_res.update(eval_res) - return coco_metric_res + predictions, groundtruths = [], [] + for data_sample in data_samples: + # parse prediction + pred = parse_coco_prediction(data_sample) + predictions.append(pred) - def evaluate_occluded_separated(self, results: List[tuple]) -> dict: - """Compute the recall of occluded and separated masks. + # parse groundtruth + if self._coco_api is None: + ann = parse_coco_groundtruth(data_sample) + else: + ann = dict() + groundtruths.append(ann) - Args: - results (list[tuple]): Testing results of the dataset. + self.add(predictions, groundtruths) - Returns: - dict[str, float]: The recall of occluded and separated masks. - """ - dict_det = {} - print_log('processing detection results...') - prog_bar = mmengine.ProgressBar(len(results)) - for i in range(len(results)): - gt, dt = results[i] - img_id = dt['img_id'] - cur_img_name = self._coco_api.imgs[img_id]['file_name'] - if cur_img_name not in dict_det.keys(): - dict_det[cur_img_name] = [] - - for bbox, score, label, mask in zip(dt['bboxes'], dt['scores'], - dt['labels'], dt['masks']): - cur_binary_mask = coco_mask.decode(mask) - dict_det[cur_img_name].append([ - score, self.dataset_meta['classes'][label], - cur_binary_mask, bbox - ]) - dict_det[cur_img_name].sort( - key=lambda x: (-x[0], x[3][0], x[3][1]) - ) # rank by confidence from high to low, avoid same confidence - prog_bar.update() - print_log('\ncomputing occluded mask recall...', logger='current') - occluded_correct_num, occluded_recall = self.compute_recall( - dict_det, gt_ann=self.occluded_ann, is_occ=True) - print_log( - f'\nCOCO occluded mask recall: {occluded_recall:.2f}%', - logger='current') - print_log( - f'COCO occluded mask success num: {occluded_correct_num}', - logger='current') - print_log('computing separated mask recall...', logger='current') - separated_correct_num, separated_recall = self.compute_recall( - dict_det, gt_ann=self.separated_ann, is_occ=False) - print_log( - f'\nCOCO separated mask recall: {separated_recall:.2f}%', - logger='current') - print_log( - f'COCO separated mask success num: {separated_correct_num}', - logger='current') - table_data = [ - ['mask type', 'recall', 'num correct'], - ['occluded', f'{occluded_recall:.2f}%', occluded_correct_num], - ['separated', f'{separated_recall:.2f}%', separated_correct_num] - ] - table = AsciiTable(table_data) - print_log('\n' + table.table, logger='current') - return dict( - occluded_recall=occluded_recall, separated_recall=separated_recall) - - def compute_recall(self, - result_dict: dict, - gt_ann: list, - is_occ: bool = True) -> tuple: - """Compute the recall of occluded or separated masks. + def evaluate(self, *args, **kwargs) -> dict: + """Returns metric results and print pretty table of metrics per class. - Args: - result_dict (dict): Processed mask results. - gt_ann (list): Occluded or separated coco annotations. - is_occ (bool): Whether the annotation is occluded mask. - Defaults to True. - Returns: - tuple: number of correct masks and the recall. + This method would be invoked by ``mmengine.Evaluator``. """ - correct = 0 - prog_bar = mmengine.ProgressBar(len(gt_ann)) - for iter_i in range(len(gt_ann)): - cur_item = gt_ann[iter_i] - cur_img_name = cur_item[0] - cur_gt_bbox = cur_item[3] - if is_occ: - cur_gt_bbox = [ - cur_gt_bbox[0], cur_gt_bbox[1], - cur_gt_bbox[0] + cur_gt_bbox[2], - cur_gt_bbox[1] + cur_gt_bbox[3] - ] - cur_gt_class = cur_item[1] - cur_gt_mask = coco_mask.decode(cur_item[4]) - - assert cur_img_name in result_dict.keys() - cur_detections = result_dict[cur_img_name] - - correct_flag = False - for i in range(len(cur_detections)): - cur_det_confidence = cur_detections[i][0] - if cur_det_confidence < self.score_thr: - break - cur_det_class = cur_detections[i][1] - if cur_det_class != cur_gt_class: - continue - cur_det_mask = cur_detections[i][2] - cur_iou = self.mask_iou(cur_det_mask, cur_gt_mask) - if cur_iou >= self.iou_thr: - correct_flag = True - break - if correct_flag: - correct += 1 - prog_bar.update() - recall = correct / len(gt_ann) * 100 - return correct, recall - - def mask_iou(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: - """Compute IoU between two masks.""" - mask1_area = np.count_nonzero(mask1 == 1) - mask2_area = np.count_nonzero(mask2 == 1) - intersection = np.count_nonzero(np.logical_and(mask1 == 1, mask2 == 1)) - iou = intersection / (mask1_area + mask2_area - intersection) - return iou + metric_results = self.compute(*args, **kwargs) + self.reset() + + evaluate_results = { + f'{self.prefix}/{k}(%)': round(float(v) * 100, 4) + for k, v in metric_results.items() + } + return evaluate_results diff --git a/mmdet/evaluation/metrics/proposal_recall_metric.py b/mmdet/evaluation/metrics/proposal_recall_metric.py new file mode 100644 index 00000000000..f508422fabd --- /dev/null +++ b/mmdet/evaluation/metrics/proposal_recall_metric.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import Optional, Sequence, Union + +from mmengine.logging import MMLogger +from mmeval.metrics import ProposalRecall + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class ProposalRecallMetric(ProposalRecall): + """A wrapper of :class:`mmeval.ProposalRecall`. + + The speed of calculating recall is faster than COCO Detection metric. + + Args: + iou_thrs (float | List[float], optional): IoU thresholds. + If not specified, IoUs from 0.5 to 0.95 will be used. + Defaults to None. + proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. + Defaults to (1, 10, 100, 1000). + Note: it defaults to (100, 300, 1000) in MMDet 2.x. + use_legacy_coordinate (bool): Whether to use coordinate + system in mmdet v1.x. which means width, height should be + calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. + Please set `True` when using VOCDataset. Defaults to False. + nproc (int): Processes used for computing TP and FP. If nproc + is less than or equal to 1, multiprocessing will not be used. + Defaults to 4. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + dist_backend (str | None): The name of the distributed communication + backend. Refer to :class:`mmeval.BaseMetric`. + Defaults to 'torch_cuda'. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + """ + default_prefix: Optional[str] = 'proposals' + + def __init__(self, + iou_thrs: Optional[Union[float, Sequence[float]]] = None, + proposal_nums: Union[int, Sequence[int]] = (1, 10, 100, 1000), + use_legacy_coordinate: bool = False, + nproc: int = 4, + prefix: Optional[str] = None, + dist_backend: str = 'torch_cuda', + **kwargs) -> None: + + collect_device = kwargs.pop('collect_device', None) + if collect_device is not None: + warnings.warn( + 'DeprecationWarning: The `collect_device` parameter of ' + '`ProposalRecallMetric` is deprecated, ' + 'use `dist_backend` instead.') + + logger = MMLogger.get_current_instance() + + super().__init__( + iou_thrs=iou_thrs, + proposal_nums=proposal_nums, + use_legacy_coordinate=use_legacy_coordinate, + nproc=nproc, + dist_backend=dist_backend, + logger=logger, + **kwargs) + + self.prefix = prefix or self.default_prefix + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. Parse predictions + and ground truths from ``data_samples`` and invoke ``self.add``. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + predictions, groundtruths = [], [] + for data_sample in data_samples: + cp_data_sample = copy.deepcopy(data_sample) + gt_instances = cp_data_sample['gt_instances'] + ann = dict( + labels=gt_instances['labels'].cpu().numpy(), + bboxes=gt_instances['bboxes'].cpu().numpy()) + groundtruths.append(ann) + + pred_instances = cp_data_sample['pred_instances'] + pred = dict( + bboxes=pred_instances['bboxes'].cpu().numpy(), + scores=pred_instances['scores'].cpu().numpy()) + predictions.append(pred) + + self.add(predictions, groundtruths) + + def evaluate(self, *args, **kwargs) -> dict: + """Returns metric results and print pretty table of metrics per class. + + This method would be invoked by ``mmengine.Evaluator``. + """ + metric_results = self.compute(*args, **kwargs) + self.reset() + + evaluate_results = { + f'{self.prefix}/{k}(%)': round(float(v) * 100, 4) + for k, v in metric_results.items() + } + return evaluate_results diff --git a/tests/test_evaluation/test_metrics/test_coco_metric.py b/tests/test_evaluation/test_metrics/test_coco_metric.py index 63611a1c3cb..2877139ad45 100644 --- a/tests/test_evaluation/test_metrics/test_coco_metric.py +++ b/tests/test_evaluation/test_metrics/test_coco_metric.py @@ -8,6 +8,7 @@ from mmengine.fileio import dump from mmdet.evaluation import CocoMetric +from mmdet.structures.mask import BitmapMasks class TestCocoMetric(TestCase): @@ -89,7 +90,7 @@ def _create_dummy_coco_json(self, json_name): def _create_dummy_results(self): bboxes = np.array([[50, 60, 70, 80], [100, 120, 130, 150], [150, 160, 190, 200], [250, 260, 350, 360]]) - scores = np.array([1.0, 0.98, 0.96, 0.95]) + scores = np.array([100.0, 0.98, 0.96, 0.95]) labels = np.array([0, 0, 1, 0]) dummy_mask = np.zeros((4, 10, 10), dtype=np.uint8) dummy_mask[:, :5, :5] = 1 @@ -121,19 +122,19 @@ def test_evaluate(self): coco_metric = CocoMetric( ann_file=fake_json_file, classwise=False, + dataset_meta=dict(classes=['car', 'bicycle']), outfile_prefix=f'{self.tmp_dir.name}/test') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) + eval_results = coco_metric.evaluate() target = { - 'coco/bbox_mAP': 1.0, - 'coco/bbox_mAP_50': 1.0, - 'coco/bbox_mAP_75': 1.0, - 'coco/bbox_mAP_s': 1.0, - 'coco/bbox_mAP_m': 1.0, - 'coco/bbox_mAP_l': 1.0, + 'coco/bbox_mAP(%)': 100.0, + 'coco/bbox_mAP_50(%)': 100.0, + 'coco/bbox_mAP_75(%)': 100.0, + 'coco/bbox_mAP_s(%)': 100.0, + 'coco/bbox_mAP_m(%)': 100.0, + 'coco/bbox_mAP_l(%)': 100.0, } self.assertDictEqual(eval_results, target) self.assertTrue( @@ -144,25 +145,25 @@ def test_evaluate(self): ann_file=fake_json_file, metric=['bbox', 'segm'], classwise=False, + dataset_meta=dict(classes=['car', 'bicycle']), outfile_prefix=f'{self.tmp_dir.name}/test') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) + eval_results = coco_metric.evaluate() target = { - 'coco/bbox_mAP': 1.0, - 'coco/bbox_mAP_50': 1.0, - 'coco/bbox_mAP_75': 1.0, - 'coco/bbox_mAP_s': 1.0, - 'coco/bbox_mAP_m': 1.0, - 'coco/bbox_mAP_l': 1.0, - 'coco/segm_mAP': 1.0, - 'coco/segm_mAP_50': 1.0, - 'coco/segm_mAP_75': 1.0, - 'coco/segm_mAP_s': 1.0, - 'coco/segm_mAP_m': 1.0, - 'coco/segm_mAP_l': 1.0, + 'coco/bbox_mAP(%)': 100.0, + 'coco/bbox_mAP_50(%)': 100.0, + 'coco/bbox_mAP_75(%)': 100.0, + 'coco/bbox_mAP_s(%)': 100.0, + 'coco/bbox_mAP_m(%)': 100.0, + 'coco/bbox_mAP_l(%)': 100.0, + 'coco/segm_mAP(%)': 100.0, + 'coco/segm_mAP_50(%)': 100.0, + 'coco/segm_mAP_75(%)': 100.0, + 'coco/segm_mAP_s(%)': 100.0, + 'coco/segm_mAP_m(%)': 100.0, + 'coco/segm_mAP_l(%)': 100.0, } self.assertDictEqual(eval_results, target) self.assertTrue( @@ -174,24 +175,26 @@ def test_evaluate(self): with self.assertRaisesRegex(KeyError, 'metric item "invalid" is not supported'): coco_metric = CocoMetric( - ann_file=fake_json_file, metric_items=['invalid']) - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + ann_file=fake_json_file, + dataset_meta=dict(classes=['car', 'bicycle']), + metric_items=['invalid']) coco_metric.process({}, [ dict( pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640)) ]) - coco_metric.evaluate(size=1) + coco_metric.evaluate() # test custom metric_items coco_metric = CocoMetric( - ann_file=fake_json_file, metric_items=['mAP_m']) - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + ann_file=fake_json_file, + dataset_meta=dict(classes=['car', 'bicycle']), + metric_items=['mAP_m']) coco_metric.process( {}, [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) + eval_results = coco_metric.evaluate() target = { - 'coco/bbox_mAP_m': 1.0, + 'coco/bbox_mAP_m(%)': 100.0, } self.assertDictEqual(eval_results, target) @@ -203,21 +206,23 @@ def test_classwise_evaluate(self): # test single coco dataset evaluation coco_metric = CocoMetric( - ann_file=fake_json_file, metric='bbox', classwise=True) - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + ann_file=fake_json_file, + metric='bbox', + dataset_meta=dict(classes=['car', 'bicycle']), + classwise=True) coco_metric.process( {}, [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) + eval_results = coco_metric.evaluate() target = { - 'coco/bbox_mAP': 1.0, - 'coco/bbox_mAP_50': 1.0, - 'coco/bbox_mAP_75': 1.0, - 'coco/bbox_mAP_s': 1.0, - 'coco/bbox_mAP_m': 1.0, - 'coco/bbox_mAP_l': 1.0, - 'coco/car_precision': 1.0, - 'coco/bicycle_precision': 1.0, + 'coco/bbox_mAP(%)': 100.0, + 'coco/bbox_mAP_50(%)': 100.0, + 'coco/bbox_mAP_75(%)': 100.0, + 'coco/bbox_mAP_s(%)': 100.0, + 'coco/bbox_mAP_m(%)': 100.0, + 'coco/bbox_mAP_l(%)': 100.0, + 'coco/bbox_car_precision(%)': 100.0, + 'coco/bbox_bicycle_precision(%)': 100.0, } self.assertDictEqual(eval_results, target) @@ -227,40 +232,47 @@ def test_manually_set_iou_thrs(self): self._create_dummy_coco_json(fake_json_file) # test single coco dataset evaluation - coco_metric = CocoMetric( - ann_file=fake_json_file, metric='bbox', iou_thrs=[0.3, 0.6]) - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) - self.assertEqual(coco_metric.iou_thrs, [0.3, 0.6]) - - def test_fast_eval_recall(self): - # create dummy data - fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') - self._create_dummy_coco_json(fake_json_file) - dummy_pred = self._create_dummy_results() - - # test default proposal nums - coco_metric = CocoMetric( - ann_file=fake_json_file, metric='proposal_fast') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) - coco_metric.process( - {}, - [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) - target = {'coco/AR@100': 1.0, 'coco/AR@300': 1.0, 'coco/AR@1000': 1.0} - self.assertDictEqual(eval_results, target) - - # test manually set proposal nums coco_metric = CocoMetric( ann_file=fake_json_file, - metric='proposal_fast', - proposal_nums=(2, 4)) - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) - coco_metric.process( - {}, - [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) - target = {'coco/AR@2': 0.5, 'coco/AR@4': 1.0} - self.assertDictEqual(eval_results, target) + metric='bbox', + iou_thrs=[0.3, 0.6], + dataset_meta=dict(classes=['car', 'bicycle'])) + self.assertTrue( + np.array_equal(coco_metric.iou_thrs, np.array([0.3, 0.6]))) + + # TODO: move to fast recall metric + # def test_fast_eval_recall(self): + # # create dummy data + # fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') + # self._create_dummy_coco_json(fake_json_file) + # dummy_pred = self._create_dummy_results() + # + # # test default proposal nums + # coco_metric = CocoMetric( + # ann_file=fake_json_file, metric='proposal_fast') + # coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + # coco_metric.process( + # {}, + # [dict(pred_instances=dummy_pred, img_id=0, + # ori_shape=(640, 640))]) + # eval_results = coco_metric.evaluate() + # target = {'coco/AR@100': 100.0, 'coco/AR@300': 100.0, + # 'coco/AR@1000': 1.0} + # self.assertDictEqual(eval_results, target) + # + # # test manually set proposal nums + # coco_metric = CocoMetric( + # ann_file=fake_json_file, + # metric='proposal_fast', + # proposal_nums=(2, 4)) + # coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + # coco_metric.process( + # {}, + # [dict(pred_instances=dummy_pred, img_id=0, + # ori_shape=(640, 640))]) + # eval_results = coco_metric.evaluate() + # target = {'coco/AR@2': 0.5, 'coco/AR@4': 1.0} + # self.assertDictEqual(eval_results, target) def test_evaluate_proposal(self): # create dummy data @@ -268,20 +280,21 @@ def test_evaluate_proposal(self): self._create_dummy_coco_json(fake_json_file) dummy_pred = self._create_dummy_results() - coco_metric = CocoMetric(ann_file=fake_json_file, metric='proposal') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + coco_metric = CocoMetric( + ann_file=fake_json_file, + metric='proposal', + dataset_meta=dict(classes=['car', 'bicycle'])) coco_metric.process( {}, [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) - print(eval_results) + eval_results = coco_metric.evaluate() target = { - 'coco/AR@100': 1, - 'coco/AR@300': 1.0, - 'coco/AR@1000': 1.0, - 'coco/AR_s@1000': 1.0, - 'coco/AR_m@1000': 1.0, - 'coco/AR_l@1000': 1.0 + 'coco/AR@1(%)': 25.0, + 'coco/AR@10(%)': 100.0, + 'coco/AR@100(%)': 100.0, + 'coco/AR_s@100(%)': 100.0, + 'coco/AR_m@100(%)': 100.0, + 'coco/AR_l@100(%)': 100.0 } self.assertDictEqual(eval_results, target) @@ -289,8 +302,10 @@ def test_empty_results(self): # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) - coco_metric = CocoMetric(ann_file=fake_json_file, metric='bbox') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) + coco_metric = CocoMetric( + ann_file=fake_json_file, + dataset_meta=dict(classes=['car', 'bicycle']), + metric='bbox') bboxes = np.zeros((0, 4)) labels = np.array([]) scores = np.array([]) @@ -304,64 +319,67 @@ def test_empty_results(self): {}, [dict(pred_instances=empty_pred, img_id=0, ori_shape=(640, 640))]) # coco api Index error will be caught - coco_metric.evaluate(size=1) + coco_metric.evaluate() def test_evaluate_without_json(self): dummy_pred = self._create_dummy_results() - dummy_mask = np.zeros((10, 10), order='F', dtype=np.uint8) - dummy_mask[:5, :5] = 1 - rle_mask = mask_util.encode(dummy_mask) - rle_mask['counts'] = rle_mask['counts'].decode('utf-8') - instances = [{ - 'bbox_label': 0, - 'bbox': [50, 60, 70, 80], - 'ignore_flag': 0, - 'mask': rle_mask, - }, { - 'bbox_label': 0, - 'bbox': [100, 120, 130, 150], - 'ignore_flag': 0, - 'mask': rle_mask, - }, { - 'bbox_label': 1, - 'bbox': [150, 160, 190, 200], - 'ignore_flag': 0, - 'mask': rle_mask, - }, { - 'bbox_label': 0, - 'bbox': [250, 260, 350, 360], - 'ignore_flag': 0, - 'mask': rle_mask, - }] + # create fake gts + bboxes = torch.Tensor([[50, 60, 70, 80], [100, 120, 130, 150], + [150, 160, 190, 200], [250, 260, 350, 360]]) + labels = torch.Tensor([0, 0, 1, 0]) + mask = np.zeros((10, 10), dtype=np.uint8) + mask[:5, :5] = 1 + + dummy_mask = BitmapMasks( + masks=[mask for _ in range(4)], height=10, width=10) + + dummy_gt = dict( + img_id=0, + width=640, + height=640, + bboxes=bboxes, + labels=labels, + masks=dummy_mask) + + dummy_gt_ignore = dict( + img_id=0, + width=640, + height=640, + bboxes=torch.zeros((0, 4)), + labels=torch.zeros((0, )), + masks=BitmapMasks(masks=[], height=640, width=640)) + + # gt area based on bboxes coco_metric = CocoMetric( ann_file=None, metric=['bbox', 'segm'], classwise=False, + gt_mask_area=False, + dataset_meta=dict(classes=['car', 'bicycle']), outfile_prefix=f'{self.tmp_dir.name}/test') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process({}, [ dict( pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640), - instances=instances) + gt_instances=dummy_gt, + ignored_instances=dummy_gt_ignore) ]) - eval_results = coco_metric.evaluate(size=1) - print(eval_results) + eval_results = coco_metric.evaluate() target = { - 'coco/bbox_mAP': 1.0, - 'coco/bbox_mAP_50': 1.0, - 'coco/bbox_mAP_75': 1.0, - 'coco/bbox_mAP_s': 1.0, - 'coco/bbox_mAP_m': 1.0, - 'coco/bbox_mAP_l': 1.0, - 'coco/segm_mAP': 1.0, - 'coco/segm_mAP_50': 1.0, - 'coco/segm_mAP_75': 1.0, - 'coco/segm_mAP_s': 1.0, - 'coco/segm_mAP_m': 1.0, - 'coco/segm_mAP_l': 1.0, + 'coco/bbox_mAP(%)': 100.0, + 'coco/bbox_mAP_50(%)': 100.0, + 'coco/bbox_mAP_75(%)': 100.0, + 'coco/bbox_mAP_s(%)': 100.0, + 'coco/bbox_mAP_m(%)': 100.0, + 'coco/bbox_mAP_l(%)': 100.0, + 'coco/segm_mAP(%)': 100.0, + 'coco/segm_mAP_50(%)': 100.0, + 'coco/segm_mAP_75(%)': 100.0, + 'coco/segm_mAP_s(%)': 100.0, + 'coco/segm_mAP_m(%)': 100.0, + 'coco/segm_mAP_l(%)': 100.0, } self.assertDictEqual(eval_results, target) self.assertTrue( @@ -371,6 +389,39 @@ def test_evaluate_without_json(self): self.assertTrue( osp.isfile(osp.join(self.tmp_dir.name, 'test.gt.json'))) + # gt area based on masks + coco_metric = CocoMetric( + ann_file=None, + metric=['bbox', 'segm'], + classwise=False, + gt_mask_area=True, + dataset_meta=dict(classes=['car', 'bicycle']), + outfile_prefix=f'{self.tmp_dir.name}/test') + coco_metric.process({}, [ + dict( + pred_instances=dummy_pred, + img_id=0, + ori_shape=(640, 640), + gt_instances=dummy_gt, + ignored_instances=dummy_gt_ignore) + ]) + eval_results = coco_metric.evaluate() + target = { + 'coco/bbox_mAP(%)': 100.0, + 'coco/bbox_mAP_50(%)': 100.0, + 'coco/bbox_mAP_75(%)': 100.0, + 'coco/bbox_mAP_s(%)': 100.0, + 'coco/bbox_mAP_m(%)': -100.0, + 'coco/bbox_mAP_l(%)': -100.0, + 'coco/segm_mAP(%)': 100.0, + 'coco/segm_mAP_50(%)': 100.0, + 'coco/segm_mAP_75(%)': 100.0, + 'coco/segm_mAP_s(%)': 100.0, + 'coco/segm_mAP_m(%)': -100.0, + 'coco/segm_mAP_l(%)': -100.0, + } + self.assertDictEqual(eval_results, target) + def test_format_only(self): # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') @@ -389,11 +440,11 @@ def test_format_only(self): metric='bbox', classwise=False, format_only=True, + dataset_meta=dict(classes=['car', 'bicycle']), outfile_prefix=f'{self.tmp_dir.name}/test') - coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, [dict(pred_instances=dummy_pred, img_id=0, ori_shape=(640, 640))]) - eval_results = coco_metric.evaluate(size=1) + eval_results = coco_metric.evaluate() self.assertDictEqual(eval_results, dict()) self.assertTrue(osp.exists(f'{self.tmp_dir.name}/test.bbox.json')) diff --git a/tests/test_evaluation/test_metrics/test_proposal_recall_metric.py b/tests/test_evaluation/test_metrics/test_proposal_recall_metric.py new file mode 100644 index 00000000000..8bc160d5b91 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_proposal_recall_metric.py @@ -0,0 +1,57 @@ +import tempfile +from unittest import TestCase + +import numpy as np +import torch + +from mmdet.evaluation import ProposalRecallMetric + + +class TesPoporsalRecallMetric(TestCase): + + def _create_dummy_gt(self): + bboxes = np.array([[50, 60, 70, 80], [100, 120, 130, 150], + [150, 160, 190, 200], [250, 260, 350, 360]]) + labels = np.array([0, 0, 1, 0]) + return dict( + bboxes=torch.from_numpy(bboxes), labels=torch.from_numpy(labels)) + + def _create_dummy_results(self): + bboxes = np.array([[50, 60, 70, 80], [100, 120, 130, 150], + [150, 160, 190, 200], [250, 260, 350, 360]]) + scores = np.array([1.0, 0.98, 0.96, 0.95]) + return dict( + bboxes=torch.from_numpy(bboxes), scores=torch.from_numpy(scores)) + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.tmp_dir.cleanup() + + def test_init(self): + # test invalid iou_thrs + with self.assertRaises(TypeError): + ProposalRecallMetric(iou_thrs={'a', 0.5}) + + metric = ProposalRecallMetric(iou_thrs=0.6) + self.assertTrue(np.array_equal(metric.iou_thrs, np.array([0.6]))) + + def test_evaluate(self): + # create dummy data + dummy_gt = self._create_dummy_gt() + dummy_pred = self._create_dummy_results() + + # test single coco dataset evaluation + proposal_metric = ProposalRecallMetric( + proposal_nums=(1, 10, 100, 1000)) + proposal_metric.process( + {}, [dict(gt_instances=dummy_gt, pred_instances=dummy_pred)]) + eval_results = proposal_metric.evaluate() + target = { + 'proposals/AR@1(%)': 25.0, + 'proposals/AR@10(%)': 100.0, + 'proposals/AR@100(%)': 100.0, + 'proposals/AR@1000(%)': 100.0 + } + self.assertDictEqual(eval_results, target) diff --git a/tools/analysis_tools/coco_occluded_separated_recall.py b/tools/analysis_tools/coco_occluded_separated_recall.py index e61f2ccd945..ee434f2aa5d 100644 --- a/tools/analysis_tools/coco_occluded_separated_recall.py +++ b/tools/analysis_tools/coco_occluded_separated_recall.py @@ -38,7 +38,7 @@ def main(): metric.dataset_meta = CocoDataset.METAINFO for datasample in results: metric.process(data_batch=None, data_samples=[datasample]) - metric_res = metric.compute_metrics(metric.results) + metric_res = metric.evaluate() if args.out is not None: mmengine.dump(metric_res, args.out) print_log(f'Evaluation results have been saved to {args.out}.')