pyopencl 2025.2.4__cp312-cp312-macosx_11_0_arm64.whl → 2025.2.6__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

pyopencl/elementwise.py CHANGED
@@ -27,17 +27,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
27
27
  OTHER DEALINGS IN THE SOFTWARE.
28
28
  """
29
29
 
30
-
30
+ import builtins
31
31
  import enum
32
- from typing import Any
32
+ from typing import TYPE_CHECKING, Any, TextIO, cast
33
33
 
34
34
  import numpy as np
35
+ from typing_extensions import override
35
36
 
36
37
  from pytools import memoize_method
37
38
 
38
39
  import pyopencl as cl
39
40
  from pyopencl.tools import (
40
- DtypedArgument,
41
+ Argument,
41
42
  KernelTemplateBase,
42
43
  ScalarArg,
43
44
  VectorArg,
@@ -47,11 +48,20 @@ from pyopencl.tools import (
47
48
  )
48
49
 
49
50
 
51
+ if TYPE_CHECKING:
52
+ from collections.abc import Callable, Sequence
53
+
54
+ from numpy.typing import DTypeLike
55
+
56
+ from pyopencl.array import Array
57
+ from pyopencl.typing import KernelArg, WaitList
58
+
59
+
50
60
  # {{{ elementwise kernel code generator
51
61
 
52
62
  def get_elwise_program(
53
63
  context: cl.Context,
54
- arguments: list[DtypedArgument],
64
+ arguments: Sequence[Argument],
55
65
  operation: str, *,
56
66
  name: str = "elwise_kernel",
57
67
  options: Any = None,
@@ -119,27 +129,27 @@ def get_elwise_program(
119
129
 
120
130
  def get_elwise_kernel_and_types(
121
131
  context: cl.Context,
122
- arguments: str | list[DtypedArgument],
132
+ arguments: str | Sequence[Argument],
123
133
  operation: str, *,
124
134
  name: str = "elwise_kernel",
125
135
  options: Any = None,
126
136
  preamble: str = "",
127
137
  use_range: bool = False,
128
- **kwargs: Any) -> tuple[cl.Kernel, list[DtypedArgument]]:
138
+ **kwargs: Any) -> tuple[cl.Kernel, Sequence[Argument]]:
129
139
 
130
140
  from pyopencl.tools import get_arg_offset_adjuster_code, parse_arg_list
131
- parsed_args = parse_arg_list(arguments, with_offset=True)
141
+ parsed_args = list(parse_arg_list(arguments, with_offset=True))
132
142
 
133
143
  auto_preamble = kwargs.pop("auto_preamble", True)
134
144
 
135
- pragmas = []
136
- includes = []
145
+ pragmas: list[str] = []
146
+ includes: list[str] = []
137
147
  have_double_pragma = False
138
148
  have_complex_include = False
139
149
 
140
150
  if auto_preamble:
141
151
  for arg in parsed_args:
142
- if arg.dtype in [np.float64, np.complex128]:
152
+ if arg.dtype.type in [np.float64, np.complex128]:
143
153
  if not have_double_pragma:
144
154
  pragmas.append("""
145
155
  #if __OPENCL_C_VERSION__ < 120
