cuda-cccl 0.3.1__cp311-cp311-manylinux_2_24_aarch64.whl → 0.3.2__cp311-cp311-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (185) hide show
  1. cuda/cccl/headers/include/cub/agent/agent_histogram.cuh +354 -572
  2. cuda/cccl/headers/include/cub/block/block_adjacent_difference.cuh +6 -8
  3. cuda/cccl/headers/include/cub/block/block_discontinuity.cuh +24 -14
  4. cuda/cccl/headers/include/cub/block/block_exchange.cuh +5 -0
  5. cuda/cccl/headers/include/cub/block/block_histogram.cuh +4 -0
  6. cuda/cccl/headers/include/cub/block/block_load.cuh +4 -0
  7. cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +1 -0
  8. cuda/cccl/headers/include/cub/block/block_reduce.cuh +1 -0
  9. cuda/cccl/headers/include/cub/block/block_scan.cuh +12 -2
  10. cuda/cccl/headers/include/cub/block/block_store.cuh +3 -2
  11. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +34 -30
  12. cuda/cccl/headers/include/cub/detail/ptx-json-parser.h +1 -1
  13. cuda/cccl/headers/include/cub/device/device_for.cuh +118 -40
  14. cuda/cccl/headers/include/cub/device/device_reduce.cuh +6 -7
  15. cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +12 -13
  16. cuda/cccl/headers/include/cub/device/device_transform.cuh +122 -91
  17. cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +2 -3
  18. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +4 -3
  19. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -1
  20. cuda/cccl/headers/include/cub/device/dispatch/dispatch_streaming_reduce.cuh +4 -5
  21. cuda/cccl/headers/include/cub/device/dispatch/dispatch_streaming_reduce_by_key.cuh +0 -1
  22. cuda/cccl/headers/include/cub/device/dispatch/dispatch_topk.cuh +3 -5
  23. cuda/cccl/headers/include/cub/device/dispatch/dispatch_transform.cuh +13 -5
  24. cuda/cccl/headers/include/cub/device/dispatch/kernels/for_each.cuh +72 -37
  25. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +22 -27
  26. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +61 -70
  27. cuda/cccl/headers/include/cub/thread/thread_reduce.cuh +24 -17
  28. cuda/cccl/headers/include/cub/warp/warp_load.cuh +6 -6
  29. cuda/cccl/headers/include/cub/warp/warp_reduce.cuh +7 -2
  30. cuda/cccl/headers/include/cub/warp/warp_scan.cuh +7 -3
  31. cuda/cccl/headers/include/cub/warp/warp_store.cuh +1 -0
  32. cuda/cccl/headers/include/cuda/__barrier/barrier_block_scope.h +19 -0
  33. cuda/cccl/headers/include/cuda/__cccl_config +1 -0
  34. cuda/cccl/headers/include/cuda/__cmath/fast_modulo_division.h +3 -74
  35. cuda/cccl/headers/include/cuda/__cmath/mul_hi.h +146 -0
  36. cuda/cccl/headers/include/cuda/__complex/get_real_imag.h +0 -4
  37. cuda/cccl/headers/include/cuda/__device/arch_id.h +176 -0
  38. cuda/cccl/headers/include/cuda/__device/arch_traits.h +239 -317
  39. cuda/cccl/headers/include/cuda/__device/attributes.h +4 -3
  40. cuda/cccl/headers/include/cuda/__device/compute_capability.h +171 -0
  41. cuda/cccl/headers/include/cuda/__device/device_ref.h +0 -10
  42. cuda/cccl/headers/include/cuda/__device/physical_device.h +1 -26
  43. cuda/cccl/headers/include/cuda/__event/event.h +26 -26
  44. cuda/cccl/headers/include/cuda/__event/event_ref.h +5 -5
  45. cuda/cccl/headers/include/cuda/__event/timed_event.h +9 -7
  46. cuda/cccl/headers/include/cuda/__fwd/devices.h +4 -4
  47. cuda/cccl/headers/include/cuda/__iterator/constant_iterator.h +46 -31
  48. cuda/cccl/headers/include/cuda/__iterator/strided_iterator.h +79 -47
  49. cuda/cccl/headers/include/cuda/__iterator/tabulate_output_iterator.h +59 -36
  50. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +79 -49
  51. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +74 -48
  52. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +80 -55
  53. cuda/cccl/headers/include/cuda/__iterator/zip_common.h +2 -12
  54. cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +15 -19
  55. cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +59 -60
  56. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +127 -60
  57. cuda/cccl/headers/include/cuda/__mdspan/host_device_mdspan.h +178 -3
  58. cuda/cccl/headers/include/cuda/__mdspan/restrict_accessor.h +38 -8
  59. cuda/cccl/headers/include/cuda/__mdspan/restrict_mdspan.h +67 -1
  60. cuda/cccl/headers/include/cuda/__memory/ptr_in_range.h +93 -0
  61. cuda/cccl/headers/include/cuda/__memory_resource/get_memory_resource.h +4 -4
  62. cuda/cccl/headers/include/cuda/__memory_resource/properties.h +44 -0
  63. cuda/cccl/headers/include/cuda/__memory_resource/resource.h +1 -1
  64. cuda/cccl/headers/include/cuda/__memory_resource/resource_ref.h +4 -6
  65. cuda/cccl/headers/include/cuda/__nvtx/nvtx3.h +2 -1
  66. cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +5 -4
  67. cuda/cccl/headers/include/cuda/__stream/stream.h +8 -8
  68. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +17 -16
  69. cuda/cccl/headers/include/cuda/__utility/in_range.h +65 -0
  70. cuda/cccl/headers/include/cuda/cmath +1 -0
  71. cuda/cccl/headers/include/cuda/devices +3 -0
  72. cuda/cccl/headers/include/cuda/memory +1 -0
  73. cuda/cccl/headers/include/cuda/std/__algorithm/equal_range.h +2 -2
  74. cuda/cccl/headers/include/cuda/std/__algorithm/find.h +1 -1
  75. cuda/cccl/headers/include/cuda/std/__algorithm/includes.h +2 -4
  76. cuda/cccl/headers/include/cuda/std/__algorithm/lower_bound.h +1 -1
  77. cuda/cccl/headers/include/cuda/std/__algorithm/make_projected.h +7 -15
  78. cuda/cccl/headers/include/cuda/std/__algorithm/min_element.h +1 -1
  79. cuda/cccl/headers/include/cuda/std/__algorithm/minmax_element.h +1 -2
  80. cuda/cccl/headers/include/cuda/std/__algorithm/partial_sort_copy.h +2 -2
  81. cuda/cccl/headers/include/cuda/std/__algorithm/upper_bound.h +1 -1
  82. cuda/cccl/headers/include/cuda/std/__cccl/algorithm_wrapper.h +36 -0
  83. cuda/cccl/headers/include/cuda/std/__cccl/builtin.h +46 -49
  84. cuda/cccl/headers/include/cuda/std/__cccl/execution_space.h +6 -0
  85. cuda/cccl/headers/include/cuda/std/__cccl/host_std_lib.h +52 -0
  86. cuda/cccl/headers/include/cuda/std/__cccl/memory_wrapper.h +36 -0
  87. cuda/cccl/headers/include/cuda/std/__cccl/numeric_wrapper.h +36 -0
  88. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +3 -2
  89. cuda/cccl/headers/include/cuda/std/__complex/complex.h +3 -2
  90. cuda/cccl/headers/include/cuda/std/__complex/literals.h +14 -34
  91. cuda/cccl/headers/include/cuda/std/__complex/nvbf16.h +2 -1
  92. cuda/cccl/headers/include/cuda/std/__complex/nvfp16.h +4 -3
  93. cuda/cccl/headers/include/cuda/std/__concepts/invocable.h +2 -2
  94. cuda/cccl/headers/include/cuda/std/__cstdlib/malloc.h +3 -2
  95. cuda/cccl/headers/include/cuda/std/__functional/bind.h +10 -13
  96. cuda/cccl/headers/include/cuda/std/__functional/function.h +5 -8
  97. cuda/cccl/headers/include/cuda/std/__functional/invoke.h +71 -335
  98. cuda/cccl/headers/include/cuda/std/__functional/mem_fn.h +1 -2
  99. cuda/cccl/headers/include/cuda/std/__functional/reference_wrapper.h +3 -3
  100. cuda/cccl/headers/include/cuda/std/__functional/weak_result_type.h +0 -6
  101. cuda/cccl/headers/include/cuda/std/__fwd/allocator.h +13 -0
  102. cuda/cccl/headers/include/cuda/std/__fwd/char_traits.h +13 -0
  103. cuda/cccl/headers/include/cuda/std/__fwd/complex.h +13 -4
  104. cuda/cccl/headers/include/cuda/std/__fwd/mdspan.h +23 -0
  105. cuda/cccl/headers/include/cuda/std/__fwd/pair.h +13 -0
  106. cuda/cccl/headers/include/cuda/std/__fwd/string.h +22 -0
  107. cuda/cccl/headers/include/cuda/std/__fwd/string_view.h +14 -0
  108. cuda/cccl/headers/include/cuda/std/__internal/features.h +0 -5
  109. cuda/cccl/headers/include/cuda/std/__internal/namespaces.h +21 -0
  110. cuda/cccl/headers/include/cuda/std/__iterator/iterator_traits.h +5 -5
  111. cuda/cccl/headers/include/cuda/std/__mdspan/extents.h +7 -1
  112. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +53 -39
  113. cuda/cccl/headers/include/cuda/std/__memory/allocator.h +3 -3
  114. cuda/cccl/headers/include/cuda/std/__memory/construct_at.h +1 -3
  115. cuda/cccl/headers/include/cuda/std/__optional/optional_base.h +1 -0
  116. cuda/cccl/headers/include/cuda/std/__ranges/compressed_movable_box.h +892 -0
  117. cuda/cccl/headers/include/cuda/std/__ranges/movable_box.h +2 -2
  118. cuda/cccl/headers/include/cuda/std/__type_traits/is_primary_template.h +7 -5
  119. cuda/cccl/headers/include/cuda/std/__type_traits/result_of.h +1 -1
  120. cuda/cccl/headers/include/cuda/std/__utility/pair.h +0 -5
  121. cuda/cccl/headers/include/cuda/std/bitset +1 -1
  122. cuda/cccl/headers/include/cuda/std/detail/libcxx/include/__config +15 -12
  123. cuda/cccl/headers/include/cuda/std/detail/libcxx/include/variant +11 -9
  124. cuda/cccl/headers/include/cuda/std/inplace_vector +4 -4
  125. cuda/cccl/headers/include/cuda/std/numbers +5 -0
  126. cuda/cccl/headers/include/cuda/std/string_view +146 -11
  127. cuda/cccl/headers/include/cuda/stream_ref +5 -0
  128. cuda/cccl/headers/include/cuda/utility +1 -0
  129. cuda/cccl/headers/include/nv/target +7 -2
  130. cuda/cccl/headers/include/thrust/allocate_unique.h +1 -1
  131. cuda/cccl/headers/include/thrust/detail/allocator/allocator_traits.h +309 -33
  132. cuda/cccl/headers/include/thrust/detail/allocator/copy_construct_range.h +151 -4
  133. cuda/cccl/headers/include/thrust/detail/allocator/destroy_range.h +60 -3
  134. cuda/cccl/headers/include/thrust/detail/allocator/fill_construct_range.h +45 -3
  135. cuda/cccl/headers/include/thrust/detail/allocator/malloc_allocator.h +31 -6
  136. cuda/cccl/headers/include/thrust/detail/allocator/tagged_allocator.h +29 -16
  137. cuda/cccl/headers/include/thrust/detail/allocator/temporary_allocator.h +41 -4
  138. cuda/cccl/headers/include/thrust/detail/allocator/value_initialize_range.h +42 -4
  139. cuda/cccl/headers/include/thrust/detail/complex/ccosh.h +3 -3
  140. cuda/cccl/headers/include/thrust/detail/internal_functional.h +1 -1
  141. cuda/cccl/headers/include/thrust/detail/memory_algorithms.h +1 -1
  142. cuda/cccl/headers/include/thrust/detail/temporary_array.h +1 -1
  143. cuda/cccl/headers/include/thrust/detail/type_traits.h +1 -1
  144. cuda/cccl/headers/include/thrust/device_delete.h +18 -3
  145. cuda/cccl/headers/include/thrust/device_free.h +16 -3
  146. cuda/cccl/headers/include/thrust/device_new.h +29 -8
  147. cuda/cccl/headers/include/thrust/host_vector.h +1 -1
  148. cuda/cccl/headers/include/thrust/iterator/tabulate_output_iterator.h +5 -2
  149. cuda/cccl/headers/include/thrust/mr/disjoint_pool.h +1 -1
  150. cuda/cccl/headers/include/thrust/mr/pool.h +1 -1
  151. cuda/cccl/headers/include/thrust/system/cuda/detail/find.h +13 -115
  152. cuda/cccl/headers/include/thrust/system/cuda/detail/mismatch.h +8 -2
  153. cuda/cccl/headers/include/thrust/type_traits/is_contiguous_iterator.h +7 -7
  154. cuda/compute/__init__.py +2 -0
  155. cuda/compute/_bindings.pyi +43 -1
  156. cuda/compute/_bindings_impl.pyx +156 -7
  157. cuda/compute/algorithms/_scan.py +108 -36
  158. cuda/compute/algorithms/_transform.py +32 -11
  159. cuda/compute/cu12/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
  160. cuda/compute/cu12/cccl/libcccl.c.parallel.so +0 -0
  161. cuda/compute/cu13/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
  162. cuda/compute/cu13/cccl/libcccl.c.parallel.so +0 -0
  163. cuda/compute/iterators/__init__.py +2 -0
  164. cuda/compute/iterators/_factories.py +28 -0
  165. cuda/compute/iterators/_iterators.py +206 -1
  166. cuda/compute/numba_utils.py +2 -2
  167. cuda/compute/typing.py +2 -0
  168. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/METADATA +1 -1
  169. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/RECORD +171 -175
  170. cuda/cccl/headers/include/thrust/detail/algorithm_wrapper.h +0 -37
  171. cuda/cccl/headers/include/thrust/detail/allocator/allocator_traits.inl +0 -371
  172. cuda/cccl/headers/include/thrust/detail/allocator/copy_construct_range.inl +0 -242
  173. cuda/cccl/headers/include/thrust/detail/allocator/destroy_range.inl +0 -137
  174. cuda/cccl/headers/include/thrust/detail/allocator/fill_construct_range.inl +0 -99
  175. cuda/cccl/headers/include/thrust/detail/allocator/malloc_allocator.inl +0 -68
  176. cuda/cccl/headers/include/thrust/detail/allocator/tagged_allocator.inl +0 -86
  177. cuda/cccl/headers/include/thrust/detail/allocator/temporary_allocator.inl +0 -79
  178. cuda/cccl/headers/include/thrust/detail/allocator/value_initialize_range.inl +0 -98
  179. cuda/cccl/headers/include/thrust/detail/device_delete.inl +0 -52
  180. cuda/cccl/headers/include/thrust/detail/device_free.inl +0 -47
  181. cuda/cccl/headers/include/thrust/detail/device_new.inl +0 -61
  182. cuda/cccl/headers/include/thrust/detail/memory_wrapper.h +0 -40
  183. cuda/cccl/headers/include/thrust/detail/numeric_wrapper.h +0 -37
  184. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/WHEEL +0 -0
  185. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -14,19 +14,17 @@
