pyopencl 2026.1.1__cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pyopencl/.libs/libOpenCL-34a55fe4.so.1.0.0 +0 -0
  2. pyopencl/__init__.py +1995 -0
  3. pyopencl/_cl.cpython-314t-aarch64-linux-gnu.so +0 -0
  4. pyopencl/_cl.pyi +2009 -0
  5. pyopencl/_cluda.py +57 -0
  6. pyopencl/_monkeypatch.py +1104 -0
  7. pyopencl/_mymako.py +17 -0
  8. pyopencl/algorithm.py +1454 -0
  9. pyopencl/array.py +3530 -0
  10. pyopencl/bitonic_sort.py +245 -0
  11. pyopencl/bitonic_sort_templates.py +597 -0
  12. pyopencl/cache.py +553 -0
  13. pyopencl/capture_call.py +200 -0
  14. pyopencl/characterize/__init__.py +461 -0
  15. pyopencl/characterize/performance.py +240 -0
  16. pyopencl/cl/pyopencl-airy.cl +324 -0
  17. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  18. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  19. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  20. pyopencl/cl/pyopencl-complex.h +303 -0
  21. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  22. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  23. pyopencl/cl/pyopencl-random123/array.h +325 -0
  24. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  25. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  26. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  27. pyopencl/clmath.py +281 -0
  28. pyopencl/clrandom.py +412 -0
  29. pyopencl/cltypes.py +217 -0
  30. pyopencl/compyte/.gitignore +21 -0
  31. pyopencl/compyte/__init__.py +0 -0
  32. pyopencl/compyte/array.py +211 -0
  33. pyopencl/compyte/dtypes.py +314 -0
  34. pyopencl/compyte/pyproject.toml +49 -0
  35. pyopencl/elementwise.py +1288 -0
  36. pyopencl/invoker.py +417 -0
  37. pyopencl/ipython_ext.py +70 -0
  38. pyopencl/py.typed +0 -0
  39. pyopencl/reduction.py +829 -0
  40. pyopencl/scan.py +1921 -0
  41. pyopencl/tools.py +1680 -0
  42. pyopencl/typing.py +61 -0
  43. pyopencl/version.py +11 -0
  44. pyopencl-2026.1.1.dist-info/METADATA +108 -0
  45. pyopencl-2026.1.1.dist-info/RECORD +47 -0
  46. pyopencl-2026.1.1.dist-info/WHEEL +6 -0
  47. pyopencl-2026.1.1.dist-info/licenses/LICENSE +104 -0
