cuda-cccl 0.3.1__cp311-cp311-manylinux_2_24_aarch64.whl → 0.3.2__cp311-cp311-manylinux_2_24_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuda-cccl might be problematic. Click here for more details.

Files changed (185) hide show
  1. cuda/cccl/headers/include/cub/agent/agent_histogram.cuh +354 -572
  2. cuda/cccl/headers/include/cub/block/block_adjacent_difference.cuh +6 -8
  3. cuda/cccl/headers/include/cub/block/block_discontinuity.cuh +24 -14
  4. cuda/cccl/headers/include/cub/block/block_exchange.cuh +5 -0
  5. cuda/cccl/headers/include/cub/block/block_histogram.cuh +4 -0
  6. cuda/cccl/headers/include/cub/block/block_load.cuh +4 -0
  7. cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +1 -0
  8. cuda/cccl/headers/include/cub/block/block_reduce.cuh +1 -0
  9. cuda/cccl/headers/include/cub/block/block_scan.cuh +12 -2
  10. cuda/cccl/headers/include/cub/block/block_store.cuh +3 -2
  11. cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +34 -30
  12. cuda/cccl/headers/include/cub/detail/ptx-json-parser.h +1 -1
  13. cuda/cccl/headers/include/cub/device/device_for.cuh +118 -40
  14. cuda/cccl/headers/include/cub/device/device_reduce.cuh +6 -7
  15. cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +12 -13
  16. cuda/cccl/headers/include/cub/device/device_transform.cuh +122 -91
  17. cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +2 -3
  18. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +4 -3
  19. cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -1
  20. cuda/cccl/headers/include/cub/device/dispatch/dispatch_streaming_reduce.cuh +4 -5
  21. cuda/cccl/headers/include/cub/device/dispatch/dispatch_streaming_reduce_by_key.cuh +0 -1
  22. cuda/cccl/headers/include/cub/device/dispatch/dispatch_topk.cuh +3 -5
  23. cuda/cccl/headers/include/cub/device/dispatch/dispatch_transform.cuh +13 -5
  24. cuda/cccl/headers/include/cub/device/dispatch/kernels/for_each.cuh +72 -37
  25. cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +22 -27
  26. cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +61 -70
  27. cuda/cccl/headers/include/cub/thread/thread_reduce.cuh +24 -17
  28. cuda/cccl/headers/include/cub/warp/warp_load.cuh +6 -6
  29. cuda/cccl/headers/include/cub/warp/warp_reduce.cuh +7 -2
  30. cuda/cccl/headers/include/cub/warp/warp_scan.cuh +7 -3
  31. cuda/cccl/headers/include/cub/warp/warp_store.cuh +1 -0
  32. cuda/cccl/headers/include/cuda/__barrier/barrier_block_scope.h +19 -0
  33. cuda/cccl/headers/include/cuda/__cccl_config +1 -0
  34. cuda/cccl/headers/include/cuda/__cmath/fast_modulo_division.h +3 -74
  35. cuda/cccl/headers/include/cuda/__cmath/mul_hi.h +146 -0
  36. cuda/cccl/headers/include/cuda/__complex/get_real_imag.h +0 -4
  37. cuda/cccl/headers/include/cuda/__device/arch_id.h +176 -0
  38. cuda/cccl/headers/include/cuda/__device/arch_traits.h +239 -317
  39. cuda/cccl/headers/include/cuda/__device/attributes.h +4 -3
  40. cuda/cccl/headers/include/cuda/__device/compute_capability.h +171 -0
  41. cuda/cccl/headers/include/cuda/__device/device_ref.h +0 -10
  42. cuda/cccl/headers/include/cuda/__device/physical_device.h +1 -26
  43. cuda/cccl/headers/include/cuda/__event/event.h +26 -26
  44. cuda/cccl/headers/include/cuda/__event/event_ref.h +5 -5
  45. cuda/cccl/headers/include/cuda/__event/timed_event.h +9 -7
  46. cuda/cccl/headers/include/cuda/__fwd/devices.h +4 -4
  47. cuda/cccl/headers/include/cuda/__iterator/constant_iterator.h +46 -31
  48. cuda/cccl/headers/include/cuda/__iterator/strided_iterator.h +79 -47
  49. cuda/cccl/headers/include/cuda/__iterator/tabulate_output_iterator.h +59 -36
  50. cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +79 -49
  51. cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +74 -48
  52. cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +80 -55
  53. cuda/cccl/headers/include/cuda/__iterator/zip_common.h +2 -12
  54. cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +15 -19
  55. cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +59 -60
  56. cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +127 -60
  57. cuda/cccl/headers/include/cuda/__mdspan/host_device_mdspan.h +178 -3
  58. cuda/cccl/headers/include/cuda/__mdspan/restrict_accessor.h +38 -8
  59. cuda/cccl/headers/include/cuda/__mdspan/restrict_mdspan.h +67 -1
  60. cuda/cccl/headers/include/cuda/__memory/ptr_in_range.h +93 -0
  61. cuda/cccl/headers/include/cuda/__memory_resource/get_memory_resource.h +4 -4
  62. cuda/cccl/headers/include/cuda/__memory_resource/properties.h +44 -0
  63. cuda/cccl/headers/include/cuda/__memory_resource/resource.h +1 -1
  64. cuda/cccl/headers/include/cuda/__memory_resource/resource_ref.h +4 -6
  65. cuda/cccl/headers/include/cuda/__nvtx/nvtx3.h +2 -1
  66. cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +5 -4
  67. cuda/cccl/headers/include/cuda/__stream/stream.h +8 -8
  68. cuda/cccl/headers/include/cuda/__stream/stream_ref.h +17 -16
  69. cuda/cccl/headers/include/cuda/__utility/in_range.h +65 -0
  70. cuda/cccl/headers/include/cuda/cmath +1 -0
  71. cuda/cccl/headers/include/cuda/devices +3 -0
  72. cuda/cccl/headers/include/cuda/memory +1 -0
  73. cuda/cccl/headers/include/cuda/std/__algorithm/equal_range.h +2 -2
  74. cuda/cccl/headers/include/cuda/std/__algorithm/find.h +1 -1
  75. cuda/cccl/headers/include/cuda/std/__algorithm/includes.h +2 -4
  76. cuda/cccl/headers/include/cuda/std/__algorithm/lower_bound.h +1 -1
  77. cuda/cccl/headers/include/cuda/std/__algorithm/make_projected.h +7 -15
  78. cuda/cccl/headers/include/cuda/std/__algorithm/min_element.h +1 -1
  79. cuda/cccl/headers/include/cuda/std/__algorithm/minmax_element.h +1 -2
  80. cuda/cccl/headers/include/cuda/std/__algorithm/partial_sort_copy.h +2 -2
  81. cuda/cccl/headers/include/cuda/std/__algorithm/upper_bound.h +1 -1
  82. cuda/cccl/headers/include/cuda/std/__cccl/algorithm_wrapper.h +36 -0
  83. cuda/cccl/headers/include/cuda/std/__cccl/builtin.h +46 -49
  84. cuda/cccl/headers/include/cuda/std/__cccl/execution_space.h +6 -0
  85. cuda/cccl/headers/include/cuda/std/__cccl/host_std_lib.h +52 -0
  86. cuda/cccl/headers/include/cuda/std/__cccl/memory_wrapper.h +36 -0
  87. cuda/cccl/headers/include/cuda/std/__cccl/numeric_wrapper.h +36 -0
  88. cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +3 -2
  89. cuda/cccl/headers/include/cuda/std/__complex/complex.h +3 -2
  90. cuda/cccl/headers/include/cuda/std/__complex/literals.h +14 -34
  91. cuda/cccl/headers/include/cuda/std/__complex/nvbf16.h +2 -1
  92. cuda/cccl/headers/include/cuda/std/__complex/nvfp16.h +4 -3
  93. cuda/cccl/headers/include/cuda/std/__concepts/invocable.h +2 -2
  94. cuda/cccl/headers/include/cuda/std/__cstdlib/malloc.h +3 -2
  95. cuda/cccl/headers/include/cuda/std/__functional/bind.h +10 -13
  96. cuda/cccl/headers/include/cuda/std/__functional/function.h +5 -8
  97. cuda/cccl/headers/include/cuda/std/__functional/invoke.h +71 -335
  98. cuda/cccl/headers/include/cuda/std/__functional/mem_fn.h +1 -2
  99. cuda/cccl/headers/include/cuda/std/__functional/reference_wrapper.h +3 -3
  100. cuda/cccl/headers/include/cuda/std/__functional/weak_result_type.h +0 -6
  101. cuda/cccl/headers/include/cuda/std/__fwd/allocator.h +13 -0
  102. cuda/cccl/headers/include/cuda/std/__fwd/char_traits.h +13 -0
  103. cuda/cccl/headers/include/cuda/std/__fwd/complex.h +13 -4
  104. cuda/cccl/headers/include/cuda/std/__fwd/mdspan.h +23 -0
  105. cuda/cccl/headers/include/cuda/std/__fwd/pair.h +13 -0
  106. cuda/cccl/headers/include/cuda/std/__fwd/string.h +22 -0
  107. cuda/cccl/headers/include/cuda/std/__fwd/string_view.h +14 -0
  108. cuda/cccl/headers/include/cuda/std/__internal/features.h +0 -5
  109. cuda/cccl/headers/include/cuda/std/__internal/namespaces.h +21 -0
  110. cuda/cccl/headers/include/cuda/std/__iterator/iterator_traits.h +5 -5
  111. cuda/cccl/headers/include/cuda/std/__mdspan/extents.h +7 -1
  112. cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +53 -39
  113. cuda/cccl/headers/include/cuda/std/__memory/allocator.h +3 -3
  114. cuda/cccl/headers/include/cuda/std/__memory/construct_at.h +1 -3
  115. cuda/cccl/headers/include/cuda/std/__optional/optional_base.h +1 -0
  116. cuda/cccl/headers/include/cuda/std/__ranges/compressed_movable_box.h +892 -0
  117. cuda/cccl/headers/include/cuda/std/__ranges/movable_box.h +2 -2
  118. cuda/cccl/headers/include/cuda/std/__type_traits/is_primary_template.h +7 -5
  119. cuda/cccl/headers/include/cuda/std/__type_traits/result_of.h +1 -1
  120. cuda/cccl/headers/include/cuda/std/__utility/pair.h +0 -5
  121. cuda/cccl/headers/include/cuda/std/bitset +1 -1
  122. cuda/cccl/headers/include/cuda/std/detail/libcxx/include/__config +15 -12
  123. cuda/cccl/headers/include/cuda/std/detail/libcxx/include/variant +11 -9
  124. cuda/cccl/headers/include/cuda/std/inplace_vector +4 -4
  125. cuda/cccl/headers/include/cuda/std/numbers +5 -0
  126. cuda/cccl/headers/include/cuda/std/string_view +146 -11
  127. cuda/cccl/headers/include/cuda/stream_ref +5 -0
  128. cuda/cccl/headers/include/cuda/utility +1 -0
  129. cuda/cccl/headers/include/nv/target +7 -2
  130. cuda/cccl/headers/include/thrust/allocate_unique.h +1 -1
  131. cuda/cccl/headers/include/thrust/detail/allocator/allocator_traits.h +309 -33
  132. cuda/cccl/headers/include/thrust/detail/allocator/copy_construct_range.h +151 -4
  133. cuda/cccl/headers/include/thrust/detail/allocator/destroy_range.h +60 -3
  134. cuda/cccl/headers/include/thrust/detail/allocator/fill_construct_range.h +45 -3
  135. cuda/cccl/headers/include/thrust/detail/allocator/malloc_allocator.h +31 -6
  136. cuda/cccl/headers/include/thrust/detail/allocator/tagged_allocator.h +29 -16
  137. cuda/cccl/headers/include/thrust/detail/allocator/temporary_allocator.h +41 -4
  138. cuda/cccl/headers/include/thrust/detail/allocator/value_initialize_range.h +42 -4
  139. cuda/cccl/headers/include/thrust/detail/complex/ccosh.h +3 -3
  140. cuda/cccl/headers/include/thrust/detail/internal_functional.h +1 -1
  141. cuda/cccl/headers/include/thrust/detail/memory_algorithms.h +1 -1
  142. cuda/cccl/headers/include/thrust/detail/temporary_array.h +1 -1
  143. cuda/cccl/headers/include/thrust/detail/type_traits.h +1 -1
  144. cuda/cccl/headers/include/thrust/device_delete.h +18 -3
  145. cuda/cccl/headers/include/thrust/device_free.h +16 -3
  146. cuda/cccl/headers/include/thrust/device_new.h +29 -8
  147. cuda/cccl/headers/include/thrust/host_vector.h +1 -1
  148. cuda/cccl/headers/include/thrust/iterator/tabulate_output_iterator.h +5 -2
  149. cuda/cccl/headers/include/thrust/mr/disjoint_pool.h +1 -1
  150. cuda/cccl/headers/include/thrust/mr/pool.h +1 -1
  151. cuda/cccl/headers/include/thrust/system/cuda/detail/find.h +13 -115
  152. cuda/cccl/headers/include/thrust/system/cuda/detail/mismatch.h +8 -2
  153. cuda/cccl/headers/include/thrust/type_traits/is_contiguous_iterator.h +7 -7
  154. cuda/compute/__init__.py +2 -0
  155. cuda/compute/_bindings.pyi +43 -1
  156. cuda/compute/_bindings_impl.pyx +156 -7
  157. cuda/compute/algorithms/_scan.py +108 -36
  158. cuda/compute/algorithms/_transform.py +32 -11
  159. cuda/compute/cu12/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
  160. cuda/compute/cu12/cccl/libcccl.c.parallel.so +0 -0
  161. cuda/compute/cu13/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
  162. cuda/compute/cu13/cccl/libcccl.c.parallel.so +0 -0
  163. cuda/compute/iterators/__init__.py +2 -0
  164. cuda/compute/iterators/_factories.py +28 -0
  165. cuda/compute/iterators/_iterators.py +206 -1
  166. cuda/compute/numba_utils.py +2 -2
  167. cuda/compute/typing.py +2 -0
  168. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/METADATA +1 -1
  169. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/RECORD +171 -175
  170. cuda/cccl/headers/include/thrust/detail/algorithm_wrapper.h +0 -37
  171. cuda/cccl/headers/include/thrust/detail/allocator/allocator_traits.inl +0 -371
  172. cuda/cccl/headers/include/thrust/detail/allocator/copy_construct_range.inl +0 -242
  173. cuda/cccl/headers/include/thrust/detail/allocator/destroy_range.inl +0 -137
  174. cuda/cccl/headers/include/thrust/detail/allocator/fill_construct_range.inl +0 -99
  175. cuda/cccl/headers/include/thrust/detail/allocator/malloc_allocator.inl +0 -68
  176. cuda/cccl/headers/include/thrust/detail/allocator/tagged_allocator.inl +0 -86
  177. cuda/cccl/headers/include/thrust/detail/allocator/temporary_allocator.inl +0 -79
  178. cuda/cccl/headers/include/thrust/detail/allocator/value_initialize_range.inl +0 -98
  179. cuda/cccl/headers/include/thrust/detail/device_delete.inl +0 -52
  180. cuda/cccl/headers/include/thrust/detail/device_free.inl +0 -47
  181. cuda/cccl/headers/include/thrust/detail/device_new.inl +0 -61
  182. cuda/cccl/headers/include/thrust/detail/memory_wrapper.h +0 -40
  183. cuda/cccl/headers/include/thrust/detail/numeric_wrapper.h +0 -37
  184. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/WHEEL +0 -0
  185. {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@
3
3
  #
4
4
  # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5
5
 
6
- from typing import Callable, Union
6
+ from typing import Callable, Union, cast
7
7
 
8
8
  import numba
9
9
  import numpy as np
@@ -20,14 +20,27 @@ from ..op import OpKind
20
20
  from ..typing import DeviceArrayLike, GpuStruct
21
21
 
22
22
 
23
+ def get_init_kind(
24
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
25
+ ) -> _bindings.InitKind:
26
+ match init_value:
27
+ case None:
28
+ return _bindings.InitKind.NO_INIT
29
+ case _ if isinstance(init_value, DeviceArrayLike):
30
+ return _bindings.InitKind.FUTURE_VALUE_INIT
31
+ case _:
32
+ return _bindings.InitKind.VALUE_INIT
33
+
34
+
23
35
  class _Scan:
24
36
  __slots__ = [
25
37
  "build_result",
26
38
  "d_in_cccl",
27
39
  "d_out_cccl",
28
- "h_init_cccl",
40
+ "init_value_cccl",
29
41
  "op_wrapper",
30
42
  "device_scan_fn",
43
+ "init_kind",
31
44
  ]
32
45
 
33
46
  # TODO: constructor shouldn't require concrete `d_in`, `d_out`:
@@ -36,36 +49,74 @@ class _Scan:
36
49
  d_in: DeviceArrayLike | IteratorBase,
37
50
  d_out: DeviceArrayLike | IteratorBase,
38
51
  op: Callable | OpKind,
39
- h_init: np.ndarray | GpuStruct,
52
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
40
53
  force_inclusive: bool,
41
54
  ):
