Skip to content

Commit

Permalink
Extract reduction kernels into NVRTC-compilable header (#2231)
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Aug 14, 2024
1 parent d7c83fe commit 64d28d1
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 232 deletions.
8 changes: 3 additions & 5 deletions cub/cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@

#include <cuda/std/type_traits>

#include <iterator>

_CCCL_SUPPRESS_DEPRECATED_PUSH
#include <cuda/std/functional>
_CCCL_SUPPRESS_DEPRECATED_POP
Expand Down Expand Up @@ -147,7 +145,7 @@ struct AgentReduce
// Wrap the native input pointer with CacheModifiedInputIterator
// or directly use the supplied input iterator type
using WrappedInputIteratorT =
::cuda::std::_If<std::is_pointer<InputIteratorT>::value,
::cuda::std::_If<::cuda::std::is_pointer<InputIteratorT>::value,
CacheModifiedInputIterator<AgentReducePolicy::LOAD_MODIFIER, InputT, OffsetT>,
InputIteratorT>;

Expand All @@ -160,8 +158,8 @@ struct AgentReduce
// Can vectorize according to the policy if the input iterator is a native
// pointer to a primitive type
static constexpr bool ATTEMPT_VECTORIZATION =
(VECTOR_LOAD_LENGTH > 1) && (ITEMS_PER_THREAD % VECTOR_LOAD_LENGTH == 0) && (std::is_pointer<InputIteratorT>::value)
&& Traits<InputT>::PRIMITIVE;
(VECTOR_LOAD_LENGTH > 1) && (ITEMS_PER_THREAD % VECTOR_LOAD_LENGTH == 0)
&& (::cuda::std::is_pointer<InputIteratorT>::value) && Traits<InputT>::PRIMITIVE;

static constexpr CacheLoadModifier LOAD_MODIFIER = AgentReducePolicy::LOAD_MODIFIER;

Expand Down
228 changes: 1 addition & 227 deletions cub/cub/device/dispatch/dispatch_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#endif // no system header

#include <cub/agent/agent_reduce.cuh>
#include <cub/device/dispatch/kernels/reduce.cuh>
#include <cub/grid/grid_even_share.cuh>
#include <cub/iterator/arg_index_input_iterator.cuh>
#include <cub/thread/thread_operators.cuh>
Expand All @@ -66,233 +67,6 @@ _CCCL_SUPPRESS_DEPRECATED_POP

CUB_NAMESPACE_BEGIN

namespace detail
{
namespace reduce
{

/**
* All cub::DeviceReduce::* algorithms are using the same implementation. Some of them, however,
* should use initial value only for empty problems. If this struct is used as initial value with
* one of the `DeviceReduce` algorithms, the `init` value wrapped by this struct will only be used
* for empty problems; it will not be incorporated into the aggregate of non-empty problems.
*/
template <class T>
struct empty_problem_init_t
{
T init;

_CCCL_HOST_DEVICE operator T() const
{
return init;
}
};

/**
* @brief Applies initial value to the block aggregate and stores the result to the output iterator.
*
* @param d_out Iterator to the output aggregate
* @param reduction_op Binary reduction functor
* @param init Initial value
* @param block_aggregate Aggregate value computed by the block
*/
template <class OutputIteratorT, class ReductionOpT, class InitT, class AccumT>
_CCCL_HOST_DEVICE void
finalize_and_store_aggregate(OutputIteratorT d_out, ReductionOpT reduction_op, InitT init, AccumT block_aggregate)
{
*d_out = reduction_op(init, block_aggregate);
}

/**
* @brief Ignores initial value and stores the block aggregate to the output iterator.
*
* @param d_out Iterator to the output aggregate
* @param block_aggregate Aggregate value computed by the block
*/
template <class OutputIteratorT, class ReductionOpT, class InitT, class AccumT>
_CCCL_HOST_DEVICE void
finalize_and_store_aggregate(OutputIteratorT d_out, ReductionOpT, empty_problem_init_t<InitT>, AccumT block_aggregate)
{
*d_out = block_aggregate;
}
} // namespace reduce
} // namespace detail

/******************************************************************************
* Kernel entry points
*****************************************************************************/

/**
* @brief Reduce region kernel entry point (multi-block). Computes privatized
* reductions, one per thread block.
*
* @tparam ChainedPolicyT
* Chained tuning policy
*
* @tparam InputIteratorT
* Random-access input iterator type for reading input items @iterator
*
* @tparam OffsetT
* Signed integer type for global offsets
*
* @tparam ReductionOpT
* Binary reduction functor type having member
* `auto operator()(const T &a, const U &b)`
*
* @tparam InitT
* Initial value type
*
* @tparam AccumT
* Accumulator type
*
* @param[in] d_in
* Pointer to the input sequence of data items
*
* @param[out] d_out
* Pointer to the output aggregate
*
* @param[in] num_items
* Total number of input data items
*
* @param[in] even_share
* Even-share descriptor for mapping an equal number of tiles onto each
* thread block
*
* @param[in] reduction_op
* Binary reduction functor
*/
template <typename ChainedPolicyT,
typename InputIteratorT,
typename OffsetT,
typename ReductionOpT,
typename AccumT,
typename TransformOpT>
CUB_DETAIL_KERNEL_ATTRIBUTES
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) void DeviceReduceKernel(
InputIteratorT d_in,
AccumT* d_out,
OffsetT num_items,
GridEvenShare<OffsetT> even_share,
ReductionOpT reduction_op,
TransformOpT transform_op)
{
// Thread block type for reducing input tiles
using AgentReduceT =
AgentReduce<typename ChainedPolicyT::ActivePolicy::ReducePolicy,
InputIteratorT,
AccumT*,
OffsetT,
ReductionOpT,
AccumT,
TransformOpT>;

// Shared memory storage
__shared__ typename AgentReduceT::TempStorage temp_storage;

// Consume input tiles
AccumT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op, transform_op).ConsumeTiles(even_share);

