pyopencl 2025.2.5__cp310-cp310-macosx_11_0_arm64.whl → 2025.2.7__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

pyopencl/tools.py CHANGED
@@ -128,23 +128,28 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
128
128
  OTHER DEALINGS IN THE SOFTWARE.
129
129
  """
130
130
 
131
-
131
+ import atexit
132
132
  import re
133
133
  from abc import ABC, abstractmethod
134
- from dataclasses import dataclass
134
+ from dataclasses import dataclass, field
135
135
  from sys import intern
136
136
  from typing import (
137
137
  TYPE_CHECKING,
138
138
  Any,
139
+ ClassVar,
139
140
  Concatenate,
140
141
  ParamSpec,
142
+ TypeAlias,
143
+ TypedDict,
141
144
  TypeVar,
145
+ cast,
146
+ overload,
142
147
  )
143
148
 
144
149
  import numpy as np
145
150
  from typing_extensions import TypeIs, override
146
151
 
147
- from pytools import memoize, memoize_method
152
+ from pytools import Hash, memoize, memoize_method
148
153
  from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
149
154
 
150
155
  from pyopencl._cl import bitlog2, get_cl_header_version
@@ -157,10 +162,11 @@ from pyopencl.compyte.dtypes import (
157
162
 
158
163
 
159
164
  if TYPE_CHECKING:
160
- from collections.abc import Callable, Hashable, Sequence
161
-
162
- from numpy.typing import DTypeLike
165
+ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
163
166
 
167
+ import pytest
168
+ from mako.template import Template
169
+ from numpy.typing import DTypeLike, NDArray
164
170
 
165
171
  # Do not add a pyopencl import here: This will add an import cycle.
166
172
 
@@ -373,7 +379,7 @@ def _monkeypatch_svm_docstrings():
373
379
 
374
380
  # {{{ PooledSVM
375
381
 
376
- PooledSVM.__doc__ = ( # pylint: disable=possibly-used-before-assignment
382
+ PooledSVM.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
377
383
  """An object representing a :class:`SVMPool`-based allocation of
378
384
  :ref:`svm`. Analogous to :class:`~pyopencl.SVMAllocation`, however once
379
385
  this object is deleted, its associated device memory is returned to the
@@ -420,7 +426,7 @@ def _monkeypatch_svm_docstrings():
420
426
 
421
427
  # {{{ SVMAllocator
422
428
 
423
- SVMAllocator.__doc__ = ( # pylint: disable=possibly-used-before-assignment
429
+ SVMAllocator.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
424
430
  """
425
431
  .. versionadded:: 2022.2
426
432
 
@@ -453,7 +459,7 @@ def _monkeypatch_svm_docstrings():
453
459
 
454
460
  # {{{ SVMPool
455
461
 
456
- SVMPool.__doc__ = ( # pylint: disable=possibly-used-before-assignment
462
+ SVMPool.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
457
463
  remove_common_indentation("""
458
464
  A memory pool for OpenCL device memory in :ref:`SVM <svm>` form.
459
465
  *allocator* must be an instance of :class:`SVMAllocator`.
@@ -477,17 +483,18 @@ if get_cl_header_version() >= (2, 0):
477
483
 
478
484
  # {{{ first-arg caches
479
485
 
480
- _first_arg_dependent_caches: list[dict[Hashable, object]] = []
486
+ _first_arg_dependent_caches: list[Mapping[Hashable, object]] = []
481
487
 
482
488
 
489
+ HashableT = TypeVar("HashableT", bound="Hashable")
483
490
  RetT = TypeVar("RetT")
484
491
  P = ParamSpec("P")
485
492
 
486
493
 
487
494
  def first_arg_dependent_memoize(
488
- func: Callable[Concatenate[object, P], RetT]
489
- ) -> Callable[Concatenate[object, P], RetT]:
490
- def wrapper(cl_object: Hashable, *args: P.args, **kwargs: P.kwargs) -> RetT:
495
+ func: Callable[Concatenate[HashableT, P], RetT]
496
+ ) -> Callable[Concatenate[HashableT, P], RetT]:
497
+ def wrapper(cl_object: HashableT, *args: P.args, **kwargs: P.kwargs) -> RetT:
491
498
  """Provides memoization for a function. Typically used to cache
492
499
  things that get created inside a :class:`pyopencl.Context`, e.g. programs
493
500
  and kernels. Assumes that the first argument of the decorated function is
@@ -527,33 +534,41 @@ def first_arg_dependent_memoize(
527
534
  context_dependent_memoize = first_arg_dependent_memoize
528
535
 
529
536
 
530
- def first_arg_dependent_memoize_nested(nested_func):
531
- """Provides memoization for nested functions. Typically used to cache
532
- things that get created inside a :class:`pyopencl.Context`, e.g. programs
533
- and kernels. Assumes that the first argument of the decorated function is
534
- an OpenCL object that might go away, such as a :class:`pyopencl.Context` or
535
- a :class:`pyopencl.CommandQueue`, and will therefore respond to
536
- :func:`clear_first_arg_caches`.
537
+ def first_arg_dependent_memoize_nested(
538
+ nested_func: Callable[Concatenate[Hashable, P], RetT]
539
+ ) -> Callable[Concatenate[Hashable, P], RetT]:
540
+ """Provides memoization for nested functions.
541
+
542
+ Typically used to cache things that get created inside a
543
+ :class:`pyopencl.Context`, e.g. programs and kernels. Assumes that the first
544
+ argument of the decorated function is an OpenCL object that might go away,
545
+ such as a :class:`pyopencl.Context` or a :class:`pyopencl.CommandQueue`, and
546
+ will therefore respond to :func:`clear_first_arg_caches`.
537
547
 
538
548
  .. versionadded:: 2013.1
539
549
  """
540
550
 
541
551
  from functools import wraps
542
- cache_dict_name = intern("_memoize_inner_dic_%s_%s_%d"
543
- % (nested_func.__name__, nested_func.__code__.co_filename,
544
- nested_func.__code__.co_firstlineno))
552
+ cache_dict_name = intern(
553
+ f"_memoize_inner_dic_{nested_func.__name__}_"
554
+ f"{nested_func.__code__.co_filename}_"
555
+ f"{nested_func.__code__.co_firstlineno}")
545
556
 
