From 0251ae4c5d73e28e7cbc77850195ea41036631ed Mon Sep 17 00:00:00 2001 From: pciolkosz Date: Tue, 3 Sep 2024 23:51:16 -0700 Subject: [PATCH] [CUDAX] Branch out an experimental version of stream_ref (#2343) * Branch out experimental version of stream_ref * Add tests for the experimental part of stream_ref * Move inequality check * typo * Remove not needed using declaration * Add a TODO to remove NULL stream_ref * Remove TODO and remove NULL stream ref constructor * move runtime api include after the system header decl Co-authored-by: Michael Schellenberger Costa --- .../cuda/experimental/__stream/stream.cuh | 68 +---------- .../cuda/experimental/__stream/stream_ref.cuh | 107 ++++++++++++++++++ cudax/test/stream/stream_smoke.cu | 14 ++- 3 files changed, 121 insertions(+), 68 deletions(-) create mode 100644 cudax/include/cuda/experimental/__stream/stream_ref.cuh diff --git a/cudax/include/cuda/experimental/__stream/stream.cuh b/cudax/include/cuda/experimental/__stream/stream.cuh index 27f0f698db..54efde0116 100644 --- a/cudax/include/cuda/experimental/__stream/stream.cuh +++ b/cudax/include/cuda/experimental/__stream/stream.cuh @@ -23,10 +23,9 @@ #endif // no system header #include -#include #include -#include +#include #include namespace cuda::experimental @@ -51,6 +50,7 @@ struct stream : stream_ref //! //! @throws cuda_error if stream creation fails explicit stream(device_ref __dev, int __priority = default_priority) + : stream_ref(detail::invalid_stream) { [[maybe_unused]] __ensure_current_device __dev_setter(__dev); _CCCL_TRY_CUDA_API( @@ -110,70 +110,6 @@ struct stream : stream_ref stream& operator=(const stream&) = delete; - // Ideally records and waits below would be in stream_ref, but we can't have it depend on cudax yet - - //! @brief Create a new event and record it into this stream - //! - //! @return A new event that was recorded into this stream - //! - //! @throws cuda_error if event creation or record failed - _CCCL_NODISCARD event record_event(event::flags __flags = event::flags::none) const - { - return event(*this, __flags); - } - - //! @brief Create a new timed event and record it into this stream - //! - //! @return A new timed event that was recorded into this stream - //! - //! @throws cuda_error if event creation or record failed - _CCCL_NODISCARD timed_event record_timed_event(event::flags __flags = event::flags::none) const - { - return timed_event(*this, __flags); - } - - using stream_ref::wait; - - //! @brief Make all future work submitted into this stream depend on completion of the specified event - //! - //! @param __ev Event that this stream should wait for - //! - //! @throws cuda_error if inserting the dependency fails - void wait(event_ref __ev) const - { - assert(__ev.get() != nullptr); - // Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty - detail::driver::streamWaitEvent(get(), __ev.get()); - } - - //! @brief Make all future work submitted into this stream depend on completion of all work from the specified - //! stream - //! - //! @param __other Stream that this stream should wait for - //! - //! @throws cuda_error if inserting the dependency fails - void wait(stream_ref __other) const - { - // TODO consider an optimization to not create an event every time and instead have one persistent event or one - // per stream - assert(__stream != detail::invalid_stream); - event __tmp(__other); - wait(__tmp); - } - - //! @brief Get device under which this stream was created. - //! - //! @throws cuda_error if device check fails - device_ref device() const - { - // Because the stream can come from_native_handle, we can't just loop over devices comparing contexts, - // lower to CUDART for this instead - __ensure_current_device __dev_setter(*this); - int result; - _CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &result); - return result; - } - //! @brief Construct an `stream` object from a native `cudaStream_t` handle. //! //! @param __handle The native handle diff --git a/cudax/include/cuda/experimental/__stream/stream_ref.cuh b/cudax/include/cuda/experimental/__stream/stream_ref.cuh new file mode 100644 index 0000000000..14b438356f --- /dev/null +++ b/cudax/include/cuda/experimental/__stream/stream_ref.cuh @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDA Experimental in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#ifndef _CUDAX__STREAM_STREAM_REF +#define _CUDAX__STREAM_STREAM_REF + +#include +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include + +#include +#include + +#include +#include +#include + +namespace cuda::experimental +{ + +//! @brief A non-owning wrapper for cudaStream_t. +struct stream_ref : ::cuda::stream_ref +{ + using ::cuda::stream_ref::stream_ref; + + stream_ref() = delete; + + //! @brief Create a new event and record it into this stream + //! + //! @return A new event that was recorded into this stream + //! + //! @throws cuda_error if event creation or record failed + _CCCL_NODISCARD event record_event(event::flags __flags = event::flags::none) const + { + return event(*this, __flags); + } + + //! @brief Create a new timed event and record it into this stream + //! + //! @return A new timed event that was recorded into this stream + //! + //! @throws cuda_error if event creation or record failed + _CCCL_NODISCARD timed_event record_timed_event(event::flags __flags = event::flags::none) const + { + return timed_event(*this, __flags); + } + + using ::cuda::stream_ref::wait; + + //! @brief Make all future work submitted into this stream depend on completion of the specified event + //! + //! @param __ev Event that this stream should wait for + //! + //! @throws cuda_error if inserting the dependency fails + void wait(event_ref __ev) const + { + assert(__ev.get() != nullptr); + // Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty + detail::driver::streamWaitEvent(get(), __ev.get()); + } + + //! @brief Make all future work submitted into this stream depend on completion of all work from the specified + //! stream + //! + //! @param __other Stream that this stream should wait for + //! + //! @throws cuda_error if inserting the dependency fails + void wait(stream_ref __other) const + { + // TODO consider an optimization to not create an event every time and instead have one persistent event or one + // per stream + assert(__stream != detail::invalid_stream); + event __tmp(__other); + wait(__tmp); + } + + //! @brief Get device under which this stream was created. + //! + //! @throws cuda_error if device check fails + device_ref device() const + { + // Because the stream can come from_native_handle, we can't just loop over devices comparing contexts, + // lower to CUDART for this instead + __ensure_current_device __dev_setter(*this); + int result; + _CCCL_TRY_CUDA_API(cudaGetDevice, "Could not get device from a stream", &result); + return result; + } +}; + +} // namespace cuda::experimental + +#endif // _CUDAX__STREAM_STREAM_REF diff --git a/cudax/test/stream/stream_smoke.cu b/cudax/test/stream/stream_smoke.cu index c1adee89b6..58bf31b6d3 100644 --- a/cudax/test/stream/stream_smoke.cu +++ b/cudax/test/stream/stream_smoke.cu @@ -39,9 +39,9 @@ TEST_CASE("From native handle", "[stream]") CUDART(cudaStreamDestroy(handle)); } -TEST_CASE("Can add dependency into a stream", "[stream]") +template +void add_dependency_test(const StreamType& waiter, const StreamType& waitee) { - cudax::stream waiter, waitee; CUDAX_REQUIRE(waiter != waitee); auto verify_dependency = [&](const auto& insert_dependency) { @@ -91,6 +91,14 @@ TEST_CASE("Can add dependency into a stream", "[stream]") } } +TEST_CASE("Can add dependency into a stream", "[stream]") +{ + cudax::stream waiter, waitee; + + add_dependency_test(waiter, waitee); + add_dependency_test(waiter, waitee); +} + TEST_CASE("Stream priority", "[stream]") { cudax::stream stream_default_prio; @@ -111,4 +119,6 @@ TEST_CASE("Stream get device", "[stream]") CUDART(cudaStreamCreate(&stream_handle)); auto stream_cudart = cudax::stream::from_native_handle(stream_handle); CUDAX_REQUIRE(stream_cudart.device() == *std::prev(cudax::devices.end())); + auto stream_ref_cudart = cudax::stream_ref(stream_handle); + CUDAX_REQUIRE(stream_ref_cudart.device() == *std::prev(cudax::devices.end())); }