You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
And the Sepialization of LoadInternal for BLOCK_LOAD_VECTORIZE has a bug that two overload of Load have a template parameter InputIteratorT that cannot be inferred from the input arguments. So it will always use the overload that skips vectorization.
/// Load a linear segment of items from memory, specialized for native pointer types (attempts vectorization)template <typename InputIteratorT>
__device__ __forceinline__ voidLoad(
InputT *block_ptr, ///< [in] The thread block's base input iterator for loading fromInputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load
{
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, items);
}
/// Load a linear segment of items from memory, specialized for native pointer types (attempts vectorization)template <typename InputIteratorT>
__device__ __forceinline__ voidLoad(
const InputT *block_ptr, ///< [in] The thread block's base input iterator for loading fromInputT (&items)[ITEMS_PER_THREAD]) ///< [out] Data to load
{
InternalLoadDirectBlockedVectorized<LOAD_DEFAULT>(linear_tid, block_ptr, items);
}
I tried removing that template parameter which enables vectorized load. But I've found another issue that some functions like ThreadLoad (which is used by BlockLoad) does not have const T* as input parameter, which prevents passing in const T* to BlockLoad::Load in some cases. The non-constness propagates all the way to user code. I hope there will be a fixe to that too.
make a kernel that load a block of data using vectorized algorithm.
template <typename InputT, int BLOCK_SIZE>
__global__ voidload_and_verify(const InputT *input_ptr, int numel) {
using BlockLoadT =
cub::BlockLoad<InputT, BLOCK_SIZE, 4, cub::BLOCK_LOAD_VECTORIZE>;
BlockLoadT block_load;
InputT thread_local_elems[4];
int block_offset =
blockIdx.x * (4 * BLOCK_SIZE); // offset of mem for this block to load
block_load.Load(input_ptr + block_offset, thread_local_elems);
cub::CTA_SYNC();
}
allocate a thrust device vector and pass its data pointer to that kernel. Since the address of the data is aligned and the size is a power of 2, we can expect vectorized load.
Thanks for submitting this issue - the CCCL team has been notified and we'll get back to you as soon as we can!
In the mean time, feel free to add any relevant information to this issue.
I can verify that vectorization doesn't happen due to the extra template parameter. Removing extra template parameters leads to vectorized loads, but I don't think we can fix the issue right now because vectorized load guarantees that:
/* - The following conditions will prevent vectorization and loading will fall * back to cub::BLOCK_LOAD_DIRECT: * - The block input offset is not quadword-aligned*/
There are no checks of the input alignment. When given a pointer that's not aligned to the vectorized type, block / warp load leads to invalid loads. I believe this has to be addressed before enabling vectorization.
I suggest we file another issue to check input pointer alignment and fallback to direct load and mark this issue as blocked. Since adding extra checks might affect performance, I suggest we also consider providing cuda::aligned_size_t overload.
Hi folks, it seems that BLOCK_LOAD_VECTORIZE is not working due to this bug. I tried it on Halfx8 and don't see STG.E.128.SYS instructions being used.
If the issue is how to guarantee alignment perhaps the documentation should be updated to say it's undefined behavior using an unaligned pointer in BlockLoad.load() ? This is equivalent behavior to loads/stores using vectorized types.
Is this a duplicate?
Type of Bug
Something else
Component
CUB
Describe the bug
BlockLoad/WarpLoad never attempts to vectorize even using BLOCK_LOAD_VECTORIZE/WARP_LOAD_VECTORIZE algorithm.
Since the implementation use
InternalLoad(temp_storage, linear_tid).Load(block_itr, items);
And the Sepialization of
LoadInternal
forBLOCK_LOAD_VECTORIZE
has a bug that two overload ofLoad
have a template parameterInputIteratorT
that cannot be inferred from the input arguments. So it will always use the overload that skips vectorization.https://github.com/NVIDIA/cccl/blob/3ef9dd9642da2d4e0b3ff77e445e73d7aabd4687/cub/cub/block/block_load.cuh#L793C1-L810C1
WarpLoad
has similar issues.https://github.com/NVIDIA/cccl/blob/3ef9dd9642da2d4e0b3ff77e445e73d7aabd4687/cub/cub/warp/warp_load.cuh#L320C1-L330C6
I tried removing that template parameter which enables vectorized load. But I've found another issue that some functions like
ThreadLoad
(which is used byBlockLoad
) does not haveconst T*
as input parameter, which prevents passing inconst T*
toBlockLoad::Load
in some cases. The non-constness propagates all the way to user code. I hope there will be a fixe to that too.https://github.com/NVIDIA/cccl/blob/3ef9dd9642da2d4e0b3ff77e445e73d7aabd4687/cub/cub/block/block_load.cuh#L185C1-L190C6
https://github.com/NVIDIA/cccl/blob/3ef9dd9642da2d4e0b3ff77e445e73d7aabd4687/cub/cub/thread/thread_load.cuh#L301C1-L311C2
https://github.com/NVIDIA/cccl/blob/3ef9dd9642da2d4e0b3ff77e445e73d7aabd4687/cub/cub/thread/thread_load.cuh#L348C1-L360C1
It is forwarded from NVIDIA/cub#739
How to Reproduce
But by cuda-gdb we can find that vectorized load is not used. It falls back to direct load.
Expected behavior
Vectorized Load should be used. More specifically, the code below should be invoked.
https://github.com/NVIDIA/cccl/blob/3ef9dd9642da2d4e0b3ff77e445e73d7aabd4687/cub/cub/block/block_load.cuh#L793C1-L800C10
Reproduction link
No response
Operating System
Ubuntu linux 23.04
nvidia-smi output
NVCC version
The text was updated successfully, but these errors were encountered: