pyopencl 2025.2.5__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

Files changed (47) hide show
  1. pyopencl/.libs/libOpenCL-83a5a7fd.so.1.0.0 +0 -0
  2. pyopencl/__init__.py +1995 -0
  3. pyopencl/_cl.cpython-310-x86_64-linux-gnu.so +0 -0
  4. pyopencl/_cl.pyi +2006 -0
  5. pyopencl/_cluda.py +57 -0
  6. pyopencl/_monkeypatch.py +1069 -0
  7. pyopencl/_mymako.py +17 -0
  8. pyopencl/algorithm.py +1454 -0
  9. pyopencl/array.py +3441 -0
  10. pyopencl/bitonic_sort.py +245 -0
  11. pyopencl/bitonic_sort_templates.py +597 -0
  12. pyopencl/cache.py +535 -0
  13. pyopencl/capture_call.py +200 -0
  14. pyopencl/characterize/__init__.py +463 -0
  15. pyopencl/characterize/performance.py +240 -0
  16. pyopencl/cl/pyopencl-airy.cl +324 -0
  17. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  18. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  19. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  20. pyopencl/cl/pyopencl-complex.h +303 -0
  21. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  22. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  23. pyopencl/cl/pyopencl-random123/array.h +325 -0
  24. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  25. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  26. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  27. pyopencl/clmath.py +282 -0
  28. pyopencl/clrandom.py +412 -0
  29. pyopencl/cltypes.py +202 -0
  30. pyopencl/compyte/.gitignore +21 -0
  31. pyopencl/compyte/__init__.py +0 -0
  32. pyopencl/compyte/array.py +241 -0
  33. pyopencl/compyte/dtypes.py +316 -0
  34. pyopencl/compyte/pyproject.toml +52 -0
  35. pyopencl/elementwise.py +1178 -0
  36. pyopencl/invoker.py +417 -0
  37. pyopencl/ipython_ext.py +70 -0
  38. pyopencl/py.typed +0 -0
  39. pyopencl/reduction.py +815 -0
  40. pyopencl/scan.py +1916 -0
  41. pyopencl/tools.py +1565 -0
  42. pyopencl/typing.py +61 -0
  43. pyopencl/version.py +11 -0
  44. pyopencl-2025.2.5.dist-info/METADATA +109 -0
  45. pyopencl-2025.2.5.dist-info/RECORD +47 -0
  46. pyopencl-2025.2.5.dist-info/WHEEL +6 -0
  47. pyopencl-2025.2.5.dist-info/licenses/LICENSE +104 -0
pyopencl/tools.py ADDED
@@ -0,0 +1,1565 @@
1
+ r"""
2
+ .. _memory-pools:
3
+
4
+ Memory Pools
5
+ ------------
6
+
7
+ Memory allocation (e.g. in the form of the :func:`pyopencl.Buffer` constructor)
8
+ can be expensive if used frequently. For example, code based on
9
+ :class:`pyopencl.array.Array` can easily run into this issue because a fresh
10
+ memory area is allocated for each intermediate result. Memory pools are a
11
+ remedy for this problem based on the observation that often many of the block
12
+ allocations are of the same sizes as previously used ones.
13
+
14
+ Then, instead of fully returning the memory to the system and incurring the
15
+ associated reallocation overhead, the pool holds on to the memory and uses it
16
+ to satisfy future allocations of similarly-sized blocks. The pool reacts
17
+ appropriately to out-of-memory conditions as long as all memory allocations
18
+ are made through it. Allocations performed from outside of the pool may run
19
+ into spurious out-of-memory conditions due to the pool owning much or all of
20
+ the available memory.
21
+
22
+ There are two flavors of allocators and memory pools:
23
+
24
+ - :ref:`buf-mempool`
25
+ - :ref:`svm-mempool`
26
+
27
+ Using :class:`pyopencl.array.Array`\ s can be used with memory pools in a
28
+ straightforward manner::
29
+
30
+ mem_pool = pyopencl.tools.MemoryPool(pyopencl.tools.ImmediateAllocator(queue))
31
+ a_dev = cl_array.arange(queue, 2000, dtype=np.float32, allocator=mem_pool)
32
+
33
+ Likewise, SVM-based allocators are directly usable with
34
+ :class:`pyopencl.array.Array`.
35
+
36
+ .. _buf-mempool:
37
+
38
+ :class:`~pyopencl.Buffer`-based Allocators and Memory Pools
39
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
40
+
41
+ .. autoclass:: PooledBuffer
42
+
43
+ .. autoclass:: AllocatorBase
44
+
45
+ .. autoclass:: DeferredAllocator
46
+
47
+ .. autoclass:: ImmediateAllocator
48
+
49
+ .. autoclass:: MemoryPool
50
+
51
+ .. _svm-mempool:
52
+
53
+ :ref:`SVM <svm>`-Based Allocators and Memory Pools
54
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55
+
56
+ SVM functionality requires OpenCL 2.0.
57
+
58
+ .. autoclass:: PooledSVM
59
+
60
+ .. autoclass:: SVMAllocator
61
+
62
+ .. autoclass:: SVMPool
63
+
64
+ CL-Object-dependent Caching
65
+ ---------------------------
66
+
67
+ .. autofunction:: first_arg_dependent_memoize
68
+ .. autofunction:: clear_first_arg_caches
69
+
70
+ Testing
71
+ -------
72
+
73
+ .. autofunction:: pytest_generate_tests_for_pyopencl
74
+
75
+ Argument Types
76
+ --------------
77
+
78
+ .. autoclass:: Argument
79
+ .. autoclass:: DtypedArgument
80
+
81
+ .. autoclass:: VectorArg
82
+ .. autoclass:: ScalarArg
83
+ .. autoclass:: OtherArg
84
+
85
+ .. autofunction:: parse_arg_list
86
+
87
+ Device Characterization
88
+ -----------------------
89
+
90
+ .. automodule:: pyopencl.characterize
91
+ :members:
92
+
93
+ Type aliases
94
+ ------------
95
+
96
+ .. currentmodule:: pyopencl._cl
97
+
98
+ .. class:: AllocatorBase
99
+
100
+ See :class:`pyopencl.tools.AllocatorBase`.
101
+ """
102
+
103
+ from __future__ import annotations
104
+
105
+
106
+ __copyright__ = "Copyright (C) 2010 Andreas Kloeckner"
107
+
108
+ __license__ = """
109
+ Permission is hereby granted, free of charge, to any person
110
+ obtaining a copy of this software and associated documentation
111
+ files (the "Software"), to deal in the Software without
112
+ restriction, including without limitation the rights to use,
113
+ copy, modify, merge, publish, distribute, sublicense, and/or sell
114
+ copies of the Software, and to permit persons to whom the
115
+ Software is furnished to do so, subject to the following
116
+ conditions:
117
+
118
+ The above copyright notice and this permission notice shall be
119
+ included in all copies or substantial portions of the Software.
120
+
121
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
122
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
123
+ OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
124
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
125
+ HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
126
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
127
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
128
+ OTHER DEALINGS IN THE SOFTWARE.
129
+ """
130
+
131
+
132
+ import re
133
+ from abc import ABC, abstractmethod
134
+ from dataclasses import dataclass
135
+ from sys import intern
136
+ from typing import (
137
+ TYPE_CHECKING,
138
+ Any,
139
+ Concatenate,
140
+ ParamSpec,
141
+ TypeVar,
142
+ )
143
+
144
+ import numpy as np
145
+ from typing_extensions import TypeIs, override
146
+
147
+ from pytools import memoize, memoize_method
148
+ from pytools.persistent_dict import KeyBuilder as KeyBuilderBase
149
+
150
+ from pyopencl._cl import bitlog2, get_cl_header_version
151
+ from pyopencl.compyte.dtypes import (
152
+ TypeNameNotKnown,
153
+ dtype_to_ctype,
154
+ get_or_register_dtype,
155
+ register_dtype,
156
+ )
157
+
158
+
159
+ if TYPE_CHECKING:
160
+ from collections.abc import Callable, Hashable, Sequence
161
+
162
+ from numpy.typing import DTypeLike
163
+
164
+
165
+ # Do not add a pyopencl import here: This will add an import cycle.
166
+
167
+
168
+ def _register_types():
169
+ from pyopencl.compyte.dtypes import TYPE_REGISTRY, fill_registry_with_opencl_c_types
170
+
171
+ fill_registry_with_opencl_c_types(TYPE_REGISTRY)
172
+
173
+ get_or_register_dtype("cfloat_t", np.complex64)
174
+ get_or_register_dtype("cdouble_t", np.complex128)
175
+
176
+
177
+ _register_types()
178
+
179
+
180
+ # {{{ imported names
181
+
182
+ from pyopencl._cl import (
183
+ AllocatorBase,
184
+ DeferredAllocator,
185
+ ImmediateAllocator,
186
+ MemoryPool,
187
+ PooledBuffer,
188
+ )
189
+
190
+
191
+ if get_cl_header_version() >= (2, 0):
192
+ from pyopencl._cl import PooledSVM, SVMAllocator, SVMPool
193
+
194
+ # }}}
195
+
196
+
197
+ # {{{ monkeypatch docstrings into imported interfaces
198
+
199
+ _MEMPOOL_IFACE_DOCS = """
200
+ .. note::
201
+
202
+ The current implementation of the memory pool will retain allocated
203
+ memory after it is returned by the application and keep it in a bin
204
+ identified by the leading *leading_bits_in_bin_id* bits of the
205
+ allocation size. To ensure that allocations within each bin are
206
+ interchangeable, allocation sizes are rounded up to the largest size
207
+ that shares the leading bits of the requested allocation size.
208
+
209
+ The current default value of *leading_bits_in_bin_id* is
210
+ four, but this may change in future versions and is not
211
+ guaranteed.
212
+
213
+ *leading_bits_in_bin_id* must be passed by keyword,
214
+ and its role is purely advisory. It is not guaranteed
215
+ that future versions of the pool will use the
216
+ same allocation scheme and/or honor *leading_bits_in_bin_id*.
217
+
218
+ .. attribute:: held_blocks
219
+
220
+ The number of unused blocks being held by this pool.
221
+
222
+ .. attribute:: active_blocks
223
+
224
+ The number of blocks in active use that have been allocated
225
+ through this pool.
226
+
227
+ .. attribute:: managed_bytes
228
+
229
+ "Managed" memory is "active" and "held" memory.
230
+
231
+ .. versionadded:: 2021.1.2
232
+
233
+ .. attribute:: active_bytes
234
+
235
+ "Active" bytes are bytes under the control of the application.
236
+ This may be smaller than the actual allocated size reflected
237
+ in :attr:`managed_bytes`.
238
+
239
+ .. versionadded:: 2021.1.2
240
+
241
+
242
+ .. method:: free_held
243
+
244
+ Free all unused memory that the pool is currently holding.
245
+
246
+ .. method:: stop_holding
247
+
248
+ Instruct the memory to start immediately freeing memory returned
249
+ to it, instead of holding it for future allocations.
250
+ Implicitly calls :meth:`free_held`.
251
+ This is useful as a cleanup action when a memory pool falls out
252
+ of use.
253
+ """
254
+
255
+
256
+ def _monkeypatch_docstrings():
257
+ from pytools.codegen import remove_common_indentation
258
+
259
+ PooledBuffer.__doc__ = """
260
+ An object representing a :class:`MemoryPool`-based allocation of
261
+ :class:`~pyopencl.Buffer`-style device memory. Analogous to
262
+ :class:`~pyopencl.Buffer`, however once this object is deleted, its
263
+ associated device memory is returned to the pool.
264
+
265
+ Is a :class:`pyopencl.MemoryObject`.
266
+ """
267
+
268
+ AllocatorBase.__doc__ = """
269
+ An interface implemented by various memory allocation functions
270
+ in :mod:`pyopencl`.
271
+
272
+ .. automethod:: __call__
273
+
274
+ Allocate and return a :class:`pyopencl.Buffer` of the given *size*.
275
+ """
276
+
277
+ # {{{ DeferredAllocator
278
+
279
+ DeferredAllocator.__doc__ = """
280
+ *mem_flags* takes its values from :class:`pyopencl.mem_flags` and corresponds
281
+ to the *flags* argument of :class:`pyopencl.Buffer`. DeferredAllocator
282
+ has the same semantics as regular OpenCL buffer allocation, i.e. it may
283
+ promise memory to be available that may (in any call to a buffer-using
284
+ CL function) turn out to not exist later on. (Allocations in CL are
285
+ bound to contexts, not devices, and memory availability depends on which
286
+ device the buffer is used with.)
287
+
288
+ Implements :class:`AllocatorBase`.
289
+
290
+ .. versionchanged :: 2013.1
291
+
292
+ ``CLAllocator`` was deprecated and replaced
293
+ by :class:`DeferredAllocator`.
294
+
295
+ .. method:: __init__(context, mem_flags=pyopencl.mem_flags.READ_WRITE)
296
+
297
+ .. automethod:: __call__
298
+
299
+ Allocate a :class:`pyopencl.Buffer` of the given *size*.
300
+
301
+ .. versionchanged :: 2020.2
302
+
303
+ The allocator will succeed even for allocations of size zero,
304
+ returning *None*.
305
+ """
306
+
307
+ # }}}
308
+
309
+ # {{{ ImmediateAllocator
310
+
311
+ ImmediateAllocator.__doc__ = """
312
+ *mem_flags* takes its values from :class:`pyopencl.mem_flags` and corresponds
313
+ to the *flags* argument of :class:`pyopencl.Buffer`.
314
+ :class:`ImmediateAllocator` will attempt to ensure at allocation time that
315
+ allocated memory is actually available. If no memory is available, an
316
+ out-of-memory error is reported at allocation time.
317
+
318
+ Implements :class:`AllocatorBase`.
319
+
320
+ .. versionadded:: 2013.1
321
+
322
+ .. method:: __init__(queue, mem_flags=pyopencl.mem_flags.READ_WRITE)
323
+
324
+ .. automethod:: __call__
325
+
326
+ Allocate a :class:`pyopencl.Buffer` of the given *size*.
327
+
328
+ .. versionchanged :: 2020.2
329
+
330
+ The allocator will succeed even for allocations of size zero,
331
+ returning *None*.
332
+ """
333
+
334
+ # }}}
335
+
336
+ # {{{ MemoryPool
337
+
338
+ MemoryPool.__doc__ = remove_common_indentation("""
339
+ A memory pool for OpenCL device memory in :class:`pyopencl.Buffer` form.
340
+ *allocator* must be an instance of one of the above classes, and should be
341
+ an :class:`ImmediateAllocator`. The memory pool assumes that allocation
342
+ failures are reported by the allocator immediately, and not in the
343
+ OpenCL-typical deferred manner.
344
+
345
+ Implements :class:`AllocatorBase`.
346
+
347
+ .. versionchanged:: 2019.1
348
+
349
+ Current bin allocation behavior documented, *leading_bits_in_bin_id*
350
+ added.
351
+
352
+ .. automethod:: __init__
353
+
354
+ .. automethod:: allocate
355
+
356
+ Return a :class:`PooledBuffer` of the given *size*.
357
+
358
+ .. automethod:: __call__
359
+
360
+ Synonym for :meth:`allocate` to match :class:`AllocatorBase`.
361
+
362
+ .. versionadded:: 2011.2
363
+ """) + _MEMPOOL_IFACE_DOCS
364
+
365
+ # }}}
366
+
367
+
368
+ _monkeypatch_docstrings()
369
+
370
+
371
+ def _monkeypatch_svm_docstrings():
372
+ from pytools.codegen import remove_common_indentation
373
+
374
+ # {{{ PooledSVM
375
+
376
+ PooledSVM.__doc__ = ( # pylint: disable=possibly-used-before-assignment
377
+ """An object representing a :class:`SVMPool`-based allocation of
378
+ :ref:`svm`. Analogous to :class:`~pyopencl.SVMAllocation`, however once
379
+ this object is deleted, its associated device memory is returned to the
380
+ pool from which it came.
381
+
382
+ .. versionadded:: 2022.2
383
+
384
+ .. note::
385
+
386
+ If the :class:`SVMAllocator` for the :class:`SVMPool` that allocated an
387
+ object of this type is associated with an (in-order)
388
+ :class:`~pyopencl.CommandQueue`, sufficient synchronization is provided
389
+ to ensure operations enqueued before deallocation complete before
390
+ operations from a different use (possibly in a different queue) are
391
+ permitted to start. This applies when :class:`release` is called and
392
+ also when the object is freed automatically by the garbage collector.
393
+
394
+ Is a :class:`pyopencl.SVMPointer`.
395
+
396
+ Supports structural equality and hashing.
397
+
398
+ .. automethod:: release
399
+
400
+ Return the held memory to the pool. See the note about synchronization
401
+ behavior during deallocation above.
402
+
403
+ .. automethod:: enqueue_release
404
+
405
+ Synonymous to :meth:`release`, for consistency with
406
+ :class:`~pyopencl.SVMAllocation`. Note that, unlike
407
+ :meth:`pyopencl.SVMAllocation.enqueue_release`, specifying a queue
408
+ or events to be waited for is not supported.
409
+
410
+ .. automethod:: bind_to_queue
411
+
412
+ Analogous to :meth:`pyopencl.SVMAllocation.bind_to_queue`.
413
+
414
+ .. automethod:: unbind_from_queue
415
+
416
+ Analogous to :meth:`pyopencl.SVMAllocation.unbind_from_queue`.
417
+ """)
418
+
419
+ # }}}
420
+
421
+ # {{{ SVMAllocator
422
+
423
+ SVMAllocator.__doc__ = ( # pylint: disable=possibly-used-before-assignment
424
+ """
425
+ .. versionadded:: 2022.2
426
+
427
+ .. automethod:: __init__
428
+
429
+ :arg flags: See :class:`~pyopencl.svm_mem_flags`.
430
+ :arg queue: If not specified, allocations will be freed
431
+ eagerly, irrespective of whether pending/enqueued operations
432
+ are still using the memory.
433
+
434
+ If specified, deallocation of memory will be enqueued
435
+ with the given queue, and will only be performed
436
+ after previously-enqueue operations in the queue have
437
+ completed.
438
+
439
+ It is an error to specify an out-of-order queue.
440
+
441
+ .. warning::
442
+
443
+ Not specifying a queue will typically lead to undesired
444
+ behavior, including crashes and memory corruption.
445
+ See the warning in :ref:`svm`.
446
+
447
+ .. automethod:: __call__
448
+
449
+ Return a :class:`~pyopencl.SVMAllocation` of the given *size*.
450
+ """)
451
+
452
+ # }}}
453
+
454
+ # {{{ SVMPool
455
+
456
+ SVMPool.__doc__ = ( # pylint: disable=possibly-used-before-assignment
457
+ remove_common_indentation("""
458
+ A memory pool for OpenCL device memory in :ref:`SVM <svm>` form.
459
+ *allocator* must be an instance of :class:`SVMAllocator`.
460
+
461
+ .. versionadded:: 2022.2
462
+
463
+ .. automethod:: __init__
464
+ .. automethod:: __call__
465
+
466
+ Return a :class:`PooledSVM` of the given *size*.
467
+ """) + _MEMPOOL_IFACE_DOCS)
468
+
469
+ # }}}
470
+
471
+
472
+ if get_cl_header_version() >= (2, 0):
473
+ _monkeypatch_svm_docstrings()
474
+
475
+ # }}}
476
+
477
+
478
+ # {{{ first-arg caches
479
+
480
+ _first_arg_dependent_caches: list[dict[Hashable, object]] = []
481
+
482
+
483
+ RetT = TypeVar("RetT")
484
+ P = ParamSpec("P")
485
+
486
+
487
+ def first_arg_dependent_memoize(
488
+ func: Callable[Concatenate[object, P], RetT]
489
+ ) -> Callable[Concatenate[object, P], RetT]:
490
+ def wrapper(cl_object: Hashable, *args: P.args, **kwargs: P.kwargs) -> RetT:
491
+ """Provides memoization for a function. Typically used to cache
492
+ things that get created inside a :class:`pyopencl.Context`, e.g. programs
493
+ and kernels. Assumes that the first argument of the decorated function is
494
+ an OpenCL object that might go away, such as a :class:`pyopencl.Context` or
495
+ a :class:`pyopencl.CommandQueue`, and based on which we might want to clear
496
+ the cache.
497
+
498
+ .. versionadded:: 2011.2
499
+ """
500
+ if kwargs:
501
+ cache_key = (args, frozenset(kwargs.items()))
502
+ else:
503
+ cache_key = (args,)
504
+
505
+ ctx_dict: dict[Hashable, dict[Hashable, RetT]]
506
+ try:
507
+ ctx_dict = func._pyopencl_first_arg_dep_memoize_dic # pyright: ignore[reportFunctionMemberAccess]
508
+ except AttributeError:
509
+ # FIXME: This may keep contexts alive longer than desired.
510
+ # But I guess since the memory in them is freed, who cares.
511
+ ctx_dict = func._pyopencl_first_arg_dep_memoize_dic = {} # pyright: ignore[reportFunctionMemberAccess]
512
+ _first_arg_dependent_caches.append(ctx_dict)
513
+
514
+ try:
515
+ return ctx_dict[cl_object][cache_key]
516
+ except KeyError:
517
+ arg_dict = ctx_dict.setdefault(cl_object, {})
518
+ result = func(cl_object, *args, **kwargs)
519
+ arg_dict[cache_key] = result
520
+ return result
521
+
522
+ from functools import update_wrapper
523
+ update_wrapper(wrapper, func)
524
+ return wrapper
525
+
526
+
527
+ context_dependent_memoize = first_arg_dependent_memoize
528
+
529
+
530
+ def first_arg_dependent_memoize_nested(nested_func):
531
+ """Provides memoization for nested functions. Typically used to cache
532
+ things that get created inside a :class:`pyopencl.Context`, e.g. programs
533
+ and kernels. Assumes that the first argument of the decorated function is
534
+ an OpenCL object that might go away, such as a :class:`pyopencl.Context` or
535
+ a :class:`pyopencl.CommandQueue`, and will therefore respond to
536
+ :func:`clear_first_arg_caches`.
537
+
538
+ .. versionadded:: 2013.1
539
+ """
540
+
541
+ from functools import wraps
542
+ cache_dict_name = intern("_memoize_inner_dic_%s_%s_%d"
543
+ % (nested_func.__name__, nested_func.__code__.co_filename,
544
+ nested_func.__code__.co_firstlineno))
545
+
546
+ from inspect import currentframe
547
+
548
+ # prevent ref cycle
549
+ try:
550
+ caller_frame = currentframe().f_back
551
+ cache_context = caller_frame.f_globals[
552
+ caller_frame.f_code.co_name]
553
+ finally:
554
+ # del caller_frame
555
+ pass
556
+
557
+ try:
558
+ cache_dict = getattr(cache_context, cache_dict_name)
559
+ except AttributeError:
560
+ cache_dict = {}
561
+ _first_arg_dependent_caches.append(cache_dict)
562
+ setattr(cache_context, cache_dict_name, cache_dict)
563
+
564
+ @wraps(nested_func)
565
+ def new_nested_func(cl_object, *args):
566
+ try:
567
+ return cache_dict[cl_object][args]
568
+ except KeyError:
569
+ arg_dict = cache_dict.setdefault(cl_object, {})
570
+ result = nested_func(cl_object, *args)
571
+ arg_dict[args] = result
572
+ return result
573
+
574
+ return new_nested_func
575
+
576
+
577
+ def clear_first_arg_caches():
578
+ """Empties all first-argument-dependent memoization caches. Also releases
579
+ all held reference contexts. If it is important to you that the
580
+ program detaches from its context, you might need to call this
581
+ function to free all remaining references to your context.
582
+
583
+ .. versionadded:: 2011.2
584
+ """
585
+ for cache in _first_arg_dependent_caches:
586
+ cache.clear()
587
+
588
+
589
+ import atexit
590
+
591
+
592
+ if TYPE_CHECKING:
593
+ import pyopencl as cl
594
+
595
+
596
+ atexit.register(clear_first_arg_caches)
597
+
598
+ # }}}
599
+
600
+
601
+ # {{{ pytest fixtures
602
+
603
+ class _ContextFactory:
604
+ device: cl.Device
605
+
606
+ def __init__(self, device: cl.Device):
607
+ self.device = device
608
+
609
+ def __call__(self):
610
+ # Get rid of leftovers from past tests.
611
+ # CL implementations are surprisingly limited in how many
612
+ # simultaneous contexts they allow...
613
+ clear_first_arg_caches()
614
+
615
+ from gc import collect
616
+ collect()
617
+
618
+ import pyopencl as cl
619
+ return cl.Context([self.device])
620
+
621
+ @override
622
+ def __str__(self) -> str:
623
+ # Don't show address, so that parallel test collection works
624
+ return ("<context factory for <pyopencl.Device '%s' on '%s'>>" %
625
+ (self.device.name.strip(),
626
+ self.device.platform.name.strip()))
627
+
628
+
629
+ DeviceOrPlatformT = TypeVar("DeviceOrPlatformT", "cl.Device", "cl.Platform")
630
+
631
+
632
+ def _find_cl_obj(
633
+ objs: Sequence[DeviceOrPlatformT],
634
+ identifier: str
635
+ ) -> DeviceOrPlatformT:
636
+ try:
637
+ num = int(identifier)
638
+ except Exception:
639
+ pass
640
+ else:
641
+ return objs[num]
642
+
643
+ for obj in objs:
644
+ if identifier.lower() in (obj.name + " " + obj.vendor).lower():
645
+ return obj
646
+ raise RuntimeError("object '%s' not found" % identifier)
647
+
648
+
649
+ def get_test_platforms_and_devices(
650
+ plat_dev_string: str | None = None
651
+ ):
652
+ """Parse a string of the form 'PYOPENCL_TEST=0:0,1;intel:i5'.
653
+
654
+ :return: list of tuples (platform, [device, device, ...])
655
+ """
656
+
657
+ import pyopencl as cl
658
+
659
+ if plat_dev_string is None:
660
+ import os
661
+ plat_dev_string = os.environ.get("PYOPENCL_TEST", None)
662
+
663
+ if plat_dev_string:
664
+ result: list[tuple[cl.Platform, list[cl.Device]]] = []
665
+
666
+ for entry in plat_dev_string.split(";"):
667
+ lhsrhs = entry.split(":")
668
+
669
+ if len(lhsrhs) == 1:
670
+ platform = _find_cl_obj(cl.get_platforms(), lhsrhs[0])
671
+ result.append((platform, platform.get_devices()))
672
+
673
+ elif len(lhsrhs) != 2:
674
+ raise RuntimeError("invalid syntax of PYOPENCL_TEST")
675
+ else:
676
+ plat_str, dev_strs = lhsrhs
677
+
678
+ platform = _find_cl_obj(cl.get_platforms(), plat_str)
679
+ devs = platform.get_devices()
680
+ result.append(
681
+ (platform,
682
+ [_find_cl_obj(devs, dev_id)
683
+ for dev_id in dev_strs.split(",")]))
684
+
685
+ return result
686
+
687
+ else:
688
+ return [
689
+ (platform, platform.get_devices())
690
+ for platform in cl.get_platforms()]
691
+
692
+
693
+ def get_pyopencl_fixture_arg_names(metafunc, extra_arg_names=None):
694
+ if extra_arg_names is None:
695
+ extra_arg_names = []
696
+
697
+ supported_arg_names = [
698
+ "platform", "device",
699
+ "ctx_factory", "ctx_getter",
700
+ *extra_arg_names
701
+ ]
702
+
703
+ arg_names = []
704
+ for arg in supported_arg_names:
705
+ if arg not in metafunc.fixturenames:
706
+ continue
707
+
708
+ if arg == "ctx_getter":
709
+ from warnings import warn
710
+ warn(
711
+ "The 'ctx_getter' arg is deprecated in favor of 'ctx_factory'.",
712
+ DeprecationWarning, stacklevel=2)
713
+
714
+ arg_names.append(arg)
715
+
716
+ return arg_names
717
+
718
+
719
+ def get_pyopencl_fixture_arg_values():
720
+ import pyopencl as cl
721
+
722
+ arg_values = []
723
+ for platform, devices in get_test_platforms_and_devices():
724
+ for device in devices:
725
+ arg_dict = {
726
+ "platform": platform,
727
+ "device": device,
728
+ "ctx_factory": _ContextFactory(device),
729
+ "ctx_getter": _ContextFactory(device)
730
+ }
731
+ arg_values.append(arg_dict)
732
+
733
+ def idfn(val):
734
+ if isinstance(val, cl.Platform):
735
+ # Don't show address, so that parallel test collection works
736
+ return f"<pyopencl.Platform '{val.name}'>"
737
+ else:
738
+ return str(val)
739
+
740
+ return arg_values, idfn
741
+
742
+
743
+ def pytest_generate_tests_for_pyopencl(metafunc):
744
+ """Using the line::
745
+
746
+ from pyopencl.tools import pytest_generate_tests_for_pyopencl
747
+ as pytest_generate_tests
748
+
749
+ in your `pytest <https://docs.pytest.org/en/latest/>`__ test scripts allows
750
+ you to use the arguments *ctx_factory*, *device*, or *platform* in your test
751
+ functions, and they will automatically be run for each OpenCL device/platform
752
+ in the system, as appropriate.
753
+
754
+ The following two environment variabls is also supported to control
755
+ device/platform choice::
756
+
757
+ PYOPENCL_TEST=0:0,1;intel=i5,i7
758
+ """
759
+
760
+ arg_names = get_pyopencl_fixture_arg_names(metafunc)
761
+ if not arg_names:
762
+ return
763
+
764
+ arg_values, ids = get_pyopencl_fixture_arg_values()
765
+ arg_values = [
766
+ tuple(arg_dict[name] for name in arg_names)
767
+ for arg_dict in arg_values
768
+ ]
769
+
770
+ metafunc.parametrize(arg_names, arg_values, ids=ids)
771
+
772
+ # }}}
773
+
774
+
775
+ # {{{ C argument lists
776
+
777
+ class Argument(ABC):
778
+ """
779
+ .. automethod:: declarator
780
+ """
781
+
782
+ @abstractmethod
783
+ def declarator(self) -> str:
784
+ pass
785
+
786
+
787
+ @dataclass(frozen=True, init=False)
788
+ class DtypedArgument(Argument, ABC):
789
+ """
790
+ .. autoattribute:: name
791
+ .. autoattribute:: dtype
792
+ """
793
+ dtype: np.dtype[Any]
794
+ name: str
795
+
796
+ def __init__(self, dtype: DTypeLike, name: str) -> None:
797
+ object.__setattr__(self, "name", name)
798
+ object.__setattr__(self, "dtype", np.dtype(dtype))
799
+
800
+
801
+ @dataclass(frozen=True)
802
+ class VectorArg(DtypedArgument):
803
+ """Inherits from :class:`DtypedArgument`.
804
+
805
+ .. automethod:: __init__
806
+ """
807
+ with_offset: bool
808
+
809
+ def __init__(self, dtype: DTypeLike, name: str, with_offset: bool = False):
810
+ super().__init__(dtype, name)
811
+ object.__setattr__(self, "with_offset", with_offset)
812
+
813
+ @override
814
+ def declarator(self) -> str:
815
+ if self.with_offset:
816
+ # Two underscores -> less likelihood of a name clash.
817
+ return "__global {} *{}__base, long {}__offset".format(
818
+ dtype_to_ctype(self.dtype), self.name, self.name)
819
+ else:
820
+ result = "__global {} *{}".format(dtype_to_ctype(self.dtype), self.name)
821
+
822
+ return result
823
+
824
+
825
+ @dataclass(frozen=True, init=False)
826
+ class ScalarArg(DtypedArgument):
827
+ """Inherits from :class:`DtypedArgument`."""
828
+
829
+ @override
830
+ def declarator(self) -> str:
831
+ return "{} {}".format(dtype_to_ctype(self.dtype), self.name)
832
+
833
+
834
+ @dataclass(frozen=True)
835
+ class OtherArg(Argument):
836
+ decl: str
837
+ name: str
838
+
839
+ @override
840
+ def declarator(self) -> str:
841
+ return self.decl
842
+
843
+
844
+ def parse_c_arg(c_arg: str, with_offset: bool = False) -> DtypedArgument:
845
+ for aspace in ["__local", "__constant"]:
846
+ if aspace in c_arg:
847
+ raise RuntimeError("cannot deal with local or constant "
848
+ "OpenCL address spaces in C argument lists ")
849
+
850
+ c_arg = c_arg.replace("__global", "")
851
+
852
+ if with_offset:
853
+ def vec_arg_factory(dtype, name):
854
+ return VectorArg(dtype, name, with_offset=True)
855
+ else:
856
+ vec_arg_factory = VectorArg
857
+
858
+ from pyopencl.compyte.dtypes import parse_c_arg_backend
859
+ return parse_c_arg_backend(c_arg, ScalarArg, vec_arg_factory)
860
+
861
+
862
+ def parse_arg_list(
863
+ arguments: str | list[str] | Sequence[DtypedArgument],
864
+ with_offset: bool = False) -> list[DtypedArgument]:
865
+ """Parse a list of kernel arguments. *arguments* may be a comma-separate
866
+ list of C declarators in a string, a list of strings representing C
867
+ declarators, or :class:`Argument` objects.
868
+ """
869
+
870
+ if isinstance(arguments, str):
871
+ arguments = arguments.split(",")
872
+
873
+ def parse_single_arg(obj: str | DtypedArgument) -> DtypedArgument:
874
+ if isinstance(obj, str):
875
+ from pyopencl.tools import parse_c_arg
876
+ return parse_c_arg(obj, with_offset=with_offset)
877
+ else:
878
+ assert isinstance(obj, DtypedArgument)
879
+ return obj
880
+
881
+ return [parse_single_arg(arg) for arg in arguments]
882
+
883
+
884
+ def get_arg_list_arg_types(arg_types):
885
+ result = []
886
+
887
+ for arg_type in arg_types:
888
+ if isinstance(arg_type, ScalarArg):
889
+ result.append(arg_type.dtype)
890
+ elif isinstance(arg_type, VectorArg):
891
+ result.append(arg_type)
892
+ else:
893
+ raise RuntimeError("arg type not understood: %s" % type(arg_type))
894
+
895
+ return tuple(result)
896
+
897
+
898
+ def get_arg_list_scalar_arg_dtypes(
899
+ arg_types: Sequence[DtypedArgument]
900
+ ) -> list[np.dtype | None]:
901
+ result: list[np.dtype | None] = []
902
+
903
+ for arg_type in arg_types:
904
+ if isinstance(arg_type, ScalarArg):
905
+ result.append(arg_type.dtype)
906
+ elif isinstance(arg_type, VectorArg):
907
+ result.append(None)
908
+ if arg_type.with_offset:
909
+ result.append(np.dtype(np.int64))
910
+ else:
911
+ raise RuntimeError(f"arg type not understood: {type(arg_type)}")
912
+
913
+ return result
914
+
915
+
916
+ def get_arg_offset_adjuster_code(arg_types: Sequence[Argument]) -> str:
917
+ result: list[str] = []
918
+
919
+ for arg_type in arg_types:
920
+ if isinstance(arg_type, VectorArg) and arg_type.with_offset:
921
+ result.append("__global %(type)s *%(name)s = "
922
+ "(__global %(type)s *) "
923
+ "((__global char *) %(name)s__base + %(name)s__offset);"
924
+ % {
925
+ "type": dtype_to_ctype(arg_type.dtype),
926
+ "name": arg_type.name})
927
+
928
+ return "\n".join(result)
929
+
930
+ # }}}
931
+
932
+
933
+ def get_gl_sharing_context_properties():
934
+ import pyopencl as cl
935
+
936
+ ctx_props = cl.context_properties
937
+
938
+ from OpenGL import platform as gl_platform
939
+
940
+ props = []
941
+
942
+ import sys
943
+ if sys.platform in ["linux", "linux2"]:
944
+ from OpenGL import GLX
945
+ props.append(
946
+ (ctx_props.GL_CONTEXT_KHR, GLX.glXGetCurrentContext()))
947
+ props.append(
948
+ (ctx_props.GLX_DISPLAY_KHR,
949
+ GLX.glXGetCurrentDisplay()))
950
+ elif sys.platform == "win32":
951
+ from OpenGL import WGL
952
+ props.append(
953
+ (ctx_props.GL_CONTEXT_KHR, gl_platform.GetCurrentContext()))
954
+ props.append(
955
+ (ctx_props.WGL_HDC_KHR,
956
+ WGL.wglGetCurrentDC()))
957
+ elif sys.platform == "darwin":
958
+ props.append(
959
+ (ctx_props.CONTEXT_PROPERTY_USE_CGL_SHAREGROUP_APPLE,
960
+ cl.get_apple_cgl_share_group()))
961
+ else:
962
+ raise NotImplementedError("platform '%s' not yet supported"
963
+ % sys.platform)
964
+
965
+ return props
966
+
967
+
968
+ class _CDeclList:
969
+ def __init__(self, device):
970
+ self.device = device
971
+ self.declared_dtypes = set()
972
+ self.declarations = []
973
+ self.saw_double = False
974
+ self.saw_complex = False
975
+
976
+ def add_dtype(self, dtype):
977
+ dtype = np.dtype(dtype)
978
+
979
+ if dtype in (np.float64, np.complex128):
980
+ self.saw_double = True
981
+
982
+ if dtype.kind == "c":
983
+ self.saw_complex = True
984
+
985
+ if dtype.kind != "V":
986
+ return
987
+
988
+ if dtype in self.declared_dtypes:
989
+ return
990
+
991
+ import pyopencl.cltypes
992
+ if dtype in pyopencl.cltypes.vec_type_to_scalar_and_count:
993
+ return
994
+
995
+ if hasattr(dtype, "subdtype") and dtype.subdtype is not None:
996
+ self.add_dtype(dtype.subdtype[0])
997
+ return
998
+
999
+ for _name, field_data in sorted(dtype.fields.items()):
1000
+ field_dtype, _offset = field_data[:2]
1001
+ self.add_dtype(field_dtype)
1002
+
1003
+ _, cdecl = match_dtype_to_c_struct(
1004
+ self.device, dtype_to_ctype(dtype), dtype)
1005
+
1006
+ self.declarations.append(cdecl)
1007
+ self.declared_dtypes.add(dtype)
1008
+
1009
+ def visit_arguments(self, arguments):
1010
+ for arg in arguments:
1011
+ dtype = arg.dtype
1012
+ if dtype in (np.float64, np.complex128):
1013
+ self.saw_double = True
1014
+
1015
+ if dtype.kind == "c":
1016
+ self.saw_complex = True
1017
+
1018
+ def get_declarations(self):
1019
+ result = "\n\n".join(self.declarations)
1020
+
1021
+ if self.saw_complex:
1022
+ result = (
1023
+ "#include <pyopencl-complex.h>\n\n"
1024
+ + result)
1025
+
1026
+ if self.saw_double:
1027
+ result = (
1028
+ """
1029
+ #if __OPENCL_C_VERSION__ < 120
1030
+ #pragma OPENCL EXTENSION cl_khr_fp64: enable
1031
+ #endif
1032
+ #define PYOPENCL_DEFINE_CDOUBLE
1033
+ """
1034
+ + result)
1035
+
1036
+ return result
1037
+
1038
+
1039
+ @memoize
1040
+ def match_dtype_to_c_struct(device, name, dtype, context=None):
1041
+ """Return a tuple ``(dtype, c_decl)`` such that the C struct declaration
1042
+ in ``c_decl`` and the structure :class:`numpy.dtype` instance ``dtype``
1043
+ have the same memory layout.
1044
+
1045
+ Note that *dtype* may be modified from the value that was passed in,
1046
+ for example to insert padding.
1047
+
1048
+ (As a remark on implementation, this routine runs a small kernel on
1049
+ the given *device* to ensure that :mod:`numpy` and C offsets and
1050
+ sizes match.)
1051
+
1052
+ .. versionadded:: 2013.1
1053
+
1054
+ This example explains the use of this function::
1055
+
1056
+ >>> import numpy as np
1057
+ >>> import pyopencl as cl
1058
+ >>> import pyopencl.tools
1059
+ >>> ctx = cl.create_some_context()
1060
+ >>> dtype = np.dtype([("id", np.uint32), ("value", np.float32)])
1061
+ >>> dtype, c_decl = pyopencl.tools.match_dtype_to_c_struct(
1062
+ ... ctx.devices[0], 'id_val', dtype)
1063
+ >>> print c_decl
1064
+ typedef struct {
1065
+ unsigned id;
1066
+ float value;
1067
+ } id_val;
1068
+ >>> print dtype
1069
+ [('id', '<u4'), ('value', '<f4')]
1070
+ >>> cl.tools.get_or_register_dtype('id_val', dtype)
1071
+
1072
+ As this example shows, it is important to call
1073
+ :func:`get_or_register_dtype` on the modified ``dtype`` returned by this
1074
+ function, not the original one.
1075
+ """
1076
+
1077
+ import pyopencl as cl
1078
+
1079
+ fields = sorted(dtype.fields.items(),
1080
+ key=lambda name_dtype_offset: name_dtype_offset[1][1])
1081
+
1082
+ c_fields = []
1083
+ for field_name, dtype_and_offset in fields:
1084
+ field_dtype, _offset = dtype_and_offset[:2]
1085
+ if hasattr(field_dtype, "subdtype") and field_dtype.subdtype is not None:
1086
+ array_dtype = field_dtype.subdtype[0]
1087
+ if hasattr(array_dtype, "subdtype") and array_dtype.subdtype is not None:
1088
+ raise NotImplementedError("nested array dtypes are not supported")
1089
+ array_dims = field_dtype.subdtype[1]
1090
+ dims_str = ""
1091
+ try:
1092
+ for dim in array_dims:
1093
+ dims_str += "[%d]" % dim
1094
+ except TypeError:
1095
+ dims_str = "[%d]" % array_dims
1096
+ c_fields.append(" {} {}{};".format(
1097
+ dtype_to_ctype(array_dtype), field_name, dims_str)
1098
+ )
1099
+ else:
1100
+ c_fields.append(
1101
+ " {} {};".format(dtype_to_ctype(field_dtype), field_name))
1102
+
1103
+ c_decl = "typedef struct {{\n{}\n}} {};\n\n".format(
1104
+ "\n".join(c_fields),
1105
+ name)
1106
+
1107
+ cdl = _CDeclList(device)
1108
+ for _field_name, dtype_and_offset in fields:
1109
+ field_dtype, _offset = dtype_and_offset[:2]
1110
+ cdl.add_dtype(field_dtype)
1111
+
1112
+ pre_decls = cdl.get_declarations()
1113
+
1114
+ offset_code = "\n".join(
1115
+ "result[%d] = pycl_offsetof(%s, %s);" % (i+1, name, field_name)
1116
+ for i, (field_name, _) in enumerate(fields))
1117
+
1118
+ src = rf"""
1119
+ #define pycl_offsetof(st, m) \
1120
+ ((uint) ((__local char *) &(dummy.m) \
1121
+ - (__local char *)&dummy ))
1122
+
1123
+ {pre_decls}
1124
+
1125
+ {c_decl}
1126
+
1127
+ __kernel void get_size_and_offsets(__global uint *result)
1128
+ {{
1129
+ result[0] = sizeof({name});
1130
+ __local {name} dummy;
1131
+ {offset_code}
1132
+ }}
1133
+ """
1134
+
1135
+ if context is None:
1136
+ context = cl.Context([device])
1137
+
1138
+ queue = cl.CommandQueue(context)
1139
+
1140
+ prg = cl.Program(context, src)
1141
+ knl = prg.build(devices=[device]).get_size_and_offsets
1142
+
1143
+ import pyopencl.array
1144
+
1145
+ result_buf = cl.array.empty(queue, 1+len(fields), np.uint32)
1146
+ knl(queue, (1,), (1,), result_buf.data)
1147
+ queue.finish()
1148
+ size_and_offsets = result_buf.get()
1149
+
1150
+ size = int(size_and_offsets[0])
1151
+
1152
+ offsets = size_and_offsets[1:]
1153
+ if any(ofs >= size for ofs in offsets):
1154
+ # offsets not plausible
1155
+
1156
+ if dtype.itemsize == size:
1157
+ # If sizes match, use numpy's idea of the offsets.
1158
+ offsets = [dtype_and_offset[1]
1159
+ for field_name, dtype_and_offset in fields]
1160
+ else:
1161
+ raise RuntimeError(
1162
+ "OpenCL compiler reported offsetof() past sizeof() "
1163
+ "for struct layout on '%s'. "
1164
+ "This makes no sense, and it's usually indicates a "
1165
+ "compiler bug. "
1166
+ "Refusing to discover struct layout." % device)
1167
+
1168
+ result_buf.data.release()
1169
+ del knl
1170
+ del prg
1171
+ del queue
1172
+ del context
1173
+
1174
+ try:
1175
+ dtype_arg_dict = {
1176
+ "names": [field_name
1177
+ for field_name, (field_dtype, offset) in fields],
1178
+ "formats": [field_dtype
1179
+ for field_name, (field_dtype, offset) in fields],
1180
+ "offsets": [int(x) for x in offsets],
1181
+ "itemsize": int(size_and_offsets[0]),
1182
+ }
1183
+ dtype = np.dtype(dtype_arg_dict)
1184
+ if dtype.itemsize != size_and_offsets[0]:
1185
+ # "Old" versions of numpy (1.6.x?) silently ignore "itemsize". Boo.
1186
+ dtype_arg_dict["names"].append("_pycl_size_fixer")
1187
+ dtype_arg_dict["formats"].append(np.uint8)
1188
+ dtype_arg_dict["offsets"].append(int(size_and_offsets[0])-1)
1189
+ dtype = np.dtype(dtype_arg_dict)
1190
+ except NotImplementedError:
1191
+ def calc_field_type():
1192
+ total_size = 0
1193
+ padding_count = 0
1194
+ for offset, (field_name, (field_dtype, _)) in zip(
1195
+ offsets, fields, strict=True):
1196
+ if offset > total_size:
1197
+ padding_count += 1
1198
+ yield ("__pycl_padding%d" % padding_count,
1199
+ "V%d" % offset - total_size)
1200
+ yield field_name, field_dtype
1201
+ total_size = field_dtype.itemsize + offset
1202
+ dtype = np.dtype(list(calc_field_type()))
1203
+
1204
+ assert dtype.itemsize == size_and_offsets[0]
1205
+
1206
+ return dtype, c_decl
1207
+
1208
+
1209
+ @memoize
1210
+ def dtype_to_c_struct(device, dtype):
1211
+ if dtype.fields is None:
1212
+ return ""
1213
+
1214
+ import pyopencl.cltypes
1215
+ if dtype in pyopencl.cltypes.vec_type_to_scalar_and_count:
1216
+ # Vector types are built-in. Don't try to redeclare those.
1217
+ return ""
1218
+
1219
+ matched_dtype, c_decl = match_dtype_to_c_struct(
1220
+ device, dtype_to_ctype(dtype), dtype)
1221
+
1222
+ def dtypes_match():
1223
+ result = len(dtype.fields) == len(matched_dtype.fields)
1224
+
1225
+ for name, val in dtype.fields.items():
1226
+ result = result and matched_dtype.fields[name] == val
1227
+
1228
+ return result
1229
+
1230
+ assert dtypes_match()
1231
+
1232
+ return c_decl
1233
+
1234
+
1235
+ # {{{ code generation/templating helper
1236
+
1237
+ def _process_code_for_macro(code):
1238
+ code = code.replace("//CL//", "\n")
1239
+
1240
+ if "//" in code:
1241
+ raise RuntimeError("end-of-line comments ('//') may not be used in "
1242
+ "code snippets")
1243
+
1244
+ return code.replace("\n", " \\\n")
1245
+
1246
+
1247
+ class _SimpleTextTemplate:
1248
+ def __init__(self, txt):
1249
+ self.txt = txt
1250
+
1251
+ def render(self, context):
1252
+ return self.txt
1253
+
1254
+
1255
+ class _PrintfTextTemplate:
1256
+ def __init__(self, txt):
1257
+ self.txt = txt
1258
+
1259
+ def render(self, context):
1260
+ return self.txt % context
1261
+
1262
+
1263
+ class _MakoTextTemplate:
1264
+ def __init__(self, txt):
1265
+ from mako.template import Template
1266
+ self.template = Template(txt, strict_undefined=True)
1267
+
1268
+ def render(self, context):
1269
+ return self.template.render(**context)
1270
+
1271
+
1272
+ class _ArgumentPlaceholder:
1273
+ """A placeholder for subclasses of :class:`DtypedArgument`. This is needed
1274
+ because the concrete dtype of the argument is not known at template
1275
+ creation time--it may be a type alias that will only be filled in
1276
+ at run time. These types take the place of these proto-arguments until
1277
+ all types are known.
1278
+
1279
+ See also :class:`_TemplateRenderer.render_arg`.
1280
+ """
1281
+
1282
+ def __init__(self, typename, name, **extra_kwargs):
1283
+ self.typename = typename
1284
+ self.name = name
1285
+ self.extra_kwargs = extra_kwargs
1286
+
1287
+
1288
+ class _VectorArgPlaceholder(_ArgumentPlaceholder):
1289
+ target_class = VectorArg
1290
+
1291
+
1292
+ class _ScalarArgPlaceholder(_ArgumentPlaceholder):
1293
+ target_class = ScalarArg
1294
+
1295
+
1296
+ class _TemplateRenderer:
1297
+ def __init__(self, template, type_aliases, var_values, context=None,
1298
+ options=None):
1299
+ self.template = template
1300
+ self.type_aliases = dict(type_aliases)
1301
+ self.var_dict = dict(var_values)
1302
+
1303
+ for name in self.var_dict:
1304
+ if name.startswith("macro_"):
1305
+ self.var_dict[name] = _process_code_for_macro(
1306
+ self.var_dict[name])
1307
+
1308
+ self.context = context
1309
+ self.options = options
1310
+
1311
+ def __call__(self, txt):
1312
+ if txt is None:
1313
+ return txt
1314
+
1315
+ result = self.template.get_text_template(txt).render(self.var_dict)
1316
+
1317
+ return str(result)
1318
+
1319
+ def get_rendered_kernel(self, txt, kernel_name):
1320
+ import pyopencl as cl
1321
+ prg = cl.Program(self.context, self(txt)).build(self.options)
1322
+
1323
+ kernel_name_prefix = self.var_dict.get("kernel_name_prefix")
1324
+ if kernel_name_prefix is not None:
1325
+ kernel_name = kernel_name_prefix+kernel_name
1326
+
1327
+ return getattr(prg, kernel_name)
1328
+
1329
+ def parse_type(self, typename):
1330
+ if isinstance(typename, str):
1331
+ try:
1332
+ return self.type_aliases[typename]
1333
+ except KeyError:
1334
+ from pyopencl.compyte.dtypes import NAME_TO_DTYPE
1335
+ return NAME_TO_DTYPE[typename]
1336
+ else:
1337
+ return np.dtype(typename)
1338
+
1339
+ def render_arg(self, arg_placeholder):
1340
+ return arg_placeholder.target_class(
1341
+ self.parse_type(arg_placeholder.typename),
1342
+ arg_placeholder.name,
1343
+ **arg_placeholder.extra_kwargs)
1344
+
1345
+ _C_COMMENT_FINDER = re.compile(r"/\*.*?\*/")
1346
+
1347
+ def render_argument_list(self, *arg_lists, **kwargs):
1348
+ with_offset = kwargs.pop("with_offset", False)
1349
+ if kwargs:
1350
+ raise TypeError("unrecognized kwargs: " + ", ".join(kwargs))
1351
+
1352
+ all_args = []
1353
+
1354
+ for arg_list in arg_lists:
1355
+ if isinstance(arg_list, str):
1356
+ arg_list = str(
1357
+ self.template
1358
+ .get_text_template(arg_list).render(self.var_dict))
1359
+ arg_list = self._C_COMMENT_FINDER.sub("", arg_list)
1360
+ arg_list = arg_list.replace("\n", " ")
1361
+
1362
+ all_args.extend(arg_list.split(","))
1363
+ else:
1364
+ all_args.extend(arg_list)
1365
+
1366
+ if with_offset:
1367
+ def vec_arg_factory(typename, name):
1368
+ return _VectorArgPlaceholder(typename, name, with_offset=True)
1369
+ else:
1370
+ vec_arg_factory = _VectorArgPlaceholder
1371
+
1372
+ from pyopencl.compyte.dtypes import parse_c_arg_backend
1373
+ parsed_args = []
1374
+ for arg in all_args:
1375
+ if isinstance(arg, str):
1376
+ arg = arg.strip()
1377
+ if not arg:
1378
+ continue
1379
+
1380
+ ph = parse_c_arg_backend(arg,
1381
+ _ScalarArgPlaceholder, vec_arg_factory,
1382
+ name_to_dtype=lambda x: x)
1383
+ parsed_arg = self.render_arg(ph)
1384
+
1385
+ elif isinstance(arg, Argument):
1386
+ parsed_arg = arg
1387
+ elif isinstance(arg, tuple):
1388
+ parsed_arg = ScalarArg(self.parse_type(arg[0]), arg[1])
1389
+ else:
1390
+ raise TypeError("unexpected argument type: %s" % type(arg))
1391
+
1392
+ parsed_args.append(parsed_arg)
1393
+
1394
+ return parsed_args
1395
+
1396
+ def get_type_decl_preamble(self, device, decl_type_names, arguments=None):
1397
+ cdl = _CDeclList(device)
1398
+
1399
+ for typename in decl_type_names:
1400
+ cdl.add_dtype(self.parse_type(typename))
1401
+
1402
+ if arguments is not None:
1403
+ cdl.visit_arguments(arguments)
1404
+
1405
+ for _, tv in sorted(self.type_aliases.items()):
1406
+ cdl.add_dtype(tv)
1407
+
1408
+ type_alias_decls = [
1409
+ "typedef {} {};".format(dtype_to_ctype(val), name)
1410
+ for name, val in sorted(self.type_aliases.items())
1411
+ ]
1412
+
1413
+ return cdl.get_declarations() + "\n" + "\n".join(type_alias_decls)
1414
+
1415
+
1416
+ class KernelTemplateBase:
1417
+ def __init__(self, template_processor=None):
1418
+ self.template_processor = template_processor
1419
+
1420
+ self.build_cache = {}
1421
+ _first_arg_dependent_caches.append(self.build_cache)
1422
+
1423
+ def get_preamble(self):
1424
+ pass
1425
+
1426
+ _TEMPLATE_PROCESSOR_PATTERN = re.compile(r"^//CL(?::([a-zA-Z0-9_]+))?//")
1427
+
1428
+ @memoize_method
1429
+ def get_text_template(self, txt):
1430
+ proc_match = self._TEMPLATE_PROCESSOR_PATTERN.match(txt)
1431
+ tpl_processor = None
1432
+
1433
+ if proc_match is not None:
1434
+ tpl_processor = proc_match.group(1)
1435
+ # chop off //CL// mark
1436
+ txt = txt[len(proc_match.group(0)):]
1437
+ if tpl_processor is None:
1438
+ tpl_processor = self.template_processor
1439
+
1440
+ if tpl_processor is None or tpl_processor == "none":
1441
+ return _SimpleTextTemplate(txt)
1442
+ elif tpl_processor == "printf":
1443
+ return _PrintfTextTemplate(txt)
1444
+ elif tpl_processor == "mako":
1445
+ return _MakoTextTemplate(txt)
1446
+ else:
1447
+ raise RuntimeError(
1448
+ "unknown template processor '%s'" % proc_match.group(1))
1449
+
1450
+ def get_renderer(self, type_aliases, var_values, context=None, options=None):
1451
+ return _TemplateRenderer(self, type_aliases, var_values)
1452
+
1453
+ def build_inner(self, context, *args, **kwargs):
1454
+ raise NotImplementedError
1455
+
1456
+ def build(self, context, *args, **kwargs):
1457
+ """Provide caching for an :meth:`build_inner`."""
1458
+
1459
+ cache_key = (context, args, tuple(sorted(kwargs.items())))
1460
+ try:
1461
+ return self.build_cache[cache_key]
1462
+ except KeyError:
1463
+ result = self.build_inner(context, *args, **kwargs)
1464
+ self.build_cache[cache_key] = result
1465
+ return result
1466
+
1467
+ # }}}
1468
+
1469
+
1470
+ # {{{ array_module
1471
+
1472
+ class _CLFakeArrayModule:
1473
+ def __init__(self, queue):
1474
+ self.queue = queue
1475
+
1476
+ @property
1477
+ def ndarray(self):
1478
+ from pyopencl.array import Array
1479
+ return Array
1480
+
1481
+ def dot(self, x, y):
1482
+ from pyopencl.array import dot
1483
+ return dot(x, y, queue=self.queue).get()
1484
+
1485
+ def vdot(self, x, y):
1486
+ from pyopencl.array import vdot
1487
+ return vdot(x, y, queue=self.queue).get()
1488
+
1489
+ def empty(self, shape, dtype, order="C"):
1490
+ from pyopencl.array import empty
1491
+ return empty(self.queue, shape, dtype, order=order)
1492
+
1493
+ def hstack(self, arrays):
1494
+ from pyopencl.array import hstack
1495
+ return hstack(arrays, self.queue)
1496
+
1497
+
1498
+ def array_module(a):
1499
+ if isinstance(a, np.ndarray):
1500
+ return np
1501
+ else:
1502
+ from pyopencl.array import Array
1503
+ if isinstance(a, Array):
1504
+ return _CLFakeArrayModule(a.queue)
1505
+ else:
1506
+ raise TypeError("array type not understood: %s" % type(a))
1507
+
1508
+ # }}}
1509
+
1510
+
1511
+ def is_spirv(s: str | bytes) -> TypeIs[bytes]:
1512
+ spirv_magic = b"\x07\x23\x02\x03"
1513
+ return (
1514
+ isinstance(s, bytes)
1515
+ and (
1516
+ s[:4] == spirv_magic
1517
+ or s[:4] == spirv_magic[::-1]))
1518
+
1519
+
1520
+ # {{{ numpy key types builder
1521
+
1522
+ class _NumpyTypesKeyBuilder(KeyBuilderBase):
1523
+ def update_for_VectorArg(self, key_hash, key): # noqa: N802
1524
+ self.rec(key_hash, key.dtype)
1525
+ self.update_for_str(key_hash, key.name)
1526
+ self.rec(key_hash, key.with_offset)
1527
+
1528
+ def update_for_type(self, key_hash, key):
1529
+ if issubclass(key, np.generic):
1530
+ self.update_for_str(key_hash, key.__name__)
1531
+ return
1532
+
1533
+ raise TypeError("unsupported type for persistent hash keying: %s"
1534
+ % type(key))
1535
+
1536
+ # }}}
1537
+
1538
+
1539
+ __all__ = [
1540
+ "AllocatorBase",
1541
+ "AllocatorBase",
1542
+ "Argument",
1543
+ "DeferredAllocator",
1544
+ "DtypedArgument",
1545
+ "ImmediateAllocator",
1546
+ "MemoryPool",
1547
+ "OtherArg",
1548
+ "PooledBuffer",
1549
+ "PooledSVM",
1550
+ "SVMAllocator",
1551
+ "SVMPool",
1552
+ "ScalarArg",
1553
+ "TypeNameNotKnown",
1554
+ "VectorArg",
1555
+ "bitlog2",
1556
+ "clear_first_arg_caches",
1557
+ "dtype_to_ctype",
1558
+ "first_arg_dependent_memoize",
1559
+ "get_or_register_dtype",
1560
+ "parse_arg_list",
1561
+ "pytest_generate_tests_for_pyopencl",
1562
+ "register_dtype",
1563
+ ]
1564
+
1565
+ # vim: foldmethod=marker