Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError when loading COCO dataset with multiple segmentation masks for one class #1209

Open
2 tasks done
DancinParrot opened this issue May 20, 2024 · 13 comments
Open
2 tasks done
Labels
bug Something isn't working

Comments

@DancinParrot
Copy link

DancinParrot commented May 20, 2024

Search before asking

  • I have searched the Supervision issues and found no similar bug report.

Bug

My current COCO dataset includes annotations with more than 1 segmentation masks of the same class. A rough analogy is as follows whereby one eye of a cat is segmented as a whole but when exported from Fiftyone two polygons are produced (turned into segmentation masks):

cat

As a result, when the COCO dataset is loaded into my program using supervision, the program crashes with the following error:

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (7,) + inhomogeneous part.

After some research, I discovered Ultralytics' JSON2YOLO repository on GitHub, and adapted the library's merge_multi_segment() (seen here) function in supervision's coco.py file which then allows the COCO dataset to be loaded.

Environment

  • Supervision: 0.20.0
  • OS: OpenSUSE Tumbleweed 20240423
  • Python: 3.12.2

Minimal Reproducible Example

The following code is used to load the COCO dataset with the annotations_path being the path to a .json file containing the paths and annotations for all images in the dataset:

ds = sv.DetectionDataset.from_coco(
        images_directory_path=images_directory_path,
        annotations_path=annotations_path,
        force_masks=True,
    )

The following is an example of a class/category containing multiple segmentation masks:

{
 ...
    {
      "id": 41,
      "image_id": 6,
      "category_id": 0,
      "bbox": [
        694.801517364719,
        278.90263698033465,
        161.52883212628387,
        282.881946369456
      ],
      "segmentation": [
        [
          694,
          560.5,
          764.5,
          407,
          759.5,
          400,
          765.5,
          397,
          765.5,
          393,
          760.5,
          391,
          759.5,
          384,
          754.5,
          381,
          763,
          376.5,
          767.5,
          370,
          764.5,
          363,
          768.5,
          354,
          735,
          278.5,
          741.5,
          284,
          776,
          356.5,
          782,
          359.5,
          794.5,
          348,
          806.5,
          321,
          809.5,
          321,
          799.5,
          346,
          800,
          348.5,
          806.5,
          349,
          798.5,
          364,
          800.5,
          387,
          808.5,
          400,
          818.5,
          408,
          811,
          413.5,
          802,
          405.5,
          802.5,
          416,
          855.5,
          530,
          851.5,
          529,
          787.5,
          395,
          779,
          394.5,
          714.5,
          531,
          715.5,
          522,
          729,
          490.5,
          694,
          560.5
        ],
        [
          713,
          534.5,
          713,
          531.5,
          713,
          534.5
        ]
      ],
      "area": 45693.59042666829,
      "iscrowd": 0,
      "ignore": 0
    }
}

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@DancinParrot DancinParrot added the bug Something isn't working label May 20, 2024
@SkalskiP
Copy link
Collaborator

Hi @DancinParrot 👋🏻 Sorry for the late response, but I traveled a lot at the end of last week, and my access to GitHub was limited.

You're correct, we currently do not support loading multi-segment masks. I assume the change you want to make would be in the coco_annotations_to_masks function?

@DancinParrot
Copy link
Author

DancinParrot commented May 20, 2024

Hi @SkalskiP! Thanks for your response, I wasn't expecting a response this quick actually so no worries!

I see, that would explain the error. However, the error was actually raised from this line which is within the coco_annotations_to_detections function. I could not find any reference of the coco_annotations_to_masks function within this repo, did I miss out anything?

I modified the coco_annotations_to_detections function based on JSON2YOLO's implementation:

def coco_annotations_to_detections(
    image_annotations: List[dict], resolution_wh: Tuple[int, int], with_masks: bool
) -> Detections:
    #...

    if with_masks:
        polygons = []

        for image_annotation in image_annotations:
            segmentations = image_annotation["segmentation"]
            if len(segmentations) > 1:
                s = merge_multi_segment(segmentations)
                s = (
                    (np.concatenate(s, axis=0) / np.array(resolution_wh))
                    .reshape(-1)
                    .tolist()
                )
                reshaped = np.reshape(np.asarray(s, dtype=np.int32), (-1, 2))
            else:
                reshaped = np.reshape(
                    np.asarray(segmentations, dtype=np.int32), (-1, 2)
                )
            polygons.append(reshaped)

        #...

    return Detections(xyxy=xyxy, class_id=np.asarray(class_ids, dtype=int))

The aforementioned modification allows the dataset to be loaded. Though, I'm not sure if the outcome really fits my use case since the merge_multi_segment() function included in JSON2YOLO seems to connect all segmentation masks into one with a thin line, which I presume would form one whole mask as opposed to the intended seperate masks. Any thoughts on this?

