cuda-cccl 0.1.3.2.0.dev438__cp310-cp310-manylinux_2_24_aarch64.whl → 0.3.0__cp310-cp310-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (60) hide show
  1. cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +23 -0
  2. cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +22 -14
  3. cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
  4. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
  5. cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
  6. cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
  7. cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
  8. cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +321 -262
  9. cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +8 -0
  10. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
  11. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
  12. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +203 -51
  13. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
  14. cuda/cccl/headers/include/cub/util_device.cuh +51 -35
  15. cuda/cccl/headers/include/cuda/__algorithm/copy.h +3 -3
  16. cuda/cccl/headers/include/cuda/__device/all_devices.h +3 -6
  17. cuda/cccl/headers/include/cuda/__device/arch_traits.h +3 -3
  18. cuda/cccl/headers/include/cuda/__device/attributes.h +7 -7
  19. cuda/cccl/headers/include/cuda/__device/device_ref.h +3 -10
  20. cuda/cccl/headers/include/cuda/__driver/driver_api.h +225 -33
  21. cuda/cccl/headers/include/cuda/__event/event.h +7 -8
  22. cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
  23. cuda/cccl/headers/include/cuda/__event/timed_event.h +3 -4
  24. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
  25. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
  26. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
  27. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
  28. cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
  29. cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
  30. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +17 -12
  31. cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
  32. cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
  33. cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
  34. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
  35. cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
  36. cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
  37. cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
  38. cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
  39. cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
  40. cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
  41. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
  42. cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
  43. cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
  44. cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
  45. cuda/cccl/parallel/experimental/__init__.py +4 -0
  46. cuda/cccl/parallel/experimental/_bindings.pyi +28 -0
  47. cuda/cccl/parallel/experimental/_bindings_impl.pyx +140 -0
  48. cuda/cccl/parallel/experimental/algorithms/__init__.py +4 -0
  49. cuda/cccl/parallel/experimental/algorithms/_reduce.py +0 -2
  50. cuda/cccl/parallel/experimental/algorithms/_scan.py +0 -2
  51. cuda/cccl/parallel/experimental/algorithms/_three_way_partition.py +261 -0
  52. cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-310-aarch64-linux-gnu.so +0 -0
  53. cuda/cccl/parallel/experimental/cu12/cccl/libcccl.c.parallel.so +0 -0
  54. cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-310-aarch64-linux-gnu.so +0 -0
  55. cuda/cccl/parallel/experimental/cu13/cccl/libcccl.c.parallel.so +0 -0
  56. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/METADATA +1 -1
  57. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/RECORD +59 -57
  58. cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
  59. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/WHEEL +0 -0
  60. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -172,6 +172,10 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)
172
172
  AccumT,
173
173
  TransformOpT>;
174
174
 
175
+ static_assert(sizeof(typename AgentReduceT::TempStorage) <= max_smem_per_block,
176
+ "cub::DeviceReduce ran out of CUDA shared memory, which we judged to be extremely unlikely. Please "
177
+ "file an issue at: https://github.com/NVIDIA/cccl/issues");
178
+
175
179
  // Shared memory storage
176
180
  __shared__ typename AgentReduceT::TempStorage temp_storage;
177
181
 
