pyopencl 2025.1__cp312-cp312-win_amd64.whl → 2025.2.1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

pyopencl/scan.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Scan primitive."""
2
+ from __future__ import annotations
2
3
 
3
4
 
4
5
  __copyright__ = """
@@ -25,7 +26,7 @@ Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
25
26
  import logging
26
27
  from abc import ABC, abstractmethod
27
28
  from dataclasses import dataclass
28
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
29
+ from typing import Any
29
30
 
30
31
  import numpy as np
31
32
 
@@ -33,7 +34,7 @@ from pytools.persistent_dict import WriteOncePersistentDict
33
34
 
34
35
  import pyopencl as cl
35
36
  import pyopencl._mymako as mako
36
- import pyopencl.array
37
+ import pyopencl.array as cl_array
37
38
  from pyopencl._cluda import CLUDA_PREAMBLE
38
39
  from pyopencl.tools import (
39
40
  DtypedArgument,
@@ -848,7 +849,7 @@ def _make_template(s: str):
848
849
  import re
849
850
  leftovers = set()
850
851
 
851
- def replace_id(match: "re.Match") -> str:
852
+ def replace_id(match: re.Match) -> str:
852
853
  # avoid name clashes with user code by adding 'psc_' prefix to
853
854
  # identifiers.
854
855
 
@@ -874,11 +875,11 @@ def _make_template(s: str):
874
875
  class _GeneratedScanKernelInfo:
875
876
  scan_src: str
876
877
  kernel_name: str
877
- scalar_arg_dtypes: List[Optional[np.dtype]]
878
+ scalar_arg_dtypes: list[np.dtype | None]
878
879
  wg_size: int
879
880
  k_group_size: int
880
881
 
881
- def build(self, context: cl.Context, options: Any) -> "_BuiltScanKernelInfo":
882
+ def build(self, context: cl.Context, options: Any) -> _BuiltScanKernelInfo:
882
883
  program = cl.Program(context, self.scan_src).build(options)
883
884
  kernel = getattr(program, self.kernel_name)
884
885
  kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
@@ -899,12 +900,12 @@ class _BuiltScanKernelInfo:
899
900
  class _GeneratedFinalUpdateKernelInfo:
900
901
  source: str
901
902
  kernel_name: str
902
- scalar_arg_dtypes: List[Optional[np.dtype]]
903
+ scalar_arg_dtypes: list[np.dtype | None]
903
904
  update_wg_size: int
904
905
 
905
906
  def build(self,
906
907
  context: cl.Context,
907
- options: Any) -> "_BuiltFinalUpdateKernelInfo":
908
+ options: Any) -> _BuiltFinalUpdateKernelInfo:
908
909
  program = cl.Program(context, self.source).build(options)
909
910
  kernel = getattr(program, self.kernel_name)
910
911
  kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
@@ -930,18 +931,18 @@ class GenericScanKernelBase(ABC):
930
931
  self,
931
932
  ctx: cl.Context,
932
933
  dtype: Any,
933
- arguments: Union[str, List[DtypedArgument]],
934
+ arguments: str | list[DtypedArgument],
934
935
  input_expr: str,
935
936
  scan_expr: str,
936
- neutral: Optional[str],
937
+ neutral: str | None,
937
938
  output_statement: str,
938
- is_segment_start_expr: Optional[str] = None,
939
- input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
939
+ is_segment_start_expr: str | None = None,
940
+ input_fetch_exprs: list[tuple[str, str, int]] | None = None,
940
941
  index_dtype: Any = None,
941
942
  name_prefix: str = "scan",
942
943
  options: Any = None,
943
944
  preamble: str = "",
