Skip to content

Commit

Permalink
Fix docs bugs and use correct accumulator and intermediate result types
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Sep 5, 2024
1 parent 97f6cb3 commit 1ead3e2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion thrust/testing/cuda/transform_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void TestTransformScanDevice(ExecutionPolicy exec)
}

iter = iter_vec[0];
ref = {2, -1, 1, -3, 2};
ref = {3, 2, -1, 1, -3};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(ref, output);
Expand Down
6 changes: 4 additions & 2 deletions thrust/thrust/system/cuda/detail/transform_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ OutputIt _CCCL_HOST_DEVICE transform_inclusive_scan(
InitialValueType init,
ScanOp scan_op)
{
using result_type = ::cuda::std::__accumulator_t<ScanOp, cub::detail::value_t<InputIt>, InitialValueType>;
using input_type = typename thrust::iterator_value<InputIt>::type;
using result_type = thrust::detail::invoke_result_t<TransformOp, input_type>;
using value_type = thrust::remove_cvref_t<result_type>;

using size_type = typename iterator_traits<InputIt>::difference_type;
size_type num_items = static_cast<size_type>(thrust::distance(first, last));
using transformed_iterator_t = transform_input_iterator_t<result_type, InputIt, TransformOp>;
using transformed_iterator_t = transform_input_iterator_t<value_type, InputIt, TransformOp>;

return cuda_cub::inclusive_scan_n(
policy, transformed_iterator_t(first, transform_op), num_items, result, init, scan_op);
Expand Down
4 changes: 3 additions & 1 deletion thrust/thrust/system/detail/generic/transform_scan.inl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ _CCCL_HOST_DEVICE OutputIterator transform_inclusive_scan(
InitialValueType init,
BinaryFunction binary_op)
{
using ValueType = thrust::remove_cvref_t<InitialValueType>; // cuda::std::__accumulator_t?
using InputType = typename thrust::iterator_value<InputIterator>::type;
using ResultType = thrust::detail::invoke_result_t<UnaryFunction, InputType>;
using ValueType = thrust::remove_cvref_t<ResultType>;

thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _first(first, unary_op);
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _last(last, unary_op);
Expand Down
15 changes: 8 additions & 7 deletions thrust/thrust/transform_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ OutputIterator transform_inclusive_scan(
* performing an \p inclusive_scan on the tranformed sequence. In most
* cases, fusing these two operations together is more efficient, since
* fewer memory reads and writes are required. In \p transform_inclusive_scan,
* <tt>binary_op(unary_op(\*first), init)</tt> is assigned to
* <tt>\*result</tt> and the result of <tt>binary_op(unary_op(\*result),
* if <tt>binary_op(init, unary_op(\*first))</tt> is <tt>accum</tt>, it is assigned to
* <tt>\*result</tt> and the result of <tt>binary_op(accum,
* unary_op(\*(first + 1)))</tt> is assigned to <tt>\*(result + 1)</tt>,
* and so on. The transform scan operation is permitted to be in-place.
*
Expand All @@ -186,7 +186,7 @@ OutputIterator transform_inclusive_scan(
* \param last The end of the input sequence.
* \param result The beginning of the output sequence.
* \param unary_op The function used to tranform the input sequence.
* \param init The initial value of the \p inclusive_scan
* \param init The initial value of the \p transform_inclusive_scan
* \param binary_op The associatve operator used to 'sum' transformed values.
* \return The end of the output sequence.
*
Expand Down Expand Up @@ -239,6 +239,7 @@ _CCCL_HOST_DEVICE OutputIterator transform_inclusive_scan(
InputIterator first,
InputIterator last,
OutputIterator result,
UnaryFunction unary_op,
T init,
AssociativeOperator binary_op);

Expand All @@ -247,17 +248,17 @@ _CCCL_HOST_DEVICE OutputIterator transform_inclusive_scan(
* tranformation defined by \p unary_op into a temporary sequence and then
* performing an \p inclusive_scan on the tranformed sequence. In most
* cases, fusing these two operations together is more efficient, since
* fewer memory reads and writes are required. In \p transform_inclusive_scan,
* <tt>binary_op(unary_op(\*first), init)</tt> is assigned to
* <tt>\*result</tt> and the result of <tt>binary_op(unary_op(\*result),
* fewer memory reads and writes are required.In \p transform_inclusive_scan,
* if <tt>binary_op(init, unary_op(\*first))</tt> is <tt>accum</tt>, it is assigned to
* <tt>\*result</tt> and the result of <tt>binary_op(accum,
* unary_op(\*(first + 1)))</tt> is assigned to <tt>\*(result + 1)</tt>,
* and so on. The transform scan operation is permitted to be in-place.
*
* \param first The beginning of the input sequence.
* \param last The end of the input sequence.
* \param result The beginning of the output sequence.
* \param unary_op The function used to tranform the input sequence.
* \param init The initial value of the \p inclusive_scan
* \param init The initial value of the \p transform_inclusive_scan
* \param binary_op The associatve operator used to 'sum' transformed values.
* \return The end of the output sequence.
*
Expand Down

0 comments on commit 1ead3e2

Please sign in to comment.