cuda-cccl 0.1.3.2.0.dev438__cp312-cp312-manylinux_2_24_aarch64.whl → 0.3.0__cp312-cp312-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (60) hide show
  1. cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +23 -0
  2. cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +22 -14
  3. cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
  4. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
  5. cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
  6. cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
  7. cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
  8. cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +321 -262
  9. cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +8 -0
  10. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
  11. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
  12. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +203 -51
  13. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
  14. cuda/cccl/headers/include/cub/util_device.cuh +51 -35
  15. cuda/cccl/headers/include/cuda/__algorithm/copy.h +3 -3
  16. cuda/cccl/headers/include/cuda/__device/all_devices.h +3 -6
  17. cuda/cccl/headers/include/cuda/__device/arch_traits.h +3 -3
  18. cuda/cccl/headers/include/cuda/__device/attributes.h +7 -7
  19. cuda/cccl/headers/include/cuda/__device/device_ref.h +3 -10
  20. cuda/cccl/headers/include/cuda/__driver/driver_api.h +225 -33
  21. cuda/cccl/headers/include/cuda/__event/event.h +7 -8
  22. cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
  23. cuda/cccl/headers/include/cuda/__event/timed_event.h +3 -4
  24. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
  25. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
  26. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
  27. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
  28. cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
  29. cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
  30. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +17 -12
  31. cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
  32. cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
  33. cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
  34. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
  35. cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
  36. cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
  37. cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
  38. cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
  39. cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
  40. cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
  41. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
  42. cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
  43. cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
  44. cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
  45. cuda/cccl/parallel/experimental/__init__.py +4 -0
  46. cuda/cccl/parallel/experimental/_bindings.pyi +28 -0
  47. cuda/cccl/parallel/experimental/_bindings_impl.pyx +140 -0
  48. cuda/cccl/parallel/experimental/algorithms/__init__.py +4 -0
  49. cuda/cccl/parallel/experimental/algorithms/_reduce.py +0 -2
  50. cuda/cccl/parallel/experimental/algorithms/_scan.py +0 -2
  51. cuda/cccl/parallel/experimental/algorithms/_three_way_partition.py +261 -0
  52. cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  53. cuda/cccl/parallel/experimental/cu12/cccl/libcccl.c.parallel.so +0 -0
  54. cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  55. cuda/cccl/parallel/experimental/cu13/cccl/libcccl.c.parallel.so +0 -0
  56. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/METADATA +1 -1
  57. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/RECORD +59 -57
  58. cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
  59. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/WHEEL +0 -0
  60. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -40,6 +40,7 @@
40
40
  #include <cub/detail/device_double_buffer.cuh>
41
41
  #include <cub/detail/temporary_storage.cuh>
42
42
  #include <cub/device/device_partition.cuh>
43
+ #include <cub/device/dispatch/dispatch_advance_iterators.cuh>
43
44
  #include <cub/device/dispatch/kernels/segmented_sort.cuh>
44
45
  #include <cub/device/dispatch/tuning/tuning_segmented_sort.cuh>
45
46
  #include <cub/util_debug.cuh>
@@ -69,14 +70,14 @@ namespace detail::segmented_sort
69
70
  * of this stage is required to eliminate device-side synchronization in
70
71
  * the CDP mode.
71
72
  */
72
- template <typename LargeSegmentPolicyT,
73
- typename SmallAndMediumPolicyT,
73
+ template <typename WrappedPolicyT,
74
74
  typename LargeKernelT,
75
75
  typename SmallKernelT,
76
76
  typename KeyT,
77
77
  typename ValueT,
78
78
  typename BeginOffsetIteratorT,