pyopencl/tools.py ADDED
@@ -0,0 +1,1680 @@
1
+ r"""
2
+ .. _memory-pools:
3
+
4
+ Memory Pools
5
+ ------------
6
+
7
+ Memory allocation (e.g. in the form of the :func:`pyopencl.Buffer` constructor)
8
+ can be expensive if used frequently. For example, code based on
9
+ :class:`pyopencl.array.Array` can easily run into this issue because a fresh
10
+ memory area is allocated for each intermediate result. Memory pools are a
11
+ remedy for this problem based on the observation that often many of the block
12
+ allocations are of the same sizes as previously used ones.
13
+
14
+ Then, instead of fully returning the memory to the system and incurring the
15
+ associated reallocation overhead, the pool holds on to the memory and uses it
16
+ to satisfy future allocations of similarly-sized blocks. The pool reacts
17
+ appropriately to out-of-memory conditions as long as all memory allocations
18
+ are made through it. Allocations performed from outside of the pool may run
19
+ into spurious out-of-memory conditions due to the pool owning much or all of
20
+ the available memory.
21
+
22
+ There are two flavors of allocators and memory pools:
23
+
24
+ - :ref:`buf-mempool`
25
+ - :ref:`svm-mempool`
26
+
27
+ Using :class:`pyopencl.array.Array`\ s can be used with memory pools in a
28
+ straightforward manner::
29
+
30
+ mem_pool = pyopencl.tools.MemoryPool(pyopencl.tools.ImmediateAllocator(queue))
31
+ a_dev = cl_array.arange(queue, 2000, dtype=np.float32, allocator=mem_pool)
32
+
33
+ Likewise, SVM-based allocators are directly usable with
34
+ :class:`pyopencl.array.Array`.
35
+
36
+ .. _buf-mempool:
37
+
38
+ :class:`~pyopencl.Buffer`-based Allocators and Memory Pools
39
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
40
+
41
+ .. autoclass:: PooledBuffer
42
+
43
+ .. autoclass:: AllocatorBase
44
+
45
+ .. autoclass:: DeferredAllocator
46
+
47
+ .. autoclass:: ImmediateAllocator
48
+
49
+ .. autoclass:: MemoryPool
50
+
51
+ .. _svm-mempool:
52
+
53
+ :ref:`SVM <svm>`-Based Allocators and Memory Pools
54
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55
+
56
+ SVM functionality requires OpenCL 2.0.
57
+
58
+ .. autoclass:: PooledSVM
59
+
60
+ .. autoclass:: SVMAllocator
61
+
62
+ .. autoclass:: SVMPool
63
+
64
+ CL-Object-dependent Caching
65
+ ---------------------------
66
+
67
+ .. autofunction:: first_arg_dependent_memoize
68
+ .. autofunction:: clear_first_arg_caches
69
+
70
+ Testing
71
+ -------
72
+
73
+ .. autofunction:: pytest_generate_tests_for_pyopencl
74
+
75
+ Argument Types
76
+ --------------
77
+
78
+ .. autoclass:: Argument
79
+ .. autoclass:: DtypedArgument
80
+
81
+ .. autoclass:: VectorArg
82
+ .. autoclass:: ScalarArg
83
+ .. autoclass:: OtherArg
84
+
85
+ .. autofunction:: parse_arg_list
86
+
87
+ Device Characterization
88
+ -----------------------
89
+
90
+ .. automodule:: pyopencl.characterize
91
+ :members:
92
+
93
+ Type aliases
94
+ ------------
95
+
96
+ .. currentmodule:: pyopencl._cl
97
+
98
+ .. class:: AllocatorBase
99
+
100
+ See :class:`pyopencl.tools.AllocatorBase`.
101
+ """
102
+
103
+ from __future__ import annotations
104
+
105
+
106
+ __copyright__ = "Copyright (C) 2010 Andreas Kloeckner"
107
+
108
+ __license__ = """
109
+ Permission is hereby granted, free of charge, to any person
110
+ obtaining a copy of this software and associated documentation
111
+ files (the "Software"), to deal in the Software without
112
+ restriction, including without limitation the rights to use,
113
+ copy, modify, merge, publish, distribute, sublicense, and/or sell
114
+ copies of the Software, and to permit persons to whom the
115
+ Software is furnished to do so, subject to the following
116
+ conditions:
117
+
118
+ The above copyright notice and this permission notice shall be
119
+ included in all copies or substantial portions of the Software.
120
+
121
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
122
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
123
+ OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
124
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
125
+ HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
126
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
127
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
128
+ OTHER DEALINGS IN THE SOFTWARE.
129
+ """
130
+
131
+ import atexit
132
+ import re
133
+ from abc import ABC, abstractmethod
134
+ from dataclasses import dataclass, field
135
+ from sys import intern
136
+ from typing import (
137
+ TYPE_CHECKING,
138
+ Any,
139
+ ClassVar,
140
+ Concatenate,
141
+ ParamSpec,
142
+ TypeAlias,
143
+ TypedDict,
144
+ TypeVar,
145
+ cast,
146
+ overload,
147
+ )
148
+
149
+ import numpy as np
150
+ from typing_extensions import TypeIs, override
151
+
152
+ from pytools import Hash, memoize, memoize_method
153
+ from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
154
+
155
+ from pyopencl._cl import bitlog2, get_cl_header_version
156
+ from pyopencl.compyte.dtypes import (
157
+ TypeNameNotKnown,
158
+ dtype_to_ctype,
159
+ get_or_register_dtype,
160
+ register_dtype,
161
+ )
162
+
163
+
164
+ if TYPE_CHECKING:
165
+ from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
166
+
167
+ import pytest
168
+ from mako.template import Template
169
+ from numpy.typing import DTypeLike, NDArray
170
+
171
+ # Do not add a pyopencl import here: This will add an import cycle.
172
+
173
+
174
+ def _register_types():
175
+ from pyopencl.compyte.dtypes import TYPE_REGISTRY, fill_registry_with_opencl_c_types
176
+
177
+ fill_registry_with_opencl_c_types(TYPE_REGISTRY)
178
+
179
+ get_or_register_dtype("cfloat_t", np.complex64)
180
+ get_or_register_dtype("cdouble_t", np.complex128)
181
+
182
+
183
+ _register_types()
184
+
185
+
186
+ # {{{ imported names
187
+
188
+ from pyopencl._cl import (
189
+ AllocatorBase,
190
+ DeferredAllocator,
191
+ ImmediateAllocator,
192
+ MemoryPool,
193
+ PooledBuffer,
194
+ )
195
+
196
+
197
+ if get_cl_header_version() >= (2, 0):
198
+ from pyopencl._cl import PooledSVM, SVMAllocator, SVMPool
199
+
200
+ # }}}
201
+
202
+
203
+ # {{{ monkeypatch docstrings into imported interfaces
204
+
205
+ _MEMPOOL_IFACE_DOCS = """
206
+ .. note::
207
+
208
+ The current implementation of the memory pool will retain allocated
209
+ memory after it is returned by the application and keep it in a bin
210
+ identified by the leading *leading_bits_in_bin_id* bits of the
211
+ allocation size. To ensure that allocations within each bin are
212
+ interchangeable, allocation sizes are rounded up to the largest size
213
+ that shares the leading bits of the requested allocation size.
214
+
215
+ The current default value of *leading_bits_in_bin_id* is
216
+ four, but this may change in future versions and is not
217
+ guaranteed.
218
+
219
+ *leading_bits_in_bin_id* must be passed by keyword,
220
+ and its role is purely advisory. It is not guaranteed
221
+ that future versions of the pool will use the
222
+ same allocation scheme and/or honor *leading_bits_in_bin_id*.
223
+
224
+ .. attribute:: held_blocks
225
+
226
+ The number of unused blocks being held by this pool.
227
+
228
+ .. attribute:: active_blocks
229
+
230
+ The number of blocks in active use that have been allocated
231
+ through this pool.
232
+
233
+ .. attribute:: managed_bytes
234
+
235
+ "Managed" memory is "active" and "held" memory.
236
+
237
+ .. versionadded:: 2021.1.2
238
+
239
+ .. attribute:: active_bytes
240
+
241
+ "Active" bytes are bytes under the control of the application.
242
+ This may be smaller than the actual allocated size reflected
243
+ in :attr:`managed_bytes`.
244
+
245
+ .. versionadded:: 2021.1.2
246
+
247
+
248
+ .. method:: free_held
249
+
250
+ Free all unused memory that the pool is currently holding.
251
+
252
+ .. method:: stop_holding
253
+
254
+ Instruct the memory to start immediately freeing memory returned
255
+ to it, instead of holding it for future allocations.
256
+ Implicitly calls :meth:`free_held`.
257
+ This is useful as a cleanup action when a memory pool falls out
258
+ of use.
259
+ """
260
+
261
+
262
+ def _monkeypatch_docstrings():
263
+ from pytools.codegen import remove_common_indentation
264
+
265
+ PooledBuffer.__doc__ = """
266
+ An object representing a :class:`MemoryPool`-based allocation of
267
+ :class:`~pyopencl.Buffer`-style device memory. Analogous to
268
+ :class:`~pyopencl.Buffer`, however once this object is deleted, its
269
+ associated device memory is returned to the pool.
270
+
271
+ Is a :class:`pyopencl.MemoryObject`.
272
+ """
273
+
274
+ AllocatorBase.__doc__ = """
275
+ An interface implemented by various memory allocation functions
276
+ in :mod:`pyopencl`.
277
+
278
+ .. automethod:: __call__
279
+
280
+ Allocate and return a :class:`pyopencl.Buffer` of the given *size*.
281
+ """
282
+
283
+ # {{{ DeferredAllocator
284
+
285
+ DeferredAllocator.__doc__ = """
286
+ *mem_flags* takes its values from :class:`pyopencl.mem_flags` and corresponds
287
+ to the *flags* argument of :class:`pyopencl.Buffer`. DeferredAllocator
288
+ has the same semantics as regular OpenCL buffer allocation, i.e. it may
289
+ promise memory to be available that may (in any call to a buffer-using
290
+ CL function) turn out to not exist later on. (Allocations in CL are
291
+ bound to contexts, not devices, and memory availability depends on which
292
+ device the buffer is used with.)
293
+
294
+ Implements :class:`AllocatorBase`.
295
+
296
+ .. versionchanged :: 2013.1
297
+
298
+ ``CLAllocator`` was deprecated and replaced
299
+ by :class:`DeferredAllocator`.
300
+
301
+ .. method:: __init__(context, mem_flags=pyopencl.mem_flags.READ_WRITE)
302
+
303
+ .. automethod:: __call__
304
+
305
+ Allocate a :class:`pyopencl.Buffer` of the given *size*.
306
+
307
+ .. versionchanged :: 2020.2
308
+
309
+ The allocator will succeed even for allocations of size zero,
310
+ returning *None*.
311
+ """
312
+
313
+ # }}}
314
+
315
+ # {{{ ImmediateAllocator
316
+
317
+ ImmediateAllocator.__doc__ = """
318
+ *mem_flags* takes its values from :class:`pyopencl.mem_flags` and corresponds
319
+ to the *flags* argument of :class:`pyopencl.Buffer`.
320
+ :class:`ImmediateAllocator` will attempt to ensure at allocation time that
321
+ allocated memory is actually available. If no memory is available, an
322
+ out-of-memory error is reported at allocation time.
323
+
324
+ Implements :class:`AllocatorBase`.
325
+
326
+ .. versionadded:: 2013.1
327
+
328
+ .. method:: __init__(queue, mem_flags=pyopencl.mem_flags.READ_WRITE)
329
+
330
+ .. automethod:: __call__
331
+
332
+ Allocate a :class:`pyopencl.Buffer` of the given *size*.
333
+
334
+ .. versionchanged :: 2020.2
335
+
336
+ The allocator will succeed even for allocations of size zero,
337
+ returning *None*.
338
+ """
339
+
340
+ # }}}
341
+
342
+ # {{{ MemoryPool
343
+
344
+ MemoryPool.__doc__ = remove_common_indentation("""
345
+ A memory pool for OpenCL device memory in :class:`pyopencl.Buffer` form.
346
+ *allocator* must be an instance of one of the above classes, and should be
347
+ an :class:`ImmediateAllocator`. The memory pool assumes that allocation
348
+ failures are reported by the allocator immediately, and not in the
349
+ OpenCL-typical deferred manner.
350
+
351
+ Implements :class:`AllocatorBase`.
352
+
353
+ .. versionchanged:: 2019.1
354
+
355
+ Current bin allocation behavior documented, *leading_bits_in_bin_id*
356
+ added.
357
+
358
+ .. automethod:: __init__
359
+
360
+ .. automethod:: allocate
361
+
362
+ Return a :class:`PooledBuffer` of the given *size*.
363
+
364
+ .. automethod:: __call__
365
+
366
+ Synonym for :meth:`allocate` to match :class:`AllocatorBase`.
367
+
368
+ .. versionadded:: 2011.2
369
+ """) + _MEMPOOL_IFACE_DOCS
370
+
371
+ # }}}
372
+
373
+
374
+ _monkeypatch_docstrings()
375
+
376
+
377
+ def _monkeypatch_svm_docstrings():
378
+ from pytools.codegen import remove_common_indentation
379
+
380
+ # {{{ PooledSVM
381
+
382
+ PooledSVM.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
383
+ """An object representing a :class:`SVMPool`-based allocation of
384
+ :ref:`svm`. Analogous to :class:`~pyopencl.SVMAllocation`, however once
385
+ this object is deleted, its associated device memory is returned to the
386
+ pool from which it came.
387
+
388
+ .. versionadded:: 2022.2
389
+
390
+ .. note::
391
+
392
+ If the :class:`SVMAllocator` for the :class:`SVMPool` that allocated an
393
+ object of this type is associated with an (in-order)
394
+ :class:`~pyopencl.CommandQueue`, sufficient synchronization is provided
395
+ to ensure operations enqueued before deallocation complete before
396
+ operations from a different use (possibly in a different queue) are
397
+ permitted to start. This applies when :class:`release` is called and
398
+ also when the object is freed automatically by the garbage collector.
399
+
400
+ Is a :class:`pyopencl.SVMPointer`.
401
+
402
+ Supports structural equality and hashing.
403
+
404
+ .. automethod:: release
405
+
406
+ Return the held memory to the pool. See the note about synchronization
407
+ behavior during deallocation above.
408
+
409
+ .. automethod:: enqueue_release
410
+
411
+ Synonymous to :meth:`release`, for consistency with
412
+ :class:`~pyopencl.SVMAllocation`. Note that, unlike
413
+ :meth:`pyopencl.SVMAllocation.enqueue_release`, specifying a queue
414
+ or events to be waited for is not supported.
415
+
416
+ .. automethod:: bind_to_queue
417
+
418
+ Analogous to :meth:`pyopencl.SVMAllocation.bind_to_queue`.
419
+
420
+ .. automethod:: unbind_from_queue
421
+
422
+ Analogous to :meth:`pyopencl.SVMAllocation.unbind_from_queue`.
423
+ """)
424
+
425
+ # }}}
426
+
427
+ # {{{ SVMAllocator
428
+
429
+ SVMAllocator.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
430
+ """
431
+ .. versionadded:: 2022.2
432
+
433
+ .. automethod:: __init__
434
+
435
+ :arg flags: See :class:`~pyopencl.svm_mem_flags`.
436
+ :arg queue: If not specified, allocations will be freed
437
+ eagerly, irrespective of whether pending/enqueued operations
438
+ are still using the memory.
439
+
440
+ If specified, deallocation of memory will be enqueued
441
+ with the given queue, and will only be performed
442
+ after previously-enqueue operations in the queue have
443
+ completed.
444
+
445
+ It is an error to specify an out-of-order queue.
446
+
447
+ .. warning::
448
+
449
+ Not specifying a queue will typically lead to undesired
450
+ behavior, including crashes and memory corruption.
451
+ See the warning in :ref:`svm`.
452
+
453
+ .. automethod:: __call__
454
+
455
+ Return a :class:`~pyopencl.SVMAllocation` of the given *size*.
456
+ """)
457
+
458
+ # }}}
459
+
460
+ # {{{ SVMPool
461
+
462
+ SVMPool.__doc__ = ( # pyright: ignore[reportPossiblyUnboundVariable]
463
+ remove_common_indentation("""
464
+ A memory pool for OpenCL device memory in :ref:`SVM <svm>` form.
465
+ *allocator* must be an instance of :class:`SVMAllocator`.
466
+
467
+ .. versionadded:: 2022.2
468
+
469
+ .. automethod:: __init__
470
+ .. automethod:: __call__
471
+
472
+ Return a :class:`PooledSVM` of the given *size*.
473
+ """) + _MEMPOOL_IFACE_DOCS)
474
+
475
+ # }}}
476
+
477
+
478
+ if get_cl_header_version() >= (2, 0):
479
+ _monkeypatch_svm_docstrings()
480
+
481
+ # }}}
482
+
483
+
484
+ # {{{ first-arg caches
485
+
486
+ _first_arg_dependent_caches: list[Mapping[Hashable, object]] = []
487
+
488
+
489
+ HashableT = TypeVar("HashableT", bound="Hashable")
490
+ RetT = TypeVar("RetT")
491
+ P = ParamSpec("P")
492
+
493
+
494
+ def first_arg_dependent_memoize(
495
+ func: Callable[Concatenate[HashableT, P], RetT]
496
+ ) -> Callable[Concatenate[HashableT, P], RetT]:
497
+ def wrapper(cl_object: HashableT, *args: P.args, **kwargs: P.kwargs) -> RetT:
498
+ """Provides memoization for a function. Typically used to cache
499
+ things that get created inside a :class:`pyopencl.Context`, e.g. programs
500
+ and kernels. Assumes that the first argument of the decorated function is
501
+ an OpenCL object that might go away, such as a :class:`pyopencl.Context` or
502
+ a :class:`pyopencl.CommandQueue`, and based on which we might want to clear
503
+ the cache.
504
+
505
+ .. versionadded:: 2011.2
506
+ """
507
+ if kwargs:
508
+ cache_key = (args, frozenset(kwargs.items()))
509
+ else:
510
+ cache_key = (args,)
511
+
512
+ ctx_dict: dict[Hashable, dict[Hashable, RetT]]
513
+ try:
514
+ ctx_dict = func._pyopencl_first_arg_dep_memoize_dic # pyright: ignore[reportFunctionMemberAccess]
515
+ except AttributeError:
516
+ # FIXME: This may keep contexts alive longer than desired.
517
+ # But I guess since the memory in them is freed, who cares.
518
+ ctx_dict = func._pyopencl_first_arg_dep_memoize_dic = {} # pyright: ignore[reportFunctionMemberAccess]
519
+ _first_arg_dependent_caches.append(ctx_dict)
520
+
521
+ try:
522
+ return ctx_dict[cl_object][cache_key]
523
+ except KeyError:
524
+ arg_dict = ctx_dict.setdefault(cl_object, {})
525
+ result = func(cl_object, *args, **kwargs)
526
+ arg_dict[cache_key] = result
527
+ return result
528
+
529
+ from functools import update_wrapper
530
+ update_wrapper(wrapper, func)
531
+ return wrapper
532
+
533
+
534
+ context_dependent_memoize = first_arg_dependent_memoize
535
+
536
+
537
+ def first_arg_dependent_memoize_nested(
538
+ nested_func: Callable[Concatenate[Hashable, P], RetT]
539
+ ) -> Callable[Concatenate[Hashable, P], RetT]:
540
+ """Provides memoization for nested functions.
541
+
542
+ Typically used to cache things that get created inside a
543
+ :class:`pyopencl.Context`, e.g. programs and kernels. Assumes that the first
544
+ argument of the decorated function is an OpenCL object that might go away,
545
+ such as a :class:`pyopencl.Context` or a :class:`pyopencl.CommandQueue`, and
546
+ will therefore respond to :func:`clear_first_arg_caches`.
547
+
548
+ .. versionadded:: 2013.1
549
+ """
550
+
551
+ from functools import wraps
552
+ cache_dict_name = intern(
553
+ f"_memoize_inner_dic_{nested_func.__name__}_"
554
+ f"{nested_func.__code__.co_filename}_"
555
+ f"{nested_func.__code__.co_firstlineno}")
556
+
557
+ from inspect import currentframe
558
+
559
+ # prevent ref cycle
560
+ frame = currentframe()
561
+ cache_context = None
562
+ if frame:
563
+ try:
564
+ caller_frame = frame.f_back
565
+ if caller_frame:
566
+ cache_context = caller_frame.f_globals[caller_frame.f_code.co_name]
567
+ finally:
568
+ # del caller_frame
569
+ pass
570
+
571
+ cache_dict: dict[Hashable, dict[Hashable, RetT]]
572
+ try:
573
+ cache_dict = getattr(cache_context, cache_dict_name)
574
+ except AttributeError:
575
+ cache_dict = {}
576
+ _first_arg_dependent_caches.append(cache_dict)
577
+ setattr(cache_context, cache_dict_name, cache_dict)
578
+
579
+ @wraps(nested_func)
580
+ def new_nested_func(cl_object: Hashable, *args: P.args, **kwargs: P.kwargs) -> RetT:
581
+ assert not kwargs
582
+
583
+ try:
584
+ return cache_dict[cl_object][args]
585
+ except KeyError:
586
+ arg_dict = cache_dict.setdefault(cl_object, {})
587
+ result = nested_func(cl_object, *args, **kwargs)
588
+ arg_dict[args] = result
589
+ return result
590
+
591
+ return new_nested_func
592
+
593
+
594
+ def clear_first_arg_caches():
595
+ """Empties all first-argument-dependent memoization caches.
596
+
597
+ Also releases all held reference contexts. If it is important to you that the
598
+ program detaches from its context, you might need to call this function to
599
+ free all remaining references to your context.
600
+
601
+ .. versionadded:: 2011.2
602
+ """
603
+ for cache in _first_arg_dependent_caches:
604
+ # NOTE: this could be fixed by making the caches a MutableMapping, but
605
+ # that doesn't seem to be correctly covariant in its values, so other
606
+ # parts fail to work nicely..
607
+ cache.clear() # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
608
+
609
+
610
+ if TYPE_CHECKING:
611
+ import pyopencl as cl
612
+ from pyopencl.array import Array as CLArray
613
+
614
+ atexit.register(clear_first_arg_caches)
615
+
616
+ # }}}
617
+
618
+
619
+ # {{{ pytest fixtures
620
+
621
+ class _ContextFactory:
622
+ device: cl.Device
623
+
624
+ def __init__(self, device: cl.Device):
625
+ self.device = device
626
+
627
+ def __call__(self):
628
+ # Get rid of leftovers from past tests.
629
+ # CL implementations are surprisingly limited in how many
630
+ # simultaneous contexts they allow...
631
+ clear_first_arg_caches()
632
+
633
+ from gc import collect
634
+ collect()
635
+
636
+ import pyopencl as cl
637
+ return cl.Context([self.device])
638
+
639
+ @override
640
+ def __str__(self) -> str:
641
+ # Don't show address, so that parallel test collection works
642
+ device = self.device.name.strip()
643
+ platform = self.device.platform.name.strip()
644
+ return f"<context factory for <pyopencl.Device '{device}' on '{platform}'>>"
645
+
646
+
647
+ DeviceOrPlatformT = TypeVar("DeviceOrPlatformT", "cl.Device", "cl.Platform")
648
+
649
+
650
+ def _find_cl_obj(
651
+ objs: Sequence[DeviceOrPlatformT],
652
+ identifier: str
653
+ ) -> DeviceOrPlatformT:
654
+ try:
655
+ num = int(identifier)
656
+ except Exception:
657
+ pass
658
+ else:
659
+ return objs[num]
660
+
661
+ for obj in objs:
662
+ if identifier.lower() in (obj.name + " " + obj.vendor).lower():
663
+ return obj
664
+ raise RuntimeError(f"object '{identifier}' not found")
665
+
666
+
667
+ def get_test_platforms_and_devices(
668
+ plat_dev_string: str | None = None
669
+ ):
670
+ """Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'.
671
+
672
+ :return: list of tuples (platform, [device, device, ...])
673
+ """
674
+
675
+ import pyopencl as cl
676
+
677
+ if plat_dev_string is None:
678
+ import os
679
+ plat_dev_string = os.environ.get("PYOPENCL_TEST", None)
680
+
681
+ if plat_dev_string:
682
+ result: list[tuple[cl.Platform, list[cl.Device]]] = []
683
+
684
+ for entry in plat_dev_string.split(";"):
685
+ lhsrhs = entry.split(":")
686
+
687
+ if len(lhsrhs) == 1:
688
+ platform = _find_cl_obj(cl.get_platforms(), lhsrhs[0])
689
+ result.append((platform, platform.get_devices()))
690
+
691
+ elif len(lhsrhs) != 2:
692
+ raise RuntimeError("invalid syntax of PYOPENCL_TEST")
693
+ else:
694
+ plat_str, dev_strs = lhsrhs
695
+
696
+ platform = _find_cl_obj(cl.get_platforms(), plat_str)
697
+ devs = platform.get_devices()
698
+ result.append(
699
+ (platform,
700
+ [_find_cl_obj(devs, dev_id)
701
+ for dev_id in dev_strs.split(",")]))
702
+
703
+ return result
704
+
705
+ else:
706
+ return [
707
+ (platform, platform.get_devices())
708
+ for platform in cl.get_platforms()]
709
+
710
+
711
+ def get_pyopencl_fixture_arg_names(
712
+ metafunc: pytest.Metafunc,
713
+ extra_arg_names: list[str] | None = None) -> list[str]:
714
+ if extra_arg_names is None:
715
+ extra_arg_names = []
716
+
717
+ supported_arg_names = [
718
+ "platform", "device",
719
+ "ctx_factory", "ctx_getter",
720
+ *extra_arg_names
721
+ ]
722
+
723
+ arg_names: list[str] = []
724
+ for arg in supported_arg_names:
725
+ if arg not in metafunc.fixturenames:
726
+ continue
727
+
728
+ if arg == "ctx_getter":
729
+ from warnings import warn
730
+ warn(
731
+ "The 'ctx_getter' arg is deprecated in favor of 'ctx_factory'.",
732
+ DeprecationWarning, stacklevel=2)
733
+
734
+ arg_names.append(arg)
735
+
736
+ return arg_names
737
+
738
+
739
+ def get_pyopencl_fixture_arg_values() -> tuple[list[dict[str, Any]],
740
+ Callable[[Any], str]]:
741
+ import pyopencl as cl
742
+
743
+ arg_values: list[dict[str, Any]] = []
744
+ for platform, devices in get_test_platforms_and_devices():
745
+ for device in devices:
746
+ arg_dict = {
747
+ "platform": platform,
748
+ "device": device,
749
+ "ctx_factory": _ContextFactory(device),
750
+ "ctx_getter": _ContextFactory(device)
751
+ }
752
+ arg_values.append(arg_dict)
753
+
754
+ def idfn(val: Any) -> str:
755
+ if isinstance(val, cl.Platform):
756
+ # Don't show address, so that parallel test collection works
757
+ return f"<pyopencl.Platform '{val.name}'>"
758
+ else:
759
+ return str(val)
760
+
761
+ return arg_values, idfn
762
+
763
+
764
+ def pytest_generate_tests_for_pyopencl(metafunc: pytest.Metafunc) -> None:
765
+ """Using the line::
766
+
767
+ from pyopencl.tools import pytest_generate_tests_for_pyopencl
768
+ as pytest_generate_tests
769
+
770
+ in your `pytest <https://docs.pytest.org/en/latest/>`__ test scripts allows
771
+ you to use the arguments *ctx_factory*, *device*, or *platform* in your test
772
+ functions, and they will automatically be run for each OpenCL device/platform
773
+ in the system, as appropriate.
774
+
775
+ The following two environment variables is also supported to control
776
+ device/platform choice::
777
+
778
+ PYOPENCL_TEST=0:0,1;intel=i5,i7
779
+ """
780
+
781
+ arg_names = get_pyopencl_fixture_arg_names(metafunc)
782
+ if not arg_names:
783
+ return
784
+
785
+ arg_values, ids = get_pyopencl_fixture_arg_values()
786
+ arg_values = [
787
+ tuple(arg_dict[name] for name in arg_names)
788
+ for arg_dict in arg_values
789
+ ]
790
+
791
+ metafunc.parametrize(arg_names, arg_values, ids=ids)
792
+
793
+ # }}}
794
+
795
+
796
+ # {{{ C argument lists
797
+
798
+ ArgType: TypeAlias = "np.dtype[Any] | VectorArg"
799
+ ArgDType: TypeAlias = "np.dtype[Any] | None"
800
+
801
+
802
+ class Argument(ABC):
803
+ """
804
+ .. automethod:: declarator
805
+ """
806
+
807
+ @abstractmethod
808
+ def declarator(self) -> str:
809
+ pass
810
+
811
+
812
+ @dataclass(frozen=True, init=False)
813
+ class DtypedArgument(Argument, ABC):
814
+ """
815
+ .. autoattribute:: name
816
+ .. autoattribute:: dtype
817
+ """
818
+ dtype: np.dtype[Any]
819
+ name: str
820
+
821
+ def __init__(self, dtype: DTypeLike, name: str) -> None:
822
+ object.__setattr__(self, "name", name)
823
+ object.__setattr__(self, "dtype", np.dtype(dtype))
824
+
825
+
826
+ @dataclass(frozen=True)
827
+ class VectorArg(DtypedArgument):
828
+ """Inherits from :class:`DtypedArgument`.
829
+
830
+ .. automethod:: __init__
831
+ """
832
+ with_offset: bool
833
+
834
+ def __init__(self, dtype: DTypeLike, name: str, with_offset: bool = False):
835
+ super().__init__(dtype, name)
836
+ object.__setattr__(self, "with_offset", with_offset)
837
+
838
+ @override
839
+ def declarator(self) -> str:
840
+ if self.with_offset:
841
+ # Two underscores -> less likelihood of a name clash.
842
+ return "__global {} *{}__base, long {}__offset".format(
843
+ dtype_to_ctype(self.dtype), self.name, self.name)
844
+ else:
845
+ result = "__global {} *{}".format(dtype_to_ctype(self.dtype), self.name)
846
+
847
+ return result
848
+
849
+
850
+ @dataclass(frozen=True, init=False)
851
+ class ScalarArg(DtypedArgument):
852
+ """Inherits from :class:`DtypedArgument`."""
853
+
854
+ @override
855
+ def declarator(self) -> str:
856
+ return "{} {}".format(dtype_to_ctype(self.dtype), self.name)
857
+
858
+
859
+ @dataclass(frozen=True)
860
+ class OtherArg(Argument):
861
+ decl: str
862
+ name: str
863
+
864
+ @override
865
+ def declarator(self) -> str:
866
+ return self.decl
867
+
868
+
869
+ def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument:
870
+ for aspace in ["__local", "__constant"]:
871
+ if aspace in c_arg:
872
+ raise RuntimeError("cannot deal with local or constant "
873
+ "OpenCL address spaces in C argument lists ")
874
+
875
+ c_arg = c_arg.replace("__global", "")
876
+
877
+ if with_offset:
878
+ def vec_arg_factory(dtype: DTypeLike, name: str) -> VectorArg:
879
+ return VectorArg(dtype, name, with_offset=True)
880
+ else:
881
+ vec_arg_factory = VectorArg
882
+
883
+ from pyopencl.compyte.dtypes import parse_c_arg_backend
884
+
885
+ return parse_c_arg_backend(c_arg, ScalarArg, vec_arg_factory)
886
+
887
+
888
+ def parse_arg_list(
889
+ arguments: str | Sequence[str] | Sequence[Argument],
890
+ with_offset: bool = False) -> Sequence[DtypedArgument]:
891
+ """Parse a list of kernel arguments. *arguments* may be a comma-separate
892
+ list of C declarators in a string, a list of strings representing C
893
+ declarators, or :class:`Argument` objects.
894
+ """
895
+
896
+ if isinstance(arguments, str):
897
+ arguments = arguments.split(",")
898
+
899
+ def parse_single_arg(obj: str | Argument) -> DtypedArgument:
900
+ if isinstance(obj, str):
901
+ from pyopencl.tools import parse_c_arg
902
+ return parse_c_arg(obj, with_offset=with_offset)
903
+ else:
904
+ assert isinstance(obj, DtypedArgument)
905
+ return obj
906
+
907
+ return [parse_single_arg(arg) for arg in arguments]
908
+
909
+
910
+ def get_arg_list_arg_types(arg_types: Sequence[Argument]) -> tuple[ArgType, ...]:
911
+ result: list[ArgType] = []
912
+
913
+ for arg_type in arg_types:
914
+ if isinstance(arg_type, ScalarArg):
915
+ result.append(arg_type.dtype)
916
+ elif isinstance(arg_type, VectorArg):
917
+ result.append(arg_type)
918
+ else:
919
+ raise RuntimeError(f"arg type not understood: {type(arg_type)}")
920
+
921
+ return tuple(result)
922
+
923
+
924
+ def get_arg_list_scalar_arg_dtypes(
925
+ arg_types: Sequence[Argument]
926
+ ) -> Sequence[ArgDType]:
927
+ result: list[ArgDType] = []
928
+
929
+ for arg_type in arg_types:
930
+ if isinstance(arg_type, ScalarArg):
931
+ result.append(arg_type.dtype)
932
+ elif isinstance(arg_type, VectorArg):
933
+ result.append(None)
934
+ if arg_type.with_offset:
935
+ result.append(np.dtype(np.int64))
936
+ else:
937
+ raise RuntimeError(f"arg type not understood: {type(arg_type)}")
938
+
939
+ return result
940
+
941
+
942
+ def get_arg_offset_adjuster_code(arg_types: Sequence[Argument]) -> str:
943
+ result: list[str] = []
944
+
945
+ for arg_type in arg_types:
946
+ if isinstance(arg_type, VectorArg) and arg_type.with_offset:
947
+ name = arg_type.name
948
+ ctype = dtype_to_ctype(arg_type.dtype)
949
+ result.append(
950
+ f"__global {ctype} *{name} = "
951
+ f"(__global {ctype} *) "
952
+ f"((__global char *) {name}__base + {name}__offset);")
953
+
954
+ return "\n".join(result)
955
+
956
+ # }}}
957
+
958
+
959
+ def get_gl_sharing_context_properties() -> list[tuple[cl.context_properties, Any]]:
960
+ import pyopencl as cl
961
+
962
+ ctx_props = cl.context_properties
963
+
964
+ from OpenGL import platform as gl_platform
965
+
966
+ props: list[tuple[cl.context_properties, Any]] = []
967
+
968
+ import sys
969
+ if sys.platform in ["linux", "linux2"]:
970
+ from OpenGL import GLX
971
+ props.append(
972
+ (ctx_props.GL_CONTEXT_KHR, GLX.glXGetCurrentContext()))
973
+ props.append(
974
+ (ctx_props.GLX_DISPLAY_KHR,
975
+ GLX.glXGetCurrentDisplay()))
976
+ elif sys.platform == "win32":
977
+ from OpenGL import WGL
978
+ props.append(
979
+ (ctx_props.GL_CONTEXT_KHR, gl_platform.GetCurrentContext()))
980
+ props.append(
981
+ (ctx_props.WGL_HDC_KHR,
982
+ WGL.wglGetCurrentDC()))
983
+ elif sys.platform == "darwin":
984
+ props.append(
985
+ (ctx_props.CONTEXT_PROPERTY_USE_CGL_SHAREGROUP_APPLE,
986
+ cl.get_apple_cgl_share_group()))
987
+ else:
988
+ raise NotImplementedError(f"platform '{sys.platform}' not yet supported")
989
+
990
+ return props
991
+
992
+
993
+ class _CDeclList:
994
+ def __init__(self, device: cl.Device) -> None:
995
+ self.device: cl.Device = device
996
+ self.declared_dtypes: set[np.dtype[Any]] = set()
997
+ self.declarations: list[str] = []
998
+ self.saw_double: bool = False
999
+ self.saw_complex: bool = False
1000
+
1001
+ def add_dtype(self, dtype: DTypeLike) -> None:
1002
+ dtype = np.dtype(dtype)
1003
+
1004
+ if dtype.type in (np.float64, np.complex128):
1005
+ self.saw_double = True
1006
+
1007
+ if dtype.kind == "c":
1008
+ self.saw_complex = True
1009
+
1010
+ if dtype.kind != "V":
1011
+ return
1012
+
1013
+ if dtype in self.declared_dtypes:
1014
+ return
1015
+
1016
+ from pyopencl.cltypes import vec_type_to_scalar_and_count
1017
+
1018
+ if dtype in vec_type_to_scalar_and_count:
1019
+ return
1020
+
1021
+ if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
1022
+ self.add_dtype(dtype.subdtype[0])
1023
+ return
1024
+
1025
+ fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
1026
+ if fields is not None:
1027
+ for _name, field_data in sorted(fields.items()):
1028
+ field_dtype, _offset = field_data[:2]
1029
+ self.add_dtype(field_dtype)
1030
+
1031
+ _, cdecl = match_dtype_to_c_struct(
1032
+ self.device, dtype_to_ctype(dtype), dtype)
1033
+
1034
+ self.declarations.append(cdecl)
1035
+ self.declared_dtypes.add(dtype)
1036
+
1037
+ def visit_arguments(self, arguments: Sequence[Argument]) -> None:
1038
+ for arg in arguments:
1039
+ if not isinstance(arg, DtypedArgument):
1040
+ continue
1041
+
1042
+ dtype = arg.dtype
1043
+ if dtype.type in (np.float64, np.complex128):
1044
+ self.saw_double = True
1045
+
1046
+ if dtype.kind == "c":
1047
+ self.saw_complex = True
1048
+
1049
+ def get_declarations(self) -> str:
1050
+ result = "\n\n".join(self.declarations)
1051
+
1052
+ if self.saw_complex:
1053
+ result = (
1054
+ "#include <pyopencl-complex.h>\n\n"
1055
+ + result)
1056
+
1057
+ if self.saw_double:
1058
+ result = (
1059
+ """
1060
+ #if __OPENCL_C_VERSION__ < 120
1061
+ #pragma OPENCL EXTENSION cl_khr_fp64: enable
1062
+ #endif
1063
+ #define PYOPENCL_DEFINE_CDOUBLE
1064
+ """
1065
+ + result)
1066
+
1067
+ return result
1068
+
1069
+
1070
+ class _DTypeDict(TypedDict):
1071
+ names: list[str]
1072
+ formats: list[np.dtype[Any]]
1073
+ offsets: list[int]
1074
+ itemsize: int
1075
+
1076
+
1077
+ @memoize
1078
+ def match_dtype_to_c_struct(
1079
+ device: cl.Device,
1080
+ name: str,
1081
+ dtype: np.dtype[Any],
1082
+ context: cl.Context | None = None) -> tuple[np.dtype[Any], str]:
1083
+ """Return a tuple ``(dtype, c_decl)`` such that the C struct declaration
1084
+ in ``c_decl`` and the structure :class:`numpy.dtype` instance ``dtype``
1085
+ have the same memory layout.
1086
+
1087
+ Note that *dtype* may be modified from the value that was passed in,
1088
+ for example to insert padding.
1089
+
1090
+ (As a remark on implementation, this routine runs a small kernel on
1091
+ the given *device* to ensure that :mod:`numpy` and C offsets and
1092
+ sizes match.)
1093
+
1094
+ .. versionadded:: 2013.1
1095
+
1096
+ This example explains the use of this function::
1097
+
1098
+ >>> import numpy as np
1099
+ >>> import pyopencl as cl
1100
+ >>> import pyopencl.tools
1101
+ >>> ctx = cl.create_some_context()
1102
+ >>> dtype = np.dtype([("id", np.uint32), ("value", np.float32)])
1103
+ >>> dtype, c_decl = pyopencl.tools.match_dtype_to_c_struct(
1104
+ ... ctx.devices[0], 'id_val', dtype)
1105
+ >>> print c_decl
1106
+ typedef struct {
1107
+ unsigned id;
1108
+ float value;
1109
+ } id_val;
1110
+ >>> print dtype
1111
+ [('id', '<u4'), ('value', '<f4')]
1112
+ >>> cl.tools.get_or_register_dtype('id_val', dtype)
1113
+
1114
+ As this example shows, it is important to call
1115
+ :func:`get_or_register_dtype` on the modified ``dtype`` returned by this
1116
+ function, not the original one.
1117
+ """
1118
+ fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
1119
+ if not fields:
1120
+ raise ValueError(f"dtype has no fields: '{dtype}'")
1121
+
1122
+ import pyopencl as cl
1123
+
1124
+ sorted_fields = sorted(
1125
+ fields.items(),
1126
+ key=lambda name_dtype_offset: name_dtype_offset[1][1])
1127
+
1128
+ c_fields: list[str] = []
1129
+ for field_name, dtype_and_offset in sorted_fields:
1130
+ field_dtype, _offset = dtype_and_offset[:2]
1131
+ if hasattr(field_dtype, "subdtype") and field_dtype.subdtype is not None:
1132
+ array_dtype = field_dtype.subdtype[0]
1133
+ if hasattr(array_dtype, "subdtype") and array_dtype.subdtype is not None:
1134
+ raise NotImplementedError("nested array dtypes are not supported")
1135
+ array_dims = field_dtype.subdtype[1]
1136
+ dims_str = ""
1137
+ try:
1138
+ for dim in array_dims:
1139
+ dims_str += f"[{dim}]"
1140
+ except TypeError:
1141
+ dims_str = f"[{array_dims}]"
1142
+ c_fields.append(" {} {}{};".format(
1143
+ dtype_to_ctype(array_dtype), field_name, dims_str)
1144
+ )
1145
+ else:
1146
+ c_fields.append(
1147
+ " {} {};".format(dtype_to_ctype(field_dtype), field_name))
1148
+
1149
+ c_decl = "typedef struct {{\n{}\n}} {};\n\n".format(
1150
+ "\n".join(c_fields),
1151
+ name)
1152
+
1153
+ cdl = _CDeclList(device)
1154
+ for _field_name, dtype_and_offset in sorted_fields:
1155
+ field_dtype, _offset = dtype_and_offset[:2]
1156
+ cdl.add_dtype(field_dtype)
1157
+
1158
+ pre_decls = cdl.get_declarations()
1159
+
1160
+ offset_code = "\n".join(
1161
+ f"result[{i + 1}] = pycl_offsetof({name}, {field_name});"
1162
+ for i, (field_name, _) in enumerate(sorted_fields))
1163
+
1164
+ src = rf"""
1165
+ #define pycl_offsetof(st, m) \
1166
+ ((uint) ((__local char *) &(dummy.m) \
1167
+ - (__local char *)&dummy ))
1168
+
1169
+ {pre_decls}
1170
+
1171
+ {c_decl}
1172
+
1173
+ __kernel void get_size_and_offsets(__global uint *result)
1174
+ {{
1175
+ result[0] = sizeof({name});
1176
+ __local {name} dummy;
1177
+ {offset_code}
1178
+ }}
1179
+ """
1180
+
1181
+ if context is None:
1182
+ context = cl.Context([device])
1183
+
1184
+ queue = cl.CommandQueue(context)
1185
+
1186
+ prg = cl.Program(context, src)
1187
+ knl = prg.build(devices=[device]).get_size_and_offsets
1188
+
1189
+ import pyopencl.array as cl_array
1190
+
1191
+ result_buf = cl_array.empty(queue, 1+len(sorted_fields), np.uint32)
1192
+ assert result_buf.data is not None
1193
+
1194
+ knl(queue, (1,), (1,), result_buf.data)
1195
+ queue.finish()
1196
+ size_and_offsets = result_buf.get()
1197
+
1198
+ size = int(size_and_offsets[0])
1199
+ offsets = size_and_offsets[1:]
1200
+
1201
+ if any(ofs >= size for ofs in offsets):
1202
+ # offsets not plausible
1203
+
1204
+ if dtype.itemsize == size:
1205
+ # If sizes match, use numpy's idea of the offsets.
1206
+ offsets = [dtype_and_offset[1] for _name, dtype_and_offset in sorted_fields]
1207
+ else:
1208
+ raise RuntimeError(
1209
+ "OpenCL compiler reported offsetof() past sizeof() for struct "
1210
+ f"layout on '{device}'. This makes no sense, and it usually "
1211
+ "indicates a compiler bug. Refusing to discover struct layout.")
1212
+
1213
+ result_buf.data.release()
1214
+ del knl
1215
+ del prg
1216
+ del queue
1217
+ del context
1218
+
1219
+ try:
1220
+ dtype_arg_dict = _DTypeDict(
1221
+ names=[name for name, _ in sorted_fields],
1222
+ formats=[dtype_and_offset[0] for _, dtype_and_offset in sorted_fields],
1223
+ offsets=[int(x) for x in offsets],
1224
+ itemsize=int(size_and_offsets[0]),
1225
+ )
1226
+ arg_dtype = np.dtype(dtype_arg_dict)
1227
+
1228
+ if arg_dtype.itemsize != size_and_offsets[0]:
1229
+ # "Old" versions of numpy (1.6.x?) silently ignore "itemsize". Boo.
1230
+ dtype_arg_dict["names"].append("_pycl_size_fixer")
1231
+ dtype_arg_dict["formats"].append(np.dtype(np.uint8))
1232
+ dtype_arg_dict["offsets"].append(int(size_and_offsets[0]) - 1)
1233
+
1234
+ arg_dtype = np.dtype(dtype_arg_dict)
1235
+ except NotImplementedError:
1236
+ def calc_field_type() -> Iterator[tuple[str, str | np.dtype[Any]]]:
1237
+ total_size = 0
1238
+ padding_count = 0
1239
+ for offset, (field_name, dtype_and_offset) in zip(
1240
+ offsets, sorted_fields, strict=True):
1241
+ field_dtype, _ = dtype_and_offset[:2]
1242
+ if offset > total_size:
1243
+ padding_count += 1
1244
+ yield f"__pycl_padding{padding_count}", f"V{offset - total_size}"
1245
+
1246
+ yield field_name, field_dtype
1247
+ total_size = field_dtype.itemsize + offset
1248
+
1249
+ arg_dtype = np.dtype(list(calc_field_type()))
1250
+
1251
+ assert arg_dtype.itemsize == size_and_offsets[0]
1252
+
1253
+ return arg_dtype, c_decl
1254
+
1255
+
1256
+ @memoize
1257
+ def dtype_to_c_struct(device: cl.Device, dtype: np.dtype[Any]) -> str:
1258
+ fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None", dtype.fields)
1259
+ if fields is None:
1260
+ return ""
1261
+
1262
+ from pyopencl.cltypes import vec_type_to_scalar_and_count
1263
+
1264
+ if dtype in vec_type_to_scalar_and_count:
1265
+ # Vector types are built-in. Don't try to redeclare those.
1266
+ return ""
1267
+
1268
+ matched_dtype, c_decl = match_dtype_to_c_struct(
1269
+ device, dtype_to_ctype(dtype), dtype)
1270
+
1271
+ matched_fields = cast("Mapping[str, tuple[np.dtype[Any], int]] | None",
1272
+ matched_dtype.fields)
1273
+ assert matched_fields is not None
1274
+
1275
+ def dtypes_match() -> bool:
1276
+ result = len(fields) == len(matched_fields)
1277
+
1278
+ for name, val in fields.items():
1279
+ result = result and matched_fields[name] == val
1280
+
1281
+ return result
1282
+
1283
+ assert dtypes_match()
1284
+
1285
+ return c_decl
1286
+
1287
+
1288
+ # {{{ code generation/templating helper
1289
+
1290
+ def _process_code_for_macro(code: str) -> str:
1291
+ code = code.replace("//CL//", "\n")
1292
+
1293
+ if "//" in code:
1294
+ raise RuntimeError(
1295
+ "end-of-line comments ('//') may not be used in code snippets")
1296
+
1297
+ return code.replace("\n", " \\\n")
1298
+
1299
+
1300
+ class _TextTemplate(ABC):
1301
+ @abstractmethod
1302
+ def render(self, context: dict[str, Any]) -> str:
1303
+ pass
1304
+
1305
+
1306
+ @dataclass(frozen=True)
1307
+ class _SimpleTextTemplate(_TextTemplate):
1308
+ txt: str
1309
+
1310
+ @override
1311
+ def render(self, context: dict[str, Any]) -> str:
1312
+ return self.txt
1313
+
1314
+
1315
+ @dataclass(frozen=True)
1316
+ class _PrintfTextTemplate(_TextTemplate):
1317
+ txt: str
1318
+
1319
+ @override
1320
+ def render(self, context: dict[str, Any]) -> str:
1321
+ return self.txt % context
1322
+
1323
+
1324
+ @dataclass(frozen=True)
1325
+ class _MakoTextTemplate(_TextTemplate):
1326
+ txt: str
1327
+ template: Template = field(init=False)
1328
+
1329
+ def __post_init__(self) -> None:
1330
+ from mako.template import Template
1331
+
1332
+ object.__setattr__(self, "template", Template(self.txt, strict_undefined=True))
1333
+
1334
+ @override
1335
+ def render(self, context: dict[str, Any]) -> str:
1336
+ return self.template.render(**context)
1337
+
1338
+
1339
+ class _ArgumentPlaceholder:
1340
+ """A placeholder for subclasses of :class:`DtypedArgument`. This is needed
1341
+ because the concrete dtype of the argument is not known at template
1342
+ creation time--it may be a type alias that will only be filled in
1343
+ at run time. These types take the place of these proto-arguments until
1344
+ all types are known.
1345
+
1346
+ See also :class:`_TemplateRenderer.render_arg`.
1347
+ """
1348
+
1349
+ target_class: ClassVar[type[DtypedArgument]]
1350
+
1351
+ def __init__(self,
1352
+ typename: DTypeLike,
1353
+ name: str,
1354
+ **extra_kwargs: Any) -> None:
1355
+ self.typename: DTypeLike = typename
1356
+ self.name: str = name
1357
+ self.extra_kwargs: dict[str, Any] = extra_kwargs
1358
+
1359
+
1360
+ class _VectorArgPlaceholder(_ArgumentPlaceholder):
1361
+ target_class: ClassVar[type[DtypedArgument]] = VectorArg
1362
+
1363
+
1364
+ class _ScalarArgPlaceholder(_ArgumentPlaceholder):
1365
+ target_class: ClassVar[type[DtypedArgument]] = ScalarArg
1366
+
1367
+
1368
+ class _TemplateRenderer:
1369
+ def __init__(self,
1370
+ template: KernelTemplateBase,
1371
+ type_aliases: (
1372
+ dict[str, np.dtype[Any]]
1373
+ | Sequence[tuple[str, np.dtype[Any]]]),
1374
+ var_values: dict[str, str] | Sequence[tuple[str, str]],
1375
+ context: cl.Context | None = None,
1376
+ options: Any = None) -> None:
1377
+ self.template: KernelTemplateBase = template
1378
+ self.type_aliases: dict[str, np.dtype[Any]] = dict(type_aliases)
1379
+ self.var_dict: dict[str, str] = dict(var_values)
1380
+
1381
+ for name in self.var_dict:
1382
+ if name.startswith("macro_"):
1383
+ self.var_dict[name] = _process_code_for_macro(self.var_dict[name])
1384
+
1385
+ self.context: cl.Context | None = context
1386
+ self.options: Any = options
1387
+
1388
+ @overload
1389
+ def __call__(self, txt: None) -> None: ...
1390
+
1391
+ @overload
1392
+ def __call__(self, txt: str) -> str: ...
1393
+
1394
+ def __call__(self, txt: str | None) -> str | None:
1395
+ if txt is None:
1396
+ return txt
1397
+
1398
+ result = self.template.get_text_template(txt).render(self.var_dict)
1399
+
1400
+ return str(result)
1401
+
1402
+ def get_rendered_kernel(self, txt: str, kernel_name: str) -> cl.Kernel:
1403
+ if self.context is None:
1404
+ raise ValueError("context not provided -- cannot render kernel")
1405
+
1406
+ import pyopencl as cl
1407
+ prg = cl.Program(self.context, self(txt)).build(self.options)
1408
+
1409
+ kernel_name_prefix = self.var_dict.get("kernel_name_prefix")
1410
+ if kernel_name_prefix is not None:
1411
+ kernel_name = kernel_name_prefix+kernel_name
1412
+
1413
+ return getattr(prg, kernel_name)
1414
+
1415
+ def parse_type(self, typename: Any) -> np.dtype[Any]:
1416
+ if isinstance(typename, str):
1417
+ try:
1418
+ return self.type_aliases[typename]
1419
+ except KeyError:
1420
+ from pyopencl.compyte.dtypes import NAME_TO_DTYPE
1421
+ return NAME_TO_DTYPE[typename]
1422
+ else:
1423
+ return np.dtype(typename)
1424
+
1425
+ def render_arg(self, arg_placeholder: _ArgumentPlaceholder) -> DtypedArgument:
1426
+ return arg_placeholder.target_class(
1427
+ self.parse_type(arg_placeholder.typename),
1428
+ arg_placeholder.name,
1429
+ **arg_placeholder.extra_kwargs)
1430
+
1431
+ _C_COMMENT_FINDER: ClassVar[re.Pattern[str]] = re.compile(r"/\*.*?\*/")
1432
+
1433
+ def render_argument_list(self,
1434
+ *arg_lists: Any,
1435
+ with_offset: bool = False,
1436
+ **kwargs: Any) -> list[Argument]:
1437
+ if kwargs:
1438
+ raise TypeError("unrecognized kwargs: " + ", ".join(kwargs))
1439
+
1440
+ all_args: list[Any] = []
1441
+ for arg_list in arg_lists:
1442
+ if isinstance(arg_list, str):
1443
+ arg_list = str(
1444
+ self.template
1445
+ .get_text_template(arg_list).render(self.var_dict))
1446
+ arg_list = self._C_COMMENT_FINDER.sub("", arg_list)
1447
+ arg_list = arg_list.replace("\n", " ")
1448
+
1449
+ all_args.extend(arg_list.split(","))
1450
+ else:
1451
+ all_args.extend(arg_list)
1452
+
1453
+ if with_offset:
1454
+ def vec_arg_factory(
1455
+ typename: DTypeLike,
1456
+ name: str) -> _VectorArgPlaceholder:
1457
+ return _VectorArgPlaceholder(typename, name, with_offset=True)
1458
+ else:
1459
+ vec_arg_factory = _VectorArgPlaceholder
1460
+
1461
+ from pyopencl.compyte.dtypes import parse_c_arg_backend
1462
+
1463
+ parsed_args: list[Argument] = []
1464
+ for arg in all_args:
1465
+ if isinstance(arg, str):
1466
+ arg = arg.strip()
1467
+ if not arg:
1468
+ continue
1469
+
1470
+ ph = parse_c_arg_backend(arg,
1471
+ _ScalarArgPlaceholder, vec_arg_factory,
1472
+ name_to_dtype=lambda x: x) # pyright: ignore[reportArgumentType]
1473
+ parsed_arg = self.render_arg(ph)
1474
+ elif isinstance(arg, Argument):
1475
+ parsed_arg = arg
1476
+ elif isinstance(arg, tuple):
1477
+ assert isinstance(arg[0], str)
1478
+ assert isinstance(arg[1], str)
1479
+ parsed_arg = ScalarArg(self.parse_type(arg[0]), arg[1])
1480
+ else:
1481
+ raise TypeError(f"unexpected argument type: {type(arg)}")
1482
+
1483
+ parsed_args.append(parsed_arg)
1484
+
1485
+ return parsed_args
1486
+
1487
+ def get_type_decl_preamble(self,
1488
+ device: cl.Device,
1489
+ decl_type_names: Sequence[DTypeLike],
1490
+ arguments: Sequence[Argument] | None = None,
1491
+ ) -> str:
1492
+ cdl = _CDeclList(device)
1493
+
1494
+ for typename in decl_type_names:
1495
+ cdl.add_dtype(self.parse_type(typename))
1496
+
1497
+ if arguments is not None:
1498
+ cdl.visit_arguments(arguments)
1499
+
1500
+ for _, tv in sorted(self.type_aliases.items()):
1501
+ cdl.add_dtype(tv)
1502
+
1503
+ type_alias_decls = [
1504
+ "typedef {} {};".format(dtype_to_ctype(val), name)
1505
+ for name, val in sorted(self.type_aliases.items())
1506
+ ]
1507
+
1508
+ return cdl.get_declarations() + "\n" + "\n".join(type_alias_decls)
1509
+
1510
+
1511
+ class KernelTemplateBase(ABC):
1512
+ def __init__(self, template_processor: str | None = None) -> None:
1513
+ self.template_processor: str | None = template_processor
1514
+
1515
+ self.build_cache: dict[Hashable, Any] = {}
1516
+ _first_arg_dependent_caches.append(self.build_cache)
1517
+
1518
+ _TEMPLATE_PROCESSOR_PATTERN: ClassVar[re.Pattern[str]] = (
1519
+ re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
1520
+ )
1521
+
1522
+ @memoize_method
1523
+ def get_text_template(self, txt: str) -> _TextTemplate:
1524
+ proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt)
1525
+ tpl_processor = None
1526
+
1527
+ if proc_match is not None:
1528
+ tpl_processor = proc_match.group(1)
1529
+ # chop off //CL// mark
1530
+ txt = txt[len(proc_match.group(0)):]
1531
+
1532
+ if tpl_processor is None:
1533
+ tpl_processor = self.template_processor
1534
+
1535
+ if tpl_processor is None or tpl_processor == "none":
1536
+ return _SimpleTextTemplate(txt)
1537
+ elif tpl_processor == "printf":
1538
+ return _PrintfTextTemplate(txt)
1539
+ elif tpl_processor == "mako":
1540
+ return _MakoTextTemplate(txt)
1541
+ else:
1542
+ raise RuntimeError(f"unknown template processor '{tpl_processor}'")
1543
+
1544
+ # TODO: this does not seem to be used anywhere -> deprecate / remove
1545
+ def get_preamble(self) -> str:
1546
+ return ""
1547
+
1548
+ def get_renderer(self,
1549
+ type_aliases: (
1550
+ dict[str, np.dtype[Any]]
1551
+ | Sequence[tuple[str, np.dtype[Any]]]),
1552
+ var_values: dict[str, str] | Sequence[tuple[str, str]],
1553
+ context: cl.Context | None = None, # pyright: ignore[reportUnusedParameter]
1554
+ options: Any = None, # pyright: ignore[reportUnusedParameter]
1555
+ ) -> _TemplateRenderer:
1556
+ return _TemplateRenderer(self, type_aliases, var_values)
1557
+
1558
+ @abstractmethod
1559
+ def build_inner(self,
1560
+ context: cl.Context,
1561
+ *args: Any,
1562
+ **kwargs: Any) -> Callable[..., cl.Event]:
1563
+ pass
1564
+
1565
+ def build(self, context: cl.Context, *args: Any, **kwargs: Any) -> Any:
1566
+ """Provide caching for an :meth:`build_inner`."""
1567
+
1568
+ cache_key = (context, args, tuple(sorted(kwargs.items())))
1569
+ try:
1570
+ return self.build_cache[cache_key]
1571
+ except KeyError:
1572
+ result = self.build_inner(context, *args, **kwargs)
1573
+ self.build_cache[cache_key] = result
1574
+ return result
1575
+
1576
+ # }}}
1577
+
1578
+
1579
+ # {{{ array_module
1580
+
1581
+ # TODO: this is not used anywhere: deprecate + remove
1582
+
1583
+ class _CLFakeArrayModule:
1584
+ def __init__(self, queue: cl.CommandQueue | None = None) -> None:
1585
+ self.queue: cl.CommandQueue | None = queue
1586
+
1587
+ @property
1588
+ def ndarray(self) -> type[CLArray]:
1589
+ from pyopencl.array import Array
1590
+ return Array
1591
+
1592
+ def dot(self, x: CLArray, y: CLArray) -> NDArray[Any]:
1593
+ from pyopencl.array import dot
1594
+ return dot(x, y, queue=self.queue).get()
1595
+
1596
+ def vdot(self, x: CLArray, y: CLArray) -> NDArray[Any]:
1597
+ from pyopencl.array import vdot
1598
+ return vdot(x, y, queue=self.queue).get()
1599
+
1600
+ def empty(self,
1601
+ shape: int | tuple[int, ...],
1602
+ dtype: DTypeLike,
1603
+ order: str = "C") -> CLArray:
1604
+ from pyopencl.array import empty
1605
+ return empty(self.queue, shape, dtype, order=order)
1606
+
1607
+ def hstack(self, arrays: Sequence[CLArray]) -> CLArray:
1608
+ from pyopencl.array import hstack
1609
+ return hstack(arrays, self.queue)
1610
+
1611
+
1612
+ def array_module(a: Any) -> Any:
1613
+ if isinstance(a, np.ndarray):
1614
+ return np
1615
+ else:
1616
+ from pyopencl.array import Array
1617
+
1618
+ if isinstance(a, Array):
1619
+ return _CLFakeArrayModule(a.queue)
1620
+ else:
1621
+ raise TypeError(f"array type not understood: {type(a)}")
1622
+
1623
+ # }}}
1624
+
1625
+
1626
+ def is_spirv(s: str | bytes) -> TypeIs[bytes]:
1627
+ spirv_magic = b"\x07\x23\x02\x03"
1628
+ return (
1629
+ isinstance(s, bytes)
1630
+ and (
1631
+ s[:4] == spirv_magic
1632
+ or s[:4] == spirv_magic[::-1]))
1633
+
1634
+
1635
+ # {{{ numpy key types builder
1636
+
1637
+ class _NumpyTypesKeyBuilder(KeyBuilderBase): # pyright: ignore[reportUnusedClass]
1638
+ def update_for_VectorArg(self, key_hash: Hash, key: VectorArg) -> None: # noqa: N802
1639
+ self.rec(key_hash, key.dtype)
1640
+ self.update_for_str(key_hash, key.name)
1641
+ self.rec(key_hash, key.with_offset)
1642
+
1643
+ @override
1644
+ def update_for_type(self, key_hash: Hash, key: type) -> None:
1645
+ if issubclass(key, np.generic):
1646
+ self.update_for_str(key_hash, key.__name__)
1647
+ return
1648
+
1649
+ raise TypeError(f"unsupported type for persistent hash keying: {key}")
1650
+
1651
+ # }}}
1652
+
1653
+
1654
+ __all__ = [
1655
+ "AllocatorBase",
1656
+ "AllocatorBase",
1657
+ "Argument",
1658
+ "DeferredAllocator",
1659
+ "DtypedArgument",
1660
+ "ImmediateAllocator",
1661
+ "MemoryPool",
1662
+ "OtherArg",
1663
+ "PooledBuffer",
1664
+ "PooledSVM",
1665
+ "SVMAllocator",
1666
+ "SVMPool",
1667
+ "ScalarArg",
1668
+ "TypeNameNotKnown",
1669
+ "VectorArg",
1670
+ "bitlog2",
1671
+ "clear_first_arg_caches",
1672
+ "dtype_to_ctype",
1673
+ "first_arg_dependent_memoize",
1674
+ "get_or_register_dtype",
1675
+ "parse_arg_list",
1676
+ "pytest_generate_tests_for_pyopencl",
1677
+ "register_dtype",
1678
+ ]
1679
+
1680
+ # vim: foldmethod=marker