From 612a9c23404aab5cb8da147bcc568055f06b3d35 Mon Sep 17 00:00:00 2001 From: gonidelis Date: Tue, 17 Sep 2024 10:41:40 -0700 Subject: [PATCH] Optimized solution --- cub/cub/device/device_find_if.cuh | 32 ++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/cub/cub/device/device_find_if.cuh b/cub/cub/device/device_find_if.cuh index f6e1ce95cd..a645680014 100644 --- a/cub/cub/device/device_find_if.cuh +++ b/cub/cub/device/device_find_if.cuh @@ -60,6 +60,12 @@ __global__ void find_if(IterBegin begin, IterEnd end, Pred pred, int* result, st // int elements_per_thread = 32; auto tile_size = blockDim.x * elements_per_thread; __shared__ int sresult; + __shared__ int block_result; + + if (threadIdx.x == 0) + { + block_result = num_items; + } for (int tile_offset = blockIdx.x * tile_size; tile_offset < num_items; tile_offset += tile_size * gridDim.x) { @@ -71,22 +77,34 @@ __global__ void find_if(IterBegin begin, IterEnd end, Pred pred, int* result, st } __syncthreads(); + // early exit + if (sresult < tile_offset) + { + return; + } + + bool found = false; for (int i = 0; i < elements_per_thread; ++i) { auto index = tile_offset + threadIdx.x + i * blockDim.x; if (index < num_items) { - // early exit - if (sresult < index) + if (pred(*(begin + index))) { - return; + found = true; + atomicMin(&block_result, index); + break; } - - if (pred(*(begin + index))) + } + } + if (syncthreads_or(found)) + { + if (threadIdx.x == 0) + { + if (block_result < num_items) { - atomicMin(result, index); - return; + atomicMin(result, block_result); } } }