@@ -253,6 +257,10 @@ CUB_DETAIL_KERNEL_ATTRIBUTES __launch_bounds__(
253
257
  AccumT,
254
258
  TransformOpT>;
255
259
 
260
+ static_assert(sizeof(typename AgentReduceT::TempStorage) <= max_smem_per_block,
261
+ "cub::DeviceReduce ran out of CUDA shared memory, which we judged to be extremely unlikely. Please "
262
+ "file an issue at: https://github.com/NVIDIA/cccl/issues");
263
+
256
264
  // Shared memory storage
257
265
  __shared__ typename AgentReduceT::TempStorage temp_storage;
258
266
 
@@ -29,6 +29,56 @@ using local_segment_index_t = ::cuda::std::uint32_t;
29
29
  // Type used for total number of segments and to index within segments globally
30
30
  using global_segment_offset_t = ::cuda::std::int64_t;
31
31
 
32
+ template <typename OffsetT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
33
+ struct LargeSegmentsSelectorT
34
+ {
35
+ OffsetT value{};
36
+ BeginOffsetIteratorT d_offset_begin{};
37
+ EndOffsetIteratorT d_offset_end{};
38
+ global_segment_offset_t base_segment_offset{};
39
+
40
+ #if !_CCCL_COMPILER(NVRTC)
41
+ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE
42
+ LargeSegmentsSelectorT(OffsetT value, BeginOffsetIteratorT d_offset_begin, EndOffsetIteratorT d_offset_end)
43
+ : value(value)
44
+ , d_offset_begin(d_offset_begin)
45
+ , d_offset_end(d_offset_end)
46
+ {}
47
+ #endif
48
+
49
+ _CCCL_DEVICE _CCCL_FORCEINLINE bool operator()(local_segment_index_t segment_id) const
50
+ {
51
+ const OffsetT segment_size =
52
+ d_offset_end[base_segment_offset + segment_id] - d_offset_begin[base_segment_offset + segment_id];
53
+ return segment_size > value;
54
+ }
55
+ };
56
+
57
+ template <typename OffsetT, typename BeginOffsetIteratorT, typename EndOffsetIteratorT>
58
+ struct SmallSegmentsSelectorT
59
+ {
60
+ OffsetT value{};
61
+ BeginOffsetIteratorT d_offset_begin{};
62
+ EndOffsetIteratorT d_offset_end{};
63
+ global_segment_offset_t base_segment_offset{};
64
+
65
+ #if !_CCCL_COMPILER(NVRTC)
66
+ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE
67
+ SmallSegmentsSelectorT(OffsetT value, BeginOffsetIteratorT d_offset_begin, EndOffsetIteratorT d_offset_end)
68
+ : value(value)
69
+ , d_offset_begin(d_offset_begin)
70
+ , d_offset_end(d_offset_end)
71
+ {}
72
+ #endif
73
+
74
+ _CCCL_DEVICE _CCCL_FORCEINLINE bool operator()(local_segment_index_t segment_id) const
75
+ {
76
+ const OffsetT segment_size =
77
+ d_offset_end[base_segment_offset + segment_id] - d_offset_begin[base_segment_offset + segment_id];
78
+ return segment_size < value;
79
+ }
80
+ };
81
+
32
82
  /**
33
83
  * @brief Fallback kernel, in case there's not enough segments to
34
84
  * take advantage of partitioning.
@@ -89,7 +139,7 @@ __launch_bounds__(ChainedPolicyT::ActivePolicy::LargeSegmentPolicy::BLOCK_THREAD
89
139
  {
90
140
  using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
91
141
  using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
92
- using MediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT::MediumPolicyT;
142
+ using MediumPolicyT = typename ActivePolicyT::MediumSegmentPolicy;
93
143
 
94
144
  const auto segment_id = static_cast<local_segment_index_t>(blockIdx.x);
95
145
  OffsetT segment_begin = d_begin_offsets[segment_id];
@@ -253,7 +303,7 @@ template <SortOrder Order,
253
303
  typename BeginOffsetIteratorT,
254
304
  typename EndOffsetIteratorT,
255
305
  typename OffsetT>
256
- __launch_bounds__(ChainedPolicyT::ActivePolicy::SmallAndMediumSegmentedSortPolicyT::BLOCK_THREADS)
306
+ __launch_bounds__(ChainedPolicyT::ActivePolicy::SmallSegmentPolicy::BLOCK_THREADS)
257
307
  CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortKernelSmall(
258
308
  local_segment_index_t small_segments,
259
309
  local_segment_index_t medium_segments,
@@ -272,10 +322,9 @@ __launch_bounds__(ChainedPolicyT::ActivePolicy::SmallAndMediumSegmentedSortPolic
272
322
  const local_segment_index_t tid = threadIdx.x;
273
323
  const local_segment_index_t bid = blockIdx.x;
274
324
 
275
- using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
276
- using SmallAndMediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
277
- using MediumPolicyT = typename SmallAndMediumPolicyT::MediumPolicyT;
278
- using SmallPolicyT = typename SmallAndMediumPolicyT::SmallPolicyT;
325
+ using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
326
+ using SmallPolicyT = typename ActivePolicyT::SmallSegmentPolicy;
327
+ using MediumPolicyT = typename ActivePolicyT::MediumSegmentPolicy;
279
328
 
280
329
  constexpr auto threads_per_medium_segment = static_cast<local_segment_index_t>(MediumPolicyT::WARP_THREADS);
281
330
  constexpr auto threads_per_small_segment = static_cast<local_segment_index_t>(SmallPolicyT::WARP_THREADS);
@@ -286,11 +335,9 @@ __launch_bounds__(ChainedPolicyT::ActivePolicy::SmallAndMediumSegmentedSortPolic
286
335
  using SmallAgentWarpMergeSortT =
287
336
  sub_warp_merge_sort::AgentSubWarpSort<Order == SortOrder::Descending, SmallPolicyT, KeyT, ValueT, OffsetT>;
288
337
 
289
- constexpr auto segments_per_medium_block =
290
- static_cast<local_segment_index_t>(SmallAndMediumPolicyT::SEGMENTS_PER_MEDIUM_BLOCK);
338
+ constexpr auto segments_per_medium_block = static_cast<local_segment_index_t>(MediumPolicyT::SEGMENTS_PER_BLOCK);
291
339
 
292
- constexpr auto segments_per_small_block =
293
- static_cast<local_segment_index_t>(SmallAndMediumPolicyT::SEGMENTS_PER_SMALL_BLOCK);
340
+ constexpr auto segments_per_small_block = static_cast<local_segment_index_t>(SmallPolicyT::SEGMENTS_PER_BLOCK);
294
341
 
295
342
  __shared__ union
296
343
  {
@@ -202,14 +202,18 @@ _CCCL_HOST_DEVICE _CCCL_CONSTEVAL auto load_store_type()
202
202
  }
203
203
  }
204
204
 
205
- template <typename VectorizedPolicy, typename Offset, typename F, typename RandomAccessIteratorOut, typename... InputT>
205
+ template <typename VectorizedPolicy,
206
+ typename Offset,
207
+ typename F,
208
+ typename RandomAccessIteratorOut,
209
+ typename... RandomAccessIteratorsIn>
206
210
  _CCCL_DEVICE void transform_kernel_vectorized(
207
211
  Offset num_items,
208
212
  int num_elem_per_thread_prefetch,
209
213
  bool can_vectorize,
210
214
  F f,
211
215
  RandomAccessIteratorOut out,
212
- const InputT*... ins)
216
+ RandomAccessIteratorsIn... ins)
213
217
  {
214
218
  constexpr int block_dim = VectorizedPolicy::block_threads;
215
219
  constexpr int items_per_thread = VectorizedPolicy::items_per_thread_vectorized;
@@ -240,9 +244,12 @@ _CCCL_DEVICE void transform_kernel_vectorized(
240
244
  constexpr int load_store_size = VectorizedPolicy::load_store_word_size;
241
245
  using load_store_t = decltype(load_store_type<load_store_size>());
242
246
  using output_t = it_value_t<RandomAccessIteratorOut>;
243
- using result_t = ::cuda::std::decay_t<::cuda::std::invoke_result_t<F, const InputT&...>>;
247
+ using result_t = ::cuda::std::decay_t<::cuda::std::invoke_result_t<F, const it_value_t<RandomAccessIteratorsIn>&...>>;
244
248
  // picks output type size if there are no inputs
245
- constexpr int element_size = int{first_item(sizeof(InputT)..., size_of<output_t>)};
249
+ constexpr int element_size = int{first_nonzero_value(
250
+ (sizeof(it_value_t<RandomAccessIteratorsIn>)
251
+ * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...,
252
+ size_of<output_t>)};
246
253
  constexpr int load_store_count = (items_per_thread * element_size) / load_store_size;
247
254
 
248
255
  static_assert((items_per_thread * element_size) % load_store_size == 0);
@@ -258,18 +265,35 @@ _CCCL_DEVICE void transform_kernel_vectorized(
258
265
 
259
266
  auto provide_array = [&](auto... inputs) {
260
267
  // load inputs
261
- // TODO(bgruber): we could support fancy iterators for loading here as well (and only vectorize some inputs)
262
- [[maybe_unused]] auto load_tile_vectorized = [&](auto* in, auto& input) {
263
- auto in_vec = reinterpret_cast<const load_store_t*>(in);
264
- auto input_vec = reinterpret_cast<load_store_t*>(input.data());
265
- _CCCL_PRAGMA_UNROLL_FULL()
266
- for (int i = 0; i < load_store_count; ++i)
268
+ [[maybe_unused]] auto load_tile = [](auto in, auto& input) {
269
+ if constexpr (THRUST_NS_QUALIFIER::is_contiguous_iterator_v<decltype(in)>)
267
270
  {
268
- input_vec[i] = in_vec[i * VectorizedPolicy::block_threads + threadIdx.x];
271
+ auto in_vec = reinterpret_cast<const load_store_t*>(in) + threadIdx.x;
272
+ auto input_vec = reinterpret_cast<load_store_t*>(input.data());
273
+ _CCCL_PRAGMA_UNROLL_FULL()
274
+ for (int i = 0; i < load_store_count; ++i)
275
+ {
276
+ input_vec[i] = in_vec[i * VectorizedPolicy::block_threads];
277
+ }
278
+ }
279
+ else
280
+ {
281
+ constexpr int elems = load_store_size / element_size;
282
+ in += threadIdx.x * elems;
283
+ _CCCL_PRAGMA_UNROLL_FULL()
284
+ for (int i = 0; i < load_store_count; ++i)
285
+ {
286
+ _CCCL_PRAGMA_UNROLL_FULL()
287
+ for (int j = 0; j < elems; ++j)
288
+ {
289
+ input[i * elems + j] = in[i * elems * VectorizedPolicy::block_threads + j];
290
+ }
291
+ }
269
292
  }
270
293
  };
271
294
  _CCCL_PDL_GRID_DEPENDENCY_SYNC();
272
- (load_tile_vectorized(ins, inputs), ...);
295
+ (load_tile(ins, inputs), ...);
296
+
273
297
  // Benchmarks showed up to 38% slowdown on H200 (some improvements as well), so omitted. See #5249 for details.
274
298
  // _CCCL_PDL_TRIGGER_NEXT_LAUNCH();
275
299
 
@@ -280,7 +304,7 @@ _CCCL_DEVICE void transform_kernel_vectorized(
280
304
  output[i] = f(inputs[i]...);
281
305
  }
282
306
  };
283
- provide_array(uninitialized_array<InputT, items_per_thread>{}...);
307
+ provide_array(uninitialized_array<it_value_t<RandomAccessIteratorsIn>, items_per_thread>{}...);
284
308
 
285
309
  // write output
286
310
  if constexpr (can_vectorize_store)
@@ -47,6 +47,118 @@ namespace detail
47
47
  {
48
48
  namespace segmented_sort
49
49
  {
50
+
51
+ template <typename PolicyT, typename = void>
52
+ struct SegmentedSortPolicyWrapper : PolicyT
53
+ {
54
+ CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper(PolicyT base)
55
+ : PolicyT(base)
56
+ {}
57
+ };
58
+
59
+ template <typename StaticPolicyT>
60
+ struct SegmentedSortPolicyWrapper<StaticPolicyT,
61
+ _CUDA_VSTD::void_t<typename StaticPolicyT::LargeSegmentPolicy,
62
+ typename StaticPolicyT::SmallSegmentPolicy,
63
+ typename StaticPolicyT::MediumSegmentPolicy>> : StaticPolicyT
64
+ {
65
+ CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper(StaticPolicyT base)
66
+ : StaticPolicyT(base)
67
+ {}
68
+
69
+ CUB_RUNTIME_FUNCTION static constexpr auto LargeSegment()
70
+ {
71
+ return cub::detail::MakePolicyWrapper(typename StaticPolicyT::LargeSegmentPolicy());
72
+ }
73
+
74
+ CUB_RUNTIME_FUNCTION static constexpr auto SmallSegment()
75
+ {
76
+ return cub::detail::MakePolicyWrapper(typename StaticPolicyT::SmallSegmentPolicy());
77
+ }
78
+
79
+ CUB_RUNTIME_FUNCTION static constexpr auto MediumSegment()
80
+ {
81
+ return cub::detail::MakePolicyWrapper(typename StaticPolicyT::MediumSegmentPolicy());
82
+ }
83
+
84
+ CUB_RUNTIME_FUNCTION static constexpr int PartitioningThreshold()
85
+ {
86
+ return StaticPolicyT::PARTITIONING_THRESHOLD;
87
+ }
88
+
89
+ CUB_RUNTIME_FUNCTION static constexpr int LargeSegmentRadixBits()
90
+ {
91
+ return StaticPolicyT::LargeSegmentPolicy::RADIX_BITS;
92
+ }
93
+
94
+ CUB_RUNTIME_FUNCTION static constexpr int SegmentsPerSmallBlock()
95
+ {
96
+ return StaticPolicyT::SmallSegmentPolicy::SEGMENTS_PER_BLOCK;
97
+ }
98
+
99
+ CUB_RUNTIME_FUNCTION static constexpr int SegmentsPerMediumBlock()
100
+ {
101
+ return StaticPolicyT::MediumSegmentPolicy::SEGMENTS_PER_BLOCK;
102
+ }
103
+
104
+ CUB_RUNTIME_FUNCTION static constexpr int SmallPolicyItemsPerTile()
105
+ {
106
+ return StaticPolicyT::SmallSegmentPolicy::ITEMS_PER_TILE;
107
+ }
108
+
109
+ CUB_RUNTIME_FUNCTION static constexpr int MediumPolicyItemsPerTile()
110
+ {
111
+ return StaticPolicyT::MediumSegmentPolicy::ITEMS_PER_TILE;
112
+ }
113
+
114
+ CUB_RUNTIME_FUNCTION static constexpr CacheLoadModifier LargeSegmentLoadModifier()
115
+ {
116
+ return StaticPolicyT::LargeSegmentPolicy::LOAD_MODIFIER;
117
+ }
118
+
119
+ CUB_RUNTIME_FUNCTION static constexpr BlockLoadAlgorithm LargeSegmentLoadAlgorithm()
120
+ {
121
+ return StaticPolicyT::LargeSegmentPolicy::LOAD_ALGORITHM;
122
+ }
123
+
124
+ CUB_RUNTIME_FUNCTION static constexpr WarpLoadAlgorithm MediumSegmentLoadAlgorithm()
125
+ {
126
+ return StaticPolicyT::MediumSegmentPolicy::LOAD_ALGORITHM;
127
+ }
128
+
129
+ CUB_RUNTIME_FUNCTION static constexpr WarpLoadAlgorithm SmallSegmentLoadAlgorithm()
130
+ {
131
+ return StaticPolicyT::SmallSegmentPolicy::LOAD_ALGORITHM;
132
+ }
133
+
134
+ CUB_RUNTIME_FUNCTION static constexpr WarpStoreAlgorithm MediumSegmentStoreAlgorithm()
135
+ {
136
+ return StaticPolicyT::MediumSegmentPolicy::STORE_ALGORITHM;
137
+ }
138
+
139
+ CUB_RUNTIME_FUNCTION static constexpr WarpStoreAlgorithm SmallSegmentStoreAlgorithm()
140
+ {
141
+ return StaticPolicyT::SmallSegmentPolicy::STORE_ALGORITHM;
142
+ }
143
+
144
+ #if defined(CUB_ENABLE_POLICY_PTX_JSON)
145
+ _CCCL_DEVICE static constexpr auto EncodedPolicy()
146
+ {
147
+ using namespace ptx_json;
148
+ return object<key<"LargeSegmentPolicy">() = LargeSegment().EncodedPolicy(),
149
+ key<"SmallSegmentPolicy">() = SmallSegment().EncodedPolicy(),
150
+ key<"MediumSegmentPolicy">() = MediumSegment().EncodedPolicy(),
151
+ key<"PartitioningThreshold">() = value<StaticPolicyT::PARTITIONING_THRESHOLD>()>();
152
+ }
153
+ #endif
154
+ };
155
+
156
+ template <typename PolicyT>
157
+ CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper<PolicyT> MakeSegmentedSortPolicyWrapper(PolicyT policy)
158
+ {
159
+ return SegmentedSortPolicyWrapper<PolicyT>{policy};
160
+ }
161
+
50
162
  template <typename KeyT, typename ValueT>
51
163
  struct policy_hub
52
164
  {
@@ -71,12 +183,19 @@ struct policy_hub
71
183
 
72
184
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(7);
73
185
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(7);
74
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
75
- BLOCK_THREADS,
76
- // Small policy
77
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
78
- // Medium policy
79
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
186
+
187
+ using SmallSegmentPolicy =
188
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
189
+ 4 /* Threads per segment */,
190
+ ITEMS_PER_SMALL_THREAD,
191
+ WARP_LOAD_DIRECT,
192
+ LOAD_DEFAULT>;
193
+ using MediumSegmentPolicy =
194
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
195
+ 32 /* Threads per segment */,
196
+ ITEMS_PER_MEDIUM_THREAD,
197
+ WARP_LOAD_DIRECT,
198
+ LOAD_DEFAULT>;
80
199
  };
