pyopencl 2024.3__cp312-cp312-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

Files changed (43) hide show
  1. pyopencl/.libs/libOpenCL-1ef0e16e.so.1.0.0 +0 -0
  2. pyopencl/__init__.py +2410 -0
  3. pyopencl/_cl.cpython-312-x86_64-linux-musl.so +0 -0
  4. pyopencl/_cluda.py +54 -0
  5. pyopencl/_mymako.py +14 -0
  6. pyopencl/algorithm.py +1449 -0
  7. pyopencl/array.py +3437 -0
  8. pyopencl/bitonic_sort.py +242 -0
  9. pyopencl/bitonic_sort_templates.py +594 -0
  10. pyopencl/cache.py +535 -0
  11. pyopencl/capture_call.py +177 -0
  12. pyopencl/characterize/__init__.py +456 -0
  13. pyopencl/characterize/performance.py +237 -0
  14. pyopencl/cl/pyopencl-airy.cl +324 -0
  15. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  16. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  17. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  18. pyopencl/cl/pyopencl-complex.h +303 -0
  19. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  20. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  21. pyopencl/cl/pyopencl-random123/array.h +325 -0
  22. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  23. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  24. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  25. pyopencl/clmath.py +280 -0
  26. pyopencl/clrandom.py +409 -0
  27. pyopencl/cltypes.py +137 -0
  28. pyopencl/compyte/.gitignore +21 -0
  29. pyopencl/compyte/__init__.py +0 -0
  30. pyopencl/compyte/array.py +214 -0
  31. pyopencl/compyte/dtypes.py +290 -0
  32. pyopencl/compyte/pyproject.toml +54 -0
  33. pyopencl/elementwise.py +1171 -0
  34. pyopencl/invoker.py +421 -0
  35. pyopencl/ipython_ext.py +68 -0
  36. pyopencl/reduction.py +786 -0
  37. pyopencl/scan.py +1915 -0
  38. pyopencl/tools.py +1527 -0
  39. pyopencl/version.py +9 -0
  40. pyopencl-2024.3.dist-info/METADATA +108 -0
  41. pyopencl-2024.3.dist-info/RECORD +43 -0
  42. pyopencl-2024.3.dist-info/WHEEL +5 -0
  43. pyopencl-2024.3.dist-info/licenses/LICENSE +104 -0
