cuda-cccl 0.1.3.2.0.dev438__cp312-cp312-manylinux_2_24_aarch64.whl → 0.3.1__cp312-cp312-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (177) hide show
  1. cuda/cccl/cooperative/__init__.py +7 -1
  2. cuda/cccl/cooperative/experimental/__init__.py +21 -5
  3. cuda/cccl/headers/include/cub/agent/agent_adjacent_difference.cuh +2 -5
  4. cuda/cccl/headers/include/cub/agent/agent_batch_memcpy.cuh +2 -5
  5. cuda/cccl/headers/include/cub/agent/agent_for.cuh +2 -5
  6. cuda/cccl/headers/include/cub/agent/agent_merge.cuh +23 -21
  7. cuda/cccl/headers/include/cub/agent/agent_merge_sort.cuh +21 -3
  8. cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +25 -5
  9. cuda/cccl/headers/include/cub/agent/agent_radix_sort_histogram.cuh +2 -5
  10. cuda/cccl/headers/include/cub/agent/agent_radix_sort_onesweep.cuh +2 -5
  11. cuda/cccl/headers/include/cub/agent/agent_radix_sort_upsweep.cuh +2 -5
  12. cuda/cccl/headers/include/cub/agent/agent_rle.cuh +2 -5
  13. cuda/cccl/headers/include/cub/agent/agent_scan.cuh +5 -1
  14. cuda/cccl/headers/include/cub/agent/agent_scan_by_key.cuh +2 -5
  15. cuda/cccl/headers/include/cub/agent/agent_segmented_radix_sort.cuh +2 -5
  16. cuda/cccl/headers/include/cub/agent/agent_select_if.cuh +2 -5
  17. cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +24 -19
  18. cuda/cccl/headers/include/cub/agent/agent_three_way_partition.cuh +2 -5
  19. cuda/cccl/headers/include/cub/agent/agent_unique_by_key.cuh +22 -5
  20. cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
  21. cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +3 -2
  22. cuda/cccl/headers/include/cub/block/block_radix_sort.cuh +4 -2
  23. cuda/cccl/headers/include/cub/detail/device_memory_resource.cuh +1 -0
  24. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
  25. cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
  26. cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
  27. cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +158 -247
  28. cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
  29. cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +4 -4
  30. cuda/cccl/headers/include/cub/device/dispatch/dispatch_radix_sort.cuh +2 -11
  31. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +8 -26
  32. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -6
  33. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_nondeterministic.cuh +0 -1
  34. cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +320 -262
  35. cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +10 -5
  36. cuda/cccl/headers/include/cub/device/dispatch/kernels/scan.cuh +2 -5
  37. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_reduce.cuh +2 -5
  38. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
  39. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
  40. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_adjacent_difference.cuh +2 -5
  41. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_batch_memcpy.cuh +2 -5
  42. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_for.cuh +2 -5
  43. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_histogram.cuh +2 -5
  44. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge.cuh +2 -5
  45. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge_sort.cuh +8 -0
  46. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_radix_sort.cuh +2 -5
  47. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh +2 -5
  48. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_run_length_encode.cuh +2 -5
  49. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan.cuh +2 -5
  50. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan_by_key.cuh +2 -5
  51. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +204 -55
  52. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_three_way_partition.cuh +2 -5
  53. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
  54. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_unique_by_key.cuh +10 -0
  55. cuda/cccl/headers/include/cub/util_device.cuh +51 -35
  56. cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_shfl.cuh +3 -2
  57. cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_smem.cuh +3 -2
  58. cuda/cccl/headers/include/cub/warp/specializations/warp_scan_shfl.cuh +2 -2
  59. cuda/cccl/headers/include/cuda/__algorithm/common.h +1 -1
  60. cuda/cccl/headers/include/cuda/__algorithm/copy.h +4 -4
  61. cuda/cccl/headers/include/cuda/__algorithm/fill.h +1 -1
  62. cuda/cccl/headers/include/cuda/__device/all_devices.h +47 -147
  63. cuda/cccl/headers/include/cuda/__device/arch_traits.h +51 -49
  64. cuda/cccl/headers/include/cuda/__device/attributes.h +177 -127
  65. cuda/cccl/headers/include/cuda/__device/device_ref.h +32 -51
  66. cuda/cccl/headers/include/cuda/__device/physical_device.h +120 -91
  67. cuda/cccl/headers/include/cuda/__driver/driver_api.h +330 -36
  68. cuda/cccl/headers/include/cuda/__event/event.h +8 -8
  69. cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
  70. cuda/cccl/headers/include/cuda/__event/timed_event.h +4 -4
  71. cuda/cccl/headers/include/cuda/__fwd/devices.h +44 -0
  72. cuda/cccl/headers/include/cuda/__fwd/zip_iterator.h +9 -0
  73. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
  74. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
  75. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
  76. cuda/cccl/headers/include/cuda/__iterator/zip_common.h +158 -0
  77. cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +8 -120
  78. cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +593 -0
  79. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
  80. cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +4 -3
  81. cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
  82. cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
  83. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +18 -12
  84. cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
  85. cuda/cccl/headers/include/cuda/__utility/basic_any.h +1 -1
  86. cuda/cccl/headers/include/cuda/algorithm +1 -1
  87. cuda/cccl/headers/include/cuda/devices +10 -0
  88. cuda/cccl/headers/include/cuda/iterator +1 -0
  89. cuda/cccl/headers/include/cuda/std/__bit/countl.h +8 -1
  90. cuda/cccl/headers/include/cuda/std/__bit/countr.h +2 -2
  91. cuda/cccl/headers/include/cuda/std/__bit/reference.h +11 -11
  92. cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
  93. cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
  94. cuda/cccl/headers/include/cuda/std/__chrono/duration.h +16 -16
  95. cuda/cccl/headers/include/cuda/std/__chrono/steady_clock.h +5 -5
  96. cuda/cccl/headers/include/cuda/std/__chrono/system_clock.h +5 -5
  97. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
  98. cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
  99. cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
  100. cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
  101. cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
  102. cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
  103. cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
  104. cuda/cccl/headers/include/cuda/std/__floating_point/fp.h +1 -1
  105. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
  106. cuda/cccl/headers/include/cuda/std/__tuple_dir/make_tuple_types.h +23 -1
  107. cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like.h +4 -0
  108. cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like_ext.h +4 -0
  109. cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
  110. cuda/cccl/headers/include/cuda/std/string_view +12 -5
  111. cuda/cccl/headers/include/cuda/std/version +1 -4
  112. cuda/cccl/headers/include/thrust/detail/integer_math.h +3 -20
  113. cuda/cccl/headers/include/thrust/iterator/iterator_traits.h +11 -0
  114. cuda/cccl/headers/include/thrust/system/cuda/detail/copy.h +33 -0
  115. cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
  116. cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
  117. cuda/cccl/parallel/experimental/__init__.py +21 -70
  118. cuda/compute/__init__.py +77 -0
  119. cuda/{cccl/parallel/experimental → compute}/_bindings.pyi +28 -0
  120. cuda/{cccl/parallel/experimental → compute}/_bindings_impl.pyx +141 -1
  121. cuda/{cccl/parallel/experimental → compute}/algorithms/__init__.py +4 -0
  122. cuda/{cccl/parallel/experimental → compute}/algorithms/_histogram.py +2 -2
  123. cuda/{cccl/parallel/experimental → compute}/algorithms/_merge_sort.py +2 -2
  124. cuda/{cccl/parallel/experimental → compute}/algorithms/_radix_sort.py +3 -3
  125. cuda/{cccl/parallel/experimental → compute}/algorithms/_reduce.py +2 -4
  126. cuda/{cccl/parallel/experimental → compute}/algorithms/_scan.py +4 -6
  127. cuda/{cccl/parallel/experimental → compute}/algorithms/_segmented_reduce.py +2 -2
  128. cuda/compute/algorithms/_three_way_partition.py +261 -0
  129. cuda/{cccl/parallel/experimental → compute}/algorithms/_transform.py +4 -4
  130. cuda/{cccl/parallel/experimental → compute}/algorithms/_unique_by_key.py +2 -2
  131. cuda/compute/cu12/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  132. cuda/{cccl/parallel/experimental → compute}/cu12/cccl/libcccl.c.parallel.so +0 -0
  133. cuda/compute/cu13/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  134. cuda/{cccl/parallel/experimental → compute}/cu13/cccl/libcccl.c.parallel.so +0 -0
  135. cuda/{cccl/parallel/experimental → compute}/iterators/_factories.py +8 -8
  136. cuda/{cccl/parallel/experimental → compute}/struct.py +2 -2
  137. cuda/coop/__init__.py +8 -0
  138. cuda/{cccl/cooperative/experimental → coop}/_nvrtc.py +3 -2
  139. cuda/{cccl/cooperative/experimental → coop}/_scan_op.py +3 -3
  140. cuda/{cccl/cooperative/experimental → coop}/_types.py +2 -2
  141. cuda/{cccl/cooperative/experimental → coop}/_typing.py +1 -1
  142. cuda/{cccl/cooperative/experimental → coop}/block/__init__.py +6 -6
  143. cuda/{cccl/cooperative/experimental → coop}/block/_block_exchange.py +4 -4
  144. cuda/{cccl/cooperative/experimental → coop}/block/_block_load_store.py +6 -6
  145. cuda/{cccl/cooperative/experimental → coop}/block/_block_merge_sort.py +4 -4
  146. cuda/{cccl/cooperative/experimental → coop}/block/_block_radix_sort.py +6 -6
  147. cuda/{cccl/cooperative/experimental → coop}/block/_block_reduce.py +6 -6
  148. cuda/{cccl/cooperative/experimental → coop}/block/_block_scan.py +7 -7
  149. cuda/coop/warp/__init__.py +9 -0
  150. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_merge_sort.py +3 -3
  151. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_reduce.py +6 -6
  152. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_scan.py +4 -4
  153. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/METADATA +1 -1
  154. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/RECORD +171 -166
  155. cuda/cccl/cooperative/experimental/warp/__init__.py +0 -9
  156. cuda/cccl/headers/include/cub/device/dispatch/dispatch_advance_iterators.cuh +0 -111
  157. cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
  158. cuda/cccl/parallel/experimental/.gitignore +0 -4
  159. cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  160. cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  161. /cuda/{cccl/parallel/experimental → compute}/_bindings.py +0 -0
  162. /cuda/{cccl/parallel/experimental → compute}/_caching.py +0 -0
  163. /cuda/{cccl/parallel/experimental → compute}/_cccl_interop.py +0 -0
  164. /cuda/{cccl/parallel/experimental → compute}/_utils/__init__.py +0 -0
  165. /cuda/{cccl/parallel/experimental → compute}/_utils/protocols.py +0 -0
  166. /cuda/{cccl/parallel/experimental → compute}/_utils/temp_storage_buffer.py +0 -0
  167. /cuda/{cccl/parallel/experimental → compute}/cccl/.gitkeep +0 -0
  168. /cuda/{cccl/parallel/experimental → compute}/iterators/__init__.py +0 -0
  169. /cuda/{cccl/parallel/experimental → compute}/iterators/_iterators.py +0 -0
  170. /cuda/{cccl/parallel/experimental → compute}/iterators/_zip_iterator.py +0 -0
  171. /cuda/{cccl/parallel/experimental → compute}/numba_utils.py +0 -0
  172. /cuda/{cccl/parallel/experimental → compute}/op.py +0 -0
  173. /cuda/{cccl/parallel/experimental → compute}/typing.py +0 -0
  174. /cuda/{cccl/cooperative/experimental → coop}/_caching.py +0 -0
  175. /cuda/{cccl/cooperative/experimental → coop}/_common.py +0 -0
  176. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/WHEEL +0 -0
  177. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -43,10 +43,120 @@