944
- devices: Optional[cl.Device] = None) -> None:
945
+ devices: cl.Device | None = None) -> None:
945
946
  """
946
947
  :arg ctx: a :class:`pyopencl.Context` within which the code
947
948
  for this scan kernel will be generated.
@@ -1142,7 +1143,7 @@ class GenericScanKernelBase(ABC):
1142
1143
 
1143
1144
  if not cl._PYOPENCL_NO_CACHE:
1144
1145
  generic_scan_kernel_cache: WriteOncePersistentDict[Any,
1145
- Tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
1146
+ tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
1146
1147
  _GeneratedFinalUpdateKernelInfo]] = \
1147
1148
  WriteOncePersistentDict(
1148
1149
  "pyopencl-generated-scan-kernel-cache-v1",
@@ -1329,7 +1330,7 @@ class GenericScanKernel(GenericScanKernelBase):
1329
1330
  VectorArg(self.dtype, "interval_sums"),
1330
1331
  ]
1331
1332
 
1332
- second_level_build_kwargs: Dict[str, Optional[str]] = {}
1333
+ second_level_build_kwargs: dict[str, str | None] = {}
1333
1334
  if self.is_segmented:
1334
1335
  second_level_arguments.append(
1335
1336
  VectorArg(self.index_dtype,
@@ -1401,7 +1402,7 @@ class GenericScanKernel(GenericScanKernelBase):
1401
1402
  for arg in self.parsed_args:
1402
1403
  arg_dtypes[arg.name] = arg.dtype
1403
1404
 
1404
- fetch_expr_offsets: Dict[str, Set] = {}
1405
+ fetch_expr_offsets: dict[str, set] = {}
1405
1406
  for _name, arg_name, ife_offset in self.input_fetch_exprs:
1406
1407
  fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
1407
1408
 
@@ -1427,10 +1428,10 @@ class GenericScanKernel(GenericScanKernelBase):
1427
1428
  def generate_scan_kernel(
1428
1429
  self,
1429
1430
  max_wg_size: int,
1430
- arguments: List[DtypedArgument],
1431
+ arguments: list[DtypedArgument],
1431
1432
  input_expr: str,
1432
- is_segment_start_expr: Optional[str],
1433
- input_fetch_exprs: List[Tuple[str, str, int]],
1433
+ is_segment_start_expr: str | None,
1434
+ input_fetch_exprs: list[tuple[str, str, int]],
1434
1435
  is_first_level: bool,
1435
1436
  store_segment_start_flags: bool,
1436
1437
  k_group_size: int,
@@ -1527,7 +1528,7 @@ class GenericScanKernel(GenericScanKernelBase):
1527
1528
  return cl.enqueue_marker(queue, wait_for=wait_for)
1528
1529
 
1529
1530
  data_args = []
1530
- for arg_descr, arg_val in zip(self.parsed_args, args):
1531
+ for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
1531
1532
  from pyopencl.tools import VectorArg
1532
1533
  if isinstance(arg_descr, VectorArg):
1533
1534
  data_args.append(arg_val.base_data)
@@ -1552,16 +1553,16 @@ class GenericScanKernel(GenericScanKernelBase):
1552
1553
 
1553
1554
  # {{{ allocate some buffers
1554
1555
 
1555
- interval_results = cl.array.empty(queue,
1556
+ interval_results = cl_array.empty(queue,
1556
1557
  num_intervals, dtype=self.dtype,
1557
1558
  allocator=allocator)
1558
1559
 
1559
- partial_scan_buffer = cl.array.empty(
1560
+ partial_scan_buffer = cl_array.empty(
1560
1561
  queue, n, dtype=self.dtype,
1561
1562
  allocator=allocator)
1562
1563
 
1563
1564
  if self.store_segment_start_flags:
1564
- segment_start_flags = cl.array.empty(
1565
+ segment_start_flags = cl_array.empty(
1565
1566
  queue, n, dtype=np.bool_,
1566
1567
  allocator=allocator)
1567
1568
 
@@ -1575,7 +1576,7 @@ class GenericScanKernel(GenericScanKernelBase):
1575
1576
  ]
1576
1577
 
1577
1578
  if self.is_segmented:
1578
- first_segment_start_in_interval = cl.array.empty(queue,
1579
+ first_segment_start_in_interval = cl_array.empty(queue,
1579
1580
  num_intervals, dtype=self.index_dtype,
1580
1581
  allocator=allocator)
1581
1582
  scan1_args.append(first_segment_start_in_interval.data)
@@ -1755,13 +1756,13 @@ class GenericDebugScanKernel(GenericScanKernelBase):
1755
1756
  if n is None:
1756
1757
  n, = first_array.shape
1757
1758
 
1758
- scan_tmp = cl.array.empty(queue,
1759
+ scan_tmp = cl_array.empty(queue,
1759
1760
  n, dtype=self.dtype,
1760
1761
  allocator=allocator)
1761
1762
 
1762
1763
  data_args = [scan_tmp.data]
1763
1764
  from pyopencl.tools import VectorArg
1764
- for arg_descr, arg_val in zip(self.parsed_args, args):
1765
+ for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
1765
1766
  if isinstance(arg_descr, VectorArg):
1766
1767
  data_args.append(arg_val.base_data)
1767
1768
  if arg_descr.with_offset:
@@ -1806,7 +1807,7 @@ class _LegacyScanKernelBase(GenericScanKernel):
1806
1807
  output_ary = input_ary
1807
1808
 
1808
1809
  if isinstance(output_ary, (str, str)) and output_ary == "new":
1809
- output_ary = cl.array.empty_like(input_ary, allocator=allocator)
1810
+ output_ary = cl_array.empty_like(input_ary, allocator=allocator)
1810
1811
 
1811
1812
  if input_ary.shape != output_ary.shape:
1812
1813
  raise ValueError("input and output must have the same shape")
@@ -1841,13 +1842,13 @@ class ExclusiveScanKernel(_LegacyScanKernelBase):
1841
1842
  class ScanTemplate(KernelTemplateBase):
1842
1843
  def __init__(
1843
1844
  self,
1844
- arguments: Union[str, List[DtypedArgument]],
1845
+ arguments: str | list[DtypedArgument],
1845
1846
  input_expr: str,
1846
1847
  scan_expr: str,
1847
- neutral: Optional[str],
1848
+ neutral: str | None,
1848
1849
  output_statement: str,
1849
- is_segment_start_expr: Optional[str] = None,
1850
- input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
1850
+ is_segment_start_expr: str | None = None,
1851
+ input_fetch_exprs: list[tuple[str, str, int]] | None = None,
1851
1852
  name_prefix: str = "scan",
1852
1853
  preamble: str = "",
1853
1854
  template_processor: Any = None) -> None:
pyopencl/tools.py CHANGED
@@ -100,6 +100,8 @@ Type aliases
100
100
  See :class:`pyopencl.tools.AllocatorBase`.
101
101
  """
