Skip to content

Commit

Permalink
Prune CUB's ChainedPolicy by __CUDA_ARCH_LIST__
Browse files Browse the repository at this point in the history
Co-authored-by: Elias Stehle <[email protected]>
  • Loading branch information
bernhardmgruber and elstehle committed Aug 19, 2024
1 parent 51c1b22 commit bedf081
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 1 deletion.
77 changes: 76 additions & 1 deletion cub/cub/util_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ struct SmVersionCacheTag
{};

/**
* \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10).
* \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10). If
* __CUDA_ARCH_LIST__ is defined, this value is one of __CUDA_ARCH_LIST__.
*
* \note This function may cache the result internally.
* \note This function is thread safe.
Expand Down Expand Up @@ -635,18 +636,79 @@ struct ChainedPolicy
template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int device_ptx_version, FunctorT& op)
{
// __CUDA_ARCH_LIST__ is only available from CTK 11.5 onwards
#ifdef __CUDA_ARCH_LIST__
return runtime_to_compiletime<__CUDA_ARCH_LIST__>(device_ptx_version, op);
#else
if (device_ptx_version < PolicyPtxVersion)
{
return PrevPolicyT::Invoke(device_ptx_version, op);
}
return op.template Invoke<PolicyT>();
#endif
}

private:
template <int, typename, typename>
friend struct ChainedPolicy; // let us call invoke_static of other ChainedPolicy instantiations

template <int... CudaArches, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t runtime_to_compiletime(int device_ptx_version, FunctorT& op)
{
// we instantiate invoke_static for each CudaArches, but only call the one matching device_ptx_version
cudaError_t e = cudaSuccess;
const cudaError_t dummy[] = {
(device_ptx_version == CudaArches ? (e = invoke_static<CudaArches>(op, ::cuda::std::true_type{}))
: cudaSuccess)...};
(void) dummy;
return e;
}

template <int DevicePtxVersion, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type)
{
// TODO(bgruber): drop diagnostic suppression in C++17
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4127) // suppress Conditional Expression is Constant
_CCCL_IF_CONSTEXPR (DevicePtxVersion < PolicyPtxVersion)
{
// TODO(bgruber): drop boolean tag dispatches in C++17, since _CCCL_IF_CONSTEXPR will discard this branch properly
return PrevPolicyT::template invoke_static<DevicePtxVersion>(
op, ::cuda::std::bool_constant<(DevicePtxVersion < PolicyPtxVersion)>{});
}
else
{
return do_invoke(op, ::cuda::std::bool_constant<DevicePtxVersion >= PolicyPtxVersion>{});
}
_CCCL_DIAG_POP
}

template <int, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type)
{
_LIBCUDACXX_UNREACHABLE();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t do_invoke(FunctorT& op, ::cuda::std::true_type)
{
return op.template Invoke<PolicyT>();
}

template <typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t do_invoke(FunctorT&, ::cuda::std::false_type)
{
_LIBCUDACXX_UNREACHABLE();
}
};

/// Helper for dispatching into a policy chain (end-of-chain specialization)
template <int PTX_VERSION, typename PolicyT>
struct ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>
{
template <int, typename, typename>
friend struct ChainedPolicy; // befriend primary template, so it can call invoke_static

/// The policy for the active compiler pass
using ActivePolicy = PolicyT;

Expand All @@ -656,6 +718,19 @@ struct ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>
{
return op.template Invoke<PolicyT>();
}

private:
template <int, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type)
{
return op.template Invoke<PolicyT>();
}

template <int, typename FunctorT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type)
{
_LIBCUDACXX_UNREACHABLE();
}
};

CUB_NAMESPACE_END
173 changes: 173 additions & 0 deletions cub/test/catch2_test_util_device.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>

#include <cuda/std/__algorithm/lower_bound.h>
#include <cuda/std/array>

#include "catch2_test_helper.h"
#include "catch2_test_launch_helper.h"

Expand Down Expand Up @@ -87,3 +90,173 @@ CUB_TEST("CUB correctly identifies the ptx version the kernel was compiled for",
REQUIRE(ptx_version == kernel_cuda_arch);
REQUIRE(host_ptx_version == kernel_cuda_arch);
}

#ifdef __CUDA_ARCH_LIST__
CUB_TEST("PtxVersion returns a value from __CUDA_ARCH_LIST__", "[util][dispatch]")
{
int ptx_version = 0;
cub::PtxVersion(ptx_version);
const auto arch_list = std::vector<int>{__CUDA_ARCH_LIST__};
REQUIRE(std::find(arch_list.begin(), arch_list.end(), ptx_version) != arch_list.end());
}
#endif

#define GEN_POLICY(cur, prev) \
struct policy##cur : cub::ChainedPolicy<cur, policy##cur, policy##prev> \
{ \
static constexpr int value = cur; \
}

#ifdef __CUDA_ARCH_LIST__
// We list policies for all virtual architectures that __CUDA_ARCH_LIST__ can contain, so the actual architectures the
// tests are compiled for should match to one of those
struct policy_hub_all
{
// for the list of supported architectures, see libcudacxx/include/nv/target
GEN_POLICY(350, 350);
GEN_POLICY(370, 350);
GEN_POLICY(500, 370);
GEN_POLICY(520, 500);
GEN_POLICY(530, 520);
GEN_POLICY(600, 530);
GEN_POLICY(610, 600);
GEN_POLICY(620, 610);
GEN_POLICY(700, 620);
GEN_POLICY(720, 700);
GEN_POLICY(750, 720);
GEN_POLICY(800, 750);
GEN_POLICY(860, 800);
GEN_POLICY(870, 860);
GEN_POLICY(890, 870);
GEN_POLICY(900, 890);
GEN_POLICY(1000, 900);
// add more policies here when new architectures emerge
GEN_POLICY(2000, 1000); // non-existing architecture, just to test pruning

using max_policy = policy2000;
};

// Check that selected is one of arches
template <int Selected, int... ArchList>
struct check
{
static_assert(::cuda::std::_Or<::cuda::std::bool_constant<Selected == ArchList>...>::value, "");
using type = cudaError_t;
};

struct closure_all
{
int ptx_version;

// We need to fail template instantiation if ActivePolicy::value is not one from the __CUDA_ARCH_LIST__
template <typename ActivePolicy>
CUB_RUNTIME_FUNCTION auto Invoke() const -> typename check<ActivePolicy::value, __CUDA_ARCH_LIST__>::type
{
// policy_hub_all must list all PTX virtual architectures, so we can do an exact comparison here
# if TEST_LAUNCH == 0
REQUIRE(+ActivePolicy::value == ptx_version);
# endif // TEST_LAUNCH == 0
return +ActivePolicy::value == ptx_version ? cudaSuccess : cudaErrorInvalidValue;
}
};

CUB_RUNTIME_FUNCTION cudaError_t
check_chained_policy_prunes_to_arch_list(void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t = 0)
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}
int ptx_version = 0;
cub::PtxVersion(ptx_version);
closure_all c{ptx_version};
return policy_hub_all::max_policy::Invoke(ptx_version, c);
}

DECLARE_LAUNCH_WRAPPER(check_chained_policy_prunes_to_arch_list, check_wrapper_all);

CUB_TEST("ChainedPolicy prunes based on __CUDA_ARCH_LIST__", "[util][dispatch]")
{
check_wrapper_all();
}
#endif

template <int NumPolicies>
struct check_policy_closure
{
int ptx_version;
::cuda::std::array<int, NumPolicies> policies;

template <typename ActivePolicy>
CUB_RUNTIME_FUNCTION cudaError_t Invoke() const
{
#define CHECK_EXPR +ActivePolicy::value == ::cuda::std::lower_bound(policies.begin(), policies.end(), ptx_version)[-1]
#if TEST_LAUNCH == 0
CAPTURE(ptx_version, policies);
REQUIRE(CHECK_EXPR);
#endif // TEST_LAUNCH == 0
return CHECK_EXPR ? cudaSuccess : cudaErrorInvalidValue;
#undef CHECK_EXPR
}
};

template <typename PolicyHub, int NumPolicies>
CUB_RUNTIME_FUNCTION cudaError_t check_chained_policy_selects_correct_policy(
void* d_temp_storage, size_t& temp_storage_bytes, ::cuda::std::array<int, NumPolicies> policies, cudaStream_t = 0)
{
if (d_temp_storage == nullptr)
{
temp_storage_bytes = 1;
return cudaSuccess;
}
int ptx_version = 0;
cub::PtxVersion(ptx_version);
check_policy_closure<NumPolicies> c{ptx_version, std::move(policies)};
return PolicyHub::max_policy::Invoke(ptx_version, c);
}

DECLARE_TMPL_LAUNCH_WRAPPER(check_chained_policy_selects_correct_policy,
check_wrapper_some,
ESCAPE_LIST(typename PolicyHub, int NumPolicies),
ESCAPE_LIST(PolicyHub, NumPolicies));

struct policy_hub_some
{
GEN_POLICY(350, 350);
GEN_POLICY(500, 350);
GEN_POLICY(700, 500);
GEN_POLICY(900, 700);
GEN_POLICY(2000, 900); // non-existing architecture, just to test
using max_policy = policy2000;
};

struct policy_hub_few
{
GEN_POLICY(350, 350);
GEN_POLICY(600, 350);
GEN_POLICY(2000, 600); // non-existing architecture, just to test
using max_policy = policy2000;
};

struct policy_hub_minimal
{
GEN_POLICY(350, 350);
using max_policy = policy350;
};

CUB_TEST("ChainedPolicy invokes correct policy", "[util][dispatch]")
{
SECTION("policy_hub_some")
{
check_wrapper_some<policy_hub_some, 5>(::cuda::std::array<int, 5>{350, 500, 700, 900, 2000});
}
SECTION("policy_hub_few")
{
check_wrapper_some<policy_hub_few, 3>(::cuda::std::array<int, 3>{350, 600, 2000});
}
SECTION("policy_hub_minimal")
{
check_wrapper_some<policy_hub_minimal, 1>(::cuda::std::array<int, 1>{350});
}
}

0 comments on commit bedf081

Please sign in to comment.