diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index e1cc4d5372..fe3324e3a1 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -657,14 +657,61 @@ struct ChainedPolicy /// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version template - CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int device_ptx_version, FunctorT& op) + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t + Invoke(int device_ptx_version, FunctorT& op, ::cuda::std::true_type = {}) { if (device_ptx_version < PolicyPtxVersion) { - return PrevPolicyT::Invoke(device_ptx_version, op); + // only continue traversing if the lowest architecture we are compiling for is smaller than the current policy + // TODO(bgruber): replace dispatching by if constexpr in C++17 + constexpr bool cond = +#ifdef __CUDA_ARCH_LIST__ + _NV_FIRST_ARG(__CUDA_ARCH_LIST__) < PolicyPtxVersion; +#else + true; +#endif + return PrevPolicyT::Invoke(device_ptx_version, op, ::cuda::std::bool_constant{}); + } + + // only invoke the function object for ptx versions smaller/equal to the largest architecture we are compiling for + // TODO(bgruber): replace dispatching by if constexpr in C++17 + constexpr bool cond = +#ifdef __CUDA_ARCH_LIST__ + // TODO(bgruber): alternative to get_last: just use the comma operator inside: (__CUDA_ARCH_LIST__). however nvcc + // warns on expression without side effects. should I supress the warning instead? + PolicyPtxVersion <= get_last({__CUDA_ARCH_LIST__}); +#else + true; +#endif + { + return DoInvoke(op, ::cuda::std::bool_constant{}); } + ::cuda::std::__libcpp_unreachable(); + } + + template + CUB_RUNTIME_FUNCTION static constexpr int get_last(const int (&a)[N]) + { + return a[N - 1]; + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int, FunctorT&, ::cuda::std::false_type) + { + ::cuda::std::__libcpp_unreachable(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT& op, ::cuda::std::true_type) + { return op.template Invoke(); } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT&, ::cuda::std::false_type) + { + ::cuda::std::__libcpp_unreachable(); + } }; /// Helper for dispatching into a policy chain (end-of-chain specialization) @@ -676,10 +723,17 @@ struct ChainedPolicy /// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version template - CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int /*ptx_version*/, FunctorT& op) + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t + Invoke(int /*ptx_version*/, FunctorT& op, ::cuda::std::true_type = {}) { return op.template Invoke(); } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int, FunctorT&, ::cuda::std::false_type) + { + ::cuda::std::__libcpp_unreachable(); + } }; CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_util_device.cu b/cub/test/catch2_test_util_device.cu index c59c076ec5..1021bfbb1f 100644 --- a/cub/test/catch2_test_util_device.cu +++ b/cub/test/catch2_test_util_device.cu @@ -87,3 +87,57 @@ CUB_TEST("CUB correctly identifies the ptx version the kernel was compiled for", REQUIRE(ptx_version == kernel_cuda_arch); REQUIRE(host_ptx_version == kernel_cuda_arch); } + +#ifdef __CUDA_ARCH_LIST__ +// We list all virtual architectures, which is what __CUDA_ARCH_LIST__ contains +struct policy_hub +{ +# define GEN_POLICY(cur, prev) \ + struct policy##cur : cub::ChainedPolicy \ + { \ + static constexpr int value = cur; \ + } + GEN_POLICY(100, 100); + GEN_POLICY(200, 100); + GEN_POLICY(300, 200); + GEN_POLICY(400, 300); + GEN_POLICY(500, 300); + GEN_POLICY(600, 500); + GEN_POLICY(700, 600); + GEN_POLICY(800, 700); + GEN_POLICY(900, 800); + GEN_POLICY(1000, 900); + // TODO(bgruber): add more policies here when new architectures emerge + GEN_POLICY(2000, 1000); +# undef GEN_POLICY + + using max_policy = policy2000; +}; + +// Check that selected is one of arches +template +struct check +{ + // we just compare the major version (divide by 100), so the tests also pass for e.g. __CUDA_ARCH_LIST__=860 + static_assert(::cuda::std::_Or<::cuda::std::bool_constant...>::value, ""); + using type = cudaError_t; +}; + +struct Closure +{ + // We need to fail template instantiation if ActivePolicy::value is not one from the __CUDA_ARCH_LIST__ + template + _CCCL_HOST_DEVICE auto Invoke() const -> typename check::type + { + return cudaSuccess; + } +}; + +CUB_TEST("ChainedPolicy prunes based on __CUDA_ARCH_LIST__", "[util][dispatch]") +{ + int ptx_version = 0; + cub::PtxVersion(ptx_version); + Closure c; + policy_hub::max_policy::Invoke(ptx_version, c); +} +#endif