From 398a61129951f3c14bd26676edad666362b113c3 Mon Sep 17 00:00:00 2001 From: BIGWangYuDong Date: Wed, 15 Mar 2023 13:03:47 +0800 Subject: [PATCH] [Feature] Support CocoOccludedSeparated --- mmeval/metrics/__init__.py | 4 +- mmeval/metrics/coco_detection.py | 253 ++++++++++++++++++++++++++++++- 2 files changed, 253 insertions(+), 4 deletions(-) diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 0411f1ee..66501414 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -7,7 +7,7 @@ from .average_precision import AveragePrecision from .bleu import BLEU from .char_recall_precision import CharRecallPrecision -from .coco_detection import COCODetection +from .coco_detection import COCODetection, CocoOccludedSeparated from .connectivity_error import ConnectivityError from .dota_map import DOTAMeanAP from .end_point_error import EndPointError @@ -48,7 +48,7 @@ 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', 'WordAccuracy', 'PrecisionRecallF1score', 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score', - 'CharRecallPrecision' + 'CharRecallPrecision', 'CocoOccludedSeparated' ] _deprecated_msg = ( diff --git a/mmeval/metrics/coco_detection.py b/mmeval/metrics/coco_detection.py index bae8bd92..f45c0446 100644 --- a/mmeval/metrics/coco_detection.py +++ b/mmeval/metrics/coco_detection.py @@ -18,7 +18,7 @@ from mmeval.utils import is_list_of try: - from mmeval.metrics.utils.coco_wrapper import COCO, COCOeval + from mmeval.metrics.utils.coco_wrapper import COCO, COCOeval, mask_util HAS_COCOAPI = True except ImportError: HAS_COCOAPI = False @@ -206,7 +206,7 @@ def __init__(self, 'be saved to a temp directory which will be cleaned up at the end.' self.outfile_prefix = outfile_prefix - + self.backend_args = backend_args # if ann_file is not specified, # initialize coco api with the converted dataset self._coco_api: Optional[COCO] # type: ignore @@ -750,6 +750,255 @@ def classes(self) -> list: return classes +class CocoOccludedSeparated(COCODetection): + """Metric of separated and occluded masks which presented in paper `A Tri- + Layer Plugin to Improve Occluded Detection. + + `_. + + Separated COCO and Occluded COCO are automatically generated subsets of + COCO val dataset, collecting separated objects and partially occluded + objects for a large variety of categories. In this way, we define + occlusion into two major categories: separated and partially occluded. + + - Separation: target object segmentation mask is separated into distinct + regions by the occluder. + - Partial Occlusion: target object is partially occluded but the + segmentation mask is connected. + + These two new scalable real-image datasets are to benchmark a model's + capability to detect occluded objects of 80 common categories. + + Please cite the paper if you use this dataset: + + @article{zhan2022triocc, + title={A Tri-Layer Plugin to Improve Occluded Detection}, + author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew}, + journal={British Machine Vision Conference}, + year={2022} + } + + Args: + occluded_ann (str): Path to the occluded coco annotation file. + separated_ann (str): Path to the separated coco annotation file. + score_thr (float): Score threshold of the detection masks. + Defaults to 0.3. + iou_thr (float): IoU threshold for the recall calculation. + Defaults to 0.75. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', and 'proposal'. + Defaults to ['bbox', 'segm']. + **kwargs: Keyword parameters passed to :class:`COCODetection`. + """ + + def __init__( + self, + *args, + occluded_ann: # noqa + str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/occluded_coco.pkl', # noqa + separated_ann: # noqa + str = 'https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/separated_coco.pkl', # noqa + score_thr: float = 0.3, + iou_thr: float = 0.75, + metric: Union[str, List[str]] = ['bbox', 'segm'], + **kwargs) -> None: + super().__init__(*args, metric=metric, **kwargs) # type: ignore + self.occluded_ann = load(occluded_ann, backend_args=self.backend_args) + self.separated_ann = load( + separated_ann, backend_args=self.backend_args) + self.score_thr = score_thr + self.iou_thr = iou_thr + + def compute_metric(self, results: list) -> dict: + """Compute the COCO and CocoOccludedSeparated metrics. + + Args: + results (List[tuple]): A list of tuple. Each tuple is the + prediction and ground truth of an image. This list has already + been synced across all ranks. + + Returns: + dict: The computed metric. The keys are the names of the metrics, + and the values are corresponding results. + """ + coco_metric_res = super().compute_metric(results) + eval_res = self.evaluate_occluded_separated(results) + coco_metric_res.update(eval_res) + return coco_metric_res + + def evaluate_occluded_separated(self, results: List[tuple]) -> dict: + """Compute the recall of occluded and separated masks. + + Args: + results (List[tuple]): A list of tuple. Each tuple is the + prediction and ground truth of an image. This list has already + been synced across all ranks. + + Returns: + dict[str, float]: The recall of occluded and separated masks. + """ + dict_det: dict = dict() + self.logger.info('processing detection results...') + total_results = len(results) + + classes = self.classes + for i in range(total_results): + dt, gt = results[i] + img_id = dt['img_id'] + cur_img_name = self._coco_api.imgs[img_id]['file_name'] # type: ignore # yapf: disable # noqa: E501 + if cur_img_name not in dict_det.keys(): + dict_det[cur_img_name] = [] + + for bbox, score, label, mask in zip(dt['bboxes'], dt['scores'], + dt['labels'], dt['masks']): + cur_binary_mask = mask_util.decode(mask) + dict_det[cur_img_name].append( + [score, classes[label], cur_binary_mask, bbox]) + dict_det[cur_img_name].sort( + key=lambda x: (-x[0], x[3][0], x[3][1]) + ) # rank by confidence from high to low, avoid same confidence + print( + f'\rProcessing results {i + 1}/{total_results}', + end='', + flush=True) + print('\nFinished process results') + eval_results: OrderedDict = OrderedDict() + table_results: OrderedDict = OrderedDict() + + self.logger.info('\nComputing occluded mask recall...') + occluded_correct_num, occluded_recall = self.compute_recall( + dict_det, gt_ann=self.occluded_ann, is_occ=True) + self.logger.info( + f'COCO occluded mask success num: {occluded_correct_num}') + self.logger.info('COCO occluded mask recall: ' + f'{round(occluded_recall * 100, 2):.2f}%') + eval_results['occluded_recall'] = occluded_recall + table_results['occluded_recall'] = \ + f'{round(occluded_recall * 100, 2):.2f}%' + table_results['occluded_correct_num'] = f'{occluded_correct_num}' + + self.logger.info('Computing separated mask recall...') + separated_correct_num, separated_recall = self.compute_recall( + dict_det, gt_ann=self.separated_ann, is_occ=False) + self.logger.info( + f'COCO separated mask success num: {separated_correct_num}') + self.logger.info('COCO separated mask recall: ' + f'{round(separated_recall * 100, 2):.2f}%') + eval_results['separated_recall'] = separated_recall + table_results['separated_recall'] = \ + f'{round(separated_recall * 100, 2):.2f}%' + table_results['separated_correct_num'] = f'{separated_correct_num}' + + if self.print_results: + self._print_occluded_separated_recall(table_results) + + return eval_results + + def compute_recall(self, + result_dict: dict, + gt_ann: list, + is_occ: bool = True) -> tuple: + """Compute the recall of occluded or separated masks. + + Args: + result_dict (dict): Processed mask results. + gt_ann (list): Occluded or separated coco annotations. + is_occ (bool): Whether the annotation is occluded mask. + Defaults to True. + Returns: + tuple: number of correct masks and the recall. + """ + correct = 0 + total_ann = len(gt_ann) + for iter_i in range(total_ann): + cur_item = gt_ann[iter_i] + cur_img_name = cur_item[0] + cur_gt_bbox = cur_item[3] + if is_occ: + cur_gt_bbox = [ + cur_gt_bbox[0], cur_gt_bbox[1], + cur_gt_bbox[0] + cur_gt_bbox[2], + cur_gt_bbox[1] + cur_gt_bbox[3] + ] + cur_gt_class = cur_item[1] + cur_gt_mask = mask_util.decode(cur_item[4]) + + assert cur_img_name in result_dict.keys() + cur_detections = result_dict[cur_img_name] + + correct_flag = False + for i in range(len(cur_detections)): + cur_det_confidence = cur_detections[i][0] + if cur_det_confidence < self.score_thr: + break + cur_det_class = cur_detections[i][1] + if cur_det_class != cur_gt_class: + continue + cur_det_mask = cur_detections[i][2] + cur_iou = self.mask_iou(cur_det_mask, cur_gt_mask) + if cur_iou >= self.iou_thr: + correct_flag = True + break + if correct_flag: + correct += 1 + print( + f'\rComputing Recall {iter_i + 1}/{total_ann}', + end='', + flush=True) + if is_occ: + print('\nFinished compute occluded recall') + else: + print('\nFinished compute separated recall') + recall = correct / len(gt_ann) + return correct, recall + + def mask_iou(self, pred_mask: np.ndarray, + gt_mask: np.ndarray) -> np.ndarray: + """Compute IoU between two masks. + + Args: + pred_mask (np.ndarry): The predicted mask. + gt_mask (np.ndarray): The groundtruth mask. + + Returns: + np.ndarry: The IoU results of two masks. + """ + mask1_area = np.count_nonzero(pred_mask == 1) + mask2_area = np.count_nonzero(gt_mask == 1) + intersection = np.count_nonzero( + np.logical_and(pred_mask == 1, gt_mask == 1)) + iou = intersection / (mask1_area + mask2_area - intersection) + return iou + + def _print_occluded_separated_recall(self, table_results: dict) -> None: + """Print the evaluation results table. + + Args: + table_results (dict): The computed metric. + """ + table_title = 'Occluded and Separated COCO Results' + headers = ['mask type', 'recall', 'num correct'] + table = Table(title=table_title) + console = Console() + + result_list = [[ + 'occluded', table_results['occluded_recall'], + table_results['occluded_correct_num'] + ], + [ + 'separated', table_results['separated_recall'], + table_results['separated_correct_num'] + ]] + + for name in headers: + table.add_column(name, justify='left') + for result in result_list: + table.add_row(*result) + with console.capture() as capture: + console.print(table, end='') + self.logger.info('\n' + capture.get()) + + # Keep the deprecated metric name as an alias. # The deprecated Metric names will be removed in 1.0.0! COCODetectionMetric = COCODetection