cuda-cccl 0.1.3.2.0.dev438__cp313-cp313-manylinux_2_24_aarch64.whl → 0.3.1__cp313-cp313-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (177) hide show
  1. cuda/cccl/cooperative/__init__.py +7 -1
  2. cuda/cccl/cooperative/experimental/__init__.py +21 -5
  3. cuda/cccl/headers/include/cub/agent/agent_adjacent_difference.cuh +2 -5
  4. cuda/cccl/headers/include/cub/agent/agent_batch_memcpy.cuh +2 -5
  5. cuda/cccl/headers/include/cub/agent/agent_for.cuh +2 -5
  6. cuda/cccl/headers/include/cub/agent/agent_merge.cuh +23 -21
  7. cuda/cccl/headers/include/cub/agent/agent_merge_sort.cuh +21 -3
  8. cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +25 -5
  9. cuda/cccl/headers/include/cub/agent/agent_radix_sort_histogram.cuh +2 -5
  10. cuda/cccl/headers/include/cub/agent/agent_radix_sort_onesweep.cuh +2 -5
  11. cuda/cccl/headers/include/cub/agent/agent_radix_sort_upsweep.cuh +2 -5
  12. cuda/cccl/headers/include/cub/agent/agent_rle.cuh +2 -5
  13. cuda/cccl/headers/include/cub/agent/agent_scan.cuh +5 -1
  14. cuda/cccl/headers/include/cub/agent/agent_scan_by_key.cuh +2 -5
  15. cuda/cccl/headers/include/cub/agent/agent_segmented_radix_sort.cuh +2 -5
  16. cuda/cccl/headers/include/cub/agent/agent_select_if.cuh +2 -5
  17. cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +24 -19
  18. cuda/cccl/headers/include/cub/agent/agent_three_way_partition.cuh +2 -5
  19. cuda/cccl/headers/include/cub/agent/agent_unique_by_key.cuh +22 -5
  20. cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
  21. cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +3 -2
  22. cuda/cccl/headers/include/cub/block/block_radix_sort.cuh +4 -2
  23. cuda/cccl/headers/include/cub/detail/device_memory_resource.cuh +1 -0
  24. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
  25. cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
  26. cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
  27. cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +158 -247
  28. cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
  29. cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +4 -4
  30. cuda/cccl/headers/include/cub/device/dispatch/dispatch_radix_sort.cuh +2 -11
  31. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +8 -26
  32. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -6
  33. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_nondeterministic.cuh +0 -1
  34. cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +320 -262
  35. cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +10 -5
  36. cuda/cccl/headers/include/cub/device/dispatch/kernels/scan.cuh +2 -5
  37. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_reduce.cuh +2 -5
  38. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
  39. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
  40. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_adjacent_difference.cuh +2 -5
  41. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_batch_memcpy.cuh +2 -5
  42. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_for.cuh +2 -5
  43. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_histogram.cuh +2 -5
  44. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge.cuh +2 -5
  45. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge_sort.cuh +8 -0
  46. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_radix_sort.cuh +2 -5
  47. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh +2 -5
  48. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_run_length_encode.cuh +2 -5
  49. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan.cuh +2 -5
  50. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan_by_key.cuh +2 -5
  51. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +204 -55
  52. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_three_way_partition.cuh +2 -5
  53. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
  54. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_unique_by_key.cuh +10 -0
  55. cuda/cccl/headers/include/cub/util_device.cuh +51 -35
  56. cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_shfl.cuh +3 -2
  57. cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_smem.cuh +3 -2
  58. cuda/cccl/headers/include/cub/warp/specializations/warp_scan_shfl.cuh +2 -2
  59. cuda/cccl/headers/include/cuda/__algorithm/common.h +1 -1
  60. cuda/cccl/headers/include/cuda/__algorithm/copy.h +4 -4
  61. cuda/cccl/headers/include/cuda/__algorithm/fill.h +1 -1
  62. cuda/cccl/headers/include/cuda/__device/all_devices.h +47 -147
  63. cuda/cccl/headers/include/cuda/__device/arch_traits.h +51 -49
  64. cuda/cccl/headers/include/cuda/__device/attributes.h +177 -127
  65. cuda/cccl/headers/include/cuda/__device/device_ref.h +32 -51
  66. cuda/cccl/headers/include/cuda/__device/physical_device.h +120 -91
  67. cuda/cccl/headers/include/cuda/__driver/driver_api.h +330 -36
  68. cuda/cccl/headers/include/cuda/__event/event.h +8 -8
  69. cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
  70. cuda/cccl/headers/include/cuda/__event/timed_event.h +4 -4
  71. cuda/cccl/headers/include/cuda/__fwd/devices.h +44 -0
  72. cuda/cccl/headers/include/cuda/__fwd/zip_iterator.h +9 -0
  73. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
  74. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
  75. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
  76. cuda/cccl/headers/include/cuda/__iterator/zip_common.h +158 -0
  77. cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +8 -120
  78. cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +593 -0
  79. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
  80. cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +4 -3
  81. cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
  82. cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
  83. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +18 -12
  84. cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
  85. cuda/cccl/headers/include/cuda/__utility/basic_any.h +1 -1
  86. cuda/cccl/headers/include/cuda/algorithm +1 -1
  87. cuda/cccl/headers/include/cuda/devices +10 -0
  88. cuda/cccl/headers/include/cuda/iterator +1 -0
  89. cuda/cccl/headers/include/cuda/std/__bit/countl.h +8 -1
  90. cuda/cccl/headers/include/cuda/std/__bit/countr.h +2 -2
  91. cuda/cccl/headers/include/cuda/std/__bit/reference.h +11 -11
  92. cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
  93. cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
  94. cuda/cccl/headers/include/cuda/std/__chrono/duration.h +16 -16
  95. cuda/cccl/headers/include/cuda/std/__chrono/steady_clock.h +5 -5
  96. cuda/cccl/headers/include/cuda/std/__chrono/system_clock.h +5 -5
  97. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
  98. cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
  99. cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
  100. cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
  101. cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
  102. cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
  103. cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
  104. cuda/cccl/headers/include/cuda/std/__floating_point/fp.h +1 -1
  105. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
  106. cuda/cccl/headers/include/cuda/std/__tuple_dir/make_tuple_types.h +23 -1
  107. cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like.h +4 -0
  108. cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like_ext.h +4 -0
  109. cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
  110. cuda/cccl/headers/include/cuda/std/string_view +12 -5
  111. cuda/cccl/headers/include/cuda/std/version +1 -4
  112. cuda/cccl/headers/include/thrust/detail/integer_math.h +3 -20
  113. cuda/cccl/headers/include/thrust/iterator/iterator_traits.h +11 -0
  114. cuda/cccl/headers/include/thrust/system/cuda/detail/copy.h +33 -0
  115. cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
  116. cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
  117. cuda/cccl/parallel/experimental/__init__.py +21 -70
  118. cuda/compute/__init__.py +77 -0
  119. cuda/{cccl/parallel/experimental → compute}/_bindings.pyi +28 -0
  120. cuda/{cccl/parallel/experimental → compute}/_bindings_impl.pyx +141 -1
  121. cuda/{cccl/parallel/experimental → compute}/algorithms/__init__.py +4 -0
  122. cuda/{cccl/parallel/experimental → compute}/algorithms/_histogram.py +2 -2
  123. cuda/{cccl/parallel/experimental → compute}/algorithms/_merge_sort.py +2 -2
  124. cuda/{cccl/parallel/experimental → compute}/algorithms/_radix_sort.py +3 -3
  125. cuda/{cccl/parallel/experimental → compute}/algorithms/_reduce.py +2 -4
  126. cuda/{cccl/parallel/experimental → compute}/algorithms/_scan.py +4 -6
  127. cuda/{cccl/parallel/experimental → compute}/algorithms/_segmented_reduce.py +2 -2
  128. cuda/compute/algorithms/_three_way_partition.py +261 -0
  129. cuda/{cccl/parallel/experimental → compute}/algorithms/_transform.py +4 -4
  130. cuda/{cccl/parallel/experimental → compute}/algorithms/_unique_by_key.py +2 -2
  131. cuda/compute/cu12/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  132. cuda/{cccl/parallel/experimental → compute}/cu12/cccl/libcccl.c.parallel.so +0 -0
  133. cuda/compute/cu13/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  134. cuda/{cccl/parallel/experimental → compute}/cu13/cccl/libcccl.c.parallel.so +0 -0
  135. cuda/{cccl/parallel/experimental → compute}/iterators/_factories.py +8 -8
  136. cuda/{cccl/parallel/experimental → compute}/struct.py +2 -2
  137. cuda/coop/__init__.py +8 -0
  138. cuda/{cccl/cooperative/experimental → coop}/_nvrtc.py +3 -2
  139. cuda/{cccl/cooperative/experimental → coop}/_scan_op.py +3 -3
  140. cuda/{cccl/cooperative/experimental → coop}/_types.py +2 -2
  141. cuda/{cccl/cooperative/experimental → coop}/_typing.py +1 -1
  142. cuda/{cccl/cooperative/experimental → coop}/block/__init__.py +6 -6
  143. cuda/{cccl/cooperative/experimental → coop}/block/_block_exchange.py +4 -4
  144. cuda/{cccl/cooperative/experimental → coop}/block/_block_load_store.py +6 -6
  145. cuda/{cccl/cooperative/experimental → coop}/block/_block_merge_sort.py +4 -4
  146. cuda/{cccl/cooperative/experimental → coop}/block/_block_radix_sort.py +6 -6
  147. cuda/{cccl/cooperative/experimental → coop}/block/_block_reduce.py +6 -6
  148. cuda/{cccl/cooperative/experimental → coop}/block/_block_scan.py +7 -7
  149. cuda/coop/warp/__init__.py +9 -0
  150. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_merge_sort.py +3 -3
  151. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_reduce.py +6 -6
  152. cuda/{cccl/cooperative/experimental → coop}/warp/_warp_scan.py +4 -4
  153. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/METADATA +1 -1
  154. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/RECORD +171 -166
  155. cuda/cccl/cooperative/experimental/warp/__init__.py +0 -9
  156. cuda/cccl/headers/include/cub/device/dispatch/dispatch_advance_iterators.cuh +0 -111
  157. cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
  158. cuda/cccl/parallel/experimental/.gitignore +0 -4
  159. cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  160. cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-313-aarch64-linux-gnu.so +0 -0
  161. /cuda/{cccl/parallel/experimental → compute}/_bindings.py +0 -0
  162. /cuda/{cccl/parallel/experimental → compute}/_caching.py +0 -0
  163. /cuda/{cccl/parallel/experimental → compute}/_cccl_interop.py +0 -0
  164. /cuda/{cccl/parallel/experimental → compute}/_utils/__init__.py +0 -0
  165. /cuda/{cccl/parallel/experimental → compute}/_utils/protocols.py +0 -0
  166. /cuda/{cccl/parallel/experimental → compute}/_utils/temp_storage_buffer.py +0 -0
  167. /cuda/{cccl/parallel/experimental → compute}/cccl/.gitkeep +0 -0
  168. /cuda/{cccl/parallel/experimental → compute}/iterators/__init__.py +0 -0
  169. /cuda/{cccl/parallel/experimental → compute}/iterators/_iterators.py +0 -0
  170. /cuda/{cccl/parallel/experimental → compute}/iterators/_zip_iterator.py +0 -0
  171. /cuda/{cccl/parallel/experimental → compute}/numba_utils.py +0 -0
  172. /cuda/{cccl/parallel/experimental → compute}/op.py +0 -0
  173. /cuda/{cccl/parallel/experimental → compute}/typing.py +0 -0
  174. /cuda/{cccl/cooperative/experimental → coop}/_caching.py +0 -0
  175. /cuda/{cccl/cooperative/experimental → coop}/_common.py +0 -0
  176. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/WHEEL +0 -0
  177. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,9 @@
