cuda-cccl 0.3.1__cp312-cp312-manylinux_2_24_aarch64.whl → 0.3.2__cp312-cp312-manylinux_2_24_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cuda-cccl might be problematic. Click here for more details.
- cuda/cccl/headers/include/cub/agent/agent_histogram.cuh +354 -572
- cuda/cccl/headers/include/cub/block/block_adjacent_difference.cuh +6 -8
- cuda/cccl/headers/include/cub/block/block_discontinuity.cuh +24 -14
- cuda/cccl/headers/include/cub/block/block_exchange.cuh +5 -0
- cuda/cccl/headers/include/cub/block/block_histogram.cuh +4 -0
- cuda/cccl/headers/include/cub/block/block_load.cuh +4 -0
- cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +1 -0
- cuda/cccl/headers/include/cub/block/block_reduce.cuh +1 -0
- cuda/cccl/headers/include/cub/block/block_scan.cuh +12 -2
- cuda/cccl/headers/include/cub/block/block_store.cuh +3 -2
- cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +34 -30
- cuda/cccl/headers/include/cub/detail/ptx-json-parser.h +1 -1
- cuda/cccl/headers/include/cub/device/device_for.cuh +118 -40
- cuda/cccl/headers/include/cub/device/device_reduce.cuh +6 -7
- cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +12 -13
- cuda/cccl/headers/include/cub/device/device_transform.cuh +122 -91
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +2 -3
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +4 -3
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -1
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_streaming_reduce.cuh +4 -5
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_streaming_reduce_by_key.cuh +0 -1
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_topk.cuh +3 -5
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_transform.cuh +13 -5
- cuda/cccl/headers/include/cub/device/dispatch/kernels/for_each.cuh +72 -37
- cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +22 -27
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +61 -70
- cuda/cccl/headers/include/cub/thread/thread_reduce.cuh +24 -17
- cuda/cccl/headers/include/cub/warp/warp_load.cuh +6 -6
- cuda/cccl/headers/include/cub/warp/warp_reduce.cuh +7 -2
- cuda/cccl/headers/include/cub/warp/warp_scan.cuh +7 -3
- cuda/cccl/headers/include/cub/warp/warp_store.cuh +1 -0
- cuda/cccl/headers/include/cuda/__barrier/barrier_block_scope.h +19 -0
- cuda/cccl/headers/include/cuda/__cccl_config +1 -0
- cuda/cccl/headers/include/cuda/__cmath/fast_modulo_division.h +3 -74
- cuda/cccl/headers/include/cuda/__cmath/mul_hi.h +146 -0
- cuda/cccl/headers/include/cuda/__complex/get_real_imag.h +0 -4
- cuda/cccl/headers/include/cuda/__device/arch_id.h +176 -0
- cuda/cccl/headers/include/cuda/__device/arch_traits.h +239 -317
- cuda/cccl/headers/include/cuda/__device/attributes.h +4 -3
- cuda/cccl/headers/include/cuda/__device/compute_capability.h +171 -0
- cuda/cccl/headers/include/cuda/__device/device_ref.h +0 -10
- cuda/cccl/headers/include/cuda/__device/physical_device.h +1 -26
- cuda/cccl/headers/include/cuda/__event/event.h +26 -26
- cuda/cccl/headers/include/cuda/__event/event_ref.h +5 -5
- cuda/cccl/headers/include/cuda/__event/timed_event.h +9 -7
- cuda/cccl/headers/include/cuda/__fwd/devices.h +4 -4
- cuda/cccl/headers/include/cuda/__iterator/constant_iterator.h +46 -31
- cuda/cccl/headers/include/cuda/__iterator/strided_iterator.h +79 -47
- cuda/cccl/headers/include/cuda/__iterator/tabulate_output_iterator.h +59 -36
- cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +79 -49
- cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +74 -48
- cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +80 -55
- cuda/cccl/headers/include/cuda/__iterator/zip_common.h +2 -12
- cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +15 -19
- cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +59 -60
- cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +127 -60
- cuda/cccl/headers/include/cuda/__mdspan/host_device_mdspan.h +178 -3
- cuda/cccl/headers/include/cuda/__mdspan/restrict_accessor.h +38 -8
- cuda/cccl/headers/include/cuda/__mdspan/restrict_mdspan.h +67 -1
- cuda/cccl/headers/include/cuda/__memory/ptr_in_range.h +93 -0
- cuda/cccl/headers/include/cuda/__memory_resource/get_memory_resource.h +4 -4
- cuda/cccl/headers/include/cuda/__memory_resource/properties.h +44 -0
- cuda/cccl/headers/include/cuda/__memory_resource/resource.h +1 -1
- cuda/cccl/headers/include/cuda/__memory_resource/resource_ref.h +4 -6
- cuda/cccl/headers/include/cuda/__nvtx/nvtx3.h +2 -1
- cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +5 -4
- cuda/cccl/headers/include/cuda/__stream/stream.h +8 -8
- cuda/cccl/headers/include/cuda/__stream/stream_ref.h +17 -16
- cuda/cccl/headers/include/cuda/__utility/in_range.h +65 -0
- cuda/cccl/headers/include/cuda/cmath +1 -0
- cuda/cccl/headers/include/cuda/devices +3 -0
- cuda/cccl/headers/include/cuda/memory +1 -0
- cuda/cccl/headers/include/cuda/std/__algorithm/equal_range.h +2 -2
- cuda/cccl/headers/include/cuda/std/__algorithm/find.h +1 -1
- cuda/cccl/headers/include/cuda/std/__algorithm/includes.h +2 -4
- cuda/cccl/headers/include/cuda/std/__algorithm/lower_bound.h +1 -1
- cuda/cccl/headers/include/cuda/std/__algorithm/make_projected.h +7 -15
- cuda/cccl/headers/include/cuda/std/__algorithm/min_element.h +1 -1
- cuda/cccl/headers/include/cuda/std/__algorithm/minmax_element.h +1 -2
- cuda/cccl/headers/include/cuda/std/__algorithm/partial_sort_copy.h +2 -2
- cuda/cccl/headers/include/cuda/std/__algorithm/upper_bound.h +1 -1
- cuda/cccl/headers/include/cuda/std/__cccl/algorithm_wrapper.h +36 -0
- cuda/cccl/headers/include/cuda/std/__cccl/builtin.h +46 -49
- cuda/cccl/headers/include/cuda/std/__cccl/execution_space.h +6 -0
- cuda/cccl/headers/include/cuda/std/__cccl/host_std_lib.h +52 -0
- cuda/cccl/headers/include/cuda/std/__cccl/memory_wrapper.h +36 -0
- cuda/cccl/headers/include/cuda/std/__cccl/numeric_wrapper.h +36 -0
- cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +3 -2
- cuda/cccl/headers/include/cuda/std/__complex/complex.h +3 -2
- cuda/cccl/headers/include/cuda/std/__complex/literals.h +14 -34
- cuda/cccl/headers/include/cuda/std/__complex/nvbf16.h +2 -1
- cuda/cccl/headers/include/cuda/std/__complex/nvfp16.h +4 -3
- cuda/cccl/headers/include/cuda/std/__concepts/invocable.h +2 -2
- cuda/cccl/headers/include/cuda/std/__cstdlib/malloc.h +3 -2
- cuda/cccl/headers/include/cuda/std/__functional/bind.h +10 -13
- cuda/cccl/headers/include/cuda/std/__functional/function.h +5 -8
- cuda/cccl/headers/include/cuda/std/__functional/invoke.h +71 -335
- cuda/cccl/headers/include/cuda/std/__functional/mem_fn.h +1 -2
- cuda/cccl/headers/include/cuda/std/__functional/reference_wrapper.h +3 -3
- cuda/cccl/headers/include/cuda/std/__functional/weak_result_type.h +0 -6
- cuda/cccl/headers/include/cuda/std/__fwd/allocator.h +13 -0
- cuda/cccl/headers/include/cuda/std/__fwd/char_traits.h +13 -0
- cuda/cccl/headers/include/cuda/std/__fwd/complex.h +13 -4
- cuda/cccl/headers/include/cuda/std/__fwd/mdspan.h +23 -0
- cuda/cccl/headers/include/cuda/std/__fwd/pair.h +13 -0
- cuda/cccl/headers/include/cuda/std/__fwd/string.h +22 -0
- cuda/cccl/headers/include/cuda/std/__fwd/string_view.h +14 -0
- cuda/cccl/headers/include/cuda/std/__internal/features.h +0 -5
- cuda/cccl/headers/include/cuda/std/__internal/namespaces.h +21 -0
- cuda/cccl/headers/include/cuda/std/__iterator/iterator_traits.h +5 -5
- cuda/cccl/headers/include/cuda/std/__mdspan/extents.h +7 -1
- cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +53 -39
- cuda/cccl/headers/include/cuda/std/__memory/allocator.h +3 -3
- cuda/cccl/headers/include/cuda/std/__memory/construct_at.h +1 -3
- cuda/cccl/headers/include/cuda/std/__optional/optional_base.h +1 -0
- cuda/cccl/headers/include/cuda/std/__ranges/compressed_movable_box.h +892 -0
- cuda/cccl/headers/include/cuda/std/__ranges/movable_box.h +2 -2
- cuda/cccl/headers/include/cuda/std/__type_traits/is_primary_template.h +7 -5
- cuda/cccl/headers/include/cuda/std/__type_traits/result_of.h +1 -1
- cuda/cccl/headers/include/cuda/std/__utility/pair.h +0 -5
- cuda/cccl/headers/include/cuda/std/bitset +1 -1
- cuda/cccl/headers/include/cuda/std/detail/libcxx/include/__config +15 -12
- cuda/cccl/headers/include/cuda/std/detail/libcxx/include/variant +11 -9
- cuda/cccl/headers/include/cuda/std/inplace_vector +4 -4
- cuda/cccl/headers/include/cuda/std/numbers +5 -0
- cuda/cccl/headers/include/cuda/std/string_view +146 -11
- cuda/cccl/headers/include/cuda/stream_ref +5 -0
- cuda/cccl/headers/include/cuda/utility +1 -0
- cuda/cccl/headers/include/nv/target +7 -2
- cuda/cccl/headers/include/thrust/allocate_unique.h +1 -1
- cuda/cccl/headers/include/thrust/detail/allocator/allocator_traits.h +309 -33
- cuda/cccl/headers/include/thrust/detail/allocator/copy_construct_range.h +151 -4
- cuda/cccl/headers/include/thrust/detail/allocator/destroy_range.h +60 -3
- cuda/cccl/headers/include/thrust/detail/allocator/fill_construct_range.h +45 -3
- cuda/cccl/headers/include/thrust/detail/allocator/malloc_allocator.h +31 -6
- cuda/cccl/headers/include/thrust/detail/allocator/tagged_allocator.h +29 -16
- cuda/cccl/headers/include/thrust/detail/allocator/temporary_allocator.h +41 -4
- cuda/cccl/headers/include/thrust/detail/allocator/value_initialize_range.h +42 -4
- cuda/cccl/headers/include/thrust/detail/complex/ccosh.h +3 -3
- cuda/cccl/headers/include/thrust/detail/internal_functional.h +1 -1
- cuda/cccl/headers/include/thrust/detail/memory_algorithms.h +1 -1
- cuda/cccl/headers/include/thrust/detail/temporary_array.h +1 -1
- cuda/cccl/headers/include/thrust/detail/type_traits.h +1 -1
- cuda/cccl/headers/include/thrust/device_delete.h +18 -3
- cuda/cccl/headers/include/thrust/device_free.h +16 -3
- cuda/cccl/headers/include/thrust/device_new.h +29 -8
- cuda/cccl/headers/include/thrust/host_vector.h +1 -1
- cuda/cccl/headers/include/thrust/iterator/tabulate_output_iterator.h +5 -2
- cuda/cccl/headers/include/thrust/mr/disjoint_pool.h +1 -1
- cuda/cccl/headers/include/thrust/mr/pool.h +1 -1
- cuda/cccl/headers/include/thrust/system/cuda/detail/find.h +13 -115
- cuda/cccl/headers/include/thrust/system/cuda/detail/mismatch.h +8 -2
- cuda/cccl/headers/include/thrust/type_traits/is_contiguous_iterator.h +7 -7
- cuda/compute/__init__.py +2 -0
- cuda/compute/_bindings.pyi +43 -1
- cuda/compute/_bindings_impl.pyx +156 -7
- cuda/compute/algorithms/_scan.py +108 -36
- cuda/compute/algorithms/_transform.py +32 -11
- cuda/compute/cu12/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
- cuda/compute/cu12/cccl/libcccl.c.parallel.so +0 -0
- cuda/compute/cu13/_bindings_impl.cpython-312-aarch64-linux-gnu.so +0 -0
- cuda/compute/cu13/cccl/libcccl.c.parallel.so +0 -0
- cuda/compute/iterators/__init__.py +2 -0
- cuda/compute/iterators/_factories.py +28 -0
- cuda/compute/iterators/_iterators.py +206 -1
- cuda/compute/numba_utils.py +2 -2
- cuda/compute/typing.py +2 -0
- {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/METADATA +1 -1
- {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/RECORD +171 -175
- cuda/cccl/headers/include/thrust/detail/algorithm_wrapper.h +0 -37
- cuda/cccl/headers/include/thrust/detail/allocator/allocator_traits.inl +0 -371
- cuda/cccl/headers/include/thrust/detail/allocator/copy_construct_range.inl +0 -242
- cuda/cccl/headers/include/thrust/detail/allocator/destroy_range.inl +0 -137
- cuda/cccl/headers/include/thrust/detail/allocator/fill_construct_range.inl +0 -99
- cuda/cccl/headers/include/thrust/detail/allocator/malloc_allocator.inl +0 -68
- cuda/cccl/headers/include/thrust/detail/allocator/tagged_allocator.inl +0 -86
- cuda/cccl/headers/include/thrust/detail/allocator/temporary_allocator.inl +0 -79
- cuda/cccl/headers/include/thrust/detail/allocator/value_initialize_range.inl +0 -98
- cuda/cccl/headers/include/thrust/detail/device_delete.inl +0 -52
- cuda/cccl/headers/include/thrust/detail/device_free.inl +0 -47
- cuda/cccl/headers/include/thrust/detail/device_new.inl +0 -61
- cuda/cccl/headers/include/thrust/detail/memory_wrapper.h +0 -40
- cuda/cccl/headers/include/thrust/detail/numeric_wrapper.h +0 -37
- {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/WHEEL +0 -0
- {cuda_cccl-0.3.1.dist-info → cuda_cccl-0.3.2.dist-info}/licenses/LICENSE +0 -0
cuda/compute/algorithms/_scan.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
#
|
|
4
4
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
5
5
|
|
|
6
|
-
from typing import Callable, Union
|
|
6
|
+
from typing import Callable, Union, cast
|
|
7
7
|
|
|
8
8
|
import numba
|
|
9
9
|
import numpy as np
|
|
@@ -20,14 +20,27 @@ from ..op import OpKind
|
|
|
20
20
|
from ..typing import DeviceArrayLike, GpuStruct
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def get_init_kind(
|
|
24
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
25
|
+
) -> _bindings.InitKind:
|
|
26
|
+
match init_value:
|
|
27
|
+
case None:
|
|
28
|
+
return _bindings.InitKind.NO_INIT
|
|
29
|
+
case _ if isinstance(init_value, DeviceArrayLike):
|
|
30
|
+
return _bindings.InitKind.FUTURE_VALUE_INIT
|
|
31
|
+
case _:
|
|
32
|
+
return _bindings.InitKind.VALUE_INIT
|
|
33
|
+
|
|
34
|
+
|
|
23
35
|
class _Scan:
|
|
24
36
|
__slots__ = [
|
|
25
37
|
"build_result",
|
|
26
38
|
"d_in_cccl",
|
|
27
39
|
"d_out_cccl",
|
|
28
|
-
"
|
|
40
|
+
"init_value_cccl",
|
|
29
41
|
"op_wrapper",
|
|
30
42
|
"device_scan_fn",
|
|
43
|
+
"init_kind",
|
|
31
44
|
]
|
|
32
45
|
|
|
33
46
|
# TODO: constructor shouldn't require concrete `d_in`, `d_out`:
|
|
@@ -36,36 +49,74 @@ class _Scan:
|
|
|
36
49
|
d_in: DeviceArrayLike | IteratorBase,
|
|
37
50
|
d_out: DeviceArrayLike | IteratorBase,
|
|
38
51
|
op: Callable | OpKind,
|
|
39
|
-
|
|
52
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
40
53
|
force_inclusive: bool,
|
|
41
54
|
):
|
|
42
55
|
self.d_in_cccl = cccl.to_cccl_input_iter(d_in)
|
|
43
56
|
self.d_out_cccl = cccl.to_cccl_output_iter(d_out)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
57
|
+
|
|
58
|
+
self.init_kind = get_init_kind(init_value)
|
|
59
|
+
|
|
60
|
+
self.init_value_cccl: _bindings.Iterator | _bindings.Value | None
|
|
61
|
+
|
|
62
|
+
match self.init_kind:
|
|
63
|
+
case _bindings.InitKind.NO_INIT:
|
|
64
|
+
# TODO: we just need to extract the dtype from the input iterator
|
|
65
|
+
if not isinstance(d_in, DeviceArrayLike):
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"No init value not supported for non-DeviceArrayLike input"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.init_value_cccl = None
|
|
71
|
+
value_type = numba.from_dtype(protocols.get_dtype(d_in))
|
|
72
|
+
init_value_type_info = self.d_in_cccl.value_type
|
|
73
|
+
|
|
74
|
+
case _bindings.InitKind.FUTURE_VALUE_INIT:
|
|
75
|
+
self.init_value_cccl = cccl.to_cccl_input_iter(init_value)
|
|
76
|
+
value_type = numba.from_dtype(
|
|
77
|
+
protocols.get_dtype(cast(DeviceArrayLike, init_value))
|
|
78
|
+
)
|
|
79
|
+
init_value_type_info = self.init_value_cccl.value_type
|
|
80
|
+
|
|
81
|
+
case _bindings.InitKind.VALUE_INIT:
|
|
82
|
+
self.init_value_cccl = cccl.to_cccl_value(init_value)
|
|
83
|
+
value_type = (
|
|
84
|
+
numba.from_dtype(init_value.dtype)
|
|
85
|
+
if isinstance(init_value, np.ndarray)
|
|
86
|
+
else numba.typeof(init_value)
|
|
87
|
+
)
|
|
88
|
+
init_value_type_info = self.init_value_cccl.type
|
|
49
89
|
|
|
50
90
|
# For well-known operations, we don't need a signature
|
|
51
91
|
if isinstance(op, OpKind):
|
|
52
92
|
self.op_wrapper = cccl.to_cccl_op(op, None)
|
|
53
93
|
else:
|
|
54
94
|
self.op_wrapper = cccl.to_cccl_op(op, value_type(value_type, value_type))
|
|
95
|
+
|
|
55
96
|
self.build_result = call_build(
|
|
56
97
|
_bindings.DeviceScanBuildResult,
|
|
57
98
|
self.d_in_cccl,
|
|
58
99
|
self.d_out_cccl,
|
|
59
100
|
self.op_wrapper,
|
|
60
|
-
|
|
101
|
+
init_value_type_info,
|
|
61
102
|
force_inclusive,
|
|
103
|
+
self.init_kind,
|
|
62
104
|
)
|
|
63
105
|
|
|
64
|
-
self.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
106
|
+
match (force_inclusive, self.init_kind):
|
|
107
|
+
case (True, _bindings.InitKind.FUTURE_VALUE_INIT):
|
|
108
|
+
self.device_scan_fn = self.build_result.compute_inclusive_future_value
|
|
109
|
+
case (True, _bindings.InitKind.VALUE_INIT):
|
|
110
|
+
self.device_scan_fn = self.build_result.compute_inclusive
|
|
111
|
+
case (True, _bindings.InitKind.NO_INIT):
|
|
112
|
+
self.device_scan_fn = self.build_result.compute_inclusive_no_init
|
|
113
|
+
|
|
114
|
+
case (False, _bindings.InitKind.FUTURE_VALUE_INIT):
|
|
115
|
+
self.device_scan_fn = self.build_result.compute_exclusive_future_value
|
|
116
|
+
case (False, _bindings.InitKind.VALUE_INIT):
|
|
117
|
+
self.device_scan_fn = self.build_result.compute_exclusive
|
|
118
|
+
case (False, _bindings.InitKind.NO_INIT):
|
|
119
|
+
raise ValueError("Exclusive scan with No init value is not supported")
|
|
69
120
|
|
|
70
121
|
def __call__(
|
|
71
122
|
self,
|
|
@@ -73,13 +124,25 @@ class _Scan:
|
|
|
73
124
|
d_in,
|
|
74
125
|
d_out,
|
|
75
126
|
num_items: int,
|
|
76
|
-
|
|
127
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
77
128
|
stream=None,
|
|
78
129
|
):
|
|
79
130
|
set_cccl_iterator_state(self.d_in_cccl, d_in)
|
|
80
131
|
set_cccl_iterator_state(self.d_out_cccl, d_out)
|
|
81
132
|
|
|
82
|
-
self.
|
|
133
|
+
match self.init_kind:
|
|
134
|
+
case _bindings.InitKind.FUTURE_VALUE_INIT:
|
|
135
|
+
# We know that the init_value_cccl is an Iterator, so this cast
|
|
136
|
+
# tells MyPy what the actual type is. cast() is a no-op at runtime,
|
|
137
|
+
# which makes it better than isinstance() since this is a hot path
|
|
138
|
+
# and we have to minimize the work we do prior to calling the
|
|
139
|
+
# kernel.
|
|
140
|
+
self.init_value_cccl = cast(_bindings.Iterator, self.init_value_cccl)
|
|
141
|
+
set_cccl_iterator_state(self.init_value_cccl, init_value)
|
|
142
|
+
|
|
143
|
+
case _bindings.InitKind.VALUE_INIT:
|
|
144
|
+
self.init_value_cccl = cast(_bindings.Value, self.init_value_cccl)
|
|
145
|
+
self.init_value_cccl.state = to_cccl_value_state(init_value)
|
|
83
146
|
|
|
84
147
|
stream_handle = validate_and_get_stream(stream)
|
|
85
148
|
|
|
@@ -97,7 +160,7 @@ class _Scan:
|
|
|
97
160
|
self.d_out_cccl,
|
|
98
161
|
num_items,
|
|
99
162
|
self.op_wrapper,
|
|
100
|
-
self.
|
|
163
|
+
self.init_value_cccl,
|
|
101
164
|
stream_handle,
|
|
102
165
|
)
|
|
103
166
|
return temp_storage_bytes
|
|
@@ -107,7 +170,7 @@ def make_cache_key(
|
|
|
107
170
|
d_in: DeviceArrayLike | IteratorBase,
|
|
108
171
|
d_out: DeviceArrayLike | IteratorBase,
|
|
109
172
|
op: Callable | OpKind,
|
|
110
|
-
|
|
173
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
111
174
|
):
|
|
112
175
|
d_in_key = (
|
|
113
176
|
d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in)
|
|
@@ -123,8 +186,17 @@ def make_cache_key(
|
|
|
123
186
|
else:
|
|
124
187
|
op_key = CachableFunction(op)
|
|
125
188
|
|
|
126
|
-
|
|
127
|
-
|
|
189
|
+
init_kind_key = get_init_kind(init_value)
|
|
190
|
+
match init_kind_key:
|
|
191
|
+
case _bindings.InitKind.NO_INIT:
|
|
192
|
+
init_value_key = None
|
|
193
|
+
case _bindings.InitKind.FUTURE_VALUE_INIT:
|
|
194
|
+
init_value_key = protocols.get_dtype(cast(DeviceArrayLike, init_value))
|
|
195
|
+
case _bindings.InitKind.VALUE_INIT:
|
|
196
|
+
init_value = cast(np.ndarray | GpuStruct, init_value)
|
|
197
|
+
init_value_key = init_value.dtype
|
|
198
|
+
|
|
199
|
+
return (d_in_key, d_out_key, op_key, init_value_key, init_kind_key)
|
|
128
200
|
|
|
129
201
|
|
|
130
202
|
# TODO Figure out `sum` without operator and initial value
|
|
@@ -134,7 +206,7 @@ def make_exclusive_scan(
|
|
|
134
206
|
d_in: DeviceArrayLike | IteratorBase,
|
|
135
207
|
d_out: DeviceArrayLike | IteratorBase,
|
|
136
208
|
op: Callable | OpKind,
|
|
137
|
-
|
|
209
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
138
210
|
):
|
|
139
211
|
"""Computes a device-wide scan using the specified binary ``op`` and initial value ``init``.
|
|
140
212
|
|
|
@@ -150,19 +222,19 @@ def make_exclusive_scan(
|
|
|
150
222
|
d_in: Device array or iterator containing the input sequence of data items
|
|
151
223
|
d_out: Device array that will store the result of the scan
|
|
152
224
|
op: Callable or OpKind representing the binary operator to apply
|
|
153
|
-
|
|
225
|
+
init_value: Numpy array, device array, or GPU struct storing initial value of the scan, or None for no initial value
|
|
154
226
|
|
|
155
227
|
Returns:
|
|
156
228
|
A callable object that can be used to perform the scan
|
|
157
229
|
"""
|
|
158
|
-
return _Scan(d_in, d_out, op,
|
|
230
|
+
return _Scan(d_in, d_out, op, init_value, False)
|
|
159
231
|
|
|
160
232
|
|
|
161
233
|
def exclusive_scan(
|
|
162
234
|
d_in: DeviceArrayLike | IteratorBase,
|
|
163
235
|
d_out: DeviceArrayLike | IteratorBase,
|
|
164
236
|
op: Callable | OpKind,
|
|
165
|
-
|
|
237
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
166
238
|
num_items: int,
|
|
167
239
|
stream=None,
|
|
168
240
|
):
|
|
@@ -183,14 +255,14 @@ def exclusive_scan(
|
|
|
183
255
|
d_in: Device array or iterator containing the input sequence of data items
|
|
184
256
|
d_out: Device array or iterator to store the result of the scan
|
|
185
257
|
op: Binary scan operator
|
|
186
|
-
|
|
258
|
+
init_value: Initial value for the scan
|
|
187
259
|
num_items: Number of items to scan
|
|
188
260
|
stream: CUDA stream for the operation (optional)
|
|
189
261
|
"""
|
|
190
|
-
scanner = make_exclusive_scan(d_in, d_out, op,
|
|
191
|
-
tmp_storage_bytes = scanner(None, d_in, d_out, num_items,
|
|
262
|
+
scanner = make_exclusive_scan(d_in, d_out, op, init_value)
|
|
263
|
+
tmp_storage_bytes = scanner(None, d_in, d_out, num_items, init_value, stream)
|
|
192
264
|
tmp_storage = TempStorageBuffer(tmp_storage_bytes, stream)
|
|
193
|
-
scanner(tmp_storage, d_in, d_out, num_items,
|
|
265
|
+
scanner(tmp_storage, d_in, d_out, num_items, init_value, stream)
|
|
194
266
|
|
|
195
267
|
|
|
196
268
|
# TODO Figure out `sum` without operator and initial value
|
|
@@ -200,7 +272,7 @@ def make_inclusive_scan(
|
|
|
200
272
|
d_in: DeviceArrayLike | IteratorBase,
|
|
201
273
|
d_out: DeviceArrayLike | IteratorBase,
|
|
202
274
|
op: Callable | OpKind,
|
|
203
|
-
|
|
275
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
204
276
|
):
|
|
205
277
|
"""Computes a device-wide scan using the specified binary ``op`` and initial value ``init``.
|
|
206
278
|
|
|
@@ -216,19 +288,19 @@ def make_inclusive_scan(
|
|
|
216
288
|
d_in: Device array or iterator containing the input sequence of data items
|
|
217
289
|
d_out: Device array that will store the result of the scan
|
|
218
290
|
op: Callable or OpKind representing the binary operator to apply
|
|
219
|
-
|
|
291
|
+
init_value: Numpy array, device array, or GPU struct storing initial value of the scan, or None for no initial value
|
|
220
292
|
|
|
221
293
|
Returns:
|
|
222
294
|
A callable object that can be used to perform the scan
|
|
223
295
|
"""
|
|
224
|
-
return _Scan(d_in, d_out, op,
|
|
296
|
+
return _Scan(d_in, d_out, op, init_value, True)
|
|
225
297
|
|
|
226
298
|
|
|
227
299
|
def inclusive_scan(
|
|
228
300
|
d_in: DeviceArrayLike | IteratorBase,
|
|
229
301
|
d_out: DeviceArrayLike | IteratorBase,
|
|
230
302
|
op: Callable | OpKind,
|
|
231
|
-
|
|
303
|
+
init_value: np.ndarray | DeviceArrayLike | GpuStruct | None,
|
|
232
304
|
num_items: int,
|
|
233
305
|
stream=None,
|
|
234
306
|
):
|
|
@@ -249,11 +321,11 @@ def inclusive_scan(
|
|
|
249
321
|
d_in: Device array or iterator containing the input sequence of data items
|
|
250
322
|
d_out: Device array or iterator to store the result of the scan
|
|
251
323
|
op: Binary scan operator
|
|
252
|
-
|
|
324
|
+
init_value: Initial value for the scan
|
|
253
325
|
num_items: Number of items to scan
|
|
254
326
|
stream: CUDA stream for the operation (optional)
|
|
255
327
|
"""
|
|
256
|
-
scanner = make_inclusive_scan(d_in, d_out, op,
|
|
257
|
-
tmp_storage_bytes = scanner(None, d_in, d_out, num_items,
|
|
328
|
+
scanner = make_inclusive_scan(d_in, d_out, op, init_value)
|
|
329
|
+
tmp_storage_bytes = scanner(None, d_in, d_out, num_items, init_value, stream)
|
|
258
330
|
tmp_storage = TempStorageBuffer(tmp_storage_bytes, stream)
|
|
259
|
-
scanner(tmp_storage, d_in, d_out, num_items,
|
|
331
|
+
scanner(tmp_storage, d_in, d_out, num_items, init_value, stream)
|
|
@@ -11,7 +11,7 @@ from .._caching import CachableFunction, cache_with_key
|
|
|
11
11
|
from .._cccl_interop import set_cccl_iterator_state
|
|
12
12
|
from .._utils import protocols
|
|
13
13
|
from ..iterators._iterators import IteratorBase
|
|
14
|
-
from ..numba_utils import get_inferred_return_type
|
|
14
|
+
from ..numba_utils import get_inferred_return_type, signature_from_annotations
|
|
15
15
|
from ..op import OpKind
|
|
16
16
|
from ..typing import DeviceArrayLike
|
|
17
17
|
|
|
@@ -32,16 +32,20 @@ class _UnaryTransform:
|
|
|
32
32
|
):
|
|
33
33
|
self.d_in_cccl = cccl.to_cccl_input_iter(d_in)
|
|
34
34
|
self.d_out_cccl = cccl.to_cccl_output_iter(d_out)
|
|
35
|
-
in_value_type = cccl.get_value_type(d_in)
|
|
36
|
-
out_value_type = cccl.get_value_type(d_out)
|
|
37
35
|
|
|
38
36
|
# For well-known operations, we don't need a signature
|
|
39
37
|
if isinstance(op, OpKind):
|
|
40
38
|
self.op_wrapper = cccl.to_cccl_op(op, None)
|
|
41
39
|
else:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
40
|
+
try:
|
|
41
|
+
sig = signature_from_annotations(op)
|
|
42
|
+
except ValueError:
|
|
43
|
+
in_value_type = cccl.get_value_type(d_in)
|
|
44
|
+
out_value_type = cccl.get_value_type(d_out)
|
|
45
|
+
if not out_value_type.is_internal:
|
|
46
|
+
out_value_type = get_inferred_return_type(op, (in_value_type,))
|
|
47
|
+
sig = out_value_type(in_value_type)
|
|
48
|
+
|
|
45
49
|
self.op_wrapper = cccl.to_cccl_op(op, sig=sig)
|
|
46
50
|
self.build_result = cccl.call_build(
|
|
47
51
|
_bindings.DeviceUnaryTransform,
|
|
@@ -97,11 +101,14 @@ class _BinaryTransform:
|
|
|
97
101
|
if isinstance(op, OpKind):
|
|
98
102
|
self.op_wrapper = cccl.to_cccl_op(op, None)
|
|
99
103
|
else:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
104
|
+
try:
|
|
105
|
+
sig = signature_from_annotations(op)
|
|
106
|
+
except ValueError:
|
|
107
|
+
if not out_value_type.is_internal:
|
|
108
|
+
out_value_type = get_inferred_return_type(
|
|
109
|
+
op, (in1_value_type, in2_value_type)
|
|
110
|
+
)
|
|
111
|
+
sig = out_value_type(in1_value_type, in2_value_type)
|
|
105
112
|
self.op_wrapper = cccl.to_cccl_op(op, sig=sig)
|
|
106
113
|
self.build_result = cccl.call_build(
|
|
107
114
|
_bindings.DeviceBinaryTransform,
|
|
@@ -263,6 +270,13 @@ def unary_transform(
|
|
|
263
270
|
:language: python
|
|
264
271
|
:start-after: # example-begin
|
|
265
272
|
|
|
273
|
+
When working with custom struct types, you need to provide type annotations
|
|
274
|
+
to help with type inference. See the binary transform struct example for reference:
|
|
275
|
+
|
|
276
|
+
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/transform/binary_transform_struct.py
|
|
277
|
+
:language: python
|
|
278
|
+
:start-after: # example-begin
|
|
279
|
+
|
|
266
280
|
|
|
267
281
|
Args:
|
|
268
282
|
d_in: Device array or iterator containing the input sequence of data items.
|
|
@@ -295,6 +309,13 @@ def binary_transform(
|
|
|
295
309
|
:language: python
|
|
296
310
|
:start-after: # example-begin
|
|
297
311
|
|
|
312
|
+
When working with custom struct types, you need to provide type annotations
|
|
313
|
+
to help with type inference. See the following example:
|
|
314
|
+
|
|
315
|
+
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/transform/binary_transform_struct.py
|
|
316
|
+
:language: python
|
|
317
|
+
:start-after: # example-begin
|
|
318
|
+
|
|
298
319
|
|
|
299
320
|
Args:
|
|
300
321
|
d_in1: Device array or iterator containing the first input sequence of data items.
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -2,6 +2,7 @@ from ._factories import (
|
|
|
2
2
|
CacheModifiedInputIterator,
|
|
3
3
|
ConstantIterator,
|
|
4
4
|
CountingIterator,
|
|
5
|
+
PermutationIterator,
|
|
5
6
|
ReverseIterator,
|
|
6
7
|
TransformIterator,
|
|
7
8
|
TransformOutputIterator,
|
|
@@ -12,6 +13,7 @@ __all__ = [
|
|
|
12
13
|
"CacheModifiedInputIterator",
|
|
13
14
|
"ConstantIterator",
|
|
14
15
|
"CountingIterator",
|
|
16
|
+
"PermutationIterator",
|
|
15
17
|
"ReverseIterator",
|
|
16
18
|
"TransformIterator",
|
|
17
19
|
"TransformOutputIterator",
|
|
@@ -10,6 +10,7 @@ from ._iterators import (
|
|
|
10
10
|
CountingIterator as _CountingIterator,
|
|
11
11
|
)
|
|
12
12
|
from ._iterators import (
|
|
13
|
+
make_permutation_iterator,
|
|
13
14
|
make_reverse_iterator,
|
|
14
15
|
make_transform_iterator,
|
|
15
16
|
)
|
|
@@ -165,6 +166,33 @@ def TransformOutputIterator(it, op):
|
|
|
165
166
|
return make_transform_iterator(it, op, "output")
|
|
166
167
|
|
|
167
168
|
|
|
169
|
+
def PermutationIterator(values, indices):
|
|
170
|
+
"""Returns an Iterator that accesses values through an index mapping.
|
|
171
|
+
|
|
172
|
+
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1permutation__iterator.html
|
|
173
|
+
|
|
174
|
+
The permutation iterator accesses elements from the values collection using indices
|
|
175
|
+
from the indices collection, effectively computing values[indices[i]] at position i.
|
|
176
|
+
This is useful for gather/scatter operations and indirect array access patterns.
|
|
177
|
+
|
|
178
|
+
Example:
|
|
179
|
+
The code snippet below demonstrates the usage of a ``PermutationIterator``
|
|
180
|
+
to access values in a permuted order:
|
|
181
|
+
|
|
182
|
+
.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/iterator/permutation_iterator_basic.py
|
|
183
|
+
:language: python
|
|
184
|
+
:start-after: # example-begin
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
values: The values array or iterator to be permuted
|
|
188
|
+
indices: An iterator or device array providing the indices for permutation
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A ``PermutationIterator`` object that yields values[indices[i]] at position i
|
|
192
|
+
"""
|
|
193
|
+
return make_permutation_iterator(values, indices)
|
|
194
|
+
|
|
195
|
+
|
|
168
196
|
def ZipIterator(*iterators):
|
|
169
197
|
"""Returns an Iterator representing a zipped sequence of values from N iterators.
|
|
170
198
|
|
|
@@ -207,7 +207,15 @@ def pointer_add_intrinsic(context, ptr, offset):
|
|
|
207
207
|
def codegen(context, builder, sig, args):
|
|
208
208
|
ptr, index = args
|
|
209
209
|
base = builder.ptrtoint(ptr, ir.IntType(_DEVICE_POINTER_BITWIDTH))
|
|
210
|
-
|
|
210
|
+
sizeof = sizeof_pointee(context, ptr)
|
|
211
|
+
# Cast index to match sizeof type if needed
|
|
212
|
+
if index.type != sizeof.type:
|
|
213
|
+
index = (
|
|
214
|
+
builder.sext(index, sizeof.type)
|
|
215
|
+
if index.type.width < sizeof.type.width
|
|
216
|
+
else builder.trunc(index, sizeof.type)
|
|
217
|
+
)
|
|
218
|
+
offset = builder.mul(index, sizeof)
|
|
211
219
|
result = builder.add(base, offset)
|
|
212
220
|
return builder.inttoptr(result, ptr.type)
|
|
213
221
|
|
|
@@ -610,3 +618,200 @@ def _get_last_element_ptr(device_array) -> int:
|
|
|
610
618
|
|
|
611
619
|
ptr = get_data_pointer(device_array)
|
|
612
620
|
return ptr + offset_to_last_element
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
class PermutationIteratorKind(IteratorKind):
|
|
624
|
+
pass
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def make_permutation_iterator(values, indices):
|
|
628
|
+
"""
|
|
629
|
+
Create a PermutationIterator that accesses values through an index mapping.
|
|
630
|
+
|
|
631
|
+
The permutation iterator accesses elements from `values` using indices from `indices`,
|
|
632
|
+
effectively computing values[indices[i]] at position i.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
values: The values array or iterator to permute
|
|
636
|
+
indices: The indices array or iterator specifying the permutation
|
|
637
|
+
|
|
638
|
+
Returns:
|
|
639
|
+
PermutationIterator: Iterator that yields permuted values
|
|
640
|
+
"""
|
|
641
|
+
# Convert arrays to iterators if needed
|
|
642
|
+
if hasattr(values, "__cuda_array_interface__"):
|
|
643
|
+
values = pointer(values, numba.from_dtype(get_dtype(values)))
|
|
644
|
+
elif not isinstance(values, IteratorBase):
|
|
645
|
+
raise TypeError("values must be a device array or iterator")
|
|
646
|
+
|
|
647
|
+
if hasattr(indices, "__cuda_array_interface__"):
|
|
648
|
+
indices = pointer(indices, numba.from_dtype(get_dtype(indices)))
|
|
649
|
+
elif not isinstance(indices, IteratorBase):
|
|
650
|
+
raise TypeError("indices must be an iterator or device array")
|
|
651
|
+
|
|
652
|
+
# JIT compile value advance/dereference methods
|
|
653
|
+
value_dtype = values.value_type
|
|
654
|
+
values_state_type = values.state_type
|
|
655
|
+
index_type = indices.value_type
|
|
656
|
+
value_advance = cuda.jit(values.advance, device=True)
|
|
657
|
+
value_input_dereference = cuda.jit(values.input_dereference, device=True)
|
|
658
|
+
|
|
659
|
+
try:
|
|
660
|
+
output_deref = values.output_dereference
|
|
661
|
+
if output_deref is not None:
|
|
662
|
+
value_output_dereference = cuda.jit(output_deref, device=True)
|
|
663
|
+
values_is_output_iterator = True
|
|
664
|
+
else:
|
|
665
|
+
values_is_output_iterator = False
|
|
666
|
+
except AttributeError:
|
|
667
|
+
values_is_output_iterator = False
|
|
668
|
+
|
|
669
|
+
# JIT compile index advance/dereference methods
|
|
670
|
+
index_advance = cuda.jit(indices.advance, device=True)
|
|
671
|
+
index_input_dereference = cuda.jit(indices.input_dereference, device=True)
|
|
672
|
+
|
|
673
|
+
# The cvalue and state for PermutationIterator are
|
|
674
|
+
# structs composed of the cvalues and states of the
|
|
675
|
+
# value and index iterators.
|
|
676
|
+
from ..struct import gpu_struct_from_numba_types
|
|
677
|
+
|
|
678
|
+
class PermutationCValueStruct(ctypes.Structure):
|
|
679
|
+
_fields_ = [
|
|
680
|
+
("value_state", values.cvalue.__class__),
|
|
681
|
+
("index_state", indices.cvalue.__class__),
|
|
682
|
+
]
|
|
683
|
+
|
|
684
|
+
PermutationState = gpu_struct_from_numba_types(
|
|
685
|
+
"PermutationState",
|
|
686
|
+
("value_state", "index_state"),
|
|
687
|
+
(values_state_type, indices.state_type),
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
cvalue = PermutationCValueStruct(values.cvalue, indices.cvalue)
|
|
691
|
+
state_type = PermutationState._numba_type
|
|
692
|
+
value_type = value_dtype
|
|
693
|
+
|
|
694
|
+
# Define intrinsics for accessing struct fields
|
|
695
|
+
@intrinsic
|
|
696
|
+
def get_value_state_field_ptr(context, struct_ptr_type):
|
|
697
|
+
def codegen(context, builder, sig, args):
|
|
698
|
+
struct_ptr = args[0]
|
|
699
|
+
# Use GEP to get pointer to field at index 0 (value_state)
|
|
700
|
+
field_ptr = builder.gep(
|
|
701
|
+
struct_ptr,
|
|
702
|
+
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
|
|
703
|
+
)
|
|
704
|
+
return field_ptr
|
|
705
|
+
|
|
706
|
+
from numba.core.datamodel.registry import default_manager
|
|
707
|
+
|
|
708
|
+
struct_model = default_manager.lookup(struct_ptr_type.dtype)
|
|
709
|
+
field_type = struct_model._members[0]
|
|
710
|
+
return types.CPointer(field_type)(struct_ptr_type), codegen
|
|
711
|
+
|
|
712
|
+
@intrinsic
|
|
713
|
+
def get_index_state_field_ptr(context, struct_ptr_type):
|
|
714
|
+
def codegen(context, builder, sig, args):
|
|
715
|
+
struct_ptr = args[0]
|
|
716
|
+
# Use GEP to get pointer to field at index 1 (index_state)
|
|
717
|
+
field_ptr = builder.gep(
|
|
718
|
+
struct_ptr,
|
|
719
|
+
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 1)],
|
|
720
|
+
)
|
|
721
|
+
return field_ptr
|
|
722
|
+
|
|
723
|
+
from numba.core.datamodel.registry import default_manager
|
|
724
|
+
|
|
725
|
+
struct_model = default_manager.lookup(struct_ptr_type.dtype)
|
|
726
|
+
field_type = struct_model._members[1]
|
|
727
|
+
return types.CPointer(field_type)(struct_ptr_type), codegen
|
|
728
|
+
|
|
729
|
+
# Create intrinsic for allocating temporary storage for index
|
|
730
|
+
@intrinsic
|
|
731
|
+
def alloca_temp_for_index_type(context):
|
|
732
|
+
def codegen(context, builder, sig, args):
|
|
733
|
+
temp_value_type = context.get_value_type(index_type)
|
|
734
|
+
temp_ptr = builder.alloca(temp_value_type)
|
|
735
|
+
return temp_ptr
|
|
736
|
+
|
|
737
|
+
return types.CPointer(index_type)(), codegen
|
|
738
|
+
|
|
739
|
+
# Create intrinsic for allocating temporary storage for value state
|
|
740
|
+
@intrinsic
|
|
741
|
+
def alloca_temp_for_value_state(context):
|
|
742
|
+
def codegen(context, builder, sig, args):
|
|
743
|
+
temp_state_type = context.get_value_type(values_state_type)
|
|
744
|
+
temp_ptr = builder.alloca(temp_state_type)
|
|
745
|
+
return temp_ptr
|
|
746
|
+
|
|
747
|
+
return types.CPointer(values_state_type)(), codegen
|
|
748
|
+
|
|
749
|
+
class PermutationIterator(IteratorBase):
|
|
750
|
+
iterator_kind_type = PermutationIteratorKind
|
|
751
|
+
|
|
752
|
+
def __init__(self, values_it, indices_it):
|
|
753
|
+
self._values = values_it
|
|
754
|
+
self._indices = indices_it
|
|
755
|
+
super().__init__(
|
|
756
|
+
cvalue=cvalue,
|
|
757
|
+
state_type=state_type,
|
|
758
|
+
value_type=value_type,
|
|
759
|
+
)
|
|
760
|
+
self._kind = self.__class__.iterator_kind_type(
|
|
761
|
+
(value_type, values_it.kind, indices_it.kind), state_type
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
@property
|
|
765
|
+
def advance(self):
|
|
766
|
+
return PermutationIterator._advance
|
|
767
|
+
|
|
768
|
+
@property
|
|
769
|
+
def input_dereference(self):
|
|
770
|
+
return PermutationIterator._input_dereference
|
|
771
|
+
|
|
772
|
+
@property
|
|
773
|
+
def output_dereference(self):
|
|
774
|
+
if not values_is_output_iterator:
|
|
775
|
+
raise AttributeError(
|
|
776
|
+
"PermutationIterator cannot be used as output iterator "
|
|
777
|
+
"when values iterator does not support output"
|
|
778
|
+
)
|
|
779
|
+
return PermutationIterator._output_dereference
|
|
780
|
+
|
|
781
|
+
@staticmethod
|
|
782
|
+
def _advance(state, distance):
|
|
783
|
+
# advance the index iterator
|
|
784
|
+
index_state_ptr = get_index_state_field_ptr(state)
|
|
785
|
+
index_advance(index_state_ptr, distance)
|
|
786
|
+
|
|
787
|
+
@staticmethod
|
|
788
|
+
def _input_dereference(state, result):
|
|
789
|
+
# dereference index to get the index value
|
|
790
|
+
index_state_ptr = get_index_state_field_ptr(state)
|
|
791
|
+
temp_index = alloca_temp_for_index_type()
|
|
792
|
+
index_input_dereference(index_state_ptr, temp_index)
|
|
793
|
+
|
|
794
|
+
# copy the value state (which always points to position 0)
|
|
795
|
+
# and advance it by the index value
|
|
796
|
+
value_state_ptr = get_value_state_field_ptr(state)
|
|
797
|
+
temp_value_state = alloca_temp_for_value_state()
|
|
798
|
+
temp_value_state[0] = value_state_ptr[0]
|
|
799
|
+
value_advance(temp_value_state, temp_index[0])
|
|
800
|
+
value_input_dereference(temp_value_state, result)
|
|
801
|
+
|
|
802
|
+
@staticmethod
|
|
803
|
+
def _output_dereference(state, x):
|
|
804
|
+
# dereference index to get the index value
|
|
805
|
+
index_state_ptr = get_index_state_field_ptr(state)
|
|
806
|
+
temp_index = alloca_temp_for_index_type()
|
|
807
|
+
index_input_dereference(index_state_ptr, temp_index)
|
|
808
|
+
|
|
809
|
+
# copy the value state (which always points to position 0)
|
|
810
|
+
# and advance it by the index value
|
|
811
|
+
value_state_ptr = get_value_state_field_ptr(state)
|
|
812
|
+
temp_value_state = alloca_temp_for_value_state()
|
|
813
|
+
temp_value_state[0] = value_state_ptr[0]
|
|
814
|
+
value_advance(temp_value_state, temp_index[0])
|
|
815
|
+
value_output_dereference(temp_value_state, x)
|
|
816
|
+
|
|
817
|
+
return PermutationIterator(values, indices)
|
cuda/compute/numba_utils.py
CHANGED
|
@@ -39,10 +39,10 @@ def signature_from_annotations(func) -> numba.core.typing.Signature:
|
|
|
39
39
|
argspec = inspect.getfullargspec(func)
|
|
40
40
|
num_args = len(argspec.args)
|
|
41
41
|
try:
|
|
42
|
-
|
|
42
|
+
ret_ann = argspec.annotations["return"]
|
|
43
43
|
except KeyError:
|
|
44
44
|
raise ValueError("Function has incomplete annotations: missing return type")
|
|
45
|
-
|
|
45
|
+
retty = to_numba_type(ret_ann)
|
|
46
46
|
if num_args != len(argspec.annotations) - 1: # -1 for the return type
|
|
47
47
|
raise ValueError("One or more arguments are missing type annotations")
|
|
48
48
|
argtys = tuple(
|
cuda/compute/typing.py
CHANGED
|
@@ -7,9 +7,11 @@ from typing import Any
|
|
|
7
7
|
|
|
8
8
|
from typing_extensions import (
|
|
9
9
|
Protocol,
|
|
10
|
+
runtime_checkable,
|
|
10
11
|
) # TODO: typing_extensions required for Python 3.7 docs env
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
@runtime_checkable
|
|
13
15
|
class DeviceArrayLike(Protocol):
|
|
14
16
|
"""
|
|
15
17
|
Objects representing a device array, having a `.__cuda_array_interface__`
|