-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement cub::DeviceFind::FindIf
#2405
base: main
Are you sure you want to change the base?
Changes from 5 commits
23a1b9a
ac0cfca
0a44c91
fe5316d
6f9fba8
612a9c2
e0bef31
0a9d815
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
* modification, are permitted provided that the following conditions are met: | ||
* * Redistributions of source code must retain the above copyright | ||
* notice, this list of conditions and the following disclaimer. | ||
* * Redistributions in binary form must reproduce the above copyright | ||
* notice, this list of conditions and the following disclaimer in the | ||
* documentation and/or other materials provided with the distribution. | ||
* * Neither the name of the NVIDIA CORPORATION nor the | ||
* names of its contributors may be used to endorse or promote products | ||
* derived from this software without specific prior written permission. | ||
* | ||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | ||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
* | ||
******************************************************************************/ | ||
|
||
#include <cub/device/device_find_if.cuh> | ||
|
||
#include <thrust/count.h> | ||
#include <thrust/find.h> | ||
|
||
#include <nvbench_helper.cuh> | ||
|
||
template <typename T> | ||
struct equals_100 | ||
{ | ||
__device__ bool operator()(T i) | ||
{ | ||
return i == 1; | ||
} // @amd you 'll never find out the secret sauce | ||
}; | ||
|
||
template <typename T> | ||
void find_if(nvbench::state& state, nvbench::type_list<T>) | ||
{ | ||
// set up input | ||
const auto elements = state.get_int64("Elements"); | ||
const auto common_prefix = state.get_float64("CommonPrefixRatio"); | ||
const auto same_elements = elements * common_prefix; | ||
|
||
thrust::device_vector<T> dinput(elements, 0); | ||
thrust::fill(dinput.begin() + same_elements, dinput.end(), 1); | ||
thrust::device_vector<T> d_result(1); | ||
/// | ||
|
||
void* d_temp_storage = nullptr; | ||
size_t temp_storage_bytes{}; | ||
|
||
cub::DeviceFind::FindIf( | ||
d_temp_storage, temp_storage_bytes, dinput.begin(), d_result.begin(), equals_100<int>{}, dinput.size(), 0); | ||
|
||
thrust::device_vector<uint8_t> temp_storage(temp_storage_bytes); | ||
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); | ||
|
||
state.exec(nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { | ||
cub::DeviceFind::FindIf( | ||
d_temp_storage, | ||
temp_storage_bytes, | ||
dinput.begin(), | ||
d_result.begin(), | ||
equals_100<int>{}, | ||
dinput.size(), | ||
launch.get_stream()); | ||
}); | ||
} | ||
NVBENCH_BENCH_TYPES(find_if, NVBENCH_TYPE_AXES(nvbench::type_list<int32_t> /*integral_types*/)) | ||
.add_int64_power_of_two_axis("Elements", nvbench::range(16, 28, 4)) | ||
.add_float64_axis("CommonPrefixRatio", std::vector{1.0, 0.5, 0.0}); | ||
|
||
////////////////////////////////////////////////////// | ||
template <typename T> | ||
void thrust_find_if(nvbench::state& state, nvbench::type_list<T>) | ||
{ | ||
// set up input | ||
const auto elements = state.get_int64("Elements"); | ||
const auto common_prefix = state.get_float64("CommonPrefixRatio"); | ||
const auto same_elements = elements * common_prefix; | ||
|
||
thrust::device_vector<T> dinput(elements, 0); | ||
thrust::fill(dinput.begin() + same_elements, dinput.end(), 1); | ||
/// | ||
|
||
caching_allocator_t alloc; | ||
thrust::find_if(policy(alloc), dinput.begin(), dinput.end(), equals_100<int>{}); | ||
|
||
state.exec(nvbench::exec_tag::no_batch | nvbench::exec_tag::sync, [&](nvbench::launch& launch) { | ||
thrust::find_if(policy(alloc, launch), dinput.begin(), dinput.end(), equals_100<int>{}); | ||
}); | ||
} | ||
NVBENCH_BENCH_TYPES(thrust_find_if, NVBENCH_TYPE_AXES(nvbench::type_list<int32_t> /*integral_types*/)) | ||
.add_int64_power_of_two_axis("Elements", nvbench::range(16, 28, 4)) | ||
.add_float64_axis("CommonPrefixRatio", std::vector{1.0, 0.5, 0.0}); | ||
|
||
////////////////////////////////////////////////////// | ||
template <typename T> | ||
void thrust_count_if(nvbench::state& state, nvbench::type_list<T>) | ||
{ | ||
// set up input | ||
const auto elements = state.get_int64("Elements"); | ||
const auto common_prefix = state.get_float64("CommonPrefixRatio"); | ||
const auto same_elements = elements * common_prefix; | ||
|
||
thrust::device_vector<T> dinput(elements, 0); | ||
thrust::fill(dinput.begin() + same_elements, dinput.end(), 1); | ||
/// | ||
|
||
caching_allocator_t alloc; | ||
thrust::count_if(policy(alloc), dinput.begin(), dinput.end(), equals_100<int>{}); | ||
|
||
state.exec(nvbench::exec_tag::no_batch | nvbench::exec_tag::sync, [&](nvbench::launch& launch) { | ||
thrust::count_if(policy(alloc, launch), dinput.begin(), dinput.end(), equals_100<int>{}); | ||
}); | ||
} | ||
NVBENCH_BENCH_TYPES(thrust_count_if, NVBENCH_TYPE_AXES(nvbench::type_list<int32_t> /*integral_types*/)) | ||
.add_int64_power_of_two_axis("Elements", nvbench::range(16, 28, 4)) | ||
.add_float64_axis("CommonPrefixRatio", std::vector{1.0, 0.5, 0.0}); |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,187 @@ | ||||||||||
/****************************************************************************** | ||||||||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||||||||||
* | ||||||||||
* Redistribution and use in source and binary forms, with or without | ||||||||||
* modification, are permitted provided that the following conditions are met: | ||||||||||
* * Redistributions of source code must retain the above copyright | ||||||||||
* notice, this list of conditions and the following disclaimer. | ||||||||||
* * Redistributions in binary form must reproduce the above copyright | ||||||||||
* notice, this list of conditions and the following disclaimer in the | ||||||||||
* documentation and/or other materials provided with the distribution. | ||||||||||
* * Neither the name of the NVIDIA CORPORATION nor the | ||||||||||
* names of its contributors may be used to endorse or promote products | ||||||||||
* derived from this software without specific prior written permission. | ||||||||||
* | ||||||||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||||||||
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | ||||||||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||||||||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||||||||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||||||||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||||||||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||||||||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||||||||||
* | ||||||||||
******************************************************************************/ | ||||||||||
|
||||||||||
//! @file | ||||||||||
//! cub::DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data | ||||||||||
//! items residing within device-accessible memory. | ||||||||||
|
||||||||||
#pragma once | ||||||||||
|
||||||||||
#include <cub/config.cuh> | ||||||||||
|
||||||||||
#include "device_launch_parameters.h" | ||||||||||
|
||||||||||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||||||||||
# pragma GCC system_header | ||||||||||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) | ||||||||||
# pragma clang system_header | ||||||||||
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) | ||||||||||
# pragma system_header | ||||||||||
#endif // no system header | ||||||||||
|
||||||||||
#include <cub/detail/choose_offset.cuh> | ||||||||||
#include <cub/detail/nvtx.cuh> | ||||||||||
#include <cub/device/dispatch/dispatch_scan.cuh> | ||||||||||
#include <cub/device/dispatch/dispatch_scan_by_key.cuh> | ||||||||||
#include <cub/thread/thread_operators.cuh> | ||||||||||
#include <cub/util_deprecated.cuh> | ||||||||||
|
||||||||||
#define elements_per_thread 16 | ||||||||||
|
||||||||||
CUB_NAMESPACE_BEGIN | ||||||||||
|
||||||||||
template <typename IterBegin, typename IterEnd, typename Pred> | ||||||||||
Comment on lines
+61
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing some documentation here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code is super early draft on the performance analysis stage still. Until we make sure that this is the right implementation to go (not sure yet) I won't be adding any docs - unless it makes it easier for the crowd to understand the new algo - but I will add a detailed explanation on the description for that. |
||||||||||
__global__ void find_if(IterBegin begin, IterEnd end, Pred pred, int* result, std::size_t num_items) | ||||||||||
{ | ||||||||||
// int elements_per_thread = 32; | ||||||||||
auto tile_size = blockDim.x * elements_per_thread; | ||||||||||
__shared__ int sresult; | ||||||||||
|
||||||||||
for (int tile_offset = blockIdx.x * tile_size; tile_offset < num_items; tile_offset += tile_size * gridDim.x) | ||||||||||
{ | ||||||||||
// Only one thread reads atomically and propagates it to the | ||||||||||
// the rest threads of the block through shared memory | ||||||||||
gonidelis marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
if (threadIdx.x == 0) | ||||||||||
{ | ||||||||||
sresult = atomicAdd(result, 0); | ||||||||||
} | ||||||||||
__syncthreads(); | ||||||||||
|
||||||||||
for (int i = 0; i < elements_per_thread; ++i) | ||||||||||
{ | ||||||||||
auto index = tile_offset + threadIdx.x + i * blockDim.x; | ||||||||||
|
||||||||||
if (index < num_items) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider turning this into an early return
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole loob body is within the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok ok got it. I had to look it with a fresh eye and some morning coffee to get it. |
||||||||||
{ | ||||||||||
// early exit | ||||||||||
if (sresult < index) | ||||||||||
{ | ||||||||||
return; | ||||||||||
} | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: I'd hoist this before elements per thread loop. The invariant here is that |
||||||||||
|
||||||||||
if (pred(*(begin + index))) | ||||||||||
{ | ||||||||||
atomicMin(result, index); | ||||||||||
return; | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: a few observations about this code:
Instead, you could experiment with an approach where threads first issue atomic min for shared memory, then outside of the loop over elements per thread whole block synchronizes, and if a minimal This could make an iteration over a single tile a bit more expensive for the case when no predicate where found, but on the other hand we'll address (1) because now remaining threads will see that they have to exit without accessing global memory, and (2) because now at most one thread of a given CTA issues an atomic. |
||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template <typename ValueType, typename OutputIteratorT> | ||||||||||
__global__ void write_final_result_in_output_iterator_already(ValueType* d_temp_storage, OutputIteratorT d_out) | ||||||||||
{ | ||||||||||
*d_out = *d_temp_storage; | ||||||||||
} | ||||||||||
|
||||||||||
template <typename ValueType, typename NumItemsT> | ||||||||||
__global__ void cuda_mem_set_async_dtemp_storage(ValueType* d_temp_storage, NumItemsT num_items) | ||||||||||
{ | ||||||||||
*d_temp_storage = num_items; | ||||||||||
} | ||||||||||
|
||||||||||
struct DeviceFind | ||||||||||
{ | ||||||||||
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename NumItemsT> | ||||||||||
CUB_RUNTIME_FUNCTION static void FindIf( | ||||||||||
void* d_temp_storage, | ||||||||||
size_t& temp_storage_bytes, | ||||||||||
InputIteratorT d_in, | ||||||||||
OutputIteratorT d_out, | ||||||||||
ScanOpT op, | ||||||||||
NumItemsT num_items, | ||||||||||
cudaStream_t stream = 0) | ||||||||||
{ | ||||||||||
int block_threads = 128; | ||||||||||
// int items_per_thread = 2; | ||||||||||
int tile_size = block_threads * elements_per_thread; | ||||||||||
int num_tiles = static_cast<int>(cub::DivideAndRoundUp(num_items, tile_size)); | ||||||||||
|
||||||||||
// Get device ordinal | ||||||||||
int device_ordinal; | ||||||||||
cudaError error = CubDebug(cudaGetDevice(&device_ordinal)); | ||||||||||
if (cudaSuccess != error) | ||||||||||
{ | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
// Get SM count | ||||||||||
int sm_count; | ||||||||||
error = CubDebug(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal)); | ||||||||||
if (cudaSuccess != error) | ||||||||||
{ | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
int find_if_sm_occupancy; | ||||||||||
error = CubDebug( | ||||||||||
cub::MaxSmOccupancy(find_if_sm_occupancy, find_if<InputIteratorT, InputIteratorT, ScanOpT>, block_threads)); | ||||||||||
if (cudaSuccess != error) | ||||||||||
{ | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
int findif_device_occupancy = find_if_sm_occupancy * sm_count; | ||||||||||
|
||||||||||
// Even-share work distribution | ||||||||||
int max_blocks = findif_device_occupancy; // no * CUB_SUBSCRIPTION_FACTOR(0) because max_blocks gets too big | ||||||||||
|
||||||||||
int findif_grid_size = CUB_MIN(num_tiles, max_blocks); | ||||||||||
|
||||||||||
// Temporary storage allocation requirements | ||||||||||
void* allocations[1] = {}; | ||||||||||
size_t allocation_sizes[1] = {sizeof(int)}; | ||||||||||
|
||||||||||
// Alias the temporary allocations from the single storage blob (or | ||||||||||
// compute the necessary size of the blob) | ||||||||||
error = CubDebug(AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)); | ||||||||||
if (cudaSuccess != error) | ||||||||||
{ | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
int* int_temp_storage = static_cast<int*>(allocations[0]); // this shouldn't be just int | ||||||||||
|
||||||||||
if (d_temp_storage == nullptr) | ||||||||||
{ | ||||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
// use d_temp_storage as the intermediate device result | ||||||||||
// to read and write from. Then store the final result in the output iterator. | ||||||||||
cuda_mem_set_async_dtemp_storage<<<1, 1>>>(int_temp_storage, num_items); | ||||||||||
|
||||||||||
find_if<<<findif_grid_size, block_threads, 0, stream>>>(d_in, d_in + num_items, op, int_temp_storage, num_items); | ||||||||||
|
||||||||||
write_final_result_in_output_iterator_already<int><<<1, 1>>>(int_temp_storage, d_out); | ||||||||||
|
||||||||||
return; | ||||||||||
} | ||||||||||
}; | ||||||||||
|
||||||||||
CUB_NAMESPACE_END |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should come after the other includes I assume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah still a silly draft, but why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because any header other than the config one should come after the system header guard so we do not accidentally emit warnings to users