Skip to content

Commit

Permalink
[Feature] Add the implementation of dynamic_scatter with mlu-ops (#2847)
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielmic committed Jun 29, 2023
1 parent af0aadd commit 10c8b9e
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/en/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ We implement common ops used in detection, segmentation, etc.
| Deformable Convolution v1/v2 ||| | ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DynamicScatter | || | | |
| DynamicScatter | || | | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | ||
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| Deformable Convolution v1/v2 ||| | ||
| Deformable RoIPool | ||| ||
| DiffIoURotated | || | | |
| DynamicScatter | || | | |
| DynamicScatter | || | | |
| FurthestPointSample | || | | |
| FurthestPointSampleWithDist | || | | |
| FusedBiasLeakyrelu | || | ||
Expand Down
13 changes: 13 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input) {
return layout;
}

mluOpReduceMode_t getMluOpReduceMode(const reduce_t reduce_type) {
const std::map<reduce_t, mluOpReduceMode_t> mapping_type = {
{reduce_t::MAX, MLUOP_REDUCE_DMAX},
{reduce_t::SUM, MLUOP_REDUCE_DSUM},
{reduce_t::MEAN, MLUOP_REDUCE_DMEAN}};
if (mapping_type.find(reduce_type) != mapping_type.end()) {
return mapping_type.find(reduce_type)->second;
} else {
TORCH_CHECK(false, "Unsupported reduce type: ", to_string(reduce_type));
return MLUOP_REDUCE_DSUM;
}
}