546
557
  from inspect import currentframe
547
558
 
548
559
  # prevent ref cycle
549
- try:
550
- caller_frame = currentframe().f_back
551
- cache_context = caller_frame.f_globals[
552
- caller_frame.f_code.co_name]
553
- finally:
554
- # del caller_frame
555
- pass
556
-
560
+ frame = currentframe()
561
+ cache_context = None
562
+ if frame:
563
+ try:
564
+ caller_frame = frame.f_back
565
+ if caller_frame:
566
+ cache_context = caller_frame.f_globals[caller_frame.f_code.co_name]
567
+ finally:
568
+ # del caller_frame
569
+ pass
570
+
571
+ cache_dict: dict[Hashable, dict[Hashable, RetT]]
557
572
  try:
558
573
  cache_dict = getattr(cache_context, cache_dict_name)
559
574
  except AttributeError:
@@ -562,12 +577,14 @@ def first_arg_dependent_memoize_nested(nested_func):
562
577
  setattr(cache_context, cache_dict_name, cache_dict)
563
578
 
564
579
  @wraps(nested_func)
565
- def new_nested_func(cl_object, *args):
580
+ def new_nested_func(cl_object: Hashable, *args: P.args, **kwargs: P.kwargs) -> RetT:
581
+ assert not kwargs
582
+
566
583
  try:
567
584
  return cache_dict[cl_object][args]
568
585
  except KeyError:
569
586
  arg_dict = cache_dict.setdefault(cl_object, {})
570
- result = nested_func(cl_object, *args)
587
+ result = nested_func(cl_object, *args, **kwargs)
571
588
  arg_dict[args] = result
572
589
  return result
573
590
 
@@ -575,23 +592,24 @@ def first_arg_dependent_memoize_nested(nested_func):
575
592
 
576
593
 
577
594
  def clear_first_arg_caches():
578
- """Empties all first-argument-dependent memoization caches. Also releases
579
- all held reference contexts. If it is important to you that the
580
- program detaches from its context, you might need to call this
581
- function to free all remaining references to your context.
595
+ """Empties all first-argument-dependent memoization caches.
596
+
597
+ Also releases all held reference contexts. If it is important to you that the
598
+ program detaches from its context, you might need to call this function to
599
+ free all remaining references to your context.
582
600
 
583
601
  .. versionadded:: 2011.2
584
602
  """
585
603
  for cache in _first_arg_dependent_caches:
586
- cache.clear()
587
-
588
-
589
- import atexit
604
+ # NOTE: this could be fixed by making the caches a MutableMapping, but
605
+ # that doesn't seem to be correctly covariant in its values, so other
606
+ # parts fail to work nicely..
607
+ cache.clear() # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
590
608
 
591
609
 
592
610
  if TYPE_CHECKING:
593
611
  import pyopencl as cl
594
-
612
+ from pyopencl.array import Array as CLArray
595
613
 
596
614
  atexit.register(clear_first_arg_caches)
597
615
 
@@ -621,9 +639,9 @@ class _ContextFactory:
621
639
  @override
622
640
  def __str__(self) -> str:
623
641
  # Don't show address, so that parallel test collection works
624
- return ("<context factory for <pyopencl.Device '%s' on '%s'>>" %
625
- (self.device.name.strip(),
626
- self.device.platform.name.strip()))
642
+ device = self.device.name.strip()
643
+ platform = self.device.platform.name.strip()
644
+ return f"<context factory for <pyopencl.Device '{device}' on '{platform}'>>"
627
645
 
628
646
 
629
647
  DeviceOrPlatformT = TypeVar("DeviceOrPlatformT", "cl.Device", "cl.Platform")
@@ -643,7 +661,7 @@ def _find_cl_obj(
643
661
  for obj in objs:
644
662
  if identifier.lower() in (obj.name + " " + obj.vendor).lower():
645
663
  return obj
646
- raise RuntimeError("object '%s' not found" % identifier)
664
+ raise RuntimeError(f"object '{identifier}' not found")
647
665
 
648
666
 
649
667
  def get_test_platforms_and_devices(
@@ -690,7 +708,9 @@ def get_test_platforms_and_devices(
690
708
  for platform in cl.get_platforms()]
691
709
 
692
710
 
693
- def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
711
+ def get_pyopencl_fixture_arg_names(
712
+ metafunc: pytest.Metafunc,
713
+ extra_arg_names: list[str] | None = None) -> list[str]:
694
714
  if extra_arg_names is None:
695
715
  extra_arg_names = []
696
716
 
@@ -700,7 +720,7 @@ def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
700
720
  *extra_arg_names
701
721
  ]
702
722
 
703
- arg_names = []
723
+ arg_names: list[str] = []
704
724
  for arg in supported_arg_names:
705
725
  if arg not in metafunc.fixturenames:
706
726
  continue
@@ -716,10 +736,11 @@ def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
716
736
  return arg_names
717
737
 
718
738
 
719
- def get_pyopencl_fixture_arg_values():
739
+ def get_pyopencl_fixture_arg_values() -> tuple[list[dict[str, Any]],
740
+ Callable[[Any], str]]:
720
741
  import pyopencl as cl
721
742
 
722
- arg_values = []
743
+ arg_values: list[dict[str, Any]] = []
723
744
  for platform, devices in get_test_platforms_and_devices():
724
745
  for device in devices:
725
746
  arg_dict = {
@@ -730,7 +751,7 @@ def get_pyopencl_fixture_arg_values():
730
751
  }
731
752
  arg_values.append(arg_dict)
732
753
 
733
- def idfn(val):
754
+ def idfn(val: Any) -> str:
734
755
  if isinstance(val, cl.Platform):
735
756
  # Don't show address, so that parallel test collection works
736
757
  return f"<pyopencl.Platform '{val.name}'>"
@@ -740,7 +761,7 @@ def get_pyopencl_fixture_arg_values():
740
761
  return arg_values, idfn
741
762
 
742
763
 
743
- def pytest_generate_tests_for_pyopencl(metafunc):
764
+ def pytest_generate_tests_for_pyopencl(metafunc: pytest.Metafunc) -> None:
744
765
  """Using the line::
745
766
 
746
767
  from pyopencl.tools import pytest_generate_tests_for_pyopencl