102
102
 
103
+ from __future__ import annotations
104
+
103
105
 
104
106
  __copyright__ = "Copyright (C) 2010 Andreas Kloeckner"
105
107
 
@@ -126,18 +128,27 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
126
128
  OTHER DEALINGS IN THE SOFTWARE.
127
129
  """
128
130
 
131
+
129
132
  import re
130
133
  from abc import ABC, abstractmethod
134
+ from dataclasses import dataclass
131
135
  from sys import intern
132
- from typing import Any, List, Optional, Union
136
+ from typing import (
137
+ TYPE_CHECKING,
138
+ Any,
139
+ Concatenate,
140
+ ParamSpec,
141
+ TypeVar,
142
+ )
133
143
 
134
144
  import numpy as np
145
+ from typing_extensions import TypeIs, override
135
146
 
136
147
  from pytools import memoize, memoize_method
137
148
  from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
138
149
 
139
- from pyopencl._cl import bitlog2, get_cl_header_version # noqa: F401
140
- from pyopencl.compyte.dtypes import ( # noqa: F401
150
+ from pyopencl._cl import bitlog2, get_cl_header_version
151
+ from pyopencl.compyte.dtypes import (
141
152
  TypeNameNotKnown,
142
153
  dtype_to_ctype,
143
154
  get_or_register_dtype,
@@ -145,6 +156,12 @@ from pyopencl.compyte.dtypes import ( # noqa: F401
145
156
  )
146
157
 
147
158
 
159
+ if TYPE_CHECKING:
160
+ from collections.abc import Callable, Hashable, Sequence
161
+
162
+ from numpy.typing import DTypeLike
163
+
164
+
148
165
  # Do not add a pyopencl import here: This will add an import cycle.
149
166
 
150
167
 
@@ -460,11 +477,17 @@ if get_cl_header_version() >= (2, 0):
460
477
 
461
478
  # {{{ first-arg caches
462
479
 
463
- _first_arg_dependent_caches = []
480
+ _first_arg_dependent_caches: list[dict[Hashable, object]] = []
481
+
482
+
483
+ RetT = TypeVar("RetT")
484
+ P = ParamSpec("P")
464
485
 
465
486
 
466
- def first_arg_dependent_memoize(func):
467
- def wrapper(cl_object, *args, **kwargs):
487
+ def first_arg_dependent_memoize(
488
+ func: Callable[Concatenate[object, P], RetT]
489
+ ) -> Callable[Concatenate[object, P], RetT]:
490
+ def wrapper(cl_object: Hashable, *args: P.args, **kwargs: P.kwargs) -> RetT:
468
491
  """Provides memoization for a function. Typically used to cache