// Output result
if (threadIdx.x == 0)
{
detail::uninitialized_copy_single(d_out + blockIdx.x, block_aggregate);
}
}

/**
* @brief Reduce a single tile kernel entry point (single-block). Can be used
* to aggregate privatized thread block reductions from a previous
* multi-block reduction pass.
*
* @tparam ChainedPolicyT
* Chained tuning policy
*
* @tparam InputIteratorT
* Random-access input iterator type for reading input items @iterator
*
* @tparam OutputIteratorT
* Output iterator type for recording the reduced aggregate @iterator
*
* @tparam OffsetT
* Signed integer type for global offsets
*
* @tparam ReductionOpT
* Binary reduction functor type having member
* `T operator()(const T &a, const U &b)`
*
* @tparam InitT
* Initial value type
*
* @tparam AccumT
* Accumulator type
*
* @param[in] d_in
* Pointer to the input sequence of data items
*
* @param[out] d_out
* Pointer to the output aggregate
*
* @param[in] num_items
* Total number of input data items
*
* @param[in] reduction_op
* Binary reduction functor
*
* @param[in] init
* The initial value of the reduction
*/
template <typename ChainedPolicyT,
typename InputIteratorT,
typename OutputIteratorT,
typename OffsetT,
typename ReductionOpT,
typename InitT,
typename AccumT,
typename TransformOpT = ::cuda::std::__identity>
CUB_DETAIL_KERNEL_ATTRIBUTES __launch_bounds__(
int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS),
1) void DeviceReduceSingleTileKernel(InputIteratorT d_in,
OutputIteratorT d_out,
OffsetT num_items,
ReductionOpT reduction_op,
InitT init,
TransformOpT transform_op)
{
// Thread block type for reducing input tiles
using AgentReduceT =
AgentReduce<typename ChainedPolicyT::ActivePolicy::SingleTilePolicy,
InputIteratorT,
OutputIteratorT,
OffsetT,
ReductionOpT,
AccumT,
TransformOpT>;

// Shared memory storage
__shared__ typename AgentReduceT::TempStorage temp_storage;

// Check if empty problem
if (num_items == 0)
{
if (threadIdx.x == 0)
{
*d_out = init;
}

return;
}

// Consume input tiles
AccumT block_aggregate =
AgentReduceT(temp_storage, d_in, reduction_op, transform_op).ConsumeRange(OffsetT(0), num_items);

// Output result
if (threadIdx.x == 0)
{
detail::reduce::finalize_and_store_aggregate(d_out, reduction_op, init, block_aggregate);
}
}

/// Normalize input iterator to segment offset
template <typename T, typename OffsetT, typename IteratorT>
_CCCL_DEVICE _CCCL_FORCEINLINE void NormalizeReductionOutput(T& /*val*/, OffsetT /*base_offset*/, IteratorT /*itr*/)
Expand Down
Loading

0 comments on commit 64d28d1

Please sign in to comment.