pyopencl 2025.2.5__cp312-cp312-macosx_11_0_arm64.whl → 2025.2.7__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyopencl might be problematic. Click here for more details.
- pyopencl/_cl.cpython-312-darwin.so +0 -0
- pyopencl/_cl.pyi +10 -7
- pyopencl/_monkeypatch.py +40 -5
- pyopencl/algorithm.py +1 -1
- pyopencl/array.py +214 -125
- pyopencl/cache.py +1 -1
- pyopencl/characterize/__init__.py +2 -4
- pyopencl/clmath.py +0 -1
- pyopencl/cltypes.py +42 -27
- pyopencl/compyte/array.py +9 -39
- pyopencl/compyte/dtypes.py +9 -11
- pyopencl/compyte/pyproject.toml +0 -3
- pyopencl/elementwise.py +223 -113
- pyopencl/scan.py +30 -25
- pyopencl/tools.py +327 -212
- {pyopencl-2025.2.5.dist-info → pyopencl-2025.2.7.dist-info}/METADATA +3 -4
- {pyopencl-2025.2.5.dist-info → pyopencl-2025.2.7.dist-info}/RECORD +19 -19
- {pyopencl-2025.2.5.dist-info → pyopencl-2025.2.7.dist-info}/WHEEL +1 -1
- {pyopencl-2025.2.5.dist-info → pyopencl-2025.2.7.dist-info}/licenses/LICENSE +0 -0
pyopencl/tools.py
CHANGED
|
@@ -128,23 +128,28 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
|
|
128
128
|
OTHER DEALINGS IN THE SOFTWARE.
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
|
-
|
|
131
|
+
import atexit
|
|
132
132
|
import re
|
|
133
133
|
from abc import ABC, abstractmethod
|
|
134
|
-
from dataclasses import dataclass
|
|
134
|
+
from dataclasses import dataclass, field
|
|
135
135
|
from sys import intern
|
|
136
136
|
from typing import (
|
|
137
137
|
TYPE_CHECKING,
|
|
138
138
|
Any,
|
|
139
|
+
ClassVar,
|
|
139
140
|
Concatenate,
|
|
140
141
|
ParamSpec,
|
|
142
|
+
TypeAlias,
|
|
143
|
+
TypedDict,
|
|
141
144
|
TypeVar,
|
|
145
|
+
cast,
|
|
146
|
+
overload,
|
|
142
147
|
)
|
|
143
148
|
|
|
144
149
|
import numpy as np
|
|
145
150
|
from typing_extensions import TypeIs, override
|
|
146
151
|
|
|
147
|
-
from pytools import memoize, memoize_method
|
|
152
|
+
from pytools import Hash, memoize, memoize_method
|
|
148
153
|
from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
|
|
149
154
|
|
|
150
155
|
from pyopencl._cl import bitlog2, get_cl_header_version
|
|
@@ -157,10 +162,11 @@ from pyopencl.compyte.dtypes import (
|
|
|
157
162
|
|
|
158
163
|
|
|
159
164
|
if TYPE_CHECKING:
|
|
160
|
-
from collections.abc import Callable, Hashable, Sequence
|
|
161
|
-
|
|
162
|
-
from numpy.typing import DTypeLike
|
|
165
|
+
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
|
|
163
166
|
|
|
167
|
+
import pytest
|
|
168
|
+
from mako.template import Template
|
|
169
|
+
from numpy.typing import DTypeLike, NDArray
|
|
164
170
|
|
|
165
171
|
# Do not add a pyopencl import here: This will add an import cycle.
|
|
166
172
|
|
|
@@ -373,7 +379,7 @@ def _monkeypatch_svm_docstrings():
|
|
|
373
379
|
|
|
374
380
|
# {{{ PooledSVM
|
|
375
381
|
|
|
376
|
-
PooledSVM.__doc__ = ( #
|
|
382
|
+
PooledSVM.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
|
|
377
383
|
"""An object representing a :class:`SVMPool`-based allocation of
|
|
378
384
|
:ref:`svm`. Analogous to :class:`~pyopencl.SVMAllocation`, however once
|
|
379
385
|
this object is deleted, its associated device memory is returned to the
|
|
@@ -420,7 +426,7 @@ def _monkeypatch_svm_docstrings():
|
|
|
420
426
|
|
|
421
427
|
# {{{ SVMAllocator
|
|
422
428
|
|
|
423
|
-
SVMAllocator.__doc__ = ( #
|
|
429
|
+
SVMAllocator.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
|
|
424
430
|
"""
|
|
425
431
|
.. versionadded:: 2022.2
|
|
426
432
|
|
|
@@ -453,7 +459,7 @@ def _monkeypatch_svm_docstrings():
|
|
|
453
459
|
|
|
454
460
|
# {{{ SVMPool
|
|
455
461
|
|
|
456
|
-
SVMPool.__doc__ = ( #
|
|
462
|
+
SVMPool.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
|
|
457
463
|
remove_common_indentation("""
|
|
458
464
|
A memory pool for OpenCL device memory in :ref:`SVM <svm>` form.
|
|
459
465
|
*allocator* must be an instance of :class:`SVMAllocator`.
|
|
@@ -477,17 +483,18 @@ if get_cl_header_version() >= (2, 0):
|
|
|
477
483
|
|
|
478
484
|
# {{{ first-arg caches
|
|
479
485
|
|
|
480
|
-
_first_arg_dependent_caches: list[
|
|
486
|
+
_first_arg_dependent_caches: list[Mapping[Hashable, object]] = []
|
|
481
487
|
|
|
482
488
|
|
|
489
|
+
HashableT = TypeVar("HashableT", bound="Hashable")
|
|
483
490
|
RetT = TypeVar("RetT")
|
|
484
491
|
P = ParamSpec("P")
|
|
485
492
|
|
|
486
493
|
|
|
487
494
|
def first_arg_dependent_memoize(
|
|
488
|
-
func: Callable[Concatenate[
|
|
489
|
-
) -> Callable[Concatenate[
|
|
490
|
-
def wrapper(cl_object:
|
|
495
|
+
func: Callable[Concatenate[HashableT, P], RetT]
|
|
496
|
+
) -> Callable[Concatenate[HashableT, P], RetT]:
|
|
497
|
+
def wrapper(cl_object: HashableT, *args: P.args, **kwargs: P.kwargs) -> RetT:
|
|
491
498
|
"""Provides memoization for a function. Typically used to cache
|
|
492
499
|
things that get created inside a :class:`pyopencl.Context`, e.g. programs
|
|
493
500
|
and kernels. Assumes that the first argument of the decorated function is
|
|
@@ -527,33 +534,41 @@ def first_arg_dependent_memoize(
|
|
|
527
534
|
context_dependent_memoize = first_arg_dependent_memoize
|
|
528
535
|
|
|
529
536
|
|
|
530
|
-
def first_arg_dependent_memoize_nested(
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
:
|
|
537
|
+
def first_arg_dependent_memoize_nested(
|
|
538
|
+
nested_func: Callable[Concatenate[Hashable, P], RetT]
|
|
539
|
+
) -> Callable[Concatenate[Hashable, P], RetT]:
|
|
540
|
+
"""Provides memoization for nested functions.
|
|
541
|
+
|
|
542
|
+
Typically used to cache things that get created inside a
|
|
543
|
+
:class:`pyopencl.Context`, e.g. programs and kernels. Assumes that the first
|
|
544
|
+
argument of the decorated function is an OpenCL object that might go away,
|
|
545
|
+
such as a :class:`pyopencl.Context` or a :class:`pyopencl.CommandQueue`, and
|
|
546
|
+
will therefore respond to :func:`clear_first_arg_caches`.
|
|
537
547
|
|
|
538
548
|
.. versionadded:: 2013.1
|
|
539
549
|
"""
|
|
540
550
|
|
|
541
551
|
from functools import wraps
|
|
542
|
-
cache_dict_name = intern(
|
|
543
|
-
|
|
544
|
-
|
|
552
|
+
cache_dict_name = intern(
|
|
553
|
+
f"_memoize_inner_dic_{nested_func.__name__}_"
|
|
554
|
+
f"{nested_func.__code__.co_filename}_"
|
|
555
|
+
f"{nested_func.__code__.co_firstlineno}")
|
|
545
556
|
|
|
546
557
|
from inspect import currentframe
|
|
547
558
|
|
|
548
559
|
# prevent ref cycle
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
560
|
+
frame = currentframe()
|
|
561
|
+
cache_context = None
|
|
562
|
+
if frame:
|
|
563
|
+
try:
|
|
564
|
+
caller_frame = frame.f_back
|
|
565
|
+
if caller_frame:
|
|
566
|
+
cache_context = caller_frame.f_globals[caller_frame.f_code.co_name]
|
|
567
|
+
finally:
|
|
568
|
+
# del caller_frame
|
|
569
|
+
pass
|
|
570
|
+
|
|
571
|
+
cache_dict: dict[Hashable, dict[Hashable, RetT]]
|
|
557
572
|
try:
|
|
558
573
|
cache_dict = getattr(cache_context, cache_dict_name)
|
|
559
574
|
except AttributeError:
|
|
@@ -562,12 +577,14 @@ def first_arg_dependent_memoize_nested(nested_func):
|
|
|
562
577
|
setattr(cache_context, cache_dict_name, cache_dict)
|
|
563
578
|
|
|
564
579
|
@wraps(nested_func)
|
|
565
|
-
def new_nested_func(cl_object, *args):
|
|
580
|
+
def new_nested_func(cl_object: Hashable, *args: P.args, **kwargs: P.kwargs) -> RetT:
|
|
581
|
+
assert not kwargs
|
|
582
|
+
|
|
566
583
|
try:
|
|
567
584
|
return cache_dict[cl_object][args]
|
|
568
585
|
except KeyError:
|
|
569
586
|
arg_dict = cache_dict.setdefault(cl_object, {})
|
|
570
|
-
result = nested_func(cl_object, *args)
|
|
587
|
+
result = nested_func(cl_object, *args, **kwargs)
|
|
571
588
|
arg_dict[args] = result
|
|
572
589
|
return result
|
|
573
590
|
|
|
@@ -575,23 +592,24 @@ def first_arg_dependent_memoize_nested(nested_func):
|
|
|
575
592
|
|
|
576
593
|
|
|
577
594
|
def clear_first_arg_caches():
|
|
578
|
-
"""Empties all first-argument-dependent memoization caches.
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
595
|
+
"""Empties all first-argument-dependent memoization caches.
|
|
596
|
+
|
|
597
|
+
Also releases all held reference contexts. If it is important to you that the
|
|
598
|
+
program detaches from its context, you might need to call this function to
|
|
599
|
+
free all remaining references to your context.
|
|
582
600
|
|
|
583
601
|
.. versionadded:: 2011.2
|
|
584
602
|
"""
|
|
585
603
|
for cache in _first_arg_dependent_caches:
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
604
|
+
# NOTE: this could be fixed by making the caches a MutableMapping, but
|
|
605
|
+
# that doesn't seem to be correctly covariant in its values, so other
|
|
606
|
+
# parts fail to work nicely..
|
|
607
|
+
cache.clear() # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
|
|
590
608
|
|
|
591
609
|
|
|
592
610
|
if TYPE_CHECKING:
|
|
593
611
|
import pyopencl as cl
|
|
594
|
-
|
|
612
|
+
from pyopencl.array import Array as CLArray
|
|
595
613
|
|
|
596
614
|
atexit.register(clear_first_arg_caches)
|
|
597
615
|
|
|
@@ -621,9 +639,9 @@ class _ContextFactory:
|
|
|
621
639
|
@override
|
|
622
640
|
def __str__(self) -> str:
|
|
623
641
|
# Don't show address, so that parallel test collection works
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
642
|
+
device = self.device.name.strip()
|
|
643
|
+
platform = self.device.platform.name.strip()
|
|
644
|
+
return f"<context factory for <pyopencl.Device '{device}' on '{platform}'>>"
|
|
627
645
|
|
|
628
646
|
|
|
629
647
|
DeviceOrPlatformT = TypeVar("DeviceOrPlatformT", "cl.Device", "cl.Platform")
|
|
@@ -643,7 +661,7 @@ def _find_cl_obj(
|
|
|
643
661
|
for obj in objs:
|
|
644
662
|
if identifier.lower() in (obj.name + " " + obj.vendor).lower():
|
|
645
663
|
return obj
|
|
646
|
-
raise RuntimeError("object '
|
|
664
|
+
raise RuntimeError(f"object '{identifier}' not found")
|
|
647
665
|
|
|
648
666
|
|
|
649
667
|
def get_test_platforms_and_devices(
|
|
@@ -690,7 +708,9 @@ def get_test_platforms_and_devices(
|
|
|
690
708
|
for platform in cl.get_platforms()]
|
|
691
709
|
|
|
692
710
|
|
|
693
|
-
def get_pyopencl_fixture_arg_names(
|
|
711
|
+
def get_pyopencl_fixture_arg_names(
|
|
712
|
+
metafunc: pytest.Metafunc,
|
|
713
|
+
extra_arg_names: list[str] | None = None) -> list[str]:
|
|
694
714
|
if extra_arg_names is None:
|
|
695
715
|
extra_arg_names = []
|
|
696
716
|
|
|
@@ -700,7 +720,7 @@ def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
|
|
|
700
720
|
*extra_arg_names
|
|
701
721
|
]
|
|
702
722
|
|
|
703
|
-
arg_names = []
|
|
723
|
+
arg_names: list[str] = []
|
|
704
724
|
for arg in supported_arg_names:
|
|
705
725
|
if arg not in metafunc.fixturenames:
|
|
706
726
|
continue
|
|
@@ -716,10 +736,11 @@ def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
|
|
|
716
736
|
return arg_names
|
|
717
737
|
|
|
718
738
|
|
|
719
|
-
def get_pyopencl_fixture_arg_values()
|
|
739
|
+
def get_pyopencl_fixture_arg_values() -> tuple[list[dict[str, Any]],
|
|
740
|
+
Callable[[Any], str]]:
|
|
720
741
|
import pyopencl as cl
|
|
721
742
|
|
|
722
|
-
arg_values = []
|
|
743
|
+
arg_values: list[dict[str, Any]] = []
|
|
723
744
|
for platform, devices in get_test_platforms_and_devices():
|
|
724
745
|
for device in devices:
|
|
725
746
|
arg_dict = {
|
|
@@ -730,7 +751,7 @@ def get_pyopencl_fixture_arg_values():
|
|
|
730
751
|
}
|
|
731
752
|
arg_values.append(arg_dict)
|
|
732
753
|
|
|
733
|
-
def idfn(val):
|
|
754
|
+
def idfn(val: Any) -> str:
|
|
734
755
|
if isinstance(val, cl.Platform):
|
|
735
756
|
# Don't show address, so that parallel test collection works
|
|
736
757
|
return f"<pyopencl.Platform '{val.name}'>"
|
|
@@ -740,7 +761,7 @@ def get_pyopencl_fixture_arg_values():
|
|
|
740
761
|
return arg_values, idfn
|
|
741
762
|
|
|
742
763
|
|
|
743
|
-
def pytest_generate_tests_for_pyopencl(metafunc):
|
|
764
|
+
def pytest_generate_tests_for_pyopencl(metafunc: pytest.Metafunc) -> None:
|
|
744
765
|
"""Using the line::
|
|
745
766
|
|
|
746
767
|
from pyopencl.tools import pytest_generate_tests_for_pyopencl
|
|
@@ -751,7 +772,7 @@ def pytest_generate_tests_for_pyopencl(metafunc):
|
|
|
751
772
|
functions, and they will automatically be run for each OpenCL device/platform
|
|
752
773
|
in the system, as appropriate.
|
|
753
774
|
|
|
754
|
-
The following two environment
|
|
775
|
+
The following two environment variables is also supported to control
|
|
755
776
|
device/platform choice::
|
|
756
777
|
|
|
757
778
|
PYOPENCL_TEST=0:0,1;intel=i5,i7
|
|
@@ -774,6 +795,10 @@ def pytest_generate_tests_for_pyopencl(metafunc):
|
|
|
774
795
|
|
|
775
796
|
# {{{ C argument lists
|
|
776
797
|
|
|
798
|
+
ArgType: TypeAlias = "np.dtype[Any] | VectorArg"
|
|
799
|
+
ArgDType: TypeAlias = "np.dtype[Any] | None"
|
|
800
|
+
|
|
801
|
+
|
|
777
802
|
class Argument(ABC):
|
|
778
803
|
"""
|
|
779
804
|
.. automethod:: declarator
|
|
@@ -850,18 +875,19 @@ def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument:
|
|
|
850
875
|
c_arg = c_arg.replace("__global", "")
|
|
851
876
|
|
|
852
877
|
if with_offset:
|
|
853
|
-
def vec_arg_factory(dtype, name):
|
|
878
|
+
def vec_arg_factory(dtype: DTypeLike, name: str) -> VectorArg:
|
|
854
879
|
return VectorArg(dtype, name, with_offset=True)
|
|
855
880
|
else:
|
|
856
881
|
vec_arg_factory = VectorArg
|
|
857
882
|
|
|
858
883
|
from pyopencl.compyte.dtypes import parse_c_arg_backend
|
|
884
|
+
|
|
859
885
|
return parse_c_arg_backend(c_arg, ScalarArg, vec_arg_factory)
|
|
860
886
|
|
|
861
887
|
|
|
862
888
|
def parse_arg_list(
|
|
863
|
-
arguments: str |
|
|
864
|
-
with_offset: bool = False) ->
|
|
889
|
+
arguments: str | Sequence[str] | Sequence[Argument],
|
|
890
|
+
with_offset: bool = False) -> Sequence[DtypedArgument]:
|
|
865
891
|
"""Parse a list of kernel arguments. *arguments* may be a comma-separate
|
|
866
892
|
list of C declarators in a string, a list of strings representing C
|
|
867
893
|
declarators, or :class:`Argument` objects.
|
|
@@ -870,7 +896,7 @@ def parse_arg_list(
|
|
|
870
896
|
if isinstance(arguments, str):
|
|
871
897
|
arguments = arguments.split(",")
|
|
872
898
|
|
|
873
|
-
def parse_single_arg(obj: str |
|
|
899
|
+
def parse_single_arg(obj: str | Argument) -> DtypedArgument:
|
|
874
900
|
if isinstance(obj, str):
|
|
875
901
|
from pyopencl.tools import parse_c_arg
|
|
876
902
|
return parse_c_arg(obj, with_offset=with_offset)
|
|
@@ -881,8 +907,8 @@ def parse_arg_list(
|
|
|
881
907
|
return [parse_single_arg(arg) for arg in arguments]
|
|
882
908
|
|
|
883
909
|
|
|
884
|
-
def get_arg_list_arg_types(arg_types):
|
|
885
|
-
result = []
|
|
910
|
+
def get_arg_list_arg_types(arg_types: Sequence[Argument]) -> tuple[ArgType, ...]:
|
|
911
|
+
result: list[ArgType] = []
|
|
886
912
|
|
|
887
913
|
for arg_type in arg_types:
|
|
888
914
|
if isinstance(arg_type, ScalarArg):
|
|
@@ -890,15 +916,15 @@ def get_arg_list_arg_types(arg_types):
|
|
|
890
916
|
elif isinstance(arg_type, VectorArg):
|
|
891
917
|
result.append(arg_type)
|
|
892
918
|
else:
|
|
893
|
-
raise RuntimeError("arg type not understood:
|
|
919
|
+
raise RuntimeError(f"arg type not understood: {type(arg_type)}")
|
|
894
920
|
|
|
895
921
|
return tuple(result)
|
|
896
922
|
|
|
897
923
|
|
|
898
924
|
def get_arg_list_scalar_arg_dtypes(
|
|
899
|
-
arg_types: Sequence[
|
|
900
|
-
) ->
|
|
901
|
-
result: list[
|
|
925
|
+
arg_types: Sequence[Argument]
|
|
926
|
+
) -> Sequence[ArgDType]:
|
|
927
|
+
result: list[ArgDType] = []
|
|
902
928
|
|
|
903
929
|
for arg_type in arg_types:
|
|
904
930
|
if isinstance(arg_type, ScalarArg):
|
|
@@ -918,26 +944,26 @@ def get_arg_offset_adjuster_code(arg_types: Sequence[Argument]) -> str:
|
|
|
918
944
|
|
|
919
945
|
for arg_type in arg_types:
|
|
920
946
|
if isinstance(arg_type, VectorArg) and arg_type.with_offset:
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
947
|
+
name = arg_type.name
|
|
948
|
+
ctype = dtype_to_ctype(arg_type.dtype)
|
|
949
|
+
result.append(
|
|
950
|
+
f"__global {ctype} *{name} = "
|
|
951
|
+
f"(__global {ctype} *) "
|
|
952
|
+
f"((__global char *) {name}__base + {name}__offset);")
|
|
927
953
|
|
|
928
954
|
return "\n".join(result)
|
|
929
955
|
|
|
930
956
|
# }}}
|
|
931
957
|
|
|
932
958
|
|
|
933
|
-
def get_gl_sharing_context_properties():
|
|
959
|
+
def get_gl_sharing_context_properties() -> list[tuple[cl.context_properties, Any]]:
|
|
934
960
|
import pyopencl as cl
|
|
935
961
|
|
|
936
962
|
ctx_props = cl.context_properties
|
|
937
963
|
|
|
938
964
|
from OpenGL import platform as gl_platform
|
|
939
965
|
|
|
940
|
-
props = []
|
|
966
|
+
props: list[tuple[cl.context_properties, Any]] = []
|
|
941
967
|
|
|
942
968
|
import sys
|
|
943
969
|
if sys.platform in ["linux", "linux2"]:
|
|
@@ -959,24 +985,23 @@ def get_gl_sharing_context_properties():
|
|
|
959
985
|
(ctx_props.CONTEXT_PROPERTY_USE_CGL_SHAREGROUP_APPLE,
|
|
960
986
|
cl.get_apple_cgl_share_group()))
|
|
961
987
|
else:
|
|
962
|
-
raise NotImplementedError("platform '
|
|
963
|
-
% sys.platform)
|
|
988
|
+
raise NotImplementedError(f"platform '{sys.platform}' not yet supported")
|
|
964
989
|
|
|
965
990
|
return props
|
|
966
991
|
|
|
967
992
|
|
|
968
993
|
class _CDeclList:
|
|
969
|
-
def __init__(self, device):
|
|
970
|
-
self.device = device
|
|
971
|
-
self.declared_dtypes = set()
|
|
972
|
-
self.declarations = []
|
|
973
|
-
self.saw_double = False
|
|
974
|
-
self.saw_complex = False
|
|
975
|
-
|
|
976
|
-
def add_dtype(self, dtype):
|
|
994
|
+
def __init__(self, device: cl.Device) -> None:
|
|
995
|
+
self.device: cl.Device = device
|
|
996
|
+
self.declared_dtypes: set[np.dtype[Any]] = set()
|
|
997
|
+
self.declarations: list[str] = []
|
|
998
|
+
self.saw_double: bool = False
|
|
999
|
+
self.saw_complex: bool = False
|
|
1000
|
+
|
|
1001
|
+
def add_dtype(self, dtype: DTypeLike) -> None:
|
|
977
1002
|
dtype = np.dtype(dtype)
|
|
978
1003
|
|
|
979
|
-
if dtype in (np.float64, np.complex128):
|
|
1004
|
+
if dtype.type in (np.float64, np.complex128):
|
|
980
1005
|
self.saw_double = True
|
|
981
1006
|
|
|
982
1007
|
if dtype.kind == "c":
|
|
@@ -988,17 +1013,20 @@ class _CDeclList:
|
|
|
988
1013
|
if dtype in self.declared_dtypes:
|
|
989
1014
|
return
|
|
990
1015
|
|
|
991
|
-
|
|
992
|
-
|
|
1016
|
+
from pyopencl.cltypes import vec_type_to_scalar_and_count
|
|
1017
|
+
|
|
1018
|
+
if dtype in vec_type_to_scalar_and_count:
|
|
993
1019
|
return
|
|
994
1020
|
|
|
995
1021
|
if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
|
|
996
1022
|
self.add_dtype(dtype.subdtype[0])
|
|
997
1023
|
return
|
|
998
1024
|
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1025
|
+
fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
|
|
1026
|
+
if fields is not None:
|
|
1027
|
+
for _name, field_data in sorted(fields.items()):
|
|
1028
|
+
field_dtype, _offset = field_data[:2]
|
|
1029
|
+
self.add_dtype(field_dtype)
|
|
1002
1030
|
|
|
1003
1031
|
_, cdecl = match_dtype_to_c_struct(
|
|
1004
1032
|
self.device, dtype_to_ctype(dtype), dtype)
|
|
@@ -1006,16 +1034,19 @@ class _CDeclList:
|
|
|
1006
1034
|
self.declarations.append(cdecl)
|
|
1007
1035
|
self.declared_dtypes.add(dtype)
|
|
1008
1036
|
|
|
1009
|
-
def visit_arguments(self, arguments):
|
|
1037
|
+
def visit_arguments(self, arguments: Sequence[Argument]) -> None:
|
|
1010
1038
|
for arg in arguments:
|
|
1039
|
+
if not isinstance(arg, DtypedArgument):
|
|
1040
|
+
continue
|
|
1041
|
+
|
|
1011
1042
|
dtype = arg.dtype
|
|
1012
|
-
if dtype in (np.float64, np.complex128):
|
|
1043
|
+
if dtype.type in (np.float64, np.complex128):
|
|
1013
1044
|
self.saw_double = True
|
|
1014
1045
|
|
|
1015
1046
|
if dtype.kind == "c":
|
|
1016
1047
|
self.saw_complex = True
|
|
1017
1048
|
|
|
1018
|
-
def get_declarations(self):
|
|
1049
|
+
def get_declarations(self) -> str:
|
|
1019
1050
|
result = "\n\n".join(self.declarations)
|
|
1020
1051
|
|
|
1021
1052
|
if self.saw_complex:
|
|
@@ -1036,8 +1067,19 @@ class _CDeclList:
|
|
|
1036
1067
|
return result
|
|
1037
1068
|
|
|
1038
1069
|
|
|
1070
|
+
class _DTypeDict(TypedDict):
|
|
1071
|
+
names: list[str]
|
|
1072
|
+
formats: list[np.dtype[Any]]
|
|
1073
|
+
offsets: list[int]
|
|
1074
|
+
itemsize: int
|
|
1075
|
+
|
|
1076
|
+
|
|
1039
1077
|
@memoize
|
|
1040
|
-
def match_dtype_to_c_struct(
|
|
1078
|
+
def match_dtype_to_c_struct(
|
|
1079
|
+
device: cl.Device,
|
|
1080
|
+
name: str,
|
|
1081
|
+
dtype: np.dtype[Any],
|
|
1082
|
+
context: cl.Context | None = None) -> tuple[np.dtype[Any], str]:
|
|
1041
1083
|
"""Return a tuple ``(dtype, c_decl)`` such that the C struct declaration
|
|
1042
1084
|
in ``c_decl`` and the structure :class:`numpy.dtype` instance ``dtype``
|
|
1043
1085
|
have the same memory layout.
|
|
@@ -1073,14 +1115,18 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
|
|
|
1073
1115
|
:func:`get_or_register_dtype` on the modified ``dtype`` returned by this
|
|
1074
1116
|
function, not the original one.
|
|
1075
1117
|
"""
|
|
1118
|
+
fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
|
|
1119
|
+
if not fields:
|
|
1120
|
+
raise ValueError(f"dtype has no fields: '{dtype}'")
|
|
1076
1121
|
|
|
1077
1122
|
import pyopencl as cl
|
|
1078
1123
|
|
|
1079
|
-
|
|
1124
|
+
sorted_fields = sorted(
|
|
1125
|
+
fields.items(),
|
|
1080
1126
|
key=lambda name_dtype_offset: name_dtype_offset[1][1])
|
|
1081
1127
|
|
|
1082
|
-
c_fields = []
|
|
1083
|
-
for field_name, dtype_and_offset in
|
|
1128
|
+
c_fields: list[str] = []
|
|
1129
|
+
for field_name, dtype_and_offset in sorted_fields:
|
|
1084
1130
|
field_dtype, _offset = dtype_and_offset[:2]
|
|
1085
1131
|
if hasattr(field_dtype, "subdtype") and field_dtype.subdtype is not None:
|
|
1086
1132
|
array_dtype = field_dtype.subdtype[0]
|
|
@@ -1090,9 +1136,9 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
|
|
|
1090
1136
|
dims_str = ""
|
|
1091
1137
|
try:
|
|
1092
1138
|
for dim in array_dims:
|
|
1093
|
-
dims_str += "[
|
|
1139
|
+
dims_str += f"[{dim}]"
|
|
1094
1140
|
except TypeError:
|
|
1095
|
-
dims_str = "[
|
|
1141
|
+
dims_str = f"[{array_dims}]"
|
|
1096
1142
|
c_fields.append(" {} {}{};".format(
|
|
1097
1143
|
dtype_to_ctype(array_dtype), field_name, dims_str)
|
|
1098
1144
|
)
|
|
@@ -1105,15 +1151,15 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
|
|
|
1105
1151
|
name)
|
|
1106
1152
|
|
|
1107
1153
|
cdl = _CDeclList(device)
|
|
1108
|
-
for _field_name, dtype_and_offset in
|
|
1154
|
+
for _field_name, dtype_and_offset in sorted_fields:
|
|
1109
1155
|
field_dtype, _offset = dtype_and_offset[:2]
|
|
1110
1156
|
cdl.add_dtype(field_dtype)
|
|
1111
1157
|
|
|
1112
1158
|
pre_decls = cdl.get_declarations()
|
|
1113
1159
|
|
|
1114
1160
|
offset_code = "\n".join(
|
|
1115
|
-
"result[
|
|
1116
|
-
for i, (field_name, _) in enumerate(
|
|
1161
|
+
f"result[{i + 1}] = pycl_offsetof({name}, {field_name});"
|
|
1162
|
+
for i, (field_name, _) in enumerate(sorted_fields))
|
|
1117
1163
|
|
|
1118
1164
|
src = rf"""
|
|
1119
1165
|
#define pycl_offsetof(st, m) \
|
|
@@ -1140,30 +1186,29 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
|
|
|
1140
1186
|
prg = cl.Program(context, src)
|
|
1141
1187
|
knl = prg.build(devices=[device]).get_size_and_offsets
|
|
1142
1188
|
|
|
1143
|
-
import pyopencl.array
|
|
1189
|
+
import pyopencl.array as cl_array
|
|
1190
|
+
|
|
1191
|
+
result_buf = cl_array.empty(queue, 1+len(sorted_fields), np.uint32)
|
|
1192
|
+
assert result_buf.data is not None
|
|
1144
1193
|
|
|
1145
|
-
result_buf = cl.array.empty(queue, 1+len(fields), np.uint32)
|
|
1146
1194
|
knl(queue, (1,), (1,), result_buf.data)
|
|
1147
1195
|
queue.finish()
|
|
1148
1196
|
size_and_offsets = result_buf.get()
|
|
1149
1197
|
|
|
1150
1198
|
size = int(size_and_offsets[0])
|
|
1151
|
-
|
|
1152
1199
|
offsets = size_and_offsets[1:]
|
|
1200
|
+
|
|
1153
1201
|
if any(ofs >= size for ofs in offsets):
|
|
1154
1202
|
# offsets not plausible
|
|
1155
1203
|
|
|
1156
1204
|
if dtype.itemsize == size:
|
|
1157
1205
|
# If sizes match, use numpy's idea of the offsets.
|
|
1158
|
-
offsets = [dtype_and_offset[1]
|
|
1159
|
-
for field_name, dtype_and_offset in fields]
|
|
1206
|
+
offsets = [dtype_and_offset[1] for _name, dtype_and_offset in sorted_fields]
|
|
1160
1207
|
else:
|
|
1161
1208
|
raise RuntimeError(
|
|
1162
|
-
"OpenCL compiler reported offsetof() past sizeof() "
|
|
1163
|
-
"
|
|
1164
|
-
"
|
|
1165
|
-
"compiler bug. "
|
|
1166
|
-
"Refusing to discover struct layout." % device)
|
|
1209
|
+
"OpenCL compiler reported offsetof() past sizeof() for struct "
|
|
1210
|
+
f"layout on '{device}'. This makes no sense, and it usually "
|
|
1211
|
+
"indicates a compiler bug. Refusing to discover struct layout.")
|
|
1167
1212
|
|
|
1168
1213
|
result_buf.data.release()
|
|
1169
1214
|
del knl
|
|
@@ -1172,58 +1217,66 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
|
|
|
1172
1217
|
del context
|
|
1173
1218
|
|
|
1174
1219
|
try:
|
|
1175
|
-
dtype_arg_dict =
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
if dtype.itemsize != size_and_offsets[0]:
|
|
1220
|
+
dtype_arg_dict = _DTypeDict(
|
|
1221
|
+
names=[name for name, _ in sorted_fields],
|
|
1222
|
+
formats=[dtype_and_offset[0] for _, dtype_and_offset in sorted_fields],
|
|
1223
|
+
offsets=[int(x) for x in offsets],
|
|
1224
|
+
itemsize=int(size_and_offsets[0]),
|
|
1225
|
+
)
|
|
1226
|
+
arg_dtype = np.dtype(dtype_arg_dict)
|
|
1227
|
+
|
|
1228
|
+
if arg_dtype.itemsize != size_and_offsets[0]:
|
|
1185
1229
|
# "Old" versions of numpy (1.6.x?) silently ignore "itemsize". Boo.
|
|
1186
1230
|
dtype_arg_dict["names"].append("_pycl_size_fixer")
|
|
1187
|
-
dtype_arg_dict["formats"].append(np.uint8)
|
|
1188
|
-
dtype_arg_dict["offsets"].append(int(size_and_offsets[0])-1)
|
|
1189
|
-
|
|
1231
|
+
dtype_arg_dict["formats"].append(np.dtype(np.uint8))
|
|
1232
|
+
dtype_arg_dict["offsets"].append(int(size_and_offsets[0]) - 1)
|
|
1233
|
+
|
|
1234
|
+
arg_dtype = np.dtype(dtype_arg_dict)
|
|
1190
1235
|
except NotImplementedError:
|
|
1191
|
-
def calc_field_type():
|
|
1236
|
+
def calc_field_type() -> Iterator[tuple[str, str | np.dtype[Any]]]:
|
|
1192
1237
|
total_size = 0
|
|
1193
1238
|
padding_count = 0
|
|
1194
|
-
for offset, (field_name,
|
|
1195
|
-
|
|
1239
|
+
for offset, (field_name, dtype_and_offset) in zip(
|
|
1240
|
+
offsets, sorted_fields, strict=True):
|
|
1241
|
+
field_dtype, _ = dtype_and_offset[:2]
|
|
1196
1242
|
if offset > total_size:
|
|
1197
1243
|
padding_count += 1
|
|
1198
|
-
yield
|
|
1199
|
-
|
|
1244
|
+
yield f"__pycl_padding{padding_count}", f"V{offset - total_size}"
|
|
1245
|
+
|
|
1200
1246
|
yield field_name, field_dtype
|
|
1201
1247
|
total_size = field_dtype.itemsize + offset
|
|
1202
|
-
dtype = np.dtype(list(calc_field_type()))
|
|
1203
1248
|
|
|
1204
|
-
|
|
1249
|
+
arg_dtype = np.dtype(list(calc_field_type()))
|
|
1250
|
+
|
|
1251
|
+
assert arg_dtype.itemsize == size_and_offsets[0]
|
|
1205
1252
|
|
|
1206
|
-
return
|
|
1253
|
+
return arg_dtype, c_decl
|
|
1207
1254
|
|
|
1208
1255
|
|
|
1209
1256
|
@memoize
|
|
1210
|
-
def dtype_to_c_struct(device, dtype):
|
|
1211
|
-
|
|
1257
|
+
def dtype_to_c_struct(device: cl.Device, dtype: np.dtype[Any]) -> str:
|
|
1258
|
+
fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
|
|
1259
|
+
if fields is None:
|
|
1212
1260
|
return ""
|
|
1213
1261
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1262
|
+
from pyopencl.cltypes import vec_type_to_scalar_and_count
|
|
1263
|
+
|
|
1264
|
+
if dtype in vec_type_to_scalar_and_count:
|
|
1216
1265
|
# Vector types are built-in. Don't try to redeclare those.
|
|
1217
1266
|
return ""
|
|
1218
1267
|
|
|
1219
1268
|
matched_dtype, c_decl = match_dtype_to_c_struct(
|
|
1220
1269
|
device, dtype_to_ctype(dtype), dtype)
|
|
1221
1270
|
|
|
1222
|
-
|
|
1223
|
-
|
|
1271
|
+
matched_fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None",
|
|
1272
|
+
matched_dtype.fields)
|
|
1273
|
+
assert matched_fields is not None
|
|
1274
|
+
|
|
1275
|
+
def dtypes_match() -> bool:
|
|
1276
|
+
result = len(fields) == len(matched_fields)
|
|
1224
1277
|
|
|
1225
|
-
for name, val in
|
|
1226
|
-
result = result and
|
|
1278
|
+
for name, val in fields.items():
|
|
1279
|
+
result = result and matched_fields[name] == val
|
|
1227
1280
|
|
|
1228
1281
|
return result
|
|
1229
1282
|
|
|
@@ -1234,38 +1287,52 @@ def dtype_to_c_struct(device, dtype):
|
|
|
1234
1287
|
|
|
1235
1288
|
# {{{ code generation/templating helper
|
|
1236
1289
|
|
|
1237
|
-
def _process_code_for_macro(code):
|
|
1290
|
+
def _process_code_for_macro(code: str) -> str:
|
|
1238
1291
|
code = code.replace("//CL//", "\n")
|
|
1239
1292
|
|
|
1240
1293
|
if "//" in code:
|
|
1241
|
-
raise RuntimeError(
|
|
1242
|
-
|
|
1294
|
+
raise RuntimeError(
|
|
1295
|
+
"end-of-line comments ('//') may not be used in code snippets")
|
|
1243
1296
|
|
|
1244
1297
|
return code.replace("\n", " \\\n")
|
|
1245
1298
|
|
|
1246
1299
|
|
|
1247
|
-
class
|
|
1248
|
-
|
|
1249
|
-
|
|
1300
|
+
class _TextTemplate(ABC):
|
|
1301
|
+
@abstractmethod
|
|
1302
|
+
def render(self, context: dict[str, Any]) -> str:
|
|
1303
|
+
pass
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
@dataclass(frozen=True)
|
|
1307
|
+
class _SimpleTextTemplate(_TextTemplate):
|
|
1308
|
+
txt: str
|
|
1250
1309
|
|
|
1251
|
-
|
|
1310
|
+
@override
|
|
1311
|
+
def render(self, context: dict[str, Any]) -> str:
|
|
1252
1312
|
return self.txt
|
|
1253
1313
|
|
|
1254
1314
|
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1315
|
+
@dataclass(frozen=True)
|
|
1316
|
+
class _PrintfTextTemplate(_TextTemplate):
|
|
1317
|
+
txt: str
|
|
1258
1318
|
|
|
1259
|
-
|
|
1319
|
+
@override
|
|
1320
|
+
def render(self, context: dict[str, Any]) -> str:
|
|
1260
1321
|
return self.txt % context
|
|
1261
1322
|
|
|
1262
1323
|
|
|
1263
|
-
|
|
1264
|
-
|
|
1324
|
+
@dataclass(frozen=True)
|
|
1325
|
+
class _MakoTextTemplate(_TextTemplate):
|
|
1326
|
+
txt: str
|
|
1327
|
+
template: Template = field(init=False)
|
|
1328
|
+
|
|
1329
|
+
def __post_init__(self) -> None:
|
|
1265
1330
|
from mako.template import Template
|
|
1266
|
-
self.template = Template(txt, strict_undefined=True)
|
|
1267
1331
|
|
|
1268
|
-
|
|
1332
|
+
object.__setattr__(self, "template", Template(self.txt, strict_undefined=True))
|
|
1333
|
+
|
|
1334
|
+
@override
|
|
1335
|
+
def render(self, context: dict[str, Any]) -> str:
|
|
1269
1336
|
return self.template.render(**context)
|
|
1270
1337
|
|
|
1271
1338
|
|
|
@@ -1279,36 +1346,52 @@ class _ArgumentPlaceholder:
|
|
|
1279
1346
|
See also :class:`_TemplateRenderer.render_arg`.
|
|
1280
1347
|
"""
|
|
1281
1348
|
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1349
|
+
target_class: ClassVar[type[DtypedArgument]]
|
|
1350
|
+
|
|
1351
|
+
def __init__(self,
|
|
1352
|
+
typename: DTypeLike,
|
|
1353
|
+
name: str,
|
|
1354
|
+
**extra_kwargs: Any) -> None:
|
|
1355
|
+
self.typename: DTypeLike = typename
|
|
1356
|
+
self.name: str = name
|
|
1357
|
+
self.extra_kwargs: dict[str, Any] = extra_kwargs
|
|
1286
1358
|
|
|
1287
1359
|
|
|
1288
1360
|
class _VectorArgPlaceholder(_ArgumentPlaceholder):
|
|
1289
|
-
target_class = VectorArg
|
|
1361
|
+
target_class: ClassVar[type[DtypedArgument]] = VectorArg
|
|
1290
1362
|
|
|
1291
1363
|
|
|
1292
1364
|
class _ScalarArgPlaceholder(_ArgumentPlaceholder):
|
|
1293
|
-
target_class = ScalarArg
|
|
1365
|
+
target_class: ClassVar[type[DtypedArgument]] = ScalarArg
|
|
1294
1366
|
|
|
1295
1367
|
|
|
1296
1368
|
class _TemplateRenderer:
|
|
1297
|
-
def __init__(self,
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1369
|
+
def __init__(self,
|
|
1370
|
+
template: KernelTemplateBase,
|
|
1371
|
+
type_aliases: (
|
|
1372
|
+
dict[str, np.dtype[Any]]
|
|
1373
|
+
| Sequence[tuple[str, np.dtype[Any]]]),
|
|
1374
|
+
var_values: dict[str, str] | Sequence[tuple[str, str]],
|
|
1375
|
+
context: cl.Context | None = None,
|
|
1376
|
+
options: Any = None) -> None:
|
|
1377
|
+
self.template: KernelTemplateBase = template
|
|
1378
|
+
self.type_aliases: dict[str, np.dtype[Any]] = dict(type_aliases)
|
|
1379
|
+
self.var_dict: dict[str, str] = dict(var_values)
|
|
1302
1380
|
|
|
1303
1381
|
for name in self.var_dict:
|
|
1304
1382
|
if name.startswith("macro_"):
|
|
1305
|
-
self.var_dict[name] = _process_code_for_macro(
|
|
1306
|
-
|
|
1383
|
+
self.var_dict[name] = _process_code_for_macro(self.var_dict[name])
|
|
1384
|
+
|
|
1385
|
+
self.context: cl.Context | None = context
|
|
1386
|
+
self.options: Any = options
|
|
1307
1387
|
|
|
1308
|
-
|
|
1309
|
-
|
|
1388
|
+
@overload
|
|
1389
|
+
def __call__(self, txt: None) -> None: ...
|
|
1310
1390
|
|
|
1311
|
-
|
|
1391
|
+
@overload
|
|
1392
|
+
def __call__(self, txt: str) -> str: ...
|
|
1393
|
+
|
|
1394
|
+
def __call__(self, txt: str | None) -> str | None:
|
|
1312
1395
|
if txt is None:
|
|
1313
1396
|
return txt
|
|
1314
1397
|
|
|
@@ -1316,7 +1399,10 @@ class _TemplateRenderer:
|
|
|
1316
1399
|
|
|
1317
1400
|
return str(result)
|
|
1318
1401
|
|
|
1319
|
-
def get_rendered_kernel(self, txt, kernel_name):
|
|
1402
|
+
def get_rendered_kernel(self, txt: str, kernel_name: str) -> cl.Kernel:
|
|
1403
|
+
if self.context is None:
|
|
1404
|
+
raise ValueError("context not provided -- cannot render kernel")
|
|
1405
|
+
|
|
1320
1406
|
import pyopencl as cl
|
|
1321
1407
|
prg = cl.Program(self.context, self(txt)).build(self.options)
|
|
1322
1408
|
|
|
@@ -1326,7 +1412,7 @@ class _TemplateRenderer:
|
|
|
1326
1412
|
|
|
1327
1413
|
return getattr(prg, kernel_name)
|
|
1328
1414
|
|
|
1329
|
-
def parse_type(self, typename):
|
|
1415
|
+
def parse_type(self, typename: Any) -> np.dtype[Any]:
|
|
1330
1416
|
if isinstance(typename, str):
|
|
1331
1417
|
try:
|
|
1332
1418
|
return self.type_aliases[typename]
|
|
@@ -1336,21 +1422,22 @@ class _TemplateRenderer:
|
|
|
1336
1422
|
else:
|
|
1337
1423
|
return np.dtype(typename)
|
|
1338
1424
|
|
|
1339
|
-
def render_arg(self, arg_placeholder):
|
|
1425
|
+
def render_arg(self, arg_placeholder: _ArgumentPlaceholder) -> DtypedArgument:
|
|
1340
1426
|
return arg_placeholder.target_class(
|
|
1341
1427
|
self.parse_type(arg_placeholder.typename),
|
|
1342
1428
|
arg_placeholder.name,
|
|
1343
1429
|
**arg_placeholder.extra_kwargs)
|
|
1344
1430
|
|
|
1345
|
-
_C_COMMENT_FINDER = re.compile(r"/\*.*?\*/")
|
|
1431
|
+
_C_COMMENT_FINDER: ClassVar[re.Pattern[str]] = re.compile(r"/\*.*?\*/")
|
|
1346
1432
|
|
|
1347
|
-
def render_argument_list(self,
|
|
1348
|
-
|
|
1433
|
+
def render_argument_list(self,
|
|
1434
|
+
*arg_lists: Any,
|
|
1435
|
+
with_offset: bool = False,
|
|
1436
|
+
**kwargs: Any) -> list[Argument]:
|
|
1349
1437
|
if kwargs:
|
|
1350
1438
|
raise TypeError("unrecognized kwargs: " + ", ".join(kwargs))
|
|
1351
1439
|
|
|
1352
|
-
all_args = []
|
|
1353
|
-
|
|
1440
|
+
all_args: list[Any] = []
|
|
1354
1441
|
for arg_list in arg_lists:
|
|
1355
1442
|
if isinstance(arg_list, str):
|
|
1356
1443
|
arg_list = str(
|
|
@@ -1364,13 +1451,16 @@ class _TemplateRenderer:
|
|
|
1364
1451
|
all_args.extend(arg_list)
|
|
1365
1452
|
|
|
1366
1453
|
if with_offset:
|
|
1367
|
-
def vec_arg_factory(
|
|
1454
|
+
def vec_arg_factory(
|
|
1455
|
+
typename: DTypeLike,
|
|
1456
|
+
name: str) -> _VectorArgPlaceholder:
|
|
1368
1457
|
return _VectorArgPlaceholder(typename, name, with_offset=True)
|
|
1369
1458
|
else:
|
|
1370
1459
|
vec_arg_factory = _VectorArgPlaceholder
|
|
1371
1460
|
|
|
1372
1461
|
from pyopencl.compyte.dtypes import parse_c_arg_backend
|
|
1373
|
-
|
|
1462
|
+
|
|
1463
|
+
parsed_args: list[Argument] = []
|
|
1374
1464
|
for arg in all_args:
|
|
1375
1465
|
if isinstance(arg, str):
|
|
1376
1466
|
arg = arg.strip()
|
|
@@ -1379,21 +1469,26 @@ class _TemplateRenderer:
|
|
|
1379
1469
|
|
|
1380
1470
|
ph = parse_c_arg_backend(arg,
|
|
1381
1471
|
_ScalarArgPlaceholder, vec_arg_factory,
|
|
1382
|
-
name_to_dtype=lambda x: x)
|
|
1472
|
+
name_to_dtype=lambda x: x) # pyright: ignore[reportArgumentType]
|
|
1383
1473
|
parsed_arg = self.render_arg(ph)
|
|
1384
|
-
|
|
1385
1474
|
elif isinstance(arg, Argument):
|
|
1386
1475
|
parsed_arg = arg
|
|
1387
1476
|
elif isinstance(arg, tuple):
|
|
1477
|
+
assert isinstance(arg[0], str)
|
|
1478
|
+
assert isinstance(arg[1], str)
|
|
1388
1479
|
parsed_arg = ScalarArg(self.parse_type(arg[0]), arg[1])
|
|
1389
1480
|
else:
|
|
1390
|
-
raise TypeError("unexpected argument type:
|
|
1481
|
+
raise TypeError(f"unexpected argument type: {type(arg)}")
|
|
1391
1482
|
|
|
1392
1483
|
parsed_args.append(parsed_arg)
|
|
1393
1484
|
|
|
1394
1485
|
return parsed_args
|
|
1395
1486
|
|
|
1396
|
-
def get_type_decl_preamble(self,
|
|
1487
|
+
def get_type_decl_preamble(self,
|
|
1488
|
+
device: cl.Device,
|
|
1489
|
+
decl_type_names: Sequence[DTypeLike],
|
|
1490
|
+
arguments: Sequence[Argument] | None = None,
|
|
1491
|
+
) -> str:
|
|
1397
1492
|
cdl = _CDeclList(device)
|
|
1398
1493
|
|
|
1399
1494
|
for typename in decl_type_names:
|
|
@@ -1413,20 +1508,19 @@ class _TemplateRenderer:
|
|
|
1413
1508
|
return cdl.get_declarations() + "\n" + "\n".join(type_alias_decls)
|
|
1414
1509
|
|
|
1415
1510
|
|
|
1416
|
-
class KernelTemplateBase:
|
|
1417
|
-
def __init__(self, template_processor=None):
|
|
1418
|
-
self.template_processor = template_processor
|
|
1511
|
+
class KernelTemplateBase(ABC):
|
|
1512
|
+
def __init__(self, template_processor: str | None = None) -> None:
|
|
1513
|
+
self.template_processor: str | None = template_processor
|
|
1419
1514
|
|
|
1420
|
-
self.build_cache = {}
|
|
1515
|
+
self.build_cache: dict[Hashable, Any] = {}
|
|
1421
1516
|
_first_arg_dependent_caches.append(self.build_cache)
|
|
1422
1517
|
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
_TEMPLATE_PROCESSOR_PATTERN = re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
|
|
1518
|
+
_TEMPLATE_PROCESSOR_PATTERN: ClassVar[re.Pattern[str]] = (
|
|
1519
|
+
re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
|
|
1520
|
+
)
|
|
1427
1521
|
|
|
1428
1522
|
@memoize_method
|
|
1429
|
-
def get_text_template(self, txt):
|
|
1523
|
+
def get_text_template(self, txt: str) -> _TextTemplate:
|
|
1430
1524
|
proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt)
|
|
1431
1525
|
tpl_processor = None
|
|
1432
1526
|
|
|
@@ -1434,6 +1528,7 @@ class KernelTemplateBase:
|
|
|
1434
1528
|
tpl_processor = proc_match.group(1)
|
|
1435
1529
|
# chop off //CL// mark
|
|
1436
1530
|
txt = txt[len(proc_match.group(0)):]
|
|
1531
|
+
|
|
1437
1532
|
if tpl_processor is None:
|
|
1438
1533
|
tpl_processor = self.template_processor
|
|
1439
1534
|
|
|
@@ -1444,16 +1539,30 @@ class KernelTemplateBase:
|
|
|
1444
1539
|
elif tpl_processor == "mako":
|
|
1445
1540
|
return _MakoTextTemplate(txt)
|
|
1446
1541
|
else:
|
|
1447
|
-
raise RuntimeError(
|
|
1448
|
-
|
|
1542
|
+
raise RuntimeError(f"unknown template processor '{tpl_processor}'")
|
|
1543
|
+
|
|
1544
|
+
# TODO: this does not seem to be used anywhere -> deprecate / remove
|
|
1545
|
+
def get_preamble(self) -> str:
|
|
1546
|
+
return ""
|
|
1449
1547
|
|
|
1450
|
-
def get_renderer(self,
|
|
1548
|
+
def get_renderer(self,
|
|
1549
|
+
type_aliases: (
|
|
1550
|
+
dict[str, np.dtype[Any]]
|
|
1551
|
+
| Sequence[tuple[str, np.dtype[Any]]]),
|
|
1552
|
+
var_values: dict[str, str] | Sequence[tuple[str, str]],
|
|
1553
|
+
context: cl.Context | None = None, # pyright: ignore[reportUnusedParameter]
|
|
1554
|
+
options: Any = None, # pyright: ignore[reportUnusedParameter]
|
|
1555
|
+
) -> _TemplateRenderer:
|
|
1451
1556
|
return _TemplateRenderer(self, type_aliases, var_values)
|
|
1452
1557
|
|
|
1453
|
-
|
|
1454
|
-
|
|
1558
|
+
@abstractmethod
|
|
1559
|
+
def build_inner(self,
|
|
1560
|
+
context: cl.Context,
|
|
1561
|
+
*args: Any,
|
|
1562
|
+
**kwargs: Any) -> Callable[..., cl.Event]:
|
|
1563
|
+
pass
|
|
1455
1564
|
|
|
1456
|
-
def build(self, context, *args, **kwargs):
|
|
1565
|
+
def build(self, context: cl.Context, *args: Any, **kwargs: Any) -> Any:
|
|
1457
1566
|
"""Provide caching for an :meth:`build_inner`."""
|
|
1458
1567
|
|
|
1459
1568
|
cache_key = (context, args, tuple(sorted(kwargs.items())))
|
|
@@ -1469,41 +1578,47 @@ class KernelTemplateBase:
|
|
|
1469
1578
|
|
|
1470
1579
|
# {{{ array_module
|
|
1471
1580
|
|
|
1581
|
+
# TODO: this is not used anywhere: deprecate + remove
|
|
1582
|
+
|
|
1472
1583
|
class _CLFakeArrayModule:
|
|
1473
|
-
def __init__(self, queue):
|
|
1474
|
-
self.queue = queue
|
|
1584
|
+
def __init__(self, queue: cl.CommandQueue | None = None) -> None:
|
|
1585
|
+
self.queue: cl.CommandQueue | None = queue
|
|
1475
1586
|
|
|
1476
1587
|
@property
|
|
1477
|
-
def ndarray(self):
|
|
1588
|
+
def ndarray(self) -> type[CLArray]:
|
|
1478
1589
|
from pyopencl.array import Array
|
|
1479
1590
|
return Array
|
|
1480
1591
|
|
|
1481
|
-
def dot(self, x, y):
|
|
1592
|
+
def dot(self, x: CLArray, y: CLArray) -> NDArray[Any]:
|
|
1482
1593
|
from pyopencl.array import dot
|
|
1483
1594
|
return dot(x, y, queue=self.queue).get()
|
|
1484
1595
|
|
|
1485
|
-
def vdot(self, x, y):
|
|
1596
|
+
def vdot(self, x: CLArray, y: CLArray) -> NDArray[Any]:
|
|
1486
1597
|
from pyopencl.array import vdot
|
|
1487
1598
|
return vdot(x, y, queue=self.queue).get()
|
|
1488
1599
|
|
|
1489
|
-
def empty(self,
|
|
1600
|
+
def empty(self,
|
|
1601
|
+
shape: int | tuple[int, ...],
|
|
1602
|
+
dtype: DTypeLike,
|
|
1603
|
+
order: str = "C") -> CLArray:
|
|
1490
1604
|
from pyopencl.array import empty
|
|
1491
1605
|
return empty(self.queue, shape, dtype, order=order)
|
|
1492
1606
|
|
|
1493
|
-
def hstack(self, arrays):
|
|
1607
|
+
def hstack(self, arrays: Sequence[CLArray]) -> CLArray:
|
|
1494
1608
|
from pyopencl.array import hstack
|
|
1495
1609
|
return hstack(arrays, self.queue)
|
|
1496
1610
|
|
|
1497
1611
|
|
|
1498
|
-
def array_module(a):
|
|
1612
|
+
def array_module(a: Any) -> Any:
|
|
1499
1613
|
if isinstance(a, np.ndarray):
|
|
1500
1614
|
return np
|
|
1501
1615
|
else:
|
|
1502
1616
|
from pyopencl.array import Array
|
|
1617
|
+
|
|
1503
1618
|
if isinstance(a, Array):
|
|
1504
1619
|
return _CLFakeArrayModule(a.queue)
|
|
1505
1620
|
else:
|
|
1506
|
-
raise TypeError("array type not understood:
|
|
1621
|
+
raise TypeError(f"array type not understood: {type(a)}")
|
|
1507
1622
|
|
|
1508
1623
|
# }}}
|
|
1509
1624
|
|
|
@@ -1519,19 +1634,19 @@ def is_spirv(s: str | bytes) -> TypeIs[bytes]:
|
|
|
1519
1634
|
|
|
1520
1635
|
# {{{ numpy key types builder
|
|
1521
1636
|
|
|
1522
|
-
class _NumpyTypesKeyBuilder(KeyBuilderBase):
|
|
1523
|
-
def update_for_VectorArg(self, key_hash, key): # noqa: N802
|
|
1637
|
+
class _NumpyTypesKeyBuilder(KeyBuilderBase): # pyright: ignore[reportUnusedClass]
|
|
1638
|
+
def update_for_VectorArg(self, key_hash: Hash, key: VectorArg) -> None: # noqa: N802
|
|
1524
1639
|
self.rec(key_hash, key.dtype)
|
|
1525
1640
|
self.update_for_str(key_hash, key.name)
|
|
1526
1641
|
self.rec(key_hash, key.with_offset)
|
|
1527
1642
|
|
|
1528
|
-
|
|
1643
|
+
@override
|
|
1644
|
+
def update_for_type(self, key_hash: Hash, key: type) -> None:
|
|
1529
1645
|
if issubclass(key, np.generic):
|
|
1530
1646
|
self.update_for_str(key_hash, key.__name__)
|
|
1531
1647
|
return
|
|
1532
1648
|
|
|
1533
|
-
raise TypeError("unsupported type for persistent hash keying:
|
|
1534
|
-
% type(key))
|
|
1649
|
+
raise TypeError(f"unsupported type for persistent hash keying: {key}")
|
|
1535
1650
|
|
|
1536
1651
|
# }}}
|
|
1537
1652
|
|