43
43
 
44
44
  CUB_NAMESPACE_BEGIN
45
45
 
46
- namespace detail
46
+ namespace detail::segmented_sort
47
47
  {
48
- namespace segmented_sort
48
+
49
+ template <typename PolicyT, typename = void>
50
+ struct SegmentedSortPolicyWrapper : PolicyT
51
+ {
52
+ CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper(PolicyT base)
53
+ : PolicyT(base)
54
+ {}
55
+ };
56
+
57
+ template <typename StaticPolicyT>
58
+ struct SegmentedSortPolicyWrapper<StaticPolicyT,
59
+ _CUDA_VSTD::void_t<typename StaticPolicyT::LargeSegmentPolicy,
60
+ typename StaticPolicyT::SmallSegmentPolicy,
61
+ typename StaticPolicyT::MediumSegmentPolicy>> : StaticPolicyT
49
62
  {
63
+ CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper(StaticPolicyT base)
64
+ : StaticPolicyT(base)
65
+ {}
66
+
67
+ CUB_RUNTIME_FUNCTION static constexpr auto LargeSegment()
68
+ {
69
+ return cub::detail::MakePolicyWrapper(typename StaticPolicyT::LargeSegmentPolicy());
70
+ }
71
+
72
+ CUB_RUNTIME_FUNCTION static constexpr auto SmallSegment()
73
+ {
74
+ return cub::detail::MakePolicyWrapper(typename StaticPolicyT::SmallSegmentPolicy());
75
+ }
76
+
77
+ CUB_RUNTIME_FUNCTION static constexpr auto MediumSegment()
78
+ {
79
+ return cub::detail::MakePolicyWrapper(typename StaticPolicyT::MediumSegmentPolicy());
80
+ }
81
+
82
+ CUB_RUNTIME_FUNCTION static constexpr int PartitioningThreshold()
83
+ {
84
+ return StaticPolicyT::PARTITIONING_THRESHOLD;
85
+ }
86
+
87
+ CUB_RUNTIME_FUNCTION static constexpr int LargeSegmentRadixBits()
88
+ {
89
+ return StaticPolicyT::LargeSegmentPolicy::RADIX_BITS;
90
+ }
91
+
92
+ CUB_RUNTIME_FUNCTION static constexpr int SegmentsPerSmallBlock()
93
+ {
94
+ return StaticPolicyT::SmallSegmentPolicy::SEGMENTS_PER_BLOCK;
95
+ }
96
+
97
+ CUB_RUNTIME_FUNCTION static constexpr int SegmentsPerMediumBlock()
98
+ {
99
+ return StaticPolicyT::MediumSegmentPolicy::SEGMENTS_PER_BLOCK;
100
+ }
101
+
102
+ CUB_RUNTIME_FUNCTION static constexpr int SmallPolicyItemsPerTile()
103
+ {
104
+ return StaticPolicyT::SmallSegmentPolicy::ITEMS_PER_TILE;
105
+ }
106
+
107
+ CUB_RUNTIME_FUNCTION static constexpr int MediumPolicyItemsPerTile()
108
+ {
109
+ return StaticPolicyT::MediumSegmentPolicy::ITEMS_PER_TILE;
110
+ }
111
+
112
+ CUB_RUNTIME_FUNCTION static constexpr CacheLoadModifier LargeSegmentLoadModifier()
113
+ {
114
+ return StaticPolicyT::LargeSegmentPolicy::LOAD_MODIFIER;
115
+ }
116
+
117
+ CUB_RUNTIME_FUNCTION static constexpr BlockLoadAlgorithm LargeSegmentLoadAlgorithm()
118
+ {
119
+ return StaticPolicyT::LargeSegmentPolicy::LOAD_ALGORITHM;
120
+ }
121
+
122
+ CUB_RUNTIME_FUNCTION static constexpr WarpLoadAlgorithm MediumSegmentLoadAlgorithm()
123
+ {
124
+ return StaticPolicyT::MediumSegmentPolicy::LOAD_ALGORITHM;
125
+ }
126
+
127
+ CUB_RUNTIME_FUNCTION static constexpr WarpLoadAlgorithm SmallSegmentLoadAlgorithm()
128
+ {
129
+ return StaticPolicyT::SmallSegmentPolicy::LOAD_ALGORITHM;
130
+ }
131
+
132
+ CUB_RUNTIME_FUNCTION static constexpr WarpStoreAlgorithm MediumSegmentStoreAlgorithm()
133
+ {
134
+ return StaticPolicyT::MediumSegmentPolicy::STORE_ALGORITHM;
135
+ }
136
+
137
+ CUB_RUNTIME_FUNCTION static constexpr WarpStoreAlgorithm SmallSegmentStoreAlgorithm()
138
+ {
139
+ return StaticPolicyT::SmallSegmentPolicy::STORE_ALGORITHM;
140
+ }
141
+
142
+ #if defined(CUB_ENABLE_POLICY_PTX_JSON)
143
+ _CCCL_DEVICE static constexpr auto EncodedPolicy()
144
+ {
145
+ using namespace ptx_json;
146
+ return object<key<"LargeSegmentPolicy">() = LargeSegment().EncodedPolicy(),
147
+ key<"SmallSegmentPolicy">() = SmallSegment().EncodedPolicy(),
148
+ key<"MediumSegmentPolicy">() = MediumSegment().EncodedPolicy(),
149
+ key<"PartitioningThreshold">() = value<StaticPolicyT::PARTITIONING_THRESHOLD>()>();
150
+ }
151
+ #endif
152
+ };
153
+
154
+ template <typename PolicyT>
155
+ CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper<PolicyT> MakeSegmentedSortPolicyWrapper(PolicyT policy)
156
+ {
157
+ return SegmentedSortPolicyWrapper<PolicyT>{policy};
158
+ }
159
+
50
160
  template <typename KeyT, typename ValueT>
51
161
  struct policy_hub
52
162
  {
@@ -71,12 +181,19 @@ struct policy_hub
71
181
 
72
182
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(7);
73
183
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(7);
74
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
75
- BLOCK_THREADS,
76
- // Small policy
77
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
78
- // Medium policy
79
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
184
+
185
+ using SmallSegmentPolicy =
186
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
187
+ 4 /* Threads per segment */,
188
+ ITEMS_PER_SMALL_THREAD,
189
+ WARP_LOAD_DIRECT,
190
+ LOAD_DEFAULT>;
191
+ using MediumSegmentPolicy =
192
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
193
+ 32 /* Threads per segment */,
194
+ ITEMS_PER_MEDIUM_THREAD,
195
+ WARP_LOAD_DIRECT,
196
+ LOAD_DEFAULT>;
80
197
  };
81
198
 
82
199
  struct Policy600 : ChainedPolicy<600, Policy600, Policy500>
@@ -97,12 +214,19 @@ struct policy_hub
97
214
 
98
215
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
99
216
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
100
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
101
- BLOCK_THREADS,
102
- // Small policy
103
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
104
- // Medium policy
105
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
217
+
218
+ using SmallSegmentPolicy =
219
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
220
+ 4 /* Threads per segment */,
221
+ ITEMS_PER_SMALL_THREAD,
222
+ WARP_LOAD_DIRECT,
223
+ LOAD_DEFAULT>;
224
+ using MediumSegmentPolicy =
225
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
226
+ 32 /* Threads per segment */,
227
+ ITEMS_PER_MEDIUM_THREAD,
228
+ WARP_LOAD_DIRECT,
229
+ LOAD_DEFAULT>;
106
230
  };
