pyopencl 2025.1__cp310-cp310-win_amd64.whl → 2025.2.2__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyopencl might be problematic. Click here for more details.
- pyopencl/__init__.py +582 -997
- pyopencl/_cl.cp310-win_amd64.pyd +0 -0
- pyopencl/_cl.pyi +2006 -0
- pyopencl/_cluda.py +3 -0
- pyopencl/_monkeypatch.py +1063 -0
- pyopencl/_mymako.py +3 -0
- pyopencl/algorithm.py +29 -24
- pyopencl/array.py +300 -255
- pyopencl/bitonic_sort.py +5 -2
- pyopencl/bitonic_sort_templates.py +3 -0
- pyopencl/cache.py +5 -5
- pyopencl/capture_call.py +31 -8
- pyopencl/characterize/__init__.py +26 -19
- pyopencl/characterize/performance.py +3 -0
- pyopencl/clmath.py +2 -0
- pyopencl/clrandom.py +3 -0
- pyopencl/cltypes.py +67 -2
- pyopencl/compyte/.basedpyright/baseline.json +1272 -0
- pyopencl/compyte/array.py +36 -9
- pyopencl/compyte/dtypes.py +61 -29
- pyopencl/compyte/pyproject.toml +17 -22
- pyopencl/elementwise.py +13 -10
- pyopencl/invoker.py +13 -17
- pyopencl/ipython_ext.py +2 -0
- pyopencl/py.typed +0 -0
- pyopencl/reduction.py +72 -43
- pyopencl/scan.py +31 -30
- pyopencl/tools.py +128 -90
- pyopencl/typing.py +57 -0
- pyopencl/version.py +2 -0
- {pyopencl-2025.1.dist-info → pyopencl-2025.2.2.dist-info}/METADATA +11 -10
- pyopencl-2025.2.2.dist-info/RECORD +47 -0
- {pyopencl-2025.1.dist-info → pyopencl-2025.2.2.dist-info}/WHEEL +1 -1
- pyopencl-2025.1.dist-info/RECORD +0 -42
- {pyopencl-2025.1.dist-info → pyopencl-2025.2.2.dist-info}/licenses/LICENSE +0 -0
pyopencl/reduction.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
"""Computation of reductions on vectors."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
2
4
|
|
|
3
5
|
__copyright__ = "Copyright (C) 2010 Andreas Kloeckner"
|
|
4
6
|
|
|
@@ -28,12 +30,14 @@ Based on code/ideas by Mark Harris <mharris@nvidia.com>.
|
|
|
28
30
|
None of the original source code remains.
|
|
29
31
|
"""
|
|
30
32
|
|
|
33
|
+
import builtins
|
|
31
34
|
from dataclasses import dataclass
|
|
32
|
-
from typing import
|
|
35
|
+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
|
33
36
|
|
|
34
37
|
import numpy as np
|
|
35
38
|
|
|
36
39
|
import pyopencl as cl
|
|
40
|
+
import pyopencl.array as cl_array
|
|
37
41
|
from pyopencl.tools import (
|
|
38
42
|
DtypedArgument,
|
|
39
43
|
KernelTemplateBase,
|
|
@@ -43,6 +47,10 @@ from pyopencl.tools import (
|
|
|
43
47
|
)
|
|
44
48
|
|
|
45
49
|
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from pyopencl.typing import Allocator
|
|
52
|
+
|
|
53
|
+
|
|
46
54
|
# {{{ kernel source
|
|
47
55
|
|
|
48
56
|
KERNEL = r"""//CL//
|
|
@@ -133,7 +141,7 @@ class _ReductionInfo:
|
|
|
133
141
|
|
|
134
142
|
program: cl.Program
|
|
135
143
|
kernel: cl.Kernel
|
|
136
|
-
arg_types:
|
|
144
|
+
arg_types: list[DtypedArgument]
|
|
137
145
|
|
|
138
146
|
|
|
139
147
|
def _get_reduction_source(
|
|
@@ -143,12 +151,12 @@ def _get_reduction_source(
|
|
|
143
151
|
neutral: str,
|
|
144
152
|
reduce_expr: str,
|
|
145
153
|
map_expr: str,
|
|
146
|
-
parsed_args:
|
|
154
|
+
parsed_args: list[DtypedArgument],
|
|
147
155
|
name: str = "reduce_kernel",
|
|
148
156
|
preamble: str = "",
|
|
149
157
|
arg_prep: str = "",
|
|
150
|
-
device:
|
|
151
|
-
max_group_size:
|
|
158
|
+
device: cl.Device | None = None,
|
|
159
|
+
max_group_size: int | None = None) -> tuple[str, int]:
|
|
152
160
|
|
|
153
161
|
if device is not None:
|
|
154
162
|
devices = [device]
|
|
@@ -209,13 +217,13 @@ def get_reduction_kernel(
|
|
|
209
217
|
dtype_out: Any,
|
|
210
218
|
neutral: str,
|
|
211
219
|
reduce_expr: str,
|
|
212
|
-
map_expr:
|
|
213
|
-
arguments:
|
|
220
|
+
map_expr: str | None = None,
|
|
221
|
+
arguments: list[DtypedArgument] | None = None,
|
|
214
222
|
name: str = "reduce_kernel",
|
|
215
223
|
preamble: str = "",
|
|
216
|
-
device:
|
|
224
|
+
device: cl.Device | None = None,
|
|
217
225
|
options: Any = None,
|
|
218
|
-
max_group_size:
|
|
226
|
+
max_group_size: int | None = None) -> _ReductionInfo:
|
|
219
227
|
if stage not in (1, 2):
|
|
220
228
|
raise ValueError(f"unknown stage index: '{stage}'")
|
|
221
229
|
|
|
@@ -308,8 +316,8 @@ class ReductionKernel:
|
|
|
308
316
|
dtype_out: Any,
|
|
309
317
|
neutral: str,
|
|
310
318
|
reduce_expr: str,
|
|
311
|
-
map_expr:
|
|
312
|
-
arguments:
|
|
319
|
+
map_expr: str | None = None,
|
|
320
|
+
arguments: str | list[DtypedArgument] | None = None,
|
|
313
321
|
name: str = "reduce_kernel",
|
|
314
322
|
options: Any = None,
|
|
315
323
|
preamble: str = "") -> None:
|
|
@@ -349,7 +357,40 @@ class ReductionKernel:
|
|
|
349
357
|
name=f"{name}_stage2", options=options, preamble=preamble,
|
|
350
358
|
max_group_size=max_group_size)
|
|
351
359
|
|
|
352
|
-
|
|
360
|
+
@overload
|
|
361
|
+
def __call__(self,
|
|
362
|
+
*args: object,
|
|
363
|
+
return_event: Literal[True],
|
|
364
|
+
queue: cl.CommandQueue | None = None,
|
|
365
|
+
allocator: Allocator | None = None,
|
|
366
|
+
wait_for: cl.WaitList = None,
|
|
367
|
+
out: cl_array.Array | None = None,
|
|
368
|
+
range: slice | None = None,
|
|
369
|
+
slice: slice | None = None
|
|
370
|
+
) -> tuple[cl_array.Array, cl.Event]: ...
|
|
371
|
+
|
|
372
|
+
@overload
|
|
373
|
+
def __call__(self,
|
|
374
|
+
*args: object,
|
|
375
|
+
return_event: Literal[False],
|
|
376
|
+
queue: cl.CommandQueue | None = None,
|
|
377
|
+
allocator: Allocator | None = None,
|
|
378
|
+
wait_for: cl.WaitList = None,
|
|
379
|
+
out: cl_array.Array | None = None,
|
|
380
|
+
range: slice | None = None,
|
|
381
|
+
slice: slice | None = None
|
|
382
|
+
) -> cl_array.Array: ...
|
|
383
|
+
|
|
384
|
+
def __call__(self,
|
|
385
|
+
*args: object,
|
|
386
|
+
return_event: bool = False,
|
|
387
|
+
queue: cl.CommandQueue | None = None,
|
|
388
|
+
allocator: Allocator | None = None,
|
|
389
|
+
wait_for: cl.WaitList = None,
|
|
390
|
+
out: cl_array.Array | None = None,
|
|
391
|
+
range: slice | None = None,
|
|
392
|
+
slice: slice | None = None
|
|
393
|
+
) -> cl_array.Array | tuple[cl_array.Array, cl.Event]:
|
|
353
394
|
"""Invoke the generated kernel.
|
|
354
395
|
|
|
355
396
|
|explain-waitfor|
|
|
@@ -388,18 +429,6 @@ class ReductionKernel:
|
|
|
388
429
|
``(scalar_array, event)``.
|
|
389
430
|
"""
|
|
390
431
|
|
|
391
|
-
queue = kwargs.pop("queue", None)
|
|
392
|
-
allocator = kwargs.pop("allocator", None)
|
|
393
|
-
wait_for = kwargs.pop("wait_for", None)
|
|
394
|
-
return_event = kwargs.pop("return_event", False)
|
|
395
|
-
out = kwargs.pop("out", None)
|
|
396
|
-
|
|
397
|
-
range_ = kwargs.pop("range", None)
|
|
398
|
-
slice_ = kwargs.pop("slice", None)
|
|
399
|
-
|
|
400
|
-
if kwargs:
|
|
401
|
-
raise TypeError("invalid keyword argument to reduction kernel")
|
|
402
|
-
|
|
403
432
|
if wait_for is None:
|
|
404
433
|
wait_for = []
|
|
405
434
|
else:
|
|
@@ -413,13 +442,14 @@ class ReductionKernel:
|
|
|
413
442
|
|
|
414
443
|
while True:
|
|
415
444
|
invocation_args = []
|
|
416
|
-
vectors = []
|
|
445
|
+
vectors: list[cl_array.Array] = []
|
|
417
446
|
|
|
418
447
|
array_empty = empty
|
|
419
448
|
|
|
420
449
|
from pyopencl.tools import VectorArg
|
|
421
|
-
for arg, arg_tp in zip(args, stage_inf.arg_types):
|
|
450
|
+
for arg, arg_tp in zip(args, stage_inf.arg_types, strict=True):
|
|
422
451
|
if isinstance(arg_tp, VectorArg):
|
|
452
|
+
assert isinstance(arg, cl_array.Array)
|
|
423
453
|
array_empty = arg.__class__
|
|
424
454
|
if not arg.flags.forc:
|
|
425
455
|
raise RuntimeError(
|
|
@@ -441,31 +471,30 @@ class ReductionKernel:
|
|
|
441
471
|
|
|
442
472
|
# {{{ range/slice processing
|
|
443
473
|
|
|
444
|
-
if
|
|
445
|
-
if
|
|
474
|
+
if range is not None:
|
|
475
|
+
if slice is not None:
|
|
446
476
|
raise TypeError("may not specify both range and slice "
|
|
447
477
|
"keyword arguments")
|
|
448
478
|
|
|
449
479
|
else:
|
|
450
|
-
if
|
|
451
|
-
|
|
480
|
+
if slice is None:
|
|
481
|
+
slice = builtins.slice(None)
|
|
452
482
|
|
|
453
483
|
if repr_vec is None:
|
|
454
484
|
raise TypeError(
|
|
455
485
|
"must have vector argument when range is not specified")
|
|
456
486
|
|
|
457
|
-
|
|
487
|
+
range = builtins.slice(*slice.indices(repr_vec.size))
|
|
458
488
|
|
|
459
|
-
assert
|
|
489
|
+
assert range is not None
|
|
460
490
|
|
|
461
|
-
start =
|
|
491
|
+
start = cast("int | None", range.start)
|
|
462
492
|
if start is None:
|
|
463
493
|
start = 0
|
|
464
|
-
|
|
494
|
+
step = cast("int | None", range.step)
|
|
495
|
+
if step is None:
|
|
465
496
|
step = 1
|
|
466
|
-
|
|
467
|
-
step = range_.step
|
|
468
|
-
sz = abs(range_.stop - start)//step
|
|
497
|
+
sz = abs(cast("int", range.stop) - start) //step
|
|
469
498
|
|
|
470
499
|
# }}}
|
|
471
500
|
|
|
@@ -502,7 +531,7 @@ class ReductionKernel:
|
|
|
502
531
|
macrogroup_size = group_count*stage_inf.group_size
|
|
503
532
|
seq_count = (sz + macrogroup_size - 1) // macrogroup_size
|
|
504
533
|
|
|
505
|
-
size_args = [start, step,
|
|
534
|
+
size_args = [start, step, range.stop, seq_count, sz]
|
|
506
535
|
|
|
507
536
|
if group_count == 1 and out is not None:
|
|
508
537
|
result = out
|
|
@@ -534,7 +563,7 @@ class ReductionKernel:
|
|
|
534
563
|
stage_inf = self.stage_2_inf
|
|
535
564
|
args = (result, *stage1_args)
|
|
536
565
|
|
|
537
|
-
|
|
566
|
+
range = slice = None
|
|
538
567
|
|
|
539
568
|
# }}}
|
|
540
569
|
|
|
@@ -544,12 +573,12 @@ class ReductionKernel:
|
|
|
544
573
|
class ReductionTemplate(KernelTemplateBase):
|
|
545
574
|
def __init__(
|
|
546
575
|
self,
|
|
547
|
-
arguments:
|
|
576
|
+
arguments: str | list[DtypedArgument],
|
|
548
577
|
neutral: str,
|
|
549
578
|
reduce_expr: str,
|
|
550
|
-
map_expr:
|
|
551
|
-
is_segment_start_expr:
|
|
552
|
-
input_fetch_exprs:
|
|
579
|
+
map_expr: str | None = None,
|
|
580
|
+
is_segment_start_expr: str | None = None,
|
|
581
|
+
input_fetch_exprs: list[tuple[str, str, int]] | None = None,
|
|
553
582
|
name_prefix: str = "reduce",
|
|
554
583
|
preamble: str = "",
|
|
555
584
|
template_processor: Any = None) -> None:
|
pyopencl/scan.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Scan primitive."""
|
|
2
|
+
from __future__ import annotations
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
__copyright__ = """
|
|
@@ -25,7 +26,7 @@ Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
|
|
|
25
26
|
import logging
|
|
26
27
|
from abc import ABC, abstractmethod
|
|
27
28
|
from dataclasses import dataclass
|
|
28
|
-
from typing import Any
|
|
29
|
+
from typing import Any
|
|
29
30
|
|
|
30
31
|
import numpy as np
|
|
31
32
|
|
|
@@ -33,7 +34,7 @@ from pytools.persistent_dict import WriteOncePersistentDict
|
|
|
33
34
|
|
|
34
35
|
import pyopencl as cl
|
|
35
36
|
import pyopencl._mymako as mako
|
|
36
|
-
import pyopencl.array
|
|
37
|
+
import pyopencl.array as cl_array
|
|
37
38
|
from pyopencl._cluda import CLUDA_PREAMBLE
|
|
38
39
|
from pyopencl.tools import (
|
|
39
40
|
DtypedArgument,
|
|
@@ -848,7 +849,7 @@ def _make_template(s: str):
|
|
|
848
849
|
import re
|
|
849
850
|
leftovers = set()
|
|
850
851
|
|
|
851
|
-
def replace_id(match:
|
|
852
|
+
def replace_id(match: re.Match) -> str:
|
|
852
853
|
# avoid name clashes with user code by adding 'psc_' prefix to
|
|
853
854
|
# identifiers.
|
|
854
855
|
|
|
@@ -874,11 +875,11 @@ def _make_template(s: str):
|
|
|
874
875
|
class _GeneratedScanKernelInfo:
|
|
875
876
|
scan_src: str
|
|
876
877
|
kernel_name: str
|
|
877
|
-
scalar_arg_dtypes:
|
|
878
|
+
scalar_arg_dtypes: list[np.dtype | None]
|
|
878
879
|
wg_size: int
|
|
879
880
|
k_group_size: int
|
|
880
881
|
|
|
881
|
-
def build(self, context: cl.Context, options: Any) ->
|
|
882
|
+
def build(self, context: cl.Context, options: Any) -> _BuiltScanKernelInfo:
|
|
882
883
|
program = cl.Program(context, self.scan_src).build(options)
|
|
883
884
|
kernel = getattr(program, self.kernel_name)
|
|
884
885
|
kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
|
|
@@ -899,12 +900,12 @@ class _BuiltScanKernelInfo:
|
|
|
899
900
|
class _GeneratedFinalUpdateKernelInfo:
|
|
900
901
|
source: str
|
|
901
902
|
kernel_name: str
|
|
902
|
-
scalar_arg_dtypes:
|
|
903
|
+
scalar_arg_dtypes: list[np.dtype | None]
|
|
903
904
|
update_wg_size: int
|
|
904
905
|
|
|
905
906
|
def build(self,
|
|
906
907
|
context: cl.Context,
|
|
907
|
-
options: Any) ->
|
|
908
|
+
options: Any) -> _BuiltFinalUpdateKernelInfo:
|
|
908
909
|
program = cl.Program(context, self.source).build(options)
|
|
909
910
|
kernel = getattr(program, self.kernel_name)
|
|
910
911
|
kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
|
|
@@ -930,18 +931,18 @@ class GenericScanKernelBase(ABC):
|
|
|
930
931
|
self,
|
|
931
932
|
ctx: cl.Context,
|
|
932
933
|
dtype: Any,
|
|
933
|
-
arguments:
|
|
934
|
+
arguments: str | list[DtypedArgument],
|
|
934
935
|
input_expr: str,
|
|
935
936
|
scan_expr: str,
|
|
936
|
-
neutral:
|
|
937
|
+
neutral: str | None,
|
|
937
938
|
output_statement: str,
|
|
938
|
-
is_segment_start_expr:
|
|
939
|
-
input_fetch_exprs:
|
|
939
|
+
is_segment_start_expr: str | None = None,
|
|
940
|
+
input_fetch_exprs: list[tuple[str, str, int]] | None = None,
|
|
940
941
|
index_dtype: Any = None,
|
|
941
942
|
name_prefix: str = "scan",
|
|
942
943
|
options: Any = None,
|
|
943
944
|
preamble: str = "",
|
|
944
|
-
devices:
|
|
945
|
+
devices: cl.Device | None = None) -> None:
|
|
945
946
|
"""
|
|
946
947
|
:arg ctx: a :class:`pyopencl.Context` within which the code
|
|
947
948
|
for this scan kernel will be generated.
|
|
@@ -1142,7 +1143,7 @@ class GenericScanKernelBase(ABC):
|
|
|
1142
1143
|
|
|
1143
1144
|
if not cl._PYOPENCL_NO_CACHE:
|
|
1144
1145
|
generic_scan_kernel_cache: WriteOncePersistentDict[Any,
|
|
1145
|
-
|
|
1146
|
+
tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
|
|
1146
1147
|
_GeneratedFinalUpdateKernelInfo]] = \
|
|
1147
1148
|
WriteOncePersistentDict(
|
|
1148
1149
|
"pyopencl-generated-scan-kernel-cache-v1",
|
|
@@ -1329,7 +1330,7 @@ class GenericScanKernel(GenericScanKernelBase):
|
|
|
1329
1330
|
VectorArg(self.dtype, "interval_sums"),
|
|
1330
1331
|
]
|
|
1331
1332
|
|
|
1332
|
-
second_level_build_kwargs:
|
|
1333
|
+
second_level_build_kwargs: dict[str, str | None] = {}
|
|
1333
1334
|
if self.is_segmented:
|
|
1334
1335
|
second_level_arguments.append(
|
|
1335
1336
|
VectorArg(self.index_dtype,
|
|
@@ -1401,7 +1402,7 @@ class GenericScanKernel(GenericScanKernelBase):
|
|
|
1401
1402
|
for arg in self.parsed_args:
|
|
1402
1403
|
arg_dtypes[arg.name] = arg.dtype
|
|
1403
1404
|
|
|
1404
|
-
fetch_expr_offsets:
|
|
1405
|
+
fetch_expr_offsets: dict[str, set] = {}
|
|
1405
1406
|
for _name, arg_name, ife_offset in self.input_fetch_exprs:
|
|
1406
1407
|
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
|
|
1407
1408
|
|
|
@@ -1427,10 +1428,10 @@ class GenericScanKernel(GenericScanKernelBase):
|
|
|
1427
1428
|
def generate_scan_kernel(
|
|
1428
1429
|
self,
|
|
1429
1430
|
max_wg_size: int,
|
|
1430
|
-
arguments:
|
|
1431
|
+
arguments: list[DtypedArgument],
|
|
1431
1432
|
input_expr: str,
|
|
1432
|
-
is_segment_start_expr:
|
|
1433
|
-
input_fetch_exprs:
|
|
1433
|
+
is_segment_start_expr: str | None,
|
|
1434
|
+
input_fetch_exprs: list[tuple[str, str, int]],
|
|
1434
1435
|
is_first_level: bool,
|
|
1435
1436
|
store_segment_start_flags: bool,
|
|
1436
1437
|
k_group_size: int,
|
|
@@ -1527,7 +1528,7 @@ class GenericScanKernel(GenericScanKernelBase):
|
|
|
1527
1528
|
return cl.enqueue_marker(queue, wait_for=wait_for)
|
|
1528
1529
|
|
|
1529
1530
|
data_args = []
|
|
1530
|
-
for arg_descr, arg_val in zip(self.parsed_args, args):
|
|
1531
|
+
for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
|
|
1531
1532
|
from pyopencl.tools import VectorArg
|
|
1532
1533
|
if isinstance(arg_descr, VectorArg):
|
|
1533
1534
|
data_args.append(arg_val.base_data)
|
|
@@ -1552,16 +1553,16 @@ class GenericScanKernel(GenericScanKernelBase):
|
|
|
1552
1553
|
|
|
1553
1554
|
# {{{ allocate some buffers
|
|
1554
1555
|
|
|
1555
|
-
interval_results =
|
|
1556
|
+
interval_results = cl_array.empty(queue,
|
|
1556
1557
|
num_intervals, dtype=self.dtype,
|
|
1557
1558
|
allocator=allocator)
|
|
1558
1559
|
|
|
1559
|
-
partial_scan_buffer =
|
|
1560
|
+
partial_scan_buffer = cl_array.empty(
|
|
1560
1561
|
queue, n, dtype=self.dtype,
|
|
1561
1562
|
allocator=allocator)
|
|
1562
1563
|
|
|
1563
1564
|
if self.store_segment_start_flags:
|
|
1564
|
-
segment_start_flags =
|
|
1565
|
+
segment_start_flags = cl_array.empty(
|
|
1565
1566
|
queue, n, dtype=np.bool_,
|
|
1566
1567
|
allocator=allocator)
|
|
1567
1568
|
|
|
@@ -1575,7 +1576,7 @@ class GenericScanKernel(GenericScanKernelBase):
|
|
|
1575
1576
|
]
|
|
1576
1577
|
|
|
1577
1578
|
if self.is_segmented:
|
|
1578
|
-
first_segment_start_in_interval =
|
|
1579
|
+
first_segment_start_in_interval = cl_array.empty(queue,
|
|
1579
1580
|
num_intervals, dtype=self.index_dtype,
|
|
1580
1581
|
allocator=allocator)
|
|
1581
1582
|
scan1_args.append(first_segment_start_in_interval.data)
|
|
@@ -1755,13 +1756,13 @@ class GenericDebugScanKernel(GenericScanKernelBase):
|
|
|
1755
1756
|
if n is None:
|
|
1756
1757
|
n, = first_array.shape
|
|
1757
1758
|
|
|
1758
|
-
scan_tmp =
|
|
1759
|
+
scan_tmp = cl_array.empty(queue,
|
|
1759
1760
|
n, dtype=self.dtype,
|
|
1760
1761
|
allocator=allocator)
|
|
1761
1762
|
|
|
1762
1763
|
data_args = [scan_tmp.data]
|
|
1763
1764
|
from pyopencl.tools import VectorArg
|
|
1764
|
-
for arg_descr, arg_val in zip(self.parsed_args, args):
|
|
1765
|
+
for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
|
|
1765
1766
|
if isinstance(arg_descr, VectorArg):
|
|
1766
1767
|
data_args.append(arg_val.base_data)
|
|
1767
1768
|
if arg_descr.with_offset:
|
|
@@ -1806,7 +1807,7 @@ class _LegacyScanKernelBase(GenericScanKernel):
|
|
|
1806
1807
|
output_ary = input_ary
|
|
1807
1808
|
|
|
1808
1809
|
if isinstance(output_ary, (str, str)) and output_ary == "new":
|
|
1809
|
-
output_ary =
|
|
1810
|
+
output_ary = cl_array.empty_like(input_ary, allocator=allocator)
|
|
1810
1811
|
|
|
1811
1812
|
if input_ary.shape != output_ary.shape:
|
|
1812
1813
|
raise ValueError("input and output must have the same shape")
|
|
@@ -1841,13 +1842,13 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
|
|
|
1841
1842
|
class ScanTemplate(KernelTemplateBase):
|
|
1842
1843
|
def __init__(
|
|
1843
1844
|
self,
|
|
1844
|
-
arguments:
|
|
1845
|
+
arguments: str | list[DtypedArgument],
|
|
1845
1846
|
input_expr: str,
|
|
1846
1847
|
scan_expr: str,
|
|
1847
|
-
neutral:
|
|
1848
|
+
neutral: str | None,
|
|
1848
1849
|
output_statement: str,
|
|
1849
|
-
is_segment_start_expr:
|
|
1850
|
-
input_fetch_exprs:
|
|
1850
|
+
is_segment_start_expr: str | None = None,
|
|
1851
|
+
input_fetch_exprs: list[tuple[str, str, int]] | None = None,
|
|
1851
1852
|
name_prefix: str = "scan",
|
|
1852
1853
|
preamble: str = "",
|
|
1853
1854
|
template_processor: Any = None) -> None:
|