Skip to content

Commit

Permalink
Merge pull request #58 from roboflow/feature/update_to_support_sam
Browse files Browse the repository at this point in the history
feature/update_to_support_sam
  • Loading branch information
SkalskiP committed Apr 10, 2023
2 parents bc12a8e + ddc8a8c commit dba4d9f
Show file tree
Hide file tree
Showing 12 changed files with 292 additions and 45 deletions.
7 changes: 7 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
### 0.5.0 <small>April 10, 2023</small>

- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.mask` to enable segmentation support.
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `MaskAnnotator` to allow easy `Detections.mask` annotation.
- Added [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.from_sam` to enable native Segment Anything Model (SAM) support.
- Changed [[#58](https://github.com/roboflow/supervision/pull/58)]: `Detections.area` behaviour to work not only with boxes but also with masks.

### 0.4.0 <small>April 5, 2023</small>

- Added [[#46](https://github.com/roboflow/supervision/discussions/48)]: `Detections.empty` to allow easy creation of empty `Detections` objects.
Expand Down
6 changes: 5 additions & 1 deletion docs/detection/annotate.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## BoxAnnotator

:::supervision.detection.annotate.BoxAnnotator
:::supervision.detection.annotate.BoxAnnotator

## MaskAnnotator

:::supervision.detection.annotate.MaskAnnotator
7 changes: 7 additions & 0 deletions docs/detection/tools/polygon_zone.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## PolygonZone

:::supervision.detection.tools.polygon_zone.PolygonZone

## PolygonZoneAnnotator

:::supervision.detection.tools.polygon_zone.PolygonZoneAnnotator
6 changes: 5 additions & 1 deletion docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@

## non_max_suppression

:::supervision.detection.utils.non_max_suppression
:::supervision.detection.utils.non_max_suppression

## mask_to_xyxy

:::supervision.detection.utils.mask_to_xyxy
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ nav:
- Core: detection/core.md
- Annotate: detection/annotate.md
- Utils: detection/utils.md
- Tools:
- Polygon Zone: detection/tools/polygon_zone.md
- Draw:
- Utils: draw/utils.md
- Annotations:
Expand Down
8 changes: 4 additions & 4 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__version__ = "0.4.0"
__version__ = "0.5.0"

from supervision.annotation.voc import detections_to_voc_xml
from supervision.detection.annotate import BoxAnnotator
from supervision.detection.annotate import BoxAnnotator, MaskAnnotator
from supervision.detection.core import Detections
from supervision.detection.line_counter import LineZone, LineZoneAnnotator
from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator
from supervision.detection.utils import generate_2d_mask
from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator
from supervision.detection.utils import generate_2d_mask, mask_to_xyxy
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text
from supervision.geometry.core import Point, Position, Rect
Expand Down
62 changes: 60 additions & 2 deletions supervision/detection/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ def annotate(
np.ndarray: The image with the bounding boxes drawn on it
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections):
x1, y1, x2, y2 = xyxy.astype(int)
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
Expand Down Expand Up @@ -114,3 +117,58 @@ def annotate(
lineType=cv2.LINE_AA,
)
return scene


class MaskAnnotator:
"""
A class for overlaying masks on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to fill the mask, can be a single color or a color palette
"""

def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.default(),
):
self.color: Union[Color, ColorPalette] = color

def annotate(
self, scene: np.ndarray, detections: Detections, opacity: float = 0.5
) -> np.ndarray:
"""
Overlays the masks on the given image based on the provided detections, with a specified opacity.
Parameters:
scene (np.ndarray): The image on which the masks will be overlaid
detections (Detections): The detections for which the masks will be overlaid
opacity (float): The opacity of the masks, between 0 and 1, default is 0.5
Returns:
np.ndarray: The image with the masks overlaid
"""
for i in range(len(detections.xyxy)):
if detections.mask is None:
continue

class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)

mask = detections.mask[i]
colored_mask = np.zeros_like(scene, dtype=np.uint8)
colored_mask[:] = color.as_bgr()

scene = np.where(
np.expand_dims(mask, axis=-1),
np.uint8(opacity * colored_mask + (1 - opacity) * scene),
scene,
)

return scene
141 changes: 112 additions & 29 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,78 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union

import numpy as np

from supervision.detection.utils import non_max_suppression
from supervision.detection.utils import non_max_suppression, xywh_to_xyxy
from supervision.geometry.core import Position


def _validate_xyxy(xyxy: Any, n: int) -> None:
is_valid = isinstance(xyxy, np.ndarray) and xyxy.shape == (n, 4)
if not is_valid:
raise ValueError("xyxy must be 2d np.ndarray with (n, 4) shape")


def _validate_mask(mask: Any, n: int) -> None:
is_valid = mask is None or (
isinstance(mask, np.ndarray) and len(mask.shape) == 3 and mask.shape[0] == n
)
if not is_valid:
raise ValueError("mask must be 3d np.ndarray with (n, W, H) shape")


def _validate_class_id(class_id: Any, n: int) -> None:
is_valid = class_id is None or (
isinstance(class_id, np.ndarray) and class_id.shape == (n,)
)
if not is_valid:
raise ValueError("class_id must be None or 1d np.ndarray with (n,) shape")


def _validate_confidence(confidence: Any, n: int) -> None:
is_valid = confidence is None or (
isinstance(confidence, np.ndarray) and confidence.shape == (n,)
)
if not is_valid:
raise ValueError("confidence must be None or 1d np.ndarray with (n,) shape")


def _validate_tracker_id(tracker_id: Any, n: int) -> None:
is_valid = tracker_id is None or (
isinstance(tracker_id, np.ndarray) and tracker_id.shape == (n,)
)
if not is_valid:
raise ValueError("tracker_id must be None or 1d np.ndarray with (n,) shape")


@dataclass
class Detections:
"""
Data class containing information about the detections in a video frame.
Attributes:
xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
mask: (Optional[np.ndarray]): An array of shape `(n, W, H)` containing the segmentation masks.
class_id (Optional[np.ndarray]): An array of shape `(n,)` containing the class ids of the detections.
confidence (Optional[np.ndarray]): An array of shape `(n,)` containing the confidence scores of the detections.
tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
"""

xyxy: np.ndarray
mask: np.Optional[np.ndarray] = None
class_id: Optional[np.ndarray] = None
confidence: Optional[np.ndarray] = None
tracker_id: Optional[np.ndarray] = None

def __post_init__(self):
n = len(self.xyxy)
validators = [
(isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)),
self.class_id is None
or (isinstance(self.class_id, np.ndarray) and self.class_id.shape == (n,)),
self.confidence is None
or (
isinstance(self.confidence, np.ndarray)
and self.confidence.shape == (n,)
),
self.tracker_id is None
or (
isinstance(self.tracker_id, np.ndarray)
and self.tracker_id.shape == (n,)
),
]
if not all(validators):
raise ValueError(
"xyxy must be 2d np.ndarray with (n, 4) shape, "
"class_id must be None or 1d np.ndarray with (n,) shape, "
"confidence must be None or 1d np.ndarray with (n,) shape, "
"tracker_id must be None or 1d np.ndarray with (n,) shape"
)
_validate_xyxy(xyxy=self.xyxy, n=n)
_validate_mask(mask=self.mask, n=n)
_validate_class_id(class_id=self.class_id, n=n)
_validate_confidence(confidence=self.confidence, n=n)
_validate_tracker_id(tracker_id=self.tracker_id, n=n)

def __len__(self):
"""
Expand All @@ -59,13 +82,22 @@ def __len__(self):

def __iter__(
self,
) -> Iterator[Tuple[np.ndarray, Optional[float], int, Optional[Union[str, int]]]]:
) -> Iterator[
Tuple[
np.ndarray,
Optional[np.ndarray],
Optional[float],
Optional[int],
Optional[int],
]
]:
"""
Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
Iterates over the Detections object and yield a tuple of `(xyxy, mask, confidence, class_id, tracker_id)` for each detection.
"""
for i in range(len(self.xyxy)):
yield (
self.xyxy[i],
self.mask[i] if self.mask is not None else None,
self.confidence[i] if self.confidence is not None else None,
self.class_id[i] if self.class_id is not None else None,
self.tracker_id[i] if self.tracker_id is not None else None,
Expand All @@ -75,6 +107,12 @@ def __eq__(self, other: Detections):
return all(
[
np.array_equal(self.xyxy, other.xyxy),
any(
[
self.mask is None and other.mask is None,
np.array_equal(self.mask, other.mask),
]
),
any(
[
self.class_id is None and other.class_id is None,
Expand Down Expand Up @@ -113,7 +151,7 @@ def from_yolov5(cls, yolov5_results) -> Detections:
>>> from supervision import Detections
>>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
>>> results = model(frame)
>>> results = model(IMAGE)
>>> detections = Detections.from_yolov5(results)
```
"""
Expand Down Expand Up @@ -141,8 +179,8 @@ def from_yolov8(cls, yolov8_results) -> Detections:
>>> from supervision import Detections
>>> model = YOLO('yolov8s.pt')
>>> results = model(frame)[0]
>>> detections = Detections.from_yolov8(results)
>>> yolov8_results = model(IMAGE)[0]
>>> detections = Detections.from_yolov8(yolov8_results)
```
"""
return cls(
Expand Down Expand Up @@ -201,6 +239,37 @@ def from_roboflow(cls, roboflow_result: dict, class_list: List[str]) -> Detectio
class_id=np.array(class_id).astype(int),
)

@classmethod
def from_sam(cls, sam_result: List[dict]) -> Detections:
"""
Creates a Detections instance from Segment Anything Model (SAM) by Meta AI.
Args:
sam_result (List[dict]): The output Results instance from SAM
Returns:
Detections: A new Detections object.
Example:
```python
>>> from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
>>> import supervision as sv
>>> sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
>>> mask_generator = SamAutomaticMaskGenerator(sam)
>>> sam_result = mask_generator.generate(IMAGE)
>>> detections = sv.Detections.from_sam(sam_result=sam_result)
```
"""
sorted_generated_masks = sorted(
sam_result, key=lambda x: x["area"], reverse=True
)

xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
mask = np.array([mask["segmentation"] for mask in sorted_generated_masks])

return Detections(xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask)

@classmethod
def from_coco_annotations(cls, coco_annotation: dict) -> Detections:
xyxy, class_id = [], []
Expand Down Expand Up @@ -264,6 +333,20 @@ def __getitem__(self, index: np.ndarray) -> Detections:

@property
def area(self) -> np.ndarray:
"""
Calculate the area of each detection in the set of object detections. If masks field is defined property
returns are of each mask. If only box is given property return area of each box.
Returns:
np.ndarray: An array of floats containing the area of each detection in the format of `(area_1, area_2, ..., area_n)`, where n is the number of detections.
"""
if self.mask is not None:
return np.ndarray([np.sum(mask) for mask in self.mask])
else:
return self.box_area

@property
def box_area(self) -> np.ndarray:
"""
Calculate the area of each bounding box in the set of object detections.
Expand Down
Empty file.
Loading

0 comments on commit dba4d9f

Please sign in to comment.