From 8708851eca17ad5d61307d7a08b702ad3e77bb4e Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Sat, 14 May 2022 19:46:27 +0800 Subject: [PATCH] [Fix] Fix bbox overlap fp16 (#1958) * add CUDA_ARCH check * add check in cuh --- mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh | 2 ++ mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh index 0f7be5e4de..15bd91eca6 100644 --- a/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh @@ -88,6 +88,7 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2, } } +#if __CUDA_ARCH__ >= 530 __device__ __forceinline__ __half __half_area(const __half x1, const __half y1, const __half x2, const __half y2, const __half offset) { @@ -141,5 +142,6 @@ __device__ void bbox_overlaps_cuda_kernel_half( ious[index] = __hdiv(interS, baseS); } } +#endif // __CUDA_ARCH__ >= 530 #endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu index 3aa7fd2588..b3272539bf 100644 --- a/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu @@ -4,6 +4,7 @@ // Disable fp16 on ROCm device #ifndef HIP_DIFF +#if __CUDA_ARCH__ >= 530 template <> __global__ void bbox_overlaps_cuda_kernel( const at::Half* bbox1, const at::Half* bbox2, at::Half* ious, @@ -14,6 +15,7 @@ __global__ void bbox_overlaps_cuda_kernel( reinterpret_cast<__half*>(ious), num_bbox1, num_bbox2, mode, aligned, offset); } +#endif // __CUDA_ARCH__ >= 530 #endif // HIP_DIFF void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,