14
14
  #endif // no system header
15
15
 
16
16
  #include <cub/agent/agent_for.cuh>
17
- #include <cub/detail/fast_modulo_division.cuh> // fast_div_mod
18
17
  #include <cub/detail/mdspan_utils.cuh> // is_sub_size_static
19
18
  #include <cub/detail/type_traits.cuh> // implicit_prom_t
20
19
 
21
- #include <cuda/std/__fwd/span.h>
22
20
  #include <cuda/std/__type_traits/enable_if.h>
23
21
  #include <cuda/std/__type_traits/integral_constant.h>
24
22
  #include <cuda/std/__type_traits/is_convertible.h>
25
23
  #include <cuda/std/__type_traits/is_reference.h>
26
24
  #include <cuda/std/__type_traits/is_trivially_constructible.h>
27
- #include <cuda/std/__type_traits/is_trivially_copy_constructible.h>
25
+ #include <cuda/std/__type_traits/is_trivially_copy_assignable.h>
28
26
  #include <cuda/std/__type_traits/is_trivially_destructible.h>
29
- #include <cuda/std/__type_traits/is_trivially_move_constructible.h>
27
+ #include <cuda/std/__type_traits/is_trivially_move_assignable.h>
30
28
  #include <cuda/std/__type_traits/make_unsigned.h>
