pyopencl 2026.1.1__cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pyopencl/.libs/libOpenCL-34a55fe4.so.1.0.0 +0 -0
  2. pyopencl/__init__.py +1995 -0
  3. pyopencl/_cl.cpython-314t-aarch64-linux-gnu.so +0 -0
  4. pyopencl/_cl.pyi +2009 -0
  5. pyopencl/_cluda.py +57 -0
  6. pyopencl/_monkeypatch.py +1104 -0
  7. pyopencl/_mymako.py +17 -0
  8. pyopencl/algorithm.py +1454 -0
  9. pyopencl/array.py +3530 -0
  10. pyopencl/bitonic_sort.py +245 -0
  11. pyopencl/bitonic_sort_templates.py +597 -0
  12. pyopencl/cache.py +553 -0
  13. pyopencl/capture_call.py +200 -0
  14. pyopencl/characterize/__init__.py +461 -0
  15. pyopencl/characterize/performance.py +240 -0
  16. pyopencl/cl/pyopencl-airy.cl +324 -0
  17. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  18. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  19. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  20. pyopencl/cl/pyopencl-complex.h +303 -0
  21. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  22. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  23. pyopencl/cl/pyopencl-random123/array.h +325 -0
  24. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  25. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  26. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  27. pyopencl/clmath.py +281 -0
  28. pyopencl/clrandom.py +412 -0
  29. pyopencl/cltypes.py +217 -0
  30. pyopencl/compyte/.gitignore +21 -0
  31. pyopencl/compyte/__init__.py +0 -0
  32. pyopencl/compyte/array.py +211 -0
  33. pyopencl/compyte/dtypes.py +314 -0
  34. pyopencl/compyte/pyproject.toml +49 -0
  35. pyopencl/elementwise.py +1288 -0
  36. pyopencl/invoker.py +417 -0
  37. pyopencl/ipython_ext.py +70 -0
  38. pyopencl/py.typed +0 -0
  39. pyopencl/reduction.py +829 -0
  40. pyopencl/scan.py +1921 -0
  41. pyopencl/tools.py +1680 -0
  42. pyopencl/typing.py +61 -0
  43. pyopencl/version.py +11 -0
  44. pyopencl-2026.1.1.dist-info/METADATA +108 -0
  45. pyopencl-2026.1.1.dist-info/RECORD +47 -0
  46. pyopencl-2026.1.1.dist-info/WHEEL +6 -0
  47. pyopencl-2026.1.1.dist-info/licenses/LICENSE +104 -0
