pyopencl 2024.3__cp312-cp312-macosx_11_0_arm64.whl → 2025.2.1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

pyopencl/_mymako.py CHANGED
@@ -1,3 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+
1
4
  try:
2
5
  import mako.template # noqa: F401
3
6
  except ImportError as err:
pyopencl/algorithm.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Algorithms built on scans."""
2
+ from __future__ import annotations
2
3
 
3
4
 
4
5
  __copyright__ = """
@@ -30,7 +31,7 @@ OTHER DEALINGS IN THE SOFTWARE.
30
31
  """
31
32
 
32
33
  from dataclasses import dataclass
33
- from typing import Optional
34
+ from typing import TYPE_CHECKING
34
35
 
35
36
  import numpy as np
36
37
  from mako.template import Template
@@ -38,12 +39,15 @@ from mako.template import Template
38
39
  from pytools import memoize, memoize_method
39
40
 
40
41
  import pyopencl as cl
41
- import pyopencl.array
42
- from pyopencl.elementwise import ElementwiseKernel
42
+ import pyopencl.array as cl_array
43
43
  from pyopencl.scan import GenericScanKernel, ScanTemplate
44
44
  from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code
45
45
 
46
46
 
47
+ if TYPE_CHECKING:
48
+ from pyopencl.elementwise import ElementwiseKernel
49
+
50
+
47
51
  # {{{ "extra args" handling utility
48
52
 
49
53
  def _extract_extra_args_types_values(extra_args):
@@ -55,7 +59,7 @@ def _extract_extra_args_types_values(extra_args):
55
59
  extra_args_values = []
56
60
  extra_wait_for = []
57
61
  for name, val in extra_args:
58
- if isinstance(val, cl.array.Array):
62
+ if isinstance(val, cl_array.Array):
59
63
  extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
60
64
  extra_args_values.append(val)
61
65
  extra_wait_for.extend(val.events)
@@ -117,7 +121,7 @@ def copy_if(ary, predicate, extra_args=None, preamble="", queue=None, wait_for=N
117
121
  type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
118
122
  var_values=(("predicate", predicate),),
119
123
  more_preamble=preamble, more_arguments=extra_args_types)
120
- out = cl.array.empty_like(ary)
124
+ out = cl_array.empty_like(ary)
121
125
  count = ary._new_with_changes(data=None, offset=0,
122
126
  shape=(), strides=(), dtype=scan_dtype)
123
127
 
@@ -207,8 +211,8 @@ def partition(ary, predicate, extra_args=None, preamble="",
207
211
  var_values=(("predicate", predicate),),
208
212
  more_preamble=preamble, more_arguments=extra_args_types)
209
213
 
210
- out_true = cl.array.empty_like(ary)
211
- out_false = cl.array.empty_like(ary)
214
+ out_true = cl_array.empty_like(ary)
215
+ out_false = cl_array.empty_like(ary)
212
216
  count = ary._new_with_changes(data=None, offset=0,
213
217
  shape=(), strides=(), dtype=scan_dtype)
214
218
 
@@ -279,7 +283,7 @@ def unique(ary, is_equal_expr="a == b", extra_args=None, preamble="",
279
283
  var_values=(("macro_is_equal_expr", is_equal_expr),),
280
284
  more_preamble=preamble, more_arguments=extra_args_types)
281
285
 
282
- out = cl.array.empty_like(ary)
286
+ out = cl_array.empty_like(ary)
283
287
  count = ary._new_with_changes(data=None, offset=0,
284
288
  shape=(), strides=(), dtype=scan_dtype)
285
289
 
@@ -556,7 +560,7 @@ class RadixSort:
556
560
  base_bit = 0
557
561
  while base_bit < key_bits:
558
562
  sorted_args = [
559
- cl.array.empty(queue, n, arg_descr.dtype, allocator=allocator)
563
+ cl_array.empty(queue, n, arg_descr.dtype, allocator=allocator)
560
564
  for arg_descr in self.arguments
561
565
  if arg_descr.name in self.sort_arg_names]
562
566
 
@@ -574,7 +578,7 @@ class RadixSort:
574
578
  base_bit += self.bits
575
579
 
576
580
  return [arg_val
577
- for arg_descr, arg_val in zip(self.arguments, args)
581
+ for arg_descr, arg_val in zip(self.arguments, args, strict=True)
578
582
  if arg_descr.name in self.sort_arg_names], last_evt
579
583
 
580
584
  # }}}
@@ -725,12 +729,12 @@ def _get_arg_list(arg_list, prefix=""):
725
729
 
726
730
  @dataclass
727
731
  class BuiltList:
728
- count: Optional[int]
729
- starts: Optional[pyopencl.array.Array]
730
- lists: Optional[pyopencl.array.Array] = None
731
- num_nonempty_lists: Optional[int] = None
732
- nonempty_indices: Optional[pyopencl.array.Array] = None
733
- compressed_indices: Optional[pyopencl.array.Array] = None
732
+ count: int | None
733
+ starts: cl_array.Array | None
734
+ lists: cl_array.Array | None = None
735
+ num_nonempty_lists: int | None = None
736
+ nonempty_indices: cl_array.Array | None = None
737
+ compressed_indices: cl_array.Array | None = None
734
738
 
735
739
 
736
740
  class ListOfListsBuilder:
@@ -1139,7 +1143,8 @@ class ListOfListsBuilder:
1139
1143
  compress_kernel = self.get_compress_kernel(index_dtype)
1140
1144
 
1141
1145
  data_args = []
1142
- for i, (arg_descr, arg_val) in enumerate(zip(self.arg_decls, args)):
1146
+ for i, (arg_descr, arg_val) in enumerate(
1147
+ zip(self.arg_decls, args, strict=True)):
1143
1148
  from pyopencl.tools import VectorArg
1144
1149
  if isinstance(arg_descr, VectorArg):
1145
1150
  from pyopencl import MemoryObject
@@ -1179,7 +1184,7 @@ class ListOfListsBuilder:
1179
1184
  count_list_args.append(None)
1180
1185
  continue
1181
1186
 
1182
- counts = cl.array.empty(queue,
1187
+ counts = cl_array.empty(queue,
1183
1188
  (n_objects + 1), index_dtype, allocator=allocator)
1184
1189
  counts[-1] = 0
1185
1190
  wait_for = wait_for + counts.events
@@ -1219,14 +1224,14 @@ class ListOfListsBuilder:
1219
1224
  if name not in self.eliminate_empty_output_lists:
1220
1225
  continue
1221
1226
 
1222
- compressed_counts = cl.array.empty(
1227
+ compressed_counts = cl_array.empty(
1223
1228
  queue, (n_objects + 1,), index_dtype, allocator=allocator)
1224
1229
  info_record = result[name]
1225
- info_record.nonempty_indices = cl.array.empty(
1230
+ info_record.nonempty_indices = cl_array.empty(
1226
1231
  queue, (n_objects + 1,), index_dtype, allocator=allocator)
1227
- info_record.num_nonempty_lists = cl.array.empty(
1232
+ info_record.num_nonempty_lists = cl_array.empty(
1228
1233
  queue, (1,), index_dtype, allocator=allocator)
1229
- info_record.compressed_indices = cl.array.empty(
1234
+ info_record.compressed_indices = cl_array.empty(
1230
1235
  queue, (n_objects + 1,), index_dtype, allocator=allocator)
1231
1236
  info_record.compressed_indices[0] = 0
1232
1237
 
@@ -1301,7 +1306,7 @@ class ListOfListsBuilder:
1301
1306
  else:
1302
1307
  info_record = result[name]
1303
1308
 
1304
- info_record.lists = cl.array.empty(queue,
1309
+ info_record.lists = cl_array.empty(queue,
1305
1310
  info_record.count, dtype, allocator=allocator)
1306
1311
  write_list_args.append(info_record.lists.data)
1307
1312
 
@@ -1431,7 +1436,7 @@ class KeyValueSorter:
1431
1436
  (values_sorted_by_key, keys_sorted_by_key), evt = knl_info.by_target_sorter(
1432
1437
  values, keys, queue=queue, wait_for=wait_for)
1433
1438
 
1434
- starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
1439
+ starts = (cl_array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
1435
1440
  .fill(len(values_sorted_by_key), wait_for=[evt]))
1436
1441
  evt, = starts.events
1437
1442
 
pyopencl/array.py CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  # NOTE: for elwise_kernel_runner which adds keyword arguments
4
4
  # pylint:disable=unexpected-keyword-arg
5
+ from __future__ import annotations
6
+
5
7
 
6
8
  __copyright__ = "Copyright (C) 2009 Andreas Kloeckner"
7
9
 
@@ -32,13 +34,14 @@ import builtins
32
34
  from dataclasses import dataclass
33
35
  from functools import reduce
34
36
  from numbers import Number
35
- from typing import Any, Dict, Hashable, List, Optional, Tuple, Union
37
+ from typing import Any
36
38
  from warnings import warn
37
39
 
38
40
  import numpy as np
39
41
 
40
42
  import pyopencl as cl
41
43
  import pyopencl.elementwise as elementwise
44
+ import pyopencl.tools as cl_tools
42
45
  from pyopencl import cltypes
43
46
  from pyopencl.characterize import has_double_support
44
47
  from pyopencl.compyte.array import (
@@ -58,20 +61,14 @@ else:
58
61
  _SVMPointer_or_nothing = ()
59
62
 
60
63
 
61
- _NUMPY_PRE_2 = np.__version__.startswith("1.")
62
-
63
-
64
64
  # {{{ _get_common_dtype
65
65
 
66
- _COMMON_DTYPE_CACHE: Dict[Tuple[Hashable, ...], np.dtype] = {}
67
-
68
-
69
66
  class DoubleDowncastWarning(UserWarning):
70
67
  pass
71
68
 
72
69
 
73
70
  _DOUBLE_DOWNCAST_WARNING = (
74
- "The operation you requested would result in a double-precisision "
71
+ "The operation you requested would result in a double-precision "
75
72
  "quantity according to numpy semantics. Since your device does not "
76
73
  "support double precision, a single-precision quantity is being returned.")
77
74
 
@@ -81,78 +78,12 @@ def _get_common_dtype(obj1, obj2, queue):
81
78
  raise ValueError("PyOpenCL array has no queue; call .with_queue() to "
82
79
  "add one in order to be able to perform operations")
83
80
 
84
- allow_double = has_double_support(queue.device)
85
- cache_key = None
86
- o1_is_array = isinstance(obj1, Array)
87
- o2_is_array = isinstance(obj2, Array)
88
-
89
- if o1_is_array and o2_is_array:
90
- o1_dtype = obj1.dtype
91
- o2_dtype = obj2.dtype
92
- cache_key = (obj1.dtype, obj2.dtype, allow_double)
93
- else:
94
- o1_dtype = getattr(obj1, "dtype", type(obj1))
95
- o2_dtype = getattr(obj2, "dtype", type(obj2))
96
-
97
- o1_is_integral = np.issubdtype(o1_dtype, np.integer)
98
- o2_is_integral = np.issubdtype(o1_dtype, np.integer)
99
-
100
- o1_key = obj1 if o1_is_integral and not o1_is_array else o1_dtype
101
- o2_key = obj2 if o2_is_integral and not o2_is_array else o2_dtype
102
-
103
- cache_key = (o1_key, o2_key, o1_is_array, o2_is_array, allow_double)
81
+ # Note: We are calling np.result_type with pyopencl arrays here.
82
+ # Luckily, np.result_type only looks at the dtype of input arrays up until
83
+ # at least numpy v2.1.
84
+ result = np.result_type(obj1, obj2)
104
85
 
105
- try:
106
- return _COMMON_DTYPE_CACHE[cache_key]
107
- except KeyError:
108
- pass
109
-
110
- # Numpy's behavior around integers is a bit bizarre, and definitely value-
111
- # and not just type-sensitive when it comes to scalars. We'll just do our
112
- # best to emulate it.
113
- #
114
- # Some samples that are true as of numpy 1.23.1.
115
- #
116
- # >>> a = np.zeros(1, dtype=np.int16)
117
- # >>> (a + 123123123312).dtype
118
- # dtype('int64')
119
- # >>> (a + 12312).dtype
120
- # dtype('int16')
121
- # >>> (a + 12312444).dtype
122
- # dtype('int32')
123
- # >>> (a + np.int32(12312444)).dtype
124
- # dtype('int32')
125
- # >>> (a + np.int32(1234)).dtype
126
- # dtype('int16')
127
- #
128
- # Note that np.find_common_type, while appealing, won't be able to tell
129
- # the full story.
130
-
131
- if (_NUMPY_PRE_2
132
- and not (o1_is_array and o2_is_array)
133
- and o1_is_integral and o2_is_integral):
134
- if o1_is_array:
135
- obj1 = np.zeros(1, dtype=o1_dtype)
136
- if o2_is_array:
137
- obj2 = np.zeros(1, dtype=o2_dtype)
138
-
139
- result = (obj1 + obj2).dtype
140
- else:
141
- array_types = []
142
- scalars = []
143
-
144
- if o1_is_array:
145
- array_types.append(o1_dtype)
146
- else:
147
- scalars.append(obj1)
148
- if o2_is_array:
149
- array_types.append(o2_dtype)
150
- else:
151
- scalars.append(obj2)
152
-
153
- result = np.result_type(*array_types, *scalars)
154
-
155
- if not allow_double:
86
+ if not has_double_support(queue.device):
156
87
  if result == np.float64:
157
88
  result = np.dtype(np.float32)
158
89
  warn(_DOUBLE_DOWNCAST_WARNING, DoubleDowncastWarning, stacklevel=3)
@@ -160,9 +91,6 @@ def _get_common_dtype(obj1, obj2, queue):
160
91
  result = np.dtype(np.complex64)
161
92
  warn(_DOUBLE_DOWNCAST_WARNING, DoubleDowncastWarning, stacklevel=3)
162
93
 
163
- if cache_key is not None:
164
- _COMMON_DTYPE_CACHE[cache_key] = result
165
-
166
94
  return result
167
95
 
168
96
  # }}}
@@ -305,13 +233,13 @@ def elwise_kernel_runner(kernel_getter):
305
233
  return kernel_runner
306
234
 
307
235
 
308
- class DefaultAllocator(cl.tools.DeferredAllocator):
236
+ class DefaultAllocator(cl_tools.DeferredAllocator):
309
237
  def __init__(self, *args, **kwargs):
310
238
  warn("pyopencl.array.DefaultAllocator is deprecated. "
311
239
  "It will be continue to exist throughout the 2013.x "
312
240
  "versions of PyOpenCL.",
313
241
  DeprecationWarning, stacklevel=2)
314
- cl.tools.DeferredAllocator.__init__(self, *args, **kwargs)
242
+ cl_tools.DeferredAllocator.__init__(self, *args, **kwargs)
315
243
 
316
244
  # }}}
317
245
 
@@ -337,7 +265,7 @@ class _copy_queue: # noqa: N801
337
265
  pass
338
266
 
339
267
 
340
- _ARRAY_GET_SIZES_CACHE: Dict[Tuple[int, int, int], Tuple[int, int]] = {}
268
+ _ARRAY_GET_SIZES_CACHE: dict[tuple[int, int, int], tuple[int, int]] = {}
341
269
  _BOOL_DTYPE = np.dtype(np.int8)
342
270
  _NOT_PRESENT = object()
343
271
 
@@ -532,22 +460,22 @@ class Array:
532
460
 
533
461
  def __init__(
534
462
  self,
535
- cq: Optional[Union[cl.Context, cl.CommandQueue]],
536
- shape: Union[Tuple[int, ...], int],
463
+ cq: cl.Context | cl.CommandQueue | None,
464
+ shape: tuple[int, ...] | int,
537
465
  dtype: Any,
538
466
  order: str = "C",
539
- allocator: Optional[cl.tools.AllocatorBase] = None,
467
+ allocator: cl_tools.AllocatorBase | None = None,
540
468
  data: Any = None,
541
469
  offset: int = 0,
542
- strides: Optional[Tuple[int, ...]] = None,
543
- events: Optional[List[cl.Event]] = None,
470
+ strides: tuple[int, ...] | None = None,
471
+ events: list[cl.Event] | None = None,
544
472
 
545
473
  # NOTE: following args are used for the fast constructor
546
474
  _flags: Any = None,
547
475
  _fast: bool = False,
548
- _size: Optional[int] = None,
549
- _context: Optional[cl.Context] = None,
550
- _queue: Optional[cl.CommandQueue] = None) -> None:
476
+ _size: int | None = None,
477
+ _context: cl.Context | None = None,
478
+ _queue: cl.CommandQueue | None = None) -> None:
551
479
  if _fast:
552
480
  # Assumptions, should be disabled if not testing
553
481
  if 0:
@@ -2031,13 +1959,13 @@ class Array:
2031
1959
  raise ValueError("new type not compatible with array")
2032
1960
 
2033
1961
  new_shape = (
2034
- self.shape[:min_stride_axis]
2035
- + (self.shape[min_stride_axis] * old_itemsize // itemsize,)
2036
- + self.shape[min_stride_axis+1:])
1962
+ *self.shape[:min_stride_axis],
1963
+ self.shape[min_stride_axis] * old_itemsize // itemsize,
1964
+ *self.shape[min_stride_axis+1:])
2037
1965
  new_strides = (
2038
- self.strides[:min_stride_axis]
2039
- + (self.strides[min_stride_axis] * itemsize // old_itemsize,)
2040
- + self.strides[min_stride_axis+1:])
1966
+ *self.strides[:min_stride_axis],
1967
+ self.strides[min_stride_axis] * itemsize // old_itemsize,
1968
+ *self.strides[min_stride_axis+1:])
2041
1969
 
2042
1970
  return self._new_with_changes(
2043
1971
  self.base_data, self.offset,
@@ -2427,11 +2355,11 @@ def zeros_like(ary):
2427
2355
 
2428
2356
  @dataclass
2429
2357
  class _ArangeInfo:
2430
- start: Optional[int] = None
2431
- stop: Optional[int] = None
2432
- step: Optional[int] = None
2433
- dtype: Optional["np.dtype"] = None
2434
- allocator: Optional[Any] = None
2358
+ start: int | None = None
2359
+ stop: int | None = None
2360
+ step: int | None = None
2361
+ dtype: np.dtype | None = None
2362
+ allocator: Any | None = None
2435
2363
 
2436
2364
 
2437
2365
  @elwise_kernel_runner
@@ -2518,7 +2446,7 @@ def arange(queue, *args, **kwargs):
2518
2446
  raise TypeError("arange requires a dtype argument")
2519
2447
 
2520
2448
  from math import ceil
2521
- size = int(ceil((stop-start)/step))
2449
+ size = ceil((stop-start)/step)
2522
2450
 
2523
2451
  result = Array(queue, (size,), dtype, allocator=inf.allocator)
2524
2452
  result.add_event(_arange_knl(result, start, step, queue=queue))
@@ -2834,9 +2762,9 @@ def concatenate(arrays, axis=0, queue=None, allocator=None):
2834
2762
  for ary in arrays:
2835
2763
  my_len = ary.shape[axis]
2836
2764
  result.setitem(
2837
- full_slice[:axis]
2838
- + (slice(base_idx, base_idx+my_len),)
2839
- + full_slice[axis+1:],
2765
+ (*full_slice[:axis],
2766
+ slice(base_idx, base_idx+my_len),
2767
+ *full_slice[axis+1:]),
2840
2768
  ary)
2841
2769
 
2842
2770
  base_idx += my_len
@@ -2942,7 +2870,7 @@ def stack(arrays, axis=0, queue=None):
2942
2870
  # pyopencl.Array.__setitem__ does not support non-contiguous assignments
2943
2871
  raise NotImplementedError
2944
2872
 
2945
- result_shape = input_shape[:axis] + (len(arrays),) + input_shape[axis:]
2873
+ result_shape = (*input_shape[:axis], len(arrays), *input_shape[axis:])
2946
2874
 
2947
2875
  if __debug__:
2948
2876
  if builtins.any(type(ary) != type(arrays[0]) # noqa: E721
pyopencl/bitonic_sort.py CHANGED
@@ -1,3 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+
1
4
  __copyright__ = """
