cuda-cccl 0.1.3.2.0.dev438__cp313-cp313-manylinux_2_24_aarch64.whl → 0.3.1__cp313-cp313-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (177) hide show
  1. cuda/cccl/cooperative/__init__.py +7 -1
  2. cuda/cccl/cooperative/experimental/__init__.py +21 -5
  3. cuda/cccl/headers/include/cub/agent/agent_adjacent_difference.cuh +2 -5
  4. cuda/cccl/headers/include/cub/agent/agent_batch_memcpy.cuh +2 -5
  5. cuda/cccl/headers/include/cub/agent/agent_for.cuh +2 -5
  6. cuda/cccl/headers/include/cub/agent/agent_merge.cuh +23 -21
  7. cuda/cccl/headers/include/cub/agent/agent_merge_sort.cuh +21 -3
  8. cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +25 -5
  9. cuda/cccl/headers/include/cub/agent/agent_radix_sort_histogram.cuh +2 -5
  10. cuda/cccl/headers/include/cub/agent/agent_radix_sort_onesweep.cuh +2 -5
  11. cuda/cccl/headers/include/cub/agent/agent_radix_sort_upsweep.cuh +2 -5
  12. cuda/cccl/headers/include/cub/agent/agent_rle.cuh +2 -5
  13. cuda/cccl/headers/include/cub/agent/agent_scan.cuh +5 -1
  14. cuda/cccl/headers/include/cub/agent/agent_scan_by_key.cuh +2 -5
  15. cuda/cccl/headers/include/cub/agent/agent_segmented_radix_sort.cuh +2 -5
  16. cuda/cccl/headers/include/cub/agent/agent_select_if.cuh +2 -5
  17. cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +24 -19
  18. cuda/cccl/headers/include/cub/agent/agent_three_way_partition.cuh +2 -5
  19. cuda/cccl/headers/include/cub/agent/agent_unique_by_key.cuh +22 -5
  20. cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
  21. cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +3 -2
  22. cuda/cccl/headers/include/cub/block/block_radix_sort.cuh +4 -2
  23. cuda/cccl/headers/include/cub/detail/device_memory_resource.cuh +1 -0
  24. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
  25. cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
  26. cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
  27. cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +158 -247
  28. cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
  29. cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +4 -4
  30. cuda/cccl/headers/include/cub/device/dispatch/dispatch_radix_sort.cuh +2 -11
  31. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +8 -26
  32. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -6
  33. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_nondeterministic.cuh +0 -1
  34. cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +320 -262
  35. cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +10 -5
  36. cuda/cccl/headers/include/cub/device/dispatch/kernels/scan.cuh +2 -5
  37. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_reduce.cuh +2 -5
  38. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
  39. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
  40. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_adjacent_difference.cuh +2 -5
  41. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_batch_memcpy.cuh +2 -5
  42. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_for.cuh +2 -5
  43. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_histogram.cuh +2 -5
  44. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge.cuh +2 -5
  45. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge_sort.cuh +8 -0
  46. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_radix_sort.cuh +2 -5
  47. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh +2 -5
  48. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_run_length_encode.cuh +2 -5
  49. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan.cuh +2 -5
  50. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan_by_key.cuh +2 -5
  51. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +204 -55
  52. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_three_way_partition.cuh +2 -5
  53. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
  54. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_unique_by_key.cuh +10 -0
  55. cuda/cccl/headers/include/cub/util_device.cuh +51 -35
  56. cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_shfl.cuh +3 -2
  57. cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_smem.cuh +3 -2
  58. cuda/cccl/headers/include/cub/warp/specializations/warp_scan_shfl.cuh +2 -2
  59. cuda/cccl/headers/include/cuda/__algorithm/common.h +1 -1
  60. cuda/cccl/headers/include/cuda/__algorithm/copy.h +4 -4
  61. cuda/cccl/headers/include/cuda/__algorithm/fill.h +1 -1
  62. cuda/cccl/headers/include/cuda/__device/all_devices.h +47 -147
  63. cuda/cccl/headers/include/cuda/__device/arch_traits.h +51 -49
  64. cuda/cccl/headers/include/cuda/__device/attributes.h +177 -127
  65. cuda/cccl/headers/include/cuda/__device/device_ref.h +32 -51
  66. cuda/cccl/headers/include/cuda/__device/physical_device.h +120 -91
  67. cuda/cccl/headers/include/cuda/__driver/driver_api.h +330 -36
  68. cuda/cccl/headers/include/cuda/__event/event.h +8 -8
  69. cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
  70. cuda/cccl/headers/include/cuda/__event/timed_event.h +4 -4
  71. cuda/cccl/headers/include/cuda/__fwd/devices.h +44 -0
  72. cuda/cccl/headers/include/cuda/__fwd/zip_iterator.h +9 -0
  73. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
  74. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
  75. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
  76. cuda/cccl/headers/include/cuda/__iterator/zip_common.h +158 -0
  77. cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +8 -120
  78. cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +593 -0
  79. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
  80. cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +4 -3
  81. cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
  82. cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
  83. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +18 -12
  84. cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
  85. cuda/cccl/headers/include/cuda/__utility/basic_any.h +1 -1
  86. cuda/cccl/headers/include/cuda/algorithm +1 -1
  87. cuda/cccl/headers/include/cuda/devices +10 -0
  88. cuda/cccl/headers/include/cuda/iterator +1 -0
  89. cuda/cccl/headers/include/cuda/std/__bit/countl.h +8 -1
  90. cuda/cccl/headers/include/cuda/std/__bit/countr.h +2 -2
  91. cuda/cccl/headers/include/cuda/std/__bit/reference.h +11 -11
  92. cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
  93. cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
  94. cuda/cccl/headers/include/cuda/std/__chrono/duration.h +16 -16
  95. cuda/cccl/headers/include/cuda/std/__chrono/steady_clock.h +5 -5
  96. cuda/cccl/headers/include/cuda/std/__chrono/system_clock.h +5 -5
  97. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
  98. cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
  99. cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
  100. cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
  101. cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
  102. cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
  103. cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
  104. cuda/cccl/headers/include/cuda/std/__floating_point/fp.h +1 -1
  105. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
  106. cuda/cccl/headers/include/cuda/std/__tuple_dir/make_tuple_types.h +23 -1
  107. cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like.h +4 -0
  108. cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like_ext.h +4 -0
  109. cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
  110. cuda/cccl/headers/include/cuda/std/string_view +12 -5
  111. cuda/cccl/headers/include/cuda/std/version +1 -4
  112. cuda/cccl/headers/include/thrust/detail/integer_math.h +3 -20
  113. cuda/cccl/headers/include/thrust/iterator/iterator_traits.h +11 -0
  114. cuda/cccl/headers/include/thrust/system/cuda/detail/copy.h +33 -0
  115. cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
  116. cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
  117. cuda/cccl/parallel/experimental/__init__.py +21 -70
  118. cuda/compute/__init__.py +77 -0
  119. cuda/{cccl/parallel/experimental → compute}/_bindings.pyi +28 -0
  120. cuda/{cccl/parallel/experimental → compute}/_bindings_impl.pyx +141 -1
  121. cuda/{cccl/parallel/experimental → compute}/algorithms/__init__.py +4 -0
  122. cuda/{cccl/parallel/experimental → compute}/algorithms/_histogram.py +2 -2
  123. cuda/{cccl/parallel/experimental → compute}/algorithms/_merge_sort.py +2 -2
  124. cuda/{cccl/parallel/experimental → compute}/algorithms/_radix_sort.py +3 -3
  125. cuda/{cccl/parallel/experimental → compute}/algorithms/_reduce.py +2 -4
  126. cuda/{cccl/parallel/experimental → compute}/algorithms/_scan.py +4 -6
  127. cuda/{cccl/parallel/experimental → compute}/algorithms/_segmented_reduce.py +2 -2
  128. cuda/compute/algorithms/_three_way_partition.py +261 -0
  129. cuda/{cccl/parallel/experimental → compute}/algorithms/_transform.py +4 -4
  130. cuda/{cccl/parallel/experimental → compute}/algorithms/_unique_by_key.py +2 -2
  131. cuda/compute/cu12/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  132. cuda/{cccl/parallel/experimental → compute}/cu12/cccl/libcccl.c.parallel.so +0 -0
  133. cuda/compute/cu13/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  134. cuda/{cccl/parallel/experimental → compute}/cu13/cccl/libcccl.c.parallel.so +0 -0
  135. cuda/{cccl/parallel/experimental → compute}/iterators/_factories.py +8 -8
  136. cuda/{cccl/parallel/experimental → compute}/struct.py +2 -2
  137. cuda/coop/__init__.py +8 -0
  138. cuda/{cccl/cooperative/experimental → coop}/_nvrtc.py +3 -2
  139. cuda/{cccl/cooperative/experimental → coop}/_scan_op.py +3 -3
  140. cuda/{cccl/cooperative/experimental → coop}/_types.py +2 -2
  141. cuda/{cccl/cooperative/experimental → coop}/_typing.py +1 -1
  142. cuda/{cccl/cooperative/experimental → coop}/block/__init__.py +6 -6
  143. cuda/{cccl/cooperative/experimental → coop}/block/_block_exchange.py +4 -4
  144. cuda/{cccl/cooperative/experimental → coop}/block/_block_load_store.py +6 -6
  145. cuda/{cccl/cooperative/experimental → coop}/block/_block_merge_sort.py +4 -4
  146. cuda/{cccl/cooperative/experimental → coop}/block/_block_radix_sort.py +6 -6
  147. cuda/{cccl/cooperative/experimental → coop}/block/_block_reduce.py +6 -6
  148. cuda/{cccl/cooperative/experimental → coop}/block/_block_scan.py +7 -7
  149. cuda/coop/warp/__init__.py +9 -0
  150. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_merge_sort.py +3 -3
  151. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_reduce.py +6 -6
  152. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_scan.py +4 -4
  153. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/METADATA +1 -1
  154. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/RECORD +171 -166
  155. cuda/cccl/cooperative/experimental/warp/__init__.py +0 -9
  156. cuda/cccl/headers/include/cub/device/dispatch/dispatch_advance_iterators.cuh +0 -111
  157. cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
  158. cuda/cccl/parallel/experimental/.gitignore +0 -4
  159. cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  160. cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  161. /cuda/{cccl/parallel/experimental → compute}/_bindings.py +0 -0
  162. /cuda/{cccl/parallel/experimental → compute}/_caching.py +0 -0
  163. /cuda/{cccl/parallel/experimental → compute}/_cccl_interop.py +0 -0
  164. /cuda/{cccl/parallel/experimental → compute}/_utils/__init__.py +0 -0
  165. /cuda/{cccl/parallel/experimental → compute}/_utils/protocols.py +0 -0
  166. /cuda/{cccl/parallel/experimental → compute}/_utils/temp_storage_buffer.py +0 -0
  167. /cuda/{cccl/parallel/experimental → compute}/cccl/.gitkeep +0 -0
  168. /cuda/{cccl/parallel/experimental → compute}/iterators/__init__.py +0 -0
  169. /cuda/{cccl/parallel/experimental → compute}/iterators/_iterators.py +0 -0
  170. /cuda/{cccl/parallel/experimental → compute}/iterators/_zip_iterator.py +0 -0
  171. /cuda/{cccl/parallel/experimental → compute}/numba_utils.py +0 -0
  172. /cuda/{cccl/parallel/experimental → compute}/op.py +0 -0
  173. /cuda/{cccl/parallel/experimental → compute}/typing.py +0 -0
  174. /cuda/{cccl/cooperative/experimental → coop}/_caching.py +0 -0
  175. /cuda/{cccl/cooperative/experimental → coop}/_common.py +0 -0
  176. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/WHEEL +0 -0
  177. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -69,14 +69,14 @@ namespace detail::segmented_sort
