Skip to content

Commit

Permalink
Rework mdspan concept emulation (#2213)
Browse files Browse the repository at this point in the history
It is proving difficult to handle for msvc and also the one we are using in libcu++ it much cleaner

Gets #2160 compiling on MSVC
  • Loading branch information
miscco committed Aug 9, 2024
1 parent f95f211 commit 8e20c9a
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 437 deletions.
7 changes: 3 additions & 4 deletions libcudacxx/include/cuda/std/__mdspan/default_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ struct default_accessor

__MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr default_accessor() noexcept = default;

__MDSPAN_TEMPLATE_REQUIRES(class _OtherElementType,
/* requires */ (_CCCL_TRAIT(is_convertible, _OtherElementType (*)[], element_type (*)[])))
__MDSPAN_INLINE_FUNCTION
constexpr default_accessor(default_accessor<_OtherElementType>) noexcept {}
_LIBCUDACXX_TEMPLATE(class _OtherElementType)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(is_convertible, _OtherElementType (*)[], element_type (*)[]))
__MDSPAN_INLINE_FUNCTION constexpr default_accessor(default_accessor<_OtherElementType>) noexcept {}

__MDSPAN_INLINE_FUNCTION
constexpr data_handle_type offset(data_handle_type __p, size_t __i) const noexcept
Expand Down
95 changes: 41 additions & 54 deletions libcudacxx/include/cuda/std/__mdspan/extents.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,13 @@ class extents
__MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr extents() noexcept = default;

// Converting constructor
__MDSPAN_TEMPLATE_REQUIRES(
class _OtherIndexType,
size_t... _OtherExtents,
/* requires */
(
/* multi-stage check to protect from invalid pack expansion when sizes don't match? */
decltype(__detail::__check_compatible_extents(
integral_constant<bool, sizeof...(_Extents) == sizeof...(_OtherExtents)>{},
_CUDA_VSTD::integer_sequence<size_t, _Extents...>{},
_CUDA_VSTD::integer_sequence<size_t, _OtherExtents...>{}))::value))
_LIBCUDACXX_TEMPLATE(class _OtherIndexType, size_t... _OtherExtents)
_LIBCUDACXX_REQUIRES(
/* multi-stage check to protect from invalid pack expansion when sizes don't match? */
(decltype(__detail::__check_compatible_extents(
integral_constant<bool, sizeof...(_Extents) == sizeof...(_OtherExtents)>{},
_CUDA_VSTD::integer_sequence<size_t, _Extents...>{},
_CUDA_VSTD::integer_sequence<size_t, _OtherExtents...>{}))::value))
__MDSPAN_INLINE_FUNCTION
__MDSPAN_CONDITIONAL_EXPLICIT(
(((_Extents != dynamic_extent) && (_OtherExtents == dynamic_extent)) || ...)
Expand Down Expand Up @@ -287,23 +284,23 @@ class extents
}

# ifdef __NVCC__
__MDSPAN_TEMPLATE_REQUIRES(
class... _Integral,
/* requires */ (
// TODO: check whether the other version works with newest NVCC, doesn't with 11.4
// NVCC seems to pick up rank_dynamic from the wrong extents type???
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
&& __MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */) &&
// NVCC chokes on the fold thingy here so wrote the workaround
((sizeof...(_Integral) == __detail::__count_dynamic_extents<_Extents...>::val)
|| (sizeof...(_Integral) == sizeof...(_Extents)))))
_LIBCUDACXX_TEMPLATE(class... _Integral)
_LIBCUDACXX_REQUIRES(
// TODO: check whether the other version works with newest NVCC, doesn't with 11.4
// NVCC seems to pick up rank_dynamic from the wrong extents type???
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
_LIBCUDACXX_AND __MDSPAN_FOLD_AND(
_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */) _LIBCUDACXX_AND
// NVCC chokes on the fold thingy here so wrote the workaround
((sizeof...(_Integral) == __detail::__count_dynamic_extents<_Extents...>::val)
|| (sizeof...(_Integral) == sizeof...(_Extents))))
# else
__MDSPAN_TEMPLATE_REQUIRES(
class... _Integral,
/* requires */ (
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
&& __MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */)
&& ((sizeof...(_Integral) == rank_dynamic()) || (sizeof...(_Integral) == rank()))))
_LIBCUDACXX_TEMPLATE(class... _Integral)
_LIBCUDACXX_REQUIRES(
__MDSPAN_FOLD_AND(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Integral, index_type) /* && ... */)
_LIBCUDACXX_AND __MDSPAN_FOLD_AND(
_CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Integral) /* && ... */)
_LIBCUDACXX_AND((sizeof...(_Integral) == rank_dynamic()) || (sizeof...(_Integral) == rank())))
# endif
__MDSPAN_INLINE_FUNCTION
explicit constexpr extents(_Integral... __exts) noexcept
Expand Down Expand Up @@ -337,21 +334,16 @@ class extents
# ifdef __NVCC__
// NVCC seems to pick up rank_dynamic from the wrong extents type???
// NVCC chokes on the fold thingy here so wrote the workaround
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& ((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents)))))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(
_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents))))
# else
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& (_Np == rank() || _Np == rank_dynamic())))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND(_Np == rank() || _Np == rank_dynamic()))
# endif
__MDSPAN_CONDITIONAL_EXPLICIT(_Np != rank_dynamic())
__MDSPAN_INLINE_FUNCTION
Expand Down Expand Up @@ -386,21 +378,16 @@ class extents
# ifdef __NVCC__
// NVCC seems to pick up rank_dynamic from the wrong extents type???
// NVCC chokes on the fold thingy here so wrote the workaround
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& ((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents)))))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(
_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND((_Np == __detail::__count_dynamic_extents<_Extents...>::val) || (_Np == sizeof...(_Extents))))
# else
__MDSPAN_TEMPLATE_REQUIRES(
class _IndexType,
size_t _Np,
/* requires */
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
&& (_Np == rank() || _Np == rank_dynamic())))
_LIBCUDACXX_TEMPLATE(class _IndexType, size_t _Np)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _IndexType, index_type)
_LIBCUDACXX_AND _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _IndexType)
_LIBCUDACXX_AND(_Np == rank() || _Np == rank_dynamic()))
# endif
__MDSPAN_CONDITIONAL_EXPLICIT(_Np != rank_dynamic())
__MDSPAN_INLINE_FUNCTION
Expand Down
27 changes: 13 additions & 14 deletions libcudacxx/include/cuda/std/__mdspan/layout_left.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class layout_left::mapping
: __extents(__exts)
{}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -135,9 +135,9 @@ class layout_left::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
&& (extents_type::rank() <= 1)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
_LIBCUDACXX_AND(extents_type::rank() <= 1))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -150,8 +150,8 @@ class layout_left::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((extents_type::rank() > 0))
__MDSPAN_INLINE_FUNCTION constexpr mapping(
layout_stride::mapping<_OtherExtents> const& __other) // NOLINT(google-explicit-constructor)
Expand Down Expand Up @@ -190,11 +190,10 @@ class layout_left::mapping