2
5
  Copyright (c) 2011, Eric Bainville
3
6
  Copyright (c) 2015, Ilya Efimoff
@@ -35,7 +38,7 @@ OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
38
 
36
39
  from functools import reduce
37
40
  from operator import mul
38
- from typing import ClassVar, Dict
41
+ from typing import ClassVar
39
42
 
40
43
  from mako.template import Template
41
44
 
@@ -64,7 +67,7 @@ class BitonicSort:
64
67
  .. automethod:: __call__
65
68
  """
66
69
 
67
- kernels_srcs: ClassVar[Dict[str, str]] = {
70
+ kernels_srcs: ClassVar[dict[str, str]] = {
68
71
  "B2": _tmpl.ParallelBitonic_B2,
69
72
  "B4": _tmpl.ParallelBitonic_B4,
70
73
  "B8": _tmpl.ParallelBitonic_B8,
@@ -1,3 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+
1
4
  __copyright__ = """
2
5
  Copyright (c) 2011, Eric Bainville
3
6
  Copyright (c) 2015, Ilya Efimoff
pyopencl/cache.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """PyOpenCL compiler cache."""
2
+ from __future__ import annotations
2
3
 
3
4
 
4
5
  __copyright__ = "Copyright (C) 2011 Andreas Kloeckner"