69
69
  * of this stage is required to eliminate device-side synchronization in
70
70
  * the CDP mode.
71
71
  */
72
- template <typename LargeSegmentPolicyT,
73
- typename SmallAndMediumPolicyT,
72
+ template <typename WrappedPolicyT,
74
73
  typename LargeKernelT,
75
74
  typename SmallKernelT,
76
75
  typename KeyT,
77
76
  typename ValueT,
78
77
  typename BeginOffsetIteratorT,
79
- typename EndOffsetIteratorT>
78
+ typename EndOffsetIteratorT,
79
+ typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY>
80
80
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortContinuation(
81
81
  LargeKernelT large_kernel,
82
82
  SmallKernelT small_kernel,
@@ -92,7 +92,9 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
92
92
  local_segment_index_t* group_sizes,
93
93
  local_segment_index_t* large_and_medium_segments_indices,
94
94
  local_segment_index_t* small_segments_indices,
95
- cudaStream_t stream)
95
+ cudaStream_t stream,
96
+ KernelLauncherFactory launcher_factory,
97
+ WrappedPolicyT wrapped_policy)
96
98
  {
97
99
  using local_segment_index_t = local_segment_index_t;
98
100
 
@@ -109,11 +111,11 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
109
111
  _CubLog("Invoking "
110
112
  "DeviceSegmentedSortKernelLarge<<<%d, %d, 0, %lld>>>()\n",
111
113
  static_cast<int>(blocks_in_grid),
112
- LargeSegmentPolicyT::BLOCK_THREADS,
114
+ wrapped_policy.LargeSegment().BlockThreads(),
113
115
  (long long) stream);
114
116
  #endif // CUB_DEBUG_LOG
115
117
 
116
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(blocks_in_grid, LargeSegmentPolicyT::BLOCK_THREADS, 0, stream)
118
+ launcher_factory(blocks_in_grid, wrapped_policy.LargeSegment().BlockThreads(), 0, stream)
117
119
  .doit(large_kernel,
118
120
  large_and_medium_segments_indices,
119
121
  d_current_keys,
@@ -144,11 +146,10 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
144
146
  const local_segment_index_t medium_segments =
145
147
  static_cast<local_segment_index_t>(num_segments) - (large_segments + small_segments);
146
148
 
147
- const local_segment_index_t small_blocks =
148
- ::cuda::ceil_div(small_segments, SmallAndMediumPolicyT::SEGMENTS_PER_SMALL_BLOCK);
149
+ const local_segment_index_t small_blocks = ::cuda::ceil_div(small_segments, wrapped_policy.SegmentsPerSmallBlock());
149
150
 
150
151
  const local_segment_index_t medium_blocks =
151
- ::cuda::ceil_div(medium_segments, SmallAndMediumPolicyT::SEGMENTS_PER_MEDIUM_BLOCK);
152
+ ::cuda::ceil_div(medium_segments, wrapped_policy.SegmentsPerMediumBlock());
152
153
 
153
154
  const local_segment_index_t small_and_medium_blocks_in_grid = small_blocks + medium_blocks;
154
155
 
@@ -158,12 +159,11 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
158
159
  _CubLog("Invoking "
159
160
  "DeviceSegmentedSortKernelSmall<<<%d, %d, 0, %lld>>>()\n",
160
161
  static_cast<int>(small_and_medium_blocks_in_grid),
161
- SmallAndMediumPolicyT::BLOCK_THREADS,
162
+ wrapped_policy.SmallSegment().BlockThreads(),
162
163
  (long long) stream);
163
164
  #endif // CUB_DEBUG_LOG
164
165
 
165
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(
166
- small_and_medium_blocks_in_grid, SmallAndMediumPolicyT::BLOCK_THREADS, 0, stream)
166
+ launcher_factory(small_and_medium_blocks_in_grid, wrapped_policy.SmallSegment().BlockThreads(), 0, stream)
167
167
  .doit(small_kernel,
168
168
  small_segments,
169
169
  medium_segments,
@@ -200,13 +200,14 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN cudaError_t DeviceSegmentedSortCont
200
200
  * Continuation kernel is used only in the CDP mode. It's used to
201
201
  * launch DeviceSegmentedSortContinuation as a separate kernel.
202
202
  */
203
- template <typename ChainedPolicyT,
203
+ template <typename WrappedPolicyT,
204
204
  typename LargeKernelT,
205
205
  typename SmallKernelT,
206
206
  typename KeyT,
207
207
  typename ValueT,
208
208
  typename BeginOffsetIteratorT,
209
- typename EndOffsetIteratorT>
209
+ typename EndOffsetIteratorT,
210
+ typename KernelLauncherFactory>
210
211
  __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContinuationKernel(
211
212
  LargeKernelT large_kernel,
212
213
  SmallKernelT small_kernel,
@@ -221,12 +222,10 @@ __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContin
221
222
  EndOffsetIteratorT d_end_offsets,
222
223
  local_segment_index_t* group_sizes,
223
224
  local_segment_index_t* large_and_medium_segments_indices,
224
- local_segment_index_t* small_segments_indices)
225
+ local_segment_index_t* small_segments_indices,
226
+ KernelLauncherFactory launcher_factory,
227
+ WrappedPolicyT wrapped_policy)
225
228
  {
226
- using ActivePolicyT = typename ChainedPolicyT::ActivePolicy;
227
- using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
228
- using SmallAndMediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
229
-
230
229
  // In case of CDP:
231
230
  // 1. each CTA has a different main stream
232
231
  // 2. all streams are non-blocking
@@ -236,86 +235,119 @@ __launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContin
236
235
  //
237
236
  // Due to (4, 5), we can't pass the user-provided stream in the continuation.
238
237
  // Due to (1, 2, 3) it's safe to pass the main stream.
239
- cudaError_t error =
240
- detail::segmented_sort::DeviceSegmentedSortContinuation<LargeSegmentPolicyT, SmallAndMediumPolicyT>(
241
- large_kernel,
242
- small_kernel,
243
- num_segments,
244
- d_current_keys,
245
- d_final_keys,
246
- d_keys_double_buffer,
247
- d_current_values,
248
- d_final_values,
249
- d_values_double_buffer,
250
- d_begin_offsets,
251
- d_end_offsets,
252
- group_sizes,
253
- large_and_medium_segments_indices,
254
- small_segments_indices,
255
- 0); // always launching on the main stream (see motivation above)
238
+ cudaError_t error = detail::segmented_sort::DeviceSegmentedSortContinuation<WrappedPolicyT>(
239
+ large_kernel,
240
+ small_kernel,
241
+ num_segments,
242
+ d_current_keys,
243
+ d_final_keys,
244
+ d_keys_double_buffer,
245
+ d_current_values,
246
+ d_final_values,
247
+ d_values_double_buffer,
248
+ d_begin_offsets,
249
+ d_end_offsets,
250
+ group_sizes,
251
+ large_and_medium_segments_indices,
252
+ small_segments_indices,
253
+ 0, // always launching on the main stream (see motivation above)
254
+ launcher_factory,
255
+ wrapped_policy);
256
256
 
257
257
  error = CubDebug(error);
258
258
  }
259
259
  #endif // CUB_RDC_ENABLED
260
- } // namespace detail::segmented_sort
261
-
262
- template <SortOrder Order,
260
+ template <typename MaxPolicyT,
261
+ SortOrder Order,
263
262
  typename KeyT,
264
263
  typename ValueT,
265
- typename OffsetT,
266
264
  typename BeginOffsetIteratorT,
267
265
  typename EndOffsetIteratorT,
268
- typename PolicyHub = detail::segmented_sort::policy_hub<KeyT, ValueT>>
269
- struct DispatchSegmentedSort
266
+ typename OffsetT>
267
+ struct DeviceSegmentedSortKernelSource
270
268
  {
271
- using local_segment_index_t = detail::segmented_sort::local_segment_index_t;
272
- using global_segment_offset_t = detail::segmented_sort::global_segment_offset_t;
269
+ CUB_DEFINE_KERNEL_GETTER(
270
+ SegmentedSortFallbackKernel,
271
+ DeviceSegmentedSortFallbackKernel<Order, MaxPolicyT, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
273
272
 
274
- static constexpr int KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
273
+ CUB_DEFINE_KERNEL_GETTER(
274
+ SegmentedSortKernelSmall,
275
+ DeviceSegmentedSortKernelSmall<Order, MaxPolicyT, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
276
+
277
+ CUB_DEFINE_KERNEL_GETTER(
278
+ SegmentedSortKernelLarge,
279
+ DeviceSegmentedSortKernelLarge<Order, MaxPolicyT, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
275
280
 
276
- struct LargeSegmentsSelectorT
281
+ CUB_RUNTIME_FUNCTION static constexpr size_t KeySize()
277
282
  {
278
- OffsetT value{};
279
- BeginOffsetIteratorT d_offset_begin{};
280
- EndOffsetIteratorT d_offset_end{};
281
- global_segment_offset_t base_segment_offset{};
282
-
283
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE
284
- LargeSegmentsSelectorT(OffsetT value, BeginOffsetIteratorT d_offset_begin, EndOffsetIteratorT d_offset_end)
285
- : value(value)
286
- , d_offset_begin(d_offset_begin)
287
- , d_offset_end(d_offset_end)
288
- {}
289
-
290
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator()(local_segment_index_t segment_id) const
291
- {
292
- const OffsetT segment_size =
293
- d_offset_end[base_segment_offset + segment_id] - d_offset_begin[base_segment_offset + segment_id];
294
- return segment_size > value;
295
- }
296
- };
283
+ return sizeof(KeyT);
284
+ }
297
285
 
298
- struct SmallSegmentsSelectorT
286
+ using LargeSegmentsSelectorT =
287
+ cub::detail::segmented_sort::LargeSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>;
288
+ using SmallSegmentsSelectorT =
289
+ cub::detail::segmented_sort::SmallSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>;
290
+
291
+ CUB_RUNTIME_FUNCTION static constexpr auto LargeSegmentsSelector(
292
+ OffsetT offset, BeginOffsetIteratorT begin_offset_iterator, EndOffsetIteratorT end_offset_iterator)
299
293
  {
300
- OffsetT value{};
301
- BeginOffsetIteratorT d_offset_begin{};
302
- EndOffsetIteratorT d_offset_end{};
303
- global_segment_offset_t base_segment_offset{};
304
-
305
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE
306
- SmallSegmentsSelectorT(OffsetT value, BeginOffsetIteratorT d_offset_begin, EndOffsetIteratorT d_offset_end)
307
- : value(value)
308
- , d_offset_begin(d_offset_begin)
309
- , d_offset_end(d_offset_end)
310
- {}
311
-
312
- _CCCL_HOST_DEVICE _CCCL_FORCEINLINE bool operator()(local_segment_index_t segment_id) const
313
- {
314
- const OffsetT segment_size =
315
- d_offset_end[base_segment_offset + segment_id] - d_offset_begin[base_segment_offset + segment_id];
316
- return segment_size < value;
317
- }
318
- };
294
+ return LargeSegmentsSelectorT(offset, begin_offset_iterator, end_offset_iterator);
295
+ }
296
+
297
+ CUB_RUNTIME_FUNCTION static constexpr auto SmallSegmentsSelector(
298
+ OffsetT offset, BeginOffsetIteratorT begin_offset_iterator, EndOffsetIteratorT end_offset_iterator)
299
+ {
300
+ return SmallSegmentsSelectorT(offset, begin_offset_iterator, end_offset_iterator);
301
+ }
302
+
303
+ template <typename SelectorT>
304
+ CUB_RUNTIME_FUNCTION static constexpr void
305
+ SetSegmentOffset(SelectorT& selector, global_segment_offset_t base_segment_offset)
306
+ {
307
+ selector.base_segment_offset = base_segment_offset;
308
+ }
309
+ };
310
+ } // namespace detail::segmented_sort
311
+
312
+ template <
313
+ SortOrder Order,
314
+ typename KeyT,
315
+ typename ValueT,
316
+ typename OffsetT,
317
+ typename BeginOffsetIteratorT,
318
+ typename EndOffsetIteratorT,
319
+ typename PolicyHub = detail::segmented_sort::policy_hub<KeyT, ValueT>,
320
+ typename KernelSource = detail::segmented_sort::DeviceSegmentedSortKernelSource<
321
+ typename PolicyHub::MaxPolicy,
322
+ Order,
323
+ KeyT,
324
+ ValueT,
325
+ BeginOffsetIteratorT,
326
+ EndOffsetIteratorT,
327
+ OffsetT>,
328
+ typename PartitionPolicyHub = detail::three_way_partition::policy_hub<
329
+ cub::detail::it_value_t<THRUST_NS_QUALIFIER::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>>,
330
+ detail::three_way_partition::per_partition_offset_t>,
331
+ typename PartitionKernelSource = detail::three_way_partition::DeviceThreeWayPartitionKernelSource<
332
+ typename PartitionPolicyHub::MaxPolicy,
333
+ THRUST_NS_QUALIFIER::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>,
334
+ cub::detail::segmented_sort::local_segment_index_t*,
335
+ cub::detail::segmented_sort::local_segment_index_t*,
336
+ ::cuda::std::reverse_iterator<cub::detail::segmented_sort::local_segment_index_t*>,
337
+ cub::detail::segmented_sort::local_segment_index_t*,
338
+ detail::three_way_partition::ScanTileStateT,
339
+ cub::detail::segmented_sort::LargeSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>,
340
+ cub::detail::segmented_sort::SmallSegmentsSelectorT<OffsetT, BeginOffsetIteratorT, EndOffsetIteratorT>,
341
+ detail::three_way_partition::per_partition_offset_t,
342
+ detail::three_way_partition::streaming_context_t<cub::detail::segmented_sort::global_segment_offset_t>,
343
+ detail::choose_signed_offset<cub::detail::segmented_sort::global_segment_offset_t>::type>,
344
+ typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY>
345
+ struct DispatchSegmentedSort
346
+ {
347
+ using local_segment_index_t = detail::segmented_sort::local_segment_index_t;
348
+ using global_segment_offset_t = detail::segmented_sort::global_segment_offset_t;
349
+
350
+ static constexpr int KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;
319
351
 
320
352
  // Partition selects large and small groups. The middle group is not selected.
321
353
  static constexpr size_t num_selected_groups = 2;
@@ -370,48 +402,33 @@ struct DispatchSegmentedSort
370
402
  /// CUDA stream to launch kernels within.
371
403
  cudaStream_t stream;
372
404
 
373
- CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchSegmentedSort(
374
- void* d_temp_storage,
375
- size_t& temp_storage_bytes,
376
- DoubleBuffer<KeyT>& d_keys,
377
- DoubleBuffer<ValueT>& d_values,
378
- ::cuda::std::int64_t num_items,
379
- global_segment_offset_t num_segments,
380
- BeginOffsetIteratorT d_begin_offsets,
381
- EndOffsetIteratorT d_end_offsets,
382
- bool is_overwrite_okay,
383
- cudaStream_t stream)
384
- : d_temp_storage(d_temp_storage)
385
- , temp_storage_bytes(temp_storage_bytes)
386
- , d_keys(d_keys)
387
- , d_values(d_values)
388
- , num_items(num_items)
389
- , num_segments(num_segments)
390
- , d_begin_offsets(d_begin_offsets)
391
- , d_end_offsets(d_end_offsets)
392
- , is_overwrite_okay(is_overwrite_okay)
393
- , stream(stream)
394
- {}
405
+ KernelSource kernel_source;
406
+
407
+ PartitionKernelSource partition_kernel_source;
408
+
409
+ KernelLauncherFactory launcher_factory;
410
+
411
+ typename PartitionPolicyHub::MaxPolicy partition_max_policy;
395
412
 
396
413
  template <typename ActivePolicyT>
397
- CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke()
414
+ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {})
398
415
  {
399
- using LargeSegmentPolicyT = typename ActivePolicyT::LargeSegmentPolicy;
400
- using SmallAndMediumPolicyT = typename ActivePolicyT::SmallAndMediumSegmentedSortPolicyT;
416
+ auto wrapped_policy = detail::segmented_sort::MakeSegmentedSortPolicyWrapper(policy);
401
417
 
402
- static_assert(LargeSegmentPolicyT::LOAD_MODIFIER != CacheLoadModifier::LOAD_LDG,
403
- "The memory consistency model does not apply to texture accesses");
418
+ CUB_DETAIL_STATIC_ISH_ASSERT(wrapped_policy.LargeSegmentLoadModifier() != CacheLoadModifier::LOAD_LDG,
419
+ "The memory consistency model does not apply to texture accesses");
404
420
 
405
- static_assert(KEYS_ONLY || LargeSegmentPolicyT::LOAD_ALGORITHM != BLOCK_LOAD_STRIPED
406
- || SmallAndMediumPolicyT::MediumPolicyT::LOAD_ALGORITHM != WARP_LOAD_STRIPED
407
- || SmallAndMediumPolicyT::SmallPolicyT::LOAD_ALGORITHM != WARP_LOAD_STRIPED,
408
- "Striped load will make this algorithm unstable");
421
+ CUB_DETAIL_STATIC_ISH_ASSERT(
422
+ KEYS_ONLY || wrapped_policy.LargeSegmentLoadAlgorithm() != BLOCK_LOAD_STRIPED
423
+ || wrapped_policy.MediumSegmentLoadAlgorithm() != WARP_LOAD_STRIPED
424
+ || wrapped_policy.SmallSegmentLoadAlgorithm() != WARP_LOAD_STRIPED,
425
+ "Striped load will make this algorithm unstable");
409
426
 
410
- static_assert(SmallAndMediumPolicyT::MediumPolicyT::STORE_ALGORITHM != WARP_STORE_STRIPED
411
- || SmallAndMediumPolicyT::SmallPolicyT::STORE_ALGORITHM != WARP_STORE_STRIPED,
412
- "Striped stores will produce unsorted results");
427
+ CUB_DETAIL_STATIC_ISH_ASSERT(wrapped_policy.MediumSegmentStoreAlgorithm() != WARP_STORE_STRIPED
428
+ || wrapped_policy.SmallSegmentStoreAlgorithm() != WARP_STORE_STRIPED,
429
+ "Striped stores will produce unsorted results");
413
430
 
414
- constexpr int radix_bits = LargeSegmentPolicyT::RADIX_BITS;
431
+ const int radix_bits = wrapped_policy.LargeSegmentRadixBits();
415
432
 
416
433
  cudaError error = cudaSuccess;
417
434
 
@@ -421,7 +438,7 @@ struct DispatchSegmentedSort
421
438
  // Prepare temporary storage layout
422
439
  //------------------------------------------------------------------------
423
440
 
424
- const bool partition_segments = num_segments > ActivePolicyT::PARTITIONING_THRESHOLD;
441
+ const bool partition_segments = num_segments > wrapped_policy.PartitioningThreshold();
425
442
 
426
443
  cub::detail::temporary_storage::layout<5> temporary_storage_layout;
427
444
 
@@ -451,11 +468,10 @@ struct DispatchSegmentedSort
451
468
 
452
469
  size_t three_way_partition_temp_storage_bytes{};
453
470
 
454
- LargeSegmentsSelectorT large_segments_selector(
455
- SmallAndMediumPolicyT::MediumPolicyT::ITEMS_PER_TILE, d_begin_offsets, d_end_offsets);
456
-
457
- SmallSegmentsSelectorT small_segments_selector(
458
- SmallAndMediumPolicyT::SmallPolicyT::ITEMS_PER_TILE + 1, d_begin_offsets, d_end_offsets);
471
+ auto large_segments_selector =
472
+ kernel_source.LargeSegmentsSelector(wrapped_policy.MediumPolicyItemsPerTile(), d_begin_offsets, d_end_offsets);
473
+ auto small_segments_selector = kernel_source.SmallSegmentsSelector(
474
+ wrapped_policy.SmallPolicyItemsPerTile() + 1, d_begin_offsets, d_end_offsets);
459
475
 
460
476
  auto device_partition_temp_storage = keys_slot->create_alias<uint8_t>();
461
477
 
@@ -472,7 +488,32 @@ struct DispatchSegmentedSort
472
488
 
473
489
  auto medium_indices_iterator = ::cuda::std::make_reverse_iterator(large_and_medium_segments_indices.get());
474
490
 
475
- cub::DevicePartition::IfNoNVTX(
491
+ // We call partition through dispatch instead of device because c.parallel needs to be able to call the kernel.
492
+ // This approach propagates the type erasure to partition.
493
+ using ChooseOffsetT = detail::choose_signed_offset<global_segment_offset_t>;
494
+ using PartitionOffsetT = typename ChooseOffsetT::type;
495
+ using DispatchThreeWayPartitionIfT = cub::DispatchThreeWayPartitionIf<
496
+ THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>,
497
+ decltype(large_and_medium_segments_indices.get()),
498
+ decltype(small_segments_indices.get()),
499
+ decltype(medium_indices_iterator),
500
+ decltype(group_sizes.get()),
501
+ decltype(large_segments_selector),
502
+ decltype(small_segments_selector),
503
+ PartitionOffsetT,
504
+ PartitionPolicyHub,
505
+ PartitionKernelSource,
506
+ KernelLauncherFactory>;
507
+
508
+ // Signed integer type for global offsets
509
+ // Check if the number of items exceeds the range covered by the selected signed offset type
510
+ error = ChooseOffsetT::is_exceeding_offset_type(num_items);
511
+ if (error)
512
+ {
513
+ return error;
514
+ }
515
+
516
+ DispatchThreeWayPartitionIfT::Dispatch(
476
517
  nullptr,
477
518
  three_way_partition_temp_storage_bytes,
478
519
  THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>(0),
@@ -480,10 +521,13 @@ struct DispatchSegmentedSort
480
521
  small_segments_indices.get(),
481
522
  medium_indices_iterator,
482
523
  group_sizes.get(),
483
- max_num_segments_per_invocation,
484
524
  large_segments_selector,
485
525
  small_segments_selector,
486
- stream);
526
+ max_num_segments_per_invocation,
527
+ stream,
528
+ partition_kernel_source,
529
+ launcher_factory,
530
+ partition_max_policy);
487
531
 
488
532
  device_partition_temp_storage.grow(three_way_partition_temp_storage_bytes);
489
533
  }
@@ -573,29 +617,13 @@ struct DispatchSegmentedSort
573
617
  : (is_num_passes_odd) ? values_allocation.get()
574
618
  : d_values.Alternate());