@SkalskiP
Copy link
Collaborator

However, the error was actually raised from this line [...]

I'm very sorry. You're right, of course. I mentioned coco_annotations_to_masks because there is ongoing work on COCO loading and saving in #1163, and this method will appear in that PR. Don't worry about it.

As for merge_multi_segment, we can't use that implementation since JSON2YOLO is under an AGPL license, which would conflict with our MIT license. Therefore, we need to implement our own version of that function.

I'm not sure if the outcome really fits my use case since the merge_multi_segment() function included in JSON2YOLO seems to connect all segmentation masks into one with a thin line, which I presume would form one whole mask as opposed to the intended separate masks.

If you want to load this as two separate masks, your COCO JSON is incorrectly constructed. You should not have multiple lists under the segmentation key. These should be separate annotations. If you have multiple lists there, they should be loaded as a single mask.

@DancinParrot
Copy link
Author

DancinParrot commented May 21, 2024

I'm very sorry. You're right, of course. I mentioned coco_annotations_to_masks because there is ongoing work on COCO loading and saving in #1163, and this method will appear in that PR. Don't worry about it.

I see, no worries!

As for merge_multi_segment, we can't use that implementation since JSON2YOLO is under an AGPL license, which would conflict with our MIT license. Therefore, we need to implement our own version of that function.

Understood, I'm trying out different implementations currently to fix the issue.

If you want to load this as two separate masks, your COCO JSON is incorrectly constructed. You should not have multiple lists under the segmentation key. These should be separate annotations. If you have multiple lists there, they should be loaded as a single mask.

My apologies for the confusion, I might have misunderstood my dataset. In my current workflow, I export the annotated data from Label Studio in the form of a COCO dataset. Next, I import the dataset to Fiftyone for augmentation with Albumentation, which is then exported as a COCODetectionDataset to preserve the segmentation masks along with the bbox. Though, it seems that during this step, Fiftyone's native COCODetectionDataset exporter might have modified the structure of my dataset such that a mask is split into multiple parts (perhaps due to overlapping masks), resulting in supervision's inability to parse the dataset. However, from my observation, the dataset is still structured properly since I was able to import the dataset to Fiftyone and Label Studio again and the annotations remained unchanged. Thus, I doubt it's an issue with Fiftyone's exporter, but rather supervision's inability to merge the list of segmentation masks.

Lastly, I tested out JSON2YOLO's merge_multi_segment() function which appeared to have modified the masks to the point where the contour is indecipherable by opencv, as findContours() was unable to produce an output. Hence, there would be a need for another implementation, which I believe may be found and adapted from Fiftyone's source code, particularly the _coco_segmentation_to_mask() function (here). Would love to hear your thoughts on this approach. Thanks!

@DancinParrot
Copy link
Author

Just an update, after messing around with a few implementations, I finally came up with a functional code by combining _coco_segmentation_to_mask() from Fiftyone with mask2polygon() in ultralytics/JSON2YOLO#38. The flow is as follows:

  1. If more than 1 segmentation masks exist in the dataset, normalize, merge and encode the list of segmentation masks obtained from COCO's labels.json file into RLE.
  2. Decode the RLE and pass as argument to the mask2polygon() function.
  3. The mask2polygon() function finds the contours within a mask, obtains the hierarchy of contours in the form of [[Next, Previous, First_Child, Parent], ...] (see cv2.findContours() docs).
  4. Then, it will loop over the list of contours and merge those with a valid present (non -1 value in hierarchy list).
  5. However, for contours without parents, they will not be merged. In this case, my method involves either recursively merging all contours until there's only 1 left in the list or choose the biggest contour as the resulting polygon based on area (I chose the latter).

Although the resulting product is not perfect, but it works well enough (so far) for my use case. I'm most definitely open to feedback and suggestions on ways to improve as well as possible alternatives to this approach.

Should I open a PR for this? @SkalskiP

Code is roughly as follows:

# From https://github.com/voxel51/fiftyone/blob/8205caf7646e5e7cb38041a94efb97f6524c1db6/fiftyone/utils/coco.py
def normalize_coco_segmentation(segmentation):
    # Filter out empty segmentations
    # For polygons of 4 points (1 pixel), duplicate to convert to valid polygon
    _segmentation = []
    for seg in segmentation:
        if len(seg) == 0:
            continue

        if len(seg) == 4:
            seg *= 4

        _segmentation.append(seg)

    return _segmentation

# From https://github.com/ultralytics/JSON2YOLO/issues/38
def is_clockwise(contour):
    value = 0
    num = len(contour)
    for i, point in enumerate(contour):
        p1 = contour[i]
        if i < num - 1:
            p2 = contour[i + 1]
        else:
            p2 = contour[0]
        value += (p2[0][0] - p1[0][0]) * (p2[0][1] + p1[0][1])
    return value < 0