79
- typename EndOffsetIteratorT>
79
+ typename EndOffsetIteratorT,
80
+ typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY>
80
81
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortContinuation(
81
82
  LargeKernelT large_kernel,
82
83
  SmallKernelT small_kernel,
@@ -92,7 +93,9 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
92
93
  local_segment_index_t* group_sizes,
93
94
  local_segment_index_t* large_and_medium_segments_indices,
94
95
  local_segment_index_t* small_segments_indices,
95
- cudaStream_t stream)
96
+ cudaStream_t stream,
97
+ KernelLauncherFactory launcher_factory,
98
+ WrappedPolicyT wrapped_policy)
96
99
  {
97
100
  using local_segment_index_t = local_segment_index_t;
98
101
 
@@ -109,11 +112,11 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
109
112
  _CubLog("Invoking "
110
113
  "DeviceSegmentedSortKernelLarge<<<%d, %d, 0, %lld>>>()\n",
111
114
  static_cast<int>(blocks_in_grid),
112
- LargeSegmentPolicyT::BLOCK_THREADS,
115
+ wrapped_policy.LargeSegment().BlockThreads(),
113
116
  (long long) stream);
114
117
  #endif // CUB_DEBUG_LOG
115
118
 
116
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(blocks_in_grid, LargeSegmentPolicyT::BLOCK_THREADS, 0, stream)
119
+ launcher_factory(blocks_in_grid, wrapped_policy.LargeSegment().BlockThreads(), 0, stream)
117
120
  .doit(large_kernel,
118
121
  large_and_medium_segments_indices,
119
122
  d_current_keys,
@@ -144,11 +147,10 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
144
147
  const local_segment_index_t medium_segments =
145
148
  static_cast<local_segment_index_t>(num_segments) - (large_segments + small_segments);
146
149
 
147
- const local_segment_index_t small_blocks =
148
- ::cuda::ceil_div(small_segments, SmallAndMediumPolicyT::SEGMENTS_PER_SMALL_BLOCK);
150
+ const local_segment_index_t small_blocks = ::cuda::ceil_div(small_segments, wrapped_policy.SegmentsPerSmallBlock());
149
151
 
150
152
  const local_segment_index_t medium_blocks =
151
- ::cuda::ceil_div(medium_segments, SmallAndMediumPolicyT::SEGMENTS_PER_MEDIUM_BLOCK);
153
+ ::cuda::ceil_div(medium_segments, wrapped_policy.SegmentsPerMediumBlock());
152
154
 
153
155
  const local_segment_index_t small_and_medium_blocks_in_grid = small_blocks + medium_blocks;
154
156
 
@@ -158,12 +160,11 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
158
160
  _CubLog("Invoking "
159
161
  "DeviceSegmentedSortKernelSmall<<<%d, %d, 0, %lld>>>()\n",
160
162
  static_cast<int>(small_and_medium_blocks_in_grid),
161
- SmallAndMediumPolicyT::BLOCK_THREADS,
163
+ wrapped_policy.SmallSegment().BlockThreads(),
162
164
  (long long) stream);
163
165
  #endif // CUB_DEBUG_LOG
164
166
 
165
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(
166
- small_and_medium_blocks_in_grid, SmallAndMediumPolicyT::BLOCK_THREADS, 0, stream)
167
+ launcher_factory(small_and_medium_blocks_in_grid, wrapped_policy.SmallSegment().BlockThreads(), 0, stream)
167
168
  .doit(small_kernel,
168
169
  small_segments,
169
170
  medium_segments,
@@ -200,13 +201,14 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
200
201
  * Continuation kernel is used only in the CDP mode. It's used to
201
202
  * launch DeviceSegmentedSortContinuation as a separate kernel.
202
203
  */
203
- template <typename ChainedPolicyT,
204
+ template <typename WrappedPolicyT,
204
205
  typename LargeKernelT,
205
206
  typename SmallKernelT,
206
207
  typename KeyT,
207
208
  typename ValueT,
208
209
  typename BeginOffsetIteratorT,
209
- typename EndOffsetIteratorT>
210
+ typename EndOffsetIteratorT,
211
+ typename KernelLauncherFactory>
210
212
  __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContinuationKernel(
211
213
  LargeKernelT large_kernel,
212
214
  SmallKernelT small_kernel,
@@ -221,12 +223,10 @@ __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContin
221
223
  EndOffsetIteratorT d_end_offsets,
222
224
  local_segment_index_t* group_sizes,
223
225
  local_segment_index_t* large_and_medium_segments_indices,
224
- local_segment_index_t* small_segments_indices)
226
+ local_segment_index_t* small_segments_indices,
227
+ KernelLauncherFactory launcher_factory,
228
+ WrappedPolicyT wrapped_policy)
225
229
  {
226
- using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
227
- using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
228
- using SmallAndMediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
229
-
230
230
  // In case of CDP:
231
231
  // 1. each CTA has a different main stream
232
232
  // 2. all streams are non-blocking
@@ -236,86 +236,119 @@ __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContin
236
236
  //
237
237
  // Due to (4, 5), we can't pass the user-provided stream in the continuation.
238
238
  // Due to (1, 2, 3) it's safe to pass the main stream.
239
- cudaError_t error =
240
- detail::segmented_sort::DeviceSegmentedSortContinuation<LargeSegmentPolicyT, SmallAndMediumPolicyT>(
241
- large_kernel,
242
- small_kernel,
243
- num_segments,
244
- d_current_keys,
245
- d_final_keys,
246
- d_keys_double_buffer,
247
- d_current_values,
248
- d_final_values,
249
- d_values_double_buffer,
250
- d_begin_offsets,
251
- d_end_offsets,
252
- group_sizes,
253
- large_and_medium_segments_indices,
254
- small_segments_indices,
255
- 0); // always launching on the main stream (see motivation above)
239
+ cudaError_t error = detail::segmented_sort::DeviceSegmentedSortContinuation<WrappedPolicyT>(
240
+ large_kernel,
241
+ small_kernel,
242
+ num_segments,
243
+ d_current_keys,
244
+ d_final_keys,
245
+ d_keys_double_buffer,
246
+ d_current_values,
247
+ d_final_values,
248
+ d_values_double_buffer,
249
+ d_begin_offsets,
250
+ d_end_offsets,
251
+ group_sizes,
252
+ large_and_medium_segments_indices,
253
+ small_segments_indices,
254
+ 0, // always launching on the main stream (see motivation above)
255
+ launcher_factory,
256
+ wrapped_policy);
256
257
 
257
258
  error = CubDebug(error);
258
259
  }
259
260
  #endif // CUB_RDC_ENABLED
260
- } // namespace detail::segmented_sort
261
-
262
- template <SortOrder Order,
261
+ template <typename MaxPolicyT,
262
+ SortOrder Order,
263
263
  typename KeyT,
264
264
  typename ValueT,
265
- typename OffsetT,
266
265
  typename BeginOffsetIteratorT,
267
266
  typename EndOffsetIteratorT,
268
- typename PolicyHub = detail::segmented_sort::policy_hub<KeyT, ValueT>>
269
- struct DispatchSegmentedSort
267
+ typename OffsetT>
268
+ struct DeviceSegmentedSortKernelSource
270
269
  {
271
- using local_segment_index_t = detail::segmented_sort::local_segment_index_t;
272
- using global_segment_offset_t = detail::segmented_sort::global_segment_offset_t;
270
+ CUB_DEFINE_KERNEL_GETTER(
271
+ SegmentedSortFallbackKernel,
272
+ DeviceSegmentedSortFallbackKernel<Order, MaxPolicyT, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
273
273
 
274
- static constexpr int KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
274
+ CUB_DEFINE_KERNEL_GETTER(
275
+ SegmentedSortKernelSmall,
276
+ DeviceSegmentedSortKernelSmall<Order, MaxPolicyT, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
277
+
278
+ CUB_DEFINE_KERNEL_GETTER(
279
+ SegmentedSortKernelLarge,
280
+ DeviceSegmentedSortKernelLarge<Order, MaxPolicyT, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
275
281
 
276
- struct LargeSegmentsSelectorT
282
+ CUB_RUNTIME_FUNCTION static constexpr size_t KeySize()
277
283
  {
278
- OffsetT value{};
279
- BeginOffsetIteratorT d_offset_begin{};
280
- EndOffsetIteratorT d_offset_end{};
281
- global_segment_offset_t base_segment_offset{};
282
-
283
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE
284
- LargeSegmentsSelectorT(OffsetT value, BeginOffsetIteratorT d_offset_begin, EndOffsetIteratorT d_offset_end)
285
- : value(value)
286
- , d_offset_begin(d_offset_begin)
287
- , d_offset_end(d_offset_end)
288
- {}
289
-
290
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator()(local_segment_index_t segment_id) const
291
- {
292
- const OffsetT segment_size =
293
- d_offset_end[base_segment_offset + segment_id] - d_offset_begin[base_segment_offset + segment_id];
294
- return segment_size > value;
295
- }
296
- };
284
+ return sizeof(KeyT);
285
+ }
297
286
 
298
- struct SmallSegmentsSelectorT
287
+ using LargeSegmentsSelectorT =
288
+ cub::detail::segmented_sort::LargeSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>;
289
+ using SmallSegmentsSelectorT =
290
+ cub::detail::segmented_sort::SmallSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>;
291
+
292
+ CUB_RUNTIME_FUNCTION static constexpr auto LargeSegmentsSelector(
293
+ OffsetT offset, BeginOffsetIteratorT begin_offset_iterator, EndOffsetIteratorT end_offset_iterator)
299
294
  {
300
- OffsetT value{};
301
- BeginOffsetIteratorT d_offset_begin{};
302
- EndOffsetIteratorT d_offset_end{};
303
- global_segment_offset_t base_segment_offset{};
304
-
305
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE
306
- SmallSegmentsSelectorT(OffsetT value, BeginOffsetIteratorT d_offset_begin, EndOffsetIteratorT d_offset_end)
307
- : value(value)
308
- , d_offset_begin(d_offset_begin)
309
- , d_offset_end(d_offset_end)
310
- {}
311
-
312
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator()(local_segment_index_t segment_id) const
313
- {
314
- const OffsetT segment_size =
315
- d_offset_end[base_segment_offset + segment_id] - d_offset_begin[base_segment_offset + segment_id];
316
- return segment_size < value;
317
- }
318
- };
295
+ return LargeSegmentsSelectorT(offset, begin_offset_iterator, end_offset_iterator);
296
+ }
297
+
298
+ CUB_RUNTIME_FUNCTION static constexpr auto SmallSegmentsSelector(
299
+ OffsetT offset, BeginOffsetIteratorT begin_offset_iterator, EndOffsetIteratorT end_offset_iterator)
300
+ {
301
+ return SmallSegmentsSelectorT(offset, begin_offset_iterator, end_offset_iterator);
302
+ }
303
+
304
+ template <typename SelectorT>
305
+ CUB_RUNTIME_FUNCTION static constexpr void
306
+ SetSegmentOffset(SelectorT& selector, global_segment_offset_t base_segment_offset)
307
+ {
308
+ selector.base_segment_offset = base_segment_offset;
309
+ }
310
+ };
311
+ } // namespace detail::segmented_sort
312
+
313
+ template <
314
+ SortOrder Order,
315
+ typename KeyT,
316
+ typename ValueT,
317
+ typename OffsetT,
318
+ typename BeginOffsetIteratorT,
319
+ typename EndOffsetIteratorT,
320
+ typename PolicyHub = detail::segmented_sort::policy_hub<KeyT, ValueT>,
321
+ typename KernelSource = detail::segmented_sort::DeviceSegmentedSortKernelSource<
322
+ typename PolicyHub::MaxPolicy,
323
+ Order,
324
+ KeyT,
325
+ ValueT,
326
+ BeginOffsetIteratorT,
327
+ EndOffsetIteratorT,
328
+ OffsetT>,
329
+ typename PartitionPolicyHub = detail::three_way_partition::policy_hub<
330
+ cub::detail::it_value_t<THRUST_NS_QUALIFIER::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>>,
331
+ detail::three_way_partition::per_partition_offset_t>,
332
+ typename PartitionKernelSource = detail::three_way_partition::DeviceThreeWayPartitionKernelSource<
333
+ typename PartitionPolicyHub::MaxPolicy,
334
+ THRUST_NS_QUALIFIER::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>,
335
+ cub::detail::segmented_sort::local_segment_index_t*,
336
+ cub::detail::segmented_sort::local_segment_index_t*,
337
+ ::cuda::std::reverse_iterator<cub::detail::segmented_sort::local_segment_index_t*>,
338
+ cub::detail::segmented_sort::local_segment_index_t*,
339
+ detail::three_way_partition::ScanTileStateT,
340
+ cub::detail::segmented_sort::LargeSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>,
341
+ cub::detail::segmented_sort::SmallSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>,
342
+ detail::three_way_partition::per_partition_offset_t,
343
+ detail::three_way_partition::streaming_context_t<cub::detail::segmented_sort::global_segment_offset_t>,
344
+ detail::choose_signed_offset<cub::detail::segmented_sort::global_segment_offset_t>::type>,
345
+ typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY>
346
+ struct DispatchSegmentedSort
347
+ {
348
+ using local_segment_index_t = detail::segmented_sort::local_segment_index_t;
349
+ using global_segment_offset_t = detail::segmented_sort::global_segment_offset_t;
350
+
351
+ static constexpr int KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
319
352
 
320
353
  // Partition selects large and small groups. The middle group is not selected.
321
354
  static constexpr size_t num_selected_groups = 2;
@@ -370,48 +403,33 @@ struct DispatchSegmentedSort
370
403
  /// CUDA stream to launch kernels within.
371
404
  cudaStream_t stream;
372
405
 
373
- CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchSegmentedSort(
374
- void* d_temp_storage,
375
- size_t& temp_storage_bytes,
376
- DoubleBuffer<KeyT>& d_keys,
377
- DoubleBuffer<ValueT>& d_values,
378
- ::cuda::std::int64_t num_items,
379
- global_segment_offset_t num_segments,
380
- BeginOffsetIteratorT d_begin_offsets,
381
- EndOffsetIteratorT d_end_offsets,
382
- bool is_overwrite_okay,
383
- cudaStream_t stream)
384
- : d_temp_storage(d_temp_storage)
385
- , temp_storage_bytes(temp_storage_bytes)
386
- , d_keys(d_keys)
387
- , d_values(d_values)
388
- , num_items(num_items)
389
- , num_segments(num_segments)
390
- , d_begin_offsets(d_begin_offsets)
391
- , d_end_offsets(d_end_offsets)
392
- , is_overwrite_okay(is_overwrite_okay)
393
- , stream(stream)
394
- {}
406
+ KernelSource kernel_source;
407
+
408
+ PartitionKernelSource partition_kernel_source;
409
+
410
+ KernelLauncherFactory launcher_factory;
411
+
412
+ typename PartitionPolicyHub::MaxPolicy partition_max_policy;
395
413
 
396
414
  template <typename ActivePolicyT>
397
- CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke()
415
+ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {})
398
416
  {
399
- using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
400
- using SmallAndMediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
417
+ auto wrapped_policy = detail::segmented_sort::MakeSegmentedSortPolicyWrapper(policy);
401
418
 
402
- static_assert(LargeSegmentPolicyT::LOAD_MODIFIER != CacheLoadModifier::LOAD_LDG,
403
- "The memory consistency model does not apply to texture accesses");
419
+ CUB_DETAIL_STATIC_ISH_ASSERT(wrapped_policy.LargeSegmentLoadModifier() != CacheLoadModifier::LOAD_LDG,
420
+ "The memory consistency model does not apply to texture accesses");
404
421
 
405
- static_assert(KEYS_ONLY || LargeSegmentPolicyT::LOAD_ALGORITHM != BLOCK_LOAD_STRIPED
406
- || SmallAndMediumPolicyT::MediumPolicyT::LOAD_ALGORITHM != WARP_LOAD_STRIPED
407
- || SmallAndMediumPolicyT::SmallPolicyT::LOAD_ALGORITHM != WARP_LOAD_STRIPED,
408
- "Striped load will make this algorithm unstable");
422
+ CUB_DETAIL_STATIC_ISH_ASSERT(
423
+ KEYS_ONLY || wrapped_policy.LargeSegmentLoadAlgorithm() != BLOCK_LOAD_STRIPED
424
+ || wrapped_policy.MediumSegmentLoadAlgorithm() != WARP_LOAD_STRIPED
425
+ || wrapped_policy.SmallSegmentLoadAlgorithm() != WARP_LOAD_STRIPED,
426
+ "Striped load will make this algorithm unstable");
409
427
 
410
- static_assert(SmallAndMediumPolicyT::MediumPolicyT::STORE_ALGORITHM != WARP_STORE_STRIPED
411
- || SmallAndMediumPolicyT::SmallPolicyT::STORE_ALGORITHM != WARP_STORE_STRIPED,
412
- "Striped stores will produce unsorted results");
428
+ CUB_DETAIL_STATIC_ISH_ASSERT(wrapped_policy.MediumSegmentStoreAlgorithm() != WARP_STORE_STRIPED
429
+ || wrapped_policy.SmallSegmentStoreAlgorithm() != WARP_STORE_STRIPED,
430
+ "Striped stores will produce unsorted results");
413
431
 
414
- constexpr int radix_bits = LargeSegmentPolicyT::RADIX_BITS;
432
+ const int radix_bits = wrapped_policy.LargeSegmentRadixBits();
415
433
 
416
434
  cudaError error = cudaSuccess;
417
435
 
@@ -421,7 +439,7 @@ struct DispatchSegmentedSort
421
439
  // Prepare temporary storage layout
422
440
  //------------------------------------------------------------------------
423
441
 
424
- const bool partition_segments = num_segments > ActivePolicyT::PARTITIONING_THRESHOLD;
442
+ const bool partition_segments = num_segments > wrapped_policy.PartitioningThreshold();
425
443
 
426
444
  cub::detail::temporary_storage::layout<5> temporary_storage_layout;
427
445
 
@@ -451,11 +469,10 @@ struct DispatchSegmentedSort
451
469
 
452
470
  size_t three_way_partition_temp_storage_bytes{};
453
471
 
454
- LargeSegmentsSelectorT large_segments_selector(
455
- SmallAndMediumPolicyT::MediumPolicyT::ITEMS_PER_TILE, d_begin_offsets, d_end_offsets);
456
-
457
- SmallSegmentsSelectorT small_segments_selector(
458
- SmallAndMediumPolicyT::SmallPolicyT::ITEMS_PER_TILE + 1, d_begin_offsets, d_end_offsets);
472
+ auto large_segments_selector =
473
+ kernel_source.LargeSegmentsSelector(wrapped_policy.MediumPolicyItemsPerTile(), d_begin_offsets, d_end_offsets);
474
+ auto small_segments_selector = kernel_source.SmallSegmentsSelector(
475
+ wrapped_policy.SmallPolicyItemsPerTile() + 1, d_begin_offsets, d_end_offsets);
459
476
 
460
477
  auto device_partition_temp_storage = keys_slot->create_alias<uint8_t>();
461
478
 
@@ -472,7 +489,32 @@ struct DispatchSegmentedSort
472
489
 
473
490
  auto medium_indices_iterator = ::cuda::std::make_reverse_iterator(large_and_medium_segments_indices.get());
474
491
 
475
- cub::DevicePartition::IfNoNVTX(
492
+ // We call partition through dispatch instead of device because c.parallel needs to be able to call the kernel.
493
+ // This approach propagates the type erasure to partition.
494
+ using ChooseOffsetT = detail::choose_signed_offset<global_segment_offset_t>;
495
+ using PartitionOffsetT = typename ChooseOffsetT::type;
496
+ using DispatchThreeWayPartitionIfT = cub::DispatchThreeWayPartitionIf<
497
+ THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>,
498
+ decltype(large_and_medium_segments_indices.get()),
499
+ decltype(small_segments_indices.get()),
500
+ decltype(medium_indices_iterator),
501
+ decltype(group_sizes.get()),
502
+ decltype(large_segments_selector),
503
+ decltype(small_segments_selector),
504
+ PartitionOffsetT,
505
+ PartitionPolicyHub,
506
+ PartitionKernelSource,
507
+ KernelLauncherFactory>;
508
+
509
+ // Signed integer type for global offsets
510
+ // Check if the number of items exceeds the range covered by the selected signed offset type
511
+ error = ChooseOffsetT::is_exceeding_offset_type(num_items);
512
+ if (error)
513
+ {
514
+ return error;
515
+ }
516
+
517
+ DispatchThreeWayPartitionIfT::Dispatch(
476
518
  nullptr,
477
519
  three_way_partition_temp_storage_bytes,
478
520
  THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>(0),
@@ -480,10 +522,13 @@ struct DispatchSegmentedSort
480
522
  small_segments_indices.get(),
481
523
  medium_indices_iterator,
482
524
  group_sizes.get(),
483
- max_num_segments_per_invocation,
484
525
  large_segments_selector,
485
526
  small_segments_selector,
486
- stream);
527
+ max_num_segments_per_invocation,
528
+ stream,
529
+ partition_kernel_source,
530
+ launcher_factory,
531
+ partition_max_policy);
487
532
 
488
533
  device_partition_temp_storage.grow(three_way_partition_temp_storage_bytes);
489
534
  }
@@ -573,29 +618,13 @@ struct DispatchSegmentedSort
573
618
  : (is_num_passes_odd) ? values_allocation.get()
574
619
  : d_values.Alternate());
