pyopencl 2025.2.5__cp310-cp310-macosx_11_0_arm64.whl → 2025.2.7__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

pyopencl/elementwise.py CHANGED
@@ -27,17 +27,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
27
27
  OTHER DEALINGS IN THE SOFTWARE.
28
28
  """
29
29
 
30
-
30
+ import builtins
31
31
  import enum
32
- from typing import TYPE_CHECKING, Any
32
+ from typing import TYPE_CHECKING, Any, TextIO, cast
33
33
 
34
34
  import numpy as np
35
+ from typing_extensions import override
35
36
 
36
37
  from pytools import memoize_method
37
38
 
38
39
  import pyopencl as cl
39
40
  from pyopencl.tools import (
40
- DtypedArgument,
41
+ Argument,
41
42
  KernelTemplateBase,
42
43
  ScalarArg,
43
44
  VectorArg,
@@ -48,14 +49,19 @@ from pyopencl.tools import (
48
49
 
49
50
 
50
51
  if TYPE_CHECKING:
51
- from collections.abc import Sequence
52
+ from collections.abc import Callable, Sequence
53
+
54
+ from numpy.typing import DTypeLike
55
+
56
+ from pyopencl.array import Array
57
+ from pyopencl.typing import KernelArg, WaitList
52
58
 
53
59
 
54
60
  # {{{ elementwise kernel code generator
55
61
 
56
62
  def get_elwise_program(
57
63
  context: cl.Context,
58
- arguments: list[DtypedArgument],
64
+ arguments: Sequence[Argument],
59
65
  operation: str, *,
60
66
  name: str = "elwise_kernel",
61
67
  options: Any = None,
@@ -123,27 +129,27 @@ def get_elwise_program(
123
129
 
124
130
  def get_elwise_kernel_and_types(
125
131
  context: cl.Context,
126
- arguments: str | Sequence[DtypedArgument],
132
+ arguments: str | Sequence[Argument],
127
133
  operation: str, *,
128
134
  name: str = "elwise_kernel",
129
135
  options: Any = None,
130
136
  preamble: str = "",
131
137
  use_range: bool = False,
132
- **kwargs: Any) -> tuple[cl.Kernel, list[DtypedArgument]]:
138
+ **kwargs: Any) -> tuple[cl.Kernel, Sequence[Argument]]:
133
139
 
134
140
  from pyopencl.tools import get_arg_offset_adjuster_code, parse_arg_list
135
- parsed_args = parse_arg_list(arguments, with_offset=True)
141
+ parsed_args = list(parse_arg_list(arguments, with_offset=True))
136
142
 
137
143
  auto_preamble = kwargs.pop("auto_preamble", True)
138
144
 
139
- pragmas = []
140
- includes = []
145
+ pragmas: list[str] = []
146
+ includes: list[str] = []
141
147
  have_double_pragma = False
142
148
  have_complex_include = False
143
149
 
144
150
  if auto_preamble:
145
151
  for arg in parsed_args:
146
- if arg.dtype in [np.float64, np.complex128]:
152
+ if arg.dtype.type in [np.float64, np.complex128]:
147
153
  if not have_double_pragma:
148
154
  pragmas.append("""
149
155
  #if __OPENCL_C_VERSION__ < 120