469
492
  things that get created inside a :class:`pyopencl.Context`, e.g. programs
470
493
  and kernels. Assumes that the first argument of the decorated function is
@@ -479,12 +502,13 @@ def first_arg_dependent_memoize(func):
479
502
  else:
480
503
  cache_key = (args,)
481
504
 
505
+ ctx_dict: dict[Hashable, dict[Hashable, RetT]]
482
506
  try:
483
- ctx_dict = func._pyopencl_first_arg_dep_memoize_dic
507
+ ctx_dict = func._pyopencl_first_arg_dep_memoize_dic # pyright: ignore[reportFunctionMemberAccess]
484
508
  except AttributeError:
485
509
  # FIXME: This may keep contexts alive longer than desired.
486
510
  # But I guess since the memory in them is freed, who cares.
487
- ctx_dict = func._pyopencl_first_arg_dep_memoize_dic = {}
511
+ ctx_dict = func._pyopencl_first_arg_dep_memoize_dic = {} # pyright: ignore[reportFunctionMemberAccess]
488
512
  _first_arg_dependent_caches.append(ctx_dict)
489
513
 
490
514
  try:
@@ -512,8 +536,6 @@ def first_arg_dependent_memoize_nested(nested_func):
512
536
  :func:`clear_first_arg_caches`.
513
537
 
514
538
  .. versionadded:: 2013.1
515
-
516
- Requires Python 2.5 or newer.
517
539
  """
518
540
 
519
541
  from functools import wraps
@@ -567,6 +589,10 @@ def clear_first_arg_caches():
567
589
  import atexit
568
590
 
569
591
 
592
+ if TYPE_CHECKING:
593
+ import pyopencl as cl
594
+
595
+
570
596
  atexit.register(clear_first_arg_caches)
571
597
 
572
598
  # }}}
@@ -575,7 +601,9 @@ atexit.register(clear_first_arg_caches)
575
601
  # {{{ pytest fixtures
576
602
 
577
603
  class _ContextFactory:
578
- def __init__(self, device):
604
+ device: cl.Device
605
+
606
+ def __init__(self, device: cl.Device):
579
607
  self.device = device
580
608
 
581
609
  def __call__(self):
@@ -590,14 +618,37 @@ class _ContextFactory:
590
618
  import pyopencl as cl
591
619
  return cl.Context([self.device])
592
620
 
593
- def __str__(self):
621
+ @override
622
+ def __str__(self) -> str:
594
623
  # Don't show address, so that parallel test collection works
595
624
  return ("<context factory for <pyopencl.Device '%s' on '%s'>>" %
596
625
  (self.device.name.strip(),
597
626
  self.device.platform.name.strip()))
598
627
 
599
628
 
600
- def get_test_platforms_and_devices(plat_dev_string=None):
629
+ DeviceOrPlatformT = TypeVar("DeviceOrPlatformT", "cl.Device", "cl.Platform")
630
+
631
+
632
+ def _find_cl_obj(
633
+ objs: Sequence[DeviceOrPlatformT],
634
+ identifier: str
635
+ ) -> DeviceOrPlatformT:
636
+ try:
637
+ num = int(identifier)
638
+ except Exception:
639
+ pass
640
+ else:
641
+ return objs[num]
642
+
643
+ for obj in objs:
644
+ if identifier.lower() in (obj.name + " " + obj.vendor).lower():
645
+ return obj
646
+ raise RuntimeError("object '%s' not found" % identifier)
647
+
648
+
649
+ def get_test_platforms_and_devices(
650
+ plat_dev_string: str | None = None
651
+ ):
601
652
  """Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'.
