diff --git a/cub/cub/device/dispatch/dispatch_transform.cuh b/cub/cub/device/dispatch/dispatch_transform.cuh index 583ac0e78b2..f989f5fa4e3 100644 --- a/cub/cub/device/dispatch/dispatch_transform.cuh +++ b/cub/cub/device/dispatch/dispatch_transform.cuh @@ -542,19 +542,18 @@ struct policy_hub>...>::value; static constexpr bool can_memcpy = all_contiguous && all_values_trivially_reloc; - // no_input_streams || !all_contiguous ? Algorithm::prefetch - // : !RequiresStableAddress && all_values_trivially_reloc - // ? ActivePolicy::alg_addr_unstable - // : ActivePolicy::alg_addr_stable; + + // TODO(bgruber): consider a separate kernel for just filling // below A100 struct policy300 : ChainedPolicy<300, policy300, policy300> { static constexpr int min_bif = arch_to_min_bif(300); // TODO(bgruber): we don't need algo, because we can just detect the type of algo_policy - static constexpr auto algorithm = RequiresStableAddress ? Algorithm::prefetch : Algorithm::unrolled_staged; + static constexpr auto algorithm = + RequiresStableAddress || no_input_streams ? Algorithm::prefetch : Algorithm::unrolled_staged; using algo_policy = - ::cuda::std::_If, unrolled_policy_t<256, items_per_thread_from_occupancy(256, 8, min_bif, loaded_bytes_per_iter)>>; }; @@ -567,9 +566,9 @@ struct policy_hub, async_copy_policy_t<256>>; + (RequiresStableAddress || !can_memcpy || no_input_streams) ? Algorithm::prefetch : Algorithm::memcpy_async; + using algo_policy = ::cuda::std:: + _If, async_copy_policy_t<256>>; }; // TODO(bgruber): should we add a tuning for 860? They should have items_per_thread_from_occupancy(256, 6, ...) @@ -577,10 +576,11 @@ struct policy_hub { - static constexpr int min_bif = arch_to_min_bif(900); - static constexpr auto algorithm = (RequiresStableAddress || !can_memcpy) ? Algorithm::prefetch : Algorithm::ublkcp; - using algo_policy = - ::cuda::std::_If, async_copy_policy_t<256>>; + static constexpr int min_bif = arch_to_min_bif(900); + static constexpr auto algorithm = + (RequiresStableAddress || !can_memcpy || no_input_streams) ? Algorithm::prefetch : Algorithm::ublkcp; + using algo_policy = ::cuda::std:: + _If, async_copy_policy_t<256>>; }; using max_policy = policy900;