Skip to content

Commit

Permalink
Prune CUB's ChainedPolicy by __CUDA_ARCH_LIST__
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Aug 1, 2024
1 parent 27253d7 commit dfbb1ae
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
60 changes: 57 additions & 3 deletions cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -657,14 +657,61 @@ struct ChainedPolicy

/// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version
template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int device_ptx_version, FunctorT& op)
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
Invoke(int device_ptx_version, FunctorT& op, ::cuda::std::true_type = {})
{
if (device_ptx_version < PolicyPtxVersion)
{
return PrevPolicyT::Invoke(device_ptx_version, op);
// only continue traversing if the lowest architecture we are compiling for is smaller than the current policy
// TODO(bgruber): replace dispatching by if constexpr in C++17
constexpr bool cond =
#ifdef __CUDA_ARCH_LIST__
_NV_FIRST_ARG(__CUDA_ARCH_LIST__) < PolicyPtxVersion;
#else
true;
#endif
return PrevPolicyT::Invoke(device_ptx_version, op, ::cuda::std::bool_constant<cond>{});
}

// only invoke the function object for ptx versions smaller/equal to the largest architecture we are compiling for
// TODO(bgruber): replace dispatching by if constexpr in C++17
constexpr bool cond =
#ifdef __CUDA_ARCH_LIST__
// TODO(bgruber): alternative to get_last: just use the comma operator inside: (__CUDA_ARCH_LIST__). however nvcc
// warns on expression without side effects. should I supress the warning instead?
PolicyPtxVersion <= get_last({__CUDA_ARCH_LIST__});
#else
true;
#endif
{
return DoInvoke(op, ::cuda::std::bool_constant<cond>{});
}
::cuda::std::__libcpp_unreachable();
}

template <std::size_t N>
CUB_RUNTIME_FUNCTION static constexpr int get_last(const int (&a)[N])
{
return a[N - 1];
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int, FunctorT&, ::cuda::std::false_type)
{
::cuda::std::__libcpp_unreachable();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT& op, ::cuda::std::true_type)
{
return op.template Invoke<PolicyT>();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT&, ::cuda::std::false_type)
{
::cuda::std::__libcpp_unreachable();
}
};

/// Helper for dispatching into a policy chain (end-of-chain specialization)
Expand All @@ -676,10 +723,17 @@ struct ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>

/// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version
template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int /*ptx_version*/, FunctorT& op)
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t
Invoke(int /*ptx_version*/, FunctorT& op, ::cuda::std::true_type = {})
{
return op.template Invoke<PolicyT>();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int, FunctorT&, ::cuda::std::false_type)
{
::cuda::std::__libcpp_unreachable();
}
};

CUB_NAMESPACE_END
54 changes: 54 additions & 0 deletions cub/test/catch2_test_util_device.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,57 @@ CUB_TEST("CUB correctly identifies the ptx version the kernel was compiled for",
REQUIRE(ptx_version == kernel_cuda_arch);
REQUIRE(host_ptx_version == kernel_cuda_arch);
}

#ifdef __CUDA_ARCH_LIST__
// We list all virtual architectures, which is what __CUDA_ARCH_LIST__ contains
struct policy_hub
{
# define GEN_POLICY(cur, prev) \
struct policy##cur : cub::ChainedPolicy<cur, policy##cur, policy##prev> \
{ \
static constexpr int value = cur; \
}
GEN_POLICY(100, 100);
GEN_POLICY(200, 100);
GEN_POLICY(300, 200);
GEN_POLICY(400, 300);
GEN_POLICY(500, 300);
GEN_POLICY(600, 500);
GEN_POLICY(700, 600);
GEN_POLICY(800, 700);
GEN_POLICY(900, 800);
GEN_POLICY(1000, 900);
// TODO(bgruber): add more policies here when new architectures emerge
GEN_POLICY(2000, 1000);
# undef GEN_POLICY

using max_policy = policy2000;
};

// Check that selected is one of arches
template <int selected, int... arch_list>
struct check
{
// we just compare the major version (divide by 100), so the tests also pass for e.g. __CUDA_ARCH_LIST__=860
static_assert(::cuda::std::_Or<::cuda::std::bool_constant<selected / 100 == arch_list / 100>...>::value, "");
using type = cudaError_t;
};

struct Closure
{
// We need to fail template instantiation if ActivePolicy::value is not one from the __CUDA_ARCH_LIST__
template <typename ActivePolicy>
_CCCL_HOST_DEVICE auto Invoke() const -> typename check<ActivePolicy::value, __CUDA_ARCH_LIST__>::type
{
return cudaSuccess;
}
};

CUB_TEST("ChainedPolicy prunes based on __CUDA_ARCH_LIST__", "[util][dispatch]")
{
int ptx_version = 0;
cub::PtxVersion(ptx_version);
Closure c;
policy_hub::max_policy::Invoke(ptx_version, c);
}
#endif

0 comments on commit dfbb1ae

Please sign in to comment.