42
55
  self.d_in_cccl = cccl.to_cccl_input_iter(d_in)
43
56
  self.d_out_cccl = cccl.to_cccl_output_iter(d_out)
44
- self.h_init_cccl = cccl.to_cccl_value(h_init)
45
- if isinstance(h_init, np.ndarray):
46
- value_type = numba.from_dtype(h_init.dtype)
47
- else:
48
- value_type = numba.typeof(h_init)
57
+
58
+ self.init_kind = get_init_kind(init_value)
59
+
60
+ self.init_value_cccl: _bindings.Iterator | _bindings.Value | None
61
+
62
+ match self.init_kind:
63
+ case _bindings.InitKind.NO_INIT:
64
+ # TODO: we just need to extract the dtype from the input iterator
65
+ if not isinstance(d_in, DeviceArrayLike):
66
+ raise ValueError(
67
+ "No init value not supported for non-DeviceArrayLike input"
68
+ )
69
+
70
+ self.init_value_cccl = None
71
+ value_type = numba.from_dtype(protocols.get_dtype(d_in))
72
+ init_value_type_info = self.d_in_cccl.value_type
73
+
74
+ case _bindings.InitKind.FUTURE_VALUE_INIT:
75
+ self.init_value_cccl = cccl.to_cccl_input_iter(init_value)
76
+ value_type = numba.from_dtype(
77
+ protocols.get_dtype(cast(DeviceArrayLike, init_value))
78
+ )
79
+ init_value_type_info = self.init_value_cccl.value_type
80
+
81
+ case _bindings.InitKind.VALUE_INIT:
82
+ self.init_value_cccl = cccl.to_cccl_value(init_value)
83
+ value_type = (
84
+ numba.from_dtype(init_value.dtype)
85
+ if isinstance(init_value, np.ndarray)
86
+ else numba.typeof(init_value)
87
+ )
88
+ init_value_type_info = self.init_value_cccl.type
49
89
 
50
90
  # For well-known operations, we don't need a signature
51
91
  if isinstance(op, OpKind):
52
92
  self.op_wrapper = cccl.to_cccl_op(op, None)
53
93
  else:
54
94
  self.op_wrapper = cccl.to_cccl_op(op, value_type(value_type, value_type))
95
+
55
96
  self.build_result = call_build(
56
97
  _bindings.DeviceScanBuildResult,
57
98
  self.d_in_cccl,
58
99
  self.d_out_cccl,
59
100
  self.op_wrapper,
60
- self.h_init_cccl,
101
+ init_value_type_info,
61
102
  force_inclusive,
103
+ self.init_kind,
62
104
  )
63
105
 
64
- self.device_scan_fn = (
65
- self.build_result.compute_inclusive
66
- if force_inclusive
67
- else self.build_result.compute_exclusive
68
- )
106
+ match (force_inclusive, self.init_kind):
107
+ case (True, _bindings.InitKind.FUTURE_VALUE_INIT):
108
+ self.device_scan_fn = self.build_result.compute_inclusive_future_value
109
+ case (True, _bindings.InitKind.VALUE_INIT):
110
+ self.device_scan_fn = self.build_result.compute_inclusive
111
+ case (True, _bindings.InitKind.NO_INIT):
112
+ self.device_scan_fn = self.build_result.compute_inclusive_no_init
113
+
114
+ case (False, _bindings.InitKind.FUTURE_VALUE_INIT):
115
+ self.device_scan_fn = self.build_result.compute_exclusive_future_value
116
+ case (False, _bindings.InitKind.VALUE_INIT):
117
+ self.device_scan_fn = self.build_result.compute_exclusive
118
+ case (False, _bindings.InitKind.NO_INIT):
119
+ raise ValueError("Exclusive scan with No init value is not supported")
69
120
 