602
653
 
603
654
  :return: list of tuples (platform, [device, device, ...])
@@ -609,29 +660,14 @@ def get_test_platforms_and_devices(plat_dev_string=None):
609
660
  import os
610
661
  plat_dev_string = os.environ.get("PYOPENCL_TEST", None)
611
662
 
612
- def find_cl_obj(objs, identifier):
613
- try:
614
- num = int(identifier)
615
- except Exception:
616
- pass
617
- else:
618
- return objs[num]
619
-
620
- found = False
621
- for obj in objs:
622
- if identifier.lower() in (obj.name + " " + obj.vendor).lower():
623
- return obj
624
- if not found:
625
- raise RuntimeError("object '%s' not found" % identifier)
626
-
627
663
  if plat_dev_string:
628
- result = []
664
+ result: list[tuple[cl.Platform, list[cl.Device]]] = []
629
665
 
630
666
  for entry in plat_dev_string.split(";"):
631
667
  lhsrhs = entry.split(":")
632
668
 
633
669
  if len(lhsrhs) == 1:
634
- platform = find_cl_obj(cl.get_platforms(), lhsrhs[0])
670
+ platform = _find_cl_obj(cl.get_platforms(), lhsrhs[0])
635
671
  result.append((platform, platform.get_devices()))
636
672
 
637
673
  elif len(lhsrhs) != 2:
@@ -639,11 +675,11 @@ def get_test_platforms_and_devices(plat_dev_string=None):
639
675
  else:
640
676
  plat_str, dev_strs = lhsrhs
641
677
 
642
- platform = find_cl_obj(cl.get_platforms(), plat_str)
678
+ platform = _find_cl_obj(cl.get_platforms(), plat_str)
643
679
  devs = platform.get_devices()