1
1
  # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2
2
  #
3
- # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3
+ # SPDX-License -Identifier: Apache-2.0 WITH LLVM-exception
4
+
5
+ from . import experimental
6
+
7
+ __all__ = [
8
+ "experimental",
9
+ ]
@@ -1,8 +1,24 @@
1
- # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
1
+ # Copyright (c) 2025, NVIDIA CORPORATION.
2
2
  #
3
- # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # alias for backwards compatibility
16
+
17
+ from warnings import warn
4
18
 
5
- from cuda.cccl.cooperative.experimental import block, warp
6
- from cuda.cccl.cooperative.experimental._types import StatefulFunction
19
+ from cuda.coop import * # noqa: F403
7
20
 
8
- __all__ = ["block", "warp", "StatefulFunction"]
21
+ warn(
22
+ "The module cuda.cccl.cooperative.experimental is deprecated. Use cuda.coop instead.",
23
+ FutureWarning,
24
+ )
@@ -64,9 +64,7 @@ struct AgentAdjacentDifferencePolicy
64
64
  static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM;
65
65
  };
66
66
 
67
- namespace detail
68
- {
69
- namespace adjacent_difference
67
+ namespace detail::adjacent_difference
70
68
  {
71
69
 
72
70
  template <typename Policy,
@@ -256,7 +254,6 @@ struct AgentDifferenceInit
256
254
  }
257
255
  };