81
200
 
82
201
  struct Policy600 : ChainedPolicy<600, Policy600, Policy500>
@@ -97,12 +216,19 @@ struct policy_hub
97
216
 
98
217
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
99
218
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
100
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
101
- BLOCK_THREADS,
102
- // Small policy
103
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
104
- // Medium policy
105
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
219
+
220
+ using SmallSegmentPolicy =
221
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
222
+ 4 /* Threads per segment */,
223
+ ITEMS_PER_SMALL_THREAD,
224
+ WARP_LOAD_DIRECT,
225
+ LOAD_DEFAULT>;
226
+ using MediumSegmentPolicy =
227
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
228
+ 32 /* Threads per segment */,
229
+ ITEMS_PER_MEDIUM_THREAD,
230
+ WARP_LOAD_DIRECT,
231
+ LOAD_DEFAULT>;
106
232
  };
107
233
 
108
234
  struct Policy610 : ChainedPolicy<610, Policy610, Policy600>
@@ -123,12 +249,19 @@ struct policy_hub
123
249
 
124
250
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
125
251
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
126
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
127
- BLOCK_THREADS,
128
- // Small policy
129
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
130
- // Medium policy
131
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
252
+
253
+ using SmallSegmentPolicy =
254
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
255
+ 4 /* Threads per segment */,
256
+ ITEMS_PER_SMALL_THREAD,
257
+ WARP_LOAD_DIRECT,
258
+ LOAD_DEFAULT>;
259
+ using MediumSegmentPolicy =
260
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
261
+ 32 /* Threads per segment */,
262
+ ITEMS_PER_MEDIUM_THREAD,
263
+ WARP_LOAD_DIRECT,
264
+ LOAD_DEFAULT>;
132
265
  };