@@ -182,14 +192,15 @@ def get_elwise_kernel_and_types(
182
192
 
183
193
  def get_elwise_kernel(
184
194
  context: cl.Context,
185
- arguments: str | list[DtypedArgument],
195
+ arguments: str | Sequence[Argument],
186
196
  operation: str, *,
187
197
  name: str = "elwise_kernel",
188
198
  options: Any = None, **kwargs: Any) -> cl.Kernel:
189
- """Return a L{pyopencl.Kernel} that performs the same scalar operation
190
- on one or several vectors.
191
199
  """
192
- func, arguments = get_elwise_kernel_and_types(
200
+ :returns: a :class:`pyopencl.Kernel` that performs the same scalar operation
201
+ on one or several vectors.
202
+ """
203
+ func, _arguments = get_elwise_kernel_and_types(
193
204
  context, arguments, operation,
194
205
  name=name, options=options, **kwargs)
195
206
 
@@ -229,19 +240,19 @@ class ElementwiseKernel:
229
240
  def __init__(
230
241
  self,
231
242
  context: cl.Context,
232
- arguments: str | list[DtypedArgument],
243
+ arguments: str | Sequence[Argument],
233
244
  operation: str,
234
245
  name: str = "elwise_kernel",
235
246
  options: Any = None, **kwargs: Any) -> None:
236
- self.context = context
237
- self.arguments = arguments
238
- self.operation = operation
239
- self.name = name
240
- self.options = options
241
- self.kwargs = kwargs
247
+ self.context: cl.Context = context
248
+ self.arguments: str | Sequence[Argument] = arguments
249
+ self.operation: str = operation
250
+ self.name: str = name
251
+ self.options: Any = options
252
+ self.kwargs: dict[str, Any] = kwargs
242
253
 
243
254
  @memoize_method
244
- def get_kernel(self, use_range: bool):
255
+ def get_kernel(self, use_range: bool) -> tuple[cl.Kernel, Sequence[Argument]]:
245
256
  knl, arg_descrs = get_elwise_kernel_and_types(
246
257
  self.context, self.arguments, self.operation,
247
258
  name=self.name, options=self.options,
@@ -264,7 +275,14 @@ class ElementwiseKernel:
264
275
 
265
276
  return knl, arg_descrs
266
277
 
267
- def __call__(self, *args, **kwargs) -> cl.Event:
278
+ def __call__(self,
279
+ *args: KernelArg,
280
+ range: builtins.slice | None = None,
281
+ slice: builtins.slice | None = None,
282
+ capture_as: str | TextIO | None = None,
283
+ queue: cl.CommandQueue | None = None,
284
+ wait_for: WaitList = None,
285
+ **kwargs: Any) -> cl.Event:
268
286
  """
269
287
  Invoke the generated scalar kernel.
270
288
 
@@ -273,16 +291,10 @@ class ElementwiseKernel:
273
291
 
274
292
  |std-enqueue-blurb|
275
293
  """
276
- range_ = kwargs.pop("range", None)
277
- slice_ = kwargs.pop("slice", None)
278
- capture_as = kwargs.pop("capture_as", None)
279
- queue = kwargs.pop("queue", None)
280
- wait_for = kwargs.pop("wait_for", None)
281
-
282
294
  if kwargs:
283
295
  raise TypeError(f"unknown keyword arguments: '{', '.join(kwargs)}'")
284
296
 
285
- use_range = range_ is not None or slice_ is not None
297
+ use_range = range is not None or slice is not None
286
298
  kernel, arg_descrs = self.get_kernel(use_range)
287
299
 
288
300
  if wait_for is None:
@@ -293,14 +305,14 @@ class ElementwiseKernel:
293
305
 
294
306
  # {{{ assemble arg array
295
307
 
296
- repr_vec = None
297
- invocation_args = []
308
+ repr_vec: Array | None = None
309
+ invocation_args: list[KernelArg] = []
298
310
 
299
311
  # non-strict because length arg gets appended below
300
312
  for arg, arg_descr in zip(args, arg_descrs, strict=False):
301
313
  if isinstance(arg_descr, VectorArg):
302
314
  if repr_vec is None:
303
- repr_vec = arg
315
+ repr_vec = cast("Array", arg)
304
316
 
305
317
  invocation_args.append(arg)
306
318
  else:
@@ -313,33 +325,35 @@ class ElementwiseKernel:
313
325
  if queue is None:
314
326
  queue = repr_vec.queue
315
327
 
316
- if slice_ is not None:
317
- if range_ is not None:
328
+ if slice is not None:
329
+ if range is not None:
318
330
  raise TypeError(
319
331
  "may not specify both range and slice keyword arguments")
320
332
 
321
- range_ = slice(*slice_.indices(repr_vec.size))
333
+ range = builtins.slice(*slice.indices(repr_vec.size))
334
+
335
+ assert queue is not None
322
336
 
323
337
  max_wg_size = kernel.get_work_group_info(
324
338
  cl.kernel_work_group_info.WORK_GROUP_SIZE,
325
339
  queue.device)
326
340
 
327
- if range_ is not None:
328
- start = range_.start
341
+ if range is not None:
342
+ start = range.start
329
343
  if start is None:
330
344
  start = 0
331
345
  invocation_args.append(start)
332
- invocation_args.append(range_.stop)
333
- if range_.step is None:
346
+ invocation_args.append(range.stop)
347
+ if range.step is None:
334
348
  step = 1
335
349
  else:
336
- step = range_.step
350
+ step = range.step
337
351
 
338
352
  invocation_args.append(step)
339
353
 
340
354
  from pyopencl.array import _splay
341
355
  gs, ls = _splay(queue.device,
342
- abs(range_.stop - start)//step,
356
+ abs(range.stop - start)//step,
343
357
  max_wg_size)
344
358
  else:
345
359
  invocation_args.append(repr_vec.size)
@@ -361,20 +375,28 @@ class ElementwiseKernel:
361
375
  class ElementwiseTemplate(KernelTemplateBase):
362
376
  def __init__(
363
377
  self,
364
- arguments: str | list[DtypedArgument],
378
+ arguments: str | list[Argument],
365
379
  operation: str,
366
380
  name: str = "elwise",
367
381
  preamble: str = "",
368
382
  template_processor: str | None = None) -> None:
369
383
  super().__init__(template_processor=template_processor)
370
- self.arguments = arguments
371
- self.operation = operation
372
- self.name = name
373
- self.preamble = preamble
374
-
375
- def build_inner(self, context, type_aliases=(), var_values=(),
376
- more_preamble="", more_arguments=(), declare_types=(),
377
- options=None):
384
+ self.arguments: str | list[Argument] = arguments
385
+ self.operation: str = operation
386
+ self.name: str = name
387
+ self.preamble: str = preamble
388
+
389
+ @override
390
+ def build_inner(self,
391
+ context: cl.Context,
392
+ type_aliases: (
393
+ dict[str, np.dtype[Any]]
394
+ | Sequence[tuple[str, np.dtype[Any]]]) = (),
395
+ var_values: dict[str, str] | Sequence[tuple[str, str]] = (),
396
+ more_preamble: str = "",
397
+ more_arguments: str | Sequence[Any] = (),
398
+ declare_types: Sequence[DTypeLike] = (),
399
+ options: Any = None) -> Callable[..., cl.Event]:
378
400
  renderer = self.get_renderer(
379
401
  type_aliases, var_values, context, options)
380
402
 
@@ -383,9 +405,12 @@ class ElementwiseTemplate(KernelTemplateBase):
383
405
  type_decl_preamble = renderer.get_type_decl_preamble(
384
406
  context.devices[0], declare_types, arg_list)
385
407
 
386
- return ElementwiseKernel(context,
387
- arg_list, renderer(self.operation),
388
- name=renderer(self.name), options=options,
408
+ return ElementwiseKernel(
409
+ context,
410
+ arg_list,
411
+ renderer(self.operation),
412
+ name=renderer(self.name),
413
+ options=options,
389
414
  preamble=(
390
415
  type_decl_preamble
391
416
  + "\n"
@@ -422,7 +447,7 @@ def get_decl_and_access_for_kind(name: str, kind: ArgumentKind) -> tuple[str, st
422
447
  elif kind == ArgumentKind.DEV_SCALAR:
423
448
  return f"*{name}", f"{name}[0]"
424
449
  else:
425
- raise AssertionError()
450
+ raise AssertionError
426
451
 
427
452
  # }}}
428
453
 
@@ -430,7 +455,10 @@ def get_decl_and_access_for_kind(name: str, kind: ArgumentKind) -> tuple[str, st
430
455
  # {{{ kernels supporting array functionality
431
456
 
432
457
  @context_dependent_memoize
433
- def get_take_kernel(context, dtype, idx_dtype, vec_count=1):
458
+ def get_take_kernel(context: cl.Context,
459
+ dtype: np.dtype[Any],
460
+ idx_dtype: np.dtype[Any],
461
+ vec_count: int = 1) -> cl.Kernel:
434
462
  idx_tp = dtype_to_ctype(idx_dtype)
435
463
 
436
464
  args = ([VectorArg(dtype, f"dest{i}", with_offset=True)
@@ -445,13 +473,18 @@ def get_take_kernel(context, dtype, idx_dtype, vec_count=1):
445
473
  for i in range(vec_count))
446
474
  )
447
475
 
448
- return get_elwise_kernel(context, args, body,
476
+ return get_elwise_kernel(
477
+ context, args, body,
449
478
  preamble=dtype_to_c_struct(context.devices[0], dtype),
450
479
  name="take")
451
480
 
452
481
 
453
482
  @context_dependent_memoize
454
- def get_take_put_kernel(context, dtype, idx_dtype, with_offsets, vec_count=1):
483
+ def get_take_put_kernel(context: cl.Context,
484
+ dtype: np.dtype[Any],
485
+ idx_dtype: np.dtype[Any],
486
+ with_offsets: bool,
487
+ vec_count: int = 1) -> cl.Kernel:
455
488
  idx_tp = dtype_to_ctype(idx_dtype)
456
489
 
457
490
  args = [
@@ -469,23 +502,27 @@ def get_take_put_kernel(context, dtype, idx_dtype, with_offsets, vec_count=1):
469
502
  ]
470
503
 
471
504
  if with_offsets:
472
- def get_copy_insn(i):
505
+ def get_copy_insn(i: int) -> str:
473
506
  return f"dest{i}[dest_idx] = src{i}[src_idx + offset{i}];"
474
507
  else:
475
- def get_copy_insn(i):
508
+ def get_copy_insn(i: int) -> str:
476
509
  return f"dest{i}[dest_idx] = src{i}[src_idx];"
477
510
 
478
511
  body = ((f"{idx_tp} src_idx = gmem_src_idx[i];\n"
479
512
  f"{idx_tp} dest_idx = gmem_dest_idx[i];\n")
480
513
  + "\n".join(get_copy_insn(i) for i in range(vec_count)))
481
514
 
482
- return get_elwise_kernel(context, args, body,
515
+ return get_elwise_kernel(
516
+ context, args, body,
483
517
  preamble=dtype_to_c_struct(context.devices[0], dtype),
484
518
  name="take_put")
485
519
 
486
520
 
487
521
  @context_dependent_memoize
488
- def get_put_kernel(context, dtype, idx_dtype, vec_count=1):
522
+ def get_put_kernel(context: cl.Context,
523
+ dtype: np.dtype[Any],
524
+ idx_dtype: np.dtype[Any],
525
+ vec_count: int = 1) -> cl.Kernel:
489
526
  idx_tp = dtype_to_ctype(idx_dtype)
490
527
 
491
528
  args = [
@@ -517,7 +554,9 @@ def get_put_kernel(context, dtype, idx_dtype, vec_count=1):
517
554
 
518
555
 
519
556
  @context_dependent_memoize
520
- def get_copy_kernel(context, dtype_dest, dtype_src):
557
+ def get_copy_kernel(context: cl.Context,
558
+ dtype_dest: np.dtype[Any],
559
+ dtype_src: np.dtype[Any]) -> cl.Kernel:
521
560
  src = "src[i]"
522
561
  if dtype_dest.kind == "c" != dtype_src.kind:
523
562
  name = complex_dtype_to_name(dtype_dest)
@@ -531,17 +570,17 @@ def get_copy_kernel(context, dtype_dest, dtype_src):
531
570
  dtype_dest.kind == "V" or dtype_src.kind == "V"):
532
571
  raise TypeError("copying between non-identical struct types")
533
572
 
534
- return get_elwise_kernel(context,
535
- "{tp_dest} *dest, {tp_src} *src".format(
536
- tp_dest=dtype_to_ctype(dtype_dest),
537
- tp_src=dtype_to_ctype(dtype_src),
538
- ),
573
+ ctype_dst = dtype_to_ctype(dtype_dest)
574
+ ctype_src = dtype_to_ctype(dtype_src)
575
+ return get_elwise_kernel(
576
+ context,
577
+ f"{ctype_dst} *dest, {ctype_src} *src",
539
578
  f"dest[i] = {src}",
540
579
  preamble=dtype_to_c_struct(context.devices[0], dtype_dest),
541
580
  name="copy")
542
581
 
543
582
 
544
- def complex_dtype_to_name(dtype) -> str:
583
+ def complex_dtype_to_name(dtype: DTypeLike) -> str:
545
584
  if dtype == np.complex128:
546
585
  return "cdouble"
547
586
  elif dtype == np.complex64:
@@ -550,13 +589,17 @@ def complex_dtype_to_name(dtype) -> str:
550
589
  raise RuntimeError(f"invalid complex type: {dtype}")
551
590
 
552
591
 
553
- def real_dtype(dtype):
592
+ def real_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]:
554
593
  return dtype.type(0).real.dtype
555
594
 
556
595
 
557
596
  @context_dependent_memoize
558
- def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z,
559
- x_is_scalar=False, y_is_scalar=False):
597
+ def get_axpbyz_kernel(context: cl.Context,
598
+ dtype_x: np.dtype[Any],
599
+ dtype_y: np.dtype[Any],
600
+ dtype_z: np.dtype[Any],
601
+ x_is_scalar: bool = False,
602
+ y_is_scalar: bool = False) -> cl.Kernel:
560
603
  result_t = dtype_to_ctype(dtype_z)
561
604
 
562
605
  x_is_complex = dtype_x.kind == "c"
@@ -599,7 +642,11 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z,
599
642
 
600
643
 
601
644
  @context_dependent_memoize
602
- def get_axpbz_kernel(context, dtype_a, dtype_x, dtype_b, dtype_z):
645
+ def get_axpbz_kernel(context: cl.Context,
646
+ dtype_a: np.dtype[Any],
647
+ dtype_x: np.dtype[Any],
648
+ dtype_b: np.dtype[Any],
649
+ dtype_z: np.dtype[Any]):
603
650
  a_is_complex = dtype_a.kind == "c"
604
651
  x_is_complex = dtype_x.kind == "c"
605
652
  b_is_complex = dtype_b.kind == "c"
@@ -660,8 +707,12 @@ def get_axpbz_kernel(context, dtype_a, dtype_x, dtype_b, dtype_z):
660
707
 
661
708
 
662
709
  @context_dependent_memoize
663
- def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z,
664
- x_is_scalar=False, y_is_scalar=False):
710
+ def get_multiply_kernel(context: cl.Context,
711
+ dtype_x: np.dtype[Any],
712
+ dtype_y: np.dtype[Any],
713
+ dtype_z: np.dtype[Any],
714
+ x_is_scalar: bool = False,
715
+ y_is_scalar: bool = False) -> cl.Kernel:
665
716
  x_is_complex = dtype_x.kind == "c"
666
717
  y_is_complex = dtype_y.kind == "c"
667
718
 
@@ -693,8 +744,12 @@ def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z,
693
744
 
694
745
 
695
746
  @context_dependent_memoize
696
- def get_divide_kernel(context, dtype_x, dtype_y, dtype_z,
697
- x_is_scalar=False, y_is_scalar=False):
747
+ def get_divide_kernel(context: cl.Context,
748
+ dtype_x: np.dtype[Any],
749
+ dtype_y: np.dtype[Any],
750
+ dtype_z: np.dtype[Any],
751
+ x_is_scalar: bool = False,
752
+ y_is_scalar: bool = False) -> cl.Kernel:
698
753
  x_is_complex = dtype_x.kind == "c"
699
754
  y_is_complex = dtype_y.kind == "c"
700
755
  z_is_complex = dtype_z.kind == "c"
@@ -736,7 +791,10 @@ def get_divide_kernel(context, dtype_x, dtype_y, dtype_z,
736
791
 
737
792
 
738
793
  @context_dependent_memoize
739
- def get_rdivide_elwise_kernel(context, dtype_x, dtype_y, dtype_z):
794
+ def get_rdivide_elwise_kernel(context: cl.Context,
795
+ dtype_x: np.dtype[Any],
796
+ dtype_y: np.dtype[Any],
797
+ dtype_z: np.dtype[Any]) -> cl.Kernel:
740
798
  # implements y / x!
741
799
  x_is_complex = dtype_x.kind == "c"
742
800
  y_is_complex = dtype_y.kind == "c"
@@ -771,7 +829,7 @@ def get_rdivide_elwise_kernel(context, dtype_x, dtype_y, dtype_z):
771
829
 
772
830
 
773
831
  @context_dependent_memoize
774
- def get_fill_kernel(context, dtype):
832
+ def get_fill_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
775
833
  return get_elwise_kernel(context,
776
834
  "{tp} *z, {tp} a".format(tp=dtype_to_ctype(dtype)),
777
835
  "z[i] = a",
@@ -780,7 +838,7 @@ def get_fill_kernel(context, dtype):
780
838
 
781
839
 
782
840
  @context_dependent_memoize
783
- def get_reverse_kernel(context, dtype):
841
+ def get_reverse_kernel(context: cl.Context, dtype: np.dtype[Any]):
784
842
  return get_elwise_kernel(context,
785
843
  "{tp} *z, {tp} *y".format(tp=dtype_to_ctype(dtype)),
786
844
  "z[i] = y[n-1-i]",
@@ -788,7 +846,7 @@ def get_reverse_kernel(context, dtype):
788
846
 
789
847
 
790
848
  @context_dependent_memoize
791
- def get_arange_kernel(context, dtype):
849
+ def get_arange_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
792
850
  if dtype.kind == "c":
793
851
  expr = (
794
852
  "{root}_add(start, {root}_rmul(i, step))"
@@ -806,8 +864,12 @@ def get_arange_kernel(context, dtype):
806
864
 
807
865
 
808
866
  @context_dependent_memoize
809
- def get_pow_kernel(context, dtype_x, dtype_y, dtype_z,
810
- is_base_array, is_exp_array):
867
+ def get_pow_kernel(context: cl.Context,
868
+ dtype_x: np.dtype[Any],
869
+ dtype_y: np.dtype[Any],
870
+ dtype_z: np.dtype[Any],
871
+ is_base_array: bool,
872
+ is_exp_array: bool) -> cl.Kernel:
811
873
  if is_base_array:
812
874
  x = "x[i]"
813
875
  x_ctype = "{tp_x} *x"
@@ -857,7 +919,10 @@ def get_pow_kernel(context, dtype_x, dtype_y, dtype_z,
857
919
 
858
920
 
859
921
  @context_dependent_memoize
860
- def get_unop_kernel(context, operator, res_dtype, in_dtype):
922
+ def get_unop_kernel(context: cl.Context,
923
+ operator: str,
924
+ res_dtype: np.dtype[Any],
925
+ in_dtype: np.dtype[Any]) -> cl.Kernel:
861
926
  return get_elwise_kernel(context, [
862
927
  VectorArg(res_dtype, "z", with_offset=True),
863
928
  VectorArg(in_dtype, "y", with_offset=True),
@@ -867,7 +932,11 @@ def get_unop_kernel(context, operator, res_dtype, in_dtype):
867
932
 
868
933
 
869
934
  @context_dependent_memoize
870
- def get_array_scalar_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b):
935
+ def get_array_scalar_binop_kernel(context: cl.Context,
936
+ operator: str,
937
+ dtype_res: np.dtype[Any],
938
+ dtype_a: np.dtype[Any],
939
+ dtype_b: np.dtype[Any]) -> cl.Kernel:
871
940
  return get_elwise_kernel(context, [
872
941
  VectorArg(dtype_res, "out", with_offset=True),
873
942
  VectorArg(dtype_a, "a", with_offset=True),
@@ -878,8 +947,13 @@ def get_array_scalar_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b
878
947
 
879
948
 
880
949
  @context_dependent_memoize
881
- def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b,
882
- a_is_scalar=False, b_is_scalar=False):
950
+ def get_array_binop_kernel(context: cl.Context,
951
+ operator: str,
952
+ dtype_res: np.dtype[Any],
953
+ dtype_a: np.dtype[Any],
954
+ dtype_b: np.dtype[Any],
955
+ a_is_scalar: bool = False,
956
+ b_is_scalar: bool = False) -> cl.Kernel:
883
957
  a = "a[0]" if a_is_scalar else "a[i]"
884
958
  b = "b[0]" if b_is_scalar else "b[i]"
885
959
  return get_elwise_kernel(context, [
@@ -892,7 +966,9 @@ def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b,
892
966
 
893
967
 
894
968
  @context_dependent_memoize
895
- def get_array_scalar_comparison_kernel(context, operator, dtype_a):
969
+ def get_array_scalar_comparison_kernel(context: cl.Context,
970
+ operator: str,
971
+ dtype_a: np.dtype[Any]) -> cl.Kernel:
896
972
  return get_elwise_kernel(context, [
897
973
  VectorArg(np.int8, "out", with_offset=True),
898
974
  VectorArg(dtype_a, "a", with_offset=True),
@@ -903,7 +979,10 @@ def get_array_scalar_comparison_kernel(context, operator, dtype_a):
903
979
 
904
980
 
905
981
  @context_dependent_memoize
906
- def get_array_comparison_kernel(context, operator, dtype_a, dtype_b):
982
+ def get_array_comparison_kernel(context: cl.Context,
983
+ operator: str,
984
+ dtype_a: np.dtype[Any],
985
+ dtype_b: np.dtype[Any]) -> cl.Kernel:
907
986
  return get_elwise_kernel(context, [
908
987
  VectorArg(np.int8, "out", with_offset=True),
909
988
  VectorArg(dtype_a, "a", with_offset=True),
@@ -914,7 +993,10 @@ def get_array_comparison_kernel(context, operator, dtype_a, dtype_b):
914
993
 
915
994
 
916
995
  @context_dependent_memoize
917
- def get_unary_func_kernel(context, func_name, in_dtype, out_dtype=None):
996
+ def get_unary_func_kernel(context: cl.Context,
997
+ func_name: str,
998
+ in_dtype: np.dtype[Any],
999
+ out_dtype: np.dtype[Any] | None = None) -> cl.Kernel:
918
1000
  if out_dtype is None:
919
1001
  out_dtype = in_dtype
920
1002
 
@@ -927,8 +1009,13 @@ def get_unary_func_kernel(context, func_name, in_dtype, out_dtype=None):
927
1009
 
928
1010
 
929
1011
  @context_dependent_memoize
930
- def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype,
931
- preamble="", name=None):
1012
+ def get_binary_func_kernel(context: cl.Context,
1013
+ func_name: str,
1014
+ x_dtype: np.dtype[Any],
1015
+ y_dtype: np.dtype[Any],
1016
+ out_dtype: np.dtype[Any],
1017
+ preamble: str = "",
1018
+ name: str | None = None) -> cl.Kernel:
932
1019
  if name is None:
933
1020
  name = func_name
934
1021
 
@@ -943,8 +1030,13 @@ def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype,
943
1030
 
944
1031
 
945
1032
  @context_dependent_memoize
946
- def get_float_binary_func_kernel(context, func_name, x_dtype, y_dtype,
947
- out_dtype, preamble="", name=None):
1033
+ def get_float_binary_func_kernel(context: cl.Context,
1034
+ func_name: str,
1035
+ x_dtype: np.dtype[Any],
1036
+ y_dtype: np.dtype[Any],
1037
+ out_dtype: np.dtype[Any],
1038
+ preamble: str = "",
1039
+ name: str | None = None) -> cl.Kernel:
948
1040
  if name is None:
949
1041
  name = func_name
950
1042
 
@@ -970,15 +1062,19 @@ def get_float_binary_func_kernel(context, func_name, x_dtype, y_dtype,
970
1062
 
971
1063
 
972
1064
  @context_dependent_memoize
973
- def get_fmod_kernel(context, out_dtype=np.float32, arg_dtype=np.float32,
974
- mod_dtype=np.float32):
1065
+ def get_fmod_kernel(context: cl.Context,
1066
+ out_dtype: np.dtype[Any],
1067
+ arg_dtype: np.dtype[Any],
1068
+ mod_dtype: np.dtype[Any]) -> cl.Kernel:
975
1069
  return get_float_binary_func_kernel(context, "fmod", arg_dtype,
976
1070
  mod_dtype, out_dtype)
977
1071
 
978
1072
 
979
1073
  @context_dependent_memoize
980
- def get_modf_kernel(context, int_dtype=np.float32,
981
- frac_dtype=np.float32, x_dtype=np.float32):
1074
+ def get_modf_kernel(context: cl.Context,
1075
+ int_dtype: np.dtype[Any],
1076
+ frac_dtype: np.dtype[Any],
1077
+ x_dtype: np.dtype[Any]):
982
1078
  return get_elwise_kernel(context, [
983
1079
  VectorArg(int_dtype, "intpart", with_offset=True),
984
1080
  VectorArg(frac_dtype, "fracpart", with_offset=True),
@@ -991,8 +1087,10 @@ def get_modf_kernel(context, int_dtype=np.float32,
991
1087
 
992
1088
 
993
1089
  @context_dependent_memoize
994
- def get_frexp_kernel(context, sign_dtype=np.float32, exp_dtype=np.float32,
995
- x_dtype=np.float32):
1090
+ def get_frexp_kernel(context: cl.Context,
1091
+ sign_dtype: np.dtype[Any],
1092
+ exp_dtype: np.dtype[Any],
1093
+ x_dtype: np.dtype[Any]) -> cl.Kernel:
996
1094
  return get_elwise_kernel(context, [
997
1095
  VectorArg(sign_dtype, "significand", with_offset=True),
998
1096
  VectorArg(exp_dtype, "exponent", with_offset=True),
@@ -1007,8 +1105,10 @@ def get_frexp_kernel(context, sign_dtype=np.float32, exp_dtype=np.float32,
1007
1105
 
1008
1106
 
1009
1107
  @context_dependent_memoize
1010
- def get_ldexp_kernel(context, out_dtype=np.float32, sig_dtype=np.float32,
1011
- expt_dtype=np.float32):
1108
+ def get_ldexp_kernel(context: cl.Context,
1109
+ out_dtype: np.dtype[Any],
1110
+ sig_dtype: np.dtype[Any],
1111
+ expt_dtype: np.dtype[Any]) -> cl.Kernel:
1012
1112
  return get_binary_func_kernel(
1013
1113
  context, "_PYOCL_LDEXP", sig_dtype, expt_dtype, out_dtype,
1014
1114
  preamble="#define _PYOCL_LDEXP(x, y) ldexp(x, (int)(y))",
@@ -1016,8 +1116,13 @@ def get_ldexp_kernel(context, out_dtype=np.float32, sig_dtype=np.float32,
1016
1116
 
1017
1117
 
1018
1118
  @context_dependent_memoize
1019
- def get_minmaximum_kernel(context, minmax, dtype_z, dtype_x, dtype_y,
1020
- kind_x: ArgumentKind, kind_y: ArgumentKind):
1119
+ def get_minmaximum_kernel(context: cl.Context,
1120
+ minmax: str,
1121
+ dtype_z: np.dtype[Any],
1122
+ dtype_x: np.dtype[Any],
1123
+ dtype_y: np.dtype[Any],
1124
+ kind_x: ArgumentKind,
1125
+ kind_y: ArgumentKind) -> cl.Kernel:
1021
1126
  if dtype_z.kind == "f":
1022
1127
  reduce_func = f"f{minmax}_nanprop"
1023
1128
  elif dtype_z.kind in "iu":
@@ -1042,8 +1147,11 @@ def get_minmaximum_kernel(context, minmax, dtype_z, dtype_x, dtype_y,
1042
1147
 
1043
1148
 
1044
1149
  @context_dependent_memoize
1045
- def get_bessel_kernel(context, which_func, out_dtype=np.float64,
1046
- order_dtype=np.int32, x_dtype=np.float64):
1150
+ def get_bessel_kernel(context: cl.Context,
1151
+ which_func: str,
1152
+ out_dtype: np.dtype[Any],
1153
+ order_dtype: np.dtype[Any],
1154
+ x_dtype: np.dtype[Any]) -> cl.Kernel:
1047
1155
  if x_dtype.kind != "c":
1048
1156
  return get_elwise_kernel(context, [
1049
1157
  VectorArg(out_dtype, "z", with_offset=True),
@@ -1091,7 +1199,9 @@ def get_bessel_kernel(context, which_func, out_dtype=np.float64,
1091
1199
 
1092
1200
 
1093
1201
  @context_dependent_memoize
1094
- def get_hankel_01_kernel(context, out_dtype, x_dtype):
1202
+ def get_hankel_01_kernel(context: cl.Context,
1203
+ out_dtype: np.dtype[Any],
1204
+ x_dtype: np.dtype[Any]) -> cl.Kernel:
1095
1205
  if x_dtype != np.complex128:
1096
1206
  raise NotImplementedError("non-complex double dtype")
1097
1207
  if x_dtype != out_dtype:
@@ -1121,7 +1231,7 @@ def get_hankel_01_kernel(context, out_dtype, x_dtype):
1121
1231
 
1122
1232
 
1123
1233
  @context_dependent_memoize
1124
- def get_diff_kernel(context, dtype):
1234
+ def get_diff_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
1125
1235
  return get_elwise_kernel(context, [
1126
1236
  VectorArg(dtype, "result", with_offset=True),
1127
1237
  VectorArg(dtype, "array", with_offset=True),
@@ -1132,9 +1242,13 @@ def get_diff_kernel(context, dtype):
1132
1242
 
1133
1243
  @context_dependent_memoize
1134
1244
  def get_if_positive_kernel(
1135
- context, crit_dtype, then_else_dtype,
1136
- is_then_array, is_else_array,
1137
- is_then_scalar, is_else_scalar):
1245
+ context: cl.Context,
1246
+ crit_dtype: np.dtype[Any],
1247
+ then_else_dtype: np.dtype[Any],
1248
+ is_then_array: bool,
1249
+ is_else_array: bool,
1250
+ is_then_scalar: bool,
1251
+ is_else_scalar: bool) -> cl.Kernel:
1138
1252
  if is_then_array:
1139
1253
  then_ = "then_[0]" if is_then_scalar else "then_[i]"
1140
1254
  then_arg = VectorArg(then_else_dtype, "then_", with_offset=True)
@@ -1161,7 +1275,7 @@ def get_if_positive_kernel(
1161
1275
 
1162
1276
 
1163
1277
  @context_dependent_memoize
1164
- def get_logical_not_kernel(context, in_dtype):
1278
+ def get_logical_not_kernel(context: cl.Context, in_dtype: np.dtype[Any]) -> cl.Kernel:
1165
1279
  return get_elwise_kernel(context, [
1166
1280
  VectorArg(np.int8, "z", with_offset=True),
1167
1281
  VectorArg(in_dtype, "y", with_offset=True),