258
256
 
259
- } // namespace adjacent_difference
260
- } // namespace detail
257
+ } // namespace detail::adjacent_difference
261
258
 
262
259
  CUB_NAMESPACE_END
@@ -62,9 +62,7 @@
62
62
 
63
63
  CUB_NAMESPACE_BEGIN
64
64
 
65
- namespace detail
66
- {
67
- namespace batch_memcpy
65
+ namespace detail::batch_memcpy
68
66
  {
69
67
  template <bool PTR_IS_FOUR_BYTE_ALIGNED>
70
68
  _CCCL_FORCEINLINE _CCCL_DEVICE void
@@ -1179,7 +1177,6 @@ private:
1179
1177
  // buffers
1180
1178
  BLevBlockOffsetTileState blev_block_scan_state;
1181
1179
  };
1182
- } // namespace batch_memcpy
1183
- } // namespace detail
1180
+ } // namespace detail::batch_memcpy
1184
1181
 
1185
1182
  CUB_NAMESPACE_END
@@ -42,9 +42,7 @@
42
42
 
43
43
  CUB_NAMESPACE_BEGIN
44
44
 
45
- namespace detail
46
- {
47
- namespace for_each
45
+ namespace detail::for_each
48
46
  {
49
47
 
50
48
  template <int BlockThreads, int ItemsPerThread>
@@ -78,7 +76,6 @@ struct agent_block_striped_t
78
76
  }
79
77
  };