575
619
 
576
- using MaxPolicyT = typename PolicyHub::MaxPolicy;
577
-
578
620
  if (partition_segments)
579
621
  {
580
622
  // Partition input segments into size groups and assign specialized
581
623
  // kernels for each of them.
582
- error = SortWithPartitioning<LargeSegmentPolicyT, SmallAndMediumPolicyT>(
583
- detail::segmented_sort::DeviceSegmentedSortKernelLarge<
584
- Order,
585
- MaxPolicyT,
586
- KeyT,
587
- ValueT,
588
- BeginOffsetIteratorT,
589
- EndOffsetIteratorT,
590
- OffsetT>,
591
- detail::segmented_sort::DeviceSegmentedSortKernelSmall<
592
- Order,
593
- MaxPolicyT,
594
- KeyT,
595
- ValueT,
596
- BeginOffsetIteratorT,
597
- EndOffsetIteratorT,
598
- OffsetT>,
624
+ error = SortWithPartitioning(
625
+ kernel_source.SegmentedSortKernelLarge(),
626
+ kernel_source.SegmentedSortKernelSmall(),
599
627
  three_way_partition_temp_storage_bytes,
600
628
  d_keys_double_buffer,
601
629
  d_values_double_buffer,
@@ -604,24 +632,16 @@ struct DispatchSegmentedSort
604
632
  device_partition_temp_storage,
605
633
  large_and_medium_segments_indices,
606
634
  small_segments_indices,
607
- group_sizes);
635
+ group_sizes,
636
+ wrapped_policy);
608
637
  }
