Skip to content

Commit

Permalink
Make bit_cast play nice with extended floating point types (#2434)
Browse files Browse the repository at this point in the history
* Move `__is_nvbf16` and `__is_nvfp16` to their own file

* Make `bit_cast` play nice with extended floating point types
  • Loading branch information
miscco committed Sep 20, 2024
1 parent 28888eb commit 31c3eb9
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 23 deletions.
23 changes: 18 additions & 5 deletions libcudacxx/include/cuda/std/__bit/bit_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#endif // no system header

#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_trivially_copyable.h>
#include <cuda/std/__type_traits/is_trivially_default_constructible.h>
#include <cuda/std/detail/libcxx/include/cstring>
Expand All @@ -32,13 +33,19 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
# define _LIBCUDACXX_CONSTEXPR_BIT_CAST constexpr
#else // ^^^ _LIBCUDACXX_BIT_CAST ^^^ / vvv !_LIBCUDACXX_BIT_CAST vvv
# define _LIBCUDACXX_CONSTEXPR_BIT_CAST
# if defined(_CCCL_COMPILER_GCC) && __GNUC__ >= 8
// GCC starting with GCC8 warns about our extended floating point types having protected data members
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_GCC("-Wclass-memaccess")
# endif // _CCCL_COMPILER_GCC >= 8
#endif // !_LIBCUDACXX_BIT_CAST

template <class _To,
class _From,
__enable_if_t<(sizeof(_To) == sizeof(_From)), int> = 0,
__enable_if_t<_CCCL_TRAIT(is_trivially_copyable, _To), int> = 0,
__enable_if_t<_CCCL_TRAIT(is_trivially_copyable, _From), int> = 0>
template <
class _To,
class _From,
__enable_if_t<(sizeof(_To) == sizeof(_From)), int> = 0,
__enable_if_t<_CCCL_TRAIT(is_trivially_copyable, _To) || _CCCL_TRAIT(__is_extended_floating_point, _To), int> = 0,
__enable_if_t<_CCCL_TRAIT(is_trivially_copyable, _From) || _CCCL_TRAIT(__is_extended_floating_point, _From), int> = 0>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI _LIBCUDACXX_CONSTEXPR_BIT_CAST _To bit_cast(const _From& __from) noexcept
{
#if defined(_LIBCUDACXX_BIT_CAST)
Expand All @@ -53,6 +60,12 @@ _CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI _LIBCUDACXX_CONSTEXPR_BIT_CAST _To bit
#endif // !_LIBCUDACXX_BIT_CAST
}

#if !defined(_LIBCUDACXX_BIT_CAST)
# if defined(_CCCL_COMPILER_GCC) && __GNUC__ >= 8
_CCCL_DIAG_POP
# endif // _CCCL_COMPILER_GCC >= 8
#endif // !_LIBCUDACXX_BIT_CAST

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___BIT_BIT_CAST_H
5 changes: 1 addition & 4 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ _CCCL_DIAG_POP
# include <cuda/std/__type_traits/enable_if.h>
# include <cuda/std/__type_traits/integral_constant.h>
# include <cuda/std/__type_traits/is_constructible.h>
# include <cuda/std/__type_traits/is_extended_floating_point.h>
# include <cuda/std/cmath>
# include <cuda/std/complex>

Expand All @@ -42,10 +43,6 @@ _CCCL_DIAG_POP

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <>
struct __is_nvbf16<__nv_bfloat16> : true_type
{};

template <>
struct __complex_alignment<__nv_bfloat16> : integral_constant<size_t, alignof(__nv_bfloat162)>
{};
Expand Down
5 changes: 1 addition & 4 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# include <cuda/std/__type_traits/enable_if.h>
# include <cuda/std/__type_traits/integral_constant.h>
# include <cuda/std/__type_traits/is_constructible.h>
# include <cuda/std/__type_traits/is_extended_floating_point.h>
# include <cuda/std/cmath>
# include <cuda/std/complex>

Expand All @@ -39,10 +40,6 @@

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <>
struct __is_nvfp16<__half> : true_type
{};

template <>
struct __complex_alignment<__half> : integral_constant<size_t, alignof(__half2)>
{};
Expand Down
12 changes: 2 additions & 10 deletions libcudacxx/include/cuda/std/__complex/vector_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,18 @@
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_arithmetic.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/void_t.h>
#include <cuda/std/__utility/declval.h>
#include <cuda/std/cstddef>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
struct __is_nvfp16 : false_type
{};

template <class _Tp>
struct __is_nvbf16 : false_type
{};

template <class _Tp>
struct __is_complex_float
{
static constexpr auto value =
_CCCL_TRAIT(is_floating_point, _Tp) || __is_nvfp16<_Tp>::value || __is_nvbf16<_Tp>::value;
static constexpr auto value = _CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp);
};