80
78
 
81
- } // namespace for_each
82
- } // namespace detail
79
+ } // namespace detail::for_each
83
80
 
84
81
  CUB_NAMESPACE_END
@@ -53,14 +53,8 @@ struct agent_t
53
53
  using policy = Policy;
54
54
 
55
55
  // key and value type are taken from the first input sequence (consistent with old Thrust behavior)
56
- using key_type = it_value_t<KeysIt1>;
57
- using item_type = it_value_t<ItemsIt1>;
58
-
59
- using keys_load_it1 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt1>;
60
- using keys_load_it2 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, KeysIt2>;
61
- using items_load_it1 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt1>;
62
- using items_load_it2 = try_make_cache_modified_iterator_t<Policy::LOAD_MODIFIER, ItemsIt2>;
63
-
56
+ using key_type = it_value_t<KeysIt1>;
57
+ using item_type = it_value_t<ItemsIt1>;
64
58
  using block_store_keys = typename BlockStoreType<Policy, KeysOutputIt, key_type>::type;
65
59
  using block_store_items = typename BlockStoreType<Policy, ItemsOutputIt, item_type>::type;
66
60
 
@@ -84,11 +78,11 @@ struct agent_t
84
78
 
85
79
  // Per thread data
86
80
  temp_storages& storage;
87
- keys_load_it1 keys1_in;
88
- items_load_it1 items1_in;
81
+ KeysIt1 keys1_in;
82
+ ItemsIt1 items1_in;
89
83
  Offset keys1_count;
90
- keys_load_it2 keys2_in;
91
- items_load_it2 items2_in;
84
+ KeysIt2 keys2_in;
85
+ ItemsIt2 items2_in;
92
86
  Offset keys2_count;
93
87
  KeysOutputIt keys_out;
94
88
  ItemsOutputIt items_out;
@@ -128,10 +122,14 @@ struct agent_t
128
122
  }
129
123
 
130
124
  key_type keys_loc[items_per_thread];
131
- merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
132
- keys_loc, keys1_in + keys1_beg, keys2_in + keys2_beg, keys1_count_tile, keys2_count_tile);
133
- merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
134
- __syncthreads();
125
+ {
126
+ auto keys1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(keys1_in);
127
+ auto keys2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(keys2_in);
128
+ merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
129
+ keys_loc, keys1_in_cm + keys1_beg, keys2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
130
+ merge_sort::reg_to_shared<threads_per_block>(&storage.keys_shared[0], keys_loc);
131
+ __syncthreads();
132
+ }
135
133
 
136
134
  // now find the merge path for each of thread.
137
135
  // we can use int type here, because the number of items in shared memory is limited
@@ -186,11 +184,15 @@ struct agent_t
186
184
  if constexpr (have_items)
187
185
  {
188
186
  item_type items_loc[items_per_thread];
189
- merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
190
- items_loc, items1_in + keys1_beg, items2_in + keys2_beg, keys1_count_tile, keys2_count_tile);
191
- __syncthreads(); // block_store_keys above uses SMEM, so make sure all threads are done before we write to it
192
- merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
193
- __syncthreads();
187
+ {
188
+ auto items1_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items1_in);
189
+ auto items2_in_cm = try_make_cache_modified_iterator<Policy::LOAD_MODIFIER>(items2_in);
190
+ merge_sort::gmem_to_reg<threads_per_block, IsFullTile>(
191
+ items_loc, items1_in_cm + keys1_beg, items2_in_cm + keys2_beg, keys1_count_tile, keys2_count_tile);
192
+ __syncthreads(); // block_store_keys above uses SMEM, so make sure all threads are done before we write to it
193
+ merge_sort::reg_to_shared<threads_per_block>(&storage.items_shared[0], items_loc);
194
+ __syncthreads();
195
+ }
194
196
 
195
197
  // gather items from shared mem
196
198
  _CCCL_PRAGMA_UNROLL_FULL()
@@ -66,9 +66,28 @@ struct AgentMergeSortPolicy
66
66
  static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM;
67
67
  };
68
68
 
69
+ #if defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
69
70
  namespace detail
70
71
  {
71
- namespace merge_sort
72
+ // Only define this when needed.
73
+ // Because of overload woes, this depends on C++20 concepts. util_device.h checks that concepts are available when
74
+ // either runtime policies or PTX JSON information are enabled, so if they are, this is always valid. The generic
75
+ // version is always defined, and that's the only one needed for regular CUB operations.
76
+ //
77
+ // TODO: enable this unconditionally once concepts are always available
78
+ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
79
+ MergeSortAgentPolicy,
80
+ (GenericAgentPolicy),
81
+ (BLOCK_THREADS, BlockThreads, int),
82
+ (ITEMS_PER_THREAD, ItemsPerThread, int),
83
+ (ITEMS_PER_TILE, ItemsPerTile, int),
84
+ (LOAD_ALGORITHM, LoadAlgorithm, cub::BlockLoadAlgorithm),
85
+ (LOAD_MODIFIER, LoadModifier, cub::CacheLoadModifier),
86
+ (STORE_ALGORITHM, StoreAlgorithm, cub::BlockStoreAlgorithm))
87
+ } // namespace detail
88
+ #endif // defined(CUB_DEFINE_RUNTIME_POLICIES
89
+
90
+ namespace detail::merge_sort
72
91
  {
73
92
 
74
93
  template <typename Policy,
@@ -724,7 +743,6 @@ struct AgentMerge
724
743
  }
725
744
  };
726
745
 
727
- } // namespace merge_sort
728
- } // namespace detail
746
+ } // namespace detail::merge_sort
729
747
 
730
748
  CUB_NAMESPACE_END
@@ -51,6 +51,7 @@
51
51
  #include <cub/block/radix_rank_sort_operations.cuh>
52
52
  #include <cub/iterator/cache_modified_input_iterator.cuh>
53
53
  #include <cub/thread/thread_load.cuh>
54
+ #include <cub/util_device.cuh>
54
55
  #include <cub/util_type.cuh>
55
56
 
56
57
  #include <cuda/std/cstdint>
@@ -119,13 +120,33 @@ struct AgentRadixSortDownsweepPolicy : ScalingType
119
120
  static constexpr BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM;
120
121
  };
121
122
 
123
+ #if defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
124
+ namespace detail
125
+ {
126
+ // Only define this when needed.
127
+ // Because of overload woes, this depends on C++20 concepts. util_device.h checks that concepts are available when
128
+ // either runtime policies or PTX JSON information are enabled, so if they are, this is always valid. The generic
129
+ // version is always defined, and that's the only one needed for regular CUB operations.
130
+ //
131
+ // TODO: enable this unconditionally once concepts are always available
132
+ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
133
+ RadixSortDownsweepAgentPolicy,
134
+ (GenericAgentPolicy),
135
+ (BLOCK_THREADS, BlockThreads, int),
136
+ (ITEMS_PER_THREAD, ItemsPerThread, int),
137
+ (RADIX_BITS, RadixBits, int),
138
+ (LOAD_ALGORITHM, LoadAlgorithm, cub::BlockLoadAlgorithm),
139
+ (LOAD_MODIFIER, LoadModifier, cub::CacheLoadModifier),
140
+ (RANK_ALGORITHM, RankAlgorithm, cub::RadixRankAlgorithm),
141
+ (SCAN_ALGORITHM, ScanAlgorithm, cub::BlockScanAlgorithm))
142
+ } // namespace detail
143
+ #endif // defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
144
+
122
145
  /******************************************************************************
123
146
  * Thread block abstractions
124
147
  ******************************************************************************/
125
148
 
126
- namespace detail
127
- {
128
- namespace radix_sort
149
+ namespace detail::radix_sort
129
150
  {
130
151
 
131
152
  /**
@@ -760,7 +781,6 @@ struct AgentRadixSortDownsweep
760
781
  }
761
782
  };
762
783
 
763
- } // namespace radix_sort
764
- } // namespace detail
784
+ } // namespace detail::radix_sort
765
785
 
766
786
  CUB_NAMESPACE_END
@@ -85,9 +85,7 @@ struct AgentRadixSortExclusiveSumPolicy
85
85
  };
86
86
  };
87
87
 
88
- namespace detail
89
- {
90
- namespace radix_sort
88
+ namespace detail::radix_sort
91
89
  {
92
90
 
93
91
  template <typename AgentRadixSortHistogramPolicy,
@@ -283,7 +281,6 @@ struct AgentRadixSortHistogram
283
281
  }
284
282
  };
285
283
 
286
- } // namespace radix_sort
287
- } // namespace detail
284
+ } // namespace detail::radix_sort
288
285
 
289
286
  CUB_NAMESPACE_END
@@ -100,9 +100,7 @@ struct AgentRadixSortOnesweepPolicy : ScalingType
100
100
  static constexpr RadixSortStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM;
101
101
  };
102
102
 
103
- namespace detail
104
- {
105
- namespace radix_sort
103
+ namespace detail::radix_sort
106
104
  {
107
105
 
108
106
  template <typename AgentRadixSortOnesweepPolicy,
@@ -700,7 +698,6 @@ struct AgentRadixSortOnesweep
700
698
  }
701
699
  };
702
700
 
703
- } // namespace radix_sort
704
- } // namespace detail
701
+ } // namespace detail::radix_sort
705
702
 
706
703
  CUB_NAMESPACE_END
@@ -103,9 +103,7 @@ struct AgentRadixSortUpsweepPolicy : ScalingType
103
103
  * Thread block abstractions
104
104
  ******************************************************************************/
105
105
 
106
- namespace detail
107
- {
108
- namespace radix_sort
106
+ namespace detail::radix_sort
109
107
  {
110
108
 
111
109
  /**
@@ -552,7 +550,6 @@ struct AgentRadixSortUpsweep
552
550
  }
553
551
  };
554
552
 
555
- } // namespace radix_sort
556
- } // namespace detail
553
+ } // namespace detail::radix_sort
557
554
 
558
555
  CUB_NAMESPACE_END
@@ -134,9 +134,7 @@ struct AgentRlePolicy
134
134
  * Thread block abstractions
135
135
  ******************************************************************************/
136
136
 
137
- namespace detail
138
- {
139
- namespace rle
137
+ namespace detail::rle
140
138
  {
141
139
 
142
140
  /**
@@ -1121,7 +1119,6 @@ struct AgentRle
1121
1119
  }
1122
1120
  };
1123
1121
 
1124
- } // namespace rle
1125
- } // namespace detail
1122
+ } // namespace detail::rle
1126
1123
 
1127
1124
  CUB_NAMESPACE_END
@@ -51,6 +51,10 @@
51
51
  #include <cub/iterator/cache_modified_input_iterator.cuh>
52
52
  #include <cub/util_device.cuh>
53
53
 
54
+ #if defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
55
+ # include <cub/agent/agent_unique_by_key.cuh> // for UniqueByKeyAgentPolicy
56
+ #endif
57
+
54
58
  #include <cuda/std/__type_traits/conditional.h>
55
59
  #include <cuda/std/__type_traits/is_pointer.h>
56
60
  #include <cuda/std/__type_traits/is_same.h>
@@ -123,7 +127,7 @@ namespace detail
123
127
  // TODO: enable this unconditionally once concepts are always available
124
128
  CUB_DETAIL_POLICY_WRAPPER_DEFINE(
125
129
  ScanAgentPolicy,
126
- (GenericAgentPolicy),
130
+ (UniqueByKeyAgentPolicy),
127
131
  (BLOCK_THREADS, BlockThreads, int),
128
132
  (ITEMS_PER_THREAD, ItemsPerThread, int),
129
133
  (LOAD_ALGORITHM, LoadAlgorithm, cub::BlockLoadAlgorithm),
@@ -96,9 +96,7 @@ struct AgentScanByKeyPolicy
96
96
  * Thread block abstractions
97
97
  ******************************************************************************/
98
98
 
99
- namespace detail
100
- {
101
- namespace scan_by_key
99
+ namespace detail::scan_by_key
102
100
  {
103
101
 
104
102
  /**
@@ -471,7 +469,6 @@ struct AgentScanByKey
471
469
  }
472
470
  };
473
471
 
474
- } // namespace scan_by_key
475
- } // namespace detail
472
+ } // namespace detail::scan_by_key
476
473
 
477
474
  CUB_NAMESPACE_END
@@ -45,9 +45,7 @@
45
45
 
46
46
  CUB_NAMESPACE_BEGIN
47
47
 
48
- namespace detail
49
- {
50
- namespace radix_sort
48
+ namespace detail::radix_sort
51
49
  {
52
50
 
53
51
  /**
@@ -286,7 +284,6 @@ struct AgentSegmentedRadixSort
286
284
  }
287
285
  };
288
286
 
289
- } // namespace radix_sort
290
- } // namespace detail
287
+ } // namespace detail::radix_sort
291
288
 
292
289
  CUB_NAMESPACE_END
@@ -126,9 +126,7 @@ struct AgentSelectIfPolicy
126
126
  * Thread block abstractions
127
127
  ******************************************************************************/
128
128
 
129
- namespace detail
130
- {
131
- namespace select
129
+ namespace detail::select
132
130
  {
133
131
 
134
132
  template <typename EqualityOpT>
@@ -1114,7 +1112,6 @@ struct AgentSelectIf
1114
1112
  }
1115
1113
  };
1116
1114
 
1117
- } // namespace select
1118
- } // namespace detail
1115
+ } // namespace detail::select
1119
1116
 
1120
1117
  CUB_NAMESPACE_END
@@ -48,37 +48,43 @@
48
48
 
49
49
  CUB_NAMESPACE_BEGIN
50
50
 
51
- template <int WARP_THREADS_ARG,
51
+ template <int BLOCK_THREADS_ARG,
52
+ int WARP_THREADS_ARG,
52
53
  int ITEMS_PER_THREAD_ARG,
53
54
  cub::WarpLoadAlgorithm LOAD_ALGORITHM_ARG = cub::WARP_LOAD_DIRECT,
54
55
  cub::CacheLoadModifier LOAD_MODIFIER_ARG = cub::LOAD_LDG,
55
56
  cub::WarpStoreAlgorithm STORE_ALGORITHM_ARG = cub::WARP_STORE_DIRECT>
56
57
  struct AgentSubWarpMergeSortPolicy
57
58
  {
58
- static constexpr int WARP_THREADS = WARP_THREADS_ARG;
59
- static constexpr int ITEMS_PER_THREAD = ITEMS_PER_THREAD_ARG;
60
- static constexpr int ITEMS_PER_TILE = WARP_THREADS * ITEMS_PER_THREAD;
59
+ static constexpr int BLOCK_THREADS = BLOCK_THREADS_ARG;
60
+ static constexpr int WARP_THREADS = WARP_THREADS_ARG;
61
+ static constexpr int ITEMS_PER_THREAD = ITEMS_PER_THREAD_ARG;
62
+ static constexpr int ITEMS_PER_TILE = WARP_THREADS * ITEMS_PER_THREAD;
63
+ static constexpr int SEGMENTS_PER_BLOCK = BLOCK_THREADS / WARP_THREADS;
61
64
 
62
65
  static constexpr cub::WarpLoadAlgorithm LOAD_ALGORITHM = LOAD_ALGORITHM_ARG;
63
66
  static constexpr cub::CacheLoadModifier LOAD_MODIFIER = LOAD_MODIFIER_ARG;
64
67
  static constexpr cub::WarpStoreAlgorithm STORE_ALGORITHM = STORE_ALGORITHM_ARG;
65
68
  };
66
69
 
67
- template <int BLOCK_THREADS_ARG, typename SmallPolicy, typename MediumPolicy>
68
- struct AgentSmallAndMediumSegmentedSortPolicy
69
- {
70
- static constexpr int BLOCK_THREADS = BLOCK_THREADS_ARG;
71
- using SmallPolicyT = SmallPolicy;
72
- using MediumPolicyT = MediumPolicy;
73
-
74
- static constexpr int SEGMENTS_PER_MEDIUM_BLOCK = BLOCK_THREADS / MediumPolicyT::WARP_THREADS;
75
-
76
- static constexpr int SEGMENTS_PER_SMALL_BLOCK = BLOCK_THREADS / SmallPolicyT::WARP_THREADS;
77
- };
78
-
70
+ #if defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
79
71
  namespace detail
80
72
  {
81
- namespace sub_warp_merge_sort
73
+ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
74
+ SubWarpMergeSortAgentPolicy,
75
+ (GenericAgentPolicy),
76
+ (BLOCK_THREADS, BlockThreads, int),
77
+ (WARP_THREADS, WarpThreads, int),
78
+ (ITEMS_PER_THREAD, ItemsPerThread, int),
79
+ (ITEMS_PER_TILE, ItemsPerTile, int),
80
+ (SEGMENTS_PER_BLOCK, SegmentsPerBlock, int),
81
+ (LOAD_ALGORITHM, LoadAlgorithm, cub::WarpLoadAlgorithm),
82
+ (LOAD_MODIFIER, LoadModifier, cub::CacheLoadModifier),
83
+ (STORE_ALGORITHM, StoreAlgorithm, cub::WarpStoreAlgorithm))
84
+ } // namespace detail
85
+ #endif // defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
86
+
87
+ namespace detail::sub_warp_merge_sort
82
88
  {
83
89
 
84
90
  /**
@@ -335,7 +341,6 @@ private:
335
341
  }
336
342
  };
337
343
 
338
- } // namespace sub_warp_merge_sort
339
- } // namespace detail
344
+ } // namespace detail::sub_warp_merge_sort
340
345
 
341
346
  CUB_NAMESPACE_END
@@ -91,9 +91,7 @@ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
91
91
  } // namespace detail
92
92
  #endif // defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
93
93
 
94
- namespace detail
95
- {
96
- namespace three_way_partition
94
+ namespace detail::three_way_partition
97
95
  {
98
96
 
99
97
  template <class OffsetT>
@@ -603,7 +601,6 @@ struct AgentThreeWayPartition
603
601
  }
604
602
  };
605
603
 
606
- } // namespace three_way_partition
607
- } // namespace detail
604
+ } // namespace detail::three_way_partition
608
605
 
609
606
  CUB_NAMESPACE_END
@@ -85,13 +85,31 @@ struct AgentUniqueByKeyPolicy
85
85
  };
86
86
  };
87
87
 
88
+ #if defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
89
+ namespace detail
90
+ {
91
+ // Only define this when needed.
92
+ // Because of overload woes, this depends on C++20 concepts. util_device.h checks that concepts are available when
93
+ // either runtime policies or PTX JSON information are enabled, so if they are, this is always valid. The generic
94
+ // version is always defined, and that's the only one needed for regular CUB operations.
95
+ //
96
+ // TODO: enable this unconditionally once concepts are always available
97
+ CUB_DETAIL_POLICY_WRAPPER_DEFINE(
98
+ UniqueByKeyAgentPolicy,
99
+ (GenericAgentPolicy),
100
+ (BLOCK_THREADS, BlockThreads, int),
101
+ (ITEMS_PER_THREAD, ItemsPerThread, int),
102
+ (LOAD_ALGORITHM, LoadAlgorithm, cub::BlockLoadAlgorithm),
103
+ (LOAD_MODIFIER, LoadModifier, cub::CacheLoadModifier),
104
+ (SCAN_ALGORITHM, ScanAlgorithm, cub::BlockScanAlgorithm))
105
+ } // namespace detail
106
+ #endif // defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
107
+
88
108
  /******************************************************************************
89
109
  * Thread block abstractions
90
110
  ******************************************************************************/
91
111
 
92
- namespace detail
93
- {
94
- namespace unique_by_key
112
+ namespace detail::unique_by_key
95
113
  {
96
114
 
97
115
  /**
@@ -608,7 +626,6 @@ struct AgentUniqueByKey
608
626
  }
609
627
  };
610
628
 
611
- } // namespace unique_by_key
612
- } // namespace detail
629
+ } // namespace detail::unique_by_key
613
630
 
614
631
  CUB_NAMESPACE_END