Skip to content

Commit

Permalink
Fix bug remaining on thrust::inclusive_scan with init value with CDP (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Sep 3, 2024
1 parent 709ddec commit 498251c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
31 changes: 31 additions & 0 deletions thrust/testing/cuda/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ __global__ void inclusive_scan_kernel(ExecutionPolicy exec, Iterator1 first, Ite
thrust::inclusive_scan(exec, first, last, result);
}

template <typename ExecutionPolicy, typename Iterator1, typename Iterator2, typename T, typename Pred>
__global__ void
inclusive_scan_kernel(ExecutionPolicy exec, Iterator1 first, Iterator1 last, Iterator2 result, T init, Pred pred)
{
thrust::inclusive_scan(exec, first, last, result, init, pred);
}

template <typename ExecutionPolicy, typename Iterator1, typename Iterator2>
__global__ void exclusive_scan_kernel(ExecutionPolicy exec, Iterator1 first, Iterator1 last, Iterator2 result)
{
Expand Down Expand Up @@ -43,6 +50,16 @@ void TestScanDevice(ExecutionPolicy exec, const size_t n)

ASSERT_EQUAL(d_output, h_output);

thrust::inclusive_scan(h_input.begin(), h_input.end(), h_output.begin(), (T) 11, thrust::plus<T>{});

inclusive_scan_kernel<<<1, 1>>>(exec, d_input.begin(), d_input.end(), d_output.begin(), (T) 11, thrust::plus<T>{});
{
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);
}

ASSERT_EQUAL(d_output, h_output);

thrust::exclusive_scan(h_input.begin(), h_input.end(), h_output.begin());

exclusive_scan_kernel<<<1, 1>>>(exec, d_input.begin(), d_input.end(), d_output.begin());
Expand Down Expand Up @@ -186,6 +203,20 @@ void TestScanCudaStreams()
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(output, result);

// inclusive scan with init and op
iter =
thrust::inclusive_scan(thrust::cuda::par.on(s), input.begin(), input.end(), output.begin(), 3, thrust::plus<T>());
cudaStreamSynchronize(s);

result[0] = 4;
result[1] = 7;
result[2] = 5;
result[3] = 9;
result[4] = 4;
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(output, result);

// exclusive scan with init and op
iter =
thrust::exclusive_scan(thrust::cuda::par.on(s), input.begin(), input.end(), output.begin(), 3, thrust::plus<T>());
Expand Down
3 changes: 2 additions & 1 deletion thrust/thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n(
{
THRUST_CDP_DISPATCH(
(result = thrust::cuda_cub::detail::inclusive_scan_n_impl(policy, first, num_items, result, init, scan_op);),
(result = thrust::inclusive_scan(cvt_to_seq(derived_cast(policy)), first, first + num_items, result, scan_op);));
(result =
thrust::inclusive_scan(cvt_to_seq(derived_cast(policy)), first, first + num_items, result, init, scan_op);));
return result;
}

Expand Down

0 comments on commit 498251c

Please sign in to comment.