107
231
 
108
232
  struct Policy610 : ChainedPolicy<610, Policy610, Policy600>
@@ -123,12 +247,19 @@ struct policy_hub
123
247
 
124
248
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
125
249
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
126
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
127
- BLOCK_THREADS,
128
- // Small policy
129
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
130
- // Medium policy
131
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
250
+
251
+ using SmallSegmentPolicy =
252
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
253
+ 4 /* Threads per segment */,
254
+ ITEMS_PER_SMALL_THREAD,
255
+ WARP_LOAD_DIRECT,
256
+ LOAD_DEFAULT>;
257
+ using MediumSegmentPolicy =
258
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
259
+ 32 /* Threads per segment */,
260
+ ITEMS_PER_MEDIUM_THREAD,
261
+ WARP_LOAD_DIRECT,
262
+ LOAD_DEFAULT>;
132
263
  };
133
264
 
134
265
  struct Policy620 : ChainedPolicy<620, Policy620, Policy610>
@@ -149,12 +280,19 @@ struct policy_hub
149
280
 
150
281
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
151
282
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
152
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
153
- BLOCK_THREADS,
154
- // Small policy
155
- AgentSubWarpMergeSortPolicy<4 /* Threads per segment */, ITEMS_PER_SMALL_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>,
156
- // Medium policy
157
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
283
+
284
+ using SmallSegmentPolicy =
285
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
286
+ 4 /* Threads per segment */,
287
+ ITEMS_PER_SMALL_THREAD,
288
+ WARP_LOAD_DIRECT,
289
+ LOAD_DEFAULT>;
290
+ using MediumSegmentPolicy =
291
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
292
+ 32 /* Threads per segment */,
293
+ ITEMS_PER_MEDIUM_THREAD,
294
+ WARP_LOAD_DIRECT,
295
+ LOAD_DEFAULT>;
158
296
  };
