From ebea6bdf9786a3cbd2a7767d14b266553b3b26a7 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Mon, 19 Aug 2024 08:57:36 -0700 Subject: [PATCH] uses dynamic dispatch to unsigned type --- thrust/thrust/system/cuda/detail/dispatch.h | 14 ++++++++++++++ thrust/thrust/system/cuda/detail/scan.h | 8 ++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/thrust/thrust/system/cuda/detail/dispatch.h b/thrust/thrust/system/cuda/detail/dispatch.h index 90c99688f7..f1f3090f8e 100644 --- a/thrust/thrust/system/cuda/detail/dispatch.h +++ b/thrust/thrust/system/cuda/detail/dispatch.h @@ -90,6 +90,20 @@ status = call_64 arguments; \ } +/// Like \ref THRUST_INDEX_TYPE_DISPATCH2 but dispatching to uint32_t and uint64_t, respectively, depending on the +/// `count` argument. +#define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \ + if (count <= thrust::detail::integer_traits::const_max) \ + { \ + auto THRUST_PP_CAT2(count, _fixed) = static_cast(count); \ + status = call_32 arguments; \ + } \ + else \ + { \ + auto THRUST_PP_CAT2(count, _fixed) = static_cast(count); \ + status = call_64 arguments; \ + } + /// Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts. #define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \ if (count1 + count2 <= thrust::detail::integer_traits::const_max) \ diff --git a/thrust/thrust/system/cuda/detail/scan.h b/thrust/thrust/system/cuda/detail/scan.h index e2530691db..418bbc2f2e 100644 --- a/thrust/thrust/system/cuda/detail/scan.h +++ b/thrust/thrust/system/cuda/detail/scan.h @@ -72,7 +72,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl( // Determine temporary storage requirements: size_t tmp_size = 0; { - THRUST_INDEX_TYPE_DISPATCH2( + THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2( status, Dispatch32::Dispatch, Dispatch64::Dispatch, @@ -88,7 +88,7 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl( { // Allocate temporary storage: thrust::detail::temporary_array tmp{policy, tmp_size}; - THRUST_INDEX_TYPE_DISPATCH2( + THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2( status, Dispatch32::Dispatch, Dispatch64::Dispatch, @@ -122,7 +122,7 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl( // Determine temporary storage requirements: size_t tmp_size = 0; { - THRUST_INDEX_TYPE_DISPATCH2( + THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2( status, Dispatch32::Dispatch, Dispatch64::Dispatch, @@ -138,7 +138,7 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl( { // Allocate temporary storage: thrust::detail::temporary_array tmp{policy, tmp_size}; - THRUST_INDEX_TYPE_DISPATCH2( + THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2( status, Dispatch32::Dispatch, Dispatch64::Dispatch,