Skip to content

Commit

Permalink
Add thurst::transform_inclusive_scan with init value (#2326)
Browse files Browse the repository at this point in the history
* Add thrust::transform_inclusive_scan with init value implementations

* Add tests for thrust::transform_inclusive_scan with init

* Add more tests and rebase on bug fix from thrust::inclusive_scan

* Add docs

* Use __accumulator_t

* Fix thrust tests readability with initializer_list and docs identation

* Fix docs bugs and use correct accumulator and intermediate result types
  • Loading branch information
gonidelis committed Sep 6, 2024
1 parent 4a32b1c commit 5647255
Show file tree
Hide file tree
Showing 7 changed files with 407 additions and 146 deletions.
162 changes: 85 additions & 77 deletions thrust/testing/cuda/transform_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ __global__ void transform_inclusive_scan_kernel(
*result2 = thrust::transform_inclusive_scan(exec, first, last, result1, f1, f2);
}

template <typename ExecutionPolicy,
typename Iterator1,
typename Iterator2,
typename Function1,
typename T,
typename Function2,
typename Iterator3>
__global__ void transform_inclusive_scan_init_kernel(
ExecutionPolicy exec,
Iterator1 first,
Iterator1 last,
Iterator2 result1,
Function1 f1,
T init,
Function2 f2,
Iterator3 result2)
{
*result2 = thrust::transform_inclusive_scan(exec, first, last, result1, f1, init, f2);
}

template <typename ExecutionPolicy,
typename Iterator1,
typename Iterator2,
Expand Down Expand Up @@ -50,16 +70,10 @@ void TestTransformScanDevice(ExecutionPolicy exec)

typename Vector::iterator iter;

Vector input(5);
Vector ref(5);
Vector input{1, 3, -2, 4, -5};
Vector ref{-1, -4, -2, -6, -1};
Vector output(5);

input[0] = 1;
input[1] = 3;
input[2] = -2;
input[3] = 4;
input[4] = -5;

Vector input_copy(input);

thrust::device_vector<typename Vector::iterator> iter_vec(1);
Expand All @@ -72,12 +86,21 @@ void TestTransformScanDevice(ExecutionPolicy exec)
ASSERT_EQUAL(cudaSuccess, err);
}

iter = iter_vec[0];
ref[0] = -1;
ref[1] = -4;
ref[2] = -2;
ref[3] = -6;
ref[4] = -1;
iter = iter_vec[0];
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(ref, output);

// inclusive scan with nonzero init
transform_inclusive_scan_init_kernel<<<1, 1>>>(
exec, input.begin(), input.end(), output.begin(), thrust::negate<T>(), 3, thrust::plus<T>(), iter_vec.begin());
{
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);
}

iter = iter_vec[0];
ref = {2, -1, 1, -3, 2};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(ref, output);
Expand All @@ -90,11 +113,7 @@ void TestTransformScanDevice(ExecutionPolicy exec)
ASSERT_EQUAL(cudaSuccess, err);
}

ref[0] = 0;
ref[1] = -1;
ref[2] = -4;
ref[3] = -2;
ref[4] = -6;
ref = {0, -1, -4, -2, -6};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(ref, output);
Expand All @@ -107,12 +126,8 @@ void TestTransformScanDevice(ExecutionPolicy exec)
ASSERT_EQUAL(cudaSuccess, err);
}

iter = iter_vec[0];
ref[0] = 3;
ref[1] = 2;
ref[2] = -1;
ref[3] = 1;
ref[4] = -3;
iter = iter_vec[0];
ref = {3, 2, -1, 1, -3};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(ref, output);
Expand All @@ -126,12 +141,22 @@ void TestTransformScanDevice(ExecutionPolicy exec)
ASSERT_EQUAL(cudaSuccess, err);
}

iter = iter_vec[0];
ref[0] = -1;
ref[1] = -4;
ref[2] = -2;
ref[3] = -6;
ref[4] = -1;
iter = iter_vec[0];
ref = {-1, -4, -2, -6, -1};
ASSERT_EQUAL(std::size_t(iter - input.begin()), input.size());
ASSERT_EQUAL(ref, input);

// inplace inclusive scan with init
input = input_copy;
transform_inclusive_scan_init_kernel<<<1, 1>>>(
exec, input.begin(), input.end(), input.begin(), thrust::negate<T>(), 3, thrust::plus<T>(), iter_vec.begin());
{
cudaError_t const err = cudaDeviceSynchronize();
ASSERT_EQUAL(cudaSuccess, err);
}

iter = iter_vec[0];
ref = {2, -1, 1, -3, 2};
ASSERT_EQUAL(std::size_t(iter - input.begin()), input.size());
ASSERT_EQUAL(ref, input);

Expand All @@ -144,12 +169,8 @@ void TestTransformScanDevice(ExecutionPolicy exec)
ASSERT_EQUAL(cudaSuccess, err);
}