159
297
 
160
298
  struct Policy700 : ChainedPolicy<700, Policy700, Policy620>
@@ -175,15 +313,19 @@ struct policy_hub
175
313
 
176
314
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(7);
177
315
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(KEYS_ONLY ? 11 : 7);
178
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
179
- BLOCK_THREADS,
180
- // Small policy
181
- AgentSubWarpMergeSortPolicy<KEYS_ONLY ? 4 : 8 /* Threads per segment */,
182
- ITEMS_PER_SMALL_THREAD,
183
- WARP_LOAD_DIRECT,
184
- LOAD_DEFAULT>,
185
- // Medium policy
186
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_DIRECT, LOAD_DEFAULT>>;
316
+
317
+ using SmallSegmentPolicy =
318
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
319
+ KEYS_ONLY ? 4 : 8 /* Threads per segment */,
320
+ ITEMS_PER_SMALL_THREAD,
321
+ WARP_LOAD_DIRECT,
322
+ LOAD_DEFAULT>;
323
+ using MediumSegmentPolicy =
324
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
325
+ 32 /* Threads per segment */,
326
+ ITEMS_PER_MEDIUM_THREAD,
327
+ WARP_LOAD_DIRECT,
328
+ LOAD_DEFAULT>;
187
329
  };
188
330
 