133
266
 
134
267
  struct Policy620 : ChainedPolicy<620, Policy620, Policy610>
@@ -149,12 +282,19 @@ struct policy_hub
149
282
 
150
283
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
151
284
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
152
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
153
- BLOCK_THREADS,
154
- // Small policy
155
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
156
- // Medium policy
157
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
285
+
286
+ using SmallSegmentPolicy =
287
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
288
+ 4 /* Threads per segment */,
289
+ ITEMS_PER_SMALL_THREAD,
290
+ WARP_LOAD_DIRECT,
291
+ LOAD_DEFAULT>;
292
+ using MediumSegmentPolicy =
293
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
294
+ 32 /* Threads per segment */,
295
+ ITEMS_PER_MEDIUM_THREAD,
296
+ WARP_LOAD_DIRECT,
297
+ LOAD_DEFAULT>;
158
298
  };
159
299
 
160
300
  struct Policy700 : ChainedPolicy<700, Policy700, Policy620>
@@ -175,15 +315,19 @@ struct policy_hub
175
315
 
176
316
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(7);
177
317
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(KEYS_ONLY ? 11 : 7);
178
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
179
- BLOCK_THREADS,
180
- // Small policy
181
- AgentSubWarpMergeSortPolicy<KEYS_ONLY ? 4 : 8 /* Threads per segment */,
182
- ITEMS_PER_SMALL_THREAD,
183
- WARP_LOAD_DIRECT,
184
- LOAD_DEFAULT>,
185
- // Medium policy
186
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
318
+
319
+ using SmallSegmentPolicy =
320
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
321
+ KEYS_ONLY ? 4 : 8 /* Threads per segment */,
322
+ ITEMS_PER_SMALL_THREAD,
323
+ WARP_LOAD_DIRECT,
324
+ LOAD_DEFAULT>;
325
+ using MediumSegmentPolicy =
326
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
327
+ 32 /* Threads per segment */,
328
+ ITEMS_PER_MEDIUM_THREAD,
329
+ WARP_LOAD_DIRECT,
330
+ LOAD_DEFAULT>;
187
331
  };