@@ -751,7 +772,7 @@ def pytest_generate_tests_for_pyopencl(metafunc):
751
772
  functions, and they will automatically be run for each OpenCL device/platform
752
773
  in the system, as appropriate.
753
774
 
754
- The following two environment variabls is also supported to control
775
+ The following two environment variables is also supported to control
755
776
  device/platform choice::
756
777
 
757
778
  PYOPENCL_TEST=0:0,1;intel=i5,i7
@@ -774,6 +795,10 @@ def pytest_generate_tests_for_pyopencl(metafunc):
774
795
 
775
796
  # {{{ C argument lists
776
797
 
798
+ ArgType: TypeAlias = "np.dtype[Any] | VectorArg"
799
+ ArgDType: TypeAlias = "np.dtype[Any] | None"
800
+
801
+
777
802
  class Argument(ABC):
778
803
  """
779
804
  .. automethod:: declarator
@@ -850,18 +875,19 @@ def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument:
850
875
  c_arg = c_arg.replace("__global", "")
851
876
 
852
877
  if with_offset:
853
- def vec_arg_factory(dtype, name):
878
+ def vec_arg_factory(dtype: DTypeLike, name: str) -> VectorArg:
854
879
  return VectorArg(dtype, name, with_offset=True)
855
880
  else:
856
881
  vec_arg_factory = VectorArg
857
882
 
858
883
  from pyopencl.compyte.dtypes import parse_c_arg_backend
884
+
859
885
  return parse_c_arg_backend(c_arg, ScalarArg, vec_arg_factory)
860
886
 
861
887
 
862
888
  def parse_arg_list(
863
- arguments: str | list[str] | Sequence[DtypedArgument],
864
- with_offset: bool = False) -> list[DtypedArgument]:
889
+ arguments: str | Sequence[str] | Sequence[Argument],
890
+ with_offset: bool = False) -> Sequence[DtypedArgument]:
865
891
  """Parse a list of kernel arguments. *arguments* may be a comma-separate
866
892
  list of C declarators in a string, a list of strings representing C
867
893
  declarators, or :class:`Argument` objects.
@@ -870,7 +896,7 @@ def parse_arg_list(
870
896
  if isinstance(arguments, str):
871
897
  arguments = arguments.split(",")
872
898
 
873
- def parse_single_arg(obj: str | DtypedArgument) -> DtypedArgument:
899
+ def parse_single_arg(obj: str | Argument) -> DtypedArgument:
874
900
  if isinstance(obj, str):
875
901
  from pyopencl.tools import parse_c_arg
876
902
  return parse_c_arg(obj, with_offset=with_offset)
@@ -881,8 +907,8 @@ def parse_arg_list(
881
907
  return [parse_single_arg(arg) for arg in arguments]
882
908
 
883
909
 
884
- def get_arg_list_arg_types(arg_types):
885
- result = []
910
+ def get_arg_list_arg_types(arg_types: Sequence[Argument]) -> tuple[ArgType, ...]:
911
+ result: list[ArgType] = []
886
912
 
887
913
  for arg_type in arg_types:
888
914
  if isinstance(arg_type, ScalarArg):
@@ -890,15 +916,15 @@ def get_arg_list_arg_types(arg_types):
890
916
  elif isinstance(arg_type, VectorArg):
891
917
  result.append(arg_type)
892
918
  else:
893
- raise RuntimeError("arg type not understood: %s" % type(arg_type))
919
+ raise RuntimeError(f"arg type not understood: {type(arg_type)}")
894
920
 
895
921
  return tuple(result)
896
922
 
897
923
 
898
924
  def get_arg_list_scalar_arg_dtypes(
899
- arg_types: Sequence[DtypedArgument]
900
- ) -> list[np.dtype | None]:
901
- result: list[np.dtype | None] = []
925
+ arg_types: Sequence[Argument]
926
+ ) -> Sequence[ArgDType]:
927
+ result: list[ArgDType] = []
902
928
 
903
929
  for arg_type in arg_types:
904
930
  if isinstance(arg_type, ScalarArg):
@@ -918,26 +944,26 @@ def get_arg_offset_adjuster_code(arg_types: Sequence[Argument]) -> str:
918
944
 
919
945
  for arg_type in arg_types:
920
946
  if isinstance(arg_type, VectorArg) and arg_type.with_offset:
921
- result.append("__global %(type)s *%(name)s = "
922
- "(__global %(type)s *) "
923
- "((__global char *) %(name)s__base + %(name)s__offset);"
924
- % {
925
- "type": dtype_to_ctype(arg_type.dtype),
926
- "name": arg_type.name})
947
+ name = arg_type.name
948
+ ctype = dtype_to_ctype(arg_type.dtype)
949
+ result.append(
950
+ f"__global {ctype} *{name} = "
951
+ f"(__global {ctype} *) "
952
+ f"((__global char *) {name}__base + {name}__offset);")
927
953
 
928
954
  return "\n".join(result)
929
955
 
930
956
  # }}}
931
957
 
932
958
 
933
- def get_gl_sharing_context_properties():
959
+ def get_gl_sharing_context_properties() -> list[tuple[cl.context_properties, Any]]:
934
960
  import pyopencl as cl
935
961
 
936
962
  ctx_props = cl.context_properties
937
963
 
938
964
  from OpenGL import platform as gl_platform
939
965
 
940
- props = []
966
+ props: list[tuple[cl.context_properties, Any]] = []
941
967
 
942
968
  import sys
943
969
  if sys.platform in ["linux", "linux2"]:
@@ -959,24 +985,23 @@ def get_gl_sharing_context_properties():
959
985
  (ctx_props.CONTEXT_PROPERTY_USE_CGL_SHAREGROUP_APPLE,
960
986
  cl.get_apple_cgl_share_group()))
961
987
  else:
962
- raise NotImplementedError("platform '%s' not yet supported"
963
- % sys.platform)
988
+ raise NotImplementedError(f"platform '{sys.platform}' not yet supported")
964
989
 
965
990
  return props
966
991
 
967
992
 
968
993
  class _CDeclList:
969
- def __init__(self, device):
970
- self.device = device
971
- self.declared_dtypes = set()
972
- self.declarations = []
973
- self.saw_double = False
974
- self.saw_complex = False
975
-
976
- def add_dtype(self, dtype):
994
+ def __init__(self, device: cl.Device) -> None:
995
+ self.device: cl.Device = device
996
+ self.declared_dtypes: set[np.dtype[Any]] = set()
997
+ self.declarations: list[str] = []
998
+ self.saw_double: bool = False
999
+ self.saw_complex: bool = False
1000
+
1001
+ def add_dtype(self, dtype: DTypeLike) -> None:
977
1002
  dtype = np.dtype(dtype)
978
1003
 
979
- if dtype in (np.float64, np.complex128):
1004
+ if dtype.type in (np.float64, np.complex128):
980
1005
  self.saw_double = True
981
1006
 
982
1007
  if dtype.kind == "c":
@@ -988,17 +1013,20 @@ class _CDeclList:
988
1013
  if dtype in self.declared_dtypes:
989
1014
  return
990
1015
 
991
- import pyopencl.cltypes
992
- if dtype in pyopencl.cltypes.vec_type_to_scalar_and_count:
1016
+ from pyopencl.cltypes import vec_type_to_scalar_and_count
1017
+
1018
+ if dtype in vec_type_to_scalar_and_count:
993
1019
  return
994
1020
 
995
1021
  if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
996
1022
  self.add_dtype(dtype.subdtype[0])
997
1023
  return
998
1024
 
999
- for _name, field_data in sorted(dtype.fields.items()):
1000
- field_dtype, _offset = field_data[:2]
1001
- self.add_dtype(field_dtype)
1025
+ fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
1026
+ if fields is not None:
1027
+ for _name, field_data in sorted(fields.items()):
1028
+ field_dtype, _offset = field_data[:2]
1029
+ self.add_dtype(field_dtype)
1002
1030
 
1003
1031
  _, cdecl = match_dtype_to_c_struct(
1004
1032
  self.device, dtype_to_ctype(dtype), dtype)
@@ -1006,16 +1034,19 @@ class _CDeclList:
1006
1034
  self.declarations.append(cdecl)
1007
1035
  self.declared_dtypes.add(dtype)
1008
1036
 
1009
- def visit_arguments(self, arguments):
1037
+ def visit_arguments(self, arguments: Sequence[Argument]) -> None:
1010
1038
  for arg in arguments:
1039
+ if not isinstance(arg, DtypedArgument):
1040
+ continue
1041
+
1011
1042
  dtype = arg.dtype
1012
- if dtype in (np.float64, np.complex128):
1043
+ if dtype.type in (np.float64, np.complex128):
1013
1044
  self.saw_double = True
1014
1045
 
1015
1046
  if dtype.kind == "c":
1016
1047
  self.saw_complex = True
1017
1048
 
1018
- def get_declarations(self):
1049
+ def get_declarations(self) -> str:
1019
1050
  result = "\n\n".join(self.declarations)
1020
1051
 
1021
1052
  if self.saw_complex:
@@ -1036,8 +1067,19 @@ class _CDeclList:
1036
1067
  return result
1037
1068
 
1038
1069
 
1070
+ class _DTypeDict(TypedDict):
1071
+ names: list[str]
1072
+ formats: list[np.dtype[Any]]
1073
+ offsets: list[int]
1074
+ itemsize: int
1075
+
1076
+
1039
1077
  @memoize
1040
- def match_dtype_to_c_struct(device, name, dtype, context=None):
1078
+ def match_dtype_to_c_struct(
1079
+ device: cl.Device,
1080
+ name: str,
1081
+ dtype: np.dtype[Any],
1082
+ context: cl.Context | None = None) -> tuple[np.dtype[Any], str]:
1041
1083
  """Return a tuple ``(dtype, c_decl)`` such that the C struct declaration
1042
1084
  in ``c_decl`` and the structure :class:`numpy.dtype` instance ``dtype``
1043
1085
  have the same memory layout.
@@ -1073,14 +1115,18 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
1073
1115
  :func:`get_or_register_dtype` on the modified ``dtype`` returned by this
1074
1116
  function, not the original one.
1075
1117
  """
1118
+ fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
1119
+ if not fields:
1120
+ raise ValueError(f"dtype has no fields: '{dtype}'")
1076
1121
 
1077
1122
  import pyopencl as cl
1078
1123
 
1079
- fields = sorted(dtype.fields.items(),
1124
+ sorted_fields = sorted(
1125
+ fields.items(),
1080
1126
  key=lambda name_dtype_offset: name_dtype_offset[1][1])
1081
1127
 
1082
- c_fields = []
1083
- for field_name, dtype_and_offset in fields:
1128
+ c_fields: list[str] = []
1129
+ for field_name, dtype_and_offset in sorted_fields:
1084
1130
  field_dtype, _offset = dtype_and_offset[:2]
1085
1131
  if hasattr(field_dtype, "subdtype") and field_dtype.subdtype is not None:
1086
1132
  array_dtype = field_dtype.subdtype[0]
@@ -1090,9 +1136,9 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
1090
1136
  dims_str = ""
1091
1137
  try:
1092
1138
  for dim in array_dims:
1093
- dims_str += "[%d]" % dim
1139
+ dims_str += f"[{dim}]"
1094
1140
  except TypeError:
1095
- dims_str = "[%d]" % array_dims
1141
+ dims_str = f"[{array_dims}]"
1096
1142
  c_fields.append(" {} {}{};".format(
1097
1143
  dtype_to_ctype(array_dtype), field_name, dims_str)
1098
1144
  )
@@ -1105,15 +1151,15 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
1105
1151
  name)
1106
1152
 
1107
1153
  cdl = _CDeclList(device)
1108
- for _field_name, dtype_and_offset in fields:
1154
+ for _field_name, dtype_and_offset in sorted_fields:
1109
1155
  field_dtype, _offset = dtype_and_offset[:2]
1110
1156
  cdl.add_dtype(field_dtype)
1111
1157
 
1112
1158
  pre_decls = cdl.get_declarations()
1113
1159
 
1114
1160
  offset_code = "\n".join(
1115
- "result[%d] = pycl_offsetof(%s, %s);" % (i+1, name, field_name)
1116
- for i, (field_name, _) in enumerate(fields))
1161
+ f"result[{i + 1}] = pycl_offsetof({name}, {field_name});"
1162
+ for i, (field_name, _) in enumerate(sorted_fields))
1117
1163
 