189
331
  struct Policy800 : ChainedPolicy<800, Policy800, Policy700>
@@ -202,15 +344,19 @@ struct policy_hub
202
344
 
203
345
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
204
346
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(KEYS_ONLY ? 7 : 11);
205
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
206
- BLOCK_THREADS,
207
- // Small policy
208
- AgentSubWarpMergeSortPolicy<KEYS_ONLY ? 4 : 2 /* Threads per segment */,
209
- ITEMS_PER_SMALL_THREAD,
210
- WARP_LOAD_TRANSPOSE,
211
- LOAD_DEFAULT>,
212
- // Medium policy
213
- AgentSubWarpMergeSortPolicy<32 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_TRANSPOSE, LOAD_DEFAULT>>;
347
+
348
+ using SmallSegmentPolicy =
349
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
350
+ KEYS_ONLY ? 4 : 2 /* Threads per segment */,
351
+ ITEMS_PER_SMALL_THREAD,
352
+ WARP_LOAD_TRANSPOSE,
353
+ LOAD_DEFAULT>;
354
+ using MediumSegmentPolicy =
355
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
356
+ 32 /* Threads per segment */,
357
+ ITEMS_PER_MEDIUM_THREAD,
358
+ WARP_LOAD_TRANSPOSE,
359
+ LOAD_DEFAULT>;
214
360
  };
215
361
 
216
362
  struct Policy860 : ChainedPolicy<860, Policy860, Policy800>
@@ -230,20 +376,23 @@ struct policy_hub
230
376
  static constexpr bool LARGE_ITEMS = sizeof(DominantT) > 4;
231
377
  static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(LARGE_ITEMS ? 7 : 9);
232
378
  static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(LARGE_ITEMS ? 9 : 7);
233
- using SmallAndMediumSegmentedSortPolicyT = AgentSmallAndMediumSegmentedSortPolicy<
234
- BLOCK_THREADS,
235
- // Small policy
236
- AgentSubWarpMergeSortPolicy<LARGE_ITEMS ? 8 : 2 /* Threads per segment */,
237
- ITEMS_PER_SMALL_THREAD,
238
- WARP_LOAD_TRANSPOSE,
239
- LOAD_LDG>,
240
- // Medium policy
241
- AgentSubWarpMergeSortPolicy<16 /* Threads per segment */, ITEMS_PER_MEDIUM_THREAD, WARP_LOAD_TRANSPOSE, LOAD_LDG>>;
379
+
380
+ using SmallSegmentPolicy =
381
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
382
+ LARGE_ITEMS ? 8 : 2 /* Threads per segment */,
383
+ ITEMS_PER_SMALL_THREAD,
384
+ WARP_LOAD_TRANSPOSE,
385
+ LOAD_LDG>;
386
+ using MediumSegmentPolicy =
387
+ AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
388
+ 16 /* Threads per segment */,
389
+ ITEMS_PER_MEDIUM_THREAD,
390
+ WARP_LOAD_TRANSPOSE,
391
+ LOAD_LDG>;
242
392
  };
243
393
 
244
394
  using MaxPolicy = Policy860;
245
395
  };
246
- } // namespace segmented_sort
247
- } // namespace detail
396
+ } // namespace detail::segmented_sort
248
397
 
249
398
  CUB_NAMESPACE_END
@@ -47,9 +47,7 @@
47
47
 
48
48
  CUB_NAMESPACE_BEGIN
49
49
 