188
332
 
189
333
  struct Policy800 : ChainedPolicy<800, Policy800, Policy700>
@@ -202,15 +346,19 @@ struct policy_hub
202
346
 
203
347
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
204
348
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(KEYS_ONLY ? 7 : 11);
205
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
206
- BLOCK_THREADS,
207
- // Small policy
208
- AgentSubWarpMergeSortPolicy<KEYS_ONLY ? 4 : 2 /* Threads per segment */,
209
- ITEMS_PER_SMALL_THREAD,
210
- WARP_LOAD_TRANSPOSE,
211
- LOAD_DEFAULT>,
212
- // Medium policy
213
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_TRANSPOSE, LOAD_DEFAULT>>;
349
+
350
+ using SmallSegmentPolicy =
351
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
352
+ KEYS_ONLY ? 4 : 2 /* Threads per segment */,
353
+ ITEMS_PER_SMALL_THREAD,
354
+ WARP_LOAD_TRANSPOSE,
355
+ LOAD_DEFAULT>;
356
+ using MediumSegmentPolicy =
357
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
358
+ 32 /* Threads per segment */,
359
+ ITEMS_PER_MEDIUM_THREAD,
360
+ WARP_LOAD_TRANSPOSE,
361
+ LOAD_DEFAULT>;
214
362
  };