575
620
 
576
- using MaxPolicyT = typename PolicyHub::MaxPolicy;
577
-
578
621
  if (partition_segments)
579
622
  {
580
623
  // Partition input segments into size groups and assign specialized
581
624
  // kernels for each of them.
582
- error = SortWithPartitioning<LargeSegmentPolicyT, SmallAndMediumPolicyT>(
583
- detail::segmented_sort::DeviceSegmentedSortKernelLarge<
584
- Order,
585
- MaxPolicyT,
586
- KeyT,
587
- ValueT,
588
- BeginOffsetIteratorT,
589
- EndOffsetIteratorT,
590
- OffsetT>,
591
- detail::segmented_sort::DeviceSegmentedSortKernelSmall<
592
- Order,
593
- MaxPolicyT,
594
- KeyT,
595
- ValueT,
596
- BeginOffsetIteratorT,
597
- EndOffsetIteratorT,
598
- OffsetT>,
625
+ error = SortWithPartitioning(
626
+ kernel_source.SegmentedSortKernelLarge(),
627
+ kernel_source.SegmentedSortKernelSmall(),
599
628
  three_way_partition_temp_storage_bytes,
600
629
  d_keys_double_buffer,
601
630
  d_values_double_buffer,
@@ -604,24 +633,16 @@ struct DispatchSegmentedSort
604
633
  device_partition_temp_storage,
605
634
  large_and_medium_segments_indices,
606
635
  small_segments_indices,
607
- group_sizes);
636
+ group_sizes,
637
+ wrapped_policy);
608
638
  }