609
638
  else
610
639
  {
611
640
  // If there are not enough segments, there's no reason to spend time
612
641
  // on extra partitioning steps.
613
642
 
614
- error = SortWithoutPartitioning<LargeSegmentPolicyT>(
615
- detail::segmented_sort::DeviceSegmentedSortFallbackKernel<
616
- Order,
617
- MaxPolicyT,
618
- KeyT,
619
- ValueT,
620
- BeginOffsetIteratorT,
621
- EndOffsetIteratorT,
622
- OffsetT>,
623
- d_keys_double_buffer,
624
- d_values_double_buffer);
643
+ error = SortWithoutPartitioning(
644
+ kernel_source.SegmentedSortFallbackKernel(), d_keys_double_buffer, d_values_double_buffer, wrapped_policy);
625
645
  }
626
646
 
627
647
  d_keys.selector = GetFinalSelector(d_keys.selector, radix_bits);
@@ -632,6 +652,8 @@ struct DispatchSegmentedSort
632
652
  return error;
633
653
  }
634
654
 
655
+ template <typename MaxPolicyT = typename PolicyHub::MaxPolicy,
656
+ typename PartitionMaxPolicyT = typename PartitionPolicyHub::MaxPolicy>
635
657
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch(
636
658
  void* d_temp_storage,
637
659
  size_t& temp_storage_bytes,
@@ -642,49 +664,46 @@ struct DispatchSegmentedSort
642
664
  BeginOffsetIteratorT d_begin_offsets,
643
665
  EndOffsetIteratorT d_end_offsets,
644
666
  bool is_overwrite_okay,
645
- cudaStream_t stream)
667
+ cudaStream_t stream,
668
+ KernelSource kernel_source = {},
669
+ PartitionKernelSource partition_kernel_source = {},
670
+ KernelLauncherFactory launcher_factory = {},
671
+ MaxPolicyT max_policy = {},
672
+ PartitionMaxPolicyT partition_max_policy = {})
646
673
  {
647
- cudaError error = cudaSuccess;
648
-
649
- do
674
+ // Get PTX version
675
+ int ptx_version = 0;
676
+ if (cudaError error = CubDebug(launcher_factory.PtxVersion(ptx_version)); cudaSuccess != error)
650
677
  {
651
- // Get PTX version
652
- int ptx_version = 0;
653
- error = CubDebug(PtxVersion(ptx_version));
654
- if (cudaSuccess != error)
655
- {
656
- break;
657
- }
658
-
659
- // Create dispatch functor
660
- DispatchSegmentedSort dispatch(
661
- d_temp_storage,
662
- temp_storage_bytes,
663
- d_keys,
664
- d_values,
665
- num_items,
666
- num_segments,
667
- d_begin_offsets,
668
- d_end_offsets,
669
- is_overwrite_okay,
670
- stream);
671
-
672
- // Dispatch to chained policy
673
- error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
674
- if (cudaSuccess != error)
675
- {
676
- break;
677
- }
678
- } while (false);
678
+ return error;
679
+ }
679
680
 