void MluOpTensorDescriptor::set(Tensor t) {
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
mluOpTensorLayout_t layout = getMluOpSuggestLayout(t);
Expand Down
32 changes: 30 additions & 2 deletions mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,39 @@
#include "pytorch_device_registry.hpp"

#define MLUOP_MAJOR 0
#define MLUOP_MINOR 6
#define MLUOP_PATCHLEVEL 0
#define MLUOP_MINOR 7
#define MLUOP_PATCHLEVEL 1

/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();

enum class reduce_t{ SUM = 0, MEAN = 1, MAX = 2 };

inline std::string to_string(reduce_t reduce_type) {
if (reduce_type == reduce_t::MAX) {
return "max";
} else if (reduce_type == reduce_t::MEAN) {
return "mean";
} else if (reduce_type == reduce_t::SUM) {
return "sum";
} else {
return "unknown reduce type";
}
}

mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
mluOpReduceMode_t getMluOpReduceMode(const reduce_t reduce_type);

class MluOpTensorDescriptor {
public:
Expand Down
13 changes: 0 additions & 13 deletions mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"

/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();

Tensor MsDeformAttnForwardLauncher(const Tensor& value,
const Tensor& spatial_shapes,
const Tensor& level_start_index,
Expand Down
178 changes: 178 additions & 0 deletions mmcv/ops/csrc/pytorch/mlu/scatter_points_mlu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"

std::vector<Tensor> dynamic_point_to_voxel_forward_mlu(const Tensor &feats,
const Tensor &coors,
const reduce_t reduce_type) {
// params check
TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(coors.scalar_type() == at::kInt,
"coors type should be Int32, got ", coors.scalar_type());
TORCH_CHECK(feats.size(0) == coors.size(0),
"feats.dim(0) and coors.dim(0) should be same, got ", feats.size(0), " vs ", coors.size(0));

const int num_input = feats.size(0);
const int num_feats = feats.size(1);
// zero-element check
if (num_input == 0)
return {feats.clone().detach(), coors.clone().detach(),
coors.new_empty({0}, torch::kInt32),
coors.new_empty({0}, torch::kInt32)};

auto mlu_reduce_type = getMluOpReduceMode(reduce_type);
auto reduced_feats = at::empty({num_input, num_feats}, feats.options());
auto out_coors = at::empty({num_input, 3}, coors.options());
auto coors_map = at::empty({num_input}, coors.options());
auto reduce_count = at::empty({num_input}, coors.options());
auto voxel_num = at::empty({1}, coors.options());

INITIAL_MLU_PARAM_WITH_TENSOR(feats);
INITIAL_MLU_PARAM_WITH_TENSOR(coors);
INITIAL_MLU_PARAM_WITH_TENSOR(reduced_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(out_coors);
INITIAL_MLU_PARAM_WITH_TENSOR(coors_map);
INITIAL_MLU_PARAM_WITH_TENSOR(reduce_count);
INITIAL_MLU_PARAM_WITH_TENSOR(voxel_num);

// get compute handle
auto handle = mluOpGetCurrentHandle();

size_t workspace_size;
mluOpGetDynamicPointToVoxelForwardWorkspaceSize(handle,
feats_desc.desc(),
coors_desc.desc(),
&workspace_size);
auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);

// launch kernel
mluOpDynamicPointToVoxelForward(handle,
mlu_reduce_type,
feats_desc.desc(),
feats_ptr,
coors_desc.desc(),
coors_ptr,
workspace_tensor_ptr,
workspace_size,
reduced_feats_desc.desc(),
reduced_feats_ptr,
out_coors_desc.desc(),
out_coors_ptr,
coors_map_desc.desc(),
coors_map_ptr,
reduce_count_desc.desc(),
reduce_count_ptr,
voxel_num_desc.desc(),
voxel_num_ptr);

int voxel_num_value = *static_cast<int *>(voxel_num.cpu().data_ptr());
TORCH_CHECK(voxel_num_value <= feats.size(0),
"voxel_num should be less than or equal to feats_num, got ", voxel_num_value, " vs ", feats.size(0));
return {reduced_feats.slice(0, 0, voxel_num_value), out_coors.slice(0, 0, voxel_num_value),
coors_map, reduce_count.slice(0, 0, voxel_num_value)};
}

void dynamic_point_to_voxel_backward_mlu(Tensor &grad_feats,
const Tensor &grad_reduced_feats,
const Tensor &feats,
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type) {
// params check
TORCH_CHECK(grad_reduced_feats.scalar_type() == at::kFloat,
"grad_reduced_feats type should be Float, got ", grad_reduced_feats.scalar_type());
TORCH_CHECK(feats.scalar_type() == at::kFloat,
"feats type should be Float, got ", feats.scalar_type());
TORCH_CHECK(reduced_feats.scalar_type() == at::kFloat,
"reduced_feats type should be Float, got ", reduced_feats.scalar_type());
TORCH_CHECK(coors_idx.scalar_type() == at::kInt,
"coors_idx type should be Int32, got ", coors_idx.scalar_type());
TORCH_CHECK(reduce_count.scalar_type() == at::kInt,
"reduce_count type should be Int32, got ", reduce_count.scalar_type());

const int num_input = feats.size(0);
const int num_reduced = reduced_feats.size(0);
const int num_feats = feats.size(1);

grad_feats.fill_(0);

// zero-element check
if (num_input == 0 || num_reduced == 0) return;

// TODO(miaochen): remove this after mlu-ops supports other mode of reduce.
TORCH_CHECK(reduce_type == reduce_t::MAX,
"only supports max reduce in current version, got ", to_string(reduce_type));

int voxel_num_value = reduced_feats.size(0);
auto opts = torch::TensorOptions().dtype(torch::kInt32);
auto voxel_num = torch::from_blob(&voxel_num_value, {1}, opts).clone().to(at::kMLU);
auto mlu_reduce_type = getMluOpReduceMode(reduce_type);

INITIAL_MLU_PARAM_WITH_TENSOR(grad_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(grad_reduced_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(feats);
INITIAL_MLU_PARAM_WITH_TENSOR(reduced_feats);
INITIAL_MLU_PARAM_WITH_TENSOR(coors_idx);
INITIAL_MLU_PARAM_WITH_TENSOR(reduce_count);
INITIAL_MLU_PARAM_WITH_TENSOR(voxel_num);

// get compute handle
auto handle = mluOpGetCurrentHandle();

size_t workspace_size;
mluOpGetDynamicPointToVoxelBackwardWorkspaceSize(
handle, mlu_reduce_type,
grad_feats_desc.desc(),
feats_desc.desc(),
grad_reduced_feats_desc.desc(),
coors_idx_desc.desc(),
reduce_count_desc.desc(),
voxel_num_desc.desc(),
&workspace_size);
auto workspace_tensor =
at::empty(workspace_size, feats.options().dtype(at::kByte));
INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor);

// launch kernel
mluOpDynamicPointToVoxelBackward(
handle, mlu_reduce_type,
grad_reduced_feats_desc.desc(),
grad_reduced_feats_ptr,
feats_desc.desc(), feats_ptr,
reduced_feats_desc.desc(), reduced_feats_ptr,
coors_idx_desc.desc(), coors_idx_ptr,
reduce_count_desc.desc(), reduce_count_ptr,
voxel_num_desc.desc(), voxel_num_ptr,
workspace_tensor_ptr, workspace_size,
grad_feats_desc.desc(), grad_feats_ptr);
}

std::vector<Tensor> dynamic_point_to_voxel_forward_impl(const Tensor &feats,
const Tensor &coors,
const reduce_t reduce_type);

void dynamic_point_to_voxel_backward_impl(Tensor &grad_feats,
const Tensor &grad_reduced_feats,
const Tensor &feats,
const Tensor &reduced_feats,
const Tensor &coors_idx,
const Tensor &reduce_count,
const reduce_t reduce_type);

REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MLU,
dynamic_point_to_voxel_forward_mlu);
REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_backward_impl, MLU,
dynamic_point_to_voxel_backward_mlu);
39 changes: 25 additions & 14 deletions tests/test_ops/test_scatter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@
from torch.autograd import gradcheck

from mmcv.ops import DynamicScatter
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_dynamic_scatter():
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_dynamic_scatter(device):
dsmean = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], True)
dsmax = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], False)