iter = iter_vec[0];
ref[0] = 3;
ref[1] = 2;
ref[2] = -1;
ref[3] = 1;
ref[4] = -3;
iter = iter_vec[0];
ref = {3, 2, -1, 1, -3};
ASSERT_EQUAL(std::size_t(iter - input.begin()), input.size());
ASSERT_EQUAL(ref, input);
}
Expand All @@ -174,16 +195,10 @@ void TestTransformScanCudaStreams()

Vector::iterator iter;

Vector input(5);
Vector result(5);
Vector input{1, 3, -2, 4, -5};
Vector result{-1, -4, -2, -6, -1};
Vector output(5);

input[0] = 1;
input[1] = 3;
input[2] = -2;
input[3] = 4;
input[4] = -5;

Vector input_copy(input);

cudaStream_t s;
Expand All @@ -194,11 +209,16 @@ void TestTransformScanCudaStreams()
thrust::cuda::par.on(s), input.begin(), input.end(), output.begin(), thrust::negate<T>(), thrust::plus<T>());
cudaStreamSynchronize(s);

result[0] = -1;
result[1] = -4;
result[2] = -2;
result[3] = -6;
result[4] = -1;
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(output, result);

// inclusive scan with nonzero init
iter = thrust::transform_inclusive_scan(
thrust::cuda::par.on(s), input.begin(), input.end(), output.begin(), thrust::negate<T>(), 3, thrust::plus<T>());
cudaStreamSynchronize(s);

result = {2, -1, 1, -3, 2};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(output, result);
Expand All @@ -208,11 +228,7 @@ void TestTransformScanCudaStreams()
thrust::cuda::par.on(s), input.begin(), input.end(), output.begin(), thrust::negate<T>(), 0, thrust::plus<T>());
cudaStreamSynchronize(s);

result[0] = 0;
result[1] = -1;
result[2] = -4;
result[3] = -2;
result[4] = -6;
result = {0, -1, -4, -2, -6};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(output, result);
Expand All @@ -222,11 +238,7 @@ void TestTransformScanCudaStreams()
thrust::cuda::par.on(s), input.begin(), input.end(), output.begin(), thrust::negate<T>(), 3, thrust::plus<T>());
cudaStreamSynchronize(s);

result[0] = 3;
result[1] = 2;
result[2] = -1;
result[3] = 1;
result[4] = -3;
result = {3, 2, -1, 1, -3};
ASSERT_EQUAL(std::size_t(iter - output.begin()), input.size());
ASSERT_EQUAL(input, input_copy);
ASSERT_EQUAL(output, result);
Expand All @@ -237,11 +249,17 @@ void TestTransformScanCudaStreams()
thrust::cuda::par.on(s), input.begin(), input.end(), input.begin(), thrust::negate<T>(), thrust::plus<T>());
cudaStreamSynchronize(s);

result[0] = -1;
result[1] = -4;
result[2] = -2;
result[3] = -6;
result[4] = -1;
result = {-1, -4, -2, -6, -1};
ASSERT_EQUAL(std::size_t(iter - input.begin()), input.size());
ASSERT_EQUAL(input, result);

// inplace inclusive scan with init
input = input_copy;
iter = thrust::transform_inclusive_scan(
thrust::cuda::par.on(s), input.begin(), input.end(), input.begin(), thrust::negate<T>(), 3, thrust::plus<T>());
cudaStreamSynchronize(s);

result = {2, -1, 1, -3, 2};
ASSERT_EQUAL(std::size_t(iter - input.begin()), input.size());
ASSERT_EQUAL(input, result);

Expand All @@ -251,11 +269,7 @@ void TestTransformScanCudaStreams()
thrust::cuda::par.on(s), input.begin(), input.end(), input.begin(), thrust::negate<T>(), 3, thrust::plus<T>());
cudaStreamSynchronize(s);

result[0] = 3;
result[1] = 2;
result[2] = -1;
result[3] = 1;
result[4] = -3;
result = {3, 2, -1, 1, -3};
ASSERT_EQUAL(std::size_t(iter - input.begin()), input.size());
ASSERT_EQUAL(input, result);

Expand All @@ -270,16 +284,10 @@ void TestTransformScanConstAccumulator()

Vector::iterator iter;

Vector input(5);
Vector input{1, 3, -2, 4, -5};
Vector reference(5);
Vector output(5);

input[0] = 1;
input[1] = 3;
input[2] = -2;
input[3] = 4;
input[4] = -5;

thrust::transform_inclusive_scan(input.begin(), input.end(), output.begin(), thrust::identity<T>(), thrust::plus<T>());
thrust::inclusive_scan(input.begin(), input.end(), reference.begin(), thrust::plus<T>());

Expand Down
Loading

0 comments on commit 5647255

Please sign in to comment.