644
680
  result.append(
645
681
  (platform,
646
- [find_cl_obj(devs, dev_id)
682
+ [_find_cl_obj(devs, dev_id)
647
683
  for dev_id in dev_strs.split(",")]))
648
684
 
649
685
  return result
@@ -748,44 +784,33 @@ class Argument(ABC):
748
784
  pass
749
785
 
750
786
 
751
- class DtypedArgument(Argument):
787
+ @dataclass(frozen=True, init=False)
788
+ class DtypedArgument(Argument, ABC):
752
789
  """
753
- .. attribute:: name
754
- .. attribute:: dtype
790
+ .. autoattribute:: name
791
+ .. autoattribute:: dtype
755
792
  """
793
+ dtype: np.dtype[Any]
794
+ name: str
756
795
 
757
- def __init__(self, dtype: Any, name: str) -> None:
758
- self.dtype = np.dtype(dtype)
759
- self.name = name
760
-
761
- def __repr__(self) -> str:
762
- return "{}({!r}, {})".format(
763
- self.__class__.__name__,
764
- self.name,
765
- self.dtype)
766
-
767
- def __eq__(self, other: Any) -> bool:
768
- return (type(self) is type(other)
769
- and self.dtype == other.dtype
770
- and self.name == other.name)
771
-
772
- def __hash__(self) -> int:
773
- return (
774
- hash(type(self))
775
- ^ hash(self.dtype)
776
- ^ hash(self.name))
796
+ def __init__(self, dtype: DTypeLike, name: str) -> None:
797
+ object.__setattr__(self, "name", name)
798
+ object.__setattr__(self, "dtype", np.dtype(dtype))
777
799
 
778
800
 
801
+ @dataclass(frozen=True)
779
802
  class VectorArg(DtypedArgument):
780
803
  """Inherits from :class:`DtypedArgument`.
781
804
 
782
805
  .. automethod:: __init__
783
806
  """
807
+ with_offset: bool
784
808
 
785
- def __init__(self, dtype: Any, name: str, with_offset: bool = False) -> None:
786
- super().__init__(dtype, name)
787
- self.with_offset = with_offset
809
+ def __init__(self, name: str, dtype: DTypeLike, with_offset: bool = False):
810
+ super().__init__(name, dtype)
811
+ object.__setattr__(self, "with_offset", with_offset)
788
812
 
813
+ @override
789
814
  def declarator(self) -> str:
790
815
  if self.with_offset:
791
816
  # Two underscores -> less likelihood of a name clash.
@@ -796,40 +821,25 @@ class VectorArg(DtypedArgument):
796
821
 
797
822
  return result
798
823
 
799
- def __eq__(self, other) -> bool:
800
- return (super().__eq__(other)
801
- and self.with_offset == other.with_offset)
802
-
803
- def __hash__(self) -> int:
804
- return super().__hash__() ^ hash(self.with_offset)
805
-
806
824
 
825
+ @dataclass(frozen=True, init=False)
807
826
  class ScalarArg(DtypedArgument):
808
827
  """Inherits from :class:`DtypedArgument`."""
809
828
 
810
- def declarator(self):
829
+ @override
830
+ def declarator(self) -> str:
811
831
  return "{} {}".format(dtype_to_ctype(self.dtype), self.name)
812
832
 
813
833
 
834
+ @dataclass(frozen=True)
814
835
  class OtherArg(Argument):
815
- def __init__(self, declarator: str, name: str) -> None:
816
- self.decl = declarator
817
- self.name = name
836
+ decl: str
837
+ name: str
818
838
 
839
+ @override
819
840
  def declarator(self) -> str:
820
841
  return self.decl
821
842
 
822
- def __eq__(self, other) -> bool:
823
- return (type(self) is type(other)
824
- and self.decl == other.decl
825
- and self.name == other.name)
826
-
827
- def __hash__(self) -> int:
828
- return (
829
- hash(type(self))
830
- ^ hash(self.decl)
831
- ^ hash(self.name))
832
-
833
843
 
834
844
  def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument:
835
845
  for aspace in ["__local", "__constant"]:
@@ -850,8 +860,8 @@ def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument:
850
860
 
851
861
 
852
862
  def parse_arg_list(
853
- arguments: Union[str, List[str], List[DtypedArgument]],
854
- with_offset: bool = False) -> List[DtypedArgument]:
863
+ arguments: str | list[str] | list[DtypedArgument],
864
+ with_offset: bool = False) -> list[DtypedArgument]:
855
865
  """Parse a list of kernel arguments. *arguments* may be a comma-separate
856
866
  list of C declarators in a string, a list of strings representing C
857
867
  declarators, or :class:`Argument` objects.
@@ -860,7 +870,7 @@ def parse_arg_list(
860
870
  if isinstance(arguments, str):
861
871
  arguments = arguments.split(",")
862
872
 
863
- def parse_single_arg(obj: Union[str, DtypedArgument]) -> DtypedArgument:
873
+ def parse_single_arg(obj: str | DtypedArgument) -> DtypedArgument:
864
874
  if isinstance(obj, str):
865
875
  from pyopencl.tools import parse_c_arg
866
876
  return parse_c_arg(obj, with_offset=with_offset)
@@ -886,9 +896,9 @@ def get_arg_list_arg_types(arg_types):
886
896
 
887
897
 
888
898
  def get_arg_list_scalar_arg_dtypes(
889
- arg_types: List[DtypedArgument]
890
- ) -> List[Optional[np.dtype]]:
891
- result: List[Optional[np.dtype]] = []
899
+ arg_types: Sequence[DtypedArgument]
900
+ ) -> list[np.dtype | None]:
901
+ result: list[np.dtype | None] = []
892
902
 
893
903
  for arg_type in arg_types:
894
904
  if isinstance(arg_type, ScalarArg):
@@ -903,8 +913,8 @@ def get_arg_list_scalar_arg_dtypes(
903
913
  return result
904
914
 
905
915
 
906
- def get_arg_offset_adjuster_code(arg_types):
907
- result = []
916
+ def get_arg_offset_adjuster_code(arg_types: Sequence[Argument]) -> str:
917
+ result: list[str] = []
908
918
 
909
919
  for arg_type in arg_types:
910
920
  if isinstance(arg_type, VectorArg) and arg_type.with_offset:
@@ -1181,7 +1191,8 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
1181
1191
  def calc_field_type():
1182
1192
  total_size = 0
1183
1193
  padding_count = 0
1184
- for offset, (field_name, (field_dtype, _)) in zip(offsets, fields):
1194
+ for offset, (field_name, (field_dtype, _)) in zip(
1195
+ offsets, fields, strict=True):
1185
1196
  if offset > total_size:
1186
1197
  padding_count += 1
1187
1198
  yield ("__pycl_padding%d" % padding_count,
@@ -1497,7 +1508,7 @@ def array_module(a):
1497
1508
  # }}}
1498
1509
 
1499
1510
 
1500
- def is_spirv(s):
1511
+ def is_spirv(s: str | bytes) -> TypeIs[bytes]:
1501
1512
  spirv_magic = b"\x07\x23\x02\x03"
1502
1513
  return (
1503
1514
  isinstance(s, bytes)
@@ -1524,4 +1535,31 @@ class _NumpyTypesKeyBuilder(KeyBuilderBase):
1524
1535
 
1525
1536
  # }}}
1526
1537
 
1538
+
1539
+ __all__ = [
1540
+ "AllocatorBase",
1541
+ "AllocatorBase",
1542
+ "Argument",
1543
+ "DeferredAllocator",
1544
+ "DtypedArgument",
1545
+ "ImmediateAllocator",
1546
+ "MemoryPool",
1547
+ "OtherArg",
1548
+ "PooledBuffer",
1549
+ "PooledSVM",
1550
+ "SVMAllocator",
1551
+ "SVMPool",
1552
+ "ScalarArg",
1553
+ "TypeNameNotKnown",
1554
+ "VectorArg",
1555
+ "bitlog2",
1556
+ "clear_first_arg_caches",
1557
+ "dtype_to_ctype",
1558
+ "first_arg_dependent_memoize",
1559
+ "get_or_register_dtype",
1560
+ "parse_arg_list",
1561
+ "pytest_generate_tests_for_pyopencl",
1562
+ "register_dtype",
1563
+ ]
1564
+
1527
1565
  # vim: foldmethod=marker
pyopencl/typing.py ADDED
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ __copyright__ = "Copyright (C) 2025 University of Illinois Board of Trustees"
5
+
6
+ __license__ = """
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in
15
+ all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23
+ THE SOFTWARE.
24
+ """
25
+
26
+ from collections.abc import Sequence
27
+ from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
28
+
29
+ import numpy as np
30
+ from numpy.typing import NDArray
31
+ from typing_extensions import Buffer as abc_Buffer
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ import pyopencl as _cl
36
+
37
+
38
+ DTypeT = TypeVar("DTypeT", bound=np.dtype[Any])
39
+
40
+ HasBufferInterface: TypeAlias = abc_Buffer | NDArray[Any]
41
+ SVMInnerT = TypeVar("SVMInnerT", bound=HasBufferInterface)
42
+ WaitList: TypeAlias = Sequence["_cl.Event"] | None
43
+ KernelArg: TypeAlias = """
44
+ int
45
+ | float
46
+ | complex
47
+ | HasBufferInterface
48
+ | np.generic
49
+ | _cl.Buffer
50
+ | _cl.Image
51
+ | _cl.Sampler
52
+ | _cl.SVMPointer"""
pyopencl/version.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import re
2
4
  from importlib import metadata
3
5