70
121
  def __call__(
71
122
  self,
@@ -73,13 +124,25 @@ class _Scan:
73
124
  d_in,
74
125
  d_out,
75
126
  num_items: int,
76
- h_init: np.ndarray | GpuStruct,
127
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
77
128
  stream=None,
78
129
  ):
79
130
  set_cccl_iterator_state(self.d_in_cccl, d_in)
80
131
  set_cccl_iterator_state(self.d_out_cccl, d_out)
81
132
 
82
- self.h_init_cccl.state = to_cccl_value_state(h_init)
133
+ match self.init_kind:
134
+ case _bindings.InitKind.FUTURE_VALUE_INIT:
135
+ # We know that the init_value_cccl is an Iterator, so this cast
136
+ # tells MyPy what the actual type is. cast() is a no-op at runtime,
137
+ # which makes it better than isinstance() since this is a hot path
138
+ # and we have to minimize the work we do prior to calling the
139
+ # kernel.
140
+ self.init_value_cccl = cast(_bindings.Iterator, self.init_value_cccl)
141
+ set_cccl_iterator_state(self.init_value_cccl, init_value)
142
+
143
+ case _bindings.InitKind.VALUE_INIT:
144
+ self.init_value_cccl = cast(_bindings.Value, self.init_value_cccl)
145
+ self.init_value_cccl.state = to_cccl_value_state(init_value)
83
146
 
84
147
  stream_handle = validate_and_get_stream(stream)
85
148
 
@@ -97,7 +160,7 @@ class _Scan:
97
160
  self.d_out_cccl,
98
161
  num_items,
99
162
  self.op_wrapper,