# test empty input
empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device='cuda')
empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device='cuda')
empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device=device)
empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device=device)

empty_feats.requires_grad_()
empty_feats_out_mean, empty_coors_out_mean = dsmean(
Expand All @@ -35,9 +44,9 @@ def test_dynamic_scatter():

# test empty reduced output
empty_o_feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
empty_o_coors = torch.randint(
low=-1, high=0, size=(200000, 3), dtype=torch.int32, device='cuda')
low=-1, high=0, size=(200000, 3), dtype=torch.int32, device=device)

empty_o_feats.requires_grad_()
empty_o_feats_out_mean, empty_o_coors_out_mean = dsmean(
Expand All @@ -52,9 +61,9 @@ def test_dynamic_scatter():

# test non-empty input
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
coors = torch.randint(
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device=device)

ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
Expand Down Expand Up @@ -88,9 +97,9 @@ def test_dynamic_scatter():

# test non-empty input without any point out of bound
feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
size=(200000, 3), dtype=torch.float32, device=device) * 100 - 50
coors = torch.randint(
low=0, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
low=0, high=20, size=(200000, 3), dtype=torch.int32, device=device)

ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
Expand Down Expand Up @@ -124,9 +133,11 @@ def test_dynamic_scatter():

# test grad #
feats = torch.rand(
size=(100, 4), dtype=torch.float32, device='cuda') * 100 - 50
size=(100, 4), dtype=torch.float32, device=device) * 100 - 50
coors = torch.randint(
low=-1, high=3, size=(100, 3), dtype=torch.int32, device='cuda')
low=-1, high=3, size=(100, 3), dtype=torch.int32, device=device)
feats.requires_grad_()
gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
# TODO(Cambricon): mlu only support max reduce in current version.
if not IS_MLU_AVAILABLE:
gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
gradcheck(dsmax, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)

0 comments on commit 10c8b9e

Please sign in to comment.