@@ -186,14 +192,15 @@ def get_elwise_kernel_and_types(
186
192
 
187
193
  def get_elwise_kernel(
188
194
  context: cl.Context,
189
- arguments: str | list[DtypedArgument],
195
+ arguments: str | Sequence[Argument],
190
196
  operation: str, *,
191
197
  name: str = "elwise_kernel",
192
198
  options: Any = None, **kwargs: Any) -> cl.Kernel:
193
- """Return a L{pyopencl.Kernel} that performs the same scalar operation
194
- on one or several vectors.
195
199
  """
196
- func, arguments = get_elwise_kernel_and_types(
200
+ :returns: a :class:`pyopencl.Kernel` that performs the same scalar operation
201
+ on one or several vectors.
202
+ """
203
+ func, _arguments = get_elwise_kernel_and_types(
197
204
  context, arguments, operation,
198
205
  name=name, options=options, **kwargs)
199
206
 
@@ -233,19 +240,19 @@ class ElementwiseKernel:
233
240
  def __init__(
234
241
  self,
235
242
  context: cl.Context,
236
- arguments: str | Sequence[DtypedArgument],
243
+ arguments: str | Sequence[Argument],
237
244
  operation: str,
238
245
  name: str = "elwise_kernel",
239
246
  options: Any = None, **kwargs: Any) -> None:
240
- self.context = context
241
- self.arguments = arguments
242
- self.operation = operation
243
- self.name = name
244
- self.options = options
245
- self.kwargs = kwargs
247
+ self.context: cl.Context = context
248
+ self.arguments: str | Sequence[Argument] = arguments
249
+ self.operation: str = operation
250
+ self.name: str = name
251
+ self.options: Any = options
252
+ self.kwargs: dict[str, Any] = kwargs
246
253
 
247
254
  @memoize_method
248
- def get_kernel(self, use_range: bool):
255
+ def get_kernel(self, use_range: bool) -> tuple[cl.Kernel, Sequence[Argument]]:
249
256
  knl, arg_descrs = get_elwise_kernel_and_types(
250
257
  self.context, self.arguments, self.operation,
251
258
  name=self.name, options=self.options,
@@ -268,7 +275,14 @@ class ElementwiseKernel:
268
275
 
269
276
  return knl, arg_descrs
270
277
 
271
- def __call__(self, *args, **kwargs) -> cl.Event:
278
+ def __call__(self,
279
+ *args: KernelArg,
280
+ range: builtins.slice | None = None,
281
+ slice: builtins.slice | None = None,
282
+ capture_as: str | TextIO | None = None,
283
+ queue: cl.CommandQueue | None = None,
284
+ wait_for: WaitList = None,
285
+ **kwargs: Any) -> cl.Event:
272
286
  """
273
287
  Invoke the generated scalar kernel.
274
288
 
@@ -277,16 +291,10 @@ class ElementwiseKernel:
277
291
 
278
292
  |std-enqueue-blurb|
279
293
  """
280
- range_ = kwargs.pop("range", None)
281
- slice_ = kwargs.pop("slice", None)
282
- capture_as = kwargs.pop("capture_as", None)
283
- queue = kwargs.pop("queue", None)
284
- wait_for = kwargs.pop("wait_for", None)
285
-
286
294
  if kwargs:
287
295
  raise TypeError(f"unknown keyword arguments: '{', '.join(kwargs)}'")
288
296
 
289
- use_range = range_ is not None or slice_ is not None
297
+ use_range = range is not None or slice is not None
290
298
  kernel, arg_descrs = self.get_kernel(use_range)
291
299
 
292
300
  if wait_for is None:
@@ -297,14 +305,14 @@ class ElementwiseKernel:
297
305
 
298
306
  # {{{ assemble arg array
299
307
 
300
- repr_vec = None
301
- invocation_args = []
308
+ repr_vec: Array | None = None
309
+ invocation_args: list[KernelArg] = []
302
310
 
303
311
  # non-strict because length arg gets appended below
304
312
  for arg, arg_descr in zip(args, arg_descrs, strict=False):
305
313
  if isinstance(arg_descr, VectorArg):
306
314
  if repr_vec is None:
307
- repr_vec = arg
315
+ repr_vec = cast("Array", arg)
308
316
 
309
317
  invocation_args.append(arg)
310
318
  else:
@@ -317,33 +325,35 @@ class ElementwiseKernel:
317
325
  if queue is None:
318
326
  queue = repr_vec.queue
319
327
 
320
- if slice_ is not None:
321
- if range_ is not None:
328
+ if slice is not None:
329
+ if range is not None:
322
330
  raise TypeError(
323
331
  "may not specify both range and slice keyword arguments")
324
332
 
325
- range_ = slice(*slice_.indices(repr_vec.size))
333
+ range = builtins.slice(*slice.indices(repr_vec.size))
334
+
335
+ assert queue is not None
326
336
 
327
337
  max_wg_size = kernel.get_work_group_info(
328
338
  cl.kernel_work_group_info.WORK_GROUP_SIZE,
329
339
  queue.device)
330
340
 
331
- if range_ is not None:
332
- start = range_.start
341
+ if range is not None:
342
+ start = range.start
333
343
  if start is None:
334
344
  start = 0
335
345
  invocation_args.append(start)
336
- invocation_args.append(range_.stop)
337
- if range_.step is None:
346
+ invocation_args.append(range.stop)
347
+ if range.step is None:
338
348
  step = 1
339
349
  else:
340
- step = range_.step
350
+ step = range.step
341
351
 
342
352
  invocation_args.append(step)
343
353
 
344
354
  from pyopencl.array import _splay
345
355
  gs, ls = _splay(queue.device,
346
- abs(range_.stop - start)//step,
356
+ abs(range.stop - start)//step,
347
357
  max_wg_size)
348
358
  else:
349
359
  invocation_args.append(repr_vec.size)
@@ -365,20 +375,28 @@ class ElementwiseKernel:
365
375
  class ElementwiseTemplate(KernelTemplateBase):
366
376
  def __init__(
367
377
  self,
368
- arguments: str | list[DtypedArgument],
378
+ arguments: str | list[Argument],
369
379
  operation: str,
370
380
  name: str = "elwise",
371
381
  preamble: str = "",
372
382
  template_processor: str | None = None) -> None:
373
383
  super().__init__(template_processor=template_processor)
374
- self.arguments = arguments
375
- self.operation = operation
376
- self.name = name
377
- self.preamble = preamble
378
-
379
- def build_inner(self, context, type_aliases=(), var_values=(),
380
- more_preamble="", more_arguments=(), declare_types=(),
381
- options=None):
384
+ self.arguments: str | list[Argument] = arguments
385
+ self.operation: str = operation
386
+ self.name: str = name
387
+ self.preamble: str = preamble
388
+
389
+ @override
390
+ def build_inner(self,
391
+ context: cl.Context,
392
+ type_aliases: (
393
+ dict[str, np.dtype[Any]]
394
+ | Sequence[tuple[str, np.dtype[Any]]]) = (),
395
+ var_values: dict[str, str] | Sequence[tuple[str, str]] = (),
396
+ more_preamble: str = "",
397
+ more_arguments: str | Sequence[Any] = (),
398
+ declare_types: Sequence[DTypeLike] = (),
399
+ options: Any = None) -> Callable[..., cl.Event]:
382
400
  renderer = self.get_renderer(
383
401
  type_aliases, var_values, context, options)
384
402
 
@@ -387,9 +405,12 @@ class ElementwiseTemplate(KernelTemplateBase):
387
405
  type_decl_preamble = renderer.get_type_decl_preamble(
388
406
  context.devices[0], declare_types, arg_list)
389
407
 
390
- return ElementwiseKernel(context,
391
- arg_list, renderer(self.operation),
392
- name=renderer(self.name), options=options,
408
+ return ElementwiseKernel(
409
+ context,
410
+ arg_list,
411
+ renderer(self.operation),
412
+ name=renderer(self.name),
413
+ options=options,
393
414
  preamble=(
394
415
  type_decl_preamble
395
416
  + "\n"
@@ -426,7 +447,7 @@ def get_decl_and_access_for_kind(name: str, kind: ArgumentKind) -> tuple[str, st
426
447
  elif kind == ArgumentKind.DEV_SCALAR:
427
448
  return f"*{name}", f"{name}[0]"
428
449
  else:
429
- raise AssertionError()
450
+ raise AssertionError
430
451
 
431
452
  # }}}
432
453
 
@@ -434,7 +455,10 @@ def get_decl_and_access_for_kind(name: str, kind: ArgumentKind) -> tuple[str, st
434
455
  # {{{ kernels supporting array functionality
435
456
 
436
457
  @context_dependent_memoize
437
- def get_take_kernel(context, dtype, idx_dtype, vec_count=1):
458
+ def get_take_kernel(context: cl.Context,
459
+ dtype: np.dtype[Any],
460
+ idx_dtype: np.dtype[Any],
461
+ vec_count: int = 1) -> cl.Kernel:
438
462
  idx_tp = dtype_to_ctype(idx_dtype)
439
463
 
440
464
  args = ([VectorArg(dtype, f"dest{i}", with_offset=True)
@@ -449,13 +473,18 @@ def get_take_kernel(context, dtype, idx_dtype, vec_count=1):
449
473
  for i in range(vec_count))
450
474
  )
451
475
 
452
- return get_elwise_kernel(context, args, body,
476
+ return get_elwise_kernel(
477
+ context, args, body,
453
478
  preamble=dtype_to_c_struct(context.devices[0], dtype),
454
479
  name="take")
455
480
 
456
481
 
457
482
  @context_dependent_memoize
458
- def get_take_put_kernel(context, dtype, idx_dtype, with_offsets, vec_count=1):
483
+ def get_take_put_kernel(context: cl.Context,
484
+ dtype: np.dtype[Any],
485
+ idx_dtype: np.dtype[Any],
486
+ with_offsets: bool,
487
+ vec_count: int = 1) -> cl.Kernel:
459
488
  idx_tp = dtype_to_ctype(idx_dtype)
460
489
 
461
490
  args = [
@@ -473,23 +502,27 @@ def get_take_put_kernel(context, dtype, idx_dtype, with_offsets, vec_count=1):
473
502
  ]
474
503
 
475
504
  if with_offsets:
476
- def get_copy_insn(i):
505
+ def get_copy_insn(i: int) -> str:
477
506
  return f"dest{i}[dest_idx] = src{i}[src_idx + offset{i}];"
478
507
  else:
479
- def get_copy_insn(i):
508
+ def get_copy_insn(i: int) -> str:
480
509
  return f"dest{i}[dest_idx] = src{i}[src_idx];"
481
510
 
482
511
  body = ((f"{idx_tp} src_idx = gmem_src_idx[i];\n"
483
512
  f"{idx_tp} dest_idx = gmem_dest_idx[i];\n")
484
513
  + "\n".join(get_copy_insn(i) for i in range(vec_count)))
485
514
 
486
- return get_elwise_kernel(context, args, body,
515
+ return get_elwise_kernel(
516
+ context, args, body,
487
517
  preamble=dtype_to_c_struct(context.devices[0], dtype),
488
518
  name="take_put")
489
519
 
490
520
 
491
521
  @context_dependent_memoize
492
- def get_put_kernel(context, dtype, idx_dtype, vec_count=1):
522
+ def get_put_kernel(context: cl.Context,
523
+ dtype: np.dtype[Any],
524
+ idx_dtype: np.dtype[Any],
525
+ vec_count: int = 1) -> cl.Kernel:
493
526
  idx_tp = dtype_to_ctype(idx_dtype)
494
527
 
495
528
  args = [
@@ -521,7 +554,9 @@ def get_put_kernel(context, dtype, idx_dtype, vec_count=1):
521
554
 
522
555
 
523
556
  @context_dependent_memoize
524
- def get_copy_kernel(context, dtype_dest, dtype_src):
557
+ def get_copy_kernel(context: cl.Context,
558
+ dtype_dest: np.dtype[Any],
559
+ dtype_src: np.dtype[Any]) -> cl.Kernel:
525
560
  src = "src[i]"
526
561
  if dtype_dest.kind == "c" != dtype_src.kind:
527
562
  name = complex_dtype_to_name(dtype_dest)
@@ -535,17 +570,17 @@ def get_copy_kernel(context, dtype_dest, dtype_src):
535
570
  dtype_dest.kind == "V" or dtype_src.kind == "V"):
536
571
  raise TypeError("copying between non-identical struct types")
537
572
 
538
- return get_elwise_kernel(context,
539
- "{tp_dest} *dest, {tp_src} *src".format(
540
- tp_dest=dtype_to_ctype(dtype_dest),
541
- tp_src=dtype_to_ctype(dtype_src),
542
- ),
573
+ ctype_dst = dtype_to_ctype(dtype_dest)
574
+ ctype_src = dtype_to_ctype(dtype_src)
575
+ return get_elwise_kernel(
576
+ context,
577
+ f"{ctype_dst} *dest, {ctype_src} *src",
543
578
  f"dest[i] = {src}",
544
579
  preamble=dtype_to_c_struct(context.devices[0], dtype_dest),
545
580
  name="copy")
546
581
 
547
582
 
548
- def complex_dtype_to_name(dtype) -> str:
583
+ def complex_dtype_to_name(dtype: DTypeLike) -> str:
549
584
  if dtype == np.complex128:
550
585
  return "cdouble"
551
586
  elif dtype == np.complex64:
@@ -554,13 +589,17 @@ def complex_dtype_to_name(dtype) -> str:
554
589
  raise RuntimeError(f"invalid complex type: {dtype}")
555
590
 
556
591
 
557
- def real_dtype(dtype):
592
+ def real_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]:
558
593
  return dtype.type(0).real.dtype
559
594
 
560
595
 
561
596
  @context_dependent_memoize
562
- def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z,
563
- x_is_scalar=False, y_is_scalar=False):
597
+ def get_axpbyz_kernel(context: cl.Context,
598
+ dtype_x: np.dtype[Any],
599
+ dtype_y: np.dtype[Any],
600
+ dtype_z: np.dtype[Any],
601
+ x_is_scalar: bool = False,
602
+ y_is_scalar: bool = False) -> cl.Kernel:
564
603
  result_t = dtype_to_ctype(dtype_z)
565
604
 
566
605
  x_is_complex = dtype_x.kind == "c"
@@ -603,7 +642,11 @@ def get_axpbyz_kernel(context, dtype_x, dtype_y, dtype_z,
603
642
 
604
643
 
605
644
  @context_dependent_memoize
606
- def get_axpbz_kernel(context, dtype_a, dtype_x, dtype_b, dtype_z):
645
+ def get_axpbz_kernel(context: cl.Context,
646
+ dtype_a: np.dtype[Any],
647
+ dtype_x: np.dtype[Any],
648
+ dtype_b: np.dtype[Any],
649
+ dtype_z: np.dtype[Any]):
607
650
  a_is_complex = dtype_a.kind == "c"
608
651
  x_is_complex = dtype_x.kind == "c"
609
652
  b_is_complex = dtype_b.kind == "c"
@@ -664,8 +707,12 @@ def get_axpbz_kernel(context, dtype_a, dtype_x, dtype_b, dtype_z):
664
707
 
665
708
 
666
709
  @context_dependent_memoize
667
- def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z,
668
- x_is_scalar=False, y_is_scalar=False):
710
+ def get_multiply_kernel(context: cl.Context,
711
+ dtype_x: np.dtype[Any],
712
+ dtype_y: np.dtype[Any],
713
+ dtype_z: np.dtype[Any],
714
+ x_is_scalar: bool = False,
715
+ y_is_scalar: bool = False) -> cl.Kernel:
669
716
  x_is_complex = dtype_x.kind == "c"
670
717
  y_is_complex = dtype_y.kind == "c"
671
718
 
@@ -697,8 +744,12 @@ def get_multiply_kernel(context, dtype_x, dtype_y, dtype_z,
697
744
 
698
745
 
699
746
  @context_dependent_memoize
700
- def get_divide_kernel(context, dtype_x, dtype_y, dtype_z,
701
- x_is_scalar=False, y_is_scalar=False):
747
+ def get_divide_kernel(context: cl.Context,
748
+ dtype_x: np.dtype[Any],
749
+ dtype_y: np.dtype[Any],
750
+ dtype_z: np.dtype[Any],
751
+ x_is_scalar: bool = False,
752
+ y_is_scalar: bool = False) -> cl.Kernel:
702
753
  x_is_complex = dtype_x.kind == "c"
703
754
  y_is_complex = dtype_y.kind == "c"
704
755
  z_is_complex = dtype_z.kind == "c"
@@ -740,7 +791,10 @@ def get_divide_kernel(context, dtype_x, dtype_y, dtype_z,
740
791
 
741
792
 
742
793
  @context_dependent_memoize
743
- def get_rdivide_elwise_kernel(context, dtype_x, dtype_y, dtype_z):
794
+ def get_rdivide_elwise_kernel(context: cl.Context,
795
+ dtype_x: np.dtype[Any],
796
+ dtype_y: np.dtype[Any],
797
+ dtype_z: np.dtype[Any]) -> cl.Kernel:
744
798
  # implements y / x!
745
799
  x_is_complex = dtype_x.kind == "c"
746
800
  y_is_complex = dtype_y.kind == "c"
@@ -775,7 +829,7 @@ def get_rdivide_elwise_kernel(context, dtype_x, dtype_y, dtype_z):
775
829
 
776
830
 
777
831
  @context_dependent_memoize
778
- def get_fill_kernel(context, dtype):
832
+ def get_fill_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
779
833
  return get_elwise_kernel(context,
780
834
  "{tp} *z, {tp} a".format(tp=dtype_to_ctype(dtype)),
781
835
  "z[i] = a",
@@ -784,7 +838,7 @@ def get_fill_kernel(context, dtype):
784
838
 
785
839
 
786
840
  @context_dependent_memoize
787
- def get_reverse_kernel(context, dtype):
841
+ def get_reverse_kernel(context: cl.Context, dtype: np.dtype[Any]):
788
842
  return get_elwise_kernel(context,
789
843
  "{tp} *z, {tp} *y".format(tp=dtype_to_ctype(dtype)),
790
844
  "z[i] = y[n-1-i]",
@@ -792,7 +846,7 @@ def get_reverse_kernel(context, dtype):
792
846
 
793
847
 
794
848
  @context_dependent_memoize
795
- def get_arange_kernel(context, dtype):
849
+ def get_arange_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
796
850
  if dtype.kind == "c":
797
851
  expr = (
798
852
  "{root}_add(start, {root}_rmul(i, step))"
@@ -810,8 +864,12 @@ def get_arange_kernel(context, dtype):
810
864
 
811
865
 
812
866
  @context_dependent_memoize
813
- def get_pow_kernel(context, dtype_x, dtype_y, dtype_z,
814
- is_base_array, is_exp_array):
867
+ def get_pow_kernel(context: cl.Context,
868
+ dtype_x: np.dtype[Any],
869
+ dtype_y: np.dtype[Any],
870
+ dtype_z: np.dtype[Any],
871
+ is_base_array: bool,
872
+ is_exp_array: bool) -> cl.Kernel:
815
873
  if is_base_array:
816
874
  x = "x[i]"
817
875
  x_ctype = "{tp_x} *x"
@@ -861,7 +919,10 @@ def get_pow_kernel(context, dtype_x, dtype_y, dtype_z,
861
919
 
862
920
 
863
921
  @context_dependent_memoize
864
- def get_unop_kernel(context, operator, res_dtype, in_dtype):
922
+ def get_unop_kernel(context: cl.Context,
923
+ operator: str,
924
+ res_dtype: np.dtype[Any],
925
+ in_dtype: np.dtype[Any]) -> cl.Kernel:
865
926
  return get_elwise_kernel(context, [
866
927
  VectorArg(res_dtype, "z", with_offset=True),
867
928
  VectorArg(in_dtype, "y", with_offset=True),
@@ -871,7 +932,11 @@ def get_unop_kernel(context, operator, res_dtype, in_dtype):
871
932
 
872
933
 
873
934
  @context_dependent_memoize
874
- def get_array_scalar_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b):
935
+ def get_array_scalar_binop_kernel(context: cl.Context,
936
+ operator: str,
937
+ dtype_res: np.dtype[Any],
938
+ dtype_a: np.dtype[Any],
939
+ dtype_b: np.dtype[Any]) -> cl.Kernel:
875
940
  return get_elwise_kernel(context, [
876
941
  VectorArg(dtype_res, "out", with_offset=True),
877
942
  VectorArg(dtype_a, "a", with_offset=True),
@@ -882,8 +947,13 @@ def get_array_scalar_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b
882
947
 
883
948
 
884
949
  @context_dependent_memoize
885
- def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b,
886
- a_is_scalar=False, b_is_scalar=False):
950
+ def get_array_binop_kernel(context: cl.Context,
951
+ operator: str,
952
+ dtype_res: np.dtype[Any],
953
+ dtype_a: np.dtype[Any],
954
+ dtype_b: np.dtype[Any],
955
+ a_is_scalar: bool = False,
956
+ b_is_scalar: bool = False) -> cl.Kernel:
887
957
  a = "a[0]" if a_is_scalar else "a[i]"
888
958
  b = "b[0]" if b_is_scalar else "b[i]"
889
959
  return get_elwise_kernel(context, [
@@ -896,7 +966,9 @@ def get_array_binop_kernel(context, operator, dtype_res, dtype_a, dtype_b,
896
966
 
897
967
 
898
968
  @context_dependent_memoize
899
- def get_array_scalar_comparison_kernel(context, operator, dtype_a):
969
+ def get_array_scalar_comparison_kernel(context: cl.Context,
970
+ operator: str,
971
+ dtype_a: np.dtype[Any]) -> cl.Kernel:
900
972
  return get_elwise_kernel(context, [
901
973
  VectorArg(np.int8, "out", with_offset=True),
902
974
  VectorArg(dtype_a, "a", with_offset=True),
@@ -907,7 +979,10 @@ def get_array_scalar_comparison_kernel(context, operator, dtype_a):
907
979
 
908
980
 
909
981
  @context_dependent_memoize
910
- def get_array_comparison_kernel(context, operator, dtype_a, dtype_b):
982
+ def get_array_comparison_kernel(context: cl.Context,
983
+ operator: str,
984
+ dtype_a: np.dtype[Any],
985
+ dtype_b: np.dtype[Any]) -> cl.Kernel:
911
986
  return get_elwise_kernel(context, [
912
987
  VectorArg(np.int8, "out", with_offset=True),
913
988
  VectorArg(dtype_a, "a", with_offset=True),
@@ -918,7 +993,10 @@ def get_array_comparison_kernel(context, operator, dtype_a, dtype_b):
918
993
 
919
994
 
920
995
  @context_dependent_memoize
921
- def get_unary_func_kernel(context, func_name, in_dtype, out_dtype=None):
996
+ def get_unary_func_kernel(context: cl.Context,
997
+ func_name: str,
998
+ in_dtype: np.dtype[Any],
999
+ out_dtype: np.dtype[Any] | None = None) -> cl.Kernel:
922
1000
  if out_dtype is None:
923
1001
  out_dtype = in_dtype
924
1002
 
@@ -931,8 +1009,13 @@ def get_unary_func_kernel(context, func_name, in_dtype, out_dtype=None):
931
1009
 
932
1010
 
933
1011
  @context_dependent_memoize
934
- def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype,
935
- preamble="", name=None):
1012
+ def get_binary_func_kernel(context: cl.Context,
1013
+ func_name: str,
1014
+ x_dtype: np.dtype[Any],
1015
+ y_dtype: np.dtype[Any],
1016
+ out_dtype: np.dtype[Any],
1017
+ preamble: str = "",
1018
+ name: str | None = None) -> cl.Kernel:
936
1019
  if name is None:
937
1020
  name = func_name
938
1021
 
@@ -947,8 +1030,13 @@ def get_binary_func_kernel(context, func_name, x_dtype, y_dtype, out_dtype,
947
1030
 
948
1031
 
949
1032
  @context_dependent_memoize
950
- def get_float_binary_func_kernel(context, func_name, x_dtype, y_dtype,
951
- out_dtype, preamble="", name=None):
1033
+ def get_float_binary_func_kernel(context: cl.Context,
1034
+ func_name: str,
1035
+ x_dtype: np.dtype[Any],
1036
+ y_dtype: np.dtype[Any],
1037
+ out_dtype: np.dtype[Any],
1038
+ preamble: str = "",
1039
+ name: str | None = None) -> cl.Kernel:
952
1040
  if name is None:
953
1041
  name = func_name
954
1042
 
@@ -974,15 +1062,19 @@ def get_float_binary_func_kernel(context, func_name, x_dtype, y_dtype,
974
1062
 
975
1063
 
976
1064
  @context_dependent_memoize
977
- def get_fmod_kernel(context, out_dtype=np.float32, arg_dtype=np.float32,
978
- mod_dtype=np.float32):
1065
+ def get_fmod_kernel(context: cl.Context,
1066
+ out_dtype: np.dtype[Any],
1067
+ arg_dtype: np.dtype[Any],
1068
+ mod_dtype: np.dtype[Any]) -> cl.Kernel:
979
1069
  return get_float_binary_func_kernel(context, "fmod", arg_dtype,
980
1070
  mod_dtype, out_dtype)
981
1071
 
982
1072
 
983
1073
  @context_dependent_memoize
984
- def get_modf_kernel(context, int_dtype=np.float32,
985
- frac_dtype=np.float32, x_dtype=np.float32):
1074
+ def get_modf_kernel(context: cl.Context,
1075
+ int_dtype: np.dtype[Any],
1076
+ frac_dtype: np.dtype[Any],
1077
+ x_dtype: np.dtype[Any]):
986
1078
  return get_elwise_kernel(context, [
987
1079
  VectorArg(int_dtype, "intpart", with_offset=True),
988
1080
  VectorArg(frac_dtype, "fracpart", with_offset=True),
@@ -995,8 +1087,10 @@ def get_modf_kernel(context, int_dtype=np.float32,
995
1087
 
996
1088
 
997
1089
  @context_dependent_memoize
998
- def get_frexp_kernel(context, sign_dtype=np.float32, exp_dtype=np.float32,
999
- x_dtype=np.float32):
1090
+ def get_frexp_kernel(context: cl.Context,
1091
+ sign_dtype: np.dtype[Any],
1092
+ exp_dtype: np.dtype[Any],
1093
+ x_dtype: np.dtype[Any]) -> cl.Kernel:
1000
1094
  return get_elwise_kernel(context, [
1001
1095
  VectorArg(sign_dtype, "significand", with_offset=True),
1002
1096
  VectorArg(exp_dtype, "exponent", with_offset=True),
@@ -1011,8 +1105,10 @@ def get_frexp_kernel(context, sign_dtype=np.float32, exp_dtype=np.float32,
1011
1105
 
1012
1106
 
1013
1107
  @context_dependent_memoize
1014
- def get_ldexp_kernel(context, out_dtype=np.float32, sig_dtype=np.float32,
1015
- expt_dtype=np.float32):
1108
+ def get_ldexp_kernel(context: cl.Context,
1109
+ out_dtype: np.dtype[Any],
1110
+ sig_dtype: np.dtype[Any],
1111
+ expt_dtype: np.dtype[Any]) -> cl.Kernel:
1016
1112
  return get_binary_func_kernel(
1017
1113
  context, "_PYOCL_LDEXP", sig_dtype, expt_dtype, out_dtype,
1018
1114
  preamble="#define _PYOCL_LDEXP(x, y) ldexp(x, (int)(y))",
@@ -1020,8 +1116,13 @@ def get_ldexp_kernel(context, out_dtype=np.float32, sig_dtype=np.float32,
1020
1116
 
1021
1117
 
1022
1118
  @context_dependent_memoize
1023
- def get_minmaximum_kernel(context, minmax, dtype_z, dtype_x, dtype_y,
1024
- kind_x: ArgumentKind, kind_y: ArgumentKind):
1119
+ def get_minmaximum_kernel(context: cl.Context,
1120
+ minmax: str,
1121
+ dtype_z: np.dtype[Any],
1122
+ dtype_x: np.dtype[Any],
1123
+ dtype_y: np.dtype[Any],
1124
+ kind_x: ArgumentKind,
1125
+ kind_y: ArgumentKind) -> cl.Kernel:
1025
1126
  if dtype_z.kind == "f":
1026
1127
  reduce_func = f"f{minmax}_nanprop"
1027
1128
  elif dtype_z.kind in "iu":
@@ -1046,8 +1147,11 @@ def get_minmaximum_kernel(context, minmax, dtype_z, dtype_x, dtype_y,
1046
1147
 
1047
1148
 
1048
1149
  @context_dependent_memoize
1049
- def get_bessel_kernel(context, which_func, out_dtype=np.float64,
1050
- order_dtype=np.int32, x_dtype=np.float64):
1150
+ def get_bessel_kernel(context: cl.Context,
1151
+ which_func: str,
1152
+ out_dtype: np.dtype[Any],
1153
+ order_dtype: np.dtype[Any],
1154
+ x_dtype: np.dtype[Any]) -> cl.Kernel:
1051
1155
  if x_dtype.kind != "c":
1052
1156
  return get_elwise_kernel(context, [
1053
1157
  VectorArg(out_dtype, "z", with_offset=True),
@@ -1095,7 +1199,9 @@ def get_bessel_kernel(context, which_func, out_dtype=np.float64,
1095
1199
 
1096
1200
 
1097
1201
  @context_dependent_memoize
1098
- def get_hankel_01_kernel(context, out_dtype, x_dtype):
1202
+ def get_hankel_01_kernel(context: cl.Context,
1203
+ out_dtype: np.dtype[Any],
1204
+ x_dtype: np.dtype[Any]) -> cl.Kernel:
1099
1205
  if x_dtype != np.complex128:
1100
1206
  raise NotImplementedError("non-complex double dtype")
1101
1207
  if x_dtype != out_dtype:
@@ -1125,7 +1231,7 @@ def get_hankel_01_kernel(context, out_dtype, x_dtype):
1125
1231
 
1126
1232
 
1127
1233
  @context_dependent_memoize
1128
- def get_diff_kernel(context, dtype):
1234
+ def get_diff_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
1129
1235
  return get_elwise_kernel(context, [
1130
1236
  VectorArg(dtype, "result", with_offset=True),
1131
1237
  VectorArg(dtype, "array", with_offset=True),
@@ -1136,9 +1242,13 @@ def get_diff_kernel(context, dtype):
1136
1242
 
1137
1243
  @context_dependent_memoize
1138
1244
  def get_if_positive_kernel(
1139
- context, crit_dtype, then_else_dtype,
1140
- is_then_array, is_else_array,
1141
- is_then_scalar, is_else_scalar):
1245
+ context: cl.Context,
1246
+ crit_dtype: np.dtype[Any],
1247
+ then_else_dtype: np.dtype[Any],
1248
+ is_then_array: bool,
1249
+ is_else_array: bool,
1250
+ is_then_scalar: bool,
1251
+ is_else_scalar: bool) -> cl.Kernel:
1142
1252
  if is_then_array:
1143
1253
  then_ = "then_[0]" if is_then_scalar else "then_[i]"
1144
1254
  then_arg = VectorArg(then_else_dtype, "then_", with_offset=True)
@@ -1165,7 +1275,7 @@ def get_if_positive_kernel(
1165
1275
 
1166
1276
 
1167
1277
  @context_dependent_memoize
1168
- def get_logical_not_kernel(context, in_dtype):
1278
+ def get_logical_not_kernel(context: cl.Context, in_dtype: np.dtype[Any]) -> cl.Kernel:
1169
1279
  return get_elwise_kernel(context, [
1170
1280
  VectorArg(np.int8, "z", with_offset=True),
1171
1281
  VectorArg(in_dtype, "y", with_offset=True),