1118
1164
  src = rf"""
1119
1165
  #define pycl_offsetof(st, m) \
@@ -1140,30 +1186,29 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
1140
1186
  prg = cl.Program(context, src)
1141
1187
  knl = prg.build(devices=[device]).get_size_and_offsets
1142
1188
 
1143
- import pyopencl.array
1189
+ import pyopencl.array as cl_array
1190
+
1191
+ result_buf = cl_array.empty(queue, 1+len(sorted_fields), np.uint32)
1192
+ assert result_buf.data is not None
1144
1193
 
1145
- result_buf = cl.array.empty(queue, 1+len(fields), np.uint32)
1146
1194
  knl(queue, (1,), (1,), result_buf.data)
1147
1195
  queue.finish()
1148
1196
  size_and_offsets = result_buf.get()
1149
1197
 
1150
1198
  size = int(size_and_offsets[0])
1151
-
1152
1199
  offsets = size_and_offsets[1:]
1200
+
1153
1201
  if any(ofs >= size for ofs in offsets):
1154
1202
  # offsets not plausible
1155
1203
 
1156
1204
  if dtype.itemsize == size:
1157
1205
  # If sizes match, use numpy's idea of the offsets.
1158
- offsets = [dtype_and_offset[1]
1159
- for field_name, dtype_and_offset in fields]
1206
+ offsets = [dtype_and_offset[1] for _name, dtype_and_offset in sorted_fields]
1160
1207
  else:
1161
1208
  raise RuntimeError(
1162
- "OpenCL compiler reported offsetof() past sizeof() "
1163
- "for struct layout on '%s'. "
1164
- "This makes no sense, and it's usually indicates a "
1165
- "compiler bug. "
1166
- "Refusing to discover struct layout." % device)
1209
+ "OpenCL compiler reported offsetof() past sizeof() for struct "
1210
+ f"layout on '{device}'. This makes no sense, and it usually "
1211
+ "indicates a compiler bug. Refusing to discover struct layout.")
1167
1212
 
1168
1213
  result_buf.data.release()
1169
1214
  del knl
@@ -1172,58 +1217,66 @@ def match_dtype_to_c_struct(device, name, dtype, context=None):
1172
1217
  del context
1173
1218
 
1174
1219
  try:
1175
- dtype_arg_dict = {
1176
- "names": [field_name
1177
- for field_name, (field_dtype, offset) in fields],
1178
- "formats": [field_dtype
1179
- for field_name, (field_dtype, offset) in fields],
1180
- "offsets": [int(x) for x in offsets],
1181
- "itemsize": int(size_and_offsets[0]),
1182
- }
1183
- dtype = np.dtype(dtype_arg_dict)
1184
- if dtype.itemsize != size_and_offsets[0]:
1220
+ dtype_arg_dict = _DTypeDict(
1221
+ names=[name for name, _ in sorted_fields],
1222
+ formats=[dtype_and_offset[0] for _, dtype_and_offset in sorted_fields],
1223
+ offsets=[int(x) for x in offsets],
1224
+ itemsize=int(size_and_offsets[0]),
1225
+ )
1226
+ arg_dtype = np.dtype(dtype_arg_dict)
1227
+
1228
+ if arg_dtype.itemsize != size_and_offsets[0]:
1185
1229
  # "Old" versions of numpy (1.6.x?) silently ignore "itemsize". Boo.
1186
1230
  dtype_arg_dict["names"].append("_pycl_size_fixer")
1187
- dtype_arg_dict["formats"].append(np.uint8)
1188
- dtype_arg_dict["offsets"].append(int(size_and_offsets[0])-1)
1189
- dtype = np.dtype(dtype_arg_dict)
1231
+ dtype_arg_dict["formats"].append(np.dtype(np.uint8))
1232
+ dtype_arg_dict["offsets"].append(int(size_and_offsets[0]) - 1)
1233
+
1234
+ arg_dtype = np.dtype(dtype_arg_dict)
1190
1235
  except NotImplementedError:
1191
- def calc_field_type():
1236
+ def calc_field_type() -> Iterator[tuple[str, str | np.dtype[Any]]]:
1192
1237
  total_size = 0
1193
1238
  padding_count = 0
1194
- for offset, (field_name, (field_dtype, _)) in zip(
1195
- offsets, fields, strict=True):
1239
+ for offset, (field_name, dtype_and_offset) in zip(
1240
+ offsets, sorted_fields, strict=True):
1241
+ field_dtype, _ = dtype_and_offset[:2]
1196
1242
  if offset > total_size:
1197
1243
  padding_count += 1
1198
- yield ("__pycl_padding%d" % padding_count,
1199
- "V%d" % offset - total_size)
1244
+ yield f"__pycl_padding{padding_count}", f"V{offset - total_size}"
1245
+
1200
1246
  yield field_name, field_dtype
1201
1247
  total_size = field_dtype.itemsize + offset
1202
- dtype = np.dtype(list(calc_field_type()))
1203
1248
 
1204
- assert dtype.itemsize == size_and_offsets[0]
1249
+ arg_dtype = np.dtype(list(calc_field_type()))
1250
+
1251
+ assert arg_dtype.itemsize == size_and_offsets[0]
1205
1252
 
1206
- return dtype, c_decl
1253
+ return arg_dtype, c_decl
1207
1254
 
1208
1255
 
1209
1256
  @memoize
1210
- def dtype_to_c_struct(device, dtype):
1211
- if dtype.fields is None:
1257
+ def dtype_to_c_struct(device: cl.Device, dtype: np.dtype[Any]) -> str:
1258
+ fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
1259
+ if fields is None:
1212
1260
  return ""
1213
1261
 
1214
- import pyopencl.cltypes
1215
- if dtype in pyopencl.cltypes.vec_type_to_scalar_and_count:
1262
+ from pyopencl.cltypes import vec_type_to_scalar_and_count
1263
+
1264
+ if dtype in vec_type_to_scalar_and_count:
1216
1265
  # Vector types are built-in. Don't try to redeclare those.
1217
1266
  return ""
1218
1267
 
1219
1268
  matched_dtype, c_decl = match_dtype_to_c_struct(
1220
1269
  device, dtype_to_ctype(dtype), dtype)
1221
1270
 
1222
- def dtypes_match():
1223
- result = len(dtype.fields) == len(matched_dtype.fields)
1271
+ matched_fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None",
1272
+ matched_dtype.fields)
1273
+ assert matched_fields is not None
1274
+
1275
+ def dtypes_match() -> bool:
1276
+ result = len(fields) == len(matched_fields)
1224
1277
 
1225
- for name, val in dtype.fields.items():
1226
- result = result and matched_dtype.fields[name] == val
1278
+ for name, val in fields.items():
1279
+ result = result and matched_fields[name] == val
1227
1280
 
1228
1281
  return result
1229
1282
 
@@ -1234,38 +1287,52 @@ def dtype_to_c_struct(device, dtype):
1234
1287
 
1235
1288
  # {{{ code generation/templating helper
1236
1289
 
1237
- def _process_code_for_macro(code):
1290
+ def _process_code_for_macro(code: str) -> str:
1238
1291
  code = code.replace("//CL//", "\n")
1239
1292
 
1240
1293
  if "//" in code:
1241
- raise RuntimeError("end-of-line comments ('//') may not be used in "
1242
- "code snippets")
1294
+ raise RuntimeError(
1295
+ "end-of-line comments ('//') may not be used in code snippets")
1243
1296
 
1244
1297
  return code.replace("\n", " \\\n")
1245
1298
 
1246
1299
 
1247
- class _SimpleTextTemplate:
1248
- def __init__(self, txt):
1249
- self.txt = txt
1300
+ class _TextTemplate(ABC):
1301
+ @abstractmethod
1302
+ def render(self, context: dict[str, Any]) -> str:
1303
+ pass
1304
+
1305
+
1306
+ @dataclass(frozen=True)
1307
+ class _SimpleTextTemplate(_TextTemplate):
1308
+ txt: str
1250
1309
 
1251
- def render(self, context):
1310
+ @override
1311
+ def render(self, context: dict[str, Any]) -> str:
1252
1312
  return self.txt
1253
1313
 
1254
1314
 
1255
- class _PrintfTextTemplate:
1256
- def __init__(self, txt):
1257
- self.txt = txt
1315
+ @dataclass(frozen=True)
1316
+ class _PrintfTextTemplate(_TextTemplate):
1317
+ txt: str
1258
1318
 
1259
- def render(self, context):
1319
+ @override
1320
+ def render(self, context: dict[str, Any]) -> str:
1260
1321
  return self.txt % context
1261
1322
 
1262
1323
 
1263
- class _MakoTextTemplate:
1264
- def __init__(self, txt):
1324
+ @dataclass(frozen=True)
1325
+ class _MakoTextTemplate(_TextTemplate):
1326
+ txt: str
1327
+ template: Template = field(init=False)
1328
+
1329
+ def __post_init__(self) -> None:
1265
1330
  from mako.template import Template
1266
- self.template = Template(txt, strict_undefined=True)
1267
1331
 
1268
- def render(self, context):
1332
+ object.__setattr__(self, "template", Template(self.txt, strict_undefined=True))
1333
+
1334
+ @override
1335
+ def render(self, context: dict[str, Any]) -> str:
1269
1336
  return self.template.render(**context)
1270
1337
 
1271
1338
 
@@ -1279,36 +1346,52 @@ class _ArgumentPlaceholder:
1279
1346
  See also :class:`_TemplateRenderer.render_arg`.
1280
1347
  """