pyopencl/algorithm.py ADDED
@@ -0,0 +1,1449 @@
1
+ """Algorithms built on scans."""
2
+
3
+
4
+ __copyright__ = """
5
+ Copyright 2011-2012 Andreas Kloeckner
6
+ Copyright 2017 Hao Gao
7
+ """
8
+
9
+ __license__ = """
10
+ Permission is hereby granted, free of charge, to any person
11
+ obtaining a copy of this software and associated documentation
12
+ files (the "Software"), to deal in the Software without
13
+ restriction, including without limitation the rights to use,
14
+ copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the
16
+ Software is furnished to do so, subject to the following
17
+ conditions:
18
+
19
+ The above copyright notice and this permission notice shall be
20
+ included in all copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
24
+ OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
26
+ HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
27
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
28
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
29
+ OTHER DEALINGS IN THE SOFTWARE.
30
+ """
31
+
32
+ from dataclasses import dataclass
33
+ from typing import Optional
34
+
35
+ import numpy as np
36
+ from mako.template import Template
37
+
38
+ from pytools import memoize, memoize_method
39
+
40
+ import pyopencl as cl
41
+ import pyopencl.array
42
+ from pyopencl.elementwise import ElementwiseKernel
43
+ from pyopencl.scan import GenericScanKernel, ScanTemplate
44
+ from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code
45
+
46
+
47
+ # {{{ "extra args" handling utility
48
+
49
+ def _extract_extra_args_types_values(extra_args):
50
+ if extra_args is None:
51
+ extra_args = []
52
+ from pyopencl.tools import ScalarArg, VectorArg
53
+
54
+ extra_args_types = []
55
+ extra_args_values = []
56
+ extra_wait_for = []
57
+ for name, val in extra_args:
58
+ if isinstance(val, cl.array.Array):
59
+ extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
60
+ extra_args_values.append(val)
61
+ extra_wait_for.extend(val.events)
62
+ elif isinstance(val, np.generic):
63
+ extra_args_types.append(ScalarArg(val.dtype, name))
64
+ extra_args_values.append(val)
65
+ else:
66
+ raise RuntimeError("argument '%d' not understood" % name)
67
+
68
+ return tuple(extra_args_types), extra_args_values, extra_wait_for
69
+
70
+ # }}}
71
+
72
+
73
+ # {{{ copy_if
74
+
75
+ _copy_if_template = ScanTemplate(
76
+ arguments="item_t *ary, item_t *out, scan_t *count",
77
+ input_expr="(%(predicate)s) ? 1 : 0",
78
+ scan_expr="a+b", neutral="0",
79
+ output_statement="""
80
+ if (prev_item != item) out[item-1] = ary[i];
81
+ if (i+1 == N) *count = item;
82
+ """,
83
+ template_processor="printf")
84
+
85
+
86
+ def copy_if(ary, predicate, extra_args=None, preamble="", queue=None, wait_for=None):
87
+ """Copy the elements of *ary* satisfying *predicate* to an output array.
88
+
89
+ :arg predicate: a C expression evaluating to a ``bool``, represented as a string.
90
+ The value to test is available as ``ary[i]``, and if the expression evaluates
91
+ to ``true``, then this value ends up in the output.
92
+ :arg extra_args: |scan_extra_args|
93
+ :arg preamble: |preamble|
94
+ :arg wait_for: |explain-waitfor|
95
+ :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
96
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
97
+ how many elements satisfied *predicate*, and *event* is a
98
+ :class:`pyopencl.Event` for dependency management. *out* is allocated
99
+ to the same length as *ary*, but only the first *count* entries carry
100
+ meaning.
101
+
102
+ .. versionadded:: 2013.1
103
+ """
104
+ if len(ary) > np.iinfo(np.int32).max:
105
+ scan_dtype = np.int64
106
+ else:
107
+ scan_dtype = np.int32
108
+
109
+ if wait_for is None:
110
+ wait_for = []
111
+
112
+ extra_args_types, extra_args_values, extra_wait_for = \
113
+ _extract_extra_args_types_values(extra_args)
114
+ wait_for = wait_for + extra_wait_for
115
+
116
+ knl = _copy_if_template.build(ary.context,
117
+ type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
118
+ var_values=(("predicate", predicate),),
119
+ more_preamble=preamble, more_arguments=extra_args_types)
120
+ out = cl.array.empty_like(ary)
121
+ count = ary._new_with_changes(data=None, offset=0,
122
+ shape=(), strides=(), dtype=scan_dtype)
123
+
124
+ evt = knl(ary, out, count, *extra_args_values,
125
+ queue=queue, wait_for=wait_for)
126
+
127
+ return out, count, evt
128
+
129
+ # }}}
130
+
131
+
132
+ # {{{ remove_if
133
+
134
+ def remove_if(ary, predicate, extra_args=None, preamble="",
135
+ queue=None, wait_for=None):
136
+ """Copy the elements of *ary* not satisfying *predicate* to an output array.
137
+
138
+ :arg predicate: a C expression evaluating to a ``bool``, represented as a string.
139
+ The value to test is available as ``ary[i]``, and if the expression evaluates
140
+ to ``false``, then this value ends up in the output.
141
+ :arg extra_args: |scan_extra_args|
142
+ :arg preamble: |preamble|
143
+ :arg wait_for: |explain-waitfor|
144
+ :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
145
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
146
+ how many elements did not satisfy *predicate*, and *event* is a
147
+ :class:`pyopencl.Event` for dependency management.
148
+
149
+ .. versionadded:: 2013.1
150
+ """
151
+ return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args,
152
+ preamble=preamble, queue=queue, wait_for=wait_for)
153
+
154
+ # }}}
155
+
156
+
157
+ # {{{ partition
158
+
159
+ _partition_template = ScanTemplate(
160
+ arguments=(
161
+ "item_t *ary, item_t *out_true, item_t *out_false, "
162
+ "scan_t *count_true"),
163
+ input_expr="(%(predicate)s) ? 1 : 0",
164
+ scan_expr="a+b", neutral="0",
165
+ output_statement="""//CL//
166
+ if (prev_item != item)
167
+ out_true[item-1] = ary[i];
168
+ else
169
+ out_false[i-item] = ary[i];
170
+ if (i+1 == N) *count_true = item;
171
+ """,
172
+ template_processor="printf")
173
+
174
+
175
+ def partition(ary, predicate, extra_args=None, preamble="",
176
+ queue=None, wait_for=None):
177
+ """Copy the elements of *ary* into one of two arrays depending on whether
178
+ they satisfy *predicate*.
179
+
180
+ :arg predicate: a C expression evaluating to a ``bool``, represented as a string.
181
+ The value to test is available as ``ary[i]``.
182
+ :arg extra_args: |scan_extra_args|
183
+ :arg preamble: |preamble|
184
+ :arg wait_for: |explain-waitfor|
185
+ :returns: a tuple *(out_true, out_false, count, event)* where *count*
186
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
187
+ how many elements satisfied the predicate, and *event* is a
188
+ :class:`pyopencl.Event` for dependency management.
189
+
190
+ .. versionadded:: 2013.1
191
+ """
192
+ if len(ary) > np.iinfo(np.uint32).max:
193
+ scan_dtype = np.uint64
194
+ else:
195
+ scan_dtype = np.uint32
196
+
197
+ if wait_for is None:
198
+ wait_for = []
199
+
200
+ extra_args_types, extra_args_values, extra_wait_for = \
201
+ _extract_extra_args_types_values(extra_args)
202
+ wait_for = wait_for + extra_wait_for
203
+
204
+ knl = _partition_template.build(
205
+ ary.context,
206
+ type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
207
+ var_values=(("predicate", predicate),),
208
+ more_preamble=preamble, more_arguments=extra_args_types)
209
+
210
+ out_true = cl.array.empty_like(ary)
211
+ out_false = cl.array.empty_like(ary)
212
+ count = ary._new_with_changes(data=None, offset=0,
213
+ shape=(), strides=(), dtype=scan_dtype)
214
+
215
+ evt = knl(ary, out_true, out_false, count, *extra_args_values,
216
+ queue=queue, wait_for=wait_for)
217
+
218
+ return out_true, out_false, count, evt
219
+
220
+ # }}}
221
+
222
+
223
+ # {{{ unique
224
+
225
+ _unique_template = ScanTemplate(
226
+ arguments="item_t *ary, item_t *out, scan_t *count_unique",
227
+ input_fetch_exprs=[
228
+ ("ary_im1", "ary", -1),
229
+ ("ary_i", "ary", 0),
230
+ ],
231
+ input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)",
232
+ scan_expr="a+b", neutral="0",
233
+ output_statement="""
234
+ if (prev_item != item) out[item-1] = ary[i];
235
+ if (i+1 == N) *count_unique = item;
236
+ """,
237
+ preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n",
238
+ template_processor="printf")
239
+
240
+
241
+ def unique(ary, is_equal_expr="a == b", extra_args=None, preamble="",
242
+ queue=None, wait_for=None):
243
+ """Copy the elements of *ary* into the output if *is_equal_expr*, applied to the
244
+ array element and its predecessor, yields false.
245
+
246
+ Works like the UNIX command :program:`uniq`, with a potentially custom
247
+ comparison. This operation is often used on sorted sequences.
248
+
249
+ :arg is_equal_expr: a C expression evaluating to a ``bool``,
250
+ represented as a string. The elements being compared are
251
+ available as ``a`` and ``b``. If this expression yields ``false``, the
252
+ two are considered distinct.
253
+ :arg extra_args: |scan_extra_args|
254
+ :arg preamble: |preamble|
255
+ :arg wait_for: |explain-waitfor|
256
+ :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
257
+ is an on-device scalar (fetch to host with ``count.get()``) indicating
258
+ how many elements satisfied the predicate, and *event* is a
259
+ :class:`pyopencl.Event` for dependency management.
260
+
261
+ .. versionadded:: 2013.1
262
+ """
263
+
264
+ if len(ary) > np.iinfo(np.uint32).max:
265
+ scan_dtype = np.uint64
266
+ else:
267
+ scan_dtype = np.uint32
268
+
269
+ if wait_for is None:
270
+ wait_for = []
271
+
272
+ extra_args_types, extra_args_values, extra_wait_for = \
273
+ _extract_extra_args_types_values(extra_args)
274
+ wait_for = wait_for + extra_wait_for
275
+
276
+ knl = _unique_template.build(
277
+ ary.context,
278
+ type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
279
+ var_values=(("macro_is_equal_expr", is_equal_expr),),
280
+ more_preamble=preamble, more_arguments=extra_args_types)
281
+
282
+ out = cl.array.empty_like(ary)
283
+ count = ary._new_with_changes(data=None, offset=0,
284
+ shape=(), strides=(), dtype=scan_dtype)
285
+
286
+ evt = knl(ary, out, count, *extra_args_values,
287
+ queue=queue, wait_for=wait_for)
288
+
289
+ return out, count, evt
290
+
291
+ # }}}
292
+
293
+
294
+ # {{{ radix_sort
295
+
296
+ def to_bin(n):
297
+ # Py 2.5 has no built-in bin()
298
+ digs = []
299
+ while n:
300
+ digs.append(str(n % 2))
301
+ n >>= 1
302
+
303
+ return "".join(digs[::-1])
304
+
305
+
306
+ def _padded_bin(i, nbits):
307
+ s = to_bin(i)
308
+ while len(s) < nbits:
309
+ s = "0" + s
310
+ return s
311
+
312
+
313
+ @memoize
314
+ def _make_sort_scan_type(device, bits, index_dtype):
315
+ name = "pyopencl_sort_scan_%s_%dbits_t" % (
316
+ index_dtype.type.__name__, bits)
317
+
318
+ fields = []
319
+ for mnr in range(2**bits):
320
+ fields.append(("c%s" % _padded_bin(mnr, bits), index_dtype))
321
+
322
+ dtype = np.dtype(fields)
323
+
324
+ from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct
325
+ dtype, c_decl = match_dtype_to_c_struct(device, name, dtype)
326
+
327
+ dtype = get_or_register_dtype(name, dtype)
328
+ return name, dtype, c_decl
329
+
330
+
331
+ # {{{ types, helpers preamble
332
+
333
+ RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL//
334
+ typedef ${scan_ctype} scan_t;
335
+ typedef ${key_ctype} key_t;
336
+ typedef ${index_ctype} index_t;
337
+
338
+ // #define DEBUG
339
+ #ifdef DEBUG
340
+ #define dbg_printf(ARGS) printf ARGS
341
+ #else
342
+ #define dbg_printf(ARGS) /* */
343
+ #endif
344
+
345
+ index_t get_count(scan_t s, int mnr)
346
+ {
347
+ return ${get_count_branch("")};
348
+ }
349
+
350
+ #define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1})
351
+
352
+ """, strict_undefined=True)
353
+
354
+ # }}}
355
+
356
+ # {{{ scan helpers
357
+
358
+ RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL//
359
+ scan_t scan_t_neutral()
360
+ {
361
+ scan_t result;
362
+ %for mnr in range(2**bits):
363
+ result.c${padded_bin(mnr, bits)} = 0;
364
+ %endfor
365
+ return result;
366
+ }
367
+
368
+ // considers bits (base_bit+bits-1, ..., base_bit)
369
+ scan_t scan_t_from_value(
370
+ key_t key,
371
+ int base_bit,
372
+ int i
373
+ )
374
+ {
375
+ // extract relevant bit range
376
+ key_t bin_nr = BIN_NR(key);
377
+
378
+ dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr));
379
+
380
+ scan_t result;
381
+ %for mnr in range(2**bits):
382
+ result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr});
383
+ %endfor
384
+
385
+ return result;
386
+ }
387
+
388
+ scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary)
389
+ {
390
+ %for mnr in range(2**bits):
391
+ <% field = "c"+padded_bin(mnr, bits) %>
392
+ b.${field} = a.${field} + b.${field};
393
+ %endfor
394
+
395
+ return b;
396
+ }
397
+ """, strict_undefined=True)
398
+
399
+ RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
400
+ {
401
+ key_t key = ${key_expr};
402
+ key_t my_bin_nr = BIN_NR(key);
403
+
404
+ index_t previous_bins_size = 0;
405
+ %for mnr in range(2**bits):
406
+ previous_bins_size +=
407
+ (my_bin_nr > ${mnr})
408
+ ? last_item.c${padded_bin(mnr, bits)}
409
+ : 0;
410
+ %endfor
411
+
412
+ index_t tgt_idx =
413
+ previous_bins_size
414
+ + get_count(item, my_bin_nr) - 1;
415
+
416
+ %for arg_name in sort_arg_names:
417
+ sorted_${arg_name}[tgt_idx] = ${arg_name}[i];
418
+ %endfor
419
+ }
420
+ """, strict_undefined=True)
421
+
422
+ # }}}
423
+
424
+
425
+ # {{{ driver
426
+
427
+ class RadixSort:
428
+ """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`__
429
+ on the compute device.
430
+
431
+ .. seealso:: :class:`pyopencl.bitonic_sort.BitonicSort`
432
+
433
+ .. versionadded:: 2013.1
434
+ """
435
+ def __init__(self, context, arguments, key_expr, sort_arg_names,
436
+ bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32,
437
+ scan_kernel=GenericScanKernel, options=None):
438
+ """
439
+ :arg arguments: A string of comma-separated C argument declarations.
440
+ If *arguments* is specified, then *input_expr* must also be
441
+ specified. All types used here must be known to PyOpenCL.
442
+ (see :func:`pyopencl.tools.get_or_register_dtype`).
443
+ :arg key_expr: An integer-valued C expression returning the
444
+ key based on which the sort is performed. The array index
445
+ for which the key is to be computed is available as ``i``.
446
+ The expression may refer to any of the *arguments*.
447
+ :arg sort_arg_names: A list of argument names whose corresponding
448
+ array arguments will be sorted according to *key_expr*.
449
+ """
450
+
451
+ # {{{ arg processing
452
+
453
+ from pyopencl.tools import parse_arg_list
454
+ self.arguments = parse_arg_list(arguments)
455
+ del arguments
456
+
457
+ self.sort_arg_names = sort_arg_names
458
+ self.bits = int(bits_at_a_time)
459
+ self.index_dtype = np.dtype(index_dtype)
460
+ self.key_dtype = np.dtype(key_dtype)
461
+
462
+ self.options = options
463
+
464
+ # }}}
465
+
466
+ # {{{ kernel creation
467
+
468
+ scan_ctype, scan_dtype, scan_t_cdecl = \
469
+ _make_sort_scan_type(context.devices[0], self.bits, self.index_dtype)
470
+
471
+ from pyopencl.tools import ScalarArg, VectorArg
472
+ scan_arguments = (
473
+ list(self.arguments)
474
+ + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments
475
+ if arg.name in sort_arg_names]
476
+ + [ScalarArg(np.int32, "base_bit")])
477
+
478
+ def get_count_branch(known_bits):
479
+ if len(known_bits) == self.bits:
480
+ return "s.c%s" % known_bits
481
+
482
+ boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0"
483
+
484
+ return ("((mnr < {}) ? {} : {})".format(
485
+ int(boundary_mnr, 2),
486
+ get_count_branch(known_bits+"0"),
487
+ get_count_branch(known_bits+"1")))
488
+
489
+ codegen_args = {
490
+ "bits": self.bits,
491
+ "key_ctype": dtype_to_ctype(self.key_dtype),
492
+ "key_expr": key_expr,
493
+ "index_ctype": dtype_to_ctype(self.index_dtype),
494
+ "index_type_max": np.iinfo(self.index_dtype).max,
495
+ "padded_bin": _padded_bin,
496
+ "scan_ctype": scan_ctype,
497
+ "sort_arg_names": sort_arg_names,
498
+ "get_count_branch": get_count_branch,
499
+ }
500
+
501
+ preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args)
502
+ scan_preamble = preamble \
503
+ + RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args)
504
+
505
+ self.scan_kernel = scan_kernel(
506
+ context, scan_dtype,
507
+ arguments=scan_arguments,
508
+ input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,
509
+ scan_expr="scan_t_add(a, b, across_seg_boundary)",
510
+ neutral="scan_t_neutral()",
511
+ output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args),
512
+ preamble=scan_preamble, options=self.options)
513
+
514
+ for i, arg in enumerate(self.arguments):
515
+ if isinstance(arg, VectorArg):
516
+ self.first_array_arg_idx = i
517
+
518
+ # }}}
519
+
520
+ def __call__(self, *args, **kwargs):
521
+ """Run the radix sort. In addition to *args* which must match the
522
+ *arguments* specification on the constructor, the following
523
+ keyword arguments are supported:
524
+
525
+ :arg key_bits: specify how many bits (starting from least-significant)
526
+ there are in the key.
527
+ :arg allocator: See the *allocator* argument of :func:`pyopencl.array.empty`.
528
+ :arg queue: A :class:`pyopencl.CommandQueue`, defaulting to the
529
+ one from the first argument array.
530
+ :arg wait_for: |explain-waitfor|
531
+ :returns: A tuple ``(sorted, event)``. *sorted* consists of sorted
532
+ copies of the arrays named in *sorted_args*, in the order of that
533
+ list. *event* is a :class:`pyopencl.Event` for dependency management.
534
+ """
535
+
536
+ wait_for = kwargs.pop("wait_for", None)
537
+
538
+ # {{{ run control
539
+
540
+ key_bits = kwargs.pop("key_bits", None)
541
+ if key_bits is None:
542
+ key_bits = int(np.iinfo(self.key_dtype).bits)
543
+
544
+ n = len(args[self.first_array_arg_idx])
545
+
546
+ allocator = kwargs.pop("allocator", None)
547
+ if allocator is None:
548
+ allocator = args[self.first_array_arg_idx].allocator
549
+
550
+ queue = kwargs.pop("queue", None)
551
+ if queue is None:
552
+ queue = args[self.first_array_arg_idx].queue
553
+
554
+ args = list(args)
555
+
556
+ base_bit = 0
557
+ while base_bit < key_bits:
558
+ sorted_args = [
559
+ cl.array.empty(queue, n, arg_descr.dtype, allocator=allocator)
560
+ for arg_descr in self.arguments
561
+ if arg_descr.name in self.sort_arg_names]
562
+
563
+ scan_args = args + sorted_args + [base_bit]
564
+
565
+ last_evt = self.scan_kernel(*scan_args,
566
+ queue=queue, wait_for=wait_for)
567
+ wait_for = [last_evt]
568
+
569
+ # substitute sorted
570
+ for i, arg_descr in enumerate(self.arguments):
571
+ if arg_descr.name in self.sort_arg_names:
572
+ args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)]
573
+
574
+ base_bit += self.bits
575
+
576
+ return [arg_val
577
+ for arg_descr, arg_val in zip(self.arguments, args)
578
+ if arg_descr.name in self.sort_arg_names], last_evt
579
+
580
+ # }}}
581
+
582
+ # }}}
583
+
584
+ # }}}
585
+
586
+
587
+ # {{{ generic parallel list builder
588
+
589
+ # {{{ kernel template
590
+
591
+ _LIST_BUILDER_TEMPLATE = Template("""//CL//
592
+ % if double_support:
593
+ #if __OPENCL_C_VERSION__ < 120
594
+ #pragma OPENCL EXTENSION cl_khr_fp64: enable
595
+ #endif
596
+ #define PYOPENCL_DEFINE_CDOUBLE
597
+ % endif
598
+
599
+ #include <pyopencl-complex.h>
600
+
601
+ ${preamble}
602
+
603
+ // {{{ declare helper macros for user interface
604
+
605
+ typedef ${index_type} index_type;
606
+
607
+ %if is_count_stage:
608
+ #define PLB_COUNT_STAGE
609
+
610
+ %for name, dtype in list_names_and_dtypes:
611
+ %if name in count_sharing:
612
+ #define APPEND_${name}(value) { /* nothing */ }
613
+ %else:
614
+ #define APPEND_${name}(value) { ++(*plb_loc_${name}_count); }
615
+ %endif
616
+ %endfor
617
+ %else:
618
+ #define PLB_WRITE_STAGE
619
+
620
+ %for name, dtype in list_names_and_dtypes:
621
+ %if name in count_sharing:
622
+ #define APPEND_${name}(value) \
623
+ { plb_${name}_list[(*plb_${count_sharing[name]}_index) - 1] \
624
+ = value; }
625
+ %else:
626
+ #define APPEND_${name}(value) \
627
+ { plb_${name}_list[(*plb_${name}_index)++] = value; }
628
+ %endif
629
+ %endfor
630
+ %endif
631
+
632
+ #define LIST_ARG_DECL ${user_list_arg_decl}
633
+ #define LIST_ARGS ${user_list_args}
634
+ #define USER_ARG_DECL ${user_arg_decl_no_offset}
635
+ #define USER_ARGS ${user_args_no_offset}
636
+
637
+ // }}}
638
+
639
+ ${generate_template}
640
+
641
+ // {{{ kernel entry point
642
+
643
+ __kernel
644
+ %if do_not_vectorize:
645
+ __attribute__((reqd_work_group_size(1, 1, 1)))
646
+ %endif
647
+ void ${kernel_name}(
648
+ ${kernel_list_arg_decl} ${user_arg_decl_with_offset} index_type n)
649
+
650
+ {
651
+ %if not do_not_vectorize:
652
+ int lid = get_local_id(0);
653
+ index_type gsize = get_global_size(0);
654
+ index_type work_group_start = get_local_size(0)*get_group_id(0);
655
+ for (index_type i = work_group_start + lid; i < n; i += gsize)
656
+ %else:
657
+ const int chunk_size = 128;
658
+ index_type chunk_base = get_global_id(0)*chunk_size;
659
+ index_type gsize = get_global_size(0);
660
+ for (; chunk_base < n; chunk_base += gsize*chunk_size)
661
+ for (index_type i = chunk_base; i < min(n, chunk_base+chunk_size); ++i)
662
+ %endif
663
+ {
664
+ %if is_count_stage:
665
+ %for name, dtype in list_names_and_dtypes:
666
+ %if name not in count_sharing:
667
+ index_type plb_loc_${name}_count = 0;
668
+ %endif
669
+ %endfor
670
+ %else:
671
+ %for name, dtype in list_names_and_dtypes:
672
+ %if name not in count_sharing:
673
+ index_type plb_${name}_index;
674
+ if (plb_${name}_start_index)
675
+ %if name in eliminate_empty_output_lists:
676
+ plb_${name}_index =
677
+ plb_${name}_start_index[
678
+ ${name}_compressed_indices[i]
679
+ ];
680
+ %else:
681
+ plb_${name}_index = plb_${name}_start_index[i];
682
+ %endif
683
+ else
684
+ plb_${name}_index = 0;
685
+ %endif
686
+ %endfor
687
+ %endif
688
+
689
+ ${arg_offset_adjustment}
690
+ generate(${kernel_list_arg_values} USER_ARGS i);
691
+
692
+ %if is_count_stage:
693
+ %for name, dtype in list_names_and_dtypes:
694
+ %if name not in count_sharing:
695
+ if (plb_${name}_count)
696
+ plb_${name}_count[i] = plb_loc_${name}_count;
697
+ %endif
698
+ %endfor
699
+ %endif
700
+ }
701
+ }
702
+
703
+ // }}}
704
+
705
+ """, strict_undefined=True)
706
+
707
+ # }}}
708
+
709
+
710
+ def _get_arg_decl(arg_list):
711
+ result = ""
712
+ for arg in arg_list:
713
+ result += arg.declarator() + ", "
714
+
715
+ return result
716
+
717
+
718
+ def _get_arg_list(arg_list, prefix=""):
719
+ result = ""
720
+ for arg in arg_list:
721
+ result += prefix + arg.name + ", "
722
+
723
+ return result
724
+
725
+
726
+ @dataclass
727
+ class BuiltList:
728
+ count: Optional[int]
729
+ starts: Optional[pyopencl.array.Array]
730
+ lists: Optional[pyopencl.array.Array] = None
731
+ num_nonempty_lists: Optional[int] = None
732
+ nonempty_indices: Optional[pyopencl.array.Array] = None
733
+ compressed_indices: Optional[pyopencl.array.Array] = None
734
+
735
+
736
+ class ListOfListsBuilder:
737
+ """Generates and executes code to produce a large number of variable-size
738
+ lists, simply.
739
+
740
+ .. note:: This functionality is provided as a preview. Its interface
741
+ is subject to change until this notice is removed.
742
+
743
+ .. versionadded:: 2013.1
744
+
745
+ Here's a usage example::
746
+
747
+ from pyopencl.algorithm import ListOfListsBuilder
748
+ builder = ListOfListsBuilder(context, [("mylist", np.int32)], \"\"\"
749
+ void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
750
+ {
751
+ int count = i % 4;
752
+ for (int j = 0; j < count; ++j)
753
+ {
754
+ APPEND_mylist(count);
755
+ }
756
+ }
757
+ \"\"\", arg_decls=[])
758
+
759
+ result, event = builder(queue, 2000)
760
+
761
+ inf = result["mylist"]
762
+ assert inf.count == 3000
763
+ assert (inf.list.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()
764
+
765
+ The function ``generate`` above is called once for each "input object".
766
+ Each input object can then generate zero or more list entries.
767
+ The number of these input objects is given to :meth:`__call__` as *n_objects*.
768
+ List entries are generated by calls to ``APPEND_<list name>(value)``.
769
+ Multiple lists may be generated at once.
770
+
771
+ .. automethod:: __init__
772
+ .. automethod:: __call__
773
+ """
774
+ def __init__(self, context, list_names_and_dtypes, generate_template,
775
+ arg_decls, count_sharing=None, devices=None,
776
+ name_prefix="plb_build_list", options=None, preamble="",
777
+ debug=False, complex_kernel=False,
778
+ eliminate_empty_output_lists=False):
779
+ """
780
+ :arg context: A :class:`pyopencl.Context`.
781
+ :arg list_names_and_dtypes: a list of ``(name, dtype)`` tuples
782
+ indicating the lists to be built.
783
+ :arg generate_template: a snippet of C as described below
784
+ :arg arg_decls: A string of comma-separated C argument declarations.
785
+ :arg count_sharing: A mapping consisting of ``(child, mother)``
786
+ indicating that ``mother`` and ``child`` will always have the
787
+ same number of indices, and the ``APPEND`` to ``mother``
788
+ will always happen *before* the ``APPEND`` to the child.
789
+ :arg name_prefix: the name prefix to use for the compiled kernels
790
+ :arg options: OpenCL compilation options for kernels using
791
+ *generate_template*.
792
+ :arg complex_kernel: If *True*, prevents vectorization on CPUs.
793
+ :arg eliminate_empty_output_lists: A Python list of list names
794
+ for which the empty output lists are eliminated.
795
+
796
+ *generate_template* may use the following C macros/identifiers:
797
+
798
+ * ``index_type``: expands to C identifier for the index type used
799
+ for the calculation
800
+ * ``USER_ARG_DECL``: expands to the C declarator for ``arg_decls``
801
+ * ``USER_ARGS``: a list of C argument values corresponding to
802
+ ``user_arg_decl``
803
+ * ``LIST_ARG_DECL``: expands to a C argument list representing the
804
+ data for the output lists. These are escaped prefixed with
805
+ ``"plg_"`` so as to not interfere with user-provided names.
806
+ * ``LIST_ARGS``: a list of C argument values corresponding to
807
+ ``LIST_ARG_DECL``
808
+ * ``APPEND_name(entry)``: inserts ``entry`` into the list ``name``.
809
+ *entry* must be a valid C expression of the correct type.
810
+
811
+ All argument-list related macros have a trailing comma included
812
+ if they are non-empty.
813
+
814
+ *generate_template* must supply a function:
815
+
816
+ .. code-block:: c
817
+
818
+ void generate(USER_ARG_DECL LIST_ARG_DECL index_type i)
819
+ {
820
+ APPEND_mylist(5);
821
+ }
822
+
823
+ Internally, the ``kernel_template`` is expanded (at least) twice. Once,
824
+ for a 'counting' stage where the size of all the lists is determined,
825
+ and a second time, for a 'generation' stage where the lists are
826
+ actually filled. A ``generate`` function that has side effects beyond
827
+ calling ``append`` is therefore ill-formed.
828
+
829
+ .. versionchanged:: 2018.1
830
+
831
+ Change *eliminate_empty_output_lists* argument type from ``bool`` to
832
+ ``list``.
833
+ """
834
+ if devices is None:
835
+ devices = context.devices
836
+
837
+ if count_sharing is None:
838
+ count_sharing = {}
839
+
840
+ self.context = context
841
+ self.devices = devices
842
+
843
+ self.list_names_and_dtypes = list_names_and_dtypes
844
+ self.generate_template = generate_template
845
+
846
+ from pyopencl.tools import parse_arg_list
847
+ self.arg_decls = parse_arg_list(arg_decls)
848
+
849
+ # To match with the signature of the user-supplied generate(), arguments
850
+ # can't appear to have offsets.
851
+ arg_decls_no_offset = []
852
+ from pyopencl.tools import VectorArg
853
+ for arg in self.arg_decls:
854
+ if isinstance(arg, VectorArg) and arg.with_offset:
855
+ arg = VectorArg(arg.dtype, arg.name)
856
+ arg_decls_no_offset.append(arg)
857
+ self.arg_decls_no_offset = arg_decls_no_offset
858
+
859
+ self.count_sharing = count_sharing
860
+
861
+ self.name_prefix = name_prefix
862
+ self.preamble = preamble
863
+ self.options = options
864
+
865
+ self.debug = debug
866
+
867
+ self.complex_kernel = complex_kernel
868
+
869
+ if eliminate_empty_output_lists is True:
870
+ eliminate_empty_output_lists = \
871
+ [name for name, _ in self.list_names_and_dtypes]
872
+
873
+ if eliminate_empty_output_lists is False:
874
+ eliminate_empty_output_lists = []
875
+
876
+ self.eliminate_empty_output_lists = eliminate_empty_output_lists
877
+ for list_name in self.eliminate_empty_output_lists:
878
+ if not any(list_name == name for name, _ in self.list_names_and_dtypes):
879
+ raise ValueError(
880
+ "invalid list name '%s' in eliminate_empty_output_lists"
881
+ % list_name)
882
+
883
+ # {{{ kernel generators
884
+
885
+ @memoize_method
886
+ def get_scan_kernel(self, index_dtype):
887
+ return GenericScanKernel(
888
+ self.context, index_dtype,
889
+ arguments="__global %s *ary" % dtype_to_ctype(index_dtype),
890
+ input_expr="ary[i]",
891
+ scan_expr="a+b", neutral="0",
892
+ output_statement="ary[i+1] = item;",
893
+ devices=self.devices)
894
+
895
+ @memoize_method
896
+ def get_compress_kernel(self, index_dtype):
897
+ arguments = """
898
+ __global ${index_t} *count,
899
+ __global ${index_t} *compressed_counts,
900
+ __global ${index_t} *nonempty_indices,
901
+ __global ${index_t} *compressed_indices,
902
+ __global ${index_t} *num_non_empty_list
903
+ """
904
+ arguments = Template(arguments)
905
+
906
+ return GenericScanKernel(
907
+ self.context, index_dtype,
908
+ arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)),
909
+ input_expr="count[i] == 0 ? 0 : 1",
910
+ scan_expr="a+b", neutral="0",
911
+ output_statement="""
912
+ if (i + 1 < N) compressed_indices[i + 1] = item;
913
+ if (prev_item != item) {
914
+ nonempty_indices[item - 1] = i;
915
+ compressed_counts[item - 1] = count[i];
916
+ }
917
+ if (i + 1 == N) *num_non_empty_list = item;
918
+ """,
919
+ devices=self.devices)
920
+
921
+ def do_not_vectorize(self):
922
+ return (self.complex_kernel
923
+ and any(dev.type & cl.device_type.CPU
924
+ for dev in self.context.devices))
925
+
926
+ @memoize_method
927
+ def get_count_kernel(self, index_dtype):
928
+ index_ctype = dtype_to_ctype(index_dtype)
929
+ from pyopencl.tools import OtherArg, VectorArg
930
+ kernel_list_args = [
931
+ VectorArg(index_dtype, "plb_%s_count" % name)
932
+ for name, dtype in self.list_names_and_dtypes
933
+ if name not in self.count_sharing]
934
+
935
+ user_list_args = []
936
+ for name, _dtype in self.list_names_and_dtypes:
937
+ if name in self.count_sharing:
938
+ continue
939
+
940
+ name = "plb_loc_%s_count" % name
941
+ user_list_args.append(OtherArg("{} *{}".format(
942
+ index_ctype, name), name))
943
+
944
+ kernel_name = self.name_prefix+"_count"
945
+
946
+ from pyopencl.characterize import has_double_support
947
+ src = _LIST_BUILDER_TEMPLATE.render(
948
+ is_count_stage=True,
949
+ kernel_name=kernel_name,
950
+ double_support=all(has_double_support(dev) for dev in
951
+ self.context.devices),
952
+ debug=self.debug,
953
+ do_not_vectorize=self.do_not_vectorize(),
954
+ eliminate_empty_output_lists=self.eliminate_empty_output_lists,
955
+
956
+ kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
957
+ kernel_list_arg_values=_get_arg_list(user_list_args, prefix="&"),
958
+ user_list_arg_decl=_get_arg_decl(user_list_args),
959
+ user_list_args=_get_arg_list(user_list_args),
960
+ user_arg_decl_with_offset=_get_arg_decl(self.arg_decls),
961
+ user_arg_decl_no_offset=_get_arg_decl(self.arg_decls_no_offset),
962
+ user_args_no_offset=_get_arg_list(self.arg_decls_no_offset),
963
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.arg_decls),
964
+
965
+ list_names_and_dtypes=self.list_names_and_dtypes,
966
+ count_sharing=self.count_sharing,
967
+ name_prefix=self.name_prefix,
968
+ generate_template=self.generate_template,
969
+ preamble=self.preamble,
970
+
971
+ index_type=index_ctype,
972
+ )
973
+
974
+ src = str(src)
975
+
976
+ prg = cl.Program(self.context, src).build(self.options)
977
+ knl = getattr(prg, kernel_name)
978
+
979
+ from pyopencl.tools import get_arg_list_scalar_arg_dtypes
980
+ knl.set_scalar_arg_dtypes([
981
+ *get_arg_list_scalar_arg_dtypes([*kernel_list_args, *self.arg_decls]),
982
+ index_dtype
983
+ ])
984
+
985
+ return knl
986
+
987
+ @memoize_method
988
+ def get_write_kernel(self, index_dtype):
989
+ index_ctype = dtype_to_ctype(index_dtype)
990
+ from pyopencl.tools import OtherArg, VectorArg
991
+ kernel_list_args = []
992
+ kernel_list_arg_values = ""
993
+ user_list_args = []
994
+
995
+ for name, dtype in self.list_names_and_dtypes:
996
+ list_name = "plb_%s_list" % name
997
+ list_arg = VectorArg(dtype, list_name)
998
+
999
+ kernel_list_args.append(list_arg)
1000
+ user_list_args.append(list_arg)
1001
+
1002
+ if name in self.count_sharing:
1003
+ kernel_list_arg_values += "%s, " % list_name
1004
+ continue
1005
+
1006
+ kernel_list_args.append(
1007
+ VectorArg(index_dtype, "plb_%s_start_index" % name))
1008
+
1009
+ if name in self.eliminate_empty_output_lists:
1010
+ kernel_list_args.append(
1011
+ VectorArg(index_dtype, "%s_compressed_indices" % name))
1012
+
1013
+ index_name = "plb_%s_index" % name
1014
+ user_list_args.append(OtherArg("{} *{}".format(
1015
+ index_ctype, index_name), index_name))
1016
+
1017
+ kernel_list_arg_values += f"{list_name}, &{index_name}, "
1018
+
1019
+ kernel_name = self.name_prefix+"_write"
1020
+
1021
+ from pyopencl.characterize import has_double_support
1022
+ src = _LIST_BUILDER_TEMPLATE.render(
1023
+ is_count_stage=False,
1024
+ kernel_name=kernel_name,
1025
+ double_support=all(has_double_support(dev) for dev in
1026
+ self.context.devices),
1027
+ debug=self.debug,
1028
+ do_not_vectorize=self.do_not_vectorize(),
1029
+ eliminate_empty_output_lists=self.eliminate_empty_output_lists,
1030
+
1031
+ kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
1032
+ kernel_list_arg_values=kernel_list_arg_values,
1033
+ user_list_arg_decl=_get_arg_decl(user_list_args),
1034
+ user_list_args=_get_arg_list(user_list_args),
1035
+ user_arg_decl_with_offset=_get_arg_decl(self.arg_decls),
1036
+ user_arg_decl_no_offset=_get_arg_decl(self.arg_decls_no_offset),
1037
+ user_args_no_offset=_get_arg_list(self.arg_decls_no_offset),
1038
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.arg_decls),
1039
+
1040
+ list_names_and_dtypes=self.list_names_and_dtypes,
1041
+ count_sharing=self.count_sharing,
1042
+ name_prefix=self.name_prefix,
1043
+ generate_template=self.generate_template,
1044
+ preamble=self.preamble,
1045
+
1046
+ index_type=index_ctype,
1047
+ )
1048
+
1049
+ src = str(src)
1050
+
1051
+ prg = cl.Program(self.context, src).build(self.options)
1052
+ knl = getattr(prg, kernel_name)
1053
+
1054
+ from pyopencl.tools import get_arg_list_scalar_arg_dtypes
1055
+ knl.set_scalar_arg_dtypes([
1056
+ *get_arg_list_scalar_arg_dtypes(kernel_list_args + self.arg_decls),
1057
+ index_dtype])
1058
+
1059
+ return knl
1060
+
1061
+ # }}}
1062
+
1063
+ # {{{ driver
1064
+
1065
+ def __call__(self, queue, n_objects, *args, **kwargs):
1066
+ """
1067
+ :arg args: arguments corresponding to ``arg_decls`` in the constructor.
1068
+ Array-like arguments must be either 1D :class:`pyopencl.array.Array`
1069
+ objects or :class:`pyopencl.MemoryObject` objects, of which the latter
1070
+ can be obtained from a :class:`pyopencl.array.Array` using the
1071
+ :attr:`pyopencl.array.Array.data` attribute.
1072
+ :arg allocator: optionally, the allocator to use to allocate new
1073
+ arrays.
1074
+ :arg omit_lists: an iterable of list names that should *not* be built
1075
+ with this invocation. The kernel code may *not* call ``APPEND_name``
1076
+ for these omitted lists. If it does, undefined behavior will result.
1077
+ The returned *lists* dictionary will not contain an entry for names
1078
+ in *omit_lists*.
1079
+ :arg wait_for: |explain-waitfor|
1080
+ :returns: a tuple ``(lists, event)``, where ``lists`` is a mapping from
1081
+ (built) list names to objects which have attributes
1082
+
1083
+ * ``count`` for the total number of entries in all lists combined
1084
+ * ``lists`` for the array containing all lists.
1085
+ * ``starts`` for the array of starting indices in ``lists``.
1086
+ ``starts`` is built so that it has n+1 entries, so that
1087
+ the *i*'th entry is the start of the *i*'th list, and the
1088
+ *i*'th entry is the index one past the *i*'th list's end,
1089
+ even for the last list.
1090
+
1091
+ This implies that all lists are contiguous.
1092
+
1093
+ If the list name is specified in *eliminate_empty_output_lists*
1094
+ constructor argument, *lists* has two additional attributes
1095
+ ``num_nonempty_lists`` and ``nonempty_indices``
1096
+
1097
+ * ``num_nonempty_lists`` for the number of nonempty lists.
1098
+ * ``nonempty_indices`` for the index of nonempty list in input objects.
1099
+
1100
+ In this case, ``starts`` has ``num_nonempty_lists + 1`` entries.
1101
+ The *i*'s entry is the start of the *i*'th nonempty list, which is
1102
+ generated by the object with index ``nonempty_indices[i]``.
1103
+
1104
+ *event* is a :class:`pyopencl.Event` for dependency management.
1105
+
1106
+ .. versionchanged:: 2016.2
1107
+
1108
+ Added omit_lists.
1109
+ """
1110
+ if n_objects >= int(np.iinfo(np.int32).max):
1111
+ index_dtype = np.int64
1112
+ else:
1113
+ index_dtype = np.int32
1114
+ index_dtype = np.dtype(index_dtype)
1115
+
1116
+ allocator = kwargs.pop("allocator", None)
1117
+ omit_lists = kwargs.pop("omit_lists", [])
1118
+ wait_for = kwargs.pop("wait_for", None)
1119
+ if kwargs:
1120
+ raise TypeError("invalid keyword arguments: '%s'" % ", ".join(kwargs))
1121
+
1122
+ for oml in omit_lists:
1123
+ if not any(oml == name for name, _ in self.list_names_and_dtypes):
1124
+ raise ValueError("invalid list name '%s' in omit_lists")
1125
+
1126
+ result = {}
1127
+ count_list_args = []
1128
+
1129
+ if wait_for is None:
1130
+ wait_for = []
1131
+ else:
1132
+ # We'll be modifying it below.
1133
+ wait_for = list(wait_for)
1134
+
1135
+ count_kernel = self.get_count_kernel(index_dtype)
1136
+ write_kernel = self.get_write_kernel(index_dtype)
1137
+ scan_kernel = self.get_scan_kernel(index_dtype)
1138
+ if self.eliminate_empty_output_lists:
1139
+ compress_kernel = self.get_compress_kernel(index_dtype)
1140
+
1141
+ data_args = []
1142
+ for i, (arg_descr, arg_val) in enumerate(zip(self.arg_decls, args)):
1143
+ from pyopencl.tools import VectorArg
1144
+ if isinstance(arg_descr, VectorArg):
1145
+ from pyopencl import MemoryObject
1146
+ if arg_val is None:
1147
+ data_args.append(arg_val)
1148
+ if arg_descr.with_offset:
1149
+ data_args.append(0)
1150
+ continue
1151
+
1152
+ if isinstance(arg_val, MemoryObject):
1153
+ data_args.append(arg_val)
1154
+ if arg_descr.with_offset:
1155
+ raise ValueError(
1156
+ "with_offset=True specified for argument %d "
1157
+ "but the argument is not an array" % i)
1158
+ continue
1159
+
1160
+ if arg_val.ndim != 1:
1161
+ raise ValueError("argument %d is a multidimensional array" % i)
1162
+
1163
+ data_args.append(arg_val.base_data)
1164
+ if arg_descr.with_offset:
1165
+ data_args.append(arg_val.offset)
1166
+ wait_for.extend(arg_val.events)
1167
+ else:
1168
+ data_args.append(arg_val)
1169
+
1170
+ del args
1171
+ data_args = tuple(data_args)
1172
+
1173
+ # {{{ allocate memory for counts
1174
+
1175
+ for name, _dtype in self.list_names_and_dtypes:
1176
+ if name in self.count_sharing:
1177
+ continue
1178
+ if name in omit_lists:
1179
+ count_list_args.append(None)
1180
+ continue
1181
+
1182
+ counts = cl.array.empty(queue,
1183
+ (n_objects + 1), index_dtype, allocator=allocator)
1184
+ counts[-1] = 0
1185
+ wait_for = wait_for + counts.events
1186
+
1187
+ # The scan will turn the "counts" array into the "starts" array
1188
+ # in-place.
1189
+ if name in self.eliminate_empty_output_lists:
1190
+ result[name] = BuiltList(count=None, starts=counts, lists=None,
1191
+ num_nonempty_lists=None,
1192
+ nonempty_indices=None)
1193
+ else:
1194
+ result[name] = BuiltList(count=None, starts=counts, lists=None)
1195
+ count_list_args.append(counts.data)
1196
+
1197
+ # }}}
1198
+
1199
+ if self.debug:
1200
+ gsize = (1,)
1201
+ lsize = (1,)
1202
+ elif self.do_not_vectorize():
1203
+ gsize = (4*queue.device.max_compute_units,)
1204
+ lsize = (1,)
1205
+ else:
1206
+ from pyopencl.array import _splay
1207
+ gsize, lsize = _splay(queue.device, n_objects)
1208
+
1209
+ count_event = count_kernel(queue, gsize, lsize,
1210
+ *(tuple(count_list_args) + data_args + (n_objects,)),
1211
+ wait_for=wait_for)
1212
+
1213
+ compress_events = {}
1214
+ for name, _dtype in self.list_names_and_dtypes:
1215
+ if name in omit_lists:
1216
+ continue
1217
+ if name in self.count_sharing:
1218
+ continue
1219
+ if name not in self.eliminate_empty_output_lists:
1220
+ continue
1221
+
1222
+ compressed_counts = cl.array.empty(
1223
+ queue, (n_objects + 1,), index_dtype, allocator=allocator)
1224
+ info_record = result[name]
1225
+ info_record.nonempty_indices = cl.array.empty(
1226
+ queue, (n_objects + 1,), index_dtype, allocator=allocator)
1227
+ info_record.num_nonempty_lists = cl.array.empty(
1228
+ queue, (1,), index_dtype, allocator=allocator)
1229
+ info_record.compressed_indices = cl.array.empty(
1230
+ queue, (n_objects + 1,), index_dtype, allocator=allocator)
1231
+ info_record.compressed_indices[0] = 0
1232
+
1233
+ compress_events[name] = compress_kernel( # pylint: disable=possibly-used-before-assignment
1234
+ info_record.starts,
1235
+ compressed_counts,
1236
+ info_record.nonempty_indices,
1237
+ info_record.compressed_indices,
1238
+ info_record.num_nonempty_lists,
1239
+ wait_for=[count_event, *info_record.compressed_indices.events])
1240
+
1241
+ info_record.starts = compressed_counts
1242
+
1243
+ # {{{ run scans
1244
+
1245
+ scan_events = []
1246
+
1247
+ for name, _dtype in self.list_names_and_dtypes:
1248
+ if name in self.count_sharing:
1249
+ continue
1250
+ if name in omit_lists:
1251
+ continue
1252
+
1253
+ info_record = result[name]
1254
+ if name in self.eliminate_empty_output_lists:
1255
+ compress_events[name].wait()
1256
+ num_nonempty_lists = info_record.num_nonempty_lists.get()[0]
1257
+ info_record.num_nonempty_lists = num_nonempty_lists
1258
+ info_record.starts = info_record.starts[:num_nonempty_lists + 1]
1259
+ info_record.nonempty_indices = \
1260
+ info_record.nonempty_indices[:num_nonempty_lists]
1261
+ info_record.starts[-1] = 0
1262
+
1263
+ starts_ary = info_record.starts
1264
+ if name in self.eliminate_empty_output_lists:
1265
+ evt = scan_kernel(
1266
+ starts_ary,
1267
+ size=info_record.num_nonempty_lists,
1268
+ wait_for=starts_ary.events)
1269
+ else:
1270
+ evt = scan_kernel(starts_ary, wait_for=[count_event],
1271
+ size=n_objects)
1272
+
1273
+ starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
1274
+ scan_events.extend(starts_ary.events)
1275
+
1276
+ # retrieve count
1277
+ info_record.count = int(starts_ary[-1].get())
1278
+
1279
+ # }}}
1280
+
1281
+ # {{{ deal with count-sharing lists, allocate memory for lists
1282
+
1283
+ write_list_args = []
1284
+ for name, dtype in self.list_names_and_dtypes:
1285
+ if name in omit_lists:
1286
+ write_list_args.append(None)
1287
+ if name not in self.count_sharing:
1288
+ write_list_args.append(None)
1289
+ if name in self.eliminate_empty_output_lists:
1290
+ write_list_args.append(None)
1291
+ continue
1292
+
1293
+ if name in self.count_sharing:
1294
+ sharing_from = self.count_sharing[name]
1295
+
1296
+ info_record = result[name] = BuiltList(
1297
+ count=result[sharing_from].count,
1298
+ starts=result[sharing_from].starts,
1299
+ )
1300
+
1301
+ else:
1302
+ info_record = result[name]
1303
+
1304
+ info_record.lists = cl.array.empty(queue,
1305
+ info_record.count, dtype, allocator=allocator)
1306
+ write_list_args.append(info_record.lists.data)
1307
+
1308
+ if name not in self.count_sharing:
1309
+ write_list_args.append(info_record.starts.data)
1310
+
1311
+ if name in self.eliminate_empty_output_lists:
1312
+ write_list_args.append(info_record.compressed_indices.data)
1313
+
1314
+ # }}}
1315
+
1316
+ evt = write_kernel(queue, gsize, lsize,
1317
+ *(tuple(write_list_args) + data_args + (n_objects,)),
1318
+ wait_for=scan_events)
1319
+
1320
+ return result, evt
1321
+
1322
+ # }}}
1323
+
1324
+ # }}}
1325
+
1326
+
1327
+ # {{{ key-value sorting
1328
+
1329
+ @dataclass(frozen=True)
1330
+ class _KernelInfo:
1331
+ by_target_sorter: RadixSort
1332
+ start_finder: ElementwiseKernel
1333
+ bound_propagation_scan: GenericScanKernel
1334
+
1335
+
1336
+ def _make_cl_int_literal(value, dtype):
1337
+ iinfo = np.iinfo(dtype)
1338
+ result = str(int(value))
1339
+ if dtype.itemsize == 8:
1340
+ result += "l"
1341
+ if int(iinfo.min) < 0:
1342
+ result += "u"
1343
+
1344
+ return result
1345
+
1346
+
1347
+ class KeyValueSorter:
1348
+ """Given arrays *values* and *keys* of equal length
1349
+ and a number *nkeys* of keys, returns a tuple `(starts,
1350
+ lists)`, as follows: *values* and *keys* are sorted
1351
+ by *keys*, and the sorted *values* is returned as
1352
+ *lists*. Then for each index *i* in ``range(nkeys)``,
1353
+ *starts[i]* is written to indicating where the
1354
+ group of *values* belonging to the key with index
1355
+ *i* begins. It implicitly ends at *starts[i+1]*.
1356
+
1357
+ ``starts`` is built so that it has ``nkeys + 1`` entries, so that
1358
+ the *i*'th entry is the start of the *i*'th list, and the
1359
+ *i*'th entry is the index one past the *i*'th list's end,
1360
+ even for the last list.
1361
+
1362
+ This implies that all lists are contiguous.
1363
+
1364
+ .. note:: This functionality is provided as a preview. Its
1365
+ interface is subject to change until this notice is removed.
1366
+
1367
+ .. versionadded:: 2013.1
1368
+ """
1369
+
1370
+ def __init__(self, context):
1371
+ self.context = context
1372
+
1373
+ @memoize_method
1374
+ def get_kernels(self, key_dtype, value_dtype, starts_dtype):
1375
+ from pyopencl.tools import ScalarArg, VectorArg
1376
+
1377
+ by_target_sorter = RadixSort(
1378
+ self.context, [
1379
+ VectorArg(value_dtype, "values"),
1380
+ VectorArg(key_dtype, "keys"),
1381
+ ],
1382
+ key_expr="keys[i]",
1383
+ sort_arg_names=["values", "keys"])
1384
+
1385
+ from pyopencl.elementwise import ElementwiseTemplate
1386
+ start_finder = ElementwiseTemplate(
1387
+ arguments="""//CL//
1388
+ starts_t *key_group_starts,
1389
+ key_t *keys_sorted_by_key,
1390
+ """,
1391
+
1392
+ operation=r"""//CL//
1393
+ key_t my_key = keys_sorted_by_key[i];
1394
+
1395
+ if (i == 0 || my_key != keys_sorted_by_key[i-1])
1396
+ key_group_starts[my_key] = i;
1397
+ """,
1398
+ name="find_starts").build(self.context,
1399
+ type_aliases=(
1400
+ ("key_t", starts_dtype),
1401
+ ("starts_t", starts_dtype),
1402
+ ),
1403
+ var_values=())
1404
+
1405
+ bound_propagation_scan = GenericScanKernel(
1406
+ self.context, starts_dtype,
1407
+ arguments=[
1408
+ VectorArg(starts_dtype, "starts"),
1409
+ # starts has length n+1
1410
+ ScalarArg(key_dtype, "nkeys"),
1411
+ ],
1412
+ input_expr="starts[nkeys-i]",
1413
+ scan_expr="min(a, b)",
1414
+ neutral=_make_cl_int_literal(
1415
+ np.iinfo(starts_dtype).max, starts_dtype),
1416
+ output_statement="starts[nkeys-i] = item;")
1417
+
1418
+ return _KernelInfo(
1419
+ by_target_sorter=by_target_sorter,
1420
+ start_finder=start_finder,
1421
+ bound_propagation_scan=bound_propagation_scan)
1422
+
1423
+ def __call__(self, queue, keys, values, nkeys,
1424
+ starts_dtype, allocator=None, wait_for=None):
1425
+ if allocator is None:
1426
+ allocator = values.allocator
1427
+
1428
+ knl_info = self.get_kernels(keys.dtype, values.dtype,
1429
+ starts_dtype)
1430
+
1431
+ (values_sorted_by_key, keys_sorted_by_key), evt = knl_info.by_target_sorter(
1432
+ values, keys, queue=queue, wait_for=wait_for)
1433
+
1434
+ starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
1435
+ .fill(len(values_sorted_by_key), wait_for=[evt]))
1436
+ evt, = starts.events
1437
+
1438
+ evt = knl_info.start_finder(starts, keys_sorted_by_key,
1439
+ range=slice(len(keys_sorted_by_key)),
1440
+ wait_for=[evt])
1441
+
1442
+ evt = knl_info.bound_propagation_scan(starts, nkeys,
1443
+ queue=queue, wait_for=[evt])
1444
+
1445
+ return starts, values_sorted_by_key, evt
1446
+
1447
+ # }}}
1448
+
1449
+ # vim: filetype=pyopencl:fdm=marker