215
363
 
216
364
  struct Policy860 : ChainedPolicy<860, Policy860, Policy800>
@@ -230,15 +378,19 @@ struct policy_hub
230
378
  static constexpr bool LARGE_ITEMS = sizeof(DominantT) > 4;
231
379
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(LARGE_ITEMS ? 7 : 9);
232
380
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(LARGE_ITEMS ? 9 : 7);
233
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
234
- BLOCK_THREADS,
235
- // Small policy
236
- AgentSubWarpMergeSortPolicy<LARGE_ITEMS ? 8 : 2 /* Threads per segment */,
237
- ITEMS_PER_SMALL_THREAD,
238
- WARP_LOAD_TRANSPOSE,
239
- LOAD_LDG>,
240
- // Medium policy
241
- AgentSubWarpMergeSortPolicy<16 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_TRANSPOSE, LOAD_LDG>>;
381
+
382
+ using SmallSegmentPolicy =
383
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
384
+ LARGE_ITEMS ? 8 : 2 /* Threads per segment */,
385
+ ITEMS_PER_SMALL_THREAD,
386
+ WARP_LOAD_TRANSPOSE,
387
+ LOAD_LDG>;
388
+ using MediumSegmentPolicy =
389
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
390
+ 16 /* Threads per segment */,
391
+ ITEMS_PER_MEDIUM_THREAD,
392
+ WARP_LOAD_TRANSPOSE,
393
+ LOAD_LDG>;
242
394
  };