template <class _Tp>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H
#define _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/integral_constant.h>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
struct __is_extended_floating_point : false_type
{};

#if _CCCL_STD_VER >= 2017 && defined(__cpp_inline_variables) && (__cpp_inline_variables >= 201606L)
template <class _Tp>
_LIBCUDACXX_INLINE_VAR constexpr bool __is_extended_floating_point_v = false;
#elif _CCCL_STD_VER >= 2014 && !defined(_LIBCUDACXX_HAS_NO_VARIABLE_TEMPLATES)
template <class _Tp>
_LIBCUDACXX_INLINE_VAR constexpr bool __is_extended_floating_point_v = __is_extended_floating_point<_Tp>::value;
#endif // _CCCL_STD_VER >= 2014

#if defined(_LIBCUDACXX_HAS_NVFP16)
# include <cuda_fp16.h>

template <>
struct __is_extended_floating_point<__half> : true_type
{};

# if _CCCL_STD_VER >= 2017 && defined(__cpp_inline_variables) && (__cpp_inline_variables >= 201606L)
template <>
_LIBCUDACXX_INLINE_VAR constexpr bool __is_extended_floating_point_v<__half> = true;
# endif // _CCCL_STD_VER >= 2014
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP

template <>
struct __is_extended_floating_point<__nv_bfloat16> : true_type
{};

# if _CCCL_STD_VER >= 2017 && defined(__cpp_inline_variables) && (__cpp_inline_variables >= 201606L)
template <>
_LIBCUDACXX_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = true;
# endif // _CCCL_STD_VER >= 2014
#endif // _LIBCUDACXX_HAS_NVBF16

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H
1 change: 1 addition & 0 deletions libcudacxx/include/cuda/std/type_traits
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include <cuda/std/__type_traits/is_destructible.h>
#include <cuda/std/__type_traits/is_empty.h>
#include <cuda/std/__type_traits/is_enum.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_final.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/is_function.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,40 @@ __host__ __device__ bool tests()
test_roundtrip_through<long long>(i);
}

#ifdef _LIBCUDACXX_HAS_NVFP16
// Extended floating point type __half
for (__half i :
{__float2half(0.0f),
__float2half(1.0f),
__float2half(-1.0f),
__float2half(10.0f),
__float2half(-10.0f),
__float2half(2.71828f),
__float2half(3.14159f)})
{
test_roundtrip_through_nested_T(i);
test_roundtrip_through_buffer(i);
test_roundtrip_through<cuda::std::int16_t>(i);
}
#endif // _LIBCUDACXX_HAS_NVFP16

#ifdef _LIBCUDACXX_HAS_NVBF16
// Extended floating point type __half
for (__nv_bfloat16 i :
{__float2bfloat16(0.0f),
__float2bfloat16(1.0f),
__float2bfloat16(-1.0f),
__float2bfloat16(10.0f),
__float2bfloat16(-10.0f),
__float2bfloat16(2.71828f),
__float2bfloat16(3.14159f)})
{
test_roundtrip_through_nested_T(i);
test_roundtrip_through_buffer(i);
test_roundtrip_through<cuda::std::int16_t>(i);
}
#endif // _LIBCUDACXX_HAS_NVBF16

// Test pointers
{
{
Expand Down

0 comments on commit 31c3eb9

Please sign in to comment.