From 07fef970a33ae120c8ff2a9efea3e83d9d903cff Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Sat, 7 Sep 2024 01:43:00 +0200 Subject: [PATCH] Use a constant for the amount of static SMEM (#2374) --- cub/benchmarks/bench/radix_sort/keys.cu | 3 ++- cub/benchmarks/bench/radix_sort/pairs.cu | 3 ++- cub/cub/util_arch.cuh | 13 +++++++++++-- cub/cub/util_type.cuh | 1 + cub/cub/util_vsmem.cuh | 5 +---- cub/test/catch2_test_block_load.cu | 3 ++- cub/test/catch2_test_block_store.cu | 3 ++- cub/test/test_block_radix_rank.cu | 3 ++- 8 files changed, 23 insertions(+), 11 deletions(-) diff --git a/cub/benchmarks/bench/radix_sort/keys.cu b/cub/benchmarks/bench/radix_sort/keys.cu index f3b7ba3867..b6b9e4fd53 100644 --- a/cub/benchmarks/bench/radix_sort/keys.cu +++ b/cub/benchmarks/bench/radix_sort/keys.cu @@ -26,6 +26,7 @@ ******************************************************************************/ #include +#include #include @@ -123,7 +124,7 @@ constexpr std::size_t max_temp_storage_size() template constexpr bool fits_in_default_shared_memory() { - return max_temp_storage_size() < 48 * 1024; + return max_temp_storage_size() < cub::detail::max_smem_per_block; } #else // TUNE_BASE template diff --git a/cub/benchmarks/bench/radix_sort/pairs.cu b/cub/benchmarks/bench/radix_sort/pairs.cu index 2729ce1b62..4a9f229bca 100644 --- a/cub/benchmarks/bench/radix_sort/pairs.cu +++ b/cub/benchmarks/bench/radix_sort/pairs.cu @@ -26,6 +26,7 @@ ******************************************************************************/ #include +#include #include @@ -121,7 +122,7 @@ constexpr std::size_t max_temp_storage_size() template constexpr bool fits_in_default_shared_memory() { - return max_temp_storage_size() < 48 * 1024; + return max_temp_storage_size() < cub::detail::max_smem_per_block; } #else // TUNE_BASE template diff --git a/cub/cub/util_arch.cuh b/cub/cub/util_arch.cuh index cb32766405..5f8780620f 100644 --- a/cub/cub/util_arch.cuh +++ b/cub/cub/util_arch.cuh @@ -136,13 +136,21 @@ static_assert(CUB_MAX_DEVICES > 0, "CUB_MAX_DEVICES must be greater than 0."); # define CUB_PTX_PREFER_CONFLICT_OVER_PADDING CUB_PREFER_CONFLICT_OVER_PADDING(0) # endif +namespace detail +{ +// The maximum amount of static shared memory available per thread block +// Note that in contrast to dynamic shared memory, static shared memory is still limited to 48 KB +static constexpr ::cuda::std::size_t max_smem_per_block = 48 * 1024; +} // namespace detail + template struct RegBoundScaling { enum { ITEMS_PER_THREAD = CUB_MAX(1, NOMINAL_4B_ITEMS_PER_THREAD * 4 / CUB_MAX(4, sizeof(T))), - BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS, (((1024 * 48) / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32), + BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS, + ((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32), }; }; @@ -153,7 +161,8 @@ struct MemBoundScaling { ITEMS_PER_THREAD = CUB_MAX(1, CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(T), NOMINAL_4B_ITEMS_PER_THREAD * 2)), - BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS, (((1024 * 48) / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32), + BLOCK_THREADS = CUB_MIN(NOMINAL_4B_BLOCK_THREADS, + ((cub::detail::max_smem_per_block / (sizeof(T) * ITEMS_PER_THREAD)) + 31) / 32 * 32), }; }; diff --git a/cub/cub/util_type.cuh b/cub/cub/util_type.cuh index e23f6e6578..8ae4e2d05b 100644 --- a/cub/cub/util_type.cuh +++ b/cub/cub/util_type.cuh @@ -44,6 +44,7 @@ #endif // no system header #include +#include #include #include diff --git a/cub/cub/util_vsmem.cuh b/cub/cub/util_vsmem.cuh index 6a0d6b9a94..d2e5541c09 100644 --- a/cub/cub/util_vsmem.cuh +++ b/cub/cub/util_vsmem.cuh @@ -42,6 +42,7 @@ # pragma system_header #endif // no system header +#include #include #include #include @@ -67,10 +68,6 @@ struct vsmem_t void* gmem_ptr; }; -// The maximum amount of static shared memory available per thread block -// Note that in contrast to dynamic shared memory, static shared memory is still limited to 48 KB -static constexpr std::size_t max_smem_per_block = 48 * 1024; - /** * @brief Class template that helps to prevent exceeding the available shared memory per thread block. * diff --git a/cub/test/catch2_test_block_load.cu b/cub/test/catch2_test_block_load.cu index 39bccc50c5..43fd75698f 100644 --- a/cub/test/catch2_test_block_load.cu +++ b/cub/test/catch2_test_block_load.cu @@ -28,6 +28,7 @@ #include #include #include +#include #include "catch2_test_helper.h" @@ -113,7 +114,7 @@ void block_load(InputIteratorT input, OutputIteratorT output, int num_items) using input_t = cub::detail::value_t; using block_load_t = cub::BlockLoad; using storage_t = typename block_load_t::TempStorage; - constexpr bool sufficient_resources = sizeof(storage_t) <= 1024 * 48; + constexpr bool sufficient_resources = sizeof(storage_t) <= cub::detail::max_smem_per_block; kernel <<<1, ThreadsInBlock>>>(std::integral_constant{}, input, output, num_items); diff --git a/cub/test/catch2_test_block_store.cu b/cub/test/catch2_test_block_store.cu index f157a28ea0..566dd2e828 100644 --- a/cub/test/catch2_test_block_store.cu +++ b/cub/test/catch2_test_block_store.cu @@ -29,6 +29,7 @@ #include #include #include +#include #include "catch2_test_helper.h" @@ -114,7 +115,7 @@ void block_store(InputIteratorT input, OutputIteratorT output, int num_items) using input_t = cub::detail::value_t; using block_store_t = cub::BlockStore; using storage_t = typename block_store_t::TempStorage; - constexpr bool sufficient_resources = sizeof(storage_t) <= 1024 * 48; + constexpr bool sufficient_resources = sizeof(storage_t) <= cub::detail::max_smem_per_block; kernel <<<1, ThreadsInBlock>>>(std::integral_constant{}, input, output, num_items); diff --git a/cub/test/test_block_radix_rank.cu b/cub/test/test_block_radix_rank.cu index 6d36378882..8c1df1a80c 100644 --- a/cub/test/test_block_radix_rank.cu +++ b/cub/test/test_block_radix_rank.cu @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -240,7 +241,7 @@ void Test() cub::detail::block_radix_rank_t; using storage_t = typename block_radix_rank::TempStorage; - cub::Int2Type<(sizeof(storage_t) <= 48 * 1024)> fits_smem_capacity; + cub::Int2Type<(sizeof(storage_t) <= cub::detail::max_smem_per_block)> fits_smem_capacity; TestValid(fits_smem_capacity); }