609
639
  else
610
640
  {
611
641
  // If there are not enough segments, there's no reason to spend time
612
642
  // on extra partitioning steps.
613
643
 
614
- error = SortWithoutPartitioning<LargeSegmentPolicyT>(
615
- detail::segmented_sort::DeviceSegmentedSortFallbackKernel<
616
- Order,
617
- MaxPolicyT,
618
- KeyT,
619
- ValueT,
620
- BeginOffsetIteratorT,
621
- EndOffsetIteratorT,
622
- OffsetT>,
623
- d_keys_double_buffer,
624
- d_values_double_buffer);
644
+ error = SortWithoutPartitioning(
645
+ kernel_source.SegmentedSortFallbackKernel(), d_keys_double_buffer, d_values_double_buffer, wrapped_policy);
625
646
  }
626
647
 
627
648
  d_keys.selector = GetFinalSelector(d_keys.selector, radix_bits);
@@ -632,6 +653,8 @@ struct DispatchSegmentedSort
632
653
  return error;
633
654
  }
634
655
 
656
+ template <typename MaxPolicyT = typename PolicyHub::MaxPolicy,
657
+ typename PartitionMaxPolicyT = typename PartitionPolicyHub::MaxPolicy>
635
658
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch(
636
659
  void* d_temp_storage,
637
660
  size_t& temp_storage_bytes,
@@ -642,49 +665,46 @@ struct DispatchSegmentedSort
642
665
  BeginOffsetIteratorT d_begin_offsets,
643
666
  EndOffsetIteratorT d_end_offsets,
644
667
  bool is_overwrite_okay,
645
- cudaStream_t stream)
668
+ cudaStream_t stream,
669
+ KernelSource kernel_source = {},
670
+ PartitionKernelSource partition_kernel_source = {},
671
+ KernelLauncherFactory launcher_factory = {},
672
+ MaxPolicyT max_policy = {},
673
+ PartitionMaxPolicyT partition_max_policy = {})
646
674
  {
647
- cudaError error = cudaSuccess;
648
-
649
- do
675
+ // Get PTX version
676
+ int ptx_version = 0;
677
+ if (cudaError error = CubDebug(launcher_factory.PtxVersion(ptx_version)); cudaSuccess != error)
650
678
  {
651
- // Get PTX version
652
- int ptx_version = 0;
653
- error = CubDebug(PtxVersion(ptx_version));
654
- if (cudaSuccess != error)
655
- {
656
- break;
657
- }
658
-
659
- // Create dispatch functor
660
- DispatchSegmentedSort dispatch(
661
- d_temp_storage,
662
- temp_storage_bytes,
663
- d_keys,
664
- d_values,
665
- num_items,
666
- num_segments,
667
- d_begin_offsets,
668
- d_end_offsets,
669
- is_overwrite_okay,
670
- stream);
671
-
672
- // Dispatch to chained policy
673
- error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
674
- if (cudaSuccess != error)
675
- {
676
- break;
677
- }
678
- } while (false);
679
+ return error;
680
+ }
679
681
 