1281
1348
 
1282
- def __init__(self, typename, name, **extra_kwargs):
1283
- self.typename = typename
1284
- self.name = name
1285
- self.extra_kwargs = extra_kwargs
1349
+ target_class: ClassVar[type[DtypedArgument]]
1350
+
1351
+ def __init__(self,
1352
+ typename: DTypeLike,
1353
+ name: str,
1354
+ **extra_kwargs: Any) -> None:
1355
+ self.typename: DTypeLike = typename
1356
+ self.name: str = name
1357
+ self.extra_kwargs: dict[str, Any] = extra_kwargs
1286
1358
 
1287
1359
 
1288
1360
  class _VectorArgPlaceholder(_ArgumentPlaceholder):
1289
- target_class = VectorArg
1361
+ target_class: ClassVar[type[DtypedArgument]] = VectorArg
1290
1362
 
1291
1363
 
1292
1364
  class _ScalarArgPlaceholder(_ArgumentPlaceholder):
1293
- target_class = ScalarArg
1365
+ target_class: ClassVar[type[DtypedArgument]] = ScalarArg
1294
1366
 
1295
1367
 
1296
1368
  class _TemplateRenderer:
1297
- def __init__(self, template, type_aliases, var_values, context=None,
1298
- options=None):
1299
- self.template = template
1300
- self.type_aliases = dict(type_aliases)
1301
- self.var_dict = dict(var_values)
1369
+ def __init__(self,
1370
+ template: KernelTemplateBase,
1371
+ type_aliases: (
1372
+ dict[str, np.dtype[Any]]
1373
+ | Sequence[tuple[str, np.dtype[Any]]]),
1374
+ var_values: dict[str, str] | Sequence[tuple[str, str]],
1375
+ context: cl.Context | None = None,
1376
+ options: Any = None) -> None:
1377
+ self.template: KernelTemplateBase = template
1378
+ self.type_aliases: dict[str, np.dtype[Any]] = dict(type_aliases)
1379
+ self.var_dict: dict[str, str] = dict(var_values)
1302
1380
 
1303
1381
  for name in self.var_dict:
1304
1382
  if name.startswith("macro_"):
1305
- self.var_dict[name] = _process_code_for_macro(
1306
- self.var_dict[name])
1383
+ self.var_dict[name] = _process_code_for_macro(self.var_dict[name])
1384
+
1385
+ self.context: cl.Context | None = context
1386
+ self.options: Any = options
1307
1387
 
1308
- self.context = context
1309
- self.options = options
1388
+ @overload
1389
+ def __call__(self, txt: None) -> None: ...
1310
1390
 
1311
- def __call__(self, txt):
1391
+ @overload
1392
+ def __call__(self, txt: str) -> str: ...
1393
+
1394
+ def __call__(self, txt: str | None) -> str | None:
1312
1395
  if txt is None:
1313
1396
  return txt
1314
1397
 
@@ -1316,7 +1399,10 @@ class _TemplateRenderer:
1316
1399
 
1317
1400
  return str(result)
1318
1401
 
1319
- def get_rendered_kernel(self, txt, kernel_name):
1402
+ def get_rendered_kernel(self, txt: str, kernel_name: str) -> cl.Kernel:
1403
+ if self.context is None:
1404
+ raise ValueError("context not provided -- cannot render kernel")
1405
+
1320
1406
  import pyopencl as cl
1321
1407
  prg = cl.Program(self.context, self(txt)).build(self.options)
1322
1408
 
@@ -1326,7 +1412,7 @@ class _TemplateRenderer:
1326
1412
 
1327
1413
  return getattr(prg, kernel_name)
1328
1414
 
1329
- def parse_type(self, typename):
1415
+ def parse_type(self, typename: Any) -> np.dtype[Any]:
1330
1416
  if isinstance(typename, str):
1331
1417
  try:
1332
1418
  return self.type_aliases[typename]
@@ -1336,21 +1422,22 @@ class _TemplateRenderer:
1336
1422
  else:
1337
1423
  return np.dtype(typename)
1338
1424
 
1339
- def render_arg(self, arg_placeholder):
1425
+ def render_arg(self, arg_placeholder: _ArgumentPlaceholder) -> DtypedArgument:
1340
1426
  return arg_placeholder.target_class(
1341
1427
  self.parse_type(arg_placeholder.typename),
1342
1428
  arg_placeholder.name,
1343
1429
  **arg_placeholder.extra_kwargs)
1344
1430
 
1345
- _C_COMMENT_FINDER = re.compile(r"/\*.*?\*/")
1431
+ _C_COMMENT_FINDER: ClassVar[re.Pattern[str]] = re.compile(r"/\*.*?\*/")
1346
1432
 
1347
- def render_argument_list(self, *arg_lists, **kwargs):
1348
- with_offset = kwargs.pop("with_offset", False)
1433
+ def render_argument_list(self,
1434
+ *arg_lists: Any,
1435
+ with_offset: bool = False,
1436
+ **kwargs: Any) -> list[Argument]:
1349
1437
  if kwargs:
1350
1438
  raise TypeError("unrecognized kwargs: " + ", ".join(kwargs))
1351
1439
 
1352
- all_args = []
1353
-
1440
+ all_args: list[Any] = []
1354
1441
  for arg_list in arg_lists:
1355
1442
  if isinstance(arg_list, str):
1356
1443
  arg_list = str(
@@ -1364,13 +1451,16 @@ class _TemplateRenderer:
1364
1451
  all_args.extend(arg_list)
1365
1452
 
1366
1453
  if with_offset:
1367
- def vec_arg_factory(typename, name):
1454
+ def vec_arg_factory(
1455
+ typename: DTypeLike,
1456
+ name: str) -> _VectorArgPlaceholder:
1368
1457
  return _VectorArgPlaceholder(typename, name, with_offset=True)
1369
1458
  else:
1370
1459
  vec_arg_factory = _VectorArgPlaceholder
1371
1460
 
1372
1461
  from pyopencl.compyte.dtypes import parse_c_arg_backend
1373
- parsed_args = []
1462
+
1463
+ parsed_args: list[Argument] = []
1374
1464
  for arg in all_args:
1375
1465
  if isinstance(arg, str):
1376
1466
  arg = arg.strip()
@@ -1379,21 +1469,26 @@ class _TemplateRenderer:
1379
1469
 
1380
1470
  ph = parse_c_arg_backend(arg,
1381
1471
  _ScalarArgPlaceholder, vec_arg_factory,
1382
- name_to_dtype=lambda x: x)
1472
+ name_to_dtype=lambda x: x) # pyright: ignore[reportArgumentType]
1383
1473
  parsed_arg = self.render_arg(ph)
1384
-
1385
1474
  elif isinstance(arg, Argument):
1386
1475
  parsed_arg = arg
1387
1476
  elif isinstance(arg, tuple):
1477
+ assert isinstance(arg[0], str)
1478
+ assert isinstance(arg[1], str)
1388
1479
  parsed_arg = ScalarArg(self.parse_type(arg[0]), arg[1])
1389
1480
  else:
1390
- raise TypeError("unexpected argument type: %s" % type(arg))
1481
+ raise TypeError(f"unexpected argument type: {type(arg)}")
1391
1482
 
1392
1483
  parsed_args.append(parsed_arg)
1393
1484
 
1394
1485
  return parsed_args
1395
1486
 
1396
- def get_type_decl_preamble(self, device, decl_type_names, arguments=None):
1487
+ def get_type_decl_preamble(self,
1488
+ device: cl.Device,
1489
+ decl_type_names: Sequence[DTypeLike],
1490
+ arguments: Sequence[Argument] | None = None,
1491
+ ) -> str:
1397
1492
  cdl = _CDeclList(device)
1398
1493
 
1399
1494
  for typename in decl_type_names:
@@ -1413,20 +1508,19 @@ class _TemplateRenderer:
1413
1508
  return cdl.get_declarations() + "\n" + "\n".join(type_alias_decls)
1414
1509
 
1415
1510
 
1416
- class KernelTemplateBase:
1417
- def __init__(self, template_processor=None):
1418
- self.template_processor = template_processor
1511
+ class KernelTemplateBase(ABC):
1512
+ def __init__(self, template_processor: str | None = None) -> None:
1513
+ self.template_processor: str | None = template_processor
1419
1514
 
1420
- self.build_cache = {}
1515
+ self.build_cache: dict[Hashable, Any] = {}
1421
1516
  _first_arg_dependent_caches.append(self.build_cache)
1422
1517
 
1423
- def get_preamble(self):
1424
- pass
1425
-
1426
- _TEMPLATE_PROCESSOR_PATTERN = re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
1518
+ _TEMPLATE_PROCESSOR_PATTERN: ClassVar[re.Pattern[str]] = (
1519
+ re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
1520
+ )
1427
1521
 
1428
1522
  @memoize_method
1429
- def get_text_template(self, txt):
1523
+ def get_text_template(self, txt: str) -> _TextTemplate:
1430
1524
  proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt)