680
- return error;
681
+ // Create dispatch functor
682
+ DispatchSegmentedSort dispatch{
683
+ d_temp_storage,
684
+ temp_storage_bytes,
685
+ d_keys,
686
+ d_values,
687
+ num_items,
688
+ num_segments,
689
+ d_begin_offsets,
690
+ d_end_offsets,
691
+ is_overwrite_okay,
692
+ stream,
693
+ kernel_source,
694
+ partition_kernel_source,
695
+ launcher_factory,
696
+ partition_max_policy};
697
+
698
+ // Dispatch to chained policy
699
+ return CubDebug(max_policy.Invoke(ptx_version, dispatch));
681
700
  }
682
701
 
683
702
  private:
684
703
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE int GetNumPasses(int radix_bits)
685
704
  {
686
705
  constexpr int byte_size = 8;
687
- constexpr int num_bits = sizeof(KeyT) * byte_size;
706
+ const int num_bits = static_cast<int>(kernel_source.KeySize()) * byte_size;
688
707
  const int num_passes = ::cuda::ceil_div(num_bits, radix_bits);
689
708
  return num_passes;
690
709
  }
@@ -707,19 +726,20 @@ private:
707
726
  return buffer.d_buffers[final_selector];
708
727
  }
709
728
 
710
- template <typename LargeSegmentPolicyT, typename SmallAndMediumPolicyT, typename LargeKernelT, typename SmallKernelT>
729
+ template <typename WrappedPolicyT, typename LargeKernelT, typename SmallKernelT>
711
730
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t SortWithPartitioning(
712
731
  LargeKernelT large_kernel,
713
732
  SmallKernelT small_kernel,
714
733
  size_t three_way_partition_temp_storage_bytes,
715
734
  cub::detail::device_double_buffer<KeyT>& d_keys_double_buffer,
716
735
  cub::detail::device_double_buffer<ValueT>& d_values_double_buffer,
717
- LargeSegmentsSelectorT& large_segments_selector,
718
- SmallSegmentsSelectorT& small_segments_selector,
736
+ typename KernelSource::LargeSegmentsSelectorT& large_segments_selector,
737
+ typename KernelSource::SmallSegmentsSelectorT& small_segments_selector,
719
738
  cub::detail::temporary_storage::alias<uint8_t>& device_partition_temp_storage,
720
739
  cub::detail::temporary_storage::alias<local_segment_index_t>& large_and_medium_segments_indices,
721
740
  cub::detail::temporary_storage::alias<local_segment_index_t>& small_segments_indices,
722
- cub::detail::temporary_storage::alias<local_segment_index_t>& group_sizes)
741
+ cub::detail::temporary_storage::alias<local_segment_index_t>& group_sizes,
742
+ WrappedPolicyT wrapped_policy)
723
743
  {
724
744
  cudaError_t error = cudaSuccess;
725
745
 
@@ -737,15 +757,44 @@ private:
737
757
  ? static_cast<local_segment_index_t>(num_segments - current_seg_offset)
738
758
  : num_segments_per_invocation_limit;
739
759
 
740
- large_segments_selector.base_segment_offset = current_seg_offset;
741
- small_segments_selector.base_segment_offset = current_seg_offset;
742
- [[maybe_unused]] auto current_begin_offset = d_begin_offsets + current_seg_offset;
743
- [[maybe_unused]] auto current_end_offset = d_end_offsets + current_seg_offset;
760
+ kernel_source.SetSegmentOffset(large_segments_selector, current_seg_offset);
761
+ kernel_source.SetSegmentOffset(small_segments_selector, current_seg_offset);
762
+
763
+ BeginOffsetIteratorT current_begin_offset = d_begin_offsets;
764
+ EndOffsetIteratorT current_end_offset = d_end_offsets;
765
+
766
+ current_begin_offset += current_seg_offset;
767
+ current_end_offset += current_seg_offset;
744
768
 
745
769
  auto medium_indices_iterator =
746
770
  ::cuda::std::make_reverse_iterator(large_and_medium_segments_indices.get() + current_num_segments);
747
771
 
748
- error = CubDebug(cub::DevicePartition::IfNoNVTX(
772
+ // We call partition through dispatch instead of device because c.parallel needs to be able to call the kernel.
773
+ // This approach propagates the type erasure to partition.
774
+ using ChooseOffsetT = detail::choose_signed_offset<global_segment_offset_t>;
775
+ using PartitionOffsetT = typename ChooseOffsetT::type;
776
+ using DispatchThreeWayPartitionIfT = cub::DispatchThreeWayPartitionIf<
777
+ THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>,
778
+ decltype(large_and_medium_segments_indices.get()),
779
+ decltype(small_segments_indices.get()),
780
+ decltype(medium_indices_iterator),
781
+ decltype(group_sizes.get()),
782
+ decltype(large_segments_selector),
783
+ decltype(small_segments_selector),
784
+ PartitionOffsetT,
785
+ PartitionPolicyHub,
786
+ PartitionKernelSource,
787
+ KernelLauncherFactory>;
788
+
789
+ // Signed integer type for global offsets
790
+ // Check if the number of items exceeds the range covered by the selected signed offset type
791
+ error = ChooseOffsetT::is_exceeding_offset_type(num_items);
792
+ if (error)
793
+ {
794
+ return error;
795
+ }
796
+
797
+ DispatchThreeWayPartitionIfT::Dispatch(
749
798
  device_partition_temp_storage.get(),
750
799
  three_way_partition_temp_storage_bytes,
751
800
  THRUST_NS_QUALIFIER::counting_iterator<local_segment_index_t>(0),
@@ -753,10 +802,14 @@ private:
753
802
  small_segments_indices.get(),
754
803
  medium_indices_iterator,
755
804
  group_sizes.get(),
756
- current_num_segments,
757
805
  large_segments_selector,
758
806
  small_segments_selector,
759
- stream));
807
+ current_num_segments,
808
+ stream,
809
+ partition_kernel_source,
810
+ launcher_factory,
811
+ partition_max_policy);
812
+
760
813
  if (cudaSuccess != error)
761
814
  {
762
815
  return error;
@@ -771,43 +824,46 @@ private:
771
824
 
772
825
  #else // CUB_RDC_ENABLED
773
826
 
774
- # define CUB_TEMP_DEVICE_CODE \
775
- error = \
776
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(1, 1, 0, stream) \
777
- .doit( \
778
- detail::segmented_sort::DeviceSegmentedSortContinuationKernel< \
779
- typename PolicyHub::MaxPolicy, \
780
- LargeKernelT, \
781
- SmallKernelT, \
782
- KeyT, \
783
- ValueT, \
784
- BeginOffsetIteratorT, \
785
- EndOffsetIteratorT>, \
786
- large_kernel, \
787
- small_kernel, \
788
- current_num_segments, \
789
- d_keys.Current(), \
790
- GetFinalOutput<KeyT>(LargeSegmentPolicyT::RADIX_BITS, d_keys), \
791
- d_keys_double_buffer, \
792
- d_values.Current(), \
793
- GetFinalOutput<ValueT>(LargeSegmentPolicyT::RADIX_BITS, d_values), \
794
- d_values_double_buffer, \
795
- current_begin_offset, \
796
- current_end_offset, \
797
- group_sizes.get(), \
798
- large_and_medium_segments_indices.get(), \
799
- small_segments_indices.get()); \
800
- error = CubDebug(error); \
801
- \
802
- if (cudaSuccess != error) \
803
- { \
804
- return error; \
805
- } \
806
- \
807
- error = CubDebug(detail::DebugSyncStream(stream)); \
808
- if (cudaSuccess != error) \
809
- { \
810
- return error; \
827
+ # define CUB_TEMP_DEVICE_CODE \
828
+ error = \
829
+ launcher_factory(1, 1, 0, stream) \
830
+ .doit( \
831
+ detail::segmented_sort::DeviceSegmentedSortContinuationKernel< \
832
+ WrappedPolicyT, \
833
+ LargeKernelT, \
834
+ SmallKernelT, \
835
+ KeyT, \
836
+ ValueT, \
837
+ BeginOffsetIteratorT, \
838
+ EndOffsetIteratorT, \
839
+ KernelLauncherFactory>, \
840
+ large_kernel, \
841
+ small_kernel, \
842
+ current_num_segments, \
843
+ d_keys.Current(), \
844
+ GetFinalOutput<KeyT>(wrapped_policy.LargeSegmentRadixBits(), d_keys), \
845
+ d_keys_double_buffer, \
846
+ d_values.Current(), \
847
+ GetFinalOutput<ValueT>(wrapped_policy.LargeSegmentRadixBits(), d_values), \
848
+ d_values_double_buffer, \
849
+ current_begin_offset, \
850
+ current_end_offset, \
851
+ group_sizes.get(), \
852
+ large_and_medium_segments_indices.get(), \
853
+ small_segments_indices.get(), \
854
+ launcher_factory, \
855
+ wrapped_policy); \
856
+ error = CubDebug(error); \
857
+ \
858
+ if (cudaSuccess != error) \
859
+ { \
860
+ return error; \
861
+ } \
862
+ \
863
+ error = CubDebug(detail::DebugSyncStream(stream)); \
864
+ if (cudaSuccess != error) \
865
+ { \
866
+ return error; \
811
867
  }
812
868
 
813
869
  #endif // CUB_RDC_ENABLED
@@ -818,12 +874,12 @@ private:
818
874
  NV_IS_HOST,
819
875
  (
820
876
  local_segment_index_t h_group_sizes[num_selected_groups];
821
- error = CubDebug(cudaMemcpyAsync(h_group_sizes,
822
- group_sizes.get(),
823
- num_selected_groups *
824
- sizeof(local_segment_index_t),
825
- cudaMemcpyDeviceToHost,
826
- stream));
877
+ error = CubDebug(launcher_factory.MemcpyAsync(h_group_sizes,
878
+ group_sizes.get(),
879
+ num_selected_groups *
880
+ sizeof(local_segment_index_t),
881
+ cudaMemcpyDeviceToHost,
882
+ stream));
827
883
 
828
884
  if (cudaSuccess != error)
829
885
  {
@@ -836,23 +892,24 @@ private:
836
892
  return error;
837
893
  }
838
894
 
839
- error = detail::segmented_sort::DeviceSegmentedSortContinuation<LargeSegmentPolicyT,
840
- SmallAndMediumPolicyT>(
895
+ error = detail::segmented_sort::DeviceSegmentedSortContinuation(
841
896
  large_kernel,
842
897
  small_kernel,
843
898
  current_num_segments,
844
899
  d_keys.Current(),
845
- GetFinalOutput<KeyT>(LargeSegmentPolicyT::RADIX_BITS, d_keys),
900
+ GetFinalOutput<KeyT>(wrapped_policy.LargeSegmentRadixBits(), d_keys),
846
901
  d_keys_double_buffer,
847
902
  d_values.Current(),
848
- GetFinalOutput<ValueT>(LargeSegmentPolicyT::RADIX_BITS, d_values),
903
+ GetFinalOutput<ValueT>(wrapped_policy.LargeSegmentRadixBits(), d_values),
849
904
  d_values_double_buffer,
850
905
  current_begin_offset,
851
906
  current_end_offset,
852
907
  h_group_sizes,
853
908
  large_and_medium_segments_indices.get(),
854
909
  small_segments_indices.get(),
855
- stream);),
910
+ stream,
911
+ launcher_factory,
912
+ wrapped_policy);),
856
913
  // NV_IS_DEVICE:
857
914
  (CUB_TEMP_DEVICE_CODE));
858
915
  // clang-format on
@@ -862,16 +919,17 @@ private:
862
919
  return error;
863
920
  }