50
- namespace detail
51
- {
52
- namespace three_way_partition
50
+ namespace detail::three_way_partition
53
51
  {
54
52
 
55
53
  template <typename PolicyT, typename = void>
@@ -437,7 +435,6 @@ struct policy_hub
437
435
 
438
436
  using MaxPolicy = Policy1000;
439
437
  };
440
- } // namespace three_way_partition
441
- } // namespace detail
438
+ } // namespace detail::three_way_partition
442
439
 
443
440
  CUB_NAMESPACE_END
@@ -282,21 +282,45 @@ _CCCL_HOST_DEVICE constexpr int arch_to_min_bytes_in_flight(int sm_arch)
282
282
  return 12 * 1024; // V100 and below
283
283
  }
284
284
 
285
- template <typename T, typename... Ts>
286
- _CCCL_HOST_DEVICE constexpr bool all_equal([[maybe_unused]] T head, Ts... tail)
285
+ template <typename H, typename... Ts>
286
+ _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal(H head, Ts... values)
287
287
  {
288
- return ((head == tail) && ...);
288
+ size_t first = 0;
289
+ for (size_t v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
290
+ {
291
+ if (v == 0)
292
+ {
293
+ continue;
294
+ }
295
+ if (first == 0)
296
+ {
297
+ first = v;
298
+ }
299
+ else if (v != first)
300
+ {
301
+ return false;
302
+ }
303
+ }
304
+ return true;
289
305
  }
290
306
 
291
- _CCCL_HOST_DEVICE constexpr bool all_equal()
307
+ _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal()
292
308
  {
293
309
  return true;
294
310
  }
295
311
 
296
- template <typename T, typename... Ts>
297
- _CCCL_HOST_DEVICE constexpr auto first_item(T head, Ts...) -> T
312
+ template <typename H, typename... Ts>
313
+ _CCCL_HOST_DEVICE constexpr auto first_nonzero_value(H head, Ts... values)
298
314
  {
299
- return head;
315
+ for (auto v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
316
+ {
317
+ if (v != 0)
318
+ {
319
+ return v;
320
+ }
321
+ }
322
+ // we only reach here when all input are not contiguous and the output has a void value type
323
+ return H{1};
300
324
  }
301
325
 
302
326
  template <typename T>
@@ -336,25 +360,36 @@ struct policy_hub<RequiresStableAddress,
336
360
  (THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn> && ...);
337
361
  static constexpr bool all_input_values_trivially_reloc =
338
362
  (THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>> && ...);
339
- static constexpr bool can_memcpy_inputs = all_inputs_contiguous && all_input_values_trivially_reloc;
363
+ static constexpr bool can_memcpy_all_inputs = all_inputs_contiguous && all_input_values_trivially_reloc;
364
+ // the vectorized kernel supports mixing contiguous and non-contiguous iterators
365
+ static constexpr bool can_memcpy_contiguous_inputs =
366
+ ((!THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>
367
+ || THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>>)
368
+ && ...);
340
369
 
341
370
  // for vectorized policy:
342
- static constexpr bool all_input_values_same_size = all_equal(sizeof(it_value_t<RandomAccessIteratorsIn>)...);
343
- static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+
344
- // if there are no inputs, we take the size of the output value
345
- static constexpr int value_type_size =
346
- first_item(int{sizeof(it_value_t<RandomAccessIteratorsIn>)}..., int{size_of<it_value_t<RandomAccessIteratorOut>>});
371
+ static constexpr bool all_contiguous_input_values_same_size = all_nonzero_equal(
372
+ (sizeof(it_value_t<RandomAccessIteratorsIn>)
373
+ * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...);
374
+ static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+
375
+ // find the value type size of the first contiguous iterator. if there are no inputs, we take the size of the output
376
+ // value type
377
+ static constexpr int contiguous_value_type_size = first_nonzero_value(
378
+ (int{sizeof(it_value_t<RandomAccessIteratorsIn>)}
379
+ * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...,
380
+ int{size_of<it_value_t<RandomAccessIteratorOut>>});
347
381
  static constexpr bool value_type_divides_load_store_size =
348
- load_store_word_size % value_type_size == 0; // implicitly checks that value_type_size <= load_store_word_size
382
+ load_store_word_size % contiguous_value_type_size == 0; // implicitly checks that value_type_size <=
383
+ // load_store_word_size
349
384
  static constexpr int target_bytes_per_thread =
350
385
  no_input_streams ? 16 /* by experiment on RTX 5090 */ : 32 /* guestimate by gevtushenko for loading */;
351
386
  static constexpr int items_per_thread_vec =
352
- ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / value_type_size;
387
+ ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / contiguous_value_type_size;
353
388
  using default_vectorized_policy_t = vectorized_policy_t<256, items_per_thread_vec, load_store_word_size>;
354
389
 
355
390
  static constexpr bool fallback_to_prefetch =
356
- RequiresStableAddress || !can_memcpy_inputs || !all_input_values_same_size || !value_type_divides_load_store_size
357
- || !DenseOutput;
391
+ RequiresStableAddress || !can_memcpy_contiguous_inputs || !all_contiguous_input_values_same_size
392
+ || !value_type_divides_load_store_size || !DenseOutput;
358
393
 
359
394
  // TODO(bgruber): consider a separate kernel for just filling
360
395
 
@@ -380,7 +415,7 @@ struct policy_hub<RequiresStableAddress,
380
415
  block_threads* async_policy::min_items_per_thread,
381
416
  ldgsts_size_and_align)
382
417
  > int{max_smem_per_block};
383
- static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams;
418
+ static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams || !can_memcpy_all_inputs;
384
419
 
385
420
  public:
386
421
  static constexpr int min_bif = arch_to_min_bytes_in_flight(800);
@@ -421,7 +456,8 @@ struct policy_hub<RequiresStableAddress,
421
456
  (((int{sizeof(it_value_t<RandomAccessIteratorsIn>)} * AsyncBlockSize) % max_alignment == 0) && ...);
422
457
  static constexpr bool enough_threads_for_peeling = AsyncBlockSize >= alignment; // head and tail bytes
423
458
  static constexpr bool fallback_to_vectorized =
424
- exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams;
459
+ exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams
460
+ || !can_memcpy_all_inputs;
425
461
 
426
462
  public:
427
463
  static constexpr int min_bif = arch_to_min_bytes_in_flight(PtxVersion);
@@ -788,6 +788,16 @@ struct UniqueByKeyPolicyWrapper<StaticPolicyT,
788
788
  {
789
789
  return cub::detail::MakePolicyWrapper(typename StaticPolicyT::UniqueByKeyPolicyT());
790
790
  }
791
+
792
+ #if defined(CUB_ENABLE_POLICY_PTX_JSON)
793
+ _CCCL_DEVICE static constexpr auto EncodedPolicy()
794
+ {
795
+ using namespace ptx_json;
796
+ return object<key<"UniqueByKeyPolicyT">() = UniqueByKey().EncodedPolicy(),
797
+ key<"DelayConstructor">() =
798
+ StaticPolicyT::UniqueByKeyPolicyT::detail::delay_constructor_t::EncodedConstructor()>();
799
+ }
800
+ #endif
791
801
  };
792
802
 
793
803
  template <typename PolicyT>
@@ -47,7 +47,6 @@
47
47
  // for backward compatibility
48
48
  #include <cub/util_temporary_storage.cuh>
49
49
 
50
- #include <cuda/std/__cuda/ensure_current_device.h> // IWYU pragma: export
51
50
  #include <cuda/std/__type_traits/conditional.h>
52
51
  #include <cuda/std/__utility/forward.h>
53
52
  #include <cuda/std/array>
@@ -104,7 +103,34 @@ CUB_RUNTIME_FUNCTION inline int CurrentDevice()
104
103
 
105
104
  //! @brief RAII helper which saves the current device and switches to the specified device on construction and switches
106
105
  //! to the saved device on destruction.
107
- using SwitchDevice = ::cuda::__ensure_current_device;
106
+ class SwitchDevice
107
+ {
108
+ int target_device_;
109
+ int original_device_;
110
+
111
+ public:
112
+ //! @brief Queries the current device and if that is different than @p target_device sets the current device to
113
+ //! @p target_device
114
+ SwitchDevice(const int target_device)
115
+ : target_device_(target_device)
116
+ {
117
+ CubDebug(cudaGetDevice(&original_device_));
118
+ if (original_device_ != target_device_)
119
+ {
120
+ CubDebug(cudaSetDevice(target_device_));
121
+ }
122
+ }
123
+
124
+ //! @brief If the @p original_device was not equal to @p target_device sets the current device back to
125
+ //! @p original_device
126
+ ~SwitchDevice()
127
+ {
128
+ if (original_device_ != target_device_)
129
+ {
130
+ CubDebug(cudaSetDevice(original_device_));
131
+ }
132
+ }
133
+ };
108
134
 
109
135
  # endif // _CCCL_DOXYGEN_INVOKED
110
136
 
@@ -684,16 +710,31 @@ struct KernelConfig
684
710
  return launcher_factory.MaxSmOccupancy(sm_occupancy, kernel_ptr, block_threads);
685
711
  }
686
712
  };
