pyopencl 2025.2.7__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pyopencl might be problematic. Click here for more details.
- pyopencl/__init__.py +1995 -0
- pyopencl/_cl.cpython-314-darwin.so +0 -0
- pyopencl/_cl.pyi +2009 -0
- pyopencl/_cluda.py +57 -0
- pyopencl/_monkeypatch.py +1104 -0
- pyopencl/_mymako.py +17 -0
- pyopencl/algorithm.py +1454 -0
- pyopencl/array.py +3530 -0
- pyopencl/bitonic_sort.py +245 -0
- pyopencl/bitonic_sort_templates.py +597 -0
- pyopencl/cache.py +535 -0
- pyopencl/capture_call.py +200 -0
- pyopencl/characterize/__init__.py +461 -0
- pyopencl/characterize/performance.py +240 -0
- pyopencl/cl/pyopencl-airy.cl +324 -0
- pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
- pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
- pyopencl/cl/pyopencl-bessel-y.cl +435 -0
- pyopencl/cl/pyopencl-complex.h +303 -0
- pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
- pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
- pyopencl/cl/pyopencl-random123/array.h +325 -0
- pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
- pyopencl/cl/pyopencl-random123/philox.cl +486 -0
- pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
- pyopencl/clmath.py +281 -0
- pyopencl/clrandom.py +412 -0
- pyopencl/cltypes.py +217 -0
- pyopencl/compyte/.gitignore +21 -0
- pyopencl/compyte/__init__.py +0 -0
- pyopencl/compyte/array.py +211 -0
- pyopencl/compyte/dtypes.py +314 -0
- pyopencl/compyte/pyproject.toml +49 -0
- pyopencl/elementwise.py +1288 -0
- pyopencl/invoker.py +417 -0
- pyopencl/ipython_ext.py +70 -0
- pyopencl/py.typed +0 -0
- pyopencl/reduction.py +815 -0
- pyopencl/scan.py +1921 -0
- pyopencl/tools.py +1680 -0
- pyopencl/typing.py +61 -0
- pyopencl/version.py +11 -0
- pyopencl-2025.2.7.dist-info/METADATA +108 -0
- pyopencl-2025.2.7.dist-info/RECORD +46 -0
- pyopencl-2025.2.7.dist-info/WHEEL +6 -0
- pyopencl-2025.2.7.dist-info/licenses/LICENSE +282 -0
pyopencl/elementwise.py
ADDED
|
@@ -0,0 +1,1288 @@
|
|
|
1
|
+
"""Elementwise functionality."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__copyright__ = "Copyright (C) 2009 Andreas Kloeckner"
|
|
6
|
+
|
|
7
|
+
__license__ = """
|
|
8
|
+
Permission is hereby granted, free of charge, to any person
|
|
9
|
+
obtaining a copy of this software and associated documentation
|
|
10
|
+
files (the "Software"), to deal in the Software without
|
|
11
|
+
restriction, including without limitation the rights to use,
|
|
12
|
+
copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
13
|
+
copies of the Software, and to permit persons to whom the
|
|
14
|
+
Software is furnished to do so, subject to the following
|
|
15
|
+
conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be
|
|
18
|
+
included in all copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
21
|
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
|
22
|
+
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
23
|
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
|
24
|
+
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
|
25
|
+
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
26
|
+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
|
27
|
+
OTHER DEALINGS IN THE SOFTWARE.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import builtins
|
|
31
|
+
import enum
|
|
32
|
+
from typing import TYPE_CHECKING, Any, TextIO, cast
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
from typing_extensions import override
|
|
36
|
+
|
|
37
|
+
from pytools import memoize_method
|
|
38
|
+
|
|
39
|
+
import pyopencl as cl
|
|
40
|
+
from pyopencl.tools import (
|
|
41
|
+
Argument,
|
|
42
|
+
KernelTemplateBase,
|
|
43
|
+
ScalarArg,
|
|
44
|
+
VectorArg,
|
|
45
|
+
context_dependent_memoize,
|
|
46
|
+
dtype_to_c_struct,
|
|
47
|
+
dtype_to_ctype,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if TYPE_CHECKING:
|
|
52
|
+
from collections.abc import Callable, Sequence
|
|
53
|
+
|
|
54
|
+
from numpy.typing import DTypeLike
|
|
55
|
+
|
|
56
|
+
from pyopencl.array import Array
|
|
57
|
+
from pyopencl.typing import KernelArg, WaitList
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# {{{ elementwise kernel code generator
|
|
61
|
+
|
|
62
|
+
def get_elwise_program(
|
|
63
|
+
context: cl.Context,
|
|
64
|
+
arguments: Sequence[Argument],
|
|
65
|
+
operation: str, *,
|
|
66
|
+
name: str = "elwise_kernel",
|
|
67
|
+
options: Any = None,
|
|
68
|
+
preamble: str = "",
|
|
69
|
+
loop_prep: str = "",
|
|
70
|
+
after_loop: str = "",
|
|
71
|
+
use_range: bool = False) -> cl.Program:
|
|
72
|
+
|
|
73
|
+
if use_range:
|
|
74
|
+
body = r"""//CL//
|
|
75
|
+
if (step < 0)
|
|
76
|
+
{
|
|
77
|
+
for (i = start + (work_group_start + lid)*step;
|
|
78
|
+
i > stop; i += gsize*step)
|
|
79
|
+
{
|
|
80
|
+
%(operation)s;
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
else
|
|
84
|
+
{
|
|
85
|
+
for (i = start + (work_group_start + lid)*step;
|
|
86
|
+
i < stop; i += gsize*step)
|
|
87
|
+
{
|
|
88
|
+
%(operation)s;
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
"""
|
|
92
|
+
else:
|
|
93
|
+
body = """//CL//
|
|
94
|
+
for (i = work_group_start + lid; i < n; i += gsize)
|
|
95
|
+
{
|
|
96
|
+
%(operation)s;
|
|
97
|
+
}
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
import re
|
|
101
|
+
return_match = re.search(r"\breturn\b", operation)
|
|
102
|
+
if return_match is not None:
|
|
103
|
+
from warnings import warn
|
|
104
|
+
warn("Using a 'return' statement in an element-wise operation will "
|
|
105
|
+
"likely lead to incorrect results. Use "
|
|
106
|
+
"PYOPENCL_ELWISE_CONTINUE instead.",
|
|
107
|
+
stacklevel=3)
|
|
108
|
+
|
|
109
|
+
source = (f"""//CL//
|
|
110
|
+
{preamble}
|
|
111
|
+
|
|
112
|
+
#define PYOPENCL_ELWISE_CONTINUE continue
|
|
113
|
+
|
|
114
|
+
__kernel void {name}({", ".join(arg.declarator() for arg in arguments)})
|
|
115
|
+
{{
|
|
116
|
+
int lid = get_local_id(0);
|
|
117
|
+
int gsize = get_global_size(0);
|
|
118
|
+
int work_group_start = get_local_size(0)*get_group_id(0);
|
|
119
|
+
long i;
|
|
120
|
+
|
|
121
|
+
{loop_prep};
|
|
122
|
+
{body % {"operation": operation}}
|
|
123
|
+
{after_loop};
|
|
124
|
+
}}
|
|
125
|
+
""")
|
|
126
|
+
|
|
127
|
+
return cl.Program(context, source).build(options)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_elwise_kernel_and_types(
|
|
131
|
+
context: cl.Context,
|
|
132
|
+
arguments: str | Sequence[Argument],
|
|
133
|
+
operation: str, *,
|
|
134
|
+
name: str = "elwise_kernel",
|
|
135
|
+
options: Any = None,
|
|
136
|
+
preamble: str = "",
|
|
137
|
+
use_range: bool = False,
|
|
138
|
+
**kwargs: Any) -> tuple[cl.Kernel, Sequence[Argument]]:
|
|
139
|
+
|
|
140
|
+
from pyopencl.tools import get_arg_offset_adjuster_code, parse_arg_list
|
|
141
|
+
parsed_args = list(parse_arg_list(arguments, with_offset=True))
|
|
142
|
+
|
|
143
|
+
auto_preamble = kwargs.pop("auto_preamble", True)
|
|
144
|
+
|
|
145
|
+
pragmas: list[str] = []
|
|
146
|
+
includes: list[str] = []
|
|
147
|
+
have_double_pragma = False
|
|
148
|
+
have_complex_include = False
|
|
149
|
+
|
|
150
|
+
if auto_preamble:
|
|
151
|
+
for arg in parsed_args:
|
|
152
|
+
if arg.dtype.type in [np.float64, np.complex128]:
|
|
153
|
+
if not have_double_pragma:
|
|
154
|
+
pragmas.append("""
|
|
155
|
+
#if __OPENCL_C_VERSION__ < 120
|
|
156
|
+
#pragma OPENCL EXTENSION cl_khr_fp64: enable
|
|
157
|
+
#endif
|
|
158
|
+
#define PYOPENCL_DEFINE_CDOUBLE
|
|
159
|
+
""")
|
|
160
|
+
have_double_pragma = True
|
|
161
|
+
if arg.dtype.kind == "c":
|
|
162
|
+
if not have_complex_include:
|
|
163
|
+
includes.append("#include <pyopencl-complex.h>\n")
|
|
164
|
+
have_complex_include = True
|
|
165
|
+
|
|
166
|
+
if pragmas or includes:
|
|
167
|
+
preamble = "\n".join(pragmas+includes) + "\n" + preamble
|
|
168
|
+
|
|
169
|
+
if use_range:
|
|
170
|
+
parsed_args.extend([
|
|
171
|
+
ScalarArg(np.intp, "start"),
|
|
172
|
+
ScalarArg(np.intp, "stop"),
|
|
173
|
+
ScalarArg(np.intp, "step"),
|
|
174
|
+
])
|
|
175
|
+
else:
|
|
176
|
+
parsed_args.append(ScalarArg(np.intp, "n"))
|
|
177
|
+
|
|
178
|
+
loop_prep = kwargs.pop("loop_prep", "")
|
|
179
|
+
loop_prep = get_arg_offset_adjuster_code(parsed_args) + loop_prep
|
|
180
|
+
prg = get_elwise_program(
|
|
181
|
+
context, parsed_args, operation,
|
|
182
|
+
name=name, options=options, preamble=preamble,
|
|
183
|
+
use_range=use_range, loop_prep=loop_prep, **kwargs)
|
|
184
|
+
|
|
185
|
+
from pyopencl.tools import get_arg_list_arg_types
|
|
186
|
+
|
|
187
|
+
kernel = getattr(prg, name)
|
|
188
|
+
kernel.set_scalar_arg_dtypes(get_arg_list_arg_types(parsed_args))
|
|
189
|
+
|
|
190
|
+
return kernel, parsed_args
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_elwise_kernel(
|
|
194
|
+
context: cl.Context,
|
|
195
|
+
arguments: str | Sequence[Argument],
|
|
196
|
+
operation: str, *,
|
|
197
|
+
name: str = "elwise_kernel",
|
|
198
|
+
options: Any = None, **kwargs: Any) -> cl.Kernel:
|
|
199
|
+
"""
|
|
200
|
+
:returns: a :class:`pyopencl.Kernel` that performs the same scalar operation
|
|
201
|
+
on one or several vectors.
|
|
202
|
+
"""
|
|
203
|
+
func, _arguments = get_elwise_kernel_and_types(
|
|
204
|
+
context, arguments, operation,
|
|
205
|
+
name=name, options=options, **kwargs)
|
|
206
|
+
|
|
207
|
+
return func
|
|
208
|
+
|
|
209
|
+
# }}}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# {{{ ElementwiseKernel driver
|
|
213
|
+
|
|
214
|
+
class ElementwiseKernel:
|
|
215
|
+
"""
|
|
216
|
+
A kernel that takes a number of scalar or vector *arguments* and performs
|
|
217
|
+
an *operation* specified as a snippet of C on these arguments.
|
|
218
|
+
|
|
219
|
+
:arg arguments: a string formatted as a C argument list.
|
|
220
|
+
:arg operation: a snippet of C that carries out the desired 'map'
|
|
221
|
+
operation. The current index is available as the variable *i*.
|
|
222
|
+
*operation* may contain the statement ``PYOPENCL_ELWISE_CONTINUE``,
|
|
223
|
+
which will terminate processing for the current element.
|
|
224
|
+
:arg name: the function name as which the kernel is compiled
|
|
225
|
+
:arg options: passed unmodified to :meth:`pyopencl.Program.build`.
|
|
226
|
+
:arg preamble: a piece of C source code that gets inserted outside of the
|
|
227
|
+
function context in the elementwise operation's kernel source code.
|
|
228
|
+
|
|
229
|
+
.. warning :: Using a ``return`` statement in *operation* will lead to
|
|
230
|
+
incorrect results, as some elements may never get processed. Use
|
|
231
|
+
``PYOPENCL_ELWISE_CONTINUE`` instead.
|
|
232
|
+
|
|
233
|
+
.. versionchanged:: 2013.1
|
|
234
|
+
|
|
235
|
+
Added ``PYOPENCL_ELWISE_CONTINUE``.
|
|
236
|
+
|
|
237
|
+
.. automethod:: __call__
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
context: cl.Context,
|
|
243
|
+
arguments: str | Sequence[Argument],
|
|
244
|
+
operation: str,
|
|
245
|
+
name: str = "elwise_kernel",
|
|
246
|
+
options: Any = None, **kwargs: Any) -> None:
|
|
247
|
+
self.context: cl.Context = context
|
|
248
|
+
self.arguments: str | Sequence[Argument] = arguments
|
|
249
|
+
self.operation: str = operation
|
|
250
|
+
self.name: str = name
|
|
251
|
+
self.options: Any = options
|
|
252
|
+
self.kwargs: dict[str, Any] = kwargs
|
|
253
|
+
|
|
254
|
+
@memoize_method
|
|
255
|
+
def get_kernel(self, use_range: bool) -> tuple[cl.Kernel, Sequence[Argument]]:
|
|
256
|
+
knl, arg_descrs = get_elwise_kernel_and_types(
|
|
257
|
+
self.context, self.arguments, self.operation,
|
|
258
|
+
name=self.name, options=self.options,
|
|
259
|
+
use_range=use_range, **self.kwargs)
|
|
260
|
+
|
|
261
|
+
for arg in arg_descrs:
|
|
262
|
+
if isinstance(arg, VectorArg) and not arg.with_offset:
|
|
263
|
+
from warnings import warn
|
|
264
|
+
warn(
|
|
265
|
+
f"ElementwiseKernel '{self.name}' used with VectorArgs "
|
|
266
|
+
"that do not have offset support enabled. This usage is "
|
|
267
|
+
"deprecated. Just pass with_offset=True to VectorArg, "
|
|
268
|
+
"everything should sort itself out automatically.",
|
|
269
|
+
DeprecationWarning, stacklevel=2)
|
|
270
|
+
|
|
271
|
+
if not any(isinstance(arg, VectorArg) for arg in arg_descrs):
|
|
272
|
+
raise RuntimeError(
|
|
273
|
+
"ElementwiseKernel can only be used with functions that have "
|
|
274
|
+
"at least one vector argument")
|
|
275
|
+
|
|
276
|
+
return knl, arg_descrs
|
|
277
|
+
|
|
278
|
+
def __call__(self,
|
|
279
|
+
*args: KernelArg,
|
|
280
|
+
range: builtins.slice | None = None,
|
|
281
|
+
slice: builtins.slice | None = None,
|
|
282
|
+
capture_as: str | TextIO | None = None,
|
|
283
|
+
queue: cl.CommandQueue | None = None,
|
|
284
|
+
wait_for: WaitList = None,
|
|
285
|
+
**kwargs: Any) -> cl.Event:
|
|
286
|
+
"""
|
|
287
|
+
Invoke the generated scalar kernel.
|
|
288
|
+
|
|
289
|
+
The arguments may either be scalars or :class:`pyopencl.array.Array`
|
|
290
|
+
instances.
|
|
291
|
+
|
|
292
|
+
|std-enqueue-blurb|
|
|
293
|
+
"""
|
|
294
|
+
if kwargs:
|
|
295
|
+
raise TypeError(f"unknown keyword arguments: '{', '.join(kwargs)}'")
|
|
296
|
+
|
|
297
|
+
use_range = range is not None or slice is not None
|
|
298
|
+
kernel, arg_descrs = self.get_kernel(use_range)
|
|
299
|
+
|
|
300
|
+
if wait_for is None:
|
|
301
|
+
wait_for = []
|
|
302
|
+
else:
|
|
303
|
+
# We'll be modifying it below.
|
|
304
|
+
wait_for = list(wait_for)
|
|
305
|
+
|
|
306
|
+
# {{{ assemble arg array
|
|
307
|
+
|
|
308
|
+
repr_vec: Array | None = None
|
|
309
|
+
invocation_args: list[KernelArg] = []
|
|
310
|
+
|
|
311
|
+
# non-strict because length arg gets appended below
|
|
312
|
+
for arg, arg_descr in zip(args, arg_descrs, strict=False):
|
|
313
|
+
if isinstance(arg_descr, VectorArg):
|
|
314
|
+
if repr_vec is None:
|
|
315
|
+
repr_vec = cast("Array", arg)
|
|
316
|
+
|
|
317
|
+
invocation_args.append(arg)
|
|
318
|
+
else:
|
|
319
|
+
invocation_args.append(arg)
|
|
320
|
+
|
|
321
|
+
assert repr_vec is not None
|
|
322
|
+
|
|
323
|
+
# }}}
|
|
324
|
+
|
|
325
|
+
if queue is None:
|
|
326
|
+
queue = repr_vec.queue
|
|
327
|
+
|
|
328
|
+
if slice is not None:
|
|
329
|
+
if range is not None:
|
|
330
|
+
raise TypeError(
|
|
331
|
+
"may not specify both range and slice keyword arguments")
|
|
332
|
+
|
|
333
|
+
range = builtins.slice(*slice.indices(repr_vec.size))
|
|
334
|
+
|
|
335
|
+
assert queue is not None
|
|
336
|
+
|
|
337
|
+
max_wg_size = kernel.get_work_group_info(
|
|
338
|
+
cl.kernel_work_group_info.WORK_GROUP_SIZE,
|
|
339
|
+
queue.device)
|
|
340
|
+
|
|
341
|
+
if range is not None:
|
|
342
|
+
start = range.start
|
|
343
|
+
if start is None:
|
|
344
|
+
start = 0
|
|
345
|
+
invocation_args.append(start)
|
|
346
|
+
invocation_args.append(range.stop)
|
|
347
|
+
if range.step is None:
|
|
348
|
+
step = 1
|
|
349
|
+
else:
|
|
350
|
+
step = range.step
|
|
351
|
+
|
|
352
|
+
invocation_args.append(step)
|
|
353
|
+
|
|
354
|
+
from pyopencl.array import _splay
|
|
355
|
+
gs, ls = _splay(queue.device,
|
|
356
|
+
abs(range.stop - start)//step,
|
|
357
|
+
max_wg_size)
|
|
358
|
+
else:
|
|
359
|
+
invocation_args.append(repr_vec.size)
|
|
360
|
+
gs, ls = repr_vec._get_sizes(queue, max_wg_size)
|
|
361
|
+
|
|
362
|
+
if capture_as is not None:
|
|
363
|
+
kernel.set_args(*invocation_args)
|
|
364
|
+
kernel.capture_call(
|
|
365
|
+
capture_as, queue,
|
|
366
|
+
gs, ls, *invocation_args, wait_for=wait_for)
|
|
367
|
+
|
|
368
|
+
return kernel(queue, gs, ls, *invocation_args, wait_for=wait_for)
|
|
369
|
+
|
|
370
|
+
# }}}
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# {{{ template
|
|
374
|
+
|
|
375
|
+
class ElementwiseTemplate(KernelTemplateBase):
|
|
376
|
+
def __init__(
|
|
377
|
+
self,
|
|
378
|
+
arguments: str | list[Argument],
|
|
379
|
+
operation: str,
|
|
380
|
+
name: str = "elwise",
|
|
381
|
+
preamble: str = "",
|
|
382
|
+
template_processor: str | None = None) -> None:
|
|
383
|
+
super().__init__(template_processor=template_processor)
|
|
384
|
+
self.arguments: str | list[Argument] = arguments
|
|
385
|
+
self.operation: str = operation
|
|
386
|
+
self.name: str = name
|
|
387
|
+
self.preamble: str = preamble
|
|
388
|
+
|
|
389
|
+
@override
|
|
390
|
+
def build_inner(self,
|
|
391
|
+
context: cl.Context,
|
|
392
|
+
type_aliases: (
|
|
393
|
+
dict[str, np.dtype[Any]]
|
|
394
|
+
| Sequence[tuple[str, np.dtype[Any]]]) = (),
|
|
395
|
+
var_values: dict[str, str] | Sequence[tuple[str, str]] = (),
|
|
396
|
+
more_preamble: str = "",
|
|
397
|
+
more_arguments: str | Sequence[Any] = (),
|
|
398
|
+
declare_types: Sequence[DTypeLike] = (),
|
|
399
|
+
options: Any = None) -> Callable[..., cl.Event]:
|
|
400
|
+
renderer = self.get_renderer(
|
|
401
|
+
type_aliases, var_values, context, options)
|
|
402
|
+
|
|
403
|
+
arg_list = renderer.render_argument_list(
|
|
404
|
+
self.arguments, more_arguments, with_offset=True)
|
|
405
|
+
type_decl_preamble = renderer.get_type_decl_preamble(
|
|
406
|
+
context.devices[0], declare_types, arg_list)
|
|
407
|
+
|
|
408
|
+
return ElementwiseKernel(
|
|
409
|
+
context,
|
|
410
|
+
arg_list,
|
|
411
|
+
renderer(self.operation),
|
|
412
|
+
name=renderer(self.name),
|
|
413
|
+
options=options,
|
|
414
|
+
preamble=(
|
|
415
|
+
type_decl_preamble
|
|
416
|
+
+ "\n"
|
|
417
|
+
+ renderer(self.preamble + "\n" + more_preamble)),
|
|
418
|
+
auto_preamble=False)
|
|
419
|
+
|
|
420
|
+
# }}}
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# {{{ argument kinds
|
|
424
|
+
|
|
425
|
+
class ArgumentKind(enum.Enum):
|
|
426
|
+
ARRAY = enum.auto()
|
|
427
|
+
DEV_SCALAR = enum.auto()
|
|
428
|
+
SCALAR = enum.auto()
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def get_argument_kind(v: Any) -> ArgumentKind:
|
|
432
|
+
from pyopencl.array import Array
|
|
433
|
+
if isinstance(v, Array):
|
|
434
|
+
if v.shape == ():
|
|
435
|
+
return ArgumentKind.DEV_SCALAR
|
|
436
|
+
else:
|
|
437
|
+
return ArgumentKind.ARRAY
|
|
438
|
+
else:
|
|
439
|
+
return ArgumentKind.SCALAR
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def get_decl_and_access_for_kind(name: str, kind: ArgumentKind) -> tuple[str, str]:
|
|
443
|
+
if kind == ArgumentKind.ARRAY:
|
|
444
|
+
return f"*{name}", f"{name}[i]"
|
|
445
|
+
elif kind == ArgumentKind.SCALAR:
|
|
446
|
+
return f"{name}", name
|
|
447
|
+
elif kind == ArgumentKind.DEV_SCALAR:
|
|
448
|
+
return f"*{name}", f"{name}[0]"
|
|
449
|
+
else:
|
|
450
|
+
raise AssertionError
|
|
451
|
+
|
|
452
|
+
# }}}
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
# {{{ kernels supporting array functionality
|
|
456
|
+
|
|
457
|
+
@context_dependent_memoize
|
|
458
|
+
def get_take_kernel(context: cl.Context,
|
|
459
|
+
dtype: np.dtype[Any],
|
|
460
|
+
idx_dtype: np.dtype[Any],
|
|
461
|
+
vec_count: int = 1) -> cl.Kernel:
|
|
462
|
+
idx_tp = dtype_to_ctype(idx_dtype)
|
|
463
|
+
|
|
464
|
+
args = ([VectorArg(dtype, f"dest{i}", with_offset=True)
|
|
465
|
+
for i in range(vec_count)]
|
|
466
|
+
+ [VectorArg(dtype, f"src{i}", with_offset=True)
|
|
467
|
+
for i in range(vec_count)]
|
|
468
|
+
+ [VectorArg(idx_dtype, "idx", with_offset=True)])
|
|
469
|
+
body = (
|
|
470
|
+
f"{idx_tp} src_idx = idx[i];\n"
|
|
471
|
+
+ "\n".join(
|
|
472
|
+
f"dest{i}[i] = src{i}[src_idx];"
|
|
473
|
+
for i in range(vec_count))
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
return get_elwise_kernel(
|
|
477
|
+
context, args, body,
|
|
478
|
+
preamble=dtype_to_c_struct(context.devices[0], dtype),
|
|
479
|
+
name="take")
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@context_dependent_memoize
|
|
483
|
+
def get_take_put_kernel(context: cl.Context,
|
|
484
|
+
dtype: np.dtype[Any],
|
|
485
|
+
idx_dtype: np.dtype[Any],
|
|
486
|
+
with_offsets: bool,
|
|
487
|
+
vec_count: int = 1) -> cl.Kernel:
|
|
488
|
+
idx_tp = dtype_to_ctype(idx_dtype)
|
|
489
|
+
|
|
490
|
+
args = [
|
|
491
|
+
VectorArg(dtype, f"dest{i}")
|
|
492
|
+
for i in range(vec_count)
|
|
493
|
+
] + [
|
|
494
|
+
VectorArg(idx_dtype, "gmem_dest_idx", with_offset=True),
|
|
495
|
+
VectorArg(idx_dtype, "gmem_src_idx", with_offset=True),
|
|
496
|
+
] + [
|
|
497
|
+
VectorArg(dtype, f"src{i}", with_offset=True)
|
|
498
|
+
for i in range(vec_count)
|
|
499
|
+
] + [
|
|
500
|
+
ScalarArg(idx_dtype, f"offset{i}")
|
|
501
|
+
for i in range(vec_count) if with_offsets
|
|
502
|
+
]
|
|
503
|
+
|
|
504
|
+
if with_offsets:
|
|
505
|
+
def get_copy_insn(i: int) -> str:
|
|
506
|
+
return f"dest{i}[dest_idx] = src{i}[src_idx + offset{i}];"
|
|
507
|
+
else:
|
|
508
|
+
def get_copy_insn(i: int) -> str:
|
|
509
|
+
return f"dest{i}[dest_idx] = src{i}[src_idx];"
|
|
510
|
+
|
|
511
|
+
body = ((f"{idx_tp} src_idx = gmem_src_idx[i];\n"
|
|
512
|
+
f"{idx_tp} dest_idx = gmem_dest_idx[i];\n")
|
|
513
|
+
+ "\n".join(get_copy_insn(i) for i in range(vec_count)))
|
|
514
|
+
|
|
515
|
+
return get_elwise_kernel(
|
|
516
|
+
context, args, body,
|
|
517
|
+
preamble=dtype_to_c_struct(context.devices[0], dtype),
|
|
518
|
+
name="take_put")
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
@context_dependent_memoize
|
|
522
|
+
def get_put_kernel(context: cl.Context,
|
|
523
|
+
dtype: np.dtype[Any],
|
|
524
|
+
idx_dtype: np.dtype[Any],
|
|
525
|
+
vec_count: int = 1) -> cl.Kernel:
|
|
526
|
+
idx_tp = dtype_to_ctype(idx_dtype)
|
|
527
|
+
|
|
528
|
+
args = [
|
|
529
|
+
VectorArg(dtype, f"dest{i}", with_offset=True)
|
|
530
|
+
for i in range(vec_count)
|
|
531
|
+
] + [
|
|
532
|
+
VectorArg(idx_dtype, "gmem_dest_idx", with_offset=True),
|
|
533
|
+
] + [
|
|
534
|
+
VectorArg(dtype, f"src{i}", with_offset=True)
|
|
535
|
+
for i in range(vec_count)
|
|
536
|
+
] + [
|
|
537
|
+
VectorArg(np.uint8, "use_fill", with_offset=True)
|
|
538
|
+
] + [
|
|
539
|
+
VectorArg(np.int64, "val_ary_lengths", with_offset=True)
|
|
540
|
+
]
|
|
541
|
+
|
|
542
|
+
body = (
|
|
543
|
+
f"{idx_tp} dest_idx = gmem_dest_idx[i];\n"
|
|
544
|
+
+ "\n".join(
|
|
545
|
+
f"dest{i}[dest_idx] = (use_fill[{i}] ? src{i}[0] : "
|
|
546
|
+
f"src{i}[i % val_ary_lengths[{i}]]);"
|
|
547
|
+
for i in range(vec_count)
|
|
548
|
+
)
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
return get_elwise_kernel(context, args, body,
|
|
552
|
+
preamble=dtype_to_c_struct(context.devices[0], dtype),
|
|
553
|
+
name="put")
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
@context_dependent_memoize
|
|
557
|
+
def get_copy_kernel(context: cl.Context,
|
|
558
|
+
dtype_dest: np.dtype[Any],
|
|
559
|
+
dtype_src: np.dtype[Any]) -> cl.Kernel:
|
|
560
|
+
src = "src[i]"
|
|
561
|
+
if dtype_dest.kind == "c" != dtype_src.kind:
|
|
562
|
+
name = complex_dtype_to_name(dtype_dest)
|
|
563
|
+
src = f"{name}_fromreal({src})"
|
|
564
|
+
|
|
565
|
+
if dtype_dest.kind == "c" and dtype_src != dtype_dest:
|
|
566
|
+
name = complex_dtype_to_name(dtype_dest)
|
|
567
|
+
src = f"{name}_cast({src})"
|
|
568
|
+
|
|
569
|
+
if dtype_dest != dtype_src and (
|
|
570
|
+
dtype_dest.kind == "V" or dtype_src.kind == "V"):
|
|
571
|
+
raise TypeError("copying between non-identical struct types")
|
|
572
|
+
|
|
573
|
+
ctype_dst = dtype_to_ctype(dtype_dest)
|
|
574
|
+
ctype_src = dtype_to_ctype(dtype_src)
|
|
575
|
+
return get_elwise_kernel(
|
|
576
|
+
context,
|
|
577
|
+
f"{ctype_dst} *dest, {ctype_src} *src",
|
|
578
|
+
f"dest[i] = {src}",
|
|
579
|
+
preamble=dtype_to_c_struct(context.devices[0], dtype_dest),
|
|
580
|
+
name="copy")
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def complex_dtype_to_name(dtype: DTypeLike) -> str:
|
|
584
|
+
if dtype == np.complex128:
|
|
585
|
+
return "cdouble"
|
|
586
|
+
elif dtype == np.complex64:
|
|
587
|
+
return "cfloat"
|
|
588
|
+
else:
|
|
589
|
+
raise RuntimeError(f"invalid complex type: {dtype}")
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def real_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]:
|
|
593
|
+
return dtype.type(0).real.dtype
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
@context_dependent_memoize
|
|
597
|
+
def get_axpbyz_kernel(context: cl.Context,
|
|
598
|
+
dtype_x: np.dtype[Any],
|
|
599
|
+
dtype_y: np.dtype[Any],
|
|
600
|
+
dtype_z: np.dtype[Any],
|
|
601
|
+
x_is_scalar: bool = False,
|
|
602
|
+
y_is_scalar: bool = False) -> cl.Kernel:
|
|
603
|
+
result_t = dtype_to_ctype(dtype_z)
|
|
604
|
+
|
|
605
|
+
x_is_complex = dtype_x.kind == "c"
|
|
606
|
+
y_is_complex = dtype_y.kind == "c"
|
|
607
|
+
|
|
608
|
+
x = "x[0]" if x_is_scalar else "x[i]"
|
|
609
|
+
y = "y[0]" if y_is_scalar else "y[i]"
|
|
610
|
+
|
|
611
|
+
if dtype_z.kind == "c":
|
|
612
|
+
# a and b will always be complex here.
|
|
613
|
+
z_ct = complex_dtype_to_name(dtype_z)
|
|
614
|
+
|
|
615
|
+
if x_is_complex:
|
|
616
|
+
ax = f"{z_ct}_mul(a, {z_ct}_cast({x}))"
|
|
617
|
+
else:
|
|
618
|
+
ax = f"{z_ct}_mulr(a, {x})"
|
|
619
|
+
|
|
620
|
+
if y_is_complex:
|
|
621
|
+
by = f"{z_ct}_mul(b, {z_ct}_cast({y}))"
|
|
622
|
+
else:
|
|
623
|
+
by = f"{z_ct}_mulr(b, {y})"
|
|
624
|
+
|
|
625
|
+
result = f"{z_ct}_add({ax}, {by})"
|
|
626
|
+
else:
|
|
627
|
+
# real-only
|
|
628
|
+
|
|
629
|
+
ax = f"a*(({result_t}) {x})"
|
|
630
|
+
by = f"b*(({result_t}) {y})"
|
|
631
|
+
|
|
632
|
+
result = f"{ax} + {by}"
|
|
633
|
+
|
|
634
|
+
return get_elwise_kernel(context,
|
|
635
|
+
"{tp_z} *z, {tp_z} a, {tp_x} *x, {tp_z} b, {tp_y} *y".format(
|
|
636
|
+
tp_x=dtype_to_ctype(dtype_x),
|
|
637
|
+
tp_y=dtype_to_ctype(dtype_y),
|
|
638
|
+
tp_z=dtype_to_ctype(dtype_z),
|
|
639
|
+
),
|
|
640
|
+
f"z[i] = {result}",
|
|
641
|
+
name="axpbyz")
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
@context_dependent_memoize
|
|
645
|
+
def get_axpbz_kernel(context: cl.Context,
|
|
646
|
+
dtype_a: np.dtype[Any],
|
|
647
|
+
dtype_x: np.dtype[Any],
|
|
648
|
+
dtype_b: np.dtype[Any],
|
|
649
|
+
dtype_z: np.dtype[Any]):
|
|
650
|
+
a_is_complex = dtype_a.kind == "c"
|
|
651
|
+
x_is_complex = dtype_x.kind == "c"
|
|
652
|
+
b_is_complex = dtype_b.kind == "c"
|
|
653
|
+
|
|
654
|
+
z_is_complex = dtype_z.kind == "c"
|
|
655
|
+
|
|
656
|
+
ax = "a*x[i]"
|
|
657
|
+
if x_is_complex:
|
|
658
|
+
a = "a"
|
|
659
|
+
x = "x[i]"
|
|
660
|
+
|
|
661
|
+
if dtype_x != dtype_z:
|
|
662
|
+
x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
|
|
663
|
+
|
|
664
|
+
if a_is_complex:
|
|
665
|
+
if dtype_a != dtype_z:
|
|
666
|
+
a = "{}_cast({})".format(complex_dtype_to_name(dtype_z), a)
|
|
667
|
+
|
|
668
|
+
ax = "{}_mul({}, {})".format(complex_dtype_to_name(dtype_z), a, x)
|
|
669
|
+
else:
|
|
670
|
+
ax = "{}_rmul({}, {})".format(complex_dtype_to_name(dtype_z), a, x)
|
|
671
|
+
elif a_is_complex:
|
|
672
|
+
a = "a"
|
|
673
|
+
x = "x[i]"
|
|
674
|
+
|
|
675
|
+
if dtype_a != dtype_z:
|
|
676
|
+
a = "{}_cast({})".format(complex_dtype_to_name(dtype_z), a)
|
|
677
|
+
ax = "{}_mulr({}, {})".format(complex_dtype_to_name(dtype_z), a, x)
|
|
678
|
+
|
|
679
|
+
b = "b"
|
|
680
|
+
if z_is_complex and not b_is_complex:
|
|
681
|
+
b = "{}_fromreal({})".format(complex_dtype_to_name(dtype_z), b)
|
|
682
|
+
|
|
683
|
+
if z_is_complex and not (a_is_complex or x_is_complex):
|
|
684
|
+
ax = "{}_fromreal({})".format(complex_dtype_to_name(dtype_z), ax)
|
|
685
|
+
|
|
686
|
+
if z_is_complex:
|
|
687
|
+
ax = "{}_cast({})".format(complex_dtype_to_name(dtype_z), ax)
|
|
688
|
+
b = "{}_cast({})".format(complex_dtype_to_name(dtype_z), b)
|
|
689
|
+
|
|
690
|
+
if a_is_complex or x_is_complex or b_is_complex:
|
|
691
|
+
expr = "{root}_add({ax}, {b})".format(
|
|
692
|
+
ax=ax,
|
|
693
|
+
b=b,
|
|
694
|
+
root=complex_dtype_to_name(dtype_z))
|
|
695
|
+
else:
|
|
696
|
+
expr = f"{ax} + {b}"
|
|
697
|
+
|
|
698
|
+
return get_elwise_kernel(context,
|
|
699
|
+
"{tp_z} *z, {tp_a} a, {tp_x} *x,{tp_b} b".format(
|
|
700
|
+
tp_a=dtype_to_ctype(dtype_a),
|
|
701
|
+
tp_x=dtype_to_ctype(dtype_x),
|
|
702
|
+
tp_b=dtype_to_ctype(dtype_b),
|
|
703
|
+
tp_z=dtype_to_ctype(dtype_z),
|
|
704
|
+
),
|
|
705
|
+
f"z[i] = {expr}",
|
|
706
|
+
name="axpb")
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
@context_dependent_memoize
|
|
710
|
+
def get_multiply_kernel(context: cl.Context,
|
|
711
|
+
dtype_x: np.dtype[Any],
|
|
712
|
+
dtype_y: np.dtype[Any],
|
|
713
|
+
dtype_z: np.dtype[Any],
|
|
714
|
+
x_is_scalar: bool = False,
|
|
715
|
+
y_is_scalar: bool = False) -> cl.Kernel:
|
|
716
|
+
x_is_complex = dtype_x.kind == "c"
|
|
717
|
+
y_is_complex = dtype_y.kind == "c"
|
|
718
|
+
|
|
719
|
+
x = "x[0]" if x_is_scalar else "x[i]"
|
|
720
|
+
y = "y[0]" if y_is_scalar else "y[i]"
|
|
721
|
+
|
|
722
|
+
if x_is_complex and dtype_x != dtype_z:
|
|
723
|
+
x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
|
|
724
|
+
if y_is_complex and dtype_y != dtype_z:
|
|
725
|
+
y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y)
|
|
726
|
+
|
|
727
|
+
if x_is_complex and y_is_complex:
|
|
728
|
+
xy = "{}_mul({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
729
|
+
elif x_is_complex and not y_is_complex:
|
|
730
|
+
xy = "{}_mulr({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
731
|
+
elif not x_is_complex and y_is_complex:
|
|
732
|
+
xy = "{}_rmul({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
733
|
+
else:
|
|
734
|
+
xy = f"{x} * {y}"
|
|
735
|
+
|
|
736
|
+
return get_elwise_kernel(context,
|
|
737
|
+
"{tp_z} *z, {tp_x} *x, {tp_y} *y".format(
|
|
738
|
+
tp_x=dtype_to_ctype(dtype_x),
|
|
739
|
+
tp_y=dtype_to_ctype(dtype_y),
|
|
740
|
+
tp_z=dtype_to_ctype(dtype_z),
|
|
741
|
+
),
|
|
742
|
+
f"z[i] = {xy}",
|
|
743
|
+
name="multiply")
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
@context_dependent_memoize
|
|
747
|
+
def get_divide_kernel(context: cl.Context,
|
|
748
|
+
dtype_x: np.dtype[Any],
|
|
749
|
+
dtype_y: np.dtype[Any],
|
|
750
|
+
dtype_z: np.dtype[Any],
|
|
751
|
+
x_is_scalar: bool = False,
|
|
752
|
+
y_is_scalar: bool = False) -> cl.Kernel:
|
|
753
|
+
x_is_complex = dtype_x.kind == "c"
|
|
754
|
+
y_is_complex = dtype_y.kind == "c"
|
|
755
|
+
z_is_complex = dtype_z.kind == "c"
|
|
756
|
+
|
|
757
|
+
x = "x[0]" if x_is_scalar else "x[i]"
|
|
758
|
+
y = "y[0]" if y_is_scalar else "y[i]"
|
|
759
|
+
|
|
760
|
+
if z_is_complex and dtype_x != dtype_y:
|
|
761
|
+
if x_is_complex and dtype_x != dtype_z:
|
|
762
|
+
x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
|
|
763
|
+
if y_is_complex and dtype_y != dtype_z:
|
|
764
|
+
y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y)
|
|
765
|
+
else:
|
|
766
|
+
if dtype_x != dtype_z:
|
|
767
|
+
x = f"({dtype_to_ctype(dtype_z)}) ({x})"
|
|
768
|
+
if dtype_y != dtype_z:
|
|
769
|
+
y = f"({dtype_to_ctype(dtype_z)}) ({y})"
|
|
770
|
+
|
|
771
|
+
if x_is_complex and y_is_complex:
|
|
772
|
+
xoy = "{}_divide({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
773
|
+
elif not x_is_complex and y_is_complex:
|
|
774
|
+
xoy = "{}_rdivide({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
775
|
+
elif x_is_complex and not y_is_complex:
|
|
776
|
+
xoy = "{}_divider({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
777
|
+
else:
|
|
778
|
+
xoy = f"{x} / {y}"
|
|
779
|
+
|
|
780
|
+
if z_is_complex:
|
|
781
|
+
xoy = "{}_cast({})".format(complex_dtype_to_name(dtype_z), xoy)
|
|
782
|
+
|
|
783
|
+
return get_elwise_kernel(context,
|
|
784
|
+
"{tp_z} *z, {tp_x} *x, {tp_y} *y".format(
|
|
785
|
+
tp_x=dtype_to_ctype(dtype_x),
|
|
786
|
+
tp_y=dtype_to_ctype(dtype_y),
|
|
787
|
+
tp_z=dtype_to_ctype(dtype_z),
|
|
788
|
+
),
|
|
789
|
+
f"z[i] = {xoy}",
|
|
790
|
+
name="divide")
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
@context_dependent_memoize
|
|
794
|
+
def get_rdivide_elwise_kernel(context: cl.Context,
|
|
795
|
+
dtype_x: np.dtype[Any],
|
|
796
|
+
dtype_y: np.dtype[Any],
|
|
797
|
+
dtype_z: np.dtype[Any]) -> cl.Kernel:
|
|
798
|
+
# implements y / x!
|
|
799
|
+
x_is_complex = dtype_x.kind == "c"
|
|
800
|
+
y_is_complex = dtype_y.kind == "c"
|
|
801
|
+
z_is_complex = dtype_z.kind == "c"
|
|
802
|
+
|
|
803
|
+
x = "x[i]"
|
|
804
|
+
y = "y"
|
|
805
|
+
|
|
806
|
+
if z_is_complex and dtype_x != dtype_y:
|
|
807
|
+
if x_is_complex and dtype_x != dtype_z:
|
|
808
|
+
x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
|
|
809
|
+
if y_is_complex and dtype_y != dtype_z:
|
|
810
|
+
y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y)
|
|
811
|
+
|
|
812
|
+
if x_is_complex and y_is_complex:
|
|
813
|
+
yox = "{}_divide({}, {})".format(complex_dtype_to_name(dtype_z), y, x)
|
|
814
|
+
elif not y_is_complex and x_is_complex:
|
|
815
|
+
yox = "{}_rdivide({}, {})".format(complex_dtype_to_name(dtype_z), y, x)
|
|
816
|
+
elif y_is_complex and not x_is_complex:
|
|
817
|
+
yox = "{}_divider({}, {})".format(complex_dtype_to_name(dtype_z), y, x)
|
|
818
|
+
else:
|
|
819
|
+
yox = f"{y} / {x}"
|
|
820
|
+
|
|
821
|
+
return get_elwise_kernel(context,
|
|
822
|
+
"{tp_z} *z, {tp_x} *x, {tp_y} y".format(
|
|
823
|
+
tp_x=dtype_to_ctype(dtype_x),
|
|
824
|
+
tp_y=dtype_to_ctype(dtype_y),
|
|
825
|
+
tp_z=dtype_to_ctype(dtype_z),
|
|
826
|
+
),
|
|
827
|
+
f"z[i] = {yox}",
|
|
828
|
+
name="divide_r")
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
@context_dependent_memoize
|
|
832
|
+
def get_fill_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
|
|
833
|
+
return get_elwise_kernel(context,
|
|
834
|
+
"{tp} *z, {tp} a".format(tp=dtype_to_ctype(dtype)),
|
|
835
|
+
"z[i] = a",
|
|
836
|
+
preamble=dtype_to_c_struct(context.devices[0], dtype),
|
|
837
|
+
name="fill")
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
@context_dependent_memoize
|
|
841
|
+
def get_reverse_kernel(context: cl.Context, dtype: np.dtype[Any]):
|
|
842
|
+
return get_elwise_kernel(context,
|
|
843
|
+
"{tp} *z, {tp} *y".format(tp=dtype_to_ctype(dtype)),
|
|
844
|
+
"z[i] = y[n-1-i]",
|
|
845
|
+
name="reverse")
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
@context_dependent_memoize
|
|
849
|
+
def get_arange_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
|
|
850
|
+
if dtype.kind == "c":
|
|
851
|
+
expr = (
|
|
852
|
+
"{root}_add(start, {root}_rmul(i, step))"
|
|
853
|
+
.format(root=complex_dtype_to_name(dtype)))
|
|
854
|
+
else:
|
|
855
|
+
expr = f"start + (({dtype_to_ctype(dtype)}) i) * step"
|
|
856
|
+
|
|
857
|
+
return get_elwise_kernel(context, [
|
|
858
|
+
VectorArg(dtype, "z", with_offset=True),
|
|
859
|
+
ScalarArg(dtype, "start"),
|
|
860
|
+
ScalarArg(dtype, "step"),
|
|
861
|
+
],
|
|
862
|
+
f"z[i] = {expr}",
|
|
863
|
+
name="arange")
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
@context_dependent_memoize
|
|
867
|
+
def get_pow_kernel(context: cl.Context,
|
|
868
|
+
dtype_x: np.dtype[Any],
|
|
869
|
+
dtype_y: np.dtype[Any],
|
|
870
|
+
dtype_z: np.dtype[Any],
|
|
871
|
+
is_base_array: bool,
|
|
872
|
+
is_exp_array: bool) -> cl.Kernel:
|
|
873
|
+
if is_base_array:
|
|
874
|
+
x = "x[i]"
|
|
875
|
+
x_ctype = "{tp_x} *x"
|
|
876
|
+
else:
|
|
877
|
+
x = "x"
|
|
878
|
+
x_ctype = "{tp_x} x"
|
|
879
|
+
|
|
880
|
+
if is_exp_array:
|
|
881
|
+
y = "y[i]"
|
|
882
|
+
y_ctype = "{tp_y} *y"
|
|
883
|
+
else:
|
|
884
|
+
y = "y"
|
|
885
|
+
y_ctype = "{tp_y} y"
|
|
886
|
+
|
|
887
|
+
x_is_complex = dtype_x.kind == "c"
|
|
888
|
+
y_is_complex = dtype_y.kind == "c"
|
|
889
|
+
z_is_complex = dtype_z.kind == "c"
|
|
890
|
+
|
|
891
|
+
if z_is_complex and dtype_x != dtype_y:
|
|
892
|
+
if x_is_complex and dtype_x != dtype_z:
|
|
893
|
+
x = "{}_cast({})".format(complex_dtype_to_name(dtype_z), x)
|
|
894
|
+
if y_is_complex and dtype_y != dtype_z:
|
|
895
|
+
y = "{}_cast({})".format(complex_dtype_to_name(dtype_z), y)
|
|
896
|
+
elif dtype_x != dtype_y:
|
|
897
|
+
if dtype_x != dtype_z:
|
|
898
|
+
x = "({}) ({})".format(dtype_to_ctype(dtype_z), x)
|
|
899
|
+
if dtype_y != dtype_z:
|
|
900
|
+
y = "({}) ({})".format(dtype_to_ctype(dtype_z), y)
|
|
901
|
+
|
|
902
|
+
if x_is_complex and y_is_complex:
|
|
903
|
+
result = "{}_pow({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
904
|
+
elif x_is_complex and not y_is_complex:
|
|
905
|
+
result = "{}_powr({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
906
|
+
elif not x_is_complex and y_is_complex:
|
|
907
|
+
result = "{}_rpow({}, {})".format(complex_dtype_to_name(dtype_z), x, y)
|
|
908
|
+
else:
|
|
909
|
+
result = f"pow({x}, {y})"
|
|
910
|
+
|
|
911
|
+
return get_elwise_kernel(context,
|
|
912
|
+
("{tp_z} *z, " + x_ctype + ", " + y_ctype).format(
|
|
913
|
+
tp_x=dtype_to_ctype(dtype_x),
|
|
914
|
+
tp_y=dtype_to_ctype(dtype_y),
|
|
915
|
+
tp_z=dtype_to_ctype(dtype_z),
|
|
916
|
+
),
|
|
917
|
+
f"z[i] = {result}",
|
|
918
|
+
name="pow_method")
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
@context_dependent_memoize
|
|
922
|
+
def get_unop_kernel(context: cl.Context,
|
|
923
|
+
operator: str,
|
|
924
|
+
res_dtype: np.dtype[Any],
|
|
925
|
+
in_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
926
|
+
return get_elwise_kernel(context, [
|
|
927
|
+
VectorArg(res_dtype, "z", with_offset=True),
|
|
928
|
+
VectorArg(in_dtype, "y", with_offset=True),
|
|
929
|
+
],
|
|
930
|
+
f"z[i] = {operator} y[i]",
|
|
931
|
+
name="unary_op_kernel")
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@context_dependent_memoize
|
|
935
|
+
def get_array_scalar_binop_kernel(context: cl.Context,
|
|
936
|
+
operator: str,
|
|
937
|
+
dtype_res: np.dtype[Any],
|
|
938
|
+
dtype_a: np.dtype[Any],
|
|
939
|
+
dtype_b: np.dtype[Any]) -> cl.Kernel:
|
|
940
|
+
return get_elwise_kernel(context, [
|
|
941
|
+
VectorArg(dtype_res, "out", with_offset=True),
|
|
942
|
+
VectorArg(dtype_a, "a", with_offset=True),
|
|
943
|
+
ScalarArg(dtype_b, "b"),
|
|
944
|
+
],
|
|
945
|
+
f"out[i] = a[i] {operator} b",
|
|
946
|
+
name="scalar_binop_kernel")
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
@context_dependent_memoize
|
|
950
|
+
def get_array_binop_kernel(context: cl.Context,
|
|
951
|
+
operator: str,
|
|
952
|
+
dtype_res: np.dtype[Any],
|
|
953
|
+
dtype_a: np.dtype[Any],
|
|
954
|
+
dtype_b: np.dtype[Any],
|
|
955
|
+
a_is_scalar: bool = False,
|
|
956
|
+
b_is_scalar: bool = False) -> cl.Kernel:
|
|
957
|
+
a = "a[0]" if a_is_scalar else "a[i]"
|
|
958
|
+
b = "b[0]" if b_is_scalar else "b[i]"
|
|
959
|
+
return get_elwise_kernel(context, [
|
|
960
|
+
VectorArg(dtype_res, "out", with_offset=True),
|
|
961
|
+
VectorArg(dtype_a, "a", with_offset=True),
|
|
962
|
+
VectorArg(dtype_b, "b", with_offset=True),
|
|
963
|
+
],
|
|
964
|
+
f"out[i] = {a} {operator} {b}",
|
|
965
|
+
name="binop_kernel")
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
@context_dependent_memoize
|
|
969
|
+
def get_array_scalar_comparison_kernel(context: cl.Context,
|
|
970
|
+
operator: str,
|
|
971
|
+
dtype_a: np.dtype[Any]) -> cl.Kernel:
|
|
972
|
+
return get_elwise_kernel(context, [
|
|
973
|
+
VectorArg(np.int8, "out", with_offset=True),
|
|
974
|
+
VectorArg(dtype_a, "a", with_offset=True),
|
|
975
|
+
ScalarArg(dtype_a, "b"),
|
|
976
|
+
],
|
|
977
|
+
f"out[i] = a[i] {operator} b",
|
|
978
|
+
name="scalar_comparison_kernel")
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
@context_dependent_memoize
|
|
982
|
+
def get_array_comparison_kernel(context: cl.Context,
|
|
983
|
+
operator: str,
|
|
984
|
+
dtype_a: np.dtype[Any],
|
|
985
|
+
dtype_b: np.dtype[Any]) -> cl.Kernel:
|
|
986
|
+
return get_elwise_kernel(context, [
|
|
987
|
+
VectorArg(np.int8, "out", with_offset=True),
|
|
988
|
+
VectorArg(dtype_a, "a", with_offset=True),
|
|
989
|
+
VectorArg(dtype_b, "b", with_offset=True),
|
|
990
|
+
],
|
|
991
|
+
f"out[i] = a[i] {operator} b[i]",
|
|
992
|
+
name="comparison_kernel")
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
@context_dependent_memoize
|
|
996
|
+
def get_unary_func_kernel(context: cl.Context,
|
|
997
|
+
func_name: str,
|
|
998
|
+
in_dtype: np.dtype[Any],
|
|
999
|
+
out_dtype: np.dtype[Any] | None = None) -> cl.Kernel:
|
|
1000
|
+
if out_dtype is None:
|
|
1001
|
+
out_dtype = in_dtype
|
|
1002
|
+
|
|
1003
|
+
return get_elwise_kernel(context, [
|
|
1004
|
+
VectorArg(out_dtype, "z", with_offset=True),
|
|
1005
|
+
VectorArg(in_dtype, "y", with_offset=True),
|
|
1006
|
+
],
|
|
1007
|
+
f"z[i] = {func_name}(y[i])",
|
|
1008
|
+
name=f"{func_name}_kernel")
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
@context_dependent_memoize
|
|
1012
|
+
def get_binary_func_kernel(context: cl.Context,
|
|
1013
|
+
func_name: str,
|
|
1014
|
+
x_dtype: np.dtype[Any],
|
|
1015
|
+
y_dtype: np.dtype[Any],
|
|
1016
|
+
out_dtype: np.dtype[Any],
|
|
1017
|
+
preamble: str = "",
|
|
1018
|
+
name: str | None = None) -> cl.Kernel:
|
|
1019
|
+
if name is None:
|
|
1020
|
+
name = func_name
|
|
1021
|
+
|
|
1022
|
+
return get_elwise_kernel(context, [
|
|
1023
|
+
VectorArg(out_dtype, "z", with_offset=True),
|
|
1024
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1025
|
+
VectorArg(y_dtype, "y", with_offset=True),
|
|
1026
|
+
],
|
|
1027
|
+
f"z[i] = {func_name}(x[i], y[i])",
|
|
1028
|
+
name=f"{name}_kernel",
|
|
1029
|
+
preamble=preamble)
|
|
1030
|
+
|
|
1031
|
+
|
|
1032
|
+
@context_dependent_memoize
|
|
1033
|
+
def get_float_binary_func_kernel(context: cl.Context,
|
|
1034
|
+
func_name: str,
|
|
1035
|
+
x_dtype: np.dtype[Any],
|
|
1036
|
+
y_dtype: np.dtype[Any],
|
|
1037
|
+
out_dtype: np.dtype[Any],
|
|
1038
|
+
preamble: str = "",
|
|
1039
|
+
name: str | None = None) -> cl.Kernel:
|
|
1040
|
+
if name is None:
|
|
1041
|
+
name = func_name
|
|
1042
|
+
|
|
1043
|
+
if (np.array(0, x_dtype) * np.array(0, y_dtype)).itemsize > 4:
|
|
1044
|
+
arg_type = "double"
|
|
1045
|
+
preamble = """
|
|
1046
|
+
#if __OPENCL_C_VERSION__ < 120
|
|
1047
|
+
#pragma OPENCL EXTENSION cl_khr_fp64: enable
|
|
1048
|
+
#endif
|
|
1049
|
+
#define PYOPENCL_DEFINE_CDOUBLE
|
|
1050
|
+
""" + preamble
|
|
1051
|
+
else:
|
|
1052
|
+
arg_type = "float"
|
|
1053
|
+
|
|
1054
|
+
return get_elwise_kernel(context, [
|
|
1055
|
+
VectorArg(out_dtype, "z", with_offset=True),
|
|
1056
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1057
|
+
VectorArg(y_dtype, "y", with_offset=True),
|
|
1058
|
+
],
|
|
1059
|
+
f"z[i] = {func_name}(({arg_type})x[i], ({arg_type})y[i])",
|
|
1060
|
+
name=f"{name}_kernel",
|
|
1061
|
+
preamble=preamble)
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
@context_dependent_memoize
|
|
1065
|
+
def get_fmod_kernel(context: cl.Context,
|
|
1066
|
+
out_dtype: np.dtype[Any],
|
|
1067
|
+
arg_dtype: np.dtype[Any],
|
|
1068
|
+
mod_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1069
|
+
return get_float_binary_func_kernel(context, "fmod", arg_dtype,
|
|
1070
|
+
mod_dtype, out_dtype)
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
@context_dependent_memoize
|
|
1074
|
+
def get_modf_kernel(context: cl.Context,
|
|
1075
|
+
int_dtype: np.dtype[Any],
|
|
1076
|
+
frac_dtype: np.dtype[Any],
|
|
1077
|
+
x_dtype: np.dtype[Any]):
|
|
1078
|
+
return get_elwise_kernel(context, [
|
|
1079
|
+
VectorArg(int_dtype, "intpart", with_offset=True),
|
|
1080
|
+
VectorArg(frac_dtype, "fracpart", with_offset=True),
|
|
1081
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1082
|
+
],
|
|
1083
|
+
"""
|
|
1084
|
+
fracpart[i] = modf(x[i], &intpart[i])
|
|
1085
|
+
""",
|
|
1086
|
+
name="modf_kernel")
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
@context_dependent_memoize
|
|
1090
|
+
def get_frexp_kernel(context: cl.Context,
|
|
1091
|
+
sign_dtype: np.dtype[Any],
|
|
1092
|
+
exp_dtype: np.dtype[Any],
|
|
1093
|
+
x_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1094
|
+
return get_elwise_kernel(context, [
|
|
1095
|
+
VectorArg(sign_dtype, "significand", with_offset=True),
|
|
1096
|
+
VectorArg(exp_dtype, "exponent", with_offset=True),
|
|
1097
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1098
|
+
],
|
|
1099
|
+
"""
|
|
1100
|
+
int expt = 0;
|
|
1101
|
+
significand[i] = frexp(x[i], &expt);
|
|
1102
|
+
exponent[i] = expt;
|
|
1103
|
+
""",
|
|
1104
|
+
name="frexp_kernel")
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
@context_dependent_memoize
|
|
1108
|
+
def get_ldexp_kernel(context: cl.Context,
|
|
1109
|
+
out_dtype: np.dtype[Any],
|
|
1110
|
+
sig_dtype: np.dtype[Any],
|
|
1111
|
+
expt_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1112
|
+
return get_binary_func_kernel(
|
|
1113
|
+
context, "_PYOCL_LDEXP", sig_dtype, expt_dtype, out_dtype,
|
|
1114
|
+
preamble="#define _PYOCL_LDEXP(x, y) ldexp(x, (int)(y))",
|
|
1115
|
+
name="ldexp_kernel")
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
@context_dependent_memoize
|
|
1119
|
+
def get_minmaximum_kernel(context: cl.Context,
|
|
1120
|
+
minmax: str,
|
|
1121
|
+
dtype_z: np.dtype[Any],
|
|
1122
|
+
dtype_x: np.dtype[Any],
|
|
1123
|
+
dtype_y: np.dtype[Any],
|
|
1124
|
+
kind_x: ArgumentKind,
|
|
1125
|
+
kind_y: ArgumentKind) -> cl.Kernel:
|
|
1126
|
+
if dtype_z.kind == "f":
|
|
1127
|
+
reduce_func = f"f{minmax}_nanprop"
|
|
1128
|
+
elif dtype_z.kind in "iu":
|
|
1129
|
+
reduce_func = minmax
|
|
1130
|
+
else:
|
|
1131
|
+
raise TypeError("unsupported dtype specified")
|
|
1132
|
+
|
|
1133
|
+
tp_x = dtype_to_ctype(dtype_x)
|
|
1134
|
+
tp_y = dtype_to_ctype(dtype_y)
|
|
1135
|
+
tp_z = dtype_to_ctype(dtype_z)
|
|
1136
|
+
decl_x, acc_x = get_decl_and_access_for_kind("x", kind_x)
|
|
1137
|
+
decl_y, acc_y = get_decl_and_access_for_kind("y", kind_y)
|
|
1138
|
+
|
|
1139
|
+
return get_elwise_kernel(context,
|
|
1140
|
+
f"{tp_z} *z, {tp_x} {decl_x}, {tp_y} {decl_y}",
|
|
1141
|
+
f"z[i] = {reduce_func}({acc_x}, {acc_y})",
|
|
1142
|
+
name=f"{minmax}imum",
|
|
1143
|
+
preamble="""
|
|
1144
|
+
#define fmin_nanprop(a, b) (isnan(a) || isnan(b)) ? a+b : fmin(a, b)
|
|
1145
|
+
#define fmax_nanprop(a, b) (isnan(a) || isnan(b)) ? a+b : fmax(a, b)
|
|
1146
|
+
""")
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
@context_dependent_memoize
|
|
1150
|
+
def get_bessel_kernel(context: cl.Context,
|
|
1151
|
+
which_func: str,
|
|
1152
|
+
out_dtype: np.dtype[Any],
|
|
1153
|
+
order_dtype: np.dtype[Any],
|
|
1154
|
+
x_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1155
|
+
if x_dtype.kind != "c":
|
|
1156
|
+
return get_elwise_kernel(context, [
|
|
1157
|
+
VectorArg(out_dtype, "z", with_offset=True),
|
|
1158
|
+
ScalarArg(order_dtype, "ord_n"),
|
|
1159
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1160
|
+
],
|
|
1161
|
+
f"z[i] = bessel_{which_func}n(ord_n, x[i])",
|
|
1162
|
+
name=f"bessel_{which_func}n_kernel",
|
|
1163
|
+
preamble=f"""
|
|
1164
|
+
#if __OPENCL_C_VERSION__ < 120
|
|
1165
|
+
#pragma OPENCL EXTENSION cl_khr_fp64: enable
|
|
1166
|
+
#endif
|
|
1167
|
+
#define PYOPENCL_DEFINE_CDOUBLE
|
|
1168
|
+
#include <pyopencl-bessel-{which_func}.cl>
|
|
1169
|
+
""")
|
|
1170
|
+
else:
|
|
1171
|
+
if which_func != "j":
|
|
1172
|
+
raise NotImplementedError("complex arguments for Bessel Y")
|
|
1173
|
+
|
|
1174
|
+
if x_dtype != np.complex128:
|
|
1175
|
+
raise NotImplementedError("non-complex double dtype")
|
|
1176
|
+
if x_dtype != out_dtype:
|
|
1177
|
+
raise NotImplementedError("different input/output types")
|
|
1178
|
+
|
|
1179
|
+
return get_elwise_kernel(context, [
|
|
1180
|
+
VectorArg(out_dtype, "z", with_offset=True),
|
|
1181
|
+
ScalarArg(order_dtype, "ord_n"),
|
|
1182
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1183
|
+
],
|
|
1184
|
+
"""
|
|
1185
|
+
cdouble_t jv_loc;
|
|
1186
|
+
cdouble_t jvp1_loc;
|
|
1187
|
+
bessel_j_complex(ord_n, x[i], &jv_loc, &jvp1_loc);
|
|
1188
|
+
z[i] = jv_loc;
|
|
1189
|
+
""",
|
|
1190
|
+
name="bessel_j_complex_kernel",
|
|
1191
|
+
preamble="""
|
|
1192
|
+
#if __OPENCL_C_VERSION__ < 120
|
|
1193
|
+
#pragma OPENCL EXTENSION cl_khr_fp64: enable
|
|
1194
|
+
#endif
|
|
1195
|
+
#define PYOPENCL_DEFINE_CDOUBLE
|
|
1196
|
+
#include <pyopencl-complex.h>
|
|
1197
|
+
#include <pyopencl-bessel-j-complex.cl>
|
|
1198
|
+
""")
|
|
1199
|
+
|
|
1200
|
+
|
|
1201
|
+
@context_dependent_memoize
|
|
1202
|
+
def get_hankel_01_kernel(context: cl.Context,
|
|
1203
|
+
out_dtype: np.dtype[Any],
|
|
1204
|
+
x_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1205
|
+
if x_dtype != np.complex128:
|
|
1206
|
+
raise NotImplementedError("non-complex double dtype")
|
|
1207
|
+
if x_dtype != out_dtype:
|
|
1208
|
+
raise NotImplementedError("different input/output types")
|
|
1209
|
+
|
|
1210
|
+
return get_elwise_kernel(context, [
|
|
1211
|
+
VectorArg(out_dtype, "h0", with_offset=True),
|
|
1212
|
+
VectorArg(out_dtype, "h1", with_offset=True),
|
|
1213
|
+
VectorArg(x_dtype, "x", with_offset=True),
|
|
1214
|
+
],
|
|
1215
|
+
"""
|
|
1216
|
+
cdouble_t h0_loc;
|
|
1217
|
+
cdouble_t h1_loc;
|
|
1218
|
+
hankel_01_complex(x[i], &h0_loc, &h1_loc, 1);
|
|
1219
|
+
h0[i] = h0_loc;
|
|
1220
|
+
h1[i] = h1_loc;
|
|
1221
|
+
""",
|
|
1222
|
+
name="hankel_complex_kernel",
|
|
1223
|
+
preamble="""
|
|
1224
|
+
#if __OPENCL_C_VERSION__ < 120
|
|
1225
|
+
#pragma OPENCL EXTENSION cl_khr_fp64: enable
|
|
1226
|
+
#endif
|
|
1227
|
+
#define PYOPENCL_DEFINE_CDOUBLE
|
|
1228
|
+
#include <pyopencl-complex.h>
|
|
1229
|
+
#include <pyopencl-hankel-complex.cl>
|
|
1230
|
+
""")
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
@context_dependent_memoize
|
|
1234
|
+
def get_diff_kernel(context: cl.Context, dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1235
|
+
return get_elwise_kernel(context, [
|
|
1236
|
+
VectorArg(dtype, "result", with_offset=True),
|
|
1237
|
+
VectorArg(dtype, "array", with_offset=True),
|
|
1238
|
+
],
|
|
1239
|
+
"result[i] = array[i+1] - array[i]",
|
|
1240
|
+
name="diff")
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
@context_dependent_memoize
|
|
1244
|
+
def get_if_positive_kernel(
|
|
1245
|
+
context: cl.Context,
|
|
1246
|
+
crit_dtype: np.dtype[Any],
|
|
1247
|
+
then_else_dtype: np.dtype[Any],
|
|
1248
|
+
is_then_array: bool,
|
|
1249
|
+
is_else_array: bool,
|
|
1250
|
+
is_then_scalar: bool,
|
|
1251
|
+
is_else_scalar: bool) -> cl.Kernel:
|
|
1252
|
+
if is_then_array:
|
|
1253
|
+
then_ = "then_[0]" if is_then_scalar else "then_[i]"
|
|
1254
|
+
then_arg = VectorArg(then_else_dtype, "then_", with_offset=True)
|
|
1255
|
+
else:
|
|
1256
|
+
assert is_then_scalar
|
|
1257
|
+
then_ = "then_"
|
|
1258
|
+
then_arg = ScalarArg(then_else_dtype, "then_")
|
|
1259
|
+
|
|
1260
|
+
if is_else_array:
|
|
1261
|
+
else_ = "else_[0]" if is_else_scalar else "else_[i]"
|
|
1262
|
+
else_arg = VectorArg(then_else_dtype, "else_", with_offset=True)
|
|
1263
|
+
else:
|
|
1264
|
+
assert is_else_scalar
|
|
1265
|
+
else_ = "else_"
|
|
1266
|
+
else_arg = ScalarArg(then_else_dtype, "else_")
|
|
1267
|
+
|
|
1268
|
+
return get_elwise_kernel(context, [
|
|
1269
|
+
VectorArg(then_else_dtype, "result", with_offset=True),
|
|
1270
|
+
VectorArg(crit_dtype, "crit", with_offset=True),
|
|
1271
|
+
then_arg, else_arg,
|
|
1272
|
+
],
|
|
1273
|
+
f"result[i] = crit[i] > 0 ? {then_} : {else_}",
|
|
1274
|
+
name="if_positive")
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
@context_dependent_memoize
|
|
1278
|
+
def get_logical_not_kernel(context: cl.Context, in_dtype: np.dtype[Any]) -> cl.Kernel:
|
|
1279
|
+
return get_elwise_kernel(context, [
|
|
1280
|
+
VectorArg(np.int8, "z", with_offset=True),
|
|
1281
|
+
VectorArg(in_dtype, "y", with_offset=True),
|
|
1282
|
+
],
|
|
1283
|
+
"z[i] = (y[i] == 0)",
|
|
1284
|
+
name="logical_not_kernel")
|
|
1285
|
+
|
|
1286
|
+
# }}}
|
|
1287
|
+
|
|
1288
|
+
# vim: fdm=marker
|