Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor COCOMetric, CocoOccludedSeparatedMetric, and ProposalRecallMetric by using MMEval #9079

Open
wants to merge 19 commits into
base: refactor_metrics
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/_base_/datasets/coco_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
type='CocoMetric',
ann_file=data_root + 'annotations/instances_val2017.json',
metric='bbox',
classwise=False,
format_only=False)
test_evaluator = val_evaluator

Expand Down
1 change: 1 addition & 0 deletions configs/_base_/datasets/coco_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
type='CocoMetric',
ann_file=data_root + 'annotations/instances_val2017.json',
metric=['bbox', 'segm'],
classwise=False,
format_only=False)
test_evaluator = val_evaluator

Expand Down
1 change: 1 addition & 0 deletions configs/_base_/datasets/coco_instance_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@
type='CocoMetric',
ann_file=data_root + 'annotations/instances_val2017.json',
metric=['bbox', 'segm'],
classwise=False,
format_only=False)
test_evaluator = val_evaluator
16 changes: 10 additions & 6 deletions configs/rpn/rpn_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

val_evaluator = dict(metric='proposal_fast')
val_evaluator = dict(
_delete_=True,
type='ProposalRecallMetric',
proposal_nums=(1, 10, 100, 1000),
prefix='coco',
use_legacy_coordinate=False, # VOCDataset should set True, else False
)
test_evaluator = val_evaluator

# inference on val dataset and dump the proposals with evaluate metric
Expand All @@ -14,11 +20,9 @@
# output_dir=data_root + 'proposals/',
# proposals_file='rpn_r50_fpn_1x_val2017.pkl'),
# dict(
# type='CocoMetric',
# ann_file=data_root + 'annotations/instances_val2017.json',
# metric='proposal_fast',
# file_client_args={{_base_.file_client_args}},
# format_only=False)
# type='ProposalRecallMetric',
# proposal_nums=(1, 10, 100, 1000),
# use_legacy_coordinate=False,
# ]

# inference on training dataset and dump the proposals without evaluate metric
Expand Down
100 changes: 81 additions & 19 deletions mmdet/datasets/api_wrappers/coco_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,102 @@ class COCO(_COCO):

It implements some snake case function aliases. So that the COCO class has
the same interface as LVIS class.

Args:
annotation_file (str, optional): Path of annotation file.
Defaults to None.
"""

def __init__(self, annotation_file=None):
def __init__(self, annotation_file: Optional[str] = None) -> None:
if getattr(pycocotools, '__version__', '0') >= '12.0.2':
warnings.warn(
'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
UserWarning)
'mmpycocotools is deprecated. '
'Please install official pycocotools by '
'"pip install pycocotools"', UserWarning)
super().__init__(annotation_file=annotation_file)
self.img_ann_map = self.imgToAnns
self.cat_img_map = self.catToImgs

def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
def get_ann_ids(self,
img_ids: Union[list, int] = [],
cat_ids: Union[list, int] = [],
area_rng: Union[list, int] = [],
iscrowd: Optional[bool] = None) -> list:
"""Get annotation ids that satisfy given filter conditions.

Args:
img_ids (list | int): Get annotations for given images.
cat_ids (list | int): Get categories for given images.
area_rng (list | int): Get annotations for given area range.
iscrowd (bool, optional): Get annotations for given crowd label.

Returns:
List: Integer array of annotation ids.
"""
return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)

def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
def get_cat_ids(self,
cat_names: Union[list, int] = [],
sup_names: Union[list, int] = [],
cat_ids: Union[list, int] = []) -> list:
"""Get category ids that satisfy given filter conditions.

Args:
cat_names (list | int): Get categories for given category names.
sup_names (list | int): Get categories for given supercategory
names.
cat_ids (list | int): Get categories for given category ids.

Returns:
List: Integer array of category ids.
"""
return self.getCatIds(cat_names, sup_names, cat_ids)

def get_img_ids(self, img_ids=[], cat_ids=[]):
def get_img_ids(self,
img_ids: Union[list, int] = [],
cat_ids: Union[list, int] = []) -> list:
"""Get image ids that satisfy given filter conditions.

Args:
img_ids (list | int): Get images for given ids
cat_ids (list | int): Get images with all given cats

Returns:
List: Integer array of image ids.
"""
return self.getImgIds(img_ids, cat_ids)

def load_anns(self, ids):
def load_anns(self, ids: Union[list, int] = []) -> list:
"""Load annotations with the specified ids.

Args:
ids (list | int): Integer ids specifying annotations.

Returns:
List[dict]: Loaded annotation objects.
"""
return self.loadAnns(ids)

def load_cats(self, ids):
def load_cats(self, ids: Union[list, int] = []) -> list:
"""Load categories with the specified ids.

Args:
ids (list | int): Integer ids specifying categories.

Returns:
List[dict]: loaded category objects.
"""
return self.loadCats(ids)

def load_imgs(self, ids):
def load_imgs(self, ids: Union[list, int] = []) -> list:
"""Load annotations with the specified ids.

Args:
ids (list): integer ids specifying image.

Returns:
List[dict]: Loaded image objects.
"""
return self.loadImgs(ids)


Expand All @@ -53,15 +122,8 @@ class COCOPanoptic(COCO):
"""This wrapper is for loading the panoptic style annotation file.

The format is shown in the CocoPanopticDataset class.

Args:
annotation_file (str, optional): Path of annotation file.
Defaults to None.
"""

def __init__(self, annotation_file: Optional[str] = None) -> None:
super(COCOPanoptic, self).__init__(annotation_file)

def createIndex(self) -> None:
"""Create index."""
# create index
Expand Down Expand Up @@ -114,16 +176,16 @@ def createIndex(self) -> None:

def load_anns(self,
ids: Union[List[int], int] = []) -> Optional[List[dict]]:
"""Load anns with the specified ids.
"""Load annotations with the specified ids.

``self.anns`` is a list of annotation lists instead of a
list of annotations.

Args:
ids (Union[List[int], int]): Integer ids specifying anns.
ids (Union[List[int], int]): Integer ids specifying annotations.

Returns:
anns (List[dict], optional): Loaded ann objects.
anns (List[dict], optional): Loaded annotation objects.
"""
anns = []

Expand Down
3 changes: 2 additions & 1 deletion mmdet/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from .dump_proposals_metric import DumpProposals
from .lvis_metric import LVISMetric
from .openimages_metric import OpenImagesMetric
from .proposal_recall_metric import ProposalRecallMetric
from .voc_metric import VOCMetric

__all__ = [
'CityScapesMetric', 'CocoMetric', 'CocoPanopticMetric', 'OpenImagesMetric',
'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals',
'CocoOccludedSeparatedMetric', 'DumpDetResults'
'CocoOccludedSeparatedMetric', 'DumpDetResults', 'ProposalRecallMetric'
]
Loading