Skip to content

Commit

Permalink
Fix building for CUDA 12.4 and for torch>=2.4 (#1297)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jul 24, 2024
1 parent 07c00d1 commit 9897a3a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
12 changes: 9 additions & 3 deletions k2/csrc/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2515,12 +2515,18 @@ struct HashOutputIteratorDeref { // this is what you get when you dereference

template <typename T>
struct HashOutputIterator { // outputs just the index of the pair.
explicit HashOutputIterator(T *t) : t_(t) {}
__device__ __forceinline__ HashOutputIteratorDeref<T> operator[](
explicit __host__ __device__ __forceinline__ HashOutputIterator(T *t)
: t_(t) {}
__host__ __device__ __forceinline__ HashOutputIteratorDeref<T> operator[](
int32_t idx) const {
return HashOutputIteratorDeref<T>(t_ + idx);
}
__device__ __forceinline__ HashOutputIterator operator+(size_t offset) {
__host__ __device__ __forceinline__ HashOutputIteratorDeref<T> operator*()
const {
return HashOutputIteratorDeref<T>(t_);
}
__host__ __device__ __forceinline__ HashOutputIterator
operator+(size_t offset) {
return HashOutputIterator{t_ + offset};
}
T *t_;
Expand Down
12 changes: 9 additions & 3 deletions k2/csrc/ragged_ops_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,18 @@ struct PairOutputIteratorDeref { // this is what you get when you dereference

template <typename T>
struct PairOutputIterator { // outputs just the index of the pair.
explicit PairOutputIterator(int32_t *i) : i_(i) {}
__device__ __forceinline__ PairOutputIteratorDeref<T> operator[](
explicit __host__ __device__ __forceinline__ PairOutputIterator(int32_t *i)
: i_(i) {}
__host__ __device__ __forceinline__ PairOutputIteratorDeref<T> operator[](
int32_t idx) const {
return PairOutputIteratorDeref<T>(i_ + idx);
}
__device__ __forceinline__ PairOutputIterator operator+(int32_t offset) {
__host__ __device__ __forceinline__ PairOutputIteratorDeref<T> operator*()
const {
return PairOutputIteratorDeref<T>(i_);
}
__host__ __device__ __forceinline__ PairOutputIterator
operator+(int32_t offset) {
return PairOutputIterator{i_ + offset};
}
int32_t *i_;
Expand Down
9 changes: 9 additions & 0 deletions k2/python/csrc/torch.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
#include "k2/python/csrc/torch.h"
#include "torch/extension.h"

#if K2_TORCH_VERSION_MAJOR > 2 || \
(K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 4)
// For torch >= 2.4.x
// do nothing to fix the following error
// error: class "pybind11::detail::type_caster<c10::ScalarType, void>" has
// already been defined
#else
// For torch < 2.4
namespace pybind11 {
namespace detail {

Expand Down Expand Up @@ -71,6 +79,7 @@ struct type_caster<torch::ScalarType> {

} // namespace detail
} // namespace pybind11
#endif

namespace k2 {
/* Transfer an object to a specific device.
Expand Down

0 comments on commit 9897a3a

Please sign in to comment.