Skip to content

Commit

Permalink
Address reviewer feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jul 18, 2024
1 parent 5220a75 commit 949cc83
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions thrust/thrust/system/tbb/detail/reduce_by_key.inl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cub/detail/type_traits.cuh>

#include <thrust/detail/minmax.h>
#include <thrust/detail/range/tail_flags.h>
#include <thrust/detail/seq.h>
Expand Down Expand Up @@ -60,9 +63,13 @@ inline L divide_ri(const L x, const R y)
return (x + (y - 1)) / y;
}

template <typename BinaryFunction, typename InputIterator>
using partial_sum_type = cub::detail::
accumulator_t<BinaryFunction, thrust::iterator_value_t<InputIterator>, thrust::iterator_value_t<InputIterator>>;

template <typename InputIterator1, typename InputIterator2, typename BinaryPredicate, typename BinaryFunction>
thrust::pair<InputIterator1,
thrust::pair<thrust::iterator_value_t<InputIterator1>, thrust::iterator_value_t<InputIterator2>>>
thrust::pair<thrust::iterator_value_t<InputIterator1>, partial_sum_type<InputIterator2, BinaryFunction>>>
reduce_last_segment_backward(
InputIterator1 keys_first,
InputIterator1 keys_last,
Expand All @@ -77,8 +84,8 @@ reduce_last_segment_backward(
thrust::reverse_iterator<InputIterator1> keys_last_r(keys_first);
thrust::reverse_iterator<InputIterator2> values_first_r(values_first + n);

thrust::iterator_value_t<InputIterator1> result_key = *keys_first_r;
thrust::iterator_value_t<InputIterator2> result_value = *values_first_r;
thrust::iterator_value_t<InputIterator1> result_key = *keys_first_r;
partial_sum_type<InputIterator2, BinaryFunction> result_value = *values_first_r;

// consume the entirety of the first key's sequence
for (++keys_first_r, ++values_first_r; (keys_first_r != keys_last_r) && binary_pred(*keys_first_r, result_key);
Expand All @@ -99,7 +106,7 @@ template <typename InputIterator1,
thrust::tuple<OutputIterator1,
OutputIterator2,
thrust::iterator_value_t<InputIterator1>,
thrust::iterator_value_t<InputIterator2>>
partial_sum_type<InputIterator2, BinaryFunction>>
reduce_by_key_with_carry(
InputIterator1 keys_first,
InputIterator1 keys_last,
Expand All @@ -111,7 +118,7 @@ reduce_by_key_with_carry(
{
// first, consume the last sequence to produce the carry
// XXX is there an elegant way to pose this such that we don't need to default construct carry?
thrust::pair<thrust::iterator_value_t<InputIterator1>, thrust::iterator_value_t<InputIterator2>> carry;
thrust::pair<thrust::iterator_value_t<InputIterator1>, partial_sum_type<InputIterator2, BinaryFunction>> carry;

thrust::tie(keys_last, carry) =
reduce_last_segment_backward(keys_first, keys_last, values_first, binary_pred, binary_op);
Expand Down Expand Up @@ -201,7 +208,7 @@ struct serial_reduce_by_key_body

// consume the rest of the interval with reduce_by_key
using key_type = thrust::iterator_value_t<Iterator1>;
using value_type = thrust::iterator_value_t<Iterator2>;
using value_type = partial_sum_type<Iterator2, BinaryFunction>;

// XXX is there a way to pose this so that we don't require default construction of carry?
thrust::pair<key_type, value_type> carry;
Expand Down Expand Up @@ -345,7 +352,7 @@ thrust::pair<Iterator3, Iterator4> reduce_by_key(

// do a reduce_by_key serially in each thread
// the final interval never has a carry by definition, so don't reserve space for it
using carry_type = thrust::iterator_value_t<Iterator2>;
using carry_type = partial_sum_type<Iterator2, BinaryFunction>;
thrust::detail::temporary_array<carry_type, DerivedPolicy> carries(0, exec, num_intervals - 1);

// force grainsize == 1 with simple_partioner()
Expand Down

0 comments on commit 949cc83

Please sign in to comment.