pyopencl/algorithm.py ADDED
@@ -0,0 +1,1454 @@
1
+ """Algorithms built on scans."""
2
+ from __future__ import annotations
3
+
4
+
5
+ __copyright__ = """
6
+ Copyright 2011-2012 Andreas Kloeckner
7
+ Copyright 2017 Hao Gao
8
+ """
9
+
10
+ __license__ = """
11
+ Permission is hereby granted, free of charge, to any person
12
+ obtaining a copy of this software and associated documentation
13
+ files (the "Software"), to deal in the Software without
14
+ restriction, including without limitation the rights to use,
15
+ copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the
17
+ Software is furnished to do so, subject to the following
18
+ conditions:
19
+
20
+ The above copyright notice and this permission notice shall be
21
+ included in all copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
25
+ OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
27
+ HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
28
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
29
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
30
+ OTHER DEALINGS IN THE SOFTWARE.
31
+ """
32
+
33
+ from dataclasses import dataclass
34
+ from typing import TYPE_CHECKING
35
+
36
+ import numpy as np
37
+ from mako.template import Template
38
+
39
+ from pytools import memoize, memoize_method
40
+
41
+ import pyopencl as cl
42
+ import pyopencl.array as cl_array
43
+ from pyopencl.scan import GenericScanKernel, ScanTemplate
44
+ from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from pyopencl.elementwise import ElementwiseKernel
49
+
50
+
51
+ # {{{ "extra args" handling utility
52
+
53
+ def _extract_extra_args_types_values(extra_args):
54
+ if extra_args is None:
55
+ extra_args = []
56
+ from pyopencl.tools import ScalarArg, VectorArg
57
+
58
+ extra_args_types = []
59
+ extra_args_values = []
60
+ extra_wait_for = []
61
+ for name, val in extra_args:
62
+ if isinstance(val, cl_array.Array):
63
+ extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
64
+ extra_args_values.append(val)
65
+ extra_wait_for.extend(val.events)
66
+ elif isinstance(val, np.generic):
67
+ extra_args_types.append(ScalarArg(val.dtype, name))
68
+ extra_args_values.append(val)
69
+ else:
70
+ raise RuntimeError("argument '%d' not understood" % name)
71
+
72
+ return tuple(extra_args_types), extra_args_values, extra_wait_for
73
+
74
+ # }}}
75
+
76
+
77
+ # {{{ copy_if
78
+
79
+ _copy_if_template = ScanTemplate(
80
+ arguments="item_t *ary, item_t *out, scan_t *count",
81
+ input_expr="(%(predicate)s) ? 1 : 0",
82
+ scan_expr="a+b", neutral="0",
83
+ output_statement="""
84
+ if (prev_item != item) out[item-1] = ary[i];
85
+ if (i+1 == N) *count = item;
86
+ """,
87
+ template_processor="printf")
88
+
89
+
90
+ def copy_if(ary, predicate, extra_args=None, preamble="", queue=None, wait_for=None):
91
+ """Copy the elements of *ary* satisfying *predicate* to an output array.
92
+
93
+ :arg predicate: a C expression evaluating to a ``bool``, represented as a string.
94
+ The value to test is available as ``ary[i]``, and if the expression evaluates
95
+ to ``true``, then this value ends up in the output.
96
+ :arg extra_args: |scan_extra_args|
97
+ :arg preamble: |preamble|
98
+ :arg wait_for: |explain-waitfor|
99
+ :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
100
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
101
+ how many elements satisfied *predicate*, and *event* is a
102
+ :class:`pyopencl.Event` for dependency management. *out* is allocated
103
+ to the same length as *ary*, but only the first *count* entries carry
104
+ meaning.
105
+
106
+ .. versionadded:: 2013.1
107
+ """
108
+ if len(ary) > np.iinfo(np.int32).max:
109
+ scan_dtype = np.int64
110
+ else:
111
+ scan_dtype = np.int32
112
+
113
+ if wait_for is None:
114
+ wait_for = []
115
+
116
+ extra_args_types, extra_args_values, extra_wait_for = \
117
+ _extract_extra_args_types_values(extra_args)
118
+ wait_for = wait_for + extra_wait_for
119
+
120
+ knl = _copy_if_template.build(ary.context,
121
+ type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
122
+ var_values=(("predicate", predicate),),
123
+ more_preamble=preamble, more_arguments=extra_args_types)
124
+ out = cl_array.empty_like(ary)
125
+ count = ary._new_with_changes(data=None, offset=0,
126
+ shape=(), strides=(), dtype=scan_dtype)
127
+
128
+ evt = knl(ary, out, count, *extra_args_values,
129
+ queue=queue, wait_for=wait_for)
130
+
131
+ return out, count, evt
132
+
133
+ # }}}
134
+
135
+
136
+ # {{{ remove_if
137
+
138
+ def remove_if(ary, predicate, extra_args=None, preamble="",
139
+ queue=None, wait_for=None):
140
+ """Copy the elements of *ary* not satisfying *predicate* to an output array.
141
+
142
+ :arg predicate: a C expression evaluating to a ``bool``, represented as a string.
143
+ The value to test is available as ``ary[i]``, and if the expression evaluates
144
+ to ``false``, then this value ends up in the output.
145
+ :arg extra_args: |scan_extra_args|
146
+ :arg preamble: |preamble|
147
+ :arg wait_for: |explain-waitfor|
148
+ :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
149
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
150
+ how many elements did not satisfy *predicate*, and *event* is a
151
+ :class:`pyopencl.Event` for dependency management.
152
+
153
+ .. versionadded:: 2013.1
154
+ """
155
+ return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args,
156
+ preamble=preamble, queue=queue, wait_for=wait_for)
157
+
158
+ # }}}
159
+
160
+
161
+ # {{{ partition
162
+
163
+ _partition_template = ScanTemplate(
164
+ arguments=(
165
+ "item_t *ary, item_t *out_true, item_t *out_false, "
166
+ "scan_t *count_true"),
167
+ input_expr="(%(predicate)s) ? 1 : 0",
168
+ scan_expr="a+b", neutral="0",
169
+ output_statement="""//CL//
170
+ if (prev_item != item)
171
+ out_true[item-1] = ary[i];
172
+ else
173
+ out_false[i-item] = ary[i];
174
+ if (i+1 == N) *count_true = item;
175
+ """,
176
+ template_processor="printf")
177
+
178
+
179
+ def partition(ary, predicate, extra_args=None, preamble="",
180
+ queue=None, wait_for=None):
181
+ """Copy the elements of *ary* into one of two arrays depending on whether
182
+ they satisfy *predicate*.
183
+
184
+ :arg predicate: a C expression evaluating to a ``bool``, represented as a string.
185
+ The value to test is available as ``ary[i]``.
186
+ :arg extra_args: |scan_extra_args|
187
+ :arg preamble: |preamble|
188
+ :arg wait_for: |explain-waitfor|
189
+ :returns: a tuple *(out_true, out_false, count, event)* where *count*
190
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
191
+ how many elements satisfied the predicate, and *event* is a
192
+ :class:`pyopencl.Event` for dependency management.
193
+
194
+ .. versionadded:: 2013.1
195
+ """
196
+ if len(ary) > np.iinfo(np.uint32).max:
197
+ scan_dtype = np.uint64
198
+ else:
199
+ scan_dtype = np.uint32
200
+
201
+ if wait_for is None:
202
+ wait_for = []
203
+
204
+ extra_args_types, extra_args_values, extra_wait_for = \
205
+ _extract_extra_args_types_values(extra_args)
206
+ wait_for = wait_for + extra_wait_for
207
+
208
+ knl = _partition_template.build(
209
+ ary.context,
210
+ type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
211
+ var_values=(("predicate", predicate),),
212
+ more_preamble=preamble, more_arguments=extra_args_types)
213
+
214
+ out_true = cl_array.empty_like(ary)
215
+ out_false = cl_array.empty_like(ary)
216
+ count = ary._new_with_changes(data=None, offset=0,
217
+ shape=(), strides=(), dtype=scan_dtype)
218
+
219
+ evt = knl(ary, out_true, out_false, count, *extra_args_values,
220
+ queue=queue, wait_for=wait_for)
221
+
222
+ return out_true, out_false, count, evt
223
+
224
+ # }}}
225
+
226
+
227
+ # {{{ unique
228
+
229
+ _unique_template = ScanTemplate(
230
+ arguments="item_t *ary, item_t *out, scan_t *count_unique",
231
+ input_fetch_exprs=[
232
+ ("ary_im1", "ary", -1),
233
+ ("ary_i", "ary", 0),
234
+ ],
235
+ input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)",
236
+ scan_expr="a+b", neutral="0",
237
+ output_statement="""
238
+ if (prev_item != item) out[item-1] = ary[i];
239
+ if (i+1 == N) *count_unique = item;
240
+ """,
241
+ preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n",
242
+ template_processor="printf")
243
+
244
+
245
+ def unique(ary, is_equal_expr="a == b", extra_args=None, preamble="",
246
+ queue=None, wait_for=None):
247
+ """Copy the elements of *ary* into the output if *is_equal_expr*, applied to the
248
+ array element and its predecessor, yields false.
249
+
250
+ Works like the UNIX command :program:`uniq`, with a potentially custom
251
+ comparison. This operation is often used on sorted sequences.
252
+
253
+ :arg is_equal_expr: a C expression evaluating to a ``bool``,
254
+ represented as a string. The elements being compared are
255
+ available as ``a`` and ``b``. If this expression yields ``false``, the
256
+ two are considered distinct.
257
+ :arg extra_args: |scan_extra_args|
258
+ :arg preamble: |preamble|
259
+ :arg wait_for: |explain-waitfor|
260
+ :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
261
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
262
+ how many elements satisfied the predicate, and *event* is a
263
+ :class:`pyopencl.Event` for dependency management.
264
+
265
+ .. versionadded:: 2013.1
266
+ """
267
+
268
+ if len(ary) > np.iinfo(np.uint32).max:
269
+ scan_dtype = np.uint64
270
+ else:
271
+ scan_dtype = np.uint32
272
+
273
+ if wait_for is None:
274
+ wait_for = []
275
+
276
+ extra_args_types, extra_args_values, extra_wait_for = \
277
+ _extract_extra_args_types_values(extra_args)
278
+ wait_for = wait_for + extra_wait_for
279
+
280
+ knl = _unique_template.build(
281
+ ary.context,
282
+ type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
283
+ var_values=(("macro_is_equal_expr", is_equal_expr),),
284
+ more_preamble=preamble, more_arguments=extra_args_types)
285
+
286
+ out = cl_array.empty_like(ary)
287
+ count = ary._new_with_changes(data=None, offset=0,
288
+ shape=(), strides=(), dtype=scan_dtype)
289
+
290
+ evt = knl(ary, out, count, *extra_args_values,
291
+ queue=queue, wait_for=wait_for)
292
+
293
+ return out, count, evt
294
+
295
+ # }}}
296
+
297
+
298
+ # {{{ radix_sort
299
+
300
+ def to_bin(n):
301
+ # Py 2.5 has no built-in bin()
302
+ digs = []
303
+ while n:
304
+ digs.append(str(n % 2))
305
+ n >>= 1
306
+
307
+ return "".join(digs[::-1])
308
+
309
+
310
+ def _padded_bin(i, nbits):
311
+ s = to_bin(i)
312
+ while len(s) < nbits:
313
+ s = "0" + s
314
+ return s
315
+
316
+
317
+ @memoize
318
+ def _make_sort_scan_type(device, bits, index_dtype):
319
+ name = "pyopencl_sort_scan_%s_%dbits_t" % (
320
+ index_dtype.type.__name__, bits)
321
+
322
+ fields = []
323
+ for mnr in range(2**bits):
324
+ fields.append(("c%s" % _padded_bin(mnr, bits), index_dtype))
325
+
326
+ dtype = np.dtype(fields)
327
+
328
+ from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct
329
+ dtype, c_decl = match_dtype_to_c_struct(device, name, dtype)
330
+
331
+ dtype = get_or_register_dtype(name, dtype)
332
+ return name, dtype, c_decl
333
+
334
+
335
+ # {{{ types, helpers preamble
336
+
337
+ RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL//
338
+ typedef ${scan_ctype} scan_t;
339
+ typedef ${key_ctype} key_t;
340
+ typedef ${index_ctype} index_t;
341
+
342
+ // #define DEBUG
343
+ #ifdef DEBUG
344
+ #define dbg_printf(ARGS) printf ARGS
345
+ #else
346
+ #define dbg_printf(ARGS) /* */
347
+ #endif
348
+
349
+ index_t get_count(scan_t s, int mnr)
350
+ {
351
+ return ${get_count_branch("")};
352
+ }
353
+
354
+ #define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1})
355
+
356
+ """, strict_undefined=True)
357
+
358
+ # }}}
359
+
360
+ # {{{ scan helpers
361
+
362
+ RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL//
363
+ scan_t scan_t_neutral()
364
+ {
365
+ scan_t result;
366
+ %for mnr in range(2**bits):
367
+ result.c${padded_bin(mnr, bits)} = 0;
368
+ %endfor
369
+ return result;
370
+ }
371
+
372
+ // considers bits (base_bit+bits-1, ..., base_bit)
373
+ scan_t scan_t_from_value(
374
+ key_t key,
375
+ int base_bit,
376
+ int i
377
+ )
378
+ {
379
+ // extract relevant bit range
380
+ key_t bin_nr = BIN_NR(key);
381
+
382
+ dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr));
383
+
384
+ scan_t result;
385
+ %for mnr in range(2**bits):
386
+ result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr});
387
+ %endfor
388
+
389
+ return result;
390
+ }
391
+
392
+ scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary)
393
+ {
394
+ %for mnr in range(2**bits):
395
+ <% field = "c"+padded_bin(mnr, bits) %>
396
+ b.${field} = a.${field} + b.${field};
397
+ %endfor
398
+
399
+ return b;
400
+ }
401
+ """, strict_undefined=True)
402
+
403
+ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
404
+ {
405
+ key_t key = ${key_expr};
406
+ key_t my_bin_nr = BIN_NR(key);
407
+
408
+ index_t previous_bins_size = 0;
409
+ %for mnr in range(2**bits):
410
+ previous_bins_size +=
411
+ (my_bin_nr > ${mnr})
412
+ ? last_item.c${padded_bin(mnr, bits)}
413
+ : 0;
414
+ %endfor
415
+
416
+ index_t tgt_idx =
417
+ previous_bins_size
418
+ + get_count(item, my_bin_nr) - 1;
419
+
420
+ %for arg_name in sort_arg_names:
421
+ sorted_${arg_name}[tgt_idx] = ${arg_name}[i];
422
+ %endfor
423
+ }
424
+ """, strict_undefined=True)
425
+
426
+ # }}}
427
+
428
+
429
+ # {{{ driver
430
+
431
+ class RadixSort:
432
+ """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`__
433
+ on the compute device.
434
+
435
+ .. seealso:: :class:`pyopencl.bitonic_sort.BitonicSort`
436
+
437
+ .. versionadded:: 2013.1
438
+ """
439
+ def __init__(self, context, arguments, key_expr, sort_arg_names,
440
+ bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32,
441
+ scan_kernel=GenericScanKernel, options=None):
442
+ """
443
+ :arg arguments: A string of comma-separated C argument declarations.
444
+ If *arguments* is specified, then *input_expr* must also be
445
+ specified. All types used here must be known to PyOpenCL.
446
+ (see :func:`pyopencl.tools.get_or_register_dtype`).
447
+ :arg key_expr: An integer-valued C expression returning the
448
+ key based on which the sort is performed. The array index
449
+ for which the key is to be computed is available as ``i``.
450
+ The expression may refer to any of the *arguments*.
451
+ :arg sort_arg_names: A list of argument names whose corresponding
452
+ array arguments will be sorted according to *key_expr*.
453
+ """
454
+
455
+ # {{{ arg processing
456
+
457
+ from pyopencl.tools import parse_arg_list
458
+ self.arguments = parse_arg_list(arguments)
459
+ del arguments
460
+
461
+ self.sort_arg_names = sort_arg_names
462
+ self.bits = int(bits_at_a_time)
463
+ self.index_dtype = np.dtype(index_dtype)
464
+ self.key_dtype = np.dtype(key_dtype)
465
+
466
+ self.options = options
467
+
468
+ # }}}
469
+
470
+ # {{{ kernel creation
471
+
472
+ scan_ctype, scan_dtype, scan_t_cdecl = \
473
+ _make_sort_scan_type(context.devices[0], self.bits, self.index_dtype)
474
+
475
+ from pyopencl.tools import ScalarArg, VectorArg
476
+ scan_arguments = (
477
+ list(self.arguments)
478
+ + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments
479
+ if arg.name in sort_arg_names]
480
+ + [ScalarArg(np.int32, "base_bit")])
481
+
482
+ def get_count_branch(known_bits):
483
+ if len(known_bits) == self.bits:
484
+ return "s.c%s" % known_bits
485
+
486
+ boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0"
487
+
488
+ return ("((mnr < {}) ? {} : {})".format(
489
+ int(boundary_mnr, 2),
490
+ get_count_branch(known_bits+"0"),
491
+ get_count_branch(known_bits+"1")))
492
+
493
+ codegen_args = {
494
+ "bits": self.bits,
495
+ "key_ctype": dtype_to_ctype(self.key_dtype),
496
+ "key_expr": key_expr,
497
+ "index_ctype": dtype_to_ctype(self.index_dtype),
498
+ "index_type_max": np.iinfo(self.index_dtype).max,
499
+ "padded_bin": _padded_bin,
500
+ "scan_ctype": scan_ctype,
501
+ "sort_arg_names": sort_arg_names,
502
+ "get_count_branch": get_count_branch,
503
+ }
504
+
505
+ preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args)
506
+ scan_preamble = preamble \
507
+ + RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args)
508
+
509
+ self.scan_kernel = scan_kernel(
510
+ context, scan_dtype,
511
+ arguments=scan_arguments,
512
+ input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,
513
+ scan_expr="scan_t_add(a, b, across_seg_boundary)",
514
+ neutral="scan_t_neutral()",
515
+ output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args),
516
+ preamble=scan_preamble, options=self.options)
517
+
518
+ for i, arg in enumerate(self.arguments):
519
+ if isinstance(arg, VectorArg):
520
+ self.first_array_arg_idx = i
521
+
522
+ # }}}
523
+
524
+ def __call__(self, *args, **kwargs):
525
+ """Run the radix sort. In addition to *args* which must match the
526
+ *arguments* specification on the constructor, the following
527
+ keyword arguments are supported:
528
+
529
+ :arg key_bits: specify how many bits (starting from least-significant)
530
+ there are in the key.
531
+ :arg allocator: See the *allocator* argument of :func:`pyopencl.array.empty`.
532
+ :arg queue: A :class:`pyopencl.CommandQueue`, defaulting to the
533
+ one from the first argument array.
534
+ :arg wait_for: |explain-waitfor|
535
+ :returns: A tuple ``(sorted, event)``. *sorted* consists of sorted
536
+ copies of the arrays named in *sorted_args*, in the order of that
537
+ list. *event* is a :class:`pyopencl.Event` for dependency management.
538
+ """
539
+
540
+ wait_for = kwargs.pop("wait_for", None)
541
+
542
+ # {{{ run control
543
+
544
+ key_bits = kwargs.pop("key_bits", None)
545
+ if key_bits is None:
546
+ key_bits = int(np.iinfo(self.key_dtype).bits)
547
+
548
+ n = len(args[self.first_array_arg_idx])
549
+
550
+ allocator = kwargs.pop("allocator", None)
551
+ if allocator is None:
552
+ allocator = args[self.first_array_arg_idx].allocator
553
+
554
+ queue = kwargs.pop("queue", None)
555
+ if queue is None:
556
+ queue = args[self.first_array_arg_idx].queue
557
+
558
+ args = list(args)
559
+
560
+ base_bit = 0
561
+ while base_bit < key_bits:
562
+ sorted_args = [
563
+ cl_array.empty(queue, n, arg_descr.dtype, allocator=allocator)
564
+ for arg_descr in self.arguments
565
+ if arg_descr.name in self.sort_arg_names]
566
+
567
+ scan_args = args + sorted_args + [base_bit]
568
+
569
+ last_evt = self.scan_kernel(*scan_args,
570
+ queue=queue, wait_for=wait_for)
571
+ wait_for = [last_evt]
572
+
573
+ # substitute sorted
574
+ for i, arg_descr in enumerate(self.arguments):
575
+ if arg_descr.name in self.sort_arg_names:
576
+ args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)]
577
+
578
+ base_bit += self.bits
579
+
580
+ return [arg_val
581
+ for arg_descr, arg_val in zip(self.arguments, args, strict=True)
582
+ if arg_descr.name in self.sort_arg_names], last_evt
583
+
584
+ # }}}
585
+
586
+ # }}}
587
+
588
+ # }}}
589
+
590
+
591
+ # {{{ generic parallel list builder
592
+
593
+ # {{{ kernel template
594
+
595
+ _LIST_BUILDER_TEMPLATE = Template("""//CL//
596
+ % if double_support:
597
+ #if __OPENCL_C_VERSION__ < 120
598
+ #pragma OPENCL EXTENSION cl_khr_fp64: enable
599
+ #endif
600
+ #define PYOPENCL_DEFINE_CDOUBLE
601
+ % endif
602
+
603
+ #include <pyopencl-complex.h>
604
+
605
+ ${preamble}
606
+
607
+ // {{{ declare helper macros for user interface
608
+
609
+ typedef ${index_type} index_type;
610
+
611
+ %if is_count_stage:
612
+ #define PLB_COUNT_STAGE
613
+
614
+ %for name, dtype in list_names_and_dtypes:
615
+ %if name in count_sharing:
616
+ #define APPEND_${name}(value) { /* nothing */ }
617
+ %else:
618
+ #define APPEND_${name}(value) { ++(*plb_loc_${name}_count); }
619
+ %endif
620
+ %endfor
621
+ %else:
622
+ #define PLB_WRITE_STAGE
623
+
624
+ %for name, dtype in list_names_and_dtypes:
625
+ %if name in count_sharing:
626
+ #define APPEND_${name}(value) \
627
+ { plb_${name}_list[(*plb_${count_sharing[name]}_index) - 1] \
628
+ = value; }
629
+ %else:
630
+ #define APPEND_${name}(value) \
631
+ { plb_${name}_list[(*plb_${name}_index)++] = value; }
632
+ %endif
633
+ %endfor
634
+ %endif
635
+
636
+ #define LIST_ARG_DECL ${user_list_arg_decl}
637
+ #define LIST_ARGS ${user_list_args}
638
+ #define USER_ARG_DECL ${user_arg_decl_no_offset}
639
+ #define USER_ARGS ${user_args_no_offset}
640
+
641
+ // }}}
642
+
643
+ ${generate_template}
644
+
645
+ // {{{ kernel entry point
646
+
647
+ __kernel
648
+ %if do_not_vectorize:
649
+ __attribute__((reqd_work_group_size(1, 1, 1)))
650
+ %endif
651
+ void ${kernel_name}(
652
+ ${kernel_list_arg_decl} ${user_arg_decl_with_offset} index_type n)
653
+
654
+ {
655
+ %if not do_not_vectorize:
656
+ int lid = get_local_id(0);
657
+ index_type gsize = get_global_size(0);
658
+ index_type work_group_start = get_local_size(0)*get_group_id(0);
659
+ for (index_type i = work_group_start + lid; i < n; i += gsize)
660
+ %else:
661
+ const int chunk_size = 128;
662
+ index_type chunk_base = get_global_id(0)*chunk_size;
663
+ index_type gsize = get_global_size(0);
664
+ for (; chunk_base < n; chunk_base += gsize*chunk_size)
665
+ for (index_type i = chunk_base; i < min(n, chunk_base+chunk_size); ++i)
666
+ %endif
667
+ {
668
+ %if is_count_stage:
669
+ %for name, dtype in list_names_and_dtypes:
670
+ %if name not in count_sharing:
671
+ index_type plb_loc_${name}_count = 0;
672
+ %endif
673
+ %endfor
674
+ %else:
675
+ %for name, dtype in list_names_and_dtypes:
676
+ %if name not in count_sharing:
677
+ index_type plb_${name}_index;
678
+ if (plb_${name}_start_index)
679
+ %if name in eliminate_empty_output_lists:
680
+ plb_${name}_index =
681
+ plb_${name}_start_index[
682
+ ${name}_compressed_indices[i]
683
+ ];
684
+ %else:
685
+ plb_${name}_index = plb_${name}_start_index[i];
686
+ %endif
687
+ else
688
+ plb_${name}_index = 0;
689
+ %endif
690
+ %endfor
691
+ %endif
692
+
693
+ ${arg_offset_adjustment}
694
+ generate(${kernel_list_arg_values} USER_ARGS i);
695
+
696
+ %if is_count_stage:
697
+ %for name, dtype in list_names_and_dtypes:
698
+ %if name not in count_sharing:
699
+ if (plb_${name}_count)
700
+ plb_${name}_count[i] = plb_loc_${name}_count;
701
+ %endif
702
+ %endfor
703
+ %endif
704
+ }
705
+ }
706
+
707
+ // }}}
708
+
709
+ """, strict_undefined=True)
710
+
711
+ # }}}
712
+
713
+
714
+ def _get_arg_decl(arg_list):
715
+ result = ""
716
+ for arg in arg_list:
717
+ result += arg.declarator() + ", "
718
+
719
+ return result
720
+
721
+
722
+ def _get_arg_list(arg_list, prefix=""):
723
+ result = ""
724
+ for arg in arg_list:
725
+ result += prefix + arg.name + ", "
726
+
727
+ return result
728
+
729
+
730
+ @dataclass
731
+ class BuiltList:
732
+ count: int | None
733
+ starts: cl_array.Array | None
734
+ lists: cl_array.Array | None = None
735
+ num_nonempty_lists: int | None = None
736
+ nonempty_indices: cl_array.Array | None = None
737
+ compressed_indices: cl_array.Array | None = None
738
+
739
+
740
+ class ListOfListsBuilder:
741
+ """Generates and executes code to produce a large number of variable-size
742
+ lists, simply.
743
+
744
+ .. note:: This functionality is provided as a preview. Its interface
745
+ is subject to change until this notice is removed.
746
+
747
+ .. versionadded:: 2013.1
748
+
749
+ Here's a usage example::
750
+
751
+ from pyopencl.algorithm import ListOfListsBuilder
752
+ builder = ListOfListsBuilder(context, [("mylist", np.int32)], \"\"\"
753
+ void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
754
+ {
755
+ int count = i % 4;
756
+ for (int j = 0; j < count; ++j)
757
+ {
758
+ APPEND_mylist(count);
759
+ }
760
+ }
761
+ \"\"\", arg_decls=[])
762
+
763
+ result, event = builder(queue, 2000)
764
+
765
+ inf = result["mylist"]
766
+ assert inf.count == 3000
767
+ assert (inf.list.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()
768
+
769
+ The function ``generate`` above is called once for each "input object".
770
+ Each input object can then generate zero or more list entries.
771
+ The number of these input objects is given to :meth:`__call__` as *n_objects*.
772
+ List entries are generated by calls to ``APPEND_<list name>(value)``.
773
+ Multiple lists may be generated at once.
774
+
775
+ .. automethod:: __init__
776
+ .. automethod:: __call__
777
+ """
778
+ def __init__(self, context, list_names_and_dtypes, generate_template,
779
+ arg_decls, count_sharing=None, devices=None,
780
+ name_prefix="plb_build_list", options=None, preamble="",
781
+ debug=False, complex_kernel=False,
782
+ eliminate_empty_output_lists=False):
783
+ """
784
+ :arg context: A :class:`pyopencl.Context`.
785
+ :arg list_names_and_dtypes: a list of ``(name, dtype)`` tuples
786
+ indicating the lists to be built.
787
+ :arg generate_template: a snippet of C as described below
788
+ :arg arg_decls: A string of comma-separated C argument declarations.
789
+ :arg count_sharing: A mapping consisting of ``(child, mother)``
790
+ indicating that ``mother`` and ``child`` will always have the
791
+ same number of indices, and the ``APPEND`` to ``mother``
792
+ will always happen *before* the ``APPEND`` to the child.
793
+ :arg name_prefix: the name prefix to use for the compiled kernels
794
+ :arg options: OpenCL compilation options for kernels using
795
+ *generate_template*.
796
+ :arg complex_kernel: If *True*, prevents vectorization on CPUs.
797
+ :arg eliminate_empty_output_lists: A Python list of list names
798
+ for which the empty output lists are eliminated.
799
+
800
+ *generate_template* may use the following C macros/identifiers:
801
+
802
+ * ``index_type``: expands to C identifier for the index type used
803
+ for the calculation
804
+ * ``USER_ARG_DECL``: expands to the C declarator for ``arg_decls``
805
+ * ``USER_ARGS``: a list of C argument values corresponding to
806
+ ``user_arg_decl``
807
+ * ``LIST_ARG_DECL``: expands to a C argument list representing the
808
+ data for the output lists. These are escaped prefixed with
809
+ ``"plg_"`` so as to not interfere with user-provided names.
810
+ * ``LIST_ARGS``: a list of C argument values corresponding to
811
+ ``LIST_ARG_DECL``
812
+ * ``APPEND_name(entry)``: inserts ``entry`` into the list ``name``.
813
+ *entry* must be a valid C expression of the correct type.
814
+
815
+ All argument-list related macros have a trailing comma included
816
+ if they are non-empty.
817
+
818
+ *generate_template* must supply a function:
819
+
820
+ .. code-block:: c
821
+
822
+ void generate(USER_ARG_DECL LIST_ARG_DECL index_type i)
823
+ {
824
+ APPEND_mylist(5);
825
+ }
826
+
827
+ Internally, the ``kernel_template`` is expanded (at least) twice. Once,
828
+ for a 'counting' stage where the size of all the lists is determined,
829
+ and a second time, for a 'generation' stage where the lists are
830
+ actually filled. A ``generate`` function that has side effects beyond
831
+ calling ``append`` is therefore ill-formed.
832
+
833
+ .. versionchanged:: 2018.1
834
+
835
+ Change *eliminate_empty_output_lists* argument type from ``bool`` to
836
+ ``list``.
837
+ """
838
+ if devices is None:
839
+ devices = context.devices
840
+
841
+ if count_sharing is None:
842
+ count_sharing = {}
843
+
844
+ self.context = context
845
+ self.devices = devices
846
+
847
+ self.list_names_and_dtypes = list_names_and_dtypes
848
+ self.generate_template = generate_template
849
+
850
+ from pyopencl.tools import parse_arg_list
851
+ self.arg_decls = parse_arg_list(arg_decls)
852
+
853
+ # To match with the signature of the user-supplied generate(), arguments
854
+ # can't appear to have offsets.
855
+ arg_decls_no_offset = []
856
+ from pyopencl.tools import VectorArg
857
+ for arg in self.arg_decls:
858
+ if isinstance(arg, VectorArg) and arg.with_offset:
859
+ arg = VectorArg(arg.dtype, arg.name)
860
+ arg_decls_no_offset.append(arg)
861
+ self.arg_decls_no_offset = arg_decls_no_offset
862
+
863
+ self.count_sharing = count_sharing
864
+
865
+ self.name_prefix = name_prefix
866
+ self.preamble = preamble
867
+ self.options = options
868
+
869
+ self.debug = debug
870
+
871
+ self.complex_kernel = complex_kernel
872
+
873
+ if eliminate_empty_output_lists is True:
874
+ eliminate_empty_output_lists = \
875
+ [name for name, _ in self.list_names_and_dtypes]
876
+
877
+ if eliminate_empty_output_lists is False:
878
+ eliminate_empty_output_lists = []
879
+
880
+ self.eliminate_empty_output_lists = eliminate_empty_output_lists
881
+ for list_name in self.eliminate_empty_output_lists:
882
+ if not any(list_name == name for name, _ in self.list_names_and_dtypes):
883
+ raise ValueError(
884
+ "invalid list name '%s' in eliminate_empty_output_lists"
885
+ % list_name)
886
+
887
+ # {{{ kernel generators
888
+
889
+ @memoize_method
890
+ def get_scan_kernel(self, index_dtype):
891
+ return GenericScanKernel(
892
+ self.context, index_dtype,
893
+ arguments="__global %s *ary" % dtype_to_ctype(index_dtype),
894
+ input_expr="ary[i]",
895
+ scan_expr="a+b", neutral="0",
896
+ output_statement="ary[i+1] = item;",
897
+ devices=self.devices)
898
+
899
+ @memoize_method
900
+ def get_compress_kernel(self, index_dtype):
901
+ arguments = """
902
+ __global ${index_t} *count,
903
+ __global ${index_t} *compressed_counts,
904
+ __global ${index_t} *nonempty_indices,
905
+ __global ${index_t} *compressed_indices,
906
+ __global ${index_t} *num_non_empty_list
907
+ """
908
+ arguments = Template(arguments)
909
+
910
+ return GenericScanKernel(
911
+ self.context, index_dtype,
912
+ arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)),
913
+ input_expr="count[i] == 0 ? 0 : 1",
914
+ scan_expr="a+b", neutral="0",
915
+ output_statement="""
916
+ if (i + 1 < N) compressed_indices[i + 1] = item;
917
+ if (prev_item != item) {
918
+ nonempty_indices[item - 1] = i;
919
+ compressed_counts[item - 1] = count[i];
920
+ }
921
+ if (i + 1 == N) *num_non_empty_list = item;
922
+ """,
923
+ devices=self.devices)
924
+
925
+ def do_not_vectorize(self):
926
+ return (self.complex_kernel
927
+ and any(dev.type & cl.device_type.CPU
928
+ for dev in self.context.devices))
929
+
930
+ @memoize_method
931
+ def get_count_kernel(self, index_dtype):
932
+ index_ctype = dtype_to_ctype(index_dtype)
933
+ from pyopencl.tools import OtherArg, VectorArg
934
+ kernel_list_args = [
935
+ VectorArg(index_dtype, "plb_%s_count" % name)
936
+ for name, dtype in self.list_names_and_dtypes
937
+ if name not in self.count_sharing]
938
+
939
+ user_list_args = []
940
+ for name, _dtype in self.list_names_and_dtypes:
941
+ if name in self.count_sharing:
942
+ continue
943
+
944
+ name = "plb_loc_%s_count" % name
945
+ user_list_args.append(OtherArg("{} *{}".format(
946
+ index_ctype, name), name))
947
+
948
+ kernel_name = self.name_prefix+"_count"
949
+
950
+ from pyopencl.characterize import has_double_support
951
+ src = _LIST_BUILDER_TEMPLATE.render(
952
+ is_count_stage=True,
953
+ kernel_name=kernel_name,
954
+ double_support=all(has_double_support(dev) for dev in
955
+ self.context.devices),
956
+ debug=self.debug,
957
+ do_not_vectorize=self.do_not_vectorize(),
958
+ eliminate_empty_output_lists=self.eliminate_empty_output_lists,
959
+
960
+ kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
961
+ kernel_list_arg_values=_get_arg_list(user_list_args, prefix="&"),
962
+ user_list_arg_decl=_get_arg_decl(user_list_args),
963
+ user_list_args=_get_arg_list(user_list_args),
964
+ user_arg_decl_with_offset=_get_arg_decl(self.arg_decls),
965
+ user_arg_decl_no_offset=_get_arg_decl(self.arg_decls_no_offset),
966
+ user_args_no_offset=_get_arg_list(self.arg_decls_no_offset),
967
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.arg_decls),
968
+
969
+ list_names_and_dtypes=self.list_names_and_dtypes,
970
+ count_sharing=self.count_sharing,
971
+ name_prefix=self.name_prefix,
972
+ generate_template=self.generate_template,
973
+ preamble=self.preamble,
974
+
975
+ index_type=index_ctype,
976
+ )
977
+
978
+ src = str(src)
979
+
980
+ prg = cl.Program(self.context, src).build(self.options)
981
+ knl = getattr(prg, kernel_name)
982
+
983
+ from pyopencl.tools import get_arg_list_scalar_arg_dtypes
984
+ knl.set_scalar_arg_dtypes([
985
+ *get_arg_list_scalar_arg_dtypes([*kernel_list_args, *self.arg_decls]),
986
+ index_dtype
987
+ ])
988
+
989
+ return knl
990
+
991
+ @memoize_method
992
+ def get_write_kernel(self, index_dtype):
993
+ index_ctype = dtype_to_ctype(index_dtype)
994
+ from pyopencl.tools import OtherArg, VectorArg
995
+ kernel_list_args = []
996
+ kernel_list_arg_values = ""
997
+ user_list_args = []
998
+
999
+ for name, dtype in self.list_names_and_dtypes:
1000
+ list_name = "plb_%s_list" % name
1001
+ list_arg = VectorArg(dtype, list_name)
1002
+
1003
+ kernel_list_args.append(list_arg)
1004
+ user_list_args.append(list_arg)
1005
+
1006
+ if name in self.count_sharing:
1007
+ kernel_list_arg_values += "%s, " % list_name
1008
+ continue
1009
+
1010
+ kernel_list_args.append(
1011
+ VectorArg(index_dtype, "plb_%s_start_index" % name))
1012
+
1013
+ if name in self.eliminate_empty_output_lists:
1014
+ kernel_list_args.append(
1015
+ VectorArg(index_dtype, "%s_compressed_indices" % name))
1016
+
1017
+ index_name = "plb_%s_index" % name
1018
+ user_list_args.append(OtherArg("{} *{}".format(
1019
+ index_ctype, index_name), index_name))
1020
+
1021
+ kernel_list_arg_values += f"{list_name}, &{index_name}, "
1022
+
1023
+ kernel_name = self.name_prefix+"_write"
1024
+
1025
+ from pyopencl.characterize import has_double_support
1026
+ src = _LIST_BUILDER_TEMPLATE.render(
1027
+ is_count_stage=False,
1028
+ kernel_name=kernel_name,
1029
+ double_support=all(has_double_support(dev) for dev in
1030
+ self.context.devices),
1031
+ debug=self.debug,
1032
+ do_not_vectorize=self.do_not_vectorize(),
1033
+ eliminate_empty_output_lists=self.eliminate_empty_output_lists,
1034
+
1035
+ kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
1036
+ kernel_list_arg_values=kernel_list_arg_values,
1037
+ user_list_arg_decl=_get_arg_decl(user_list_args),
1038
+ user_list_args=_get_arg_list(user_list_args),
1039
+ user_arg_decl_with_offset=_get_arg_decl(self.arg_decls),
1040
+ user_arg_decl_no_offset=_get_arg_decl(self.arg_decls_no_offset),
1041
+ user_args_no_offset=_get_arg_list(self.arg_decls_no_offset),
1042
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.arg_decls),
1043
+
1044
+ list_names_and_dtypes=self.list_names_and_dtypes,
1045
+ count_sharing=self.count_sharing,
1046
+ name_prefix=self.name_prefix,
1047
+ generate_template=self.generate_template,
1048
+ preamble=self.preamble,
1049
+
1050
+ index_type=index_ctype,
1051
+ )
1052
+
1053
+ src = str(src)
1054
+
1055
+ prg = cl.Program(self.context, src).build(self.options)
1056
+ knl = getattr(prg, kernel_name)
1057
+
1058
+ from pyopencl.tools import get_arg_list_scalar_arg_dtypes
1059
+ knl.set_scalar_arg_dtypes([
1060
+ *get_arg_list_scalar_arg_dtypes(kernel_list_args + self.arg_decls),
1061
+ index_dtype])
1062
+
1063
+ return knl
1064
+
1065
+ # }}}
1066
+
1067
+ # {{{ driver
1068
+
1069
+ def __call__(self, queue, n_objects, *args, **kwargs):
1070
+ """
1071
+ :arg args: arguments corresponding to ``arg_decls`` in the constructor.
1072
+ Array-like arguments must be either 1D :class:`pyopencl.array.Array`
1073
+ objects or :class:`pyopencl.MemoryObject` objects, of which the latter
1074
+ can be obtained from a :class:`pyopencl.array.Array` using the
1075
+ :attr:`pyopencl.array.Array.data` attribute.
1076
+ :arg allocator: optionally, the allocator to use to allocate new
1077
+ arrays.
1078
+ :arg omit_lists: an iterable of list names that should *not* be built
1079
+ with this invocation. The kernel code may *not* call ``APPEND_name``
1080
+ for these omitted lists. If it does, undefined behavior will result.
1081
+ The returned *lists* dictionary will not contain an entry for names
1082
+ in *omit_lists*.
1083
+ :arg wait_for: |explain-waitfor|
1084
+ :returns: a tuple ``(lists, event)``, where ``lists`` is a mapping from
1085
+ (built) list names to objects which have attributes
1086
+
1087
+ * ``count`` for the total number of entries in all lists combined
1088
+ * ``lists`` for the array containing all lists.
1089
+ * ``starts`` for the array of starting indices in ``lists``.
1090
+ ``starts`` is built so that it has n+1 entries, so that
1091
+ the *i*'th entry is the start of the *i*'th list, and the
1092
+ *i*'th entry is the index one past the *i*'th list's end,
1093
+ even for the last list.
1094
+
1095
+ This implies that all lists are contiguous.
1096
+
1097
+ If the list name is specified in *eliminate_empty_output_lists*
1098
+ constructor argument, *lists* has two additional attributes
1099
+ ``num_nonempty_lists`` and ``nonempty_indices``
1100
+
1101
+ * ``num_nonempty_lists`` for the number of nonempty lists.
1102
+ * ``nonempty_indices`` for the index of nonempty list in input objects.
1103
+
1104
+ In this case, ``starts`` has ``num_nonempty_lists + 1`` entries.
1105
+ The *i*'s entry is the start of the *i*'th nonempty list, which is
1106
+ generated by the object with index ``nonempty_indices[i]``.
1107
+
1108
+ *event* is a :class:`pyopencl.Event` for dependency management.
1109
+
1110
+ .. versionchanged:: 2016.2
1111
+
1112
+ Added omit_lists.
1113
+ """
1114
+ if n_objects >= int(np.iinfo(np.int32).max):
1115
+ index_dtype = np.int64
1116
+ else:
1117
+ index_dtype = np.int32
1118
+ index_dtype = np.dtype(index_dtype)
1119
+
1120
+ allocator = kwargs.pop("allocator", None)
1121
+ omit_lists = kwargs.pop("omit_lists", [])
1122
+ wait_for = kwargs.pop("wait_for", None)
1123
+ if kwargs:
1124
+ raise TypeError("invalid keyword arguments: '%s'" % ", ".join(kwargs))
1125
+
1126
+ for oml in omit_lists:
1127
+ if not any(oml == name for name, _ in self.list_names_and_dtypes):
1128
+ raise ValueError("invalid list name '%s' in omit_lists")
1129
+
1130
+ result = {}
1131
+ count_list_args = []
1132
+
1133
+ if wait_for is None:
1134
+ wait_for = []
1135
+ else:
1136
+ # We'll be modifying it below.
1137
+ wait_for = list(wait_for)
1138
+
1139
+ count_kernel = self.get_count_kernel(index_dtype)
1140
+ write_kernel = self.get_write_kernel(index_dtype)
1141
+ scan_kernel = self.get_scan_kernel(index_dtype)
1142
+ if self.eliminate_empty_output_lists:
1143
+ compress_kernel = self.get_compress_kernel(index_dtype)
1144
+
1145
+ data_args = []
1146
+ for i, (arg_descr, arg_val) in enumerate(
1147
+ zip(self.arg_decls, args, strict=True)):
1148
+ from pyopencl.tools import VectorArg
1149
+ if isinstance(arg_descr, VectorArg):
1150
+ from pyopencl import MemoryObject
1151
+ if arg_val is None:
1152
+ data_args.append(arg_val)
1153
+ if arg_descr.with_offset:
1154
+ data_args.append(0)
1155
+ continue
1156
+
1157
+ if isinstance(arg_val, MemoryObject):
1158
+ data_args.append(arg_val)
1159
+ if arg_descr.with_offset:
1160
+ raise ValueError(
1161
+ "with_offset=True specified for argument %d "
1162
+ "but the argument is not an array" % i)
1163
+ continue
1164
+
1165
+ if arg_val.ndim != 1:
1166
+ raise ValueError("argument %d is a multidimensional array" % i)
1167
+
1168
+ data_args.append(arg_val.base_data)
1169
+ if arg_descr.with_offset:
1170
+ data_args.append(arg_val.offset)
1171
+ wait_for.extend(arg_val.events)
1172
+ else:
1173
+ data_args.append(arg_val)
1174
+
1175
+ del args
1176
+ data_args = tuple(data_args)
1177
+
1178
+ # {{{ allocate memory for counts
1179
+
1180
+ for name, _dtype in self.list_names_and_dtypes:
1181
+ if name in self.count_sharing:
1182
+ continue
1183
+ if name in omit_lists:
1184
+ count_list_args.append(None)
1185
+ continue
1186
+
1187
+ counts = cl_array.empty(queue,
1188
+ (n_objects + 1), index_dtype, allocator=allocator)
1189
+ counts[-1] = 0
1190
+ wait_for = wait_for + counts.events
1191
+
1192
+ # The scan will turn the "counts" array into the "starts" array
1193
+ # in-place.
1194
+ if name in self.eliminate_empty_output_lists:
1195
+ result[name] = BuiltList(count=None, starts=counts, lists=None,
1196
+ num_nonempty_lists=None,
1197
+ nonempty_indices=None)
1198
+ else:
1199
+ result[name] = BuiltList(count=None, starts=counts, lists=None)
1200
+ count_list_args.append(counts.data)
1201
+
1202
+ # }}}
1203
+
1204
+ if self.debug:
1205
+ gsize = (1,)
1206
+ lsize = (1,)
1207
+ elif self.do_not_vectorize():
1208
+ gsize = (4*queue.device.max_compute_units,)
1209
+ lsize = (1,)
1210
+ else:
1211
+ from pyopencl.array import _splay
1212
+ gsize, lsize = _splay(queue.device, n_objects)
1213
+
1214
+ count_event = count_kernel(queue, gsize, lsize,
1215
+ *(tuple(count_list_args) + data_args + (n_objects,)),
1216
+ wait_for=wait_for)
1217
+
1218
+ compress_events = {}
1219
+ for name, _dtype in self.list_names_and_dtypes:
1220
+ if name in omit_lists:
1221
+ continue
1222
+ if name in self.count_sharing:
1223
+ continue
1224
+ if name not in self.eliminate_empty_output_lists:
1225
+ continue
1226
+
1227
+ compressed_counts = cl_array.empty(
1228
+ queue, (n_objects + 1,), index_dtype, allocator=allocator)
1229
+ info_record = result[name]
1230
+ info_record.nonempty_indices = cl_array.empty(
1231
+ queue, (n_objects + 1,), index_dtype, allocator=allocator)
1232
+ info_record.num_nonempty_lists = cl_array.empty(
1233
+ queue, (1,), index_dtype, allocator=allocator)
1234
+ info_record.compressed_indices = cl_array.empty(
1235
+ queue, (n_objects + 1,), index_dtype, allocator=allocator)
1236
+ info_record.compressed_indices[0] = 0
1237
+
1238
+ compress_events[name] = compress_kernel(
1239
+ info_record.starts,
1240
+ compressed_counts,
1241
+ info_record.nonempty_indices,
1242
+ info_record.compressed_indices,
1243
+ info_record.num_nonempty_lists,
1244
+ wait_for=[count_event, *info_record.compressed_indices.events])
1245
+
1246
+ info_record.starts = compressed_counts
1247
+
1248
+ # {{{ run scans
1249
+
1250
+ scan_events = []
1251
+
1252
+ for name, _dtype in self.list_names_and_dtypes:
1253
+ if name in self.count_sharing:
1254
+ continue
1255
+ if name in omit_lists:
1256
+ continue
1257
+
1258
+ info_record = result[name]
1259
+ if name in self.eliminate_empty_output_lists:
1260
+ compress_events[name].wait()
1261
+ num_nonempty_lists = info_record.num_nonempty_lists.get()[0]
1262
+ info_record.num_nonempty_lists = num_nonempty_lists
1263
+ info_record.starts = info_record.starts[:num_nonempty_lists + 1]
1264
+ info_record.nonempty_indices = \
1265
+ info_record.nonempty_indices[:num_nonempty_lists]
1266
+ info_record.starts[-1] = 0
1267
+
1268
+ starts_ary = info_record.starts
1269
+ if name in self.eliminate_empty_output_lists:
1270
+ evt = scan_kernel(
1271
+ starts_ary,
1272
+ size=info_record.num_nonempty_lists,
1273
+ wait_for=starts_ary.events)
1274
+ else:
1275
+ evt = scan_kernel(starts_ary, wait_for=[count_event],
1276
+ size=n_objects)
1277
+
1278
+ starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
1279
+ scan_events.extend(starts_ary.events)
1280
+
1281
+ # retrieve count
1282
+ info_record.count = int(starts_ary[-1].get())
1283
+
1284
+ # }}}
1285
+
1286
+ # {{{ deal with count-sharing lists, allocate memory for lists
1287
+
1288
+ write_list_args = []
1289
+ for name, dtype in self.list_names_and_dtypes:
1290
+ if name in omit_lists:
1291
+ write_list_args.append(None)
1292
+ if name not in self.count_sharing:
1293
+ write_list_args.append(None)
1294
+ if name in self.eliminate_empty_output_lists:
1295
+ write_list_args.append(None)
1296
+ continue
1297
+
1298
+ if name in self.count_sharing:
1299
+ sharing_from = self.count_sharing[name]
1300
+
1301
+ info_record = result[name] = BuiltList(
1302
+ count=result[sharing_from].count,
1303
+ starts=result[sharing_from].starts,
1304
+ )
1305
+
1306
+ else:
1307
+ info_record = result[name]
1308
+
1309
+ info_record.lists = cl_array.empty(queue,
1310
+ info_record.count, dtype, allocator=allocator)
1311
+ write_list_args.append(info_record.lists.data)
1312
+
1313
+ if name not in self.count_sharing:
1314
+ write_list_args.append(info_record.starts.data)
1315
+
1316
+ if name in self.eliminate_empty_output_lists:
1317
+ write_list_args.append(info_record.compressed_indices.data)
1318
+
1319
+ # }}}
1320
+
1321
+ evt = write_kernel(queue, gsize, lsize,
1322
+ *(tuple(write_list_args) + data_args + (n_objects,)),
1323
+ wait_for=scan_events)
1324
+
1325
+ return result, evt
1326
+
1327
+ # }}}
1328
+
1329
+ # }}}
1330
+
1331
+
1332
+ # {{{ key-value sorting
1333
+
1334
+ @dataclass(frozen=True)
1335
+ class _KernelInfo:
1336
+ by_target_sorter: RadixSort
1337
+ start_finder: ElementwiseKernel
1338
+ bound_propagation_scan: GenericScanKernel
1339
+
1340
+
1341
+ def _make_cl_int_literal(value, dtype):
1342
+ iinfo = np.iinfo(dtype)
1343
+ result = str(int(value))
1344
+ if dtype.itemsize == 8:
1345
+ result += "l"
1346
+ if int(iinfo.min) < 0:
1347
+ result += "u"
1348
+
1349
+ return result
1350
+
1351
+
1352
+ class KeyValueSorter:
1353
+ """Given arrays *values* and *keys* of equal length
1354
+ and a number *nkeys* of keys, returns a tuple `(starts,
1355
+ lists)`, as follows: *values* and *keys* are sorted
1356
+ by *keys*, and the sorted *values* is returned as
1357
+ *lists*. Then for each index *i* in ``range(nkeys)``,
1358
+ *starts[i]* is written to indicating where the
1359
+ group of *values* belonging to the key with index
1360
+ *i* begins. It implicitly ends at *starts[i+1]*.
1361
+
1362
+ ``starts`` is built so that it has ``nkeys + 1`` entries, so that
1363
+ the *i*'th entry is the start of the *i*'th list, and the
1364
+ *i*'th entry is the index one past the *i*'th list's end,
1365
+ even for the last list.
1366
+
1367
+ This implies that all lists are contiguous.
1368
+
1369
+ .. note:: This functionality is provided as a preview. Its
1370
+ interface is subject to change until this notice is removed.
1371
+
1372
+ .. versionadded:: 2013.1
1373
+ """
1374
+
1375
+ def __init__(self, context):
1376
+ self.context = context
1377
+
1378
+ @memoize_method
1379
+ def get_kernels(self, key_dtype, value_dtype, starts_dtype):
1380
+ from pyopencl.tools import ScalarArg, VectorArg
1381
+
1382
+ by_target_sorter = RadixSort(
1383
+ self.context, [
1384
+ VectorArg(value_dtype, "values"),
1385
+ VectorArg(key_dtype, "keys"),
1386
+ ],
1387
+ key_expr="keys[i]",
1388
+ sort_arg_names=["values", "keys"])
1389
+
1390
+ from pyopencl.elementwise import ElementwiseTemplate
1391
+ start_finder = ElementwiseTemplate(
1392
+ arguments="""//CL//
1393
+ starts_t *key_group_starts,
1394
+ key_t *keys_sorted_by_key,
1395
+ """,
1396
+
1397
+ operation=r"""//CL//
1398
+ key_t my_key = keys_sorted_by_key[i];
1399
+
1400
+ if (i == 0 || my_key != keys_sorted_by_key[i-1])
1401
+ key_group_starts[my_key] = i;
1402
+ """,
1403
+ name="find_starts").build(self.context,
1404
+ type_aliases=(
1405
+ ("key_t", starts_dtype),
1406
+ ("starts_t", starts_dtype),
1407
+ ),
1408
+ var_values=())
1409
+
1410
+ bound_propagation_scan = GenericScanKernel(
1411
+ self.context, starts_dtype,
1412
+ arguments=[
1413
+ VectorArg(starts_dtype, "starts"),
1414
+ # starts has length n+1
1415
+ ScalarArg(key_dtype, "nkeys"),
1416
+ ],
1417
+ input_expr="starts[nkeys-i]",
1418
+ scan_expr="min(a, b)",
1419
+ neutral=_make_cl_int_literal(
1420
+ np.iinfo(starts_dtype).max, starts_dtype),
1421
+ output_statement="starts[nkeys-i] = item;")
1422
+
1423
+ return _KernelInfo(
1424
+ by_target_sorter=by_target_sorter,
1425
+ start_finder=start_finder,
1426
+ bound_propagation_scan=bound_propagation_scan)
1427
+
1428
+ def __call__(self, queue, keys, values, nkeys,
1429
+ starts_dtype, allocator=None, wait_for=None):
1430
+ if allocator is None:
1431
+ allocator = values.allocator
1432
+
1433
+ knl_info = self.get_kernels(keys.dtype, values.dtype,
1434
+ starts_dtype)
1435
+
1436
+ (values_sorted_by_key, keys_sorted_by_key), evt = knl_info.by_target_sorter(
1437
+ values, keys, queue=queue, wait_for=wait_for)
1438
+
1439
+ starts = (cl_array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
1440
+ .fill(len(values_sorted_by_key), wait_for=[evt]))
1441
+ evt, = starts.events
1442
+
1443
+ evt = knl_info.start_finder(starts, keys_sorted_by_key,
1444
+ range=slice(len(keys_sorted_by_key)),
1445
+ wait_for=[evt])
1446
+
1447
+ evt = knl_info.bound_propagation_scan(starts, nkeys,
1448
+ queue=queue, wait_for=[evt])
1449
+
1450
+ return starts, values_sorted_by_key, evt
1451
+
1452
+ # }}}
1453
+
1454
+ # vim: filetype=pyopencl:fdm=marker