680
- return error;
682
+ // Create dispatch functor
683
+ DispatchSegmentedSort dispatch{
684
+ d_temp_storage,
685
+ temp_storage_bytes,
686
+ d_keys,
687
+ d_values,
688
+ num_items,
689
+ num_segments,
690
+ d_begin_offsets,
691
+ d_end_offsets,
692
+ is_overwrite_okay,
693
+ stream,
694
+ kernel_source,
695
+ partition_kernel_source,
696
+ launcher_factory,
697
+ partition_max_policy};
698
+
699
+ // Dispatch to chained policy
700
+ return CubDebug(max_policy.Invoke(ptx_version, dispatch));
681
701
  }
682
702
 
683
703
  private:
684
704
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE int GetNumPasses(int radix_bits)
685
705
  {
686
706
  constexpr int byte_size = 8;
687
- constexpr int num_bits = sizeof(KeyT) * byte_size;
707
+ const int num_bits = static_cast<int>(kernel_source.KeySize()) * byte_size;
688
708
  const int num_passes = ::cuda::ceil_div(num_bits, radix_bits);
689
709
  return num_passes;
690
710
  }
@@ -707,19 +727,20 @@ private:
707
727
  return buffer.d_buffers[final_selector];
708
728
  }
709
729
 
710
- template <typename LargeSegmentPolicyT, typename SmallAndMediumPolicyT, typename LargeKernelT, typename SmallKernelT>
730
+ template <typename WrappedPolicyT, typename LargeKernelT, typename SmallKernelT>
711
731
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t SortWithPartitioning(
712
732
  LargeKernelT large_kernel,
713
733
  SmallKernelT small_kernel,
714
734
  size_t three_way_partition_temp_storage_bytes,
715
735
  cub::detail::device_double_buffer<KeyT>& d_keys_double_buffer,
716
736
  cub::detail::device_double_buffer<ValueT>& d_values_double_buffer,
717
- LargeSegmentsSelectorT& large_segments_selector,
718
- SmallSegmentsSelectorT& small_segments_selector,
737
+ typename KernelSource::LargeSegmentsSelectorT& large_segments_selector,
738
+ typename KernelSource::SmallSegmentsSelectorT& small_segments_selector,
719
739
  cub::detail::temporary_storage::alias<uint8_t>& device_partition_temp_storage,
720
740
  cub::detail::temporary_storage::alias<local_segment_index_t>& large_and_medium_segments_indices,
721
741
  cub::detail::temporary_storage::alias<local_segment_index_t>& small_segments_indices,
722
- cub::detail::temporary_storage::alias<local_segment_index_t>& group_sizes)
742
+ cub::detail::temporary_storage::alias<local_segment_index_t>& group_sizes,
743
+ WrappedPolicyT wrapped_policy)
723
744
  {
724
745
  cudaError_t error = cudaSuccess;
725
746
 
@@ -737,15 +758,44 @@ private:
737
758
  ? static_cast<local_segment_index_t>(num_segments - current_seg_offset)
738
759
  : num_segments_per_invocation_limit;
739
760
 
740
- large_segments_selector.base_segment_offset = current_seg_offset;
741
- small_segments_selector.base_segment_offset = current_seg_offset;
742
- [[maybe_unused]] auto current_begin_offset = d_begin_offsets + current_seg_offset;
743
- [[maybe_unused]] auto current_end_offset = d_end_offsets + current_seg_offset;
761
+ kernel_source.SetSegmentOffset(large_segments_selector, current_seg_offset);
762
+ kernel_source.SetSegmentOffset(small_segments_selector, current_seg_offset);
763
+
764
+ BeginOffsetIteratorT current_begin_offset = d_begin_offsets;
765
+ EndOffsetIteratorT current_end_offset = d_end_offsets;
766
+
767
+ detail::advance_iterators_inplace_if_supported(current_begin_offset, current_seg_offset);
768
+ detail::advance_iterators_inplace_if_supported(current_end_offset, current_seg_offset);
744
769
 
745
770
  auto medium_indices_iterator =
746
771
  ::cuda::std::make_reverse_iterator(large_and_medium_segments_indices.get() + current_num_segments);
747
772
 
748
- error = CubDebug(cub::DevicePartition::IfNoNVTX(
773
+ // We call partition through dispatch instead of device because c.parallel needs to be able to call the kernel.
774
+ // This approach propagates the type erasure to partition.
775
+ using ChooseOffsetT = detail::choose_signed_offset<global_segment_offset_t>;
776
+ using PartitionOffsetT = typename ChooseOffsetT::type;
777
+ using DispatchThreeWayPartitionIfT = cub::DispatchThreeWayPartitionIf<
778
+ THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>,
779
+ decltype(large_and_medium_segments_indices.get()),
780
+ decltype(small_segments_indices.get()),
781
+ decltype(medium_indices_iterator),
782
+ decltype(group_sizes.get()),
783
+ decltype(large_segments_selector),
784
+ decltype(small_segments_selector),
785
+ PartitionOffsetT,
786
+ PartitionPolicyHub,
787
+ PartitionKernelSource,
788
+ KernelLauncherFactory>;
789
+
790
+ // Signed integer type for global offsets
791
+ // Check if the number of items exceeds the range covered by the selected signed offset type
792
+ error = ChooseOffsetT::is_exceeding_offset_type(num_items);
793
+ if (error)
794
+ {
795
+ return error;
796
+ }
797
+
798
+ DispatchThreeWayPartitionIfT::Dispatch(
749
799
  device_partition_temp_storage.get(),
750
800
  three_way_partition_temp_storage_bytes,
751
801
  THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>(0),
@@ -753,10 +803,14 @@ private:
753
803
  small_segments_indices.get(),
754
804
  medium_indices_iterator,
755
805
  group_sizes.get(),
756
- current_num_segments,
757
806
  large_segments_selector,
758
807
  small_segments_selector,
759
- stream));
808
+ current_num_segments,
809
+ stream,
810
+ partition_kernel_source,
811
+ launcher_factory,
812
+ partition_max_policy);
813
+
760
814
  if (cudaSuccess != error)
761
815
  {
762
816
  return error;
@@ -771,43 +825,46 @@ private:
771
825
 
772
826
  #else // CUB_RDC_ENABLED
773
827
 
774
- # define CUB_TEMP_DEVICE_CODE \
775
- error = \
776
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(1, 1, 0, stream) \
777
- .doit( \
778
- detail::segmented_sort::DeviceSegmentedSortContinuationKernel< \
779
- typename PolicyHub::MaxPolicy, \
780
- LargeKernelT, \
781
- SmallKernelT, \
782
- KeyT, \
783
- ValueT, \
784
- BeginOffsetIteratorT, \
785
- EndOffsetIteratorT>, \
786
- large_kernel, \
787
- small_kernel, \
788
- current_num_segments, \
789
- d_keys.Current(), \
790
- GetFinalOutput<KeyT>(LargeSegmentPolicyT::RADIX_BITS, d_keys), \
791
- d_keys_double_buffer, \
792
- d_values.Current(), \
793
- GetFinalOutput<ValueT>(LargeSegmentPolicyT::RADIX_BITS, d_values), \
794
- d_values_double_buffer, \
795
- current_begin_offset, \
796
- current_end_offset, \
797
- group_sizes.get(), \
798
- large_and_medium_segments_indices.get(), \
799
- small_segments_indices.get()); \
800
- error = CubDebug(error); \
801
- \
802
- if (cudaSuccess != error) \
803
- { \
804
- return error; \
805
- } \
806
- \
807
- error = CubDebug(detail::DebugSyncStream(stream)); \
808
- if (cudaSuccess != error) \
809
- { \
810
- return error; \
828
+ # define CUB_TEMP_DEVICE_CODE \
829
+ error = \
830
+ launcher_factory(1, 1, 0, stream) \
831
+ .doit( \
832
+ detail::segmented_sort::DeviceSegmentedSortContinuationKernel< \
833
+ WrappedPolicyT, \
834
+ LargeKernelT, \
835
+ SmallKernelT, \
836
+ KeyT, \
837
+ ValueT, \
838
+ BeginOffsetIteratorT, \
839
+ EndOffsetIteratorT, \
840
+ KernelLauncherFactory>, \
841
+ large_kernel, \
842
+ small_kernel, \
843
+ current_num_segments, \
844
+ d_keys.Current(), \
845
+ GetFinalOutput<KeyT>(wrapped_policy.LargeSegmentRadixBits(), d_keys), \
846
+ d_keys_double_buffer, \
847
+ d_values.Current(), \
848
+ GetFinalOutput<ValueT>(wrapped_policy.LargeSegmentRadixBits(), d_values), \
849
+ d_values_double_buffer, \
850
+ current_begin_offset, \
851
+ current_end_offset, \
852
+ group_sizes.get(), \
853
+ large_and_medium_segments_indices.get(), \
854
+ small_segments_indices.get(), \
855
+ launcher_factory, \
856
+ wrapped_policy); \
857
+ error = CubDebug(error); \
858
+ \
859
+ if (cudaSuccess != error) \
860
+ { \
861
+ return error; \
862
+ } \
863
+ \
864
+ error = CubDebug(detail::DebugSyncStream(stream)); \
865
+ if (cudaSuccess != error) \
866
+ { \
867
+ return error; \
811
868
  }
812
869
 
813
870
  #endif // CUB_RDC_ENABLED
@@ -818,12 +875,12 @@ private:
818
875
  NV_IS_HOST,
819
876
  (
820
877
  local_segment_index_t h_group_sizes[num_selected_groups];
821
- error = CubDebug(cudaMemcpyAsync(h_group_sizes,
822
- group_sizes.get(),
823
- num_selected_groups *
824
- sizeof(local_segment_index_t),
825
- cudaMemcpyDeviceToHost,
826
- stream));
878
+ error = CubDebug(launcher_factory.MemcpyAsync(h_group_sizes,
879
+ group_sizes.get(),
880
+ num_selected_groups *
881
+ sizeof(local_segment_index_t),
882
+ cudaMemcpyDeviceToHost,
883
+ stream));
827
884
 
828
885
  if (cudaSuccess != error)
829
886
  {
@@ -836,23 +893,24 @@ private:
836
893
  return error;
837
894
  }
838
895
 
839
- error = detail::segmented_sort::DeviceSegmentedSortContinuation<LargeSegmentPolicyT,
840
- SmallAndMediumPolicyT>(
896
+ error = detail::segmented_sort::DeviceSegmentedSortContinuation(
841
897
  large_kernel,
842
898
  small_kernel,
843
899
  current_num_segments,
844
900
  d_keys.Current(),
845
- GetFinalOutput<KeyT>(LargeSegmentPolicyT::RADIX_BITS, d_keys),
901
+ GetFinalOutput<KeyT>(wrapped_policy.LargeSegmentRadixBits(), d_keys),
846
902
  d_keys_double_buffer,
847
903
  d_values.Current(),
848
- GetFinalOutput<ValueT>(LargeSegmentPolicyT::RADIX_BITS, d_values),
904
+ GetFinalOutput<ValueT>(wrapped_policy.LargeSegmentRadixBits(), d_values),
849
905
  d_values_double_buffer,
850
906
  current_begin_offset,
851
907
  current_end_offset,
852
908
  h_group_sizes,
853
909
  large_and_medium_segments_indices.get(),
854
910
  small_segments_indices.get(),
855
- stream);),
911
+ stream,
912
+ launcher_factory,
913
+ wrapped_policy);),
856
914
  // NV_IS_DEVICE:
857
915
  (CUB_TEMP_DEVICE_CODE));
858
916
  // clang-format on
@@ -862,16 +920,17 @@ private:
862
920
  return error;
863
921
  }
