Skip to content

Commit

Permalink
[Fix] Fix bbox overlap fp16 (#1958)
Browse files Browse the repository at this point in the history
* add CUDA_ARCH check

* add check in cuh
  • Loading branch information
grimoire committed May 14, 2022
1 parent a3b4640 commit 8708851
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
}
}

#if __CUDA_ARCH__ >= 530
__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
const __half x2, const __half y2,
const __half offset) {
Expand Down Expand Up @@ -141,5 +142,6 @@ __device__ void bbox_overlaps_cuda_kernel_half(
ious[index] = __hdiv(interS, baseS);
}
}
#endif // __CUDA_ARCH__ >= 530

#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Disable fp16 on ROCm device
#ifndef HIP_DIFF
#if __CUDA_ARCH__ >= 530
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
const at::Half* bbox1, const at::Half* bbox2, at::Half* ious,
Expand All @@ -14,6 +15,7 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}
#endif // __CUDA_ARCH__ >= 530
#endif // HIP_DIFF

void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Expand Down

0 comments on commit 8708851

Please sign in to comment.