1431
1525
  tpl_processor = None
1432
1526
 
@@ -1434,6 +1528,7 @@ class KernelTemplateBase:
1434
1528
  tpl_processor = proc_match.group(1)
1435
1529
  # chop off //CL// mark
1436
1530
  txt = txt[len(proc_match.group(0)):]
1531
+
1437
1532
  if tpl_processor is None:
1438
1533
  tpl_processor = self.template_processor
1439
1534
 
@@ -1444,16 +1539,30 @@ class KernelTemplateBase:
1444
1539
  elif tpl_processor == "mako":
1445
1540
  return _MakoTextTemplate(txt)
1446
1541
  else:
1447
- raise RuntimeError(
1448
- "unknown template processor '%s'" % proc_match.group(1))
1542
+ raise RuntimeError(f"unknown template processor '{tpl_processor}'")
1543
+
1544
+ # TODO: this does not seem to be used anywhere -> deprecate / remove
1545
+ def get_preamble(self) -> str:
1546
+ return ""
1449
1547
 
1450
- def get_renderer(self, type_aliases, var_values, context=None, options=None):
1548
+ def get_renderer(self,
1549
+ type_aliases: (
1550
+ dict[str, np.dtype[Any]]
1551
+ | Sequence[tuple[str, np.dtype[Any]]]),
1552
+ var_values: dict[str, str] | Sequence[tuple[str, str]],
1553
+ context: cl.Context | None = None, # pyright: ignore[reportUnusedParameter]
1554
+ options: Any = None, # pyright: ignore[reportUnusedParameter]
1555
+ ) -> _TemplateRenderer:
1451
1556
  return _TemplateRenderer(self, type_aliases, var_values)
1452
1557
 
1453
- def build_inner(self, context, *args, **kwargs):
1454
- raise NotImplementedError
1558
+ @abstractmethod
1559
+ def build_inner(self,
1560
+ context: cl.Context,
1561
+ *args: Any,
1562
+ **kwargs: Any) -> Callable[..., cl.Event]:
1563
+ pass
1455
1564
 
1456
- def build(self, context, *args, **kwargs):
1565
+ def build(self, context: cl.Context, *args: Any, **kwargs: Any) -> Any:
1457
1566
  """Provide caching for an :meth:`build_inner`."""
1458
1567
 
1459
1568
  cache_key = (context, args, tuple(sorted(kwargs.items())))
@@ -1469,41 +1578,47 @@ class KernelTemplateBase:
1469
1578
 
1470
1579
  # {{{ array_module
1471
1580
 
1581
+ # TODO: this is not used anywhere: deprecate + remove
1582
+
1472
1583
  class _CLFakeArrayModule:
1473
- def __init__(self, queue):
1474
- self.queue = queue
1584
+ def __init__(self, queue: cl.CommandQueue | None = None) -> None:
1585
+ self.queue: cl.CommandQueue | None = queue
1475
1586
 
1476
1587
  @property
1477
- def ndarray(self):
1588
+ def ndarray(self) -> type[CLArray]:
1478
1589
  from pyopencl.array import Array
1479
1590
  return Array
1480
1591
 
1481
- def dot(self, x, y):
1592
+ def dot(self, x: CLArray, y: CLArray) -> NDArray[Any]:
1482
1593
  from pyopencl.array import dot
1483
1594
  return dot(x, y, queue=self.queue).get()
1484
1595
 
1485
- def vdot(self, x, y):
1596
+ def vdot(self, x: CLArray, y: CLArray) -> NDArray[Any]:
1486
1597
  from pyopencl.array import vdot
1487
1598
  return vdot(x, y, queue=self.queue).get()
1488
1599
 
1489
- def empty(self, shape, dtype, order="C"):
1600
+ def empty(self,
1601
+ shape: int | tuple[int, ...],
1602
+ dtype: DTypeLike,
1603
+ order: str = "C") -> CLArray:
1490
1604
  from pyopencl.array import empty
1491
1605
  return empty(self.queue, shape, dtype, order=order)
1492
1606
 
1493
- def hstack(self, arrays):
1607
+ def hstack(self, arrays: Sequence[CLArray]) -> CLArray:
1494
1608
  from pyopencl.array import hstack
1495
1609
  return hstack(arrays, self.queue)
1496
1610
 
1497
1611
 
1498
- def array_module(a):
1612
+ def array_module(a: Any) -> Any:
1499
1613
  if isinstance(a, np.ndarray):
1500
1614
  return np
1501
1615
  else:
1502
1616
  from pyopencl.array import Array
1617
+
1503
1618
  if isinstance(a, Array):
1504
1619
  return _CLFakeArrayModule(a.queue)
1505
1620
  else:
1506
- raise TypeError("array type not understood: %s" % type(a))
1621
+ raise TypeError(f"array type not understood: {type(a)}")
1507
1622
 
1508
1623
  # }}}
1509
1624
 
@@ -1519,19 +1634,19 @@ def is_spirv(s: str | bytes) -> TypeIs[bytes]:
1519
1634
 
1520
1635
  # {{{ numpy key types builder
1521
1636
 
1522
- class _NumpyTypesKeyBuilder(KeyBuilderBase):
1523
- def update_for_VectorArg(self, key_hash, key): # noqa: N802
1637
+ class _NumpyTypesKeyBuilder(KeyBuilderBase): # pyright: ignore[reportUnusedClass]
1638
+ def update_for_VectorArg(self, key_hash: Hash, key: VectorArg) -> None: # noqa: N802
1524
1639
  self.rec(key_hash, key.dtype)
1525
1640
  self.update_for_str(key_hash, key.name)
1526
1641
  self.rec(key_hash, key.with_offset)
1527
1642
 
1528
- def update_for_type(self, key_hash, key):
1643
+ @override
1644
+ def update_for_type(self, key_hash: Hash, key: type) -> None:
1529
1645
  if issubclass(key, np.generic):
1530
1646
  self.update_for_str(key_hash, key.__name__)
1531
1647
  return
1532
1648
 
1533
- raise TypeError("unsupported type for persistent hash keying: %s"
1534
- % type(key))
1649
+ raise TypeError(f"unsupported type for persistent hash keying: {key}")
1535
1650
 
1536
1651
  # }}}
1537
1652