-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement cub::DeviceFind::FindIf
#2405
base: main
Are you sure you want to change the base?
Changes from 1 commit
23a1b9a
ac0cfca
0a44c91
fe5316d
6f9fba8
612a9c2
e0bef31
0a9d815
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -33,6 +33,7 @@ | |||||||||
|
||||||||||
#include <cub/config.cuh> | ||||||||||
|
||||||||||
#include "cub/util_type.cuh" | ||||||||||
#include "device_launch_parameters.h" | ||||||||||
|
||||||||||
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) | ||||||||||
|
@@ -50,13 +51,22 @@ | |||||||||
#include <cub/thread/thread_operators.cuh> | ||||||||||
#include <cub/util_deprecated.cuh> | ||||||||||
|
||||||||||
#define elements_per_thread 16 | ||||||||||
static constexpr int elements_per_thread = 16; | ||||||||||
static constexpr int _VECTOR_LOAD_LENGTH = 4; | ||||||||||
static constexpr int block_threads = 128; | ||||||||||
|
||||||||||
CUB_NAMESPACE_BEGIN | ||||||||||
|
||||||||||
template <typename IterBegin, typename IterEnd, typename Pred> | ||||||||||
__global__ void find_if(IterBegin begin, IterEnd end, Pred pred, int* result, std::size_t num_items) | ||||||||||
{ | ||||||||||
using InputT = cub::detail::value_t<IterBegin>; | ||||||||||
using VectorT = typename CubVector<InputT, _VECTOR_LOAD_LENGTH>::Type; | ||||||||||
enum | ||||||||||
{ | ||||||||||
WORDS = elements_per_thread / _VECTOR_LOAD_LENGTH | ||||||||||
}; | ||||||||||
|
||||||||||
// int elements_per_thread = 32; | ||||||||||
auto tile_size = blockDim.x * elements_per_thread; | ||||||||||
__shared__ int sresult; | ||||||||||
|
@@ -83,14 +93,35 @@ __global__ void find_if(IterBegin begin, IterEnd end, Pred pred, int* result, st | |||||||||
return; | ||||||||||
} | ||||||||||
|
||||||||||
//// vectorized loads begin | ||||||||||
InputT* d_in_unqualified = const_cast<InputT*>(begin) + tile_offset + (threadIdx.x * _VECTOR_LOAD_LENGTH); | ||||||||||
|
||||||||||
cub::CacheModifiedInputIterator<cub::CacheLoadModifier::LOAD_DEFAULT, VectorT> d_vec_in( | ||||||||||
reinterpret_cast<VectorT*>(d_in_unqualified)); | ||||||||||
|
||||||||||
InputT input_items[elements_per_thread]; | ||||||||||
VectorT* vec_items = reinterpret_cast<VectorT*>(input_items); | ||||||||||
|
||||||||||
#pragma unroll | ||||||||||
for (int i = 0; i < WORDS; ++i) | ||||||||||
{ | ||||||||||
vec_items[i] = d_vec_in[block_threads * i]; | ||||||||||
} | ||||||||||
//// vectorized loads end | ||||||||||
// what about when input is not devisible by _VECTOR_LOAD_LENGTH? --> case for not full tile to be added | ||||||||||
|
||||||||||
bool found = false; | ||||||||||
for (int i = 0; i < elements_per_thread; ++i) | ||||||||||
{ | ||||||||||
auto index = tile_offset + threadIdx.x + i * blockDim.x; | ||||||||||
int index = i % WORDS + (i / WORDS) * block_threads * WORDS + threadIdx.x * WORDS + tile_offset; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expression for the evaluation of this index can of course be simplified further. It is as is for educational purposes. |
||||||||||
// i % WORDS = + 0 1 2 3, 0 1 2 3, 0 1 2 3, ... (static) | ||||||||||
// (i / WORDS) * block_threads * WORDS = + 0 , 64 , 128, ... (static) | ||||||||||
// threadIdx.x * WORDS = + 0, 4, 8, ... offset of the thread within working tile | ||||||||||
// tile_offset = + just start at the beginning of the block | ||||||||||
|
||||||||||
if (index < num_items) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider turning this into an early return
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole loob body is within the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok ok got it. I had to look it with a fresh eye and some morning coffee to get it. |
||||||||||
{ | ||||||||||
if (pred(*(begin + index))) | ||||||||||
if (pred(input_items[i])) | ||||||||||
{ | ||||||||||
found = true; | ||||||||||
atomicMin(&block_result, index); | ||||||||||
|
@@ -135,7 +166,6 @@ struct DeviceFind | |||||||||
NumItemsT num_items, | ||||||||||
cudaStream_t stream = 0) | ||||||||||
{ | ||||||||||
int block_threads = 128; | ||||||||||
// int items_per_thread = 2; | ||||||||||
int tile_size = block_threads * elements_per_thread; | ||||||||||
int num_tiles = static_cast<int>(cub::DivideAndRoundUp(num_items, tile_size)); | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should come after the other includes I assume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah still a silly draft, but why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because any header other than the config one should come after the system header guard so we do not accidentally emit warnings to users