Skip to content

Commit

Permalink
[CUDAX] Branch out an experimental version of stream_ref (#2343)
Browse files Browse the repository at this point in the history
* Branch out experimental version of stream_ref

* Add tests for the experimental part of stream_ref

* Move inequality check

* typo

* Remove not needed using declaration

* Add a TODO to remove NULL stream_ref

* Remove TODO and remove NULL stream ref constructor

* move runtime api include after the system header decl

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
pciolkosz and miscco committed Sep 4, 2024
1 parent 1e9125e commit 0251ae4
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 68 deletions.
68 changes: 2 additions & 66 deletions cudax/include/cuda/experimental/__stream/stream.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
#endif // no system header

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/stream_ref>

#include <cuda/experimental/__device/device_ref.cuh>
#include <cuda/experimental/__event/timed_event.cuh>
#include <cuda/experimental/__stream/stream_ref.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

namespace cuda::experimental
Expand All @@ -51,6 +50,7 @@ struct stream : stream_ref
//!
//! @throws cuda_error if stream creation fails
explicit stream(device_ref __dev, int __priority = default_priority)
: stream_ref(detail::invalid_stream)
{
[[maybe_unused]] __ensure_current_device __dev_setter(__dev);
_CCCL_TRY_CUDA_API(
Expand Down Expand Up @@ -110,70 +110,6 @@ struct stream : stream_ref

stream& operator=(const stream&) = delete;

// Ideally records and waits below would be in stream_ref, but we can't have it depend on cudax yet

//! @brief Create a new event and record it into this stream
//!
//! @return A new event that was recorded into this stream
//!
//! @throws cuda_error if event creation or record failed
_CCCL_NODISCARD event record_event(event::flags __flags = event::flags::none) const
{
return event(*this, __flags);
}

//! @brief Create a new timed event and record it into this stream
//!
//! @return A new timed event that was recorded into this stream
//!
//! @throws cuda_error if event creation or record failed
_CCCL_NODISCARD timed_event record_timed_event(event::flags __flags = event::flags::none) const
{
return timed_event(*this, __flags);
}

using stream_ref::wait;

//! @brief Make all future work submitted into this stream depend on completion of the specified event
//!
//! @param __ev Event that this stream should wait for
//!
//! @throws cuda_error if inserting the dependency fails
void wait(event_ref __ev) const
{
assert(__ev.get() != nullptr);
// Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty
detail::driver::streamWaitEvent(get(), __ev.get());
}

//! @brief Make all future work submitted into this stream depend on completion of all work from the specified
//! stream
//!
//! @param __other Stream that this stream should wait for
//!
//! @throws cuda_error if inserting the dependency fails
void wait(stream_ref __other) const
{
// TODO consider an optimization to not create an event every time and instead have one persistent event or one
// per stream
assert(__stream != detail::invalid_stream);
event __tmp(__other);
wait(__tmp);
}

//! @brief Get device under which this stream was created.
//!
//! @throws cuda_error if device check fails
device_ref device() const
{
// Because the stream can come from_native_handle, we can't just loop over devices comparing contexts,
// lower to CUDART for this instead
__ensure_current_device __dev_setter(*this);
int result;
_CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &result);
return result;
}

//! @brief Construct an `stream` object from a native `cudaStream_t` handle.
//!
//! @param __handle The native handle
Expand Down
107 changes: 107 additions & 0 deletions cudax/include/cuda/experimental/__stream/stream_ref.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__STREAM_STREAM_REF
#define _CUDAX__STREAM_STREAM_REF

#include <cuda/std/detail/__config>
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda_runtime_api.h>

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/stream_ref>

#include <cuda/experimental/__device/device_ref.cuh>
#include <cuda/experimental/__event/timed_event.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

namespace cuda::experimental
{

//! @brief A non-owning wrapper for cudaStream_t.
struct stream_ref : ::cuda::stream_ref
{
using ::cuda::stream_ref::stream_ref;

stream_ref() = delete;

//! @brief Create a new event and record it into this stream
//!
//! @return A new event that was recorded into this stream
//!
//! @throws cuda_error if event creation or record failed
_CCCL_NODISCARD event record_event(event::flags __flags = event::flags::none) const
{
return event(*this, __flags);
}

//! @brief Create a new timed event and record it into this stream
//!
//! @return A new timed event that was recorded into this stream
//!
//! @throws cuda_error if event creation or record failed
_CCCL_NODISCARD timed_event record_timed_event(event::flags __flags = event::flags::none) const
{
return timed_event(*this, __flags);
}

using ::cuda::stream_ref::wait;

//! @brief Make all future work submitted into this stream depend on completion of the specified event
//!
//! @param __ev Event that this stream should wait for
//!
//! @throws cuda_error if inserting the dependency fails
void wait(event_ref __ev) const
{
assert(__ev.get() != nullptr);
// Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty
detail::driver::streamWaitEvent(get(), __ev.get());
}

//! @brief Make all future work submitted into this stream depend on completion of all work from the specified
//! stream
//!
//! @param __other Stream that this stream should wait for
//!
//! @throws cuda_error if inserting the dependency fails
void wait(stream_ref __other) const
{
// TODO consider an optimization to not create an event every time and instead have one persistent event or one
// per stream
assert(__stream != detail::invalid_stream);
event __tmp(__other);
wait(__tmp);
}

//! @brief Get device under which this stream was created.
//!
//! @throws cuda_error if device check fails
device_ref device() const
{
// Because the stream can come from_native_handle, we can't just loop over devices comparing contexts,
// lower to CUDART for this instead
__ensure_current_device __dev_setter(*this);
int result;
_CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &result);
return result;
}
};

} // namespace cuda::experimental

#endif // _CUDAX__STREAM_STREAM_REF
14 changes: 12 additions & 2 deletions cudax/test/stream/stream_smoke.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ TEST_CASE("From native handle", "[stream]")
CUDART(cudaStreamDestroy(handle));
}

TEST_CASE("Can add dependency into a stream", "[stream]")
template <typename StreamType>
void add_dependency_test(const StreamType& waiter, const StreamType& waitee)
{
cudax::stream waiter, waitee;
CUDAX_REQUIRE(waiter != waitee);

auto verify_dependency = [&](const auto& insert_dependency) {
Expand Down Expand Up @@ -91,6 +91,14 @@ TEST_CASE("Can add dependency into a stream", "[stream]")
}
}

TEST_CASE("Can add dependency into a stream", "[stream]")
{
cudax::stream waiter, waitee;

add_dependency_test<cudax::stream>(waiter, waitee);
add_dependency_test<cudax::stream_ref>(waiter, waitee);
}

TEST_CASE("Stream priority", "[stream]")
{
cudax::stream stream_default_prio;
Expand All @@ -111,4 +119,6 @@ TEST_CASE("Stream get device", "[stream]")
CUDART(cudaStreamCreate(&stream_handle));
auto stream_cudart = cudax::stream::from_native_handle(stream_handle);
CUDAX_REQUIRE(stream_cudart.device() == *std::prev(cudax::devices.end()));
auto stream_ref_cudart = cudax::stream_ref(stream_handle);
CUDAX_REQUIRE(stream_ref_cudart.device() == *std::prev(cudax::devices.end()));
}

0 comments on commit 0251ae4

Please sign in to comment.