pyopencl 2026.1.1__cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyopencl/.libs/libOpenCL-34a55fe4.so.1.0.0 +0 -0
- pyopencl/__init__.py +1995 -0
- pyopencl/_cl.cpython-314t-aarch64-linux-gnu.so +0 -0
- pyopencl/_cl.pyi +2009 -0
- pyopencl/_cluda.py +57 -0
- pyopencl/_monkeypatch.py +1104 -0
- pyopencl/_mymako.py +17 -0
- pyopencl/algorithm.py +1454 -0
- pyopencl/array.py +3530 -0
- pyopencl/bitonic_sort.py +245 -0
- pyopencl/bitonic_sort_templates.py +597 -0
- pyopencl/cache.py +553 -0
- pyopencl/capture_call.py +200 -0
- pyopencl/characterize/__init__.py +461 -0
- pyopencl/characterize/performance.py +240 -0
- pyopencl/cl/pyopencl-airy.cl +324 -0
- pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
- pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
- pyopencl/cl/pyopencl-bessel-y.cl +435 -0
- pyopencl/cl/pyopencl-complex.h +303 -0
- pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
- pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
- pyopencl/cl/pyopencl-random123/array.h +325 -0
- pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
- pyopencl/cl/pyopencl-random123/philox.cl +486 -0
- pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
- pyopencl/clmath.py +281 -0
- pyopencl/clrandom.py +412 -0
- pyopencl/cltypes.py +217 -0
- pyopencl/compyte/.gitignore +21 -0
- pyopencl/compyte/__init__.py +0 -0
- pyopencl/compyte/array.py +211 -0
- pyopencl/compyte/dtypes.py +314 -0
- pyopencl/compyte/pyproject.toml +49 -0
- pyopencl/elementwise.py +1288 -0
- pyopencl/invoker.py +417 -0
- pyopencl/ipython_ext.py +70 -0
- pyopencl/py.typed +0 -0
- pyopencl/reduction.py +829 -0
- pyopencl/scan.py +1921 -0
- pyopencl/tools.py +1680 -0
- pyopencl/typing.py +61 -0
- pyopencl/version.py +11 -0
- pyopencl-2026.1.1.dist-info/METADATA +108 -0
- pyopencl-2026.1.1.dist-info/RECORD +47 -0
- pyopencl-2026.1.1.dist-info/WHEEL +6 -0
- pyopencl-2026.1.1.dist-info/licenses/LICENSE +104 -0
pyopencl/algorithm.py
ADDED
|
@@ -0,0 +1,1454 @@
|
|
|
1
|
+
"""Algorithms built on scans."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__copyright__ = """
|
|
6
|
+
Copyright 2011-2012 Andreas Kloeckner
|
|
7
|
+
Copyright 2017 Hao Gao
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__license__ = """
|
|
11
|
+
Permission is hereby granted, free of charge, to any person
|
|
12
|
+
obtaining a copy of this software and associated documentation
|
|
13
|
+
files (the "Software"), to deal in the Software without
|
|
14
|
+
restriction, including without limitation the rights to use,
|
|
15
|
+
copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the
|
|
17
|
+
Software is furnished to do so, subject to the following
|
|
18
|
+
conditions:
|
|
19
|
+
|
|
20
|
+
The above copyright notice and this permission notice shall be
|
|
21
|
+
included in all copies or substantial portions of the Software.
|
|
22
|
+
|
|
23
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
24
|
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
|
25
|
+
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
26
|
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
|
27
|
+
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
|
28
|
+
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
29
|
+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
|
30
|
+
OTHER DEALINGS IN THE SOFTWARE.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from dataclasses import dataclass
|
|
34
|
+
from typing import TYPE_CHECKING
|
|
35
|
+
|
|
36
|
+
import numpy as np
|
|
37
|
+
from mako.template import Template
|
|
38
|
+
|
|
39
|
+
from pytools import memoize, memoize_method
|
|
40
|
+
|
|
41
|
+
import pyopencl as cl
|
|
42
|
+
import pyopencl.array as cl_array
|
|
43
|
+
from pyopencl.scan import GenericScanKernel, ScanTemplate
|
|
44
|
+
from pyopencl.tools import dtype_to_ctype, get_arg_offset_adjuster_code
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from pyopencl.elementwise import ElementwiseKernel
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# {{{ "extra args" handling utility
|
|
52
|
+
|
|
53
|
+
def _extract_extra_args_types_values(extra_args):
|
|
54
|
+
if extra_args is None:
|
|
55
|
+
extra_args = []
|
|
56
|
+
from pyopencl.tools import ScalarArg, VectorArg
|
|
57
|
+
|
|
58
|
+
extra_args_types = []
|
|
59
|
+
extra_args_values = []
|
|
60
|
+
extra_wait_for = []
|
|
61
|
+
for name, val in extra_args:
|
|
62
|
+
if isinstance(val, cl_array.Array):
|
|
63
|
+
extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
|
|
64
|
+
extra_args_values.append(val)
|
|
65
|
+
extra_wait_for.extend(val.events)
|
|
66
|
+
elif isinstance(val, np.generic):
|
|
67
|
+
extra_args_types.append(ScalarArg(val.dtype, name))
|
|
68
|
+
extra_args_values.append(val)
|
|
69
|
+
else:
|
|
70
|
+
raise RuntimeError("argument '%d' not understood" % name)
|
|
71
|
+
|
|
72
|
+
return tuple(extra_args_types), extra_args_values, extra_wait_for
|
|
73
|
+
|
|
74
|
+
# }}}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# {{{ copy_if
|
|
78
|
+
|
|
79
|
+
_copy_if_template = ScanTemplate(
|
|
80
|
+
arguments="item_t *ary, item_t *out, scan_t *count",
|
|
81
|
+
input_expr="(%(predicate)s) ? 1 : 0",
|
|
82
|
+
scan_expr="a+b", neutral="0",
|
|
83
|
+
output_statement="""
|
|
84
|
+
if (prev_item != item) out[item-1] = ary[i];
|
|
85
|
+
if (i+1 == N) *count = item;
|
|
86
|
+
""",
|
|
87
|
+
template_processor="printf")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def copy_if(ary, predicate, extra_args=None, preamble="", queue=None, wait_for=None):
|
|
91
|
+
"""Copy the elements of *ary* satisfying *predicate* to an output array.
|
|
92
|
+
|
|
93
|
+
:arg predicate: a C expression evaluating to a ``bool``, represented as a string.
|
|
94
|
+
The value to test is available as ``ary[i]``, and if the expression evaluates
|
|
95
|
+
to ``true``, then this value ends up in the output.
|
|
96
|
+
:arg extra_args: |scan_extra_args|
|
|
97
|
+
:arg preamble: |preamble|
|
|
98
|
+
:arg wait_for: |explain-waitfor|
|
|
99
|
+
:returns: a tuple *(out, count, event)* where *out* is the output array, *count*
|
|
100
|
+
is an on-device scalar (fetch to host with ``count.get()``) indicating
|
|
101
|
+
how many elements satisfied *predicate*, and *event* is a
|
|
102
|
+
:class:`pyopencl.Event` for dependency management. *out* is allocated
|
|
103
|
+
to the same length as *ary*, but only the first *count* entries carry
|
|
104
|
+
meaning.
|
|
105
|
+
|
|
106
|
+
.. versionadded:: 2013.1
|
|
107
|
+
"""
|
|
108
|
+
if len(ary) > np.iinfo(np.int32).max:
|
|
109
|
+
scan_dtype = np.int64
|
|
110
|
+
else:
|
|
111
|
+
scan_dtype = np.int32
|
|
112
|
+
|
|
113
|
+
if wait_for is None:
|
|
114
|
+
wait_for = []
|
|
115
|
+
|
|
116
|
+
extra_args_types, extra_args_values, extra_wait_for = \
|
|
117
|
+
_extract_extra_args_types_values(extra_args)
|
|
118
|
+
wait_for = wait_for + extra_wait_for
|
|
119
|
+
|
|
120
|
+
knl = _copy_if_template.build(ary.context,
|
|
121
|
+
type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
|
|
122
|
+
var_values=(("predicate", predicate),),
|
|
123
|
+
more_preamble=preamble, more_arguments=extra_args_types)
|
|
124
|
+
out = cl_array.empty_like(ary)
|
|
125
|
+
count = ary._new_with_changes(data=None, offset=0,
|
|
126
|
+
shape=(), strides=(), dtype=scan_dtype)
|
|
127
|
+
|
|
128
|
+
evt = knl(ary, out, count, *extra_args_values,
|
|
129
|
+
queue=queue, wait_for=wait_for)
|
|
130
|
+
|
|
131
|
+
return out, count, evt
|
|
132
|
+
|
|
133
|
+
# }}}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# {{{ remove_if
|
|
137
|
+
|
|
138
|
+
def remove_if(ary, predicate, extra_args=None, preamble="",
|
|
139
|
+
queue=None, wait_for=None):
|
|
140
|
+
"""Copy the elements of *ary* not satisfying *predicate* to an output array.
|
|
141
|
+
|
|
142
|
+
:arg predicate: a C expression evaluating to a ``bool``, represented as a string.
|
|
143
|
+
The value to test is available as ``ary[i]``, and if the expression evaluates
|
|
144
|
+
to ``false``, then this value ends up in the output.
|
|
145
|
+
:arg extra_args: |scan_extra_args|
|
|
146
|
+
:arg preamble: |preamble|
|
|
147
|
+
:arg wait_for: |explain-waitfor|
|
|
148
|
+
:returns: a tuple *(out, count, event)* where *out* is the output array, *count*
|
|
149
|
+
is an on-device scalar (fetch to host with ``count.get()``) indicating
|
|
150
|
+
how many elements did not satisfy *predicate*, and *event* is a
|
|
151
|
+
:class:`pyopencl.Event` for dependency management.
|
|
152
|
+
|
|
153
|
+
.. versionadded:: 2013.1
|
|
154
|
+
"""
|
|
155
|
+
return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args,
|
|
156
|
+
preamble=preamble, queue=queue, wait_for=wait_for)
|
|
157
|
+
|
|
158
|
+
# }}}
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# {{{ partition
|
|
162
|
+
|
|
163
|
+
_partition_template = ScanTemplate(
|
|
164
|
+
arguments=(
|
|
165
|
+
"item_t *ary, item_t *out_true, item_t *out_false, "
|
|
166
|
+
"scan_t *count_true"),
|
|
167
|
+
input_expr="(%(predicate)s) ? 1 : 0",
|
|
168
|
+
scan_expr="a+b", neutral="0",
|
|
169
|
+
output_statement="""//CL//
|
|
170
|
+
if (prev_item != item)
|
|
171
|
+
out_true[item-1] = ary[i];
|
|
172
|
+
else
|
|
173
|
+
out_false[i-item] = ary[i];
|
|
174
|
+
if (i+1 == N) *count_true = item;
|
|
175
|
+
""",
|
|
176
|
+
template_processor="printf")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def partition(ary, predicate, extra_args=None, preamble="",
|
|
180
|
+
queue=None, wait_for=None):
|
|
181
|
+
"""Copy the elements of *ary* into one of two arrays depending on whether
|
|
182
|
+
they satisfy *predicate*.
|
|
183
|
+
|
|
184
|
+
:arg predicate: a C expression evaluating to a ``bool``, represented as a string.
|
|
185
|
+
The value to test is available as ``ary[i]``.
|
|
186
|
+
:arg extra_args: |scan_extra_args|
|
|
187
|
+
:arg preamble: |preamble|
|
|
188
|
+
:arg wait_for: |explain-waitfor|
|
|
189
|
+
:returns: a tuple *(out_true, out_false, count, event)* where *count*
|
|
190
|
+
is an on-device scalar (fetch to host with ``count.get()``) indicating
|
|
191
|
+
how many elements satisfied the predicate, and *event* is a
|
|
192
|
+
:class:`pyopencl.Event` for dependency management.
|
|
193
|
+
|
|
194
|
+
.. versionadded:: 2013.1
|
|
195
|
+
"""
|
|
196
|
+
if len(ary) > np.iinfo(np.uint32).max:
|
|
197
|
+
scan_dtype = np.uint64
|
|
198
|
+
else:
|
|
199
|
+
scan_dtype = np.uint32
|
|
200
|
+
|
|
201
|
+
if wait_for is None:
|
|
202
|
+
wait_for = []
|
|
203
|
+
|
|
204
|
+
extra_args_types, extra_args_values, extra_wait_for = \
|
|
205
|
+
_extract_extra_args_types_values(extra_args)
|
|
206
|
+
wait_for = wait_for + extra_wait_for
|
|
207
|
+
|
|
208
|
+
knl = _partition_template.build(
|
|
209
|
+
ary.context,
|
|
210
|
+
type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
|
|
211
|
+
var_values=(("predicate", predicate),),
|
|
212
|
+
more_preamble=preamble, more_arguments=extra_args_types)
|
|
213
|
+
|
|
214
|
+
out_true = cl_array.empty_like(ary)
|
|
215
|
+
out_false = cl_array.empty_like(ary)
|
|
216
|
+
count = ary._new_with_changes(data=None, offset=0,
|
|
217
|
+
shape=(), strides=(), dtype=scan_dtype)
|
|
218
|
+
|
|
219
|
+
evt = knl(ary, out_true, out_false, count, *extra_args_values,
|
|
220
|
+
queue=queue, wait_for=wait_for)
|
|
221
|
+
|
|
222
|
+
return out_true, out_false, count, evt
|
|
223
|
+
|
|
224
|
+
# }}}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# {{{ unique
|
|
228
|
+
|
|
229
|
+
_unique_template = ScanTemplate(
|
|
230
|
+
arguments="item_t *ary, item_t *out, scan_t *count_unique",
|
|
231
|
+
input_fetch_exprs=[
|
|
232
|
+
("ary_im1", "ary", -1),
|
|
233
|
+
("ary_i", "ary", 0),
|
|
234
|
+
],
|
|
235
|
+
input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)",
|
|
236
|
+
scan_expr="a+b", neutral="0",
|
|
237
|
+
output_statement="""
|
|
238
|
+
if (prev_item != item) out[item-1] = ary[i];
|
|
239
|
+
if (i+1 == N) *count_unique = item;
|
|
240
|
+
""",
|
|
241
|
+
preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n",
|
|
242
|
+
template_processor="printf")
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def unique(ary, is_equal_expr="a == b", extra_args=None, preamble="",
|
|
246
|
+
queue=None, wait_for=None):
|
|
247
|
+
"""Copy the elements of *ary* into the output if *is_equal_expr*, applied to the
|
|
248
|
+
array element and its predecessor, yields false.
|
|
249
|
+
|
|
250
|
+
Works like the UNIX command :program:`uniq`, with a potentially custom
|
|
251
|
+
comparison. This operation is often used on sorted sequences.
|
|
252
|
+
|
|
253
|
+
:arg is_equal_expr: a C expression evaluating to a ``bool``,
|
|
254
|
+
represented as a string. The elements being compared are
|
|
255
|
+
available as ``a`` and ``b``. If this expression yields ``false``, the
|
|
256
|
+
two are considered distinct.
|
|
257
|
+
:arg extra_args: |scan_extra_args|
|
|
258
|
+
:arg preamble: |preamble|
|
|
259
|
+
:arg wait_for: |explain-waitfor|
|
|
260
|
+
:returns: a tuple *(out, count, event)* where *out* is the output array, *count*
|
|
261
|
+
is an on-device scalar (fetch to host with ``count.get()``) indicating
|
|
262
|
+
how many elements satisfied the predicate, and *event* is a
|
|
263
|
+
:class:`pyopencl.Event` for dependency management.
|
|
264
|
+
|
|
265
|
+
.. versionadded:: 2013.1
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
if len(ary) > np.iinfo(np.uint32).max:
|
|
269
|
+
scan_dtype = np.uint64
|
|
270
|
+
else:
|
|
271
|
+
scan_dtype = np.uint32
|
|
272
|
+
|
|
273
|
+
if wait_for is None:
|
|
274
|
+
wait_for = []
|
|
275
|
+
|
|
276
|
+
extra_args_types, extra_args_values, extra_wait_for = \
|
|
277
|
+
_extract_extra_args_types_values(extra_args)
|
|
278
|
+
wait_for = wait_for + extra_wait_for
|
|
279
|
+
|
|
280
|
+
knl = _unique_template.build(
|
|
281
|
+
ary.context,
|
|
282
|
+
type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
|
|
283
|
+
var_values=(("macro_is_equal_expr", is_equal_expr),),
|
|
284
|
+
more_preamble=preamble, more_arguments=extra_args_types)
|
|
285
|
+
|
|
286
|
+
out = cl_array.empty_like(ary)
|
|
287
|
+
count = ary._new_with_changes(data=None, offset=0,
|
|
288
|
+
shape=(), strides=(), dtype=scan_dtype)
|
|
289
|
+
|
|
290
|
+
evt = knl(ary, out, count, *extra_args_values,
|
|
291
|
+
queue=queue, wait_for=wait_for)
|
|
292
|
+
|
|
293
|
+
return out, count, evt
|
|
294
|
+
|
|
295
|
+
# }}}
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# {{{ radix_sort
|
|
299
|
+
|
|
300
|
+
def to_bin(n):
|
|
301
|
+
# Py 2.5 has no built-in bin()
|
|
302
|
+
digs = []
|
|
303
|
+
while n:
|
|
304
|
+
digs.append(str(n % 2))
|
|
305
|
+
n >>= 1
|
|
306
|
+
|
|
307
|
+
return "".join(digs[::-1])
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _padded_bin(i, nbits):
|
|
311
|
+
s = to_bin(i)
|
|
312
|
+
while len(s) < nbits:
|
|
313
|
+
s = "0" + s
|
|
314
|
+
return s
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@memoize
|
|
318
|
+
def _make_sort_scan_type(device, bits, index_dtype):
|
|
319
|
+
name = "pyopencl_sort_scan_%s_%dbits_t" % (
|
|
320
|
+
index_dtype.type.__name__, bits)
|
|
321
|
+
|
|
322
|
+
fields = []
|
|
323
|
+
for mnr in range(2**bits):
|
|
324
|
+
fields.append(("c%s" % _padded_bin(mnr, bits), index_dtype))
|
|
325
|
+
|
|
326
|
+
dtype = np.dtype(fields)
|
|
327
|
+
|
|
328
|
+
from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct
|
|
329
|
+
dtype, c_decl = match_dtype_to_c_struct(device, name, dtype)
|
|
330
|
+
|
|
331
|
+
dtype = get_or_register_dtype(name, dtype)
|
|
332
|
+
return name, dtype, c_decl
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# {{{ types, helpers preamble
|
|
336
|
+
|
|
337
|
+
RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL//
|
|
338
|
+
typedef ${scan_ctype} scan_t;
|
|
339
|
+
typedef ${key_ctype} key_t;
|
|
340
|
+
typedef ${index_ctype} index_t;
|
|
341
|
+
|
|
342
|
+
// #define DEBUG
|
|
343
|
+
#ifdef DEBUG
|
|
344
|
+
#define dbg_printf(ARGS) printf ARGS
|
|
345
|
+
#else
|
|
346
|
+
#define dbg_printf(ARGS) /* */
|
|
347
|
+
#endif
|
|
348
|
+
|
|
349
|
+
index_t get_count(scan_t s, int mnr)
|
|
350
|
+
{
|
|
351
|
+
return ${get_count_branch("")};
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
#define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1})
|
|
355
|
+
|
|
356
|
+
""", strict_undefined=True)
|
|
357
|
+
|
|
358
|
+
# }}}
|
|
359
|
+
|
|
360
|
+
# {{{ scan helpers
|
|
361
|
+
|
|
362
|
+
RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL//
|
|
363
|
+
scan_t scan_t_neutral()
|
|
364
|
+
{
|
|
365
|
+
scan_t result;
|
|
366
|
+
%for mnr in range(2**bits):
|
|
367
|
+
result.c${padded_bin(mnr, bits)} = 0;
|
|
368
|
+
%endfor
|
|
369
|
+
return result;
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
// considers bits (base_bit+bits-1, ..., base_bit)
|
|
373
|
+
scan_t scan_t_from_value(
|
|
374
|
+
key_t key,
|
|
375
|
+
int base_bit,
|
|
376
|
+
int i
|
|
377
|
+
)
|
|
378
|
+
{
|
|
379
|
+
// extract relevant bit range
|
|
380
|
+
key_t bin_nr = BIN_NR(key);
|
|
381
|
+
|
|
382
|
+
dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr));
|
|
383
|
+
|
|
384
|
+
scan_t result;
|
|
385
|
+
%for mnr in range(2**bits):
|
|
386
|
+
result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr});
|
|
387
|
+
%endfor
|
|
388
|
+
|
|
389
|
+
return result;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary)
|
|
393
|
+
{
|
|
394
|
+
%for mnr in range(2**bits):
|
|
395
|
+
<% field = "c"+padded_bin(mnr, bits) %>
|
|
396
|
+
b.${field} = a.${field} + b.${field};
|
|
397
|
+
%endfor
|
|
398
|
+
|
|
399
|
+
return b;
|
|
400
|
+
}
|
|
401
|
+
""", strict_undefined=True)
|
|
402
|
+
|
|
403
|
+
RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
|
|
404
|
+
{
|
|
405
|
+
key_t key = ${key_expr};
|
|
406
|
+
key_t my_bin_nr = BIN_NR(key);
|
|
407
|
+
|
|
408
|
+
index_t previous_bins_size = 0;
|
|
409
|
+
%for mnr in range(2**bits):
|
|
410
|
+
previous_bins_size +=
|
|
411
|
+
(my_bin_nr > ${mnr})
|
|
412
|
+
? last_item.c${padded_bin(mnr, bits)}
|
|
413
|
+
: 0;
|
|
414
|
+
%endfor
|
|
415
|
+
|
|
416
|
+
index_t tgt_idx =
|
|
417
|
+
previous_bins_size
|
|
418
|
+
+ get_count(item, my_bin_nr) - 1;
|
|
419
|
+
|
|
420
|
+
%for arg_name in sort_arg_names:
|
|
421
|
+
sorted_${arg_name}[tgt_idx] = ${arg_name}[i];
|
|
422
|
+
%endfor
|
|
423
|
+
}
|
|
424
|
+
""", strict_undefined=True)
|
|
425
|
+
|
|
426
|
+
# }}}
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
# {{{ driver
|
|
430
|
+
|
|
431
|
+
class RadixSort:
|
|
432
|
+
"""Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`__
|
|
433
|
+
on the compute device.
|
|
434
|
+
|
|
435
|
+
.. seealso:: :class:`pyopencl.bitonic_sort.BitonicSort`
|
|
436
|
+
|
|
437
|
+
.. versionadded:: 2013.1
|
|
438
|
+
"""
|
|
439
|
+
def __init__(self, context, arguments, key_expr, sort_arg_names,
|
|
440
|
+
bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32,
|
|
441
|
+
scan_kernel=GenericScanKernel, options=None):
|
|
442
|
+
"""
|
|
443
|
+
:arg arguments: A string of comma-separated C argument declarations.
|
|
444
|
+
If *arguments* is specified, then *input_expr* must also be
|
|
445
|
+
specified. All types used here must be known to PyOpenCL.
|
|
446
|
+
(see :func:`pyopencl.tools.get_or_register_dtype`).
|
|
447
|
+
:arg key_expr: An integer-valued C expression returning the
|
|
448
|
+
key based on which the sort is performed. The array index
|
|
449
|
+
for which the key is to be computed is available as ``i``.
|
|
450
|
+
The expression may refer to any of the *arguments*.
|
|
451
|
+
:arg sort_arg_names: A list of argument names whose corresponding
|
|
452
|
+
array arguments will be sorted according to *key_expr*.
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
# {{{ arg processing
|
|
456
|
+
|
|
457
|
+
from pyopencl.tools import parse_arg_list
|
|
458
|
+
self.arguments = parse_arg_list(arguments)
|
|
459
|
+
del arguments
|
|
460
|
+
|
|
461
|
+
self.sort_arg_names = sort_arg_names
|
|
462
|
+
self.bits = int(bits_at_a_time)
|
|
463
|
+
self.index_dtype = np.dtype(index_dtype)
|
|
464
|
+
self.key_dtype = np.dtype(key_dtype)
|
|
465
|
+
|
|
466
|
+
self.options = options
|
|
467
|
+
|
|
468
|
+
# }}}
|
|
469
|
+
|
|
470
|
+
# {{{ kernel creation
|
|
471
|
+
|
|
472
|
+
scan_ctype, scan_dtype, scan_t_cdecl = \
|
|
473
|
+
_make_sort_scan_type(context.devices[0], self.bits, self.index_dtype)
|
|
474
|
+
|
|
475
|
+
from pyopencl.tools import ScalarArg, VectorArg
|
|
476
|
+
scan_arguments = (
|
|
477
|
+
list(self.arguments)
|
|
478
|
+
+ [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments
|
|
479
|
+
if arg.name in sort_arg_names]
|
|
480
|
+
+ [ScalarArg(np.int32, "base_bit")])
|
|
481
|
+
|
|
482
|
+
def get_count_branch(known_bits):
|
|
483
|
+
if len(known_bits) == self.bits:
|
|
484
|
+
return "s.c%s" % known_bits
|
|
485
|
+
|
|
486
|
+
boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0"
|
|
487
|
+
|
|
488
|
+
return ("((mnr < {}) ? {} : {})".format(
|
|
489
|
+
int(boundary_mnr, 2),
|
|
490
|
+
get_count_branch(known_bits+"0"),
|
|
491
|
+
get_count_branch(known_bits+"1")))
|
|
492
|
+
|
|
493
|
+
codegen_args = {
|
|
494
|
+
"bits": self.bits,
|
|
495
|
+
"key_ctype": dtype_to_ctype(self.key_dtype),
|
|
496
|
+
"key_expr": key_expr,
|
|
497
|
+
"index_ctype": dtype_to_ctype(self.index_dtype),
|
|
498
|
+
"index_type_max": np.iinfo(self.index_dtype).max,
|
|
499
|
+
"padded_bin": _padded_bin,
|
|
500
|
+
"scan_ctype": scan_ctype,
|
|
501
|
+
"sort_arg_names": sort_arg_names,
|
|
502
|
+
"get_count_branch": get_count_branch,
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args)
|
|
506
|
+
scan_preamble = preamble \
|
|
507
|
+
+ RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args)
|
|
508
|
+
|
|
509
|
+
self.scan_kernel = scan_kernel(
|
|
510
|
+
context, scan_dtype,
|
|
511
|
+
arguments=scan_arguments,
|
|
512
|
+
input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,
|
|
513
|
+
scan_expr="scan_t_add(a, b, across_seg_boundary)",
|
|
514
|
+
neutral="scan_t_neutral()",
|
|
515
|
+
output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args),
|
|
516
|
+
preamble=scan_preamble, options=self.options)
|
|
517
|
+
|
|
518
|
+
for i, arg in enumerate(self.arguments):
|
|
519
|
+
if isinstance(arg, VectorArg):
|
|
520
|
+
self.first_array_arg_idx = i
|
|
521
|
+
|
|
522
|
+
# }}}
|
|
523
|
+
|
|
524
|
+
def __call__(self, *args, **kwargs):
|
|
525
|
+
"""Run the radix sort. In addition to *args* which must match the
|
|
526
|
+
*arguments* specification on the constructor, the following
|
|
527
|
+
keyword arguments are supported:
|
|
528
|
+
|
|
529
|
+
:arg key_bits: specify how many bits (starting from least-significant)
|
|
530
|
+
there are in the key.
|
|
531
|
+
:arg allocator: See the *allocator* argument of :func:`pyopencl.array.empty`.
|
|
532
|
+
:arg queue: A :class:`pyopencl.CommandQueue`, defaulting to the
|
|
533
|
+
one from the first argument array.
|
|
534
|
+
:arg wait_for: |explain-waitfor|
|
|
535
|
+
:returns: A tuple ``(sorted, event)``. *sorted* consists of sorted
|
|
536
|
+
copies of the arrays named in *sorted_args*, in the order of that
|
|
537
|
+
list. *event* is a :class:`pyopencl.Event` for dependency management.
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
wait_for = kwargs.pop("wait_for", None)
|
|
541
|
+
|
|
542
|
+
# {{{ run control
|
|
543
|
+
|
|
544
|
+
key_bits = kwargs.pop("key_bits", None)
|
|
545
|
+
if key_bits is None:
|
|
546
|
+
key_bits = int(np.iinfo(self.key_dtype).bits)
|
|
547
|
+
|
|
548
|
+
n = len(args[self.first_array_arg_idx])
|
|
549
|
+
|
|
550
|
+
allocator = kwargs.pop("allocator", None)
|
|
551
|
+
if allocator is None:
|
|
552
|
+
allocator = args[self.first_array_arg_idx].allocator
|
|
553
|
+
|
|
554
|
+
queue = kwargs.pop("queue", None)
|
|
555
|
+
if queue is None:
|
|
556
|
+
queue = args[self.first_array_arg_idx].queue
|
|
557
|
+
|
|
558
|
+
args = list(args)
|
|
559
|
+
|
|
560
|
+
base_bit = 0
|
|
561
|
+
while base_bit < key_bits:
|
|
562
|
+
sorted_args = [
|
|
563
|
+
cl_array.empty(queue, n, arg_descr.dtype, allocator=allocator)
|
|
564
|
+
for arg_descr in self.arguments
|
|
565
|
+
if arg_descr.name in self.sort_arg_names]
|
|
566
|
+
|
|
567
|
+
scan_args = args + sorted_args + [base_bit]
|
|
568
|
+
|
|
569
|
+
last_evt = self.scan_kernel(*scan_args,
|
|
570
|
+
queue=queue, wait_for=wait_for)
|
|
571
|
+
wait_for = [last_evt]
|
|
572
|
+
|
|
573
|
+
# substitute sorted
|
|
574
|
+
for i, arg_descr in enumerate(self.arguments):
|
|
575
|
+
if arg_descr.name in self.sort_arg_names:
|
|
576
|
+
args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)]
|
|
577
|
+
|
|
578
|
+
base_bit += self.bits
|
|
579
|
+
|
|
580
|
+
return [arg_val
|
|
581
|
+
for arg_descr, arg_val in zip(self.arguments, args, strict=True)
|
|
582
|
+
if arg_descr.name in self.sort_arg_names], last_evt
|
|
583
|
+
|
|
584
|
+
# }}}
|
|
585
|
+
|
|
586
|
+
# }}}
|
|
587
|
+
|
|
588
|
+
# }}}
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
# {{{ generic parallel list builder
|
|
592
|
+
|
|
593
|
+
# {{{ kernel template
|
|
594
|
+
|
|
595
|
+
_LIST_BUILDER_TEMPLATE = Template("""//CL//
|
|
596
|
+
% if double_support:
|
|
597
|
+
#if __OPENCL_C_VERSION__ < 120
|
|
598
|
+
#pragma OPENCL EXTENSION cl_khr_fp64: enable
|
|
599
|
+
#endif
|
|
600
|
+
#define PYOPENCL_DEFINE_CDOUBLE
|
|
601
|
+
% endif
|
|
602
|
+
|
|
603
|
+
#include <pyopencl-complex.h>
|
|
604
|
+
|
|
605
|
+
${preamble}
|
|
606
|
+
|
|
607
|
+
// {{{ declare helper macros for user interface
|
|
608
|
+
|
|
609
|
+
typedef ${index_type} index_type;
|
|
610
|
+
|
|
611
|
+
%if is_count_stage:
|
|
612
|
+
#define PLB_COUNT_STAGE
|
|
613
|
+
|
|
614
|
+
%for name, dtype in list_names_and_dtypes:
|
|
615
|
+
%if name in count_sharing:
|
|
616
|
+
#define APPEND_${name}(value) { /* nothing */ }
|
|
617
|
+
%else:
|
|
618
|
+
#define APPEND_${name}(value) { ++(*plb_loc_${name}_count); }
|
|
619
|
+
%endif
|
|
620
|
+
%endfor
|
|
621
|
+
%else:
|
|
622
|
+
#define PLB_WRITE_STAGE
|
|
623
|
+
|
|
624
|
+
%for name, dtype in list_names_and_dtypes:
|
|
625
|
+
%if name in count_sharing:
|
|
626
|
+
#define APPEND_${name}(value) \
|
|
627
|
+
{ plb_${name}_list[(*plb_${count_sharing[name]}_index) - 1] \
|
|
628
|
+
= value; }
|
|
629
|
+
%else:
|
|
630
|
+
#define APPEND_${name}(value) \
|
|
631
|
+
{ plb_${name}_list[(*plb_${name}_index)++] = value; }
|
|
632
|
+
%endif
|
|
633
|
+
%endfor
|
|
634
|
+
%endif
|
|
635
|
+
|
|
636
|
+
#define LIST_ARG_DECL ${user_list_arg_decl}
|
|
637
|
+
#define LIST_ARGS ${user_list_args}
|
|
638
|
+
#define USER_ARG_DECL ${user_arg_decl_no_offset}
|
|
639
|
+
#define USER_ARGS ${user_args_no_offset}
|
|
640
|
+
|
|
641
|
+
// }}}
|
|
642
|
+
|
|
643
|
+
${generate_template}
|
|
644
|
+
|
|
645
|
+
// {{{ kernel entry point
|
|
646
|
+
|
|
647
|
+
__kernel
|
|
648
|
+
%if do_not_vectorize:
|
|
649
|
+
__attribute__((reqd_work_group_size(1, 1, 1)))
|
|
650
|
+
%endif
|
|
651
|
+
void ${kernel_name}(
|
|
652
|
+
${kernel_list_arg_decl} ${user_arg_decl_with_offset} index_type n)
|
|
653
|
+
|
|
654
|
+
{
|
|
655
|
+
%if not do_not_vectorize:
|
|
656
|
+
int lid = get_local_id(0);
|
|
657
|
+
index_type gsize = get_global_size(0);
|
|
658
|
+
index_type work_group_start = get_local_size(0)*get_group_id(0);
|
|
659
|
+
for (index_type i = work_group_start + lid; i < n; i += gsize)
|
|
660
|
+
%else:
|
|
661
|
+
const int chunk_size = 128;
|
|
662
|
+
index_type chunk_base = get_global_id(0)*chunk_size;
|
|
663
|
+
index_type gsize = get_global_size(0);
|
|
664
|
+
for (; chunk_base < n; chunk_base += gsize*chunk_size)
|
|
665
|
+
for (index_type i = chunk_base; i < min(n, chunk_base+chunk_size); ++i)
|
|
666
|
+
%endif
|
|
667
|
+
{
|
|
668
|
+
%if is_count_stage:
|
|
669
|
+
%for name, dtype in list_names_and_dtypes:
|
|
670
|
+
%if name not in count_sharing:
|
|
671
|
+
index_type plb_loc_${name}_count = 0;
|
|
672
|
+
%endif
|
|
673
|
+
%endfor
|
|
674
|
+
%else:
|
|
675
|
+
%for name, dtype in list_names_and_dtypes:
|
|
676
|
+
%if name not in count_sharing:
|
|
677
|
+
index_type plb_${name}_index;
|
|
678
|
+
if (plb_${name}_start_index)
|
|
679
|
+
%if name in eliminate_empty_output_lists:
|
|
680
|
+
plb_${name}_index =
|
|
681
|
+
plb_${name}_start_index[
|
|
682
|
+
${name}_compressed_indices[i]
|
|
683
|
+
];
|
|
684
|
+
%else:
|
|
685
|
+
plb_${name}_index = plb_${name}_start_index[i];
|
|
686
|
+
%endif
|
|
687
|
+
else
|
|
688
|
+
plb_${name}_index = 0;
|
|
689
|
+
%endif
|
|
690
|
+
%endfor
|
|
691
|
+
%endif
|
|
692
|
+
|
|
693
|
+
${arg_offset_adjustment}
|
|
694
|
+
generate(${kernel_list_arg_values} USER_ARGS i);
|
|
695
|
+
|
|
696
|
+
%if is_count_stage:
|
|
697
|
+
%for name, dtype in list_names_and_dtypes:
|
|
698
|
+
%if name not in count_sharing:
|
|
699
|
+
if (plb_${name}_count)
|
|
700
|
+
plb_${name}_count[i] = plb_loc_${name}_count;
|
|
701
|
+
%endif
|
|
702
|
+
%endfor
|
|
703
|
+
%endif
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
// }}}
|
|
708
|
+
|
|
709
|
+
""", strict_undefined=True)
|
|
710
|
+
|
|
711
|
+
# }}}
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def _get_arg_decl(arg_list):
|
|
715
|
+
result = ""
|
|
716
|
+
for arg in arg_list:
|
|
717
|
+
result += arg.declarator() + ", "
|
|
718
|
+
|
|
719
|
+
return result
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _get_arg_list(arg_list, prefix=""):
|
|
723
|
+
result = ""
|
|
724
|
+
for arg in arg_list:
|
|
725
|
+
result += prefix + arg.name + ", "
|
|
726
|
+
|
|
727
|
+
return result
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@dataclass
|
|
731
|
+
class BuiltList:
|
|
732
|
+
count: int | None
|
|
733
|
+
starts: cl_array.Array | None
|
|
734
|
+
lists: cl_array.Array | None = None
|
|
735
|
+
num_nonempty_lists: int | None = None
|
|
736
|
+
nonempty_indices: cl_array.Array | None = None
|
|
737
|
+
compressed_indices: cl_array.Array | None = None
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
class ListOfListsBuilder:
|
|
741
|
+
"""Generates and executes code to produce a large number of variable-size
|
|
742
|
+
lists, simply.
|
|
743
|
+
|
|
744
|
+
.. note:: This functionality is provided as a preview. Its interface
|
|
745
|
+
is subject to change until this notice is removed.
|
|
746
|
+
|
|
747
|
+
.. versionadded:: 2013.1
|
|
748
|
+
|
|
749
|
+
Here's a usage example::
|
|
750
|
+
|
|
751
|
+
from pyopencl.algorithm import ListOfListsBuilder
|
|
752
|
+
builder = ListOfListsBuilder(context, [("mylist", np.int32)], \"\"\"
|
|
753
|
+
void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
|
|
754
|
+
{
|
|
755
|
+
int count = i % 4;
|
|
756
|
+
for (int j = 0; j < count; ++j)
|
|
757
|
+
{
|
|
758
|
+
APPEND_mylist(count);
|
|
759
|
+
}
|
|
760
|
+
}
|
|
761
|
+
\"\"\", arg_decls=[])
|
|
762
|
+
|
|
763
|
+
result, event = builder(queue, 2000)
|
|
764
|
+
|
|
765
|
+
inf = result["mylist"]
|
|
766
|
+
assert inf.count == 3000
|
|
767
|
+
assert (inf.list.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()
|
|
768
|
+
|
|
769
|
+
The function ``generate`` above is called once for each "input object".
|
|
770
|
+
Each input object can then generate zero or more list entries.
|
|
771
|
+
The number of these input objects is given to :meth:`__call__` as *n_objects*.
|
|
772
|
+
List entries are generated by calls to ``APPEND_<list name>(value)``.
|
|
773
|
+
Multiple lists may be generated at once.
|
|
774
|
+
|
|
775
|
+
.. automethod:: __init__
|
|
776
|
+
.. automethod:: __call__
|
|
777
|
+
"""
|
|
778
|
+
def __init__(self, context, list_names_and_dtypes, generate_template,
|
|
779
|
+
arg_decls, count_sharing=None, devices=None,
|
|
780
|
+
name_prefix="plb_build_list", options=None, preamble="",
|
|
781
|
+
debug=False, complex_kernel=False,
|
|
782
|
+
eliminate_empty_output_lists=False):
|
|
783
|
+
"""
|
|
784
|
+
:arg context: A :class:`pyopencl.Context`.
|
|
785
|
+
:arg list_names_and_dtypes: a list of ``(name, dtype)`` tuples
|
|
786
|
+
indicating the lists to be built.
|
|
787
|
+
:arg generate_template: a snippet of C as described below
|
|
788
|
+
:arg arg_decls: A string of comma-separated C argument declarations.
|
|
789
|
+
:arg count_sharing: A mapping consisting of ``(child, mother)``
|
|
790
|
+
indicating that ``mother`` and ``child`` will always have the
|
|
791
|
+
same number of indices, and the ``APPEND`` to ``mother``
|
|
792
|
+
will always happen *before* the ``APPEND`` to the child.
|
|
793
|
+
:arg name_prefix: the name prefix to use for the compiled kernels
|
|
794
|
+
:arg options: OpenCL compilation options for kernels using
|
|
795
|
+
*generate_template*.
|
|
796
|
+
:arg complex_kernel: If *True*, prevents vectorization on CPUs.
|
|
797
|
+
:arg eliminate_empty_output_lists: A Python list of list names
|
|
798
|
+
for which the empty output lists are eliminated.
|
|
799
|
+
|
|
800
|
+
*generate_template* may use the following C macros/identifiers:
|
|
801
|
+
|
|
802
|
+
* ``index_type``: expands to C identifier for the index type used
|
|
803
|
+
for the calculation
|
|
804
|
+
* ``USER_ARG_DECL``: expands to the C declarator for ``arg_decls``
|
|
805
|
+
* ``USER_ARGS``: a list of C argument values corresponding to
|
|
806
|
+
``user_arg_decl``
|
|
807
|
+
* ``LIST_ARG_DECL``: expands to a C argument list representing the
|
|
808
|
+
data for the output lists. These are escaped prefixed with
|
|
809
|
+
``"plg_"`` so as to not interfere with user-provided names.
|
|
810
|
+
* ``LIST_ARGS``: a list of C argument values corresponding to
|
|
811
|
+
``LIST_ARG_DECL``
|
|
812
|
+
* ``APPEND_name(entry)``: inserts ``entry`` into the list ``name``.
|
|
813
|
+
*entry* must be a valid C expression of the correct type.
|
|
814
|
+
|
|
815
|
+
All argument-list related macros have a trailing comma included
|
|
816
|
+
if they are non-empty.
|
|
817
|
+
|
|
818
|
+
*generate_template* must supply a function:
|
|
819
|
+
|
|
820
|
+
.. code-block:: c
|
|
821
|
+
|
|
822
|
+
void generate(USER_ARG_DECL LIST_ARG_DECL index_type i)
|
|
823
|
+
{
|
|
824
|
+
APPEND_mylist(5);
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
Internally, the ``kernel_template`` is expanded (at least) twice. Once,
|
|
828
|
+
for a 'counting' stage where the size of all the lists is determined,
|
|
829
|
+
and a second time, for a 'generation' stage where the lists are
|
|
830
|
+
actually filled. A ``generate`` function that has side effects beyond
|
|
831
|
+
calling ``append`` is therefore ill-formed.
|
|
832
|
+
|
|
833
|
+
.. versionchanged:: 2018.1
|
|
834
|
+
|
|
835
|
+
Change *eliminate_empty_output_lists* argument type from ``bool`` to
|
|
836
|
+
``list``.
|
|
837
|
+
"""
|
|
838
|
+
if devices is None:
|
|
839
|
+
devices = context.devices
|
|
840
|
+
|
|
841
|
+
if count_sharing is None:
|
|
842
|
+
count_sharing = {}
|
|
843
|
+
|
|
844
|
+
self.context = context
|
|
845
|
+
self.devices = devices
|
|
846
|
+
|
|
847
|
+
self.list_names_and_dtypes = list_names_and_dtypes
|
|
848
|
+
self.generate_template = generate_template
|
|
849
|
+
|
|
850
|
+
from pyopencl.tools import parse_arg_list
|
|
851
|
+
self.arg_decls = parse_arg_list(arg_decls)
|
|
852
|
+
|
|
853
|
+
# To match with the signature of the user-supplied generate(), arguments
|
|
854
|
+
# can't appear to have offsets.
|
|
855
|
+
arg_decls_no_offset = []
|
|
856
|
+
from pyopencl.tools import VectorArg
|
|
857
|
+
for arg in self.arg_decls:
|
|
858
|
+
if isinstance(arg, VectorArg) and arg.with_offset:
|
|
859
|
+
arg = VectorArg(arg.dtype, arg.name)
|
|
860
|
+
arg_decls_no_offset.append(arg)
|
|
861
|
+
self.arg_decls_no_offset = arg_decls_no_offset
|
|
862
|
+
|
|
863
|
+
self.count_sharing = count_sharing
|
|
864
|
+
|
|
865
|
+
self.name_prefix = name_prefix
|
|
866
|
+
self.preamble = preamble
|
|
867
|
+
self.options = options
|
|
868
|
+
|
|
869
|
+
self.debug = debug
|
|
870
|
+
|
|
871
|
+
self.complex_kernel = complex_kernel
|
|
872
|
+
|
|
873
|
+
if eliminate_empty_output_lists is True:
|
|
874
|
+
eliminate_empty_output_lists = \
|
|
875
|
+
[name for name, _ in self.list_names_and_dtypes]
|
|
876
|
+
|
|
877
|
+
if eliminate_empty_output_lists is False:
|
|
878
|
+
eliminate_empty_output_lists = []
|
|
879
|
+
|
|
880
|
+
self.eliminate_empty_output_lists = eliminate_empty_output_lists
|
|
881
|
+
for list_name in self.eliminate_empty_output_lists:
|
|
882
|
+
if not any(list_name == name for name, _ in self.list_names_and_dtypes):
|
|
883
|
+
raise ValueError(
|
|
884
|
+
"invalid list name '%s' in eliminate_empty_output_lists"
|
|
885
|
+
% list_name)
|
|
886
|
+
|
|
887
|
+
# {{{ kernel generators
|
|
888
|
+
|
|
889
|
+
@memoize_method
|
|
890
|
+
def get_scan_kernel(self, index_dtype):
|
|
891
|
+
return GenericScanKernel(
|
|
892
|
+
self.context, index_dtype,
|
|
893
|
+
arguments="__global %s *ary" % dtype_to_ctype(index_dtype),
|
|
894
|
+
input_expr="ary[i]",
|
|
895
|
+
scan_expr="a+b", neutral="0",
|
|
896
|
+
output_statement="ary[i+1] = item;",
|
|
897
|
+
devices=self.devices)
|
|
898
|
+
|
|
899
|
+
@memoize_method
|
|
900
|
+
def get_compress_kernel(self, index_dtype):
|
|
901
|
+
arguments = """
|
|
902
|
+
__global ${index_t} *count,
|
|
903
|
+
__global ${index_t} *compressed_counts,
|
|
904
|
+
__global ${index_t} *nonempty_indices,
|
|
905
|
+
__global ${index_t} *compressed_indices,
|
|
906
|
+
__global ${index_t} *num_non_empty_list
|
|
907
|
+
"""
|
|
908
|
+
arguments = Template(arguments)
|
|
909
|
+
|
|
910
|
+
return GenericScanKernel(
|
|
911
|
+
self.context, index_dtype,
|
|
912
|
+
arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)),
|
|
913
|
+
input_expr="count[i] == 0 ? 0 : 1",
|
|
914
|
+
scan_expr="a+b", neutral="0",
|
|
915
|
+
output_statement="""
|
|
916
|
+
if (i + 1 < N) compressed_indices[i + 1] = item;
|
|
917
|
+
if (prev_item != item) {
|
|
918
|
+
nonempty_indices[item - 1] = i;
|
|
919
|
+
compressed_counts[item - 1] = count[i];
|
|
920
|
+
}
|
|
921
|
+
if (i + 1 == N) *num_non_empty_list = item;
|
|
922
|
+
""",
|
|
923
|
+
devices=self.devices)
|
|
924
|
+
|
|
925
|
+
def do_not_vectorize(self):
|
|
926
|
+
return (self.complex_kernel
|
|
927
|
+
and any(dev.type & cl.device_type.CPU
|
|
928
|
+
for dev in self.context.devices))
|
|
929
|
+
|
|
930
|
+
@memoize_method
|
|
931
|
+
def get_count_kernel(self, index_dtype):
|
|
932
|
+
index_ctype = dtype_to_ctype(index_dtype)
|
|
933
|
+
from pyopencl.tools import OtherArg, VectorArg
|
|
934
|
+
kernel_list_args = [
|
|
935
|
+
VectorArg(index_dtype, "plb_%s_count" % name)
|
|
936
|
+
for name, dtype in self.list_names_and_dtypes
|
|
937
|
+
if name not in self.count_sharing]
|
|
938
|
+
|
|
939
|
+
user_list_args = []
|
|
940
|
+
for name, _dtype in self.list_names_and_dtypes:
|
|
941
|
+
if name in self.count_sharing:
|
|
942
|
+
continue
|
|
943
|
+
|
|
944
|
+
name = "plb_loc_%s_count" % name
|
|
945
|
+
user_list_args.append(OtherArg("{} *{}".format(
|
|
946
|
+
index_ctype, name), name))
|
|
947
|
+
|
|
948
|
+
kernel_name = self.name_prefix+"_count"
|
|
949
|
+
|
|
950
|
+
from pyopencl.characterize import has_double_support
|
|
951
|
+
src = _LIST_BUILDER_TEMPLATE.render(
|
|
952
|
+
is_count_stage=True,
|
|
953
|
+
kernel_name=kernel_name,
|
|
954
|
+
double_support=all(has_double_support(dev) for dev in
|
|
955
|
+
self.context.devices),
|
|
956
|
+
debug=self.debug,
|
|
957
|
+
do_not_vectorize=self.do_not_vectorize(),
|
|
958
|
+
eliminate_empty_output_lists=self.eliminate_empty_output_lists,
|
|
959
|
+
|
|
960
|
+
kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
|
|
961
|
+
kernel_list_arg_values=_get_arg_list(user_list_args, prefix="&"),
|
|
962
|
+
user_list_arg_decl=_get_arg_decl(user_list_args),
|
|
963
|
+
user_list_args=_get_arg_list(user_list_args),
|
|
964
|
+
user_arg_decl_with_offset=_get_arg_decl(self.arg_decls),
|
|
965
|
+
user_arg_decl_no_offset=_get_arg_decl(self.arg_decls_no_offset),
|
|
966
|
+
user_args_no_offset=_get_arg_list(self.arg_decls_no_offset),
|
|
967
|
+
arg_offset_adjustment=get_arg_offset_adjuster_code(self.arg_decls),
|
|
968
|
+
|
|
969
|
+
list_names_and_dtypes=self.list_names_and_dtypes,
|
|
970
|
+
count_sharing=self.count_sharing,
|
|
971
|
+
name_prefix=self.name_prefix,
|
|
972
|
+
generate_template=self.generate_template,
|
|
973
|
+
preamble=self.preamble,
|
|
974
|
+
|
|
975
|
+
index_type=index_ctype,
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
src = str(src)
|
|
979
|
+
|
|
980
|
+
prg = cl.Program(self.context, src).build(self.options)
|
|
981
|
+
knl = getattr(prg, kernel_name)
|
|
982
|
+
|
|
983
|
+
from pyopencl.tools import get_arg_list_scalar_arg_dtypes
|
|
984
|
+
knl.set_scalar_arg_dtypes([
|
|
985
|
+
*get_arg_list_scalar_arg_dtypes([*kernel_list_args, *self.arg_decls]),
|
|
986
|
+
index_dtype
|
|
987
|
+
])
|
|
988
|
+
|
|
989
|
+
return knl
|
|
990
|
+
|
|
991
|
+
@memoize_method
|
|
992
|
+
def get_write_kernel(self, index_dtype):
|
|
993
|
+
index_ctype = dtype_to_ctype(index_dtype)
|
|
994
|
+
from pyopencl.tools import OtherArg, VectorArg
|
|
995
|
+
kernel_list_args = []
|
|
996
|
+
kernel_list_arg_values = ""
|
|
997
|
+
user_list_args = []
|
|
998
|
+
|
|
999
|
+
for name, dtype in self.list_names_and_dtypes:
|
|
1000
|
+
list_name = "plb_%s_list" % name
|
|
1001
|
+
list_arg = VectorArg(dtype, list_name)
|
|
1002
|
+
|
|
1003
|
+
kernel_list_args.append(list_arg)
|
|
1004
|
+
user_list_args.append(list_arg)
|
|
1005
|
+
|
|
1006
|
+
if name in self.count_sharing:
|
|
1007
|
+
kernel_list_arg_values += "%s, " % list_name
|
|
1008
|
+
continue
|
|
1009
|
+
|
|
1010
|
+
kernel_list_args.append(
|
|
1011
|
+
VectorArg(index_dtype, "plb_%s_start_index" % name))
|
|
1012
|
+
|
|
1013
|
+
if name in self.eliminate_empty_output_lists:
|
|
1014
|
+
kernel_list_args.append(
|
|
1015
|
+
VectorArg(index_dtype, "%s_compressed_indices" % name))
|
|
1016
|
+
|
|
1017
|
+
index_name = "plb_%s_index" % name
|
|
1018
|
+
user_list_args.append(OtherArg("{} *{}".format(
|
|
1019
|
+
index_ctype, index_name), index_name))
|
|
1020
|
+
|
|
1021
|
+
kernel_list_arg_values += f"{list_name}, &{index_name}, "
|
|
1022
|
+
|
|
1023
|
+
kernel_name = self.name_prefix+"_write"
|
|
1024
|
+
|
|
1025
|
+
from pyopencl.characterize import has_double_support
|
|
1026
|
+
src = _LIST_BUILDER_TEMPLATE.render(
|
|
1027
|
+
is_count_stage=False,
|
|
1028
|
+
kernel_name=kernel_name,
|
|
1029
|
+
double_support=all(has_double_support(dev) for dev in
|
|
1030
|
+
self.context.devices),
|
|
1031
|
+
debug=self.debug,
|
|
1032
|
+
do_not_vectorize=self.do_not_vectorize(),
|
|
1033
|
+
eliminate_empty_output_lists=self.eliminate_empty_output_lists,
|
|
1034
|
+
|
|
1035
|
+
kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
|
|
1036
|
+
kernel_list_arg_values=kernel_list_arg_values,
|
|
1037
|
+
user_list_arg_decl=_get_arg_decl(user_list_args),
|
|
1038
|
+
user_list_args=_get_arg_list(user_list_args),
|
|
1039
|
+
user_arg_decl_with_offset=_get_arg_decl(self.arg_decls),
|
|
1040
|
+
user_arg_decl_no_offset=_get_arg_decl(self.arg_decls_no_offset),
|
|
1041
|
+
user_args_no_offset=_get_arg_list(self.arg_decls_no_offset),
|
|
1042
|
+
arg_offset_adjustment=get_arg_offset_adjuster_code(self.arg_decls),
|
|
1043
|
+
|
|
1044
|
+
list_names_and_dtypes=self.list_names_and_dtypes,
|
|
1045
|
+
count_sharing=self.count_sharing,
|
|
1046
|
+
name_prefix=self.name_prefix,
|
|
1047
|
+
generate_template=self.generate_template,
|
|
1048
|
+
preamble=self.preamble,
|
|
1049
|
+
|
|
1050
|
+
index_type=index_ctype,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
src = str(src)
|
|
1054
|
+
|
|
1055
|
+
prg = cl.Program(self.context, src).build(self.options)
|
|
1056
|
+
knl = getattr(prg, kernel_name)
|
|
1057
|
+
|
|
1058
|
+
from pyopencl.tools import get_arg_list_scalar_arg_dtypes
|
|
1059
|
+
knl.set_scalar_arg_dtypes([
|
|
1060
|
+
*get_arg_list_scalar_arg_dtypes(kernel_list_args + self.arg_decls),
|
|
1061
|
+
index_dtype])
|
|
1062
|
+
|
|
1063
|
+
return knl
|
|
1064
|
+
|
|
1065
|
+
# }}}
|
|
1066
|
+
|
|
1067
|
+
# {{{ driver
|
|
1068
|
+
|
|
1069
|
+
def __call__(self, queue, n_objects, *args, **kwargs):
|
|
1070
|
+
"""
|
|
1071
|
+
:arg args: arguments corresponding to ``arg_decls`` in the constructor.
|
|
1072
|
+
Array-like arguments must be either 1D :class:`pyopencl.array.Array`
|
|
1073
|
+
objects or :class:`pyopencl.MemoryObject` objects, of which the latter
|
|
1074
|
+
can be obtained from a :class:`pyopencl.array.Array` using the
|
|
1075
|
+
:attr:`pyopencl.array.Array.data` attribute.
|
|
1076
|
+
:arg allocator: optionally, the allocator to use to allocate new
|
|
1077
|
+
arrays.
|
|
1078
|
+
:arg omit_lists: an iterable of list names that should *not* be built
|
|
1079
|
+
with this invocation. The kernel code may *not* call ``APPEND_name``
|
|
1080
|
+
for these omitted lists. If it does, undefined behavior will result.
|
|
1081
|
+
The returned *lists* dictionary will not contain an entry for names
|
|
1082
|
+
in *omit_lists*.
|
|
1083
|
+
:arg wait_for: |explain-waitfor|
|
|
1084
|
+
:returns: a tuple ``(lists, event)``, where ``lists`` is a mapping from
|
|
1085
|
+
(built) list names to objects which have attributes
|
|
1086
|
+
|
|
1087
|
+
* ``count`` for the total number of entries in all lists combined
|
|
1088
|
+
* ``lists`` for the array containing all lists.
|
|
1089
|
+
* ``starts`` for the array of starting indices in ``lists``.
|
|
1090
|
+
``starts`` is built so that it has n+1 entries, so that
|
|
1091
|
+
the *i*'th entry is the start of the *i*'th list, and the
|
|
1092
|
+
*i*'th entry is the index one past the *i*'th list's end,
|
|
1093
|
+
even for the last list.
|
|
1094
|
+
|
|
1095
|
+
This implies that all lists are contiguous.
|
|
1096
|
+
|
|
1097
|
+
If the list name is specified in *eliminate_empty_output_lists*
|
|
1098
|
+
constructor argument, *lists* has two additional attributes
|
|
1099
|
+
``num_nonempty_lists`` and ``nonempty_indices``
|
|
1100
|
+
|
|
1101
|
+
* ``num_nonempty_lists`` for the number of nonempty lists.
|
|
1102
|
+
* ``nonempty_indices`` for the index of nonempty list in input objects.
|
|
1103
|
+
|
|
1104
|
+
In this case, ``starts`` has ``num_nonempty_lists + 1`` entries.
|
|
1105
|
+
The *i*'s entry is the start of the *i*'th nonempty list, which is
|
|
1106
|
+
generated by the object with index ``nonempty_indices[i]``.
|
|
1107
|
+
|
|
1108
|
+
*event* is a :class:`pyopencl.Event` for dependency management.
|
|
1109
|
+
|
|
1110
|
+
.. versionchanged:: 2016.2
|
|
1111
|
+
|
|
1112
|
+
Added omit_lists.
|
|
1113
|
+
"""
|
|
1114
|
+
if n_objects >= int(np.iinfo(np.int32).max):
|
|
1115
|
+
index_dtype = np.int64
|
|
1116
|
+
else:
|
|
1117
|
+
index_dtype = np.int32
|
|
1118
|
+
index_dtype = np.dtype(index_dtype)
|
|
1119
|
+
|
|
1120
|
+
allocator = kwargs.pop("allocator", None)
|
|
1121
|
+
omit_lists = kwargs.pop("omit_lists", [])
|
|
1122
|
+
wait_for = kwargs.pop("wait_for", None)
|
|
1123
|
+
if kwargs:
|
|
1124
|
+
raise TypeError("invalid keyword arguments: '%s'" % ", ".join(kwargs))
|
|
1125
|
+
|
|
1126
|
+
for oml in omit_lists:
|
|
1127
|
+
if not any(oml == name for name, _ in self.list_names_and_dtypes):
|
|
1128
|
+
raise ValueError("invalid list name '%s' in omit_lists")
|
|
1129
|
+
|
|
1130
|
+
result = {}
|
|
1131
|
+
count_list_args = []
|
|
1132
|
+
|
|
1133
|
+
if wait_for is None:
|
|
1134
|
+
wait_for = []
|
|
1135
|
+
else:
|
|
1136
|
+
# We'll be modifying it below.
|
|
1137
|
+
wait_for = list(wait_for)
|
|
1138
|
+
|
|
1139
|
+
count_kernel = self.get_count_kernel(index_dtype)
|
|
1140
|
+
write_kernel = self.get_write_kernel(index_dtype)
|
|
1141
|
+
scan_kernel = self.get_scan_kernel(index_dtype)
|
|
1142
|
+
if self.eliminate_empty_output_lists:
|
|
1143
|
+
compress_kernel = self.get_compress_kernel(index_dtype)
|
|
1144
|
+
|
|
1145
|
+
data_args = []
|
|
1146
|
+
for i, (arg_descr, arg_val) in enumerate(
|
|
1147
|
+
zip(self.arg_decls, args, strict=True)):
|
|
1148
|
+
from pyopencl.tools import VectorArg
|
|
1149
|
+
if isinstance(arg_descr, VectorArg):
|
|
1150
|
+
from pyopencl import MemoryObject
|
|
1151
|
+
if arg_val is None:
|
|
1152
|
+
data_args.append(arg_val)
|
|
1153
|
+
if arg_descr.with_offset:
|
|
1154
|
+
data_args.append(0)
|
|
1155
|
+
continue
|
|
1156
|
+
|
|
1157
|
+
if isinstance(arg_val, MemoryObject):
|
|
1158
|
+
data_args.append(arg_val)
|
|
1159
|
+
if arg_descr.with_offset:
|
|
1160
|
+
raise ValueError(
|
|
1161
|
+
"with_offset=True specified for argument %d "
|
|
1162
|
+
"but the argument is not an array" % i)
|
|
1163
|
+
continue
|
|
1164
|
+
|
|
1165
|
+
if arg_val.ndim != 1:
|
|
1166
|
+
raise ValueError("argument %d is a multidimensional array" % i)
|
|
1167
|
+
|
|
1168
|
+
data_args.append(arg_val.base_data)
|
|
1169
|
+
if arg_descr.with_offset:
|
|
1170
|
+
data_args.append(arg_val.offset)
|
|
1171
|
+
wait_for.extend(arg_val.events)
|
|
1172
|
+
else:
|
|
1173
|
+
data_args.append(arg_val)
|
|
1174
|
+
|
|
1175
|
+
del args
|
|
1176
|
+
data_args = tuple(data_args)
|
|
1177
|
+
|
|
1178
|
+
# {{{ allocate memory for counts
|
|
1179
|
+
|
|
1180
|
+
for name, _dtype in self.list_names_and_dtypes:
|
|
1181
|
+
if name in self.count_sharing:
|
|
1182
|
+
continue
|
|
1183
|
+
if name in omit_lists:
|
|
1184
|
+
count_list_args.append(None)
|
|
1185
|
+
continue
|
|
1186
|
+
|
|
1187
|
+
counts = cl_array.empty(queue,
|
|
1188
|
+
(n_objects + 1), index_dtype, allocator=allocator)
|
|
1189
|
+
counts[-1] = 0
|
|
1190
|
+
wait_for = wait_for + counts.events
|
|
1191
|
+
|
|
1192
|
+
# The scan will turn the "counts" array into the "starts" array
|
|
1193
|
+
# in-place.
|
|
1194
|
+
if name in self.eliminate_empty_output_lists:
|
|
1195
|
+
result[name] = BuiltList(count=None, starts=counts, lists=None,
|
|
1196
|
+
num_nonempty_lists=None,
|
|
1197
|
+
nonempty_indices=None)
|
|
1198
|
+
else:
|
|
1199
|
+
result[name] = BuiltList(count=None, starts=counts, lists=None)
|
|
1200
|
+
count_list_args.append(counts.data)
|
|
1201
|
+
|
|
1202
|
+
# }}}
|
|
1203
|
+
|
|
1204
|
+
if self.debug:
|
|
1205
|
+
gsize = (1,)
|
|
1206
|
+
lsize = (1,)
|
|
1207
|
+
elif self.do_not_vectorize():
|
|
1208
|
+
gsize = (4*queue.device.max_compute_units,)
|
|
1209
|
+
lsize = (1,)
|
|
1210
|
+
else:
|
|
1211
|
+
from pyopencl.array import _splay
|
|
1212
|
+
gsize, lsize = _splay(queue.device, n_objects)
|
|
1213
|
+
|
|
1214
|
+
count_event = count_kernel(queue, gsize, lsize,
|
|
1215
|
+
*(tuple(count_list_args) + data_args + (n_objects,)),
|
|
1216
|
+
wait_for=wait_for)
|
|
1217
|
+
|
|
1218
|
+
compress_events = {}
|
|
1219
|
+
for name, _dtype in self.list_names_and_dtypes:
|
|
1220
|
+
if name in omit_lists:
|
|
1221
|
+
continue
|
|
1222
|
+
if name in self.count_sharing:
|
|
1223
|
+
continue
|
|
1224
|
+
if name not in self.eliminate_empty_output_lists:
|
|
1225
|
+
continue
|
|
1226
|
+
|
|
1227
|
+
compressed_counts = cl_array.empty(
|
|
1228
|
+
queue, (n_objects + 1,), index_dtype, allocator=allocator)
|
|
1229
|
+
info_record = result[name]
|
|
1230
|
+
info_record.nonempty_indices = cl_array.empty(
|
|
1231
|
+
queue, (n_objects + 1,), index_dtype, allocator=allocator)
|
|
1232
|
+
info_record.num_nonempty_lists = cl_array.empty(
|
|
1233
|
+
queue, (1,), index_dtype, allocator=allocator)
|
|
1234
|
+
info_record.compressed_indices = cl_array.empty(
|
|
1235
|
+
queue, (n_objects + 1,), index_dtype, allocator=allocator)
|
|
1236
|
+
info_record.compressed_indices[0] = 0
|
|
1237
|
+
|
|
1238
|
+
compress_events[name] = compress_kernel(
|
|
1239
|
+
info_record.starts,
|
|
1240
|
+
compressed_counts,
|
|
1241
|
+
info_record.nonempty_indices,
|
|
1242
|
+
info_record.compressed_indices,
|
|
1243
|
+
info_record.num_nonempty_lists,
|
|
1244
|
+
wait_for=[count_event, *info_record.compressed_indices.events])
|
|
1245
|
+
|
|
1246
|
+
info_record.starts = compressed_counts
|
|
1247
|
+
|
|
1248
|
+
# {{{ run scans
|
|
1249
|
+
|
|
1250
|
+
scan_events = []
|
|
1251
|
+
|
|
1252
|
+
for name, _dtype in self.list_names_and_dtypes:
|
|
1253
|
+
if name in self.count_sharing:
|
|
1254
|
+
continue
|
|
1255
|
+
if name in omit_lists:
|
|
1256
|
+
continue
|
|
1257
|
+
|
|
1258
|
+
info_record = result[name]
|
|
1259
|
+
if name in self.eliminate_empty_output_lists:
|
|
1260
|
+
compress_events[name].wait()
|
|
1261
|
+
num_nonempty_lists = info_record.num_nonempty_lists.get()[0]
|
|
1262
|
+
info_record.num_nonempty_lists = num_nonempty_lists
|
|
1263
|
+
info_record.starts = info_record.starts[:num_nonempty_lists + 1]
|
|
1264
|
+
info_record.nonempty_indices = \
|
|
1265
|
+
info_record.nonempty_indices[:num_nonempty_lists]
|
|
1266
|
+
info_record.starts[-1] = 0
|
|
1267
|
+
|
|
1268
|
+
starts_ary = info_record.starts
|
|
1269
|
+
if name in self.eliminate_empty_output_lists:
|
|
1270
|
+
evt = scan_kernel(
|
|
1271
|
+
starts_ary,
|
|
1272
|
+
size=info_record.num_nonempty_lists,
|
|
1273
|
+
wait_for=starts_ary.events)
|
|
1274
|
+
else:
|
|
1275
|
+
evt = scan_kernel(starts_ary, wait_for=[count_event],
|
|
1276
|
+
size=n_objects)
|
|
1277
|
+
|
|
1278
|
+
starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
|
|
1279
|
+
scan_events.extend(starts_ary.events)
|
|
1280
|
+
|
|
1281
|
+
# retrieve count
|
|
1282
|
+
info_record.count = int(starts_ary[-1].get())
|
|
1283
|
+
|
|
1284
|
+
# }}}
|
|
1285
|
+
|
|
1286
|
+
# {{{ deal with count-sharing lists, allocate memory for lists
|
|
1287
|
+
|
|
1288
|
+
write_list_args = []
|
|
1289
|
+
for name, dtype in self.list_names_and_dtypes:
|
|
1290
|
+
if name in omit_lists:
|
|
1291
|
+
write_list_args.append(None)
|
|
1292
|
+
if name not in self.count_sharing:
|
|
1293
|
+
write_list_args.append(None)
|
|
1294
|
+
if name in self.eliminate_empty_output_lists:
|
|
1295
|
+
write_list_args.append(None)
|
|
1296
|
+
continue
|
|
1297
|
+
|
|
1298
|
+
if name in self.count_sharing:
|
|
1299
|
+
sharing_from = self.count_sharing[name]
|
|
1300
|
+
|
|
1301
|
+
info_record = result[name] = BuiltList(
|
|
1302
|
+
count=result[sharing_from].count,
|
|
1303
|
+
starts=result[sharing_from].starts,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
else:
|
|
1307
|
+
info_record = result[name]
|
|
1308
|
+
|
|
1309
|
+
info_record.lists = cl_array.empty(queue,
|
|
1310
|
+
info_record.count, dtype, allocator=allocator)
|
|
1311
|
+
write_list_args.append(info_record.lists.data)
|
|
1312
|
+
|
|
1313
|
+
if name not in self.count_sharing:
|
|
1314
|
+
write_list_args.append(info_record.starts.data)
|
|
1315
|
+
|
|
1316
|
+
if name in self.eliminate_empty_output_lists:
|
|
1317
|
+
write_list_args.append(info_record.compressed_indices.data)
|
|
1318
|
+
|
|
1319
|
+
# }}}
|
|
1320
|
+
|
|
1321
|
+
evt = write_kernel(queue, gsize, lsize,
|
|
1322
|
+
*(tuple(write_list_args) + data_args + (n_objects,)),
|
|
1323
|
+
wait_for=scan_events)
|
|
1324
|
+
|
|
1325
|
+
return result, evt
|
|
1326
|
+
|
|
1327
|
+
# }}}
|
|
1328
|
+
|
|
1329
|
+
# }}}
|
|
1330
|
+
|
|
1331
|
+
|
|
1332
|
+
# {{{ key-value sorting
|
|
1333
|
+
|
|
1334
|
+
@dataclass(frozen=True)
|
|
1335
|
+
class _KernelInfo:
|
|
1336
|
+
by_target_sorter: RadixSort
|
|
1337
|
+
start_finder: ElementwiseKernel
|
|
1338
|
+
bound_propagation_scan: GenericScanKernel
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
def _make_cl_int_literal(value, dtype):
|
|
1342
|
+
iinfo = np.iinfo(dtype)
|
|
1343
|
+
result = str(int(value))
|
|
1344
|
+
if dtype.itemsize == 8:
|
|
1345
|
+
result += "l"
|
|
1346
|
+
if int(iinfo.min) < 0:
|
|
1347
|
+
result += "u"
|
|
1348
|
+
|
|
1349
|
+
return result
|
|
1350
|
+
|
|
1351
|
+
|
|
1352
|
+
class KeyValueSorter:
|
|
1353
|
+
"""Given arrays *values* and *keys* of equal length
|
|
1354
|
+
and a number *nkeys* of keys, returns a tuple `(starts,
|
|
1355
|
+
lists)`, as follows: *values* and *keys* are sorted
|
|
1356
|
+
by *keys*, and the sorted *values* is returned as
|
|
1357
|
+
*lists*. Then for each index *i* in ``range(nkeys)``,
|
|
1358
|
+
*starts[i]* is written to indicating where the
|
|
1359
|
+
group of *values* belonging to the key with index
|
|
1360
|
+
*i* begins. It implicitly ends at *starts[i+1]*.
|
|
1361
|
+
|
|
1362
|
+
``starts`` is built so that it has ``nkeys + 1`` entries, so that
|
|
1363
|
+
the *i*'th entry is the start of the *i*'th list, and the
|
|
1364
|
+
*i*'th entry is the index one past the *i*'th list's end,
|
|
1365
|
+
even for the last list.
|
|
1366
|
+
|
|
1367
|
+
This implies that all lists are contiguous.
|
|
1368
|
+
|
|
1369
|
+
.. note:: This functionality is provided as a preview. Its
|
|
1370
|
+
interface is subject to change until this notice is removed.
|
|
1371
|
+
|
|
1372
|
+
.. versionadded:: 2013.1
|
|
1373
|
+
"""
|
|
1374
|
+
|
|
1375
|
+
def __init__(self, context):
|
|
1376
|
+
self.context = context
|
|
1377
|
+
|
|
1378
|
+
@memoize_method
|
|
1379
|
+
def get_kernels(self, key_dtype, value_dtype, starts_dtype):
|
|
1380
|
+
from pyopencl.tools import ScalarArg, VectorArg
|
|
1381
|
+
|
|
1382
|
+
by_target_sorter = RadixSort(
|
|
1383
|
+
self.context, [
|
|
1384
|
+
VectorArg(value_dtype, "values"),
|
|
1385
|
+
VectorArg(key_dtype, "keys"),
|
|
1386
|
+
],
|
|
1387
|
+
key_expr="keys[i]",
|
|
1388
|
+
sort_arg_names=["values", "keys"])
|
|
1389
|
+
|
|
1390
|
+
from pyopencl.elementwise import ElementwiseTemplate
|
|
1391
|
+
start_finder = ElementwiseTemplate(
|
|
1392
|
+
arguments="""//CL//
|
|
1393
|
+
starts_t *key_group_starts,
|
|
1394
|
+
key_t *keys_sorted_by_key,
|
|
1395
|
+
""",
|
|
1396
|
+
|
|
1397
|
+
operation=r"""//CL//
|
|
1398
|
+
key_t my_key = keys_sorted_by_key[i];
|
|
1399
|
+
|
|
1400
|
+
if (i == 0 || my_key != keys_sorted_by_key[i-1])
|
|
1401
|
+
key_group_starts[my_key] = i;
|
|
1402
|
+
""",
|
|
1403
|
+
name="find_starts").build(self.context,
|
|
1404
|
+
type_aliases=(
|
|
1405
|
+
("key_t", starts_dtype),
|
|
1406
|
+
("starts_t", starts_dtype),
|
|
1407
|
+
),
|
|
1408
|
+
var_values=())
|
|
1409
|
+
|
|
1410
|
+
bound_propagation_scan = GenericScanKernel(
|
|
1411
|
+
self.context, starts_dtype,
|
|
1412
|
+
arguments=[
|
|
1413
|
+
VectorArg(starts_dtype, "starts"),
|
|
1414
|
+
# starts has length n+1
|
|
1415
|
+
ScalarArg(key_dtype, "nkeys"),
|
|
1416
|
+
],
|
|
1417
|
+
input_expr="starts[nkeys-i]",
|
|
1418
|
+
scan_expr="min(a, b)",
|
|
1419
|
+
neutral=_make_cl_int_literal(
|
|
1420
|
+
np.iinfo(starts_dtype).max, starts_dtype),
|
|
1421
|
+
output_statement="starts[nkeys-i] = item;")
|
|
1422
|
+
|
|
1423
|
+
return _KernelInfo(
|
|
1424
|
+
by_target_sorter=by_target_sorter,
|
|
1425
|
+
start_finder=start_finder,
|
|
1426
|
+
bound_propagation_scan=bound_propagation_scan)
|
|
1427
|
+
|
|
1428
|
+
def __call__(self, queue, keys, values, nkeys,
|
|
1429
|
+
starts_dtype, allocator=None, wait_for=None):
|
|
1430
|
+
if allocator is None:
|
|
1431
|
+
allocator = values.allocator
|
|
1432
|
+
|
|
1433
|
+
knl_info = self.get_kernels(keys.dtype, values.dtype,
|
|
1434
|
+
starts_dtype)
|
|
1435
|
+
|
|
1436
|
+
(values_sorted_by_key, keys_sorted_by_key), evt = knl_info.by_target_sorter(
|
|
1437
|
+
values, keys, queue=queue, wait_for=wait_for)
|
|
1438
|
+
|
|
1439
|
+
starts = (cl_array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
|
|
1440
|
+
.fill(len(values_sorted_by_key), wait_for=[evt]))
|
|
1441
|
+
evt, = starts.events
|
|
1442
|
+
|
|
1443
|
+
evt = knl_info.start_finder(starts, keys_sorted_by_key,
|
|
1444
|
+
range=slice(len(keys_sorted_by_key)),
|
|
1445
|
+
wait_for=[evt])
|
|
1446
|
+
|
|
1447
|
+
evt = knl_info.bound_propagation_scan(starts, nkeys,
|
|
1448
|
+
queue=queue, wait_for=[evt])
|
|
1449
|
+
|
|
1450
|
+
return starts, values_sorted_by_key, evt
|
|
1451
|
+
|
|
1452
|
+
# }}}
|
|
1453
|
+
|
|
1454
|
+
# vim: filetype=pyopencl:fdm=marker
|