31
29
  #include <cuda/std/__utility/integer_sequence.h>
32
30
  #include <cuda/std/cstddef> // size_t
@@ -140,16 +138,21 @@ __launch_bounds__(ChainedPolicyT::ActivePolicy::for_policy_t::block_threads) //
140
138
  * ForEachInExtents
141
139
  **********************************************************************************************************************/
142
140
 
143
- // Returns the extent at the given rank. If the extents is static, returns it, otherwise returns the precomputed value
144
- template <int Rank, typename ExtentType, typename FastDivModType>
145
- _CCCL_DEVICE _CCCL_FORCEINLINE auto extent_at(ExtentType extents, FastDivModType dynamic_extent)
141
+ // Retrieves the extent (dimension size) at a specific position in a multi-dimensional array
142
+ //
143
+ // This function efficiently returns the extent at the given position, optimizing for static extents by returning
144
+ // compile-time constants when possible. For dynamic extents, it returns the precomputed value to avoid runtime
145
+ // computation overhead.
146
+ template <int Position, typename ExtentType, typename FastDivModType>
147
+ _CCCL_DEVICE_API auto extent_at(ExtentType extents, FastDivModType dynamic_extent)
146
148
  {
147
- if constexpr (ExtentType::static_extent(Rank) != ::cuda::std::dynamic_extent)
149
+ if constexpr (ExtentType::static_extent(Position) != ::cuda::std::dynamic_extent)
148
150
  {
149
151
  using extent_index_type = typename ExtentType::index_type;
150
152
  using index_type = implicit_prom_t<extent_index_type>;
151
153
  using unsigned_index_type = ::cuda::std::make_unsigned_t<index_type>;
152
- return static_cast<unsigned_index_type>(extents.static_extent(Rank));
154
+ constexpr auto extent = extents.static_extent(Position);
155
+ return static_cast<unsigned_index_type>(extent);
153
156
  }
154
157
  else
155
158
  {
@@ -157,17 +160,22 @@ _CCCL_DEVICE _CCCL_FORCEINLINE auto extent_at(ExtentType extents, FastDivModType
157
160
  }
158
161
  }
159
162
 
160
- // Returns the product of all extents from position Rank. If the result is static, returns it, otherwise returns the
161
- // precomputed value
162
- template <int Rank, typename ExtentType, typename FastDivModType>
163
- _CCCL_DEVICE _CCCL_FORCEINLINE auto get_extents_sub_size(ExtentType extents, FastDivModType extent_sub_size)
163
+ // Computes the product of extents in a specified range for multi-dimensional indexing.
164
+ // This function calculates the product of all extent dimensions from Start (inclusive) to End (exclusive).
165
+ //
166
+ // Performance characteristics:
167
+ // - Static extents in range: Product computed at compile-time, zero runtime cost
168
+ // - Dynamic extents present: Returns precomputed value, avoiding runtime multiplication
169
+ template <int Start, int End, typename ExtentType, typename FastDivModType>
170
+ _CCCL_DEVICE_API auto get_extents_sub_size(ExtentType extents, FastDivModType extent_sub_size)
164
171
  {
165
- if constexpr (cub::detail::is_sub_size_static<Rank + 1, ExtentType>())
172
+ if constexpr (cub::detail::are_extents_in_range_static<ExtentType>(Start, End))
166
173
  {
167
174
  using extent_index_type = typename ExtentType::index_type;
168
175
  using index_type = implicit_prom_t<extent_index_type>;
169
176
  using unsigned_index_type = ::cuda::std::make_unsigned_t<index_type>;
170
- return static_cast<unsigned_index_type>(cub::detail::sub_size<Rank + 1>(extents));
177
+ auto sub_size = cub::detail::size_range(extents, Start, End);
178
+ return static_cast<unsigned_index_type>(sub_size);
171
179
  }
172
180
  else
173
181
  {
@@ -175,49 +183,76 @@ _CCCL_DEVICE _CCCL_FORCEINLINE auto get_extents_sub_size(ExtentType extents, Fas
175
183
  }
176
184
  }
177
185
 
178
- template <int Rank, typename IndexType, typename ExtentType, typename FastDivModType>
179
- _CCCL_DEVICE _CCCL_FORCEINLINE auto
186
+ // Converts a linear index to a multi-dimensional coordinate at a specific position.
187
+ //
188
+ // This function performs the mathematical conversion from a linear (flat) index to the coordinate value at a specific
189
+ // position in a multi-dimensional array. It supports both row-major (layout_right) and column-major (layout_left)
190
+ // memory layouts, which affects the indexing calculation order.
191
+ //
192
+ // The mathematical formulation depends on the layout:
193
+ // - Right layout (row-major): index_i = (index / product(extent[j] for j in [i+1, rank-1])) % extent[i]
194
+ // - Left layout (column-major): index_i = (index / product(extent[j] for j in [0, i])) % extent[i]
195
+ //
196
+ // This function leverages precomputed fast division and modulo operations to minimize runtime arithmetic overhead.
197
+ template <bool IsLayoutRight, int Position, typename IndexType, typename ExtentType, typename FastDivModType>
198
+ _CCCL_DEVICE_API auto
180
199
  coordinate_at(IndexType index, ExtentType extents, FastDivModType extent_sub_size, FastDivModType dynamic_extent)
181
200
  {
182
201
  using cub::detail::for_each::extent_at;
183
202
  using cub::detail::for_each::get_extents_sub_size;
184
203
  using extent_index_type = typename ExtentType::index_type;
185
- return static_cast<extent_index_type>(
186
- (index / get_extents_sub_size<Rank>(extents, extent_sub_size)) % extent_at<Rank>(extents, dynamic_extent));
204
+ constexpr auto start = IsLayoutRight ? Position + 1 : 0;
205
+ constexpr auto end = IsLayoutRight ? ExtentType::rank() : Position;
206
+ return static_cast<extent_index_type>((index / get_extents_sub_size<start, end>(extents, extent_sub_size))
207
+ % extent_at<Position>(extents, dynamic_extent));
187
208
  }
188
209
 
189
- template <typename OpT, typename ExtentsT, typename FastDivModArrayT>
210
+ // Function object wrapper for applying operations with multi-dimensional coordinate conversion.
211
+ //
212
+ // The wrapped operation will be called with signature: `op(linear_index, coord_0, coord_1, ..., coord_n)`
213
+ // where the number of coordinate parameters matches the rank of the extents object.
214
+ //
215
+ // This wrapper is used internally by DeviceFor::ForEachInLayout/ForEachInExtents
216
+ template <typename OpT, typename ExtentsType, bool IsLayoutRight, typename FastDivModArrayT>
190
217
  struct op_wrapper_extents_t
191
218
  {
192
- OpT op;
193
- ExtentsT extents;
194
- FastDivModArrayT sub_sizes_div_array;
195
- FastDivModArrayT extents_mod_array;
196
-
197
- template <typename OffsetT, size_t... Ranks>
198
- _CCCL_DEVICE _CCCL_FORCEINLINE void impl(OffsetT i, ::cuda::std::index_sequence<Ranks...>)
219
+ OpT op; ///< The user-provided operation to be called with coordinates
220
+ ExtentsType extents; ///< The multi-dimensional extents defining array dimensions
221
+ FastDivModArrayT sub_sizes_div_array; ///< Precomputed fast division values for extent sub-products
222
+ FastDivModArrayT extents_mod_array; ///< Precomputed fast modulo values for individual extents
223
+
224
+ // Internal implementation that converts linear index to coordinates and calls the user operation
225
+ template <typename IndexType, size_t... Positions>
226
+ _CCCL_DEVICE_API void impl(IndexType i, ::cuda::std::index_sequence<Positions...>)
199
227
  {
200
228
  using cub::detail::for_each::coordinate_at;
201
- op(i, coordinate_at<Ranks>(i, extents, sub_sizes_div_array[Ranks], extents_mod_array[Ranks])...);
229
+ op(i,
230
+ coordinate_at<IsLayoutRight, Positions>(
231
+ i, extents, sub_sizes_div_array[Positions], extents_mod_array[Positions])...);
202
232
  }
203
233
 
204
- template <typename OffsetT, size_t... Ranks>
205
- _CCCL_DEVICE _CCCL_FORCEINLINE void impl(OffsetT i, ::cuda::std::index_sequence<Ranks...>) const
234
+ // Internal implementation that converts linear index to coordinates and calls the user operation
235
+ template <typename IndexType, size_t... Positions>
236
+ _CCCL_DEVICE_API void impl(IndexType i, ::cuda::std::index_sequence<Positions...>) const
206
237
  {
207
238
  using cub::detail::for_each::coordinate_at;
208
- op(i, coordinate_at<Ranks>(i, extents, sub_sizes_div_array[Ranks], extents_mod_array[Ranks])...);
239
+ op(i,
240
+ coordinate_at<IsLayoutRight, Positions>(
241
+ i, extents, sub_sizes_div_array[Positions], extents_mod_array[Positions])...);
209
242
  }
210
243
 
211
- template <typename OffsetT>
212
- _CCCL_DEVICE _CCCL_FORCEINLINE void operator()(OffsetT i)
244
+ // Function call operator that processes a linear index by converting it to multi-dimensional coordinates
245
+ template <typename IndexType>
246
+ _CCCL_DEVICE_API void operator()(IndexType i)
213
247
  {
214
- impl(i, ::cuda::std::make_index_sequence<ExtentsT::rank()>{});
248
+ impl(i, ::cuda::std::make_index_sequence<ExtentsType::rank()>{});
215
249
  }
216
250
 
217
- template <typename OffsetT>
218
- _CCCL_DEVICE _CCCL_FORCEINLINE void operator()(OffsetT i) const
251
+ // Function call operator that processes a linear index by converting it to multi-dimensional coordinates
252
+ template <typename IndexType>
253
+ _CCCL_DEVICE_API void operator()(IndexType i) const
219
254
  {
220
- impl(i, ::cuda::std::make_index_sequence<ExtentsT::rank()>{});
255
+ impl(i, ::cuda::std::make_index_sequence<ExtentsType::rank()>{});
221
256
  }
222
257
  };
223
258
 
@@ -217,6 +217,7 @@ _CCCL_DEVICE void transform_kernel_vectorized(
217
217
  {
218
218
  constexpr int block_dim = VectorizedPolicy::block_threads;
219
219
  constexpr int items_per_thread = VectorizedPolicy::items_per_thread_vectorized;
220
+ constexpr int vec_size = VectorizedPolicy::vec_size;
220
221
  _CCCL_ASSERT(!can_vectorize || (items_per_thread == num_elem_per_thread_prefetch), "");
221
222
  constexpr int tile_size = block_dim * items_per_thread;
222
223
  const Offset offset = static_cast<Offset>(blockIdx.x) * tile_size;
@@ -241,23 +242,13 @@ _CCCL_DEVICE void transform_kernel_vectorized(
241
242
  out += offset;
242
243
  }
243
244
 
244
- constexpr int load_store_size = VectorizedPolicy::load_store_word_size;
245
- using load_store_t = decltype(load_store_type<load_store_size>());
246
- using output_t = it_value_t<RandomAccessIteratorOut>;
245
+ using output_t = it_value_t<RandomAccessIteratorOut>;
247
246
  using result_t = ::cuda::std::decay_t<::cuda::std::invoke_result_t<F, const it_value_t<RandomAccessIteratorsIn>&...>>;
248
- // picks output type size if there are no inputs
249
- constexpr int element_size = int{first_nonzero_value(
250
- (sizeof(it_value_t<RandomAccessIteratorsIn>)
251
- * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...,
252
- size_of<output_t>)};
253
- constexpr int load_store_count = (items_per_thread * element_size) / load_store_size;
247
+ constexpr int load_store_count = items_per_thread / vec_size;
248
+ static_assert(items_per_thread % vec_size == 0, "The items per thread must be a multiple of the vector size");
254
249
 
255
- static_assert((items_per_thread * element_size) % load_store_size == 0);
256
- static_assert(load_store_size % element_size == 0);
257
-
258
- constexpr bool can_vectorize_store =
259
- THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorOut>
260
- && THRUST_NS_QUALIFIER::is_trivially_relocatable_v<output_t> && size_of<output_t> == element_size;
250
+ constexpr bool can_vectorize_store = THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorOut>
251
+ && THRUST_NS_QUALIFIER::is_trivially_relocatable_v<output_t>;
261
252
 
262
253
  // if we can vectorize, we convert f's return type to the output type right away, so we can reinterpret later
263
254
  using THRUST_NS_QUALIFIER::cuda_cub::core::detail::uninitialized_array;
@@ -266,10 +257,15 @@ _CCCL_DEVICE void transform_kernel_vectorized(
266
257
  auto provide_array = [&](auto... inputs) {
267
258
  // load inputs
268
259
  [[maybe_unused]] auto load_tile = [](auto in, auto& input) {
260
+ using it_t = decltype(in);
261
+ using value_t = it_value_t<it_t>;
269
262
  if constexpr (THRUST_NS_QUALIFIER::is_contiguous_iterator_v<decltype(in)>)
270
263
  {
271
- auto in_vec = reinterpret_cast<const load_store_t*>(in) + threadIdx.x;
272
- auto input_vec = reinterpret_cast<load_store_t*>(input.data());
264
+ // TODO(bgruber): we could add a max_load_store_size to the policy to avoid huge load types and huge alignment
265
+ // requirements
266
+ using load_t = decltype(load_store_type<sizeof(value_t) * vec_size>());
267
+ auto in_vec = reinterpret_cast<const load_t*>(in) + threadIdx.x;
268
+ auto input_vec = reinterpret_cast<load_t*>(input.data());
273
269
  _CCCL_PRAGMA_UNROLL_FULL()
274
270
  for (int i = 0; i < load_store_count; ++i)
275
271
  {
@@ -278,15 +274,14 @@ _CCCL_DEVICE void transform_kernel_vectorized(
278
274
  }
279
275
  else
280
276
  {
281
- constexpr int elems = load_store_size / element_size;
282
- in += threadIdx.x * elems;
277
+ in += threadIdx.x * vec_size;
283
278
  _CCCL_PRAGMA_UNROLL_FULL()
284
279
  for (int i = 0; i < load_store_count; ++i)
285
280
  {
286
281
  _CCCL_PRAGMA_UNROLL_FULL()
287
- for (int j = 0; j < elems; ++j)
282
+ for (int j = 0; j < vec_size; ++j)
288
283
  {
289
- input[i * elems + j] = in[i * elems * VectorizedPolicy::block_threads + j];
284
+ input[i * vec_size + j] = in[i * vec_size * VectorizedPolicy::block_threads + j];
290
285
  }
291
286
  }
292
287
  }
@@ -310,8 +305,9 @@ _CCCL_DEVICE void transform_kernel_vectorized(
310
305
  if constexpr (can_vectorize_store)
311
306
  {
312
307
  // vector path
313
- auto output_vec = reinterpret_cast<const load_store_t*>(output.data());
314
- auto out_vec = reinterpret_cast<load_store_t*>(out) + threadIdx.x;
308
+ using store_t = decltype(load_store_type<sizeof(output_t) * vec_size>());
309
+ auto output_vec = reinterpret_cast<const store_t*>(output.data());
310
+ auto out_vec = reinterpret_cast<store_t*>(out) + threadIdx.x;
315
311
  _CCCL_PRAGMA_UNROLL_FULL()
316
312
  for (int i = 0; i < load_store_count; ++i)
317
313
  {
@@ -321,15 +317,14 @@ _CCCL_DEVICE void transform_kernel_vectorized(
321
317
  else
322
318
  {
323
319
  // serial path
324
- constexpr int elems = load_store_size / element_size;
325
- out += threadIdx.x * elems;
320
+ out += threadIdx.x * vec_size;
326
321
  _CCCL_PRAGMA_UNROLL_FULL()
327
322
  for (int i = 0; i < load_store_count; ++i)
328
323
  {
329
324
  _CCCL_PRAGMA_UNROLL_FULL()
330
- for (int j = 0; j < elems; ++j)
325
+ for (int j = 0; j < vec_size; ++j)
331
326
  {
332
- out[i * elems * VectorizedPolicy::block_threads + j] = output[i * elems + j];
327
+ out[i * vec_size * VectorizedPolicy::block_threads + j] = output[i * vec_size + j];
333
328
  }
334
329
  }
335
330
  }
@@ -113,11 +113,11 @@ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
113
113
  (max_items_per_thread, MaxItemsPerThread, int),
114
114
  (not_a_vectorized_policy, NotAVectorizedPolicy, int) ) // TODO: remove with C++20
115
115
 
116
- template <int BlockThreads, int ItemsPerThread, int LoadStoreWordSize>
117
- struct vectorized_policy_t : prefetch_policy_t<BlockThreads>
116
+ template <typename Tuning>
117
+ struct vectorized_policy_t : prefetch_policy_t<Tuning::block_threads>
118
118
  {
119
- static constexpr int items_per_thread_vectorized = ItemsPerThread;
120
- static constexpr int load_store_word_size = LoadStoreWordSize;
119
+ static constexpr int items_per_thread_vectorized = Tuning::items_per_thread;
120
+ static constexpr int vec_size = Tuning::vec_size;
121
121
 
122
122
  using not_a_vectorized_policy = void; // TODO: remove with C++20, shadows the variable in prefetch_policy_t
123
123
  };
@@ -130,7 +130,7 @@ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
130
130
  (min_items_per_thread, MinItemsPerThread, int),
131
131
  (max_items_per_thread, MaxItemsPerThread, int),
132
132
  (items_per_thread_vectorized, ItemsPerThreadVectorized, int),
133
- (load_store_word_size, LoadStoreWordSize, int) )
133
+ (vec_size, VecSize, int) )
134
134
 
135
135
  template <int BlockThreads, int BulkCopyAlignment>
136
136
  struct async_copy_policy_t
@@ -282,47 +282,6 @@ _CCCL_HOST_DEVICE constexpr int arch_to_min_bytes_in_flight(int sm_arch)
282
282
  return 12 * 1024; // V100 and below
283
283
  }
284
284
 
285
- template <typename H, typename... Ts>
286
- _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal(H head, Ts... values)
287
- {
288
- size_t first = 0;
289
- for (size_t v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
290
- {
291
- if (v == 0)
292
- {
293
- continue;
294
- }
295
- if (first == 0)
296
- {
297
- first = v;
298
- }
299
- else if (v != first)
300
- {
301
- return false;
302
- }
303
- }
304
- return true;
305
- }
306
-
307
- _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal()
308
- {
309
- return true;
310
- }
311
-
312
- template <typename H, typename... Ts>
313
- _CCCL_HOST_DEVICE constexpr auto first_nonzero_value(H head, Ts... values)
314
- {
315
- for (auto v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
316
- {
317
- if (v != 0)
318
- {
319
- return v;
320
- }
321
- }
322
- // we only reach here when all input are not contiguous and the output has a void value type
323
- return H{1};
324
- }
325
-
326
285
  template <typename T>
327
286
  inline constexpr size_t size_of = sizeof(T);
328
287
 
@@ -337,6 +296,47 @@ _CCCL_HOST_DEVICE static constexpr auto make_sizes_alignments()
337
296
  {{sizeof(it_value_t<RandomAccessIteratorsIn>), alignof(it_value_t<RandomAccessIteratorsIn>)}...}};
338
297
  }
339
298
 
299
+ template <int PtxVersion, int StoreSize, int... LoadSizes>
300
+ struct tuning_vec
301
+ {
302
+ // defaults from fill on RTX 5090, but can be changed
303
+ static constexpr int block_threads = 256;
304
+ static constexpr int vec_size = 4;
305
+ static constexpr int items_per_thread = 8;
306
+ };
307
+
308
+ // manually tuned fill on A100
309
+ template <int StoreSize>
310
+ struct tuning_vec<800, StoreSize>
311
+ {
312
+ static constexpr int block_threads = 256;
313
+ static constexpr int vec_size = ::cuda::std::max(8 / StoreSize, 1); // 64-bit instructions
314
+ static constexpr int items_per_thread = 8;
315
+ };
316
+
317
+ // manually tuned fill on H200
318
+ template <int StoreSize>
319
+ struct tuning_vec<900, StoreSize>
320
+ {
321
+ static constexpr int block_threads = StoreSize > 4 ? 128 : 256;
322
+ static constexpr int vec_size = ::cuda::std::max(8 / StoreSize, 1); // 64-bit instructions
323
+ static constexpr int items_per_thread = 16;
324
+ };
325
+
326
+ // manually tuned fill on B200, same as H200
327
+ template <int StoreSize>
328
+ struct tuning_vec<1000, StoreSize> : tuning_vec<900, StoreSize>
329
+ {};
330
+
331
+ // manually tuned fill on RTX 5090
332
+ template <int StoreSize>
333
+ struct tuning_vec<1200, StoreSize>
334
+ {
335
+ static constexpr int block_threads = 256;
336
+ static constexpr int vec_size = 4;
337
+ static constexpr int items_per_thread = 8;
338
+ };
339
+
340
340
  template <bool RequiresStableAddress,
341
341
  bool DenseOutput,
342
342
  typename RandomAccessIteratorTupleIn,
@@ -367,29 +367,12 @@ struct policy_hub<RequiresStableAddress,
367
367
  || THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>>)
368
368
  && ...);
369
369
 
370
- // for vectorized policy:
371
- static constexpr bool all_contiguous_input_values_same_size = all_nonzero_equal(
372
- (sizeof(it_value_t<RandomAccessIteratorsIn>)
373
- * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...);
374
- static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+
375
- // find the value type size of the first contiguous iterator. if there are no inputs, we take the size of the output
376
- // value type
377
- static constexpr int contiguous_value_type_size = first_nonzero_value(
378
- (int{sizeof(it_value_t<RandomAccessIteratorsIn>)}
379
- * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...,
380
- int{size_of<it_value_t<RandomAccessIteratorOut>>});
381
- static constexpr bool value_type_divides_load_store_size =
382
- load_store_word_size % contiguous_value_type_size == 0; // implicitly checks that value_type_size <=
383
- // load_store_word_size
384
- static constexpr int target_bytes_per_thread =
385
- no_input_streams ? 16 /* by experiment on RTX 5090 */ : 32 /* guestimate by gevtushenko for loading */;
386
- static constexpr int items_per_thread_vec =
387
- ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / contiguous_value_type_size;
388
- using default_vectorized_policy_t = vectorized_policy_t<256, items_per_thread_vec, load_store_word_size>;
370
+ static constexpr bool all_value_types_have_power_of_two_size =
371
+ (::cuda::is_power_of_two(sizeof(it_value_t<RandomAccessIteratorsIn>)) && ...)
372
+ && ::cuda::is_power_of_two(size_of<it_value_t<RandomAccessIteratorOut>>);
389
373
 
390
374
  static constexpr bool fallback_to_prefetch =
391
- RequiresStableAddress || !can_memcpy_contiguous_inputs || !all_contiguous_input_values_same_size
392
- || !value_type_divides_load_store_size || !DenseOutput;
375
+ RequiresStableAddress || !can_memcpy_contiguous_inputs || !all_value_types_have_power_of_two_size || !DenseOutput;
393
376
 
394
377
  // TODO(bgruber): consider a separate kernel for just filling
395
378
 
@@ -398,12 +381,16 @@ struct policy_hub<RequiresStableAddress,
398
381
  static constexpr int min_bif = arch_to_min_bytes_in_flight(300);
399
382
  // TODO(bgruber): we don't need algo, because we can just detect the type of algo_policy
400
383
  static constexpr auto algorithm = fallback_to_prefetch ? Algorithm::prefetch : Algorithm::vectorized;
401
- using algo_policy = ::cuda::std::_If<fallback_to_prefetch, prefetch_policy_t<256>, default_vectorized_policy_t>;
384
+ using vec_policy_t = vectorized_policy_t<
385
+ tuning_vec<500, size_of<it_value_t<RandomAccessIteratorOut>>, sizeof(it_value_t<RandomAccessIteratorsIn>)...>>;
386
+ using algo_policy = ::cuda::std::_If<fallback_to_prefetch, prefetch_policy_t<256>, vec_policy_t>;
402
387
  };
403
388
 
404
389
  struct policy800 : ChainedPolicy<800, policy800, policy300>
405
390
  {
406
391
  private:
392
+ using vec_policy_t = vectorized_policy_t<
393
+ tuning_vec<800, size_of<it_value_t<RandomAccessIteratorOut>>, sizeof(it_value_t<RandomAccessIteratorsIn>)...>>;
407
394
  static constexpr int block_threads = 256;
408
395
  using async_policy = async_copy_policy_t<block_threads, ldgsts_size_and_align>;
409
396
  // We cannot use the architecture-specific amount of SMEM here instead of max_smem_per_block, because this is not
@@ -427,13 +414,17 @@ struct policy_hub<RequiresStableAddress,
427
414
  using algo_policy =
428
415
  ::cuda::std::_If<fallback_to_prefetch,
429
416
  prefetch_policy_t<block_threads>,
430
- ::cuda::std::_If<fallback_to_vectorized, default_vectorized_policy_t, async_policy>>;
417
+ ::cuda::std::_If<fallback_to_vectorized, vec_policy_t, async_policy>>;
431
418
  };
432
419
 
433
420
  template <int AsyncBlockSize, int PtxVersion>
434
421
  struct bulk_copy_policy_base
435
422
  {
436
423
  private:
424
+ using vec_policy_t =
425
+ vectorized_policy_t<tuning_vec<PtxVersion,
426
+ size_of<it_value_t<RandomAccessIteratorOut>>,
427
+ sizeof(it_value_t<RandomAccessIteratorsIn>)...>>;
437
428
  static constexpr int alignment = bulk_copy_alignment(PtxVersion);
438
429
  using async_policy = async_copy_policy_t<AsyncBlockSize, alignment>;
439
430
  // We cannot use the architecture-specific amount of SMEM here instead of max_smem_per_block, because this is not
@@ -469,7 +460,7 @@ struct policy_hub<RequiresStableAddress,
469
460
  using algo_policy =
470
461
  ::cuda::std::_If<fallback_to_prefetch,
471
462
  prefetch_policy_t<256>,
472
- ::cuda::std::_If<fallback_to_vectorized, default_vectorized_policy_t, async_policy>>;
463
+ ::cuda::std::_If<fallback_to_vectorized, vec_policy_t, async_policy>>;
473
464
  };
474
465
 
475
466
  struct policy900
@@ -136,6 +136,7 @@ CUB_NAMESPACE_BEGIN
136
136
  //! {
137
137
  //! int array[4] = {1, 2, 3, 4};
138
138
  //! int sum = cub::ThreadReduce(array, ::cuda::std::plus<>{}); // sum = 10
139
+ //! }
139
140
  //!
140
141
  //! @endrst
141
142
  //!
@@ -437,10 +438,13 @@ template <typename Input, typename ReductionOp, typename ValueT, typename AccumT
437
438
  "Input must support the subscript operator[] and have a compile-time size");
438
439
  static_assert(has_binary_call_operator<ReductionOp, ValueT>::value,
439
440
  "ReductionOp must have the binary call operator: operator(ValueT, ValueT)");
440
- if constexpr (static_size_v<Input> == 1)
441
+
442
+ static constexpr auto length = static_size_v<Input>;
443
+ if constexpr (length == 1)
441
444
  {
442
445
  return static_cast<AccumT>(input[0]);
443
446
  }
447
+
444
448
  using PromT = ::cuda::std::_If<enable_min_max_promotion_v<ReductionOp, ValueT>, int, AccumT>;
445
449
  // TODO: should be part of the tuning policy
446
450
  if constexpr ((!is_simd_enabled_cuda_operator<ReductionOp, ValueT> && !is_simd_operator_v<ReductionOp>)
@@ -449,38 +453,41 @@ template <typename Input, typename ReductionOp, typename ValueT, typename AccumT
449
453
  return ThreadReduceSequential<AccumT>(input, reduction_op);
450
454
  }
451
455
 
452
- constexpr auto length = static_size_v<Input>;
453
- if constexpr (::cuda::std::is_same_v<Input, AccumT> && enable_sm90_simd_reduction_v<Input, ReductionOp, length>)
456
+ if constexpr (::cuda::std::is_same_v<ValueT, AccumT> && enable_sm90_simd_reduction_v<ValueT, ReductionOp, length>)
454
457
  {
455
458
  NV_IF_TARGET(NV_PROVIDES_SM_90, (return ThreadReduceSimd(input, reduction_op);))
456
459
  }
457
460
 
458
- if constexpr (::cuda::std::is_same_v<Input, AccumT> && enable_sm80_simd_reduction_v<Input, ReductionOp, length>)
461
+ if constexpr (::cuda::std::is_same_v<ValueT, AccumT> && enable_sm80_simd_reduction_v<ValueT, ReductionOp, length>)
459
462
  {
460
463
  NV_IF_TARGET(NV_PROVIDES_SM_80, (return ThreadReduceSimd(input, reduction_op);))
461
464
  }
462
465
 
463
- if constexpr (::cuda::std::is_same_v<Input, AccumT> && enable_sm70_simd_reduction_v<Input, ReductionOp, length>)
466
+ if constexpr (::cuda::std::is_same_v<ValueT, AccumT> && enable_sm70_simd_reduction_v<ValueT, ReductionOp, length>)
464
467
  {
465
468
  NV_IF_TARGET(NV_PROVIDES_SM_70, (return ThreadReduceSimd(input, reduction_op);))
466
469
  }
467
470
 
468
- if constexpr (enable_ternary_reduction_sm90_v<Input, ReductionOp>)
471
+ if constexpr (length >= 6)
469
472
  {
470
- // with the current tuning policies, SM90/int32/+ uses too many registers (TODO: fix tuning policy)
471
- if constexpr ((is_one_of_v<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<PromT>>
472
- && is_one_of_v<PromT, int32_t, uint32_t>)
473
- // the compiler generates bad code for int8/uint8 and min/max for SM90
474
- || (is_cuda_minimum_maximum_v<ReductionOp, ValueT> && is_one_of_v<PromT, int8_t, uint8_t>) )
473
+ // apply SM90 min/max ternary reduction only if the input is natively int32/uint32
474
+ if constexpr (enable_ternary_reduction_sm90_v<ValueT, ReductionOp>)
475
475
  {
476
- NV_IF_TARGET(NV_PROVIDES_SM_90, (return ThreadReduceSequential<PromT>(input, reduction_op);));
476
+ // with the current tuning policies, SM90/int32/+ uses too many registers (TODO: fix tuning policy)
477
+ if constexpr ((is_one_of_v<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<PromT>>
478
+ && is_one_of_v<PromT, int32_t, uint32_t>)
479
+ // the compiler generates bad code for int8/uint8 and min/max for SM90
480
+ || (is_cuda_minimum_maximum_v<ReductionOp, ValueT> && is_one_of_v<PromT, int8_t, uint8_t>) )
481
+ {
482
+ NV_IF_TARGET(NV_PROVIDES_SM_90, (return ThreadReduceSequential<PromT>(input, reduction_op);));
483
+ }
484
+ NV_IF_TARGET(NV_PROVIDES_SM_90, (return ThreadReduceTernaryTree<PromT>(input, reduction_op);));
477
485
  }
478
- NV_IF_TARGET(NV_PROVIDES_SM_90, (return ThreadReduceTernaryTree<PromT>(input, reduction_op);));
479
- }
480
486
 
481
- if constexpr (enable_ternary_reduction_sm50_v<Input, ReductionOp>)
482
- {
483
- NV_IF_TARGET(NV_PROVIDES_SM_50, (return ThreadReduceSequential<PromT>(input, reduction_op);));
487
+ if constexpr (enable_ternary_reduction_sm50_v<ValueT, ReductionOp>)
488
+ {
489
+ NV_IF_TARGET(NV_PROVIDES_SM_50, (return ThreadReduceSequential<PromT>(input, reduction_op);));
490
+ }
484
491
  }
485
492
 
486
493
  return ThreadReduceBinaryTree<PromT>(input, reduction_op);
@@ -191,8 +191,8 @@ enum WarpLoadAlgorithm
191
191
  //!
192
192
  //! // Load a segment of consecutive items that are blocked across threads
193
193
  //! int thread_data[items_per_thread];
194
- //! WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size,
195
- //! thread_data);
194
+ //! WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, thread_data);
195
+ //! }
196
196
  //!
197
197
  //! Suppose the input ``d_data`` is ``0, 1, 2, 3, 4, 5, ...``.
198
198
  //! The set of ``thread_data`` across the first logical warp of threads in those
@@ -484,8 +484,8 @@ public:
484
484
  //!
485
485
  //! // Load a segment of consecutive items that are blocked across threads
486
486
  //! int thread_data[items_per_thread];
487
- //! WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size,
488
- //! thread_data);
487
+ //! WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, thread_data);
488
+ //! }
489
489
  //!
490
490
  //! Suppose the input ``d_data`` is ``0, 1, 2, 3, 4, 5, ...``,
491
491
  //! The set of ``thread_data`` across the first logical warp of threads in those
@@ -533,9 +533,9 @@ public:
533
533
  //!
534
534
  //! // Load a segment of consecutive items that are blocked across threads
535
535
  //! int thread_data[items_per_thread];
536
- //! WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size,
537
- //! thread_data,
536
+ //! WarpLoadT(temp_storage[warp_id]).Load(d_data + warp_id * tile_size, thread_data,
538
537
  //! valid_items);
538
+ //! }
539
539
  //!
540
540
  //! Suppose the input ``d_data`` is ``0, 1, 2, 3, 4, 5, ...`` and ``valid_items`` is ``5``.
541
541
  //! The set of ``thread_data`` across the first logical warp of threads in those threads will be:
@@ -105,6 +105,7 @@ CUB_NAMESPACE_BEGIN
105
105
  //! // Return the warp-wide sums to each lane0 (threads 0, 32, 64, and 96)
106
106
  //! int warp_id = threadIdx.x / 32;
107
107
  //! int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data);
108
+ //! }
108
109
  //!
109
110
  //! Suppose the set of input ``thread_data`` across the block of threads is ``{0, 1, 2, 3, ..., 127}``.
110
111
  //! The corresponding output ``aggregate`` in threads 0, 32, 64, and 96 will be
@@ -130,6 +131,8 @@ CUB_NAMESPACE_BEGIN
130
131
  //! int thread_data = ...
131
132
  //! // Return the warp-wide sum to lane0
132
133
  //! int aggregate = WarpReduce(temp_storage).Sum(thread_data);
134
+ //! }
135
+ //! }
133
136
  //!
134
137
  //! Suppose the set of input ``thread_data`` across the warp of threads is ``{0, 1, 2, 3, ..., 31}``.
135
138
  //! The corresponding output ``aggregate`` in thread0 will be ``496`` (and is undefined in other threads).
@@ -218,6 +221,7 @@ public:
218
221
  //! // Return the warp-wide sums to each lane0
219
222
  //! int warp_id = threadIdx.x / 32;
220
223
  //! int aggregate = WarpReduce(temp_storage[warp_id]).Sum(thread_data);
224
+ //! }
221
225
  //!
222
226
  //! Suppose the set of input ``thread_data`` across the block of threads is ``{0, 1, 2, 3, ..., 127}``.
223
227
  //! The corresponding output ``aggregate`` in threads 0, 32, 64, and 96 will ``496``, ``1520``, ``2544``, and
@@ -299,8 +303,8 @@ public:
299
303
  //! thread_data = d_data[threadIdx.x];
300
304
  //!
301
305
  //! // Return the warp-wide sums to each lane0
302
- //! int aggregate = WarpReduce(temp_storage).Sum(
303
- //! thread_data, valid_items);
306
+ //! int aggregate = WarpReduce(temp_storage).Sum(thread_data, valid_items);
307
+ //! }
304
308
  //!
305
309
  //! Suppose the input ``d_data`` is ``{0, 1, 2, 3, 4, ...`` and ``valid_items`` is ``4``.
306
310
  //! The corresponding output ``aggregate`` in *lane*\ :sub:`0` is ``6``
@@ -363,6 +367,7 @@ public:
363
367
  //! // Return the warp-wide sums to each lane0
364
368
  //! int aggregate = WarpReduce(temp_storage).HeadSegmentedSum(
365
369
  //! thread_data, head_flag);
370
+ //! }
366
371
  //!
367
372
  //! Suppose the set of input ``thread_data`` and ``head_flag`` across the block of threads
368
373
  //! is ``{0, 1, 2, 3, ..., 31`` and is ``{1, 0, 0, 0, 1, 0, 0, 0, ..., 1, 0, 0, 0``,