687
-
688
713
  } // namespace detail
689
714
  #endif // !_CCCL_COMPILER(NVRTC)
690
715
 
716
+ namespace detail
717
+ {
718
+ template <typename T>
719
+ struct get_active_policy
720
+ {
721
+ using type = typename T::ActivePolicy;
722
+ };
723
+ } // namespace detail
724
+
691
725
  /// Helper for dispatching into a policy chain
692
726
  template <int PolicyPtxVersion, typename PolicyT, typename PrevPolicyT>
693
727
  struct ChainedPolicy
694
728
  {
729
+ private:
730
+ static constexpr bool have_previous_policy = !::cuda::std::is_same_v<PolicyT, PrevPolicyT>;
731
+
732
+ public:
695
733
  /// The policy for the active compiler pass
696
- using ActivePolicy = ::cuda::std::_If<(CUB_PTX_ARCH < PolicyPtxVersion), typename PrevPolicyT::ActivePolicy, PolicyT>;
734
+ using ActivePolicy =
735
+ typename ::cuda::std::_If<(CUB_PTX_ARCH < PolicyPtxVersion && have_previous_policy),
736
+ detail::get_active_policy<PrevPolicyT>,
737
+ ::cuda::std::type_identity<PolicyT>>::type;
697
738
 
698
739
  #if !_CCCL_COMPILER(NVRTC)
699
740
  /// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version
@@ -708,9 +749,12 @@ struct ChainedPolicy
708
749
  # elif defined(NV_TARGET_SM_INTEGER_LIST)
709
750
  return runtime_to_compiletime<10, NV_TARGET_SM_INTEGER_LIST>(device_ptx_version, op);
710
751
  # else
711
- if (device_ptx_version < PolicyPtxVersion)
752
+ if constexpr (have_previous_policy)
712
753
  {
713
- return PrevPolicyT::Invoke(device_ptx_version, op);
754
+ if (device_ptx_version < PolicyPtxVersion)
755
+ {
756
+ return PrevPolicyT::Invoke(device_ptx_version, op);
757
+ }
714
758
  }
715
759
  return op.template Invoke<PolicyT>();
716
760
  # endif
@@ -738,7 +782,7 @@ private:
738
782
  template <int DevicePtxVersion, typename FunctorT>
739
783
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op)
740
784
  {
741
- if constexpr (DevicePtxVersion < PolicyPtxVersion)
785
+ if constexpr (DevicePtxVersion < PolicyPtxVersion && have_previous_policy)
742
786
  {
743
787
  return PrevPolicyT::template invoke_static<DevicePtxVersion>(op);
744
788
  }
@@ -749,34 +793,6 @@ private:
749
793
  }