243
395
 
244
396
  using MaxPolicy = Policy860;
@@ -282,21 +282,45 @@ _CCCL_HOST_DEVICE constexpr int arch_to_min_bytes_in_flight(int sm_arch)
282
282
  return 12 * 1024; // V100 and below
283
283
  }
284
284
 
285
- template <typename T, typename... Ts>
286
- _CCCL_HOST_DEVICE constexpr bool all_equal([[maybe_unused]] T head, Ts... tail)
285
+ template <typename H, typename... Ts>
286
+ _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal(H head, Ts... values)
287
287
  {
288
- return ((head == tail) && ...);
288
+ size_t first = 0;
289
+ for (size_t v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
290
+ {
291
+ if (v == 0)
292
+ {
293
+ continue;
294
+ }
295
+ if (first == 0)
296
+ {
297
+ first = v;
298
+ }
299
+ else if (v != first)
300
+ {
301
+ return false;
302
+ }
303
+ }
304
+ return true;
289
305
  }
290
306
 
291
- _CCCL_HOST_DEVICE constexpr bool all_equal()
307
+ _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal()
292
308
  {
293
309
  return true;
294
310
  }
295
311
 
296
- template <typename T, typename... Ts>
297
- _CCCL_HOST_DEVICE constexpr auto first_item(T head, Ts...) -> T
312
+ template <typename H, typename... Ts>
313
+ _CCCL_HOST_DEVICE constexpr auto first_nonzero_value(H head, Ts... values)
298
314
  {
299
- return head;
315
+ for (auto v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
316
+ {
317
+ if (v != 0)
318
+ {
319
+ return v;
320
+ }
321
+ }
322
+ // we only reach here when all input are not contiguous and the output has a void value type
323
+ return H{1};
300
324
  }
301
325
 
302
326
  template <typename T>
@@ -336,25 +360,36 @@ struct policy_hub<RequiresStableAddress,
336
360
  (THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn> && ...);
337
361
  static constexpr bool all_input_values_trivially_reloc =
338
362
  (THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>> && ...);
339
- static constexpr bool can_memcpy_inputs = all_inputs_contiguous && all_input_values_trivially_reloc;
363
+ static constexpr bool can_memcpy_all_inputs = all_inputs_contiguous && all_input_values_trivially_reloc;
364
+ // the vectorized kernel supports mixing contiguous and non-contiguous iterators
365
+ static constexpr bool can_memcpy_contiguous_inputs =
366
+ ((!THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>
367
+ || THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>>)
368
+ && ...);
340
369
 
341
370
  // for vectorized policy:
342
- static constexpr bool all_input_values_same_size = all_equal(sizeof(it_value_t<RandomAccessIteratorsIn>)...);
343
- static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+
344
- // if there are no inputs, we take the size of the output value
345
- static constexpr int value_type_size =
346
- first_item(int{sizeof(it_value_t<RandomAccessIteratorsIn>)}..., int{size_of<it_value_t<RandomAccessIteratorOut>>});
371
+ static constexpr bool all_contiguous_input_values_same_size = all_nonzero_equal(
372
+ (sizeof(it_value_t<RandomAccessIteratorsIn>)
373
+ * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...);
374
+ static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+
375
+ // find the value type size of the first contiguous iterator. if there are no inputs, we take the size of the output
376
+ // value type
377
+ static constexpr int contiguous_value_type_size = first_nonzero_value(
378
+ (int{sizeof(it_value_t<RandomAccessIteratorsIn>)}
379
+ * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...,
380
+ int{size_of<it_value_t<RandomAccessIteratorOut>>});
347
381
  static constexpr bool value_type_divides_load_store_size =
348
- load_store_word_size % value_type_size == 0; // implicitly checks that value_type_size <= load_store_word_size
382
+ load_store_word_size % contiguous_value_type_size == 0; // implicitly checks that value_type_size <=
383
+ // load_store_word_size
349
384
  static constexpr int target_bytes_per_thread =
350
385
  no_input_streams ? 16 /* by experiment on RTX 5090 */ : 32 /* guestimate by gevtushenko for loading */;
351
386
  static constexpr int items_per_thread_vec =
352
- ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / value_type_size;
387
+ ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / contiguous_value_type_size;
353
388
  using default_vectorized_policy_t = vectorized_policy_t<256, items_per_thread_vec, load_store_word_size>;
354
389
 
355
390
  static constexpr bool fallback_to_prefetch =
356
- RequiresStableAddress || !can_memcpy_inputs || !all_input_values_same_size || !value_type_divides_load_store_size
357
- || !DenseOutput;
391
+ RequiresStableAddress || !can_memcpy_contiguous_inputs || !all_contiguous_input_values_same_size
392
+ || !value_type_divides_load_store_size || !DenseOutput;
358
393
 
359
394
  // TODO(bgruber): consider a separate kernel for just filling
360
395
 
@@ -380,7 +415,7 @@ struct policy_hub<RequiresStableAddress,
380
415
  block_threads* async_policy::min_items_per_thread,
381
416
  ldgsts_size_and_align)
382
417
  > int{max_smem_per_block};
383
- static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams;
418
+ static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams || !can_memcpy_all_inputs;
384
419
 
385
420
  public:
386
421
  static constexpr int min_bif = arch_to_min_bytes_in_flight(800);
@@ -421,7 +456,8 @@ struct policy_hub<RequiresStableAddress,
421
456
  (((int{sizeof(it_value_t<RandomAccessIteratorsIn>)} * AsyncBlockSize) % max_alignment == 0) && ...);
422
457
  static constexpr bool enough_threads_for_peeling = AsyncBlockSize >= alignment; // head and tail bytes
423
458
  static constexpr bool fallback_to_vectorized =
424
- exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams;
459
+ exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams
460
+ || !can_memcpy_all_inputs;
425
461
 
426
462
  public:
427
463
  static constexpr int min_bif = arch_to_min_bytes_in_flight(PtxVersion);