864
921
 
865
- template <typename LargeSegmentPolicyT, typename FallbackKernelT>
922
+ template <typename WrappedPolicyT, typename FallbackKernelT>
866
923
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t SortWithoutPartitioning(
867
924
  FallbackKernelT fallback_kernel,
868
925
  cub::detail::device_double_buffer<KeyT>& d_keys_double_buffer,
869
- cub::detail::device_double_buffer<ValueT>& d_values_double_buffer)
926
+ cub::detail::device_double_buffer<ValueT>& d_values_double_buffer,
927
+ WrappedPolicyT wrapped_policy)
870
928
  {
871
929
  cudaError_t error = cudaSuccess;
872
930
 
873
- const auto blocks_in_grid = static_cast<local_segment_index_t>(num_segments);
874
- constexpr auto threads_in_block = static_cast<unsigned int>(LargeSegmentPolicyT::BLOCK_THREADS);
931
+ const auto blocks_in_grid = static_cast<local_segment_index_t>(num_segments);
932
+ const auto threads_in_block = static_cast<unsigned int>(wrapped_policy.LargeSegment().BlockThreads());
875
933
 
876
934
  // Log kernel configuration
877
935
  #ifdef CUB_DEBUG_LOG
@@ -880,18 +938,18 @@ private:
880
938
  blocks_in_grid,
881
939
  threads_in_block,
882
940
  (long long) stream,
883
- LargeSegmentPolicyT::ITEMS_PER_THREAD,
884
- LargeSegmentPolicyT::RADIX_BITS);
941
+ wrapped_policy.LargeSegment().ItemsPerThread(),
942
+ wrapped_policy.LargeSegmentRadixBits());
885
943
  #endif // CUB_DEBUG_LOG
886
944
 
887
945
  // Invoke fallback kernel
888
- THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(blocks_in_grid, threads_in_block, 0, stream)
946
+ launcher_factory(blocks_in_grid, threads_in_block, 0, stream)
889
947
  .doit(fallback_kernel,
890
948
  d_keys.Current(),
891
- GetFinalOutput(LargeSegmentPolicyT::RADIX_BITS, d_keys),
949
+ GetFinalOutput(wrapped_policy.LargeSegmentRadixBits(), d_keys),
892
950
  d_keys_double_buffer,
893
951
  d_values.Current(),
894
- GetFinalOutput(LargeSegmentPolicyT::RADIX_BITS, d_values),
952
+ GetFinalOutput(wrapped_policy.LargeSegmentRadixBits(), d_values),
895
953
  d_values_double_buffer,
896
954
  d_begin_offsets,
897
955
  d_end_offsets);