def get_merge_point_idx(contour1, contour2):
    idx1 = 0
    idx2 = 0
    distance_min = -1
    for i, p1 in enumerate(contour1):
        for j, p2 in enumerate(contour2):
            distance = pow(p2[0][0] - p1[0][0], 2) + pow(p2[0][1] - p1[0][1], 2)
            if distance_min < 0:
                distance_min = distance
                idx1 = i
                idx2 = j
            elif distance < distance_min:
                distance_min = distance
                idx1 = i
                idx2 = j
    return idx1, idx2


def merge_contours(contour1, contour2, idx1, idx2):
    contour = []
    for i in list(range(0, idx1 + 1)):
        contour.append(contour1[i])
    for i in list(range(idx2, len(contour2))):
        contour.append(contour2[i])
    for i in list(range(0, idx2 + 1)):
        contour.append(contour2[i])
    for i in list(range(idx1, len(contour1))):
        contour.append(contour1[i])
    contour = np.array(contour)
    return contour


def merge_with_parent(contour_parent, contour):
    if not is_clockwise(contour_parent):
        contour_parent = contour_parent[::-1]
    if is_clockwise(contour):
        contour = contour[::-1]
    idx1, idx2 = get_merge_point_idx(contour_parent, contour)
    return merge_contours(contour_parent, contour, idx1, idx2)


def mask2polygon(image):
    contours, hierarchies = cv2.findContours(
        image, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_KCOS
    )
    contours_approx = []
    polygons = []
    for contour in contours:
        epsilon = 0.001 * cv2.arcLength(contour, True)
        contour_approx = cv2.approxPolyDP(contour, epsilon, True)
        contours_approx.append(contour_approx)

    contours_parent = []
    for i, contour in enumerate(contours_approx):
        parent_idx = hierarchies[0][i][3]
        if parent_idx < 0 and len(contour) >= 3:
            contours_parent.append(contour)
        else:
            contours_parent.append([])

    for i, contour in enumerate(contours_approx):
        parent_idx = hierarchies[0][i][3]
        if parent_idx >= 0 and len(contour) >= 3:
            contour_parent = contours_parent[parent_idx]
            if len(contour_parent) == 0:
                continue
            contours_parent[parent_idx] = merge_with_parent(contour_parent, contour)

    contours_parent_tmp = []
    for contour in contours_parent:
        if len(contour) == 0:
            continue
        contours_parent_tmp.append(contour)

    polygons = []
    max_area = 0
    max_contour = None

    # Get the largest contour based on area
    for contour in contours_parent_tmp:
        area = cv2.contourArea(contour)

        if area > max_area:
            max_area = area
            max_contour = contour

    if max_contour is not None:
        polygon = max_contour.flatten().tolist()
        return polygon


def coco_segmentation_to_mask(segmentation, bbox, frame_size):
    x, y, w, h = bbox
    width, height = frame_size

    if isinstance(segmentation, list):
        # Polygon -- a single object might consist of multiple parts, so merge
        # all parts into one mask RLE code
        segmentation = normalize_coco_segmentation(segmentation)
        if len(segmentation) == 0:
            return None

        rle = mask_utils.merge(mask_utils.frPyObjects(segmentation, height, width))
    elif isinstance(segmentation["counts"], list):
        # Uncompressed RLE
        rle = mask_utils.frPyObjects(segmentation, height, width)
    else:
        # RLE
        rle = segmentation

    mask = mask_utils.decode(rle)
    polygon = mask2polygon(mask)

    return polygon


def coco_annotations_to_detections(
    image_annotations: List[dict], resolution_wh: Tuple[int, int], with_masks: bool
) -> Detections:
    if not image_annotations:
        return Detections.empty()

    class_ids = [
        image_annotation["category_id"] for image_annotation in image_annotations
    ]
    xyxy = [image_annotation["bbox"] for image_annotation in image_annotations]
    xyxy = np.asarray(xyxy)
    xyxy[:, 2:4] += xyxy[:, 0:2]

    if with_masks:
        polygons = []

        for image_annotation in image_annotations:

            segmentation = image_annotation["segmentation"]
            print("Segmentation: ", segmentation)
            if len(segmentation) > 1:
                s = coco_segmentation_to_mask(
                    segmentation, image_annotation["bbox"], resolution_wh
                )
                reshaped = np.reshape(np.asarray(s, dtype=np.int32), (-1, 2))
            else:
                reshaped = np.reshape(
                    np.asarray(image_annotation["segmentation"], dtype=np.int32),
                    (-1, 2),
                )

            polygons.append(reshaped)

        mask = _polygons_to_masks(polygons=polygons, resolution_wh=resolution_wh)
        return Detections(
            class_id=np.asarray(class_ids, dtype=int), xyxy=xyxy, mask=mask
        )

    return Detections(xyxy=xyxy, class_id=np.asarray(class_ids, dtype=int))