750
794
  #endif // !_CCCL_COMPILER(NVRTC)
751
795
  };
752
-
753
- /// Helper for dispatching into a policy chain (end-of-chain specialization)
754
- template <int PolicyPtxVersion, typename PolicyT>
755
- struct ChainedPolicy<PolicyPtxVersion, PolicyT, PolicyT>
756
- {
757
- template <int, typename, typename>
758
- friend struct ChainedPolicy; // befriend primary template, so it can call invoke_static
759
-
760
- /// The policy for the active compiler pass
761
- using ActivePolicy = PolicyT;
762
-
763
- #if !_CCCL_COMPILER(NVRTC)
764
- /// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version
765
- template <typename FunctorT>
766
- CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int /*ptx_version*/, FunctorT& op)
767
- {
768
- return op.template Invoke<PolicyT>();
769
- }
770
-
771
- private:
772
- template <int, typename FunctorT>
773
- CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op)
774
- {
775
- return op.template Invoke<PolicyT>();
776
- }
777
- #endif // !_CCCL_COMPILER(NVRTC)
778
- };
779
-
780
796
  CUB_NAMESPACE_END
781
797
 
782
798
  #if _CCCL_HAS_CUDA_COMPILER() && !_CCCL_COMPILER(NVRTC)
@@ -51,6 +51,7 @@
51
51
  #include <cuda/__functional/maximum.h>
52
52
  #include <cuda/__functional/minimum.h>
53
53
  #include <cuda/__ptx/instructions/get_sreg.h>
54
+ #include <cuda/std/__bit/countr.h>
54
55
  #include <cuda/std/__functional/operations.h>
55
56
  #include <cuda/std/__type_traits/enable_if.h>
56
57
  #include <cuda/std/__type_traits/integral_constant.h>
@@ -701,7 +702,7 @@ struct WarpReduceShfl
701
702
  _CCCL_DEVICE _CCCL_FORCEINLINE T SegmentedReduce(T input, FlagT flag, ReductionOp reduction_op)
702
703
  {
703
704
  // Get the start flags for each thread in the warp.
704
- int warp_flags = __ballot_sync(member_mask, flag);
705
+ unsigned warp_flags = __ballot_sync(member_mask, flag);
705
706
 
706
707
  // Convert to tail-segmented
707
708
  if (HEAD_SEGMENTED)
@@ -722,7 +723,7 @@ struct WarpReduceShfl
722
723
  warp_flags |= 1u << (LOGICAL_WARP_THREADS - 1);
723
724
 
724
725
  // Find the next set flag
725
- int last_lane = __clz(__brev(warp_flags));
726
+ int last_lane = ::cuda::std::countr_zero(warp_flags);
726
727
 
727
728
  T output = input;
728
729
  // Template-iterate reduction steps
@@ -49,6 +49,7 @@
49
49
  #include <cub/util_type.cuh>
50
50
 
51
51
  #include <cuda/__ptx/instructions/get_sreg.h>
52
+ #include <cuda/std/__bit/countr.h>
52
53
  #include <cuda/std/__type_traits/integral_constant.h>
53
54
 
54
55
  CUB_NAMESPACE_BEGIN
@@ -215,7 +216,7 @@ struct WarpReduceSmem
215
216
  SegmentedReduce(T input, FlagT flag, ReductionOp reduction_op, ::cuda::std::true_type /*has_ballot*/)
216
217
  {
217
218
  // Get the start flags for each thread in the warp.
218
- int warp_flags = __ballot_sync(member_mask, flag);
219
+ unsigned warp_flags = __ballot_sync(member_mask, flag);
219
220
 
220
221
  if (!HEAD_SEGMENTED)
221
222
  {
@@ -232,7 +233,7 @@ struct WarpReduceSmem
232
233
  }
233
234
 
234
235
  // Find next flag
235
- int next_flag = __clz(__brev(warp_flags));
236
+ int next_flag = ::cuda::std::countr_zero(warp_flags);
236
237
 
237
238
  // Clip the next segment at the warp boundary if necessary
238
239
  if (LOGICAL_WARP_THREADS != 32)