diff --git a/thrust/thrust/system/tbb/detail/reduce_by_key.inl b/thrust/thrust/system/tbb/detail/reduce_by_key.inl index 385f4e59799..944784991e6 100644 --- a/thrust/thrust/system/tbb/detail/reduce_by_key.inl +++ b/thrust/thrust/system/tbb/detail/reduce_by_key.inl @@ -25,6 +25,9 @@ #elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) # pragma system_header #endif // no system header + +#include + #include #include #include @@ -60,9 +63,13 @@ inline L divide_ri(const L x, const R y) return (x + (y - 1)) / y; } +template +using partial_sum_type = cub::detail:: + accumulator_t, thrust::iterator_value_t>; + template thrust::pair, thrust::iterator_value_t>> + thrust::pair, partial_sum_type>> reduce_last_segment_backward( InputIterator1 keys_first, InputIterator1 keys_last, @@ -77,8 +84,8 @@ reduce_last_segment_backward( thrust::reverse_iterator keys_last_r(keys_first); thrust::reverse_iterator values_first_r(values_first + n); - thrust::iterator_value_t result_key = *keys_first_r; - thrust::iterator_value_t result_value = *values_first_r; + thrust::iterator_value_t result_key = *keys_first_r; + partial_sum_type result_value = *values_first_r; // consume the entirety of the first key's sequence for (++keys_first_r, ++values_first_r; (keys_first_r != keys_last_r) && binary_pred(*keys_first_r, result_key); @@ -99,7 +106,7 @@ template , - thrust::iterator_value_t> + partial_sum_type> reduce_by_key_with_carry( InputIterator1 keys_first, InputIterator1 keys_last, @@ -111,7 +118,7 @@ reduce_by_key_with_carry( { // first, consume the last sequence to produce the carry // XXX is there an elegant way to pose this such that we don't need to default construct carry? - thrust::pair, thrust::iterator_value_t> carry; + thrust::pair, partial_sum_type> carry; thrust::tie(keys_last, carry) = reduce_last_segment_backward(keys_first, keys_last, values_first, binary_pred, binary_op); @@ -201,7 +208,7 @@ struct serial_reduce_by_key_body // consume the rest of the interval with reduce_by_key using key_type = thrust::iterator_value_t; - using value_type = thrust::iterator_value_t; + using value_type = partial_sum_type; // XXX is there a way to pose this so that we don't require default construction of carry? thrust::pair carry; @@ -345,7 +352,7 @@ thrust::pair reduce_by_key( // do a reduce_by_key serially in each thread // the final interval never has a carry by definition, so don't reserve space for it - using carry_type = thrust::iterator_value_t; + using carry_type = partial_sum_type; thrust::detail::temporary_array carries(0, exec, num_intervals - 1); // force grainsize == 1 with simple_partioner()