864
922
 
865
- template <typename LargeSegmentPolicyT, typename FallbackKernelT>
923
+ template <typename WrappedPolicyT, typename FallbackKernelT>
866
924
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t SortWithoutPartitioning(
867
925
  FallbackKernelT fallback_kernel,
868
926
  cub::detail::device_double_buffer<KeyT>& d_keys_double_buffer,
869
- cub::detail::device_double_buffer<ValueT>& d_values_double_buffer)
927
+ cub::detail::device_double_buffer<ValueT>& d_values_double_buffer,
928
+ WrappedPolicyT wrapped_policy)
870
929
  {
871
930
  cudaError_t error = cudaSuccess;
872
931
 
873
- const auto blocks_in_grid = static_cast<local_segment_index_t>(num_segments);
874
- constexpr auto threads_in_block = static_cast<unsigned int>(LargeSegmentPolicyT::BLOCK_THREADS);
932
+ const auto blocks_in_grid = static_cast<local_segment_index_t>(num_segments);
933
+ const auto threads_in_block = static_cast<unsigned int>(wrapped_policy.LargeSegment().BlockThreads());
875
934
 
876
935
  // Log kernel configuration
877
936
  #ifdef CUB_DEBUG_LOG
@@ -880,18 +939,18 @@ private:
880
939
  blocks_in_grid,
881
940
  threads_in_block,
882
941
  (long long) stream,
883
- LargeSegmentPolicyT::ITEMS_PER_THREAD,
884
- LargeSegmentPolicyT::RADIX_BITS);
942
+ wrapped_policy.LargeSegment().ItemsPerThread(),
943
+ wrapped_policy.LargeSegmentRadixBits());
885
944
  #endif // CUB_DEBUG_LOG
886
945
 
887
946
  // Invoke fallback kernel
888
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(blocks_in_grid, threads_in_block, 0, stream)
947
+ launcher_factory(blocks_in_grid, threads_in_block, 0, stream)
889
948
  .doit(fallback_kernel,
890
949
  d_keys.Current(),
891
- GetFinalOutput(LargeSegmentPolicyT::RADIX_BITS, d_keys),
950
+ GetFinalOutput(wrapped_policy.LargeSegmentRadixBits(), d_keys),
892
951
  d_keys_double_buffer,
893
952
  d_values.Current(),
894
- GetFinalOutput(LargeSegmentPolicyT::RADIX_BITS, d_values),
953
+ GetFinalOutput(wrapped_policy.LargeSegmentRadixBits(), d_values),
895
954
  d_values_double_buffer,
896
955
  d_begin_offsets,
897
956
  d_end_offsets);