Skip to content

Commit

Permalink
[Feature] Add the implementation of diff_iou_rotated with mlu-ops (op…
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielmic committed Jun 30, 2023
1 parent 6976a44 commit 7fc8adc
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ We implement common ops used in detection, segmentation, etc.
| Correlation | || | | |
| Deformable Convolution v1/v2 |||| ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DiffIoURotated | || | | |
| DynamicScatter | ||| | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Correlation | || | | |
| Deformable Convolution v1/v2 |||| ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DiffIoURotated | || | | |
| DynamicScatter | ||| | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
Expand Down
55 changes: 55 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"

Tensor diff_iou_rotated_sort_vertices_forward_mlu(Tensor vertices, Tensor mask,
Tensor num_valid) {
// params check
TORCH_CHECK(vertices.scalar_type() == at::kFloat,
"vertices type should be Float, got ", vertices.scalar_type());
TORCH_CHECK(mask.scalar_type() == at::kBool, "mask should be Bool, got ",
mask.scalar_type());
TORCH_CHECK(num_valid.scalar_type() == at::kInt,
"num_valid type should be Int32, got ", num_valid.scalar_type());
TORCH_CHECK(vertices.size(2) == 24, "vertices.dim(2) should be 24, got ",
vertices.size(2));
TORCH_CHECK(mask.size(2) == 24, "mask.dim(2) should be 24, got ",
mask.size(2));

// zero-element check
if (vertices.numel() == 0) {
return at::empty({0}, num_valid.options().dtype(at::kInt));
}

auto idx = at::empty({vertices.size(0), vertices.size(1), 9},
num_valid.options().dtype(at::kInt));

INITIAL_MLU_PARAM_WITH_TENSOR(vertices);
INITIAL_MLU_PARAM_WITH_TENSOR(mask);
INITIAL_MLU_PARAM_WITH_TENSOR(num_valid);
INITIAL_MLU_PARAM_WITH_TENSOR(idx);

// get compute handle
auto handle = mluOpGetCurrentHandle();

// launch kernel
mluOpDiffIouRotatedSortVerticesForward(
handle, vertices_desc.desc(), vertices_ptr, mask_desc.desc(), mask_ptr,
num_valid_desc.desc(), num_valid_ptr, idx_desc.desc(), idx_ptr);
return idx;
}

Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
Tensor num_valid);

REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, MLU,
diff_iou_rotated_sort_vertices_forward_mlu);
40 changes: 30 additions & 10 deletions tests/test_ops/test_diff_iou_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
import torch

from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

if IS_MLU_AVAILABLE:
torch.backends.mlu.matmul.allow_tf32 = False

@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_2d():

@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_2d(device):
np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0],
[0.5, 0.5, 1., 1., .0]]],
Expand All @@ -19,17 +31,25 @@ def test_diff_iou_rotated_2d():
[1.5, 1.5, 1., 1., .0]]],
dtype=np.float32)

boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)

np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]])
ious = diff_iou_rotated_2d(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_diff_iou_rotated_3d():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_diff_iou_rotated_3d(device):
np_boxes1 = np.asarray(
[[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0],
Expand All @@ -41,8 +61,8 @@ def test_diff_iou_rotated_3d():
[-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]],
dtype=np.float32)

boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)

np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]])
ious = diff_iou_rotated_3d(boxes1, boxes2)
Expand Down

0 comments on commit 7fc8adc

Please sign in to comment.