//--------------------------------------------------------------------------------

__MDSPAN_TEMPLATE_REQUIRES(
class... _Indices,
/* requires */ ((sizeof...(_Indices) == extents_type::rank())
&& __MDSPAN_FOLD_AND((_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices)))))
_LIBCUDACXX_TEMPLATE(class... _Indices)
_LIBCUDACXX_REQUIRES((sizeof...(_Indices) == extents_type::rank()) _LIBCUDACXX_AND __MDSPAN_FOLD_AND(
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices))))
_CCCL_HOST_DEVICE constexpr index_type operator()(_Indices... __idxs) const noexcept
{
// Immediately cast incoming indices to `index_type`
Expand Down Expand Up @@ -227,8 +226,8 @@ class layout_left::mapping
return true;
}

__MDSPAN_TEMPLATE_REQUIRES(class _Ext = _Extents,
/* requires */ (_Ext::rank() > 0))
_LIBCUDACXX_TEMPLATE(class _Ext = _Extents)
_LIBCUDACXX_REQUIRES((_Ext::rank() > 0))
__MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type __i) const noexcept
{
Expand Down
27 changes: 13 additions & 14 deletions libcudacxx/include/cuda/std/__mdspan/layout_right.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class layout_right::mapping
: __extents(__exts)
{}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -140,9 +140,9 @@ class layout_right::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
&& (extents_type::rank() <= 1)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)
_LIBCUDACXX_AND(extents_type::rank() <= 1))
__MDSPAN_CONDITIONAL_EXPLICIT((!_CUDA_VSTD::is_convertible<_OtherExtents, extents_type>::value)) // needs two () due
// to comma
__MDSPAN_INLINE_FUNCTION constexpr mapping(
Expand All @@ -155,8 +155,8 @@ class layout_right::mapping
*/
}

__MDSPAN_TEMPLATE_REQUIRES(class _OtherExtents,
/* requires */ (_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents)))
_LIBCUDACXX_TEMPLATE(class _OtherExtents)
_LIBCUDACXX_REQUIRES(_CCCL_TRAIT(_CUDA_VSTD::is_constructible, extents_type, _OtherExtents))
__MDSPAN_CONDITIONAL_EXPLICIT((extents_type::rank() > 0))
__MDSPAN_INLINE_FUNCTION constexpr mapping(
layout_stride::mapping<_OtherExtents> const& __other) // NOLINT(google-explicit-constructor)
Expand Down Expand Up @@ -195,11 +195,10 @@ class layout_right::mapping

//--------------------------------------------------------------------------------

__MDSPAN_TEMPLATE_REQUIRES(
class... _Indices,
/* requires */ ((sizeof...(_Indices) == extents_type::rank())
&& __MDSPAN_FOLD_AND((_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices)))))
_LIBCUDACXX_TEMPLATE(class... _Indices)
_LIBCUDACXX_REQUIRES((sizeof...(_Indices) == extents_type::rank()) _LIBCUDACXX_AND __MDSPAN_FOLD_AND(
(_CCCL_TRAIT(_CUDA_VSTD::is_convertible, _Indices, index_type)
&& _CCCL_TRAIT(_CUDA_VSTD::is_nothrow_constructible, index_type, _Indices))))
_CCCL_HOST_DEVICE constexpr index_type operator()(_Indices... __idxs) const noexcept
{
return __compute_offset(__rank_count<0, extents_type::rank()>(), static_cast<index_type>(__idxs)...);
Expand Down Expand Up @@ -230,8 +229,8 @@ class layout_right::mapping
return true;
}

__MDSPAN_TEMPLATE_REQUIRES(class _Ext = _Extents,
/* requires */ (_Ext::rank() > 0))
_LIBCUDACXX_TEMPLATE(class _Ext = _Extents)
_LIBCUDACXX_REQUIRES((_Ext::rank() > 0))
__MDSPAN_INLINE_FUNCTION
constexpr index_type stride(rank_type __i) const noexcept
{
Expand Down
Loading

0 comments on commit 8e20c9a

Please sign in to comment.