From 007e97b96789fce4a5a0301306db188d8756842e Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Fri, 6 Sep 2024 10:48:05 +0200 Subject: [PATCH] Widen histogram agent constructor to more types This allows to accept more data types beyond arrays of exact static sizes. No SASS changes on CUB device histogram test with CTK 12.6. Fixes #1877 for AgentHistogram --- cub/cub/agent/agent_histogram.cuh | 24 +++++++++---------- .../device/dispatch/dispatch_histogram.cuh | 12 +++++----- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cub/cub/agent/agent_histogram.cuh b/cub/cub/agent/agent_histogram.cuh index ce204273da..f324de52bc 100644 --- a/cub/cub/agent/agent_histogram.cuh +++ b/cub/cub/agent/agent_histogram.cuh @@ -287,22 +287,22 @@ struct AgentHistogram SampleT* d_native_samples; /// The number of output bins for each channel - int (&num_output_bins)[NUM_ACTIVE_CHANNELS]; + int* num_output_bins; /// The number of privatized bins for each channel - int (&num_privatized_bins)[NUM_ACTIVE_CHANNELS]; + int* num_privatized_bins; - /// Reference to gmem privatized histograms for each channel + /// Copy of gmem privatized histograms for each channel CounterT* d_privatized_histograms[NUM_ACTIVE_CHANNELS]; /// Reference to final output histograms (gmem) - CounterT* (&d_output_histograms)[NUM_ACTIVE_CHANNELS]; + CounterT** d_output_histograms; /// The transform operator for determining output bin-ids from privatized counter indices, one for each channel - OutputDecodeOpT (&output_decode_op)[NUM_ACTIVE_CHANNELS]; + OutputDecodeOpT* output_decode_op; /// The transform operator for determining privatized counter indices from samples, one for each channel - PrivatizedDecodeOpT (&privatized_decode_op)[NUM_ACTIVE_CHANNELS]; + PrivatizedDecodeOpT* privatized_decode_op; /// Whether to prefer privatized smem counters vs privatized global counters bool prefer_smem; @@ -810,12 +810,12 @@ struct AgentHistogram _CCCL_DEVICE _CCCL_FORCEINLINE AgentHistogram( TempStorage& temp_storage, SampleIteratorT d_samples, - int (&num_output_bins)[NUM_ACTIVE_CHANNELS], - int (&num_privatized_bins)[NUM_ACTIVE_CHANNELS], - CounterT* (&d_output_histograms)[NUM_ACTIVE_CHANNELS], - CounterT* (&d_privatized_histograms)[NUM_ACTIVE_CHANNELS], - OutputDecodeOpT (&output_decode_op)[NUM_ACTIVE_CHANNELS], - PrivatizedDecodeOpT (&privatized_decode_op)[NUM_ACTIVE_CHANNELS]) + int* num_output_bins, + int* num_privatized_bins, + CounterT** d_output_histograms, + CounterT** d_privatized_histograms, + OutputDecodeOpT* output_decode_op, + PrivatizedDecodeOpT* privatized_decode_op) : temp_storage(temp_storage.Alias()) , d_wrapped_samples(d_samples) , d_native_samples(NativePointer(d_wrapped_samples)) diff --git a/cub/cub/device/dispatch/dispatch_histogram.cuh b/cub/cub/device/dispatch/dispatch_histogram.cuh index 1019569240..1839385b19 100644 --- a/cub/cub/device/dispatch/dispatch_histogram.cuh +++ b/cub/cub/device/dispatch/dispatch_histogram.cuh @@ -238,12 +238,12 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::AgentHistogramPolicyT::BLOCK AgentHistogramT agent( temp_storage, d_samples, - num_output_bins_wrapper.__elems_, - num_privatized_bins_wrapper.__elems_, - d_output_histograms_wrapper.__elems_, - d_privatized_histograms_wrapper.__elems_, - output_decode_op_wrapper.__elems_, - privatized_decode_op_wrapper.__elems_); + num_output_bins_wrapper.data(), + num_privatized_bins_wrapper.data(), + d_output_histograms_wrapper.data(), + d_privatized_histograms_wrapper.data(), + output_decode_op_wrapper.data(), + privatized_decode_op_wrapper.data()); // Initialize counters agent.InitBinCounters();