100
- self.h_init_cccl,
163
+ self.init_value_cccl,
101
164
  stream_handle,
102
165
  )
103
166
  return temp_storage_bytes
@@ -107,7 +170,7 @@ def make_cache_key(
107
170
  d_in: DeviceArrayLike | IteratorBase,
108
171
  d_out: DeviceArrayLike | IteratorBase,
109
172
  op: Callable | OpKind,
110
- h_init: np.ndarray,
173
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
111
174
  ):
112
175
  d_in_key = (
113
176
  d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in)
@@ -123,8 +186,17 @@ def make_cache_key(
123
186
  else:
124
187
  op_key = CachableFunction(op)
125
188
 
126
- h_init_key = h_init.dtype
127
- return (d_in_key, d_out_key, op_key, h_init_key)
189
+ init_kind_key = get_init_kind(init_value)
190
+ match init_kind_key:
191
+ case _bindings.InitKind.NO_INIT:
192
+ init_value_key = None
193
+ case _bindings.InitKind.FUTURE_VALUE_INIT:
194
+ init_value_key = protocols.get_dtype(cast(DeviceArrayLike, init_value))
195
+ case _bindings.InitKind.VALUE_INIT:
196
+ init_value = cast(np.ndarray | GpuStruct, init_value)
197
+ init_value_key = init_value.dtype
198
+
199
+ return (d_in_key, d_out_key, op_key, init_value_key, init_kind_key)
128
200
 
129
201
 
130
202
  # TODO Figure out `sum` without operator and initial value
@@ -134,7 +206,7 @@ def make_exclusive_scan(
134
206
  d_in: DeviceArrayLike | IteratorBase,
135
207
  d_out: DeviceArrayLike | IteratorBase,
136
208
  op: Callable | OpKind,
137
- h_init: np.ndarray,
209
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
138
210
  ):
139
211
  """Computes a device-wide scan using the specified binary ``op`` and initial value ``init``.
140
212
 
@@ -150,19 +222,19 @@ def make_exclusive_scan(
150
222
  d_in: Device array or iterator containing the input sequence of data items
151
223
  d_out: Device array that will store the result of the scan
152
224
  op: Callable or OpKind representing the binary operator to apply
153
- init: Numpy array storing initial value of the scan
225
+ init_value: Numpy array, device array, or GPU struct storing initial value of the scan, or None for no initial value
154
226
 
155
227
  Returns:
156
228
  A callable object that can be used to perform the scan
157
229
  """
158
- return _Scan(d_in, d_out, op, h_init, False)
230
+ return _Scan(d_in, d_out, op, init_value, False)
159
231
 
160
232
 
161
233
  def exclusive_scan(
162
234
  d_in: DeviceArrayLike | IteratorBase,
163
235
  d_out: DeviceArrayLike | IteratorBase,
164
236
  op: Callable | OpKind,
165
- h_init: np.ndarray | GpuStruct,
237
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
166
238
  num_items: int,
167
239
  stream=None,
168
240
  ):
@@ -183,14 +255,14 @@ def exclusive_scan(
183
255
  d_in: Device array or iterator containing the input sequence of data items
184
256
  d_out: Device array or iterator to store the result of the scan
185
257
  op: Binary scan operator
186
- h_init: Initial value for the scan
258
+ init_value: Initial value for the scan
187
259
  num_items: Number of items to scan
188
260
  stream: CUDA stream for the operation (optional)
189
261
  """
190
- scanner = make_exclusive_scan(d_in, d_out, op, h_init)
191
- tmp_storage_bytes = scanner(None, d_in, d_out, num_items, h_init, stream)
262
+ scanner = make_exclusive_scan(d_in, d_out, op, init_value)
263
+ tmp_storage_bytes = scanner(None, d_in, d_out, num_items, init_value, stream)
192
264
  tmp_storage = TempStorageBuffer(tmp_storage_bytes, stream)
193
- scanner(tmp_storage, d_in, d_out, num_items, h_init, stream)
265
+ scanner(tmp_storage, d_in, d_out, num_items, init_value, stream)
194
266
 
195
267
 
196
268
  # TODO Figure out `sum` without operator and initial value
@@ -200,7 +272,7 @@ def make_inclusive_scan(
200
272
  d_in: DeviceArrayLike | IteratorBase,
201
273
  d_out: DeviceArrayLike | IteratorBase,
202
274
  op: Callable | OpKind,
203
- h_init: np.ndarray,
275
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
204
276
  ):
205
277
  """Computes a device-wide scan using the specified binary ``op`` and initial value ``init``.
206
278
 
@@ -216,19 +288,19 @@ def make_inclusive_scan(
216
288
  d_in: Device array or iterator containing the input sequence of data items
217
289
  d_out: Device array that will store the result of the scan
218
290
  op: Callable or OpKind representing the binary operator to apply
219
- init: Numpy array storing initial value of the scan
291
+ init_value: Numpy array, device array, or GPU struct storing initial value of the scan, or None for no initial value
220
292
 
221
293
  Returns:
222
294
  A callable object that can be used to perform the scan
223
295
  """
224
- return _Scan(d_in, d_out, op, h_init, True)
296
+ return _Scan(d_in, d_out, op, init_value, True)
225
297
 
226
298
 
227
299
  def inclusive_scan(
228
300
  d_in: DeviceArrayLike | IteratorBase,
229
301
  d_out: DeviceArrayLike | IteratorBase,
230
302
  op: Callable | OpKind,
231
- h_init: np.ndarray | GpuStruct,
303
+ init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
232
304
  num_items: int,
233
305
  stream=None,
234
306
  ):
@@ -249,11 +321,11 @@ def inclusive_scan(
249
321
  d_in: Device array or iterator containing the input sequence of data items
250
322
  d_out: Device array or iterator to store the result of the scan
251
323
  op: Binary scan operator
252
- h_init: Initial value for the scan
324
+ init_value: Initial value for the scan
253
325
  num_items: Number of items to scan
254
326
  stream: CUDA stream for the operation (optional)
255
327
  """
256
- scanner = make_inclusive_scan(d_in, d_out, op, h_init)
257
- tmp_storage_bytes = scanner(None, d_in, d_out, num_items, h_init, stream)
328
+ scanner = make_inclusive_scan(d_in, d_out, op, init_value)
329
+ tmp_storage_bytes = scanner(None, d_in, d_out, num_items, init_value, stream)
258
330
  tmp_storage = TempStorageBuffer(tmp_storage_bytes, stream)
259
- scanner(tmp_storage, d_in, d_out, num_items, h_init, stream)
331
+ scanner(tmp_storage, d_in, d_out, num_items, init_value, stream)
@@ -11,7 +11,7 @@ from .._caching import CachableFunction, cache_with_key
11
11
  from .._cccl_interop import set_cccl_iterator_state
12
12
  from .._utils import protocols
13
13
  from ..iterators._iterators import IteratorBase
14
- from ..numba_utils import get_inferred_return_type
14
+ from ..numba_utils import get_inferred_return_type, signature_from_annotations
15
15
  from ..op import OpKind
16
16
  from ..typing import DeviceArrayLike
17
17
 
@@ -32,16 +32,20 @@ class _UnaryTransform:
32
32
  ):
33
33
  self.d_in_cccl = cccl.to_cccl_input_iter(d_in)
34
34
  self.d_out_cccl = cccl.to_cccl_output_iter(d_out)
35
- in_value_type = cccl.get_value_type(d_in)
36
- out_value_type = cccl.get_value_type(d_out)
37
35
 
38
36
  # For well-known operations, we don't need a signature
39
37
  if isinstance(op, OpKind):
40
38
  self.op_wrapper = cccl.to_cccl_op(op, None)
41
39
  else:
42
- if not out_value_type.is_internal:
43
- out_value_type = get_inferred_return_type(op, (in_value_type,))
44
- sig = out_value_type(in_value_type)
40
+ try:
41
+ sig = signature_from_annotations(op)
42
+ except ValueError:
43
+ in_value_type = cccl.get_value_type(d_in)
44
+ out_value_type = cccl.get_value_type(d_out)
45
+ if not out_value_type.is_internal:
46
+ out_value_type = get_inferred_return_type(op, (in_value_type,))
47
+ sig = out_value_type(in_value_type)
48
+
45
49
  self.op_wrapper = cccl.to_cccl_op(op, sig=sig)
46
50
  self.build_result = cccl.call_build(
47
51
  _bindings.DeviceUnaryTransform,
@@ -97,11 +101,14 @@ class _BinaryTransform:
97
101
  if isinstance(op, OpKind):
98
102
  self.op_wrapper = cccl.to_cccl_op(op, None)
99
103
  else:
100
- if not out_value_type.is_internal:
101
- out_value_type = get_inferred_return_type(
102
- op, (in1_value_type, in2_value_type)
103
- )
104
- sig = out_value_type(in1_value_type, in2_value_type)
104
+ try:
105
+ sig = signature_from_annotations(op)
106
+ except ValueError:
107
+ if not out_value_type.is_internal:
108
+ out_value_type = get_inferred_return_type(
109
+ op, (in1_value_type, in2_value_type)
110
+ )
111
+ sig = out_value_type(in1_value_type, in2_value_type)
105
112
  self.op_wrapper = cccl.to_cccl_op(op, sig=sig)
106
113
  self.build_result = cccl.call_build(
107
114
  _bindings.DeviceBinaryTransform,
@@ -263,6 +270,13 @@ def unary_transform(
263
270
  :language: python
264
271
  :start-after: # example-begin
265
272
 
273
+ When working with custom struct types, you need to provide type annotations
274
+ to help with type inference. See the binary transform struct example for reference:
275
+
276
+ .. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/transform/binary_transform_struct.py
277
+ :language: python
278
+ :start-after: # example-begin
279
+
266
280
 
267
281
  Args:
268
282
  d_in: Device array or iterator containing the input sequence of data items.
@@ -295,6 +309,13 @@ def binary_transform(
295
309
  :language: python
296
310
  :start-after: # example-begin
297
311
 
312
+ When working with custom struct types, you need to provide type annotations
313
+ to help with type inference. See the following example:
314
+
315
+ .. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/transform/binary_transform_struct.py
316
+ :language: python
317
+ :start-after: # example-begin
318
+
298
319
 
299
320
  Args:
300
321
  d_in1: Device array or iterator containing the first input sequence of data items.
@@ -2,6 +2,7 @@ from ._factories import (
2
2
  CacheModifiedInputIterator,
3
3
  ConstantIterator,
4
4
  CountingIterator,
5
+ PermutationIterator,
5
6
  ReverseIterator,
6
7
  TransformIterator,
7
8
  TransformOutputIterator,
@@ -12,6 +13,7 @@ __all__ = [
12
13
  "CacheModifiedInputIterator",
13
14
  "ConstantIterator",
14
15
  "CountingIterator",
16
+ "PermutationIterator",
15
17
  "ReverseIterator",
16
18
  "TransformIterator",
17
19
  "TransformOutputIterator",
@@ -10,6 +10,7 @@ from ._iterators import (
10
10
  CountingIterator as _CountingIterator,
11
11
  )
12
12
  from ._iterators import (
13
+ make_permutation_iterator,
13
14
  make_reverse_iterator,
14
15
  make_transform_iterator,
15
16
  )
@@ -165,6 +166,33 @@ def TransformOutputIterator(it, op):
165
166
  return make_transform_iterator(it, op, "output")
166
167
 
167
168
 
169
+ def PermutationIterator(values, indices):
170
+ """Returns an Iterator that accesses values through an index mapping.
171
+
172
+ Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1permutation__iterator.html
173
+
174
+ The permutation iterator accesses elements from the values collection using indices
175
+ from the indices collection, effectively computing values[indices[i]] at position i.
176
+ This is useful for gather/scatter operations and indirect array access patterns.
177
+
178
+ Example:
179
+ The code snippet below demonstrates the usage of a ``PermutationIterator``
180
+ to access values in a permuted order:
181
+
182
+ .. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/iterator/permutation_iterator_basic.py
183
+ :language: python
184
+ :start-after: # example-begin
185
+
186
+ Args:
187
+ values: The values array or iterator to be permuted
188
+ indices: An iterator or device array providing the indices for permutation
189
+
190
+ Returns:
191
+ A ``PermutationIterator`` object that yields values[indices[i]] at position i
192
+ """
193
+ return make_permutation_iterator(values, indices)
194
+
195
+
168
196
  def ZipIterator(*iterators):
169
197
  """Returns an Iterator representing a zipped sequence of values from N iterators.
170
198
 
@@ -207,7 +207,15 @@ def pointer_add_intrinsic(context, ptr, offset):
207
207
  def codegen(context, builder, sig, args):
208
208
  ptr, index = args
209
209
  base = builder.ptrtoint(ptr, ir.IntType(_DEVICE_POINTER_BITWIDTH))
210
- offset = builder.mul(index, sizeof_pointee(context, ptr))
210
+ sizeof = sizeof_pointee(context, ptr)
211
+ # Cast index to match sizeof type if needed
212
+ if index.type != sizeof.type:
213
+ index = (
214
+ builder.sext(index, sizeof.type)
215
+ if index.type.width < sizeof.type.width
216
+ else builder.trunc(index, sizeof.type)
217
+ )
218
+ offset = builder.mul(index, sizeof)
211
219
  result = builder.add(base, offset)
212
220
  return builder.inttoptr(result, ptr.type)
213
221
 
@@ -610,3 +618,200 @@ def _get_last_element_ptr(device_array) -> int:
610
618
 
611
619
  ptr = get_data_pointer(device_array)
612
620
  return ptr + offset_to_last_element
621
+
622
+
623
+ class PermutationIteratorKind(IteratorKind):
624
+ pass
625
+
626
+
627
+ def make_permutation_iterator(values, indices):
628
+ """
629
+ Create a PermutationIterator that accesses values through an index mapping.
630
+
631
+ The permutation iterator accesses elements from `values` using indices from `indices`,
632
+ effectively computing values[indices[i]] at position i.
633
+
634
+ Args:
635
+ values: The values array or iterator to permute
636
+ indices: The indices array or iterator specifying the permutation
637
+
638
+ Returns:
639
+ PermutationIterator: Iterator that yields permuted values
640
+ """
641
+ # Convert arrays to iterators if needed
642
+ if hasattr(values, "__cuda_array_interface__"):
643
+ values = pointer(values, numba.from_dtype(get_dtype(values)))
644
+ elif not isinstance(values, IteratorBase):
645
+ raise TypeError("values must be a device array or iterator")
646
+
647
+ if hasattr(indices, "__cuda_array_interface__"):
648
+ indices = pointer(indices, numba.from_dtype(get_dtype(indices)))
649
+ elif not isinstance(indices, IteratorBase):
650
+ raise TypeError("indices must be an iterator or device array")
651
+
652
+ # JIT compile value advance/dereference methods
653
+ value_dtype = values.value_type
654
+ values_state_type = values.state_type
655
+ index_type = indices.value_type
656
+ value_advance = cuda.jit(values.advance, device=True)
657
+ value_input_dereference = cuda.jit(values.input_dereference, device=True)
658
+
659
+ try:
660
+ output_deref = values.output_dereference
661
+ if output_deref is not None:
662
+ value_output_dereference = cuda.jit(output_deref, device=True)
663
+ values_is_output_iterator = True
664
+ else:
665
+ values_is_output_iterator = False
666
+ except AttributeError:
667
+ values_is_output_iterator = False
668
+
669
+ # JIT compile index advance/dereference methods
670
+ index_advance = cuda.jit(indices.advance, device=True)
671
+ index_input_dereference = cuda.jit(indices.input_dereference, device=True)
672
+
673
+ # The cvalue and state for PermutationIterator are
674
+ # structs composed of the cvalues and states of the
675
+ # value and index iterators.
676
+ from ..struct import gpu_struct_from_numba_types
677
+
678
+ class PermutationCValueStruct(ctypes.Structure):
679
+ _fields_ = [
680
+ ("value_state", values.cvalue.__class__),
681
+ ("index_state", indices.cvalue.__class__),
682
+ ]
683
+
684
+ PermutationState = gpu_struct_from_numba_types(
685
+ "PermutationState",
686
+ ("value_state", "index_state"),
687
+ (values_state_type, indices.state_type),
688
+ )
689
+
690
+ cvalue = PermutationCValueStruct(values.cvalue, indices.cvalue)
691
+ state_type = PermutationState._numba_type
692
+ value_type = value_dtype
693
+
694
+ # Define intrinsics for accessing struct fields
695
+ @intrinsic
696
+ def get_value_state_field_ptr(context, struct_ptr_type):
697
+ def codegen(context, builder, sig, args):
698
+ struct_ptr = args[0]
699
+ # Use GEP to get pointer to field at index 0 (value_state)
700
+ field_ptr = builder.gep(
701
+ struct_ptr,
702
+ [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
703
+ )
704
+ return field_ptr
705
+
706
+ from numba.core.datamodel.registry import default_manager
707
+
708
+ struct_model = default_manager.lookup(struct_ptr_type.dtype)
709
+ field_type = struct_model._members[0]
710
+ return types.CPointer(field_type)(struct_ptr_type), codegen
711
+
712
+ @intrinsic
713
+ def get_index_state_field_ptr(context, struct_ptr_type):
714
+ def codegen(context, builder, sig, args):
715
+ struct_ptr = args[0]
716
+ # Use GEP to get pointer to field at index 1 (index_state)
717
+ field_ptr = builder.gep(
718
+ struct_ptr,
719
+ [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 1)],
720
+ )
721
+ return field_ptr
722
+
723
+ from numba.core.datamodel.registry import default_manager
724
+
725
+ struct_model = default_manager.lookup(struct_ptr_type.dtype)
726
+ field_type = struct_model._members[1]
727
+ return types.CPointer(field_type)(struct_ptr_type), codegen
728
+
729
+ # Create intrinsic for allocating temporary storage for index
730
+ @intrinsic
731
+ def alloca_temp_for_index_type(context):
732
+ def codegen(context, builder, sig, args):
733
+ temp_value_type = context.get_value_type(index_type)
734
+ temp_ptr = builder.alloca(temp_value_type)
735
+ return temp_ptr
736
+
737
+ return types.CPointer(index_type)(), codegen
738
+
739
+ # Create intrinsic for allocating temporary storage for value state
740
+ @intrinsic
741
+ def alloca_temp_for_value_state(context):
742
+ def codegen(context, builder, sig, args):
743
+ temp_state_type = context.get_value_type(values_state_type)
744
+ temp_ptr = builder.alloca(temp_state_type)
745
+ return temp_ptr
746
+
747
+ return types.CPointer(values_state_type)(), codegen
748
+
749
+ class PermutationIterator(IteratorBase):
750
+ iterator_kind_type = PermutationIteratorKind
751
+
752
+ def __init__(self, values_it, indices_it):
753
+ self._values = values_it
754
+ self._indices = indices_it
755
+ super().__init__(
756
+ cvalue=cvalue,
757
+ state_type=state_type,
758
+ value_type=value_type,
759
+ )
760
+ self._kind = self.__class__.iterator_kind_type(
761
+ (value_type, values_it.kind, indices_it.kind), state_type
762
+ )
763
+
764
+ @property
765
+ def advance(self):
766
+ return PermutationIterator._advance
767
+
768
+ @property
769
+ def input_dereference(self):
770
+ return PermutationIterator._input_dereference
771
+
772
+ @property
773
+ def output_dereference(self):
774
+ if not values_is_output_iterator:
775
+ raise AttributeError(
776
+ "PermutationIterator cannot be used as output iterator "
777
+ "when values iterator does not support output"
778
+ )
779
+ return PermutationIterator._output_dereference
780
+
781
+ @staticmethod
782
+ def _advance(state, distance):
783
+ # advance the index iterator
784
+ index_state_ptr = get_index_state_field_ptr(state)
785
+ index_advance(index_state_ptr, distance)
786
+
787
+ @staticmethod
788
+ def _input_dereference(state, result):
789
+ # dereference index to get the index value
790
+ index_state_ptr = get_index_state_field_ptr(state)
791
+ temp_index = alloca_temp_for_index_type()
792
+ index_input_dereference(index_state_ptr, temp_index)
793
+
794
+ # copy the value state (which always points to position 0)
795
+ # and advance it by the index value
796
+ value_state_ptr = get_value_state_field_ptr(state)
797
+ temp_value_state = alloca_temp_for_value_state()
798
+ temp_value_state[0] = value_state_ptr[0]
799
+ value_advance(temp_value_state, temp_index[0])
800
+ value_input_dereference(temp_value_state, result)
801
+
802
+ @staticmethod
803
+ def _output_dereference(state, x):
804
+ # dereference index to get the index value
805
+ index_state_ptr = get_index_state_field_ptr(state)
806
+ temp_index = alloca_temp_for_index_type()
807
+ index_input_dereference(index_state_ptr, temp_index)
808
+
809
+ # copy the value state (which always points to position 0)
810
+ # and advance it by the index value
811
+ value_state_ptr = get_value_state_field_ptr(state)
812
+ temp_value_state = alloca_temp_for_value_state()
813
+ temp_value_state[0] = value_state_ptr[0]
814
+ value_advance(temp_value_state, temp_index[0])
815
+ value_output_dereference(temp_value_state, x)
816
+
817
+ return PermutationIterator(values, indices)
@@ -39,10 +39,10 @@ def signature_from_annotations(func) -> numba.core.typing.Signature:
39
39
  argspec = inspect.getfullargspec(func)
40
40
  num_args = len(argspec.args)
41
41
  try:
42
- retty = to_numba_type(argspec.annotations["return"])
42
+ ret_ann = argspec.annotations["return"]
43
43
  except KeyError:
44
44
  raise ValueError("Function has incomplete annotations: missing return type")
45
-
45
+ retty = to_numba_type(ret_ann)
46
46
  if num_args != len(argspec.annotations) - 1: # -1 for the return type
47
47
  raise ValueError("One or more arguments are missing type annotations")
48
48
  argtys = tuple(
cuda/compute/typing.py CHANGED
@@ -7,9 +7,11 @@ from typing import Any
7
7
 
8
8
  from typing_extensions import (
9
9
  Protocol,
10
+ runtime_checkable,
10
11
  ) # TODO: typing_extensions required for Python 3.7 docs env
11
12
 
12
13
 
14
+ @runtime_checkable
13
15
  class DeviceArrayLike(Protocol):
14
16
  """
15
17
  Objects representing a device array, having a `.__cuda_array_interface__`