@@ -28,7 +29,6 @@ import os
28
29
  import re
29
30
  import sys
30
31
  from dataclasses import dataclass
31
- from typing import List, Optional, Tuple
32
32
 
33
33
  import pyopencl._cl as _cl
34
34
 
@@ -339,8 +339,8 @@ def retrieve_from_cache(cache_dir, cache_key):
339
339
 
340
340
  @dataclass(frozen=True)
341
341
  class _SourceInfo:
342
- dependencies: List[Tuple[str, ...]]
343
- log: Optional[str]
342
+ dependencies: list[tuple[str, ...]]
343
+ log: str | None
344
344
 
345
345
 
346
346
  def _create_built_program_from_source_cached(ctx, src, options_bytes,
@@ -373,7 +373,7 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
373
373
  binaries = []
374
374
  to_be_built_indices = []
375
375
  logs = []
376
- for i, (_device, cache_key) in enumerate(zip(devices, cache_keys)):
376
+ for i, (_device, cache_key) in enumerate(zip(devices, cache_keys, strict=True)):
377
377
  cache_result = retrieve_from_cache(cache_dir, cache_key)
378
378
 
379
379
  if cache_result is None:
@@ -391,7 +391,7 @@ def _create_built_program_from_source_cached(ctx, src, options_bytes,
391
391
 
392
392
  message = (75*"="+"\n").join(
393
393
  f"Build on {dev} succeeded, but said:\n\n{log}"
394
- for dev, log in zip(devices, logs)
394
+ for dev, log in zip(devices, logs, strict=True)
395
395
  if log is not None and log.strip())
396
396
 
397
397
  if message:
pyopencl/capture_call.py CHANGED
@@ -1,3 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+
1
4
  __copyright__ = "Copyright (C) 2013 Andreas Kloeckner"
2
5
 
3
6
  __license__ = """
@@ -21,6 +24,8 @@ THE SOFTWARE.
21
24
  """
22
25
 
23
26
 
27
+ from typing import TYPE_CHECKING, TextIO, cast
28
+
24
29
  import numpy as np
25
30
 
26
31
  from pytools.py_codegen import Indentation, PythonCodeGenerator
@@ -28,9 +33,26 @@ from pytools.py_codegen import Indentation, PythonCodeGenerator
28
33
  import pyopencl as cl
29
34
 
30
35
 
31
- def capture_kernel_call(kernel, output_file, queue, g_size, l_size, *args, **kwargs):
36
+ if TYPE_CHECKING:
37
+ from numpy.typing import DTypeLike
38
+
39
+ from pyopencl.typing import KernelArg, WaitList
40
+
41
+
42
+ def capture_kernel_call(
43
+ kernel: cl.Kernel,
44
+ output_file: str | TextIO,
45
+ queue: cl.CommandQueue,
46
+ g_size: tuple[int, ...],
47
+ l_size: tuple[int, ...] | None,
48
+ *args: KernelArg,
49
+ wait_for: WaitList = None, # pyright: ignore[reportUnusedParameter]
50
+ g_times_l: bool = False,
51
+ allow_empty_ndrange: bool = False,
52
+ global_offset: tuple[int, ...] | None = None,
53
+ ) -> None:
32
54
  try:
33
- source = kernel._source
55
+ source = cast("str | None", kernel._source) # pyright: ignore[reportAttributeAccessIssue]
34
56
  except AttributeError as err:
35
57
  raise RuntimeError("cannot capture call, kernel source not available") from err
36
58
 
@@ -55,7 +77,7 @@ def capture_kernel_call(kernel, output_file, queue, g_size, l_size, *args, **kwa
55
77
 
56
78
  # {{{ invocation
57
79
 
58
- arg_data = []
80
+ arg_data: list[tuple[str, memoryview | bytearray]] = []
59
81
 
60
82
  cg("")
61
83
  cg("")
@@ -65,7 +87,7 @@ def capture_kernel_call(kernel, output_file, queue, g_size, l_size, *args, **kwa
65
87
  cg("queue = cl.CommandQueue(ctx)")
66
88
  cg("")
67
89
 
68
- kernel_args = []
90
+ kernel_args: list[str] = []
69
91
 
70
92
  for i, arg in enumerate(args):
71
93
  if isinstance(arg, cl.Buffer):
@@ -101,22 +123,23 @@ def capture_kernel_call(kernel, output_file, queue, g_size, l_size, *args, **kwa
101
123
 
102
124
  cg("")
103
125
 
104
- g_times_l = kwargs.get("g_times_l", False)
105
126
  if g_times_l:
127
+ assert l_size is not None
106
128
  dim = max(len(g_size), len(l_size))
107
129
  l_size = l_size + (1,) * (dim-len(l_size))
108
130
  g_size = g_size + (1,) * (dim-len(g_size))
109
131
  g_size = tuple(
110
- gs*ls for gs, ls in zip(g_size, l_size))
132
+ gs*ls for gs, ls in zip(g_size, l_size, strict=True))
111
133
 
112
- global_offset = kwargs.get("global_offset", None)
113
134
  if global_offset is not None:
114
135
  kernel_args.append("global_offset=%s" % repr(global_offset))
136
+ if allow_empty_ndrange:
137
+ kernel_args.append("allow_empty_ndrange=%s" % repr(allow_empty_ndrange))
115
138
 
116
139
  cg("prg = cl.Program(ctx, CODE).build()")
117
140
  cg("knl = prg.%s" % kernel.function_name)
118
141
  if hasattr(kernel, "_scalar_arg_dtypes"):
119
- def strify_dtype(d):
142
+ def strify_dtype(d: DTypeLike):
120
143
  if d is None:
121
144
  return "None"
122
145