@SkalskiP
Copy link
Collaborator

Hi @DancinParrot 👋🏻

I must admit, I am very confused. Initially, I thought it was only about loading COCO annotations consisting of multiple segments. Is that still the case?

I don't quite understand why we need all these extra steps like RLE conversion and polygon conversion.

@DancinParrot
Copy link
Author

DancinParrot commented May 21, 2024

I must admit, I am very confused. Initially, I thought it was only about loading COCO annotations consisting of multiple segments. Is that still the case?

Yup, still that. It was my mistake, I misunderstood the issue, turns out Fiftyone split the mask for one annotation into multiple parts during export. The dataset is still structured properly though as it can be read on Fiftyone and Label Studio, it's only supervision that is unable to load the dataset.

I don't quite understand why we need all these extra steps like RLE conversion and polygon conversion.

The conversion to RLE I suppose merges all masks within an annotation to one array which is later used as input for the mask2polygon() function. This function is the one responsible for merging the contours together to form one polygon which then allows supervision to properly load the dataset.

EDIT: I have updated the original issue to include my understanding of the issue.

@SkalskiP
Copy link
Collaborator

I think we should just be able to update the section of the code here:

def coco_annotations_to_masks(
    image_annotations: List[dict], resolution_wh: Tuple[int, int]
) -> npt.NDArray[np.bool_]:
    return np.array(
        [
            rle_to_mask(
                rle=np.array(image_annotation["segmentation"]["counts"]),
                resolution_wh=resolution_wh,
            )
            if image_annotation["iscrowd"]
            else polygon_to_mask(
                polygon=np.reshape(
                    np.asarray(image_annotation["segmentation"], dtype=np.int32),
                    (-1, 2),
                ),
                resolution_wh=resolution_wh,
            )
            for image_annotation in image_annotations
        ],
        dtype=bool,
    )

It is unhappy because it does not expect multiple lists here image_annotation["segmentation"].

The easiest way (not the most efficient), but still a lot more efficient than the conversion through all of the representations above, is to loop through lists in image_annotation["segmentation"], create a separate mask for each of them and than run np.logical_or(mask_1, mask_2).

@DancinParrot
Copy link
Author

The easiest way (not the most efficient), but still a lot more efficient than the conversion through all of the representations above, is to loop through lists in image_annotation["segmentation"], create a separate mask for each of them and than run np.logical_or(mask_1, mask_2).

This seems a lot more efficient. Thanks! I'll try it out tomorrow when I get access to my work laptop and update this thread on the results.

@SkalskiP
Copy link
Collaborator

@DancinParrot Sure! Let me know how it goes.

@DancinParrot
Copy link
Author

Hi @SkalskiP ! Thank you so much for your help! Here's the code that I've implemented based on your recommendation and it seems to merge all the polygons very well:

def merge_masks(segmentations, resolution_wh):
    parent = None
    for s in segmentations:
        if parent is None:
            parent = polygon_to_mask(
                polygon=np.reshape(
                    np.asarray(s, dtype=np.int32),
                    (-1, 2),
                ),
                resolution_wh=resolution_wh,
            )
        else:
            mask = polygon_to_mask(
                polygon=np.reshape(
                    np.asarray(s, dtype=np.int32),
                    (-1, 2),
                ),
                resolution_wh=resolution_wh,
            )

            parent = np.logical_or(parent, mask)

    return parent


def coco_annotations_to_masks(
    image_annotations: List[dict], resolution_wh: Tuple[int, int]
) -> npt.NDArray[np.bool_]:
    return np.array(
        [
            (
                rle_to_mask(
                    rle=np.array(image_annotation["segmentation"]["counts"]),
                    resolution_wh=resolution_wh,
                )
                if image_annotation["iscrowd"]
                else (
                    merge_masks(image_annotation["segmentation"], resolution_wh)
                    if len(image_annotation["segmentation"]) > 1
                    else polygon_to_mask(
                        polygon=np.reshape(
                            np.asarray(
                                image_annotation["segmentation"], dtype=np.int32
                            ),
                            (-1, 2),
                        ),
                        resolution_wh=resolution_wh,
                    )
                )
            )
            for image_annotation in image_annotations
        ],
        dtype=bool,
    )

Any feedback for further improvement is much appreciated. Also, should I create a PR for this?

@SkalskiP
Copy link
Collaborator

Hi @DancinParrot 👋🏻, that seems like a good starting point. Please open a PR proposing the change. 🙏🏻

@DancinParrot
Copy link
Author

Hi @SkalskiP ! Will do, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants