cuda-cccl 0.1.3.2.0.dev438__cp312-cp312-manylinux_2_24_aarch64.whl → 0.3.0__cp312-cp312-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (60) hide show
  1. cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +23 -0
  2. cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +22 -14
  3. cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
  4. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
  5. cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
  6. cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
  7. cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
  8. cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +321 -262
  9. cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +8 -0
  10. cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
  11. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
  12. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +203 -51
  13. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
  14. cuda/cccl/headers/include/cub/util_device.cuh +51 -35
  15. cuda/cccl/headers/include/cuda/__algorithm/copy.h +3 -3
  16. cuda/cccl/headers/include/cuda/__device/all_devices.h +3 -6
  17. cuda/cccl/headers/include/cuda/__device/arch_traits.h +3 -3
  18. cuda/cccl/headers/include/cuda/__device/attributes.h +7 -7
  19. cuda/cccl/headers/include/cuda/__device/device_ref.h +3 -10
  20. cuda/cccl/headers/include/cuda/__driver/driver_api.h +225 -33
  21. cuda/cccl/headers/include/cuda/__event/event.h +7 -8
  22. cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
  23. cuda/cccl/headers/include/cuda/__event/timed_event.h +3 -4
  24. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
  25. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
  26. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
  27. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
  28. cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
  29. cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
  30. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +17 -12
  31. cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
  32. cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
  33. cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
  34. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
  35. cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
  36. cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
  37. cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
  38. cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
  39. cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
  40. cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
  41. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
  42. cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
  43. cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
  44. cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
  45. cuda/cccl/parallel/experimental/__init__.py +4 -0
  46. cuda/cccl/parallel/experimental/_bindings.pyi +28 -0
  47. cuda/cccl/parallel/experimental/_bindings_impl.pyx +140 -0
  48. cuda/cccl/parallel/experimental/algorithms/__init__.py +4 -0
  49. cuda/cccl/parallel/experimental/algorithms/_reduce.py +0 -2
  50. cuda/cccl/parallel/experimental/algorithms/_scan.py +0 -2
  51. cuda/cccl/parallel/experimental/algorithms/_three_way_partition.py +261 -0
  52. cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  53. cuda/cccl/parallel/experimental/cu12/cccl/libcccl.c.parallel.so +0 -0
  54. cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
  55. cuda/cccl/parallel/experimental/cu13/cccl/libcccl.c.parallel.so +0 -0
  56. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/METADATA +1 -1
  57. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/RECORD +59 -57
  58. cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
  59. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/WHEEL +0 -0
  60. {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -76,25 +76,25 @@ template <class _To, class _From>
76
76
  #if _CCCL_HAS_NVFP8_E8M0()
77
77
  else if constexpr (is_same_v<_To, __nv_fp8_e8m0>)
78
78
  {
79
- return ::cuda::std::__fp_from_storage<__nv_fp8_e8m0>(::__nv_cvt_float_to_e8m0(__v, __NV_NOSAT, cudaRoundZero));
79
+ return ::cuda::std::__fp_from_storage<__nv_fp8_e8m0>(::__nv_cvt_float_to_e8m0(__v, __NV_NOSAT, ::cudaRoundZero));
80
80
  }
81
81
  #endif // _CCCL_HAS_NVFP8_E8M0()
82
82
  #if _CCCL_HAS_NVFP6_E2M3()
83
83
  else if constexpr (is_same_v<_To, __nv_fp6_e2m3>)
84
84
  {
85
- return ::cuda::std::__fp_from_storage<__nv_fp6_e2m3>(::__nv_cvt_float_to_fp6(__v, __NV_E2M3, cudaRoundNearest));
85
+ return ::cuda::std::__fp_from_storage<__nv_fp6_e2m3>(::__nv_cvt_float_to_fp6(__v, __NV_E2M3, ::cudaRoundNearest));
86
86
  }
87
87
  #endif // _CCCL_HAS_NVFP6_E2M3()
88
88
  #if _CCCL_HAS_NVFP6_E3M2()
89
89
  else if constexpr (is_same_v<_To, __nv_fp6_e3m2>)
90
90
  {
91
- return ::cuda::std::__fp_from_storage<__nv_fp6_e3m2>(::__nv_cvt_float_to_fp6(__v, __NV_E3M2, cudaRoundNearest));
91
+ return ::cuda::std::__fp_from_storage<__nv_fp6_e3m2>(::__nv_cvt_float_to_fp6(__v, __NV_E3M2, ::cudaRoundNearest));
92
92
  }
93
93
  #endif // _CCCL_HAS_NVFP6_E3M2()
94
94
  #if _CCCL_HAS_NVFP4_E2M1()
95
95
  else if constexpr (is_same_v<_To, __nv_fp4_e2m1>)
96
96
  {
97
- return ::cuda::std::__fp_from_storage<__nv_fp4_e2m1>(::__nv_cvt_float_to_fp4(__v, __NV_E2M1, cudaRoundNearest));
97
+ return ::cuda::std::__fp_from_storage<__nv_fp4_e2m1>(::__nv_cvt_float_to_fp4(__v, __NV_E2M1, ::cudaRoundNearest));
98
98
  }
99
99
  #endif // _CCCL_HAS_NVFP4_E2M1()
100
100
  else
@@ -145,25 +145,28 @@ template <class _To, class _From>
145
145
  #if _CCCL_HAS_NVFP8_E8M0()
146
146
  else if constexpr (is_same_v<_To, __nv_fp8_e8m0>)
147
147
  {
148
- return ::cuda::std::__fp_from_storage<__nv_fp8_e8m0>(::__nv_cvt_double_to_e8m0(__v, __NV_NOSAT, cudaRoundZero));
148
+ return ::cuda::std::__fp_from_storage<__nv_fp8_e8m0>(::__nv_cvt_double_to_e8m0(__v, __NV_NOSAT, ::cudaRoundZero));
149
149
  }
150
150
  #endif // _CCCL_HAS_NVFP8_E8M0()
151
151
  #if _CCCL_HAS_NVFP6_E2M3()
152
152
  else if constexpr (is_same_v<_To, __nv_fp6_e2m3>)
153
153
  {
154
- return ::cuda::std::__fp_from_storage<__nv_fp6_e2m3>(::__nv_cvt_double_to_fp6(__v, __NV_E2M3, cudaRoundNearest));
154
+ return ::cuda::std::__fp_from_storage<__nv_fp6_e2m3>(
155
+ ::__nv_cvt_double_to_fp6(__v, __NV_E2M3, ::cudaRoundNearest));
155
156
  }
156
157
  #endif // _CCCL_HAS_NVFP6_E2M3()
157
158
  #if _CCCL_HAS_NVFP6_E3M2()
158
159
  else if constexpr (is_same_v<_To, __nv_fp6_e3m2>)
159
160
  {
160
- return ::cuda::std::__fp_from_storage<__nv_fp6_e3m2>(::__nv_cvt_double_to_fp6(__v, __NV_E3M2, cudaRoundNearest));
161
+ return ::cuda::std::__fp_from_storage<__nv_fp6_e3m2>(
162
+ ::__nv_cvt_double_to_fp6(__v, __NV_E3M2, ::cudaRoundNearest));
161
163
  }
162
164
  #endif // _CCCL_HAS_NVFP6_E3M2()
163
165
  #if _CCCL_HAS_NVFP4_E2M1()
164
166
  else if constexpr (is_same_v<_To, __nv_fp4_e2m1>)
165
167
  {
166
- return ::cuda::std::__fp_from_storage<__nv_fp4_e2m1>(::__nv_cvt_double_to_fp4(__v, __NV_E2M1, cudaRoundNearest));
168
+ return ::cuda::std::__fp_from_storage<__nv_fp4_e2m1>(
169
+ ::__nv_cvt_double_to_fp4(__v, __NV_E2M1, ::cudaRoundNearest));
167
170
  }
168
171
  #endif // _CCCL_HAS_NVFP4_E2M1()
169
172
  else
@@ -352,28 +355,28 @@ template <class _To, class _From>
352
355
  else if constexpr (is_same_v<_To, __nv_fp8_e8m0>)
353
356
  {
354
357
  return ::cuda::std::__fp_from_storage<__nv_fp8_e8m0>(
355
- ::__nv_cvt_bfloat16raw_to_e8m0(__v, __NV_NOSAT, cudaRoundZero));
358
+ ::__nv_cvt_bfloat16raw_to_e8m0(__v, __NV_NOSAT, ::cudaRoundZero));
356
359
  }
357
360
  # endif // _CCCL_HAS_NVFP8_E8M0()
358
361
  # if _CCCL_HAS_NVFP6_E2M3()
359
362
  else if constexpr (is_same_v<_To, __nv_fp6_e2m3>)
360
363
  {
361
364
  return ::cuda::std::__fp_from_storage<__nv_fp6_e2m3>(
362
- ::__nv_cvt_bfloat16raw_to_fp6(__v, __NV_E2M3, cudaRoundNearest));
365
+ ::__nv_cvt_bfloat16raw_to_fp6(__v, __NV_E2M3, ::cudaRoundNearest));
363
366
  }
364
367
  # endif // _CCCL_HAS_NVFP6_E2M3()
365
368
  # if _CCCL_HAS_NVFP6_E3M2()
366
369
  else if constexpr (is_same_v<_To, __nv_fp6_e3m2>)
367
370
  {
368
371
  return ::cuda::std::__fp_from_storage<__nv_fp6_e3m2>(
369
- ::__nv_cvt_bfloat16raw_to_fp6(__v, __NV_E3M2, cudaRoundNearest));
372
+ ::__nv_cvt_bfloat16raw_to_fp6(__v, __NV_E3M2, ::cudaRoundNearest));
370
373
  }
371
374
  # endif // _CCCL_HAS_NVFP6_E3M2()
372
375
  # if _CCCL_HAS_NVFP4_E2M1()
373
376
  else if constexpr (is_same_v<_To, __nv_fp4_e2m1>)
374
377
  {
375
378
  return ::cuda::std::__fp_from_storage<__nv_fp4_e2m1>(
376
- ::__nv_cvt_bfloat16raw_to_fp4(__v, __NV_E2M1, cudaRoundNearest));
379
+ ::__nv_cvt_bfloat16raw_to_fp4(__v, __NV_E2M1, ::cudaRoundNearest));
377
380
  }
378
381
  # endif // _CCCL_HAS_NVFP4_E2M1()
379
382
  else
@@ -55,6 +55,9 @@ _CCCL_DIAG_SUPPRESS_MSVC(4100) // unreferenced formal parameter
55
55
  _CCCL_DIAG_POP
56
56
  #endif // _CCCL_HAS_NVFP4()
57
57
 
58
+ // crt/device_fp128_functions.h is available in CUDA 12.8+.
59
+ // _CCCL_HAS_FLOAT128() checks the *compiler* compatibility with __float128.
60
+ // We also need to check the toolkit version to ensure the compatibility with nvc++.
58
61
  #if _CCCL_HAS_FLOAT128() && _CCCL_DEVICE_COMPILATION() && _CCCL_CTK_AT_LEAST(12, 8)
59
62
  # if !_CCCL_COMPILER(NVRTC)
60
63
  _CCCL_DIAG_PUSH
@@ -439,7 +439,8 @@ public:
439
439
  [[nodiscard]] _CCCL_API constexpr bool is_exhaustive() const
440
440
  noexcept(noexcept(::cuda::std::declval<const mapping_type&>().is_exhaustive()))
441
441
  {
442
- return mapping().is_exhaustive();
442
+ auto __tmp = mapping(); // workaround for clang with nodiscard
443
+ return __tmp.is_exhaustive();
443
444
  }
444
445
  [[nodiscard]] _CCCL_API constexpr bool is_strided() const
445
446
  noexcept(noexcept(::cuda::std::declval<const mapping_type&>().is_strided()))
@@ -20,10 +20,8 @@
20
20
  # pragma system_header
21
21
  #endif // no system header
22
22
 
23
- #include <cuda/std/__type_traits/integral_constant.h>
24
23
  #include <cuda/std/__type_traits/is_same.h>
25
24
  #include <cuda/std/__utility/declval.h>
26
- #include <cuda/std/cstddef>
27
25
 
28
26
  #include <cuda/std/__cccl/prologue.h>
29
27
 
@@ -49,6 +47,9 @@ struct __numeric_type
49
47
  _CCCL_API inline static double __test(unsigned long long);
50
48
  _CCCL_API inline static double __test(double);
51
49
  _CCCL_API inline static long double __test(long double);
50
+ #if _CCCL_HAS_FLOAT128()
51
+ _CCCL_API inline static __float128 __test(__float128);
52
+ #endif // _CCCL_HAS_FLOAT128()
52
53
 
53
54
  using type = decltype(__test(declval<_Tp>()));
54
55
  static const bool value = !is_same_v<type, void>;
@@ -39,37 +39,23 @@
39
39
  #if _CCCL_HAS_CUDA_COMPILER()
40
40
  # include <thrust/system/cuda/config.h>
41
41
 
42
- # include <thrust/distance.h>
43
- # include <thrust/system/cuda/detail/parallel_for.h>
42
+ # include <thrust/system/cuda/detail/transform.h>
44
43
  # include <thrust/system/cuda/execution_policy.h>
45
44
 
45
+ # include <cuda/__functional/address_stability.h>
46
+ # include <cuda/std/iterator>
47
+
46
48
  THRUST_NAMESPACE_BEGIN
47
49
  namespace cuda_cub
48
50
  {
49
- namespace __tabulate
50
- {
51
- template <class Iterator, class TabulateOp>
52
- struct functor
53
- {
54
- Iterator items;
55
- TabulateOp op;
56
-
57
- template <typename Size>
58
- void _CCCL_DEVICE operator()(Size idx)
59
- {
60
- items[idx] = op(idx);
61
- }
62
- };
63
- } // namespace __tabulate
64
-
65
51
  template <class Derived, class Iterator, class TabulateOp>
66
52
  void _CCCL_HOST_DEVICE tabulate(execution_policy<Derived>& policy, Iterator first, Iterator last, TabulateOp tabulate_op)
67
53
  {
68
- using size_type = thrust::detail::it_difference_t<Iterator>;
69
- size_type count = ::cuda::std::distance(first, last);
70
- cuda_cub::parallel_for(policy, __tabulate::functor<Iterator, TabulateOp>{first, tabulate_op}, count);
54
+ using size_type = ::cuda::std::iter_difference_t<Iterator>;
55
+ const auto count = ::cuda::std::distance(first, last);
56
+ cuda_cub::transform_n(
57
+ policy, ::cuda::counting_iterator<size_type>{}, count, first, ::cuda::proclaim_copyable_arguments(tabulate_op));
71
58
  }
72
-
73
59
  } // namespace cuda_cub
74
60
  THRUST_NAMESPACE_END
75
61
  #endif
@@ -25,72 +25,39 @@
25
25
 
26
26
  THRUST_NAMESPACE_BEGIN
27
27
 
28
- namespace detail
29
- {
30
- // Type traits for contiguous iterators:
31
- template <typename Iterator>
32
- struct contiguous_iterator_traits
33
- {
34
- static_assert(thrust::is_contiguous_iterator_v<Iterator>,
35
- "contiguous_iterator_traits requires a contiguous iterator.");
36
-
37
- using raw_pointer =
38
- typename thrust::detail::pointer_traits<decltype(&*::cuda::std::declval<Iterator>())>::raw_pointer;
39
- };
40
- } // namespace detail
41
-
42
- //! Converts a contiguous iterator type to its underlying raw pointer type.
43
- template <typename ContiguousIterator>
44
- using unwrap_contiguous_iterator_t = typename detail::contiguous_iterator_traits<ContiguousIterator>::raw_pointer;
45
-
46
28
  //! Converts a contiguous iterator to its underlying raw pointer.
29
+ _CCCL_EXEC_CHECK_DISABLE
47
30
  template <typename ContiguousIterator>
48
31
  _CCCL_HOST_DEVICE auto unwrap_contiguous_iterator(ContiguousIterator it)
49
- -> unwrap_contiguous_iterator_t<ContiguousIterator>
50
32
  {
51
33
  static_assert(thrust::is_contiguous_iterator_v<ContiguousIterator>,
52
34
  "unwrap_contiguous_iterator called with non-contiguous iterator.");
53
35
  return thrust::raw_pointer_cast(&*it);
54
36
  }
55
37
 
56
- namespace detail
57
- {
58
- // Implementation for non-contiguous iterators -- passthrough.
59
- template <typename Iterator, bool IsContiguous = thrust::is_contiguous_iterator_v<Iterator>>
60
- struct try_unwrap_contiguous_iterator_impl
61
- {
62
- using type = Iterator;
63
-
64
- static _CCCL_HOST_DEVICE type get(Iterator it)
65
- {
66
- return it;
67
- }
68
- };
38
+ //! Converts a contiguous iterator type to its underlying raw pointer type.
39
+ template <typename ContiguousIterator>
40
+ using unwrap_contiguous_iterator_t = decltype(unwrap_contiguous_iterator(::cuda::std::declval<ContiguousIterator>()));
69
41
 
70
- // Implementation for contiguous iterators -- unwraps to raw pointer.
42
+ //! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the
43
+ //! iterator unmodified.
44
+ _CCCL_EXEC_CHECK_DISABLE
71
45
  template <typename Iterator>
72
- struct try_unwrap_contiguous_iterator_impl<Iterator, true /*is_contiguous*/>
46
+ _CCCL_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it)
73
47
  {
74
- using type = unwrap_contiguous_iterator_t<Iterator>;
75
-
76
- static _CCCL_HOST_DEVICE type get(Iterator it)
48
+ if constexpr (thrust::is_contiguous_iterator_v<Iterator>)
77
49
  {
78
50
  return unwrap_contiguous_iterator(it);
79
51
  }
80
- };
81
- } // namespace detail
52
+ else
53
+ {
54
+ return it;
55
+ }
56
+ }
82
57
 
83
58
  //! Takes an iterator type and, if it is contiguous, yields the raw pointer type it represents. Otherwise returns the
84
59
  //! iterator type unmodified.
85
60
  template <typename Iterator>
86
- using try_unwrap_contiguous_iterator_t = typename detail::try_unwrap_contiguous_iterator_impl<Iterator>::type;
87
-
88
- //! Takes an iterator and, if it is contiguous, unwraps it to the raw pointer it represents. Otherwise returns the
89
- //! iterator unmodified.
90
- template <typename Iterator>
91
- _CCCL_HOST_DEVICE auto try_unwrap_contiguous_iterator(Iterator it) -> try_unwrap_contiguous_iterator_t<Iterator>
92
- {
93
- return detail::try_unwrap_contiguous_iterator_impl<Iterator>::get(it);
94
- }
61
+ using try_unwrap_contiguous_iterator_t = decltype(try_unwrap_contiguous_iterator(::cuda::std::declval<Iterator>()));
95
62
 
96
63
  THRUST_NAMESPACE_END
@@ -17,12 +17,14 @@ from .algorithms import (
17
17
  make_radix_sort,
18
18
  make_reduce_into,
19
19
  make_segmented_reduce,
20
+ make_three_way_partition,
20
21
  make_unary_transform,
21
22
  make_unique_by_key,
22
23
  merge_sort,
23
24
  radix_sort,
24
25
  reduce_into,
25
26
  segmented_reduce,
27
+ three_way_partition,
26
28
  unary_transform,
27
29
  unique_by_key,
28
30
  )
@@ -56,6 +58,7 @@ __all__ = [
56
58
  "make_radix_sort",
57
59
  "make_reduce_into",
58
60
  "make_segmented_reduce",
61
+ "make_three_way_partition",
59
62
  "make_unary_transform",
60
63
  "make_unique_by_key",
61
64
  "merge_sort",
@@ -66,6 +69,7 @@ __all__ = [
66
69
  "segmented_reduce",
67
70
  "SortOrder",
68
71
  "TransformIterator",
72
+ "three_way_partition",
69
73
  "TransformOutputIterator",
70
74
  "unary_transform",
71
75
  "unique_by_key",
@@ -390,6 +390,7 @@ class DeviceHistogramBuildResult:
390
390
  num_rows: int,
391
391
  row_stride_samples: int,
392
392
  is_evenly_segmented: bool,
393
+ info: CommonData,
393
394
  ): ...
394
395
  def compute_even(
395
396
  self,
@@ -403,3 +404,30 @@ class DeviceHistogramBuildResult:
403
404
  row_stride_samples: int,
404
405
  stream,
405
406
  ) -> None: ...
407
+
408
+ # ---------------------
409
+ # DeviceThreeWayPartition
410
+ # ---------------------
411
+
412
+ class DeviceThreeWayPartitionBuildResult:
413
+ def __init__(
414
+ self,
415
+ d_in: Iterator,
416
+ d_first_part_out: Iterator,
417
+ d_second_part_out: Iterator,
418
+ d_unselected_out: Iterator,
419
+ d_num_selected_out: Iterator,
420
+ select_first_part_op: Op,
421
+ select_second_part_op: Op,
422
+ info: CommonData,
423
+ ): ...
424
+ def compute(
425
+ self,
426
+ d_in: Iterator,
427
+ d_first_part_out: Iterator,
428
+ d_second_part_out: Iterator,
429
+ d_unselected_out: Iterator,
430
+ d_num_selected_out: Iterator,
431
+ num_items: int,
432
+ stream,
433
+ ) -> int: ...
@@ -1982,3 +1982,143 @@ cdef class DeviceHistogramBuildResult:
1982
1982
  <const char*>self.build_data.cubin,
1983
1983
  self.build_data.cubin_size
1984
1984
  )
1985
+
1986
+
1987
+ # ----------------------------------
1988
+ # DeviceThreeWayPartitionBuildResult
1989
+ # ----------------------------------
1990
+ cdef extern from "cccl/c/three_way_partition.h":
1991
+ cdef struct cccl_device_three_way_partition_build_result_t 'cccl_device_three_way_partition_build_result_t':
1992
+ const char* cubin
1993
+ size_t cubin_size
1994
+
1995
+ cdef CUresult cccl_device_three_way_partition_build(
1996
+ cccl_device_three_way_partition_build_result_t *build_ptr,
1997
+ cccl_iterator_t d_in,
1998
+ cccl_iterator_t d_first_part_out,
1999
+ cccl_iterator_t d_second_part_out,
2000
+ cccl_iterator_t d_unselected_out,
2001
+ cccl_iterator_t d_num_selected_out,
2002
+ cccl_op_t select_first_part_op,
2003
+ cccl_op_t select_second_part_op,
2004
+ int, int, const char *, const char *, const char *, const char *
2005
+ ) nogil
2006
+
2007
+ CUresult cccl_device_three_way_partition(
2008
+ cccl_device_three_way_partition_build_result_t build,
2009
+ void* d_temp_storage,
2010
+ size_t* temp_storage_bytes,
2011
+ cccl_iterator_t d_in,
2012
+ cccl_iterator_t d_first_part_out,
2013
+ cccl_iterator_t d_second_part_out,
2014
+ cccl_iterator_t d_unselected_out,
2015
+ cccl_iterator_t d_num_selected_out,
2016
+ cccl_op_t select_first_part_op,
2017
+ cccl_op_t select_second_part_op,
2018
+ int64_t num_items,
2019
+ CUstream stream
2020
+ ) nogil
2021
+
2022
+ cdef CUresult cccl_device_three_way_partition_cleanup(
2023
+ cccl_device_three_way_partition_build_result_t *build_ptr
2024
+ ) nogil
2025
+
2026
+
2027
+ cdef class DeviceThreeWayPartitionBuildResult:
2028
+ cdef cccl_device_three_way_partition_build_result_t build_data
2029
+
2030
+ def __dealloc__(DeviceThreeWayPartitionBuildResult self):
2031
+ cdef CUresult status = -1
2032
+ with nogil:
2033
+ status = cccl_device_three_way_partition_cleanup(&self.build_data)
2034
+ if (status != 0):
2035
+ print(f"Return code {status} encountered during three_way_partition result cleanup")
2036
+
2037
+
2038
+ def __cinit__(
2039
+ DeviceThreeWayPartitionBuildResult self,
2040
+ Iterator d_in,
2041
+ Iterator d_first_part_out,
2042
+ Iterator d_second_part_out,
2043
+ Iterator d_unselected_out,
2044
+ Iterator d_num_selected_out,
2045
+ Op select_first_part_op,
2046
+ Op select_second_part_op,
2047
+ CommonData common_data
2048
+ ):
2049
+ cdef CUresult status = -1
2050
+ cdef int cc_major = common_data.get_cc_major()
2051
+ cdef int cc_minor = common_data.get_cc_minor()
2052
+ cdef const char *cub_path = common_data.cub_path_get_c_str()
2053
+ cdef const char *thrust_path = common_data.thrust_path_get_c_str()
2054
+ cdef const char *libcudacxx_path = common_data.libcudacxx_path_get_c_str()
2055
+ cdef const char *ctk_path = common_data.ctk_path_get_c_str()
2056
+
2057
+ memset(&self.build_data, 0, sizeof(cccl_device_three_way_partition_build_result_t))
2058
+ with nogil:
2059
+ status = cccl_device_three_way_partition_build(
2060
+ &self.build_data,
2061
+ d_in.iter_data,
2062
+ d_first_part_out.iter_data,
2063
+ d_second_part_out.iter_data,
2064
+ d_unselected_out.iter_data,
2065
+ d_num_selected_out.iter_data,
2066
+ select_first_part_op.op_data,
2067
+ select_second_part_op.op_data,
2068
+ cc_major,
2069
+ cc_minor,
2070
+ cub_path,
2071
+ thrust_path,
2072
+ libcudacxx_path,
2073
+ ctk_path,
2074
+ )
2075
+ if status != 0:
2076
+ raise RuntimeError(
2077
+ f"Failed building three_way_partition, error code: {status}"
2078
+ )
2079
+
2080
+ cpdef int compute(
2081
+ DeviceThreeWayPartitionBuildResult self,
2082
+ temp_storage_ptr,
2083
+ temp_storage_bytes,
2084
+ Iterator d_in,
2085
+ Iterator d_first_part_out,
2086
+ Iterator d_second_part_out,
2087
+ Iterator d_unselected_out,
2088
+ Iterator d_num_selected_out,
2089
+ Op select_first_part_op,
2090
+ Op select_second_part_op,
2091
+ size_t num_items,
2092
+ stream
2093
+ ):
2094
+ cdef CUresult status = -1
2095
+ cdef void *storage_ptr = (<void *><uintptr_t>temp_storage_ptr) if temp_storage_ptr else NULL
2096
+ cdef size_t storage_sz = <size_t>temp_storage_bytes
2097
+ cdef CUstream c_stream = <CUstream><uintptr_t>(stream) if stream else NULL
2098
+
2099
+ with nogil:
2100
+ status = cccl_device_three_way_partition(
2101
+ self.build_data,
2102
+ storage_ptr,
2103
+ &storage_sz,
2104
+ d_in.iter_data,
2105
+ d_first_part_out.iter_data,
2106
+ d_second_part_out.iter_data,
2107
+ d_unselected_out.iter_data,
2108
+ d_num_selected_out.iter_data,
2109
+ select_first_part_op.op_data,
2110
+ select_second_part_op.op_data,
2111
+ <uint64_t>num_items,
2112
+ c_stream
2113
+ )
2114
+ if status != 0:
2115
+ raise RuntimeError(
2116
+ f"Failed executing three_way_partition, error code: {status}"
2117
+ )
2118
+ return storage_sz
2119
+
2120
+ def _get_cubin(self):
2121
+ return PyBytes_FromStringAndSize(
2122
+ <const char*>self.build_data.cubin,
2123
+ self.build_data.cubin_size
2124
+ )
@@ -18,6 +18,8 @@ from ._scan import make_exclusive_scan as make_exclusive_scan
18
18
  from ._scan import make_inclusive_scan as make_inclusive_scan
19
19
  from ._segmented_reduce import make_segmented_reduce as make_segmented_reduce
20
20
  from ._segmented_reduce import segmented_reduce
21
+ from ._three_way_partition import make_three_way_partition as make_three_way_partition
22
+ from ._three_way_partition import three_way_partition as three_way_partition
21
23
  from ._transform import binary_transform, unary_transform
22
24
  from ._transform import make_binary_transform as make_binary_transform
23
25
  from ._transform import make_unary_transform as make_unary_transform
@@ -45,6 +47,8 @@ __all__ = [
45
47
  "make_segmented_reduce",
46
48
  "unique_by_key",
47
49
  "make_unique_by_key",
50
+ "three_way_partition",
51
+ "make_three_way_partition",
48
52
  "DoubleBuffer",
49
53
  "SortOrder",
50
54
  ]
@@ -3,8 +3,6 @@
3
3
  #
4
4
  # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5
5
 
6
- from __future__ import annotations # TODO: required for Python 3.7 docs env
7
-
8
6
  from typing import Callable, Union
9
7
 
10
8
  import numba
@@ -3,8 +3,6 @@
3
3
  #
4
4
  # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5
5
 
6
- from __future__ import annotations # TODO: required for Python 3.7 docs env
7
-
8
6
  from typing import Callable, Union
9
7
 
10
8
  import numba