pyopencl 2026.1.1__cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyopencl/.libs/libOpenCL-34a55fe4.so.1.0.0 +0 -0
- pyopencl/__init__.py +1995 -0
- pyopencl/_cl.cpython-314t-aarch64-linux-gnu.so +0 -0
- pyopencl/_cl.pyi +2009 -0
- pyopencl/_cluda.py +57 -0
- pyopencl/_monkeypatch.py +1104 -0
- pyopencl/_mymako.py +17 -0
- pyopencl/algorithm.py +1454 -0
- pyopencl/array.py +3530 -0
- pyopencl/bitonic_sort.py +245 -0
- pyopencl/bitonic_sort_templates.py +597 -0
- pyopencl/cache.py +553 -0
- pyopencl/capture_call.py +200 -0
- pyopencl/characterize/__init__.py +461 -0
- pyopencl/characterize/performance.py +240 -0
- pyopencl/cl/pyopencl-airy.cl +324 -0
- pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
- pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
- pyopencl/cl/pyopencl-bessel-y.cl +435 -0
- pyopencl/cl/pyopencl-complex.h +303 -0
- pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
- pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
- pyopencl/cl/pyopencl-random123/array.h +325 -0
- pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
- pyopencl/cl/pyopencl-random123/philox.cl +486 -0
- pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
- pyopencl/clmath.py +281 -0
- pyopencl/clrandom.py +412 -0
- pyopencl/cltypes.py +217 -0
- pyopencl/compyte/.gitignore +21 -0
- pyopencl/compyte/__init__.py +0 -0
- pyopencl/compyte/array.py +211 -0
- pyopencl/compyte/dtypes.py +314 -0
- pyopencl/compyte/pyproject.toml +49 -0
- pyopencl/elementwise.py +1288 -0
- pyopencl/invoker.py +417 -0
- pyopencl/ipython_ext.py +70 -0
- pyopencl/py.typed +0 -0
- pyopencl/reduction.py +829 -0
- pyopencl/scan.py +1921 -0
- pyopencl/tools.py +1680 -0
- pyopencl/typing.py +61 -0
- pyopencl/version.py +11 -0
- pyopencl-2026.1.1.dist-info/METADATA +108 -0
- pyopencl-2026.1.1.dist-info/RECORD +47 -0
- pyopencl-2026.1.1.dist-info/WHEEL +6 -0
- pyopencl-2026.1.1.dist-info/licenses/LICENSE +104 -0
pyopencl/scan.py
ADDED
|
@@ -0,0 +1,1921 @@
|
|
|
1
|
+
"""Scan primitive."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
__copyright__ = """
|
|
6
|
+
Copyright 2011-2012 Andreas Kloeckner
|
|
7
|
+
Copyright 2008-2011 NVIDIA Corporation
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
__license__ = """
|
|
11
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
12
|
+
you may not use this file except in compliance with the License.
|
|
13
|
+
You may obtain a copy of the License at
|
|
14
|
+
|
|
15
|
+
https://www.apache.org/licenses/LICENSE-2.0
|
|
16
|
+
|
|
17
|
+
Unless required by applicable law or agreed to in writing, software
|
|
18
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
19
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
20
|
+
See the License for the specific language governing permissions and
|
|
21
|
+
limitations under the License.
|
|
22
|
+
|
|
23
|
+
Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
from abc import ABC, abstractmethod
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
|
|
33
|
+
from pytools.persistent_dict import WriteOncePersistentDict
|
|
34
|
+
|
|
35
|
+
import pyopencl as cl
|
|
36
|
+
import pyopencl._mymako as mako
|
|
37
|
+
import pyopencl.array as cl_array
|
|
38
|
+
from pyopencl._cluda import CLUDA_PREAMBLE
|
|
39
|
+
from pyopencl.tools import (
|
|
40
|
+
DtypedArgument,
|
|
41
|
+
KernelTemplateBase,
|
|
42
|
+
_NumpyTypesKeyBuilder,
|
|
43
|
+
_process_code_for_macro,
|
|
44
|
+
bitlog2,
|
|
45
|
+
context_dependent_memoize,
|
|
46
|
+
dtype_to_ctype,
|
|
47
|
+
get_arg_list_scalar_arg_dtypes,
|
|
48
|
+
get_arg_offset_adjuster_code,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if TYPE_CHECKING:
|
|
53
|
+
from collections.abc import Sequence
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger(__name__)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# {{{ preamble
|
|
60
|
+
|
|
61
|
+
SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
|
|
62
|
+
#define WG_SIZE ${wg_size}
|
|
63
|
+
|
|
64
|
+
#define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr}
|
|
65
|
+
#define INPUT_EXPR(i) (${input_expr})
|
|
66
|
+
%if is_segmented:
|
|
67
|
+
#define IS_SEG_START(i, a) (${is_segment_start_expr})
|
|
68
|
+
%endif
|
|
69
|
+
|
|
70
|
+
${preamble}
|
|
71
|
+
|
|
72
|
+
typedef ${dtype_to_ctype(scan_dtype)} scan_type;
|
|
73
|
+
typedef ${dtype_to_ctype(index_dtype)} index_type;
|
|
74
|
+
|
|
75
|
+
// NO_SEG_BOUNDARY is the largest representable integer in index_type.
|
|
76
|
+
// This assumption is used in code below.
|
|
77
|
+
#define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# }}}
|
|
81
|
+
|
|
82
|
+
# {{{ main scan code
|
|
83
|
+
|
|
84
|
+
# Algorithm: Each work group is responsible for one contiguous
|
|
85
|
+
# 'interval'. There are just enough intervals to fill all compute
|
|
86
|
+
# units. Intervals are split into 'units'. A unit is what gets
|
|
87
|
+
# worked on in parallel by one work group.
|
|
88
|
+
#
|
|
89
|
+
# in index space:
|
|
90
|
+
# interval > unit > local-parallel > k-group
|
|
91
|
+
#
|
|
92
|
+
# (Note that there is also a transpose in here: The data is read
|
|
93
|
+
# with local ids along linear index order.)
|
|
94
|
+
#
|
|
95
|
+
# Each unit has two axes--the local-id axis and the k axis.
|
|
96
|
+
#
|
|
97
|
+
# unit 0:
|
|
98
|
+
# | | | | | | | | | | ----> lid
|
|
99
|
+
# | | | | | | | | | |
|
|
100
|
+
# | | | | | | | | | |
|
|
101
|
+
# | | | | | | | | | |
|
|
102
|
+
# | | | | | | | | | |
|
|
103
|
+
#
|
|
104
|
+
# |
|
|
105
|
+
# v k (fastest-moving in linear index)
|
|
106
|
+
#
|
|
107
|
+
# unit 1:
|
|
108
|
+
# | | | | | | | | | | ----> lid
|
|
109
|
+
# | | | | | | | | | |
|
|
110
|
+
# | | | | | | | | | |
|
|
111
|
+
# | | | | | | | | | |
|
|
112
|
+
# | | | | | | | | | |
|
|
113
|
+
#
|
|
114
|
+
# |
|
|
115
|
+
# v k (fastest-moving in linear index)
|
|
116
|
+
#
|
|
117
|
+
# ...
|
|
118
|
+
#
|
|
119
|
+
# At a device-global level, this is a three-phase algorithm, in
|
|
120
|
+
# which first each interval does its local scan, then a scan
|
|
121
|
+
# across intervals exchanges data globally, and the final update
|
|
122
|
+
# adds the exchanged sums to each interval.
|
|
123
|
+
#
|
|
124
|
+
# Exclusive scan is realized by allowing look-behind (access to the
|
|
125
|
+
# preceding item) in the final update, by means of a local shift.
|
|
126
|
+
#
|
|
127
|
+
# NOTE: All segment_start_in_X indices are relative to the start
|
|
128
|
+
# of the array.
|
|
129
|
+
|
|
130
|
+
SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL//
|
|
131
|
+
|
|
132
|
+
#define K ${k_group_size}
|
|
133
|
+
|
|
134
|
+
// #define DEBUG
|
|
135
|
+
#ifdef DEBUG
|
|
136
|
+
#define pycl_printf(ARGS) printf ARGS
|
|
137
|
+
#else
|
|
138
|
+
#define pycl_printf(ARGS) /* */
|
|
139
|
+
#endif
|
|
140
|
+
|
|
141
|
+
KERNEL
|
|
142
|
+
REQD_WG_SIZE(WG_SIZE, 1, 1)
|
|
143
|
+
void ${kernel_name}(
|
|
144
|
+
${argument_signature},
|
|
145
|
+
GLOBAL_MEM scan_type *restrict partial_scan_buffer,
|
|
146
|
+
const index_type N,
|
|
147
|
+
const index_type interval_size
|
|
148
|
+
%if is_first_level:
|
|
149
|
+
, GLOBAL_MEM scan_type *restrict interval_results
|
|
150
|
+
%endif
|
|
151
|
+
%if is_segmented and is_first_level:
|
|
152
|
+
// NO_SEG_BOUNDARY if no segment boundary in interval.
|
|
153
|
+
, GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
|
|
154
|
+
%endif
|
|
155
|
+
%if store_segment_start_flags:
|
|
156
|
+
, GLOBAL_MEM char *restrict g_segment_start_flags
|
|
157
|
+
%endif
|
|
158
|
+
)
|
|
159
|
+
{
|
|
160
|
+
${arg_offset_adjustment}
|
|
161
|
+
|
|
162
|
+
// index K in first dimension used for carry storage
|
|
163
|
+
%if use_bank_conflict_avoidance:
|
|
164
|
+
// Avoid bank conflicts by adding a single 32-bit value to the size of
|
|
165
|
+
// the scan type.
|
|
166
|
+
struct __attribute__ ((__packed__)) wrapped_scan_type
|
|
167
|
+
{
|
|
168
|
+
scan_type value;
|
|
169
|
+
int dummy;
|
|
170
|
+
};
|
|
171
|
+
%else:
|
|
172
|
+
struct wrapped_scan_type
|
|
173
|
+
{
|
|
174
|
+
scan_type value;
|
|
175
|
+
};
|
|
176
|
+
%endif
|
|
177
|
+
// padded in WG_SIZE to avoid bank conflicts
|
|
178
|
+
LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1];
|
|
179
|
+
|
|
180
|
+
%if is_segmented:
|
|
181
|
+
LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
|
|
182
|
+
LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
|
|
183
|
+
|
|
184
|
+
// only relevant/populated for local id 0
|
|
185
|
+
index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
|
|
186
|
+
|
|
187
|
+
index_type first_segment_start_in_k_group, first_segment_start_in_subtree;
|
|
188
|
+
%endif
|
|
189
|
+
|
|
190
|
+
// {{{ declare local data for input_fetch_exprs if any of them are stenciled
|
|
191
|
+
|
|
192
|
+
<%
|
|
193
|
+
fetch_expr_offsets = {}
|
|
194
|
+
for name, arg_name, ife_offset in input_fetch_exprs:
|
|
195
|
+
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
|
|
196
|
+
|
|
197
|
+
local_fetch_expr_args = set(
|
|
198
|
+
arg_name
|
|
199
|
+
for arg_name, ife_offsets in fetch_expr_offsets.items()
|
|
200
|
+
if -1 in ife_offsets or len(ife_offsets) > 1)
|
|
201
|
+
%>
|
|
202
|
+
|
|
203
|
+
%for arg_name in local_fetch_expr_args:
|
|
204
|
+
LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
|
|
205
|
+
%endfor
|
|
206
|
+
|
|
207
|
+
// }}}
|
|
208
|
+
|
|
209
|
+
const index_type interval_begin = interval_size * GID_0;
|
|
210
|
+
const index_type interval_end = min(interval_begin + interval_size, N);
|
|
211
|
+
|
|
212
|
+
const index_type unit_size = K * WG_SIZE;
|
|
213
|
+
|
|
214
|
+
index_type unit_base = interval_begin;
|
|
215
|
+
|
|
216
|
+
%for is_tail in [False, True]:
|
|
217
|
+
|
|
218
|
+
%if not is_tail:
|
|
219
|
+
for(; unit_base + unit_size <= interval_end; unit_base += unit_size)
|
|
220
|
+
%else:
|
|
221
|
+
if (unit_base < interval_end)
|
|
222
|
+
%endif
|
|
223
|
+
|
|
224
|
+
{
|
|
225
|
+
|
|
226
|
+
// {{{ carry out input_fetch_exprs
|
|
227
|
+
// (if there are ones that need to be fetched into local)
|
|
228
|
+
|
|
229
|
+
%if local_fetch_expr_args:
|
|
230
|
+
for(index_type k = 0; k < K; k++)
|
|
231
|
+
{
|
|
232
|
+
const index_type offset = k*WG_SIZE + LID_0;
|
|
233
|
+
const index_type read_i = unit_base + offset;
|
|
234
|
+
|
|
235
|
+
%for arg_name in local_fetch_expr_args:
|
|
236
|
+
%if is_tail:
|
|
237
|
+
if (read_i < interval_end)
|
|
238
|
+
%endif
|
|
239
|
+
{
|
|
240
|
+
l_${arg_name}[offset] = ${arg_name}[read_i];
|
|
241
|
+
}
|
|
242
|
+
%endfor
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
local_barrier();
|
|
246
|
+
%endif
|
|
247
|
+
|
|
248
|
+
pycl_printf(("after input_fetch_exprs\n"));
|
|
249
|
+
|
|
250
|
+
// }}}
|
|
251
|
+
|
|
252
|
+
// {{{ read a unit's worth of data from global
|
|
253
|
+
|
|
254
|
+
for(index_type k = 0; k < K; k++)
|
|
255
|
+
{
|
|
256
|
+
const index_type offset = k*WG_SIZE + LID_0;
|
|
257
|
+
const index_type read_i = unit_base + offset;
|
|
258
|
+
|
|
259
|
+
%if is_tail:
|
|
260
|
+
if (read_i < interval_end)
|
|
261
|
+
%endif
|
|
262
|
+
{
|
|
263
|
+
%for name, arg_name, ife_offset in input_fetch_exprs:
|
|
264
|
+
${arg_ctypes[arg_name]} ${name};
|
|
265
|
+
|
|
266
|
+
%if arg_name in local_fetch_expr_args:
|
|
267
|
+
if (offset + ${ife_offset} >= 0)
|
|
268
|
+
${name} = l_${arg_name}[offset + ${ife_offset}];
|
|
269
|
+
else if (read_i + ${ife_offset} >= 0)
|
|
270
|
+
${name} = ${arg_name}[read_i + ${ife_offset}];
|
|
271
|
+
/*
|
|
272
|
+
else
|
|
273
|
+
if out of bounds, name is left undefined */
|
|
274
|
+
|
|
275
|
+
%else:
|
|
276
|
+
// ${arg_name} gets fetched directly from global
|
|
277
|
+
${name} = ${arg_name}[read_i];
|
|
278
|
+
|
|
279
|
+
%endif
|
|
280
|
+
%endfor
|
|
281
|
+
|
|
282
|
+
scan_type scan_value = INPUT_EXPR(read_i);
|
|
283
|
+
|
|
284
|
+
const index_type o_mod_k = offset % K;
|
|
285
|
+
const index_type o_div_k = offset / K;
|
|
286
|
+
ldata[o_mod_k][o_div_k].value = scan_value;
|
|
287
|
+
|
|
288
|
+
%if is_segmented:
|
|
289
|
+
bool is_seg_start = IS_SEG_START(read_i, scan_value);
|
|
290
|
+
l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start;
|
|
291
|
+
%endif
|
|
292
|
+
%if store_segment_start_flags:
|
|
293
|
+
g_segment_start_flags[read_i] = is_seg_start;
|
|
294
|
+
%endif
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
pycl_printf(("after read from global\n"));
|
|
299
|
+
|
|
300
|
+
// }}}
|
|
301
|
+
|
|
302
|
+
// {{{ carry in from previous unit, if applicable
|
|
303
|
+
|
|
304
|
+
%if is_segmented:
|
|
305
|
+
local_barrier();
|
|
306
|
+
|
|
307
|
+
first_segment_start_in_k_group = NO_SEG_BOUNDARY;
|
|
308
|
+
if (l_segment_start_flags[0][LID_0])
|
|
309
|
+
first_segment_start_in_k_group = unit_base + K*LID_0;
|
|
310
|
+
%endif
|
|
311
|
+
|
|
312
|
+
if (LID_0 == 0 && unit_base != interval_begin)
|
|
313
|
+
{
|
|
314
|
+
scan_type tmp = ldata[K][WG_SIZE - 1].value;
|
|
315
|
+
scan_type tmp_aux = ldata[0][0].value;
|
|
316
|
+
|
|
317
|
+
ldata[0][0].value = SCAN_EXPR(
|
|
318
|
+
tmp, tmp_aux,
|
|
319
|
+
%if is_segmented:
|
|
320
|
+
(l_segment_start_flags[0][0])
|
|
321
|
+
%else:
|
|
322
|
+
false
|
|
323
|
+
%endif
|
|
324
|
+
);
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
pycl_printf(("after carry-in\n"));
|
|
328
|
+
|
|
329
|
+
// }}}
|
|
330
|
+
|
|
331
|
+
local_barrier();
|
|
332
|
+
|
|
333
|
+
// {{{ scan along k (sequentially in each work item)
|
|
334
|
+
|
|
335
|
+
scan_type sum = ldata[0][LID_0].value;
|
|
336
|
+
|
|
337
|
+
%if is_tail:
|
|
338
|
+
const index_type offset_end = interval_end - unit_base;
|
|
339
|
+
%endif
|
|
340
|
+
|
|
341
|
+
for (index_type k = 1; k < K; k++)
|
|
342
|
+
{
|
|
343
|
+
%if is_tail:
|
|
344
|
+
if ((index_type) (K * LID_0 + k) < offset_end)
|
|
345
|
+
%endif
|
|
346
|
+
{
|
|
347
|
+
scan_type tmp = ldata[k][LID_0].value;
|
|
348
|
+
|
|
349
|
+
%if is_segmented:
|
|
350
|
+
index_type seq_i = unit_base + K*LID_0 + k;
|
|
351
|
+
|
|
352
|
+
if (l_segment_start_flags[k][LID_0])
|
|
353
|
+
{
|
|
354
|
+
first_segment_start_in_k_group = min(
|
|
355
|
+
first_segment_start_in_k_group,
|
|
356
|
+
seq_i);
|
|
357
|
+
}
|
|
358
|
+
%endif
|
|
359
|
+
|
|
360
|
+
sum = SCAN_EXPR(sum, tmp,
|
|
361
|
+
%if is_segmented:
|
|
362
|
+
(l_segment_start_flags[k][LID_0])
|
|
363
|
+
%else:
|
|
364
|
+
false
|
|
365
|
+
%endif
|
|
366
|
+
);
|
|
367
|
+
|
|
368
|
+
ldata[k][LID_0].value = sum;
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
pycl_printf(("after scan along k\n"));
|
|
373
|
+
|
|
374
|
+
// }}}
|
|
375
|
+
|
|
376
|
+
// store carry in out-of-bounds (padding) array entry (index K) in
|
|
377
|
+
// the K direction
|
|
378
|
+
ldata[K][LID_0].value = sum;
|
|
379
|
+
|
|
380
|
+
%if is_segmented:
|
|
381
|
+
l_first_segment_start_in_subtree[LID_0] =
|
|
382
|
+
first_segment_start_in_k_group;
|
|
383
|
+
%endif
|
|
384
|
+
|
|
385
|
+
local_barrier();
|
|
386
|
+
|
|
387
|
+
// {{{ tree-based local parallel scan
|
|
388
|
+
|
|
389
|
+
// This tree-based scan works as follows:
|
|
390
|
+
// - Each work item adds the previous item to its current state
|
|
391
|
+
// - barrier
|
|
392
|
+
// - Each work item adds in the item from two positions to the left
|
|
393
|
+
// - barrier
|
|
394
|
+
// - Each work item adds in the item from four positions to the left
|
|
395
|
+
// ...
|
|
396
|
+
// At the end, each item has summed all prior items.
|
|
397
|
+
|
|
398
|
+
// across k groups, along local id
|
|
399
|
+
// (uses out-of-bounds k=K array entry for storage)
|
|
400
|
+
|
|
401
|
+
scan_type val = ldata[K][LID_0].value;
|
|
402
|
+
|
|
403
|
+
<% scan_offset = 1 %>
|
|
404
|
+
|
|
405
|
+
% while scan_offset <= wg_size:
|
|
406
|
+
// {{{ reads from local allowed, writes to local not allowed
|
|
407
|
+
|
|
408
|
+
if (LID_0 >= ${scan_offset})
|
|
409
|
+
{
|
|
410
|
+
scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value;
|
|
411
|
+
% if is_tail:
|
|
412
|
+
if (K*LID_0 < offset_end)
|
|
413
|
+
% endif
|
|
414
|
+
{
|
|
415
|
+
val = SCAN_EXPR(tmp, val,
|
|
416
|
+
%if is_segmented:
|
|
417
|
+
(l_first_segment_start_in_subtree[LID_0]
|
|
418
|
+
!= NO_SEG_BOUNDARY)
|
|
419
|
+
%else:
|
|
420
|
+
false
|
|
421
|
+
%endif
|
|
422
|
+
);
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
%if is_segmented:
|
|
426
|
+
// Prepare for l_first_segment_start_in_subtree, below.
|
|
427
|
+
|
|
428
|
+
// Note that this update must take place *even* if we're
|
|
429
|
+
// out of bounds.
|
|
430
|
+
|
|
431
|
+
first_segment_start_in_subtree = min(
|
|
432
|
+
l_first_segment_start_in_subtree[LID_0],
|
|
433
|
+
l_first_segment_start_in_subtree
|
|
434
|
+
[LID_0 - ${scan_offset}]);
|
|
435
|
+
%endif
|
|
436
|
+
}
|
|
437
|
+
%if is_segmented:
|
|
438
|
+
else
|
|
439
|
+
{
|
|
440
|
+
first_segment_start_in_subtree =
|
|
441
|
+
l_first_segment_start_in_subtree[LID_0];
|
|
442
|
+
}
|
|
443
|
+
%endif
|
|
444
|
+
|
|
445
|
+
// }}}
|
|
446
|
+
|
|
447
|
+
local_barrier();
|
|
448
|
+
|
|
449
|
+
// {{{ writes to local allowed, reads from local not allowed
|
|
450
|
+
|
|
451
|
+
ldata[K][LID_0].value = val;
|
|
452
|
+
%if is_segmented:
|
|
453
|
+
l_first_segment_start_in_subtree[LID_0] =
|
|
454
|
+
first_segment_start_in_subtree;
|
|
455
|
+
%endif
|
|
456
|
+
|
|
457
|
+
// }}}
|
|
458
|
+
|
|
459
|
+
local_barrier();
|
|
460
|
+
|
|
461
|
+
%if 0:
|
|
462
|
+
if (LID_0 == 0)
|
|
463
|
+
{
|
|
464
|
+
printf("${scan_offset}: ");
|
|
465
|
+
for (int i = 0; i < WG_SIZE; ++i)
|
|
466
|
+
{
|
|
467
|
+
if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY)
|
|
468
|
+
printf("- ");
|
|
469
|
+
else
|
|
470
|
+
printf("%d ", l_first_segment_start_in_subtree[i]);
|
|
471
|
+
}
|
|
472
|
+
printf("\n");
|
|
473
|
+
}
|
|
474
|
+
%endif
|
|
475
|
+
|
|
476
|
+
<% scan_offset *= 2 %>
|
|
477
|
+
% endwhile
|
|
478
|
+
|
|
479
|
+
pycl_printf(("after tree scan\n"));
|
|
480
|
+
|
|
481
|
+
// }}}
|
|
482
|
+
|
|
483
|
+
// {{{ update local values
|
|
484
|
+
|
|
485
|
+
if (LID_0 > 0)
|
|
486
|
+
{
|
|
487
|
+
sum = ldata[K][LID_0 - 1].value;
|
|
488
|
+
|
|
489
|
+
for(index_type k = 0; k < K; k++)
|
|
490
|
+
{
|
|
491
|
+
%if is_tail:
|
|
492
|
+
if (K * LID_0 + k < offset_end)
|
|
493
|
+
%endif
|
|
494
|
+
{
|
|
495
|
+
scan_type tmp = ldata[k][LID_0].value;
|
|
496
|
+
ldata[k][LID_0].value = SCAN_EXPR(sum, tmp,
|
|
497
|
+
%if is_segmented:
|
|
498
|
+
(unit_base + K * LID_0 + k
|
|
499
|
+
>= first_segment_start_in_k_group)
|
|
500
|
+
%else:
|
|
501
|
+
false
|
|
502
|
+
%endif
|
|
503
|
+
);
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
%if is_segmented:
|
|
509
|
+
if (LID_0 == 0)
|
|
510
|
+
{
|
|
511
|
+
// update interval-wide first-seg variable from current unit
|
|
512
|
+
first_segment_start_in_interval = min(
|
|
513
|
+
first_segment_start_in_interval,
|
|
514
|
+
l_first_segment_start_in_subtree[WG_SIZE-1]);
|
|
515
|
+
}
|
|
516
|
+
%endif
|
|
517
|
+
|
|
518
|
+
pycl_printf(("after local update\n"));
|
|
519
|
+
|
|
520
|
+
// }}}
|
|
521
|
+
|
|
522
|
+
local_barrier();
|
|
523
|
+
|
|
524
|
+
// {{{ write data
|
|
525
|
+
|
|
526
|
+
%if is_gpu:
|
|
527
|
+
{
|
|
528
|
+
// work hard with index math to achieve contiguous 32-bit stores
|
|
529
|
+
__global int *dest =
|
|
530
|
+
(__global int *) (partial_scan_buffer + unit_base);
|
|
531
|
+
|
|
532
|
+
<%
|
|
533
|
+
|
|
534
|
+
assert scan_dtype.itemsize % 4 == 0
|
|
535
|
+
|
|
536
|
+
ints_per_wg = wg_size
|
|
537
|
+
ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
|
|
538
|
+
|
|
539
|
+
%>
|
|
540
|
+
|
|
541
|
+
const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
|
|
542
|
+
|
|
543
|
+
%for store_base in range(0, ints_to_store, ints_per_wg):
|
|
544
|
+
<%
|
|
545
|
+
|
|
546
|
+
# Observe that ints_to_store is divisible by the work group
|
|
547
|
+
# size already, so we won't go out of bounds that way.
|
|
548
|
+
assert store_base + ints_per_wg <= ints_to_store
|
|
549
|
+
|
|
550
|
+
%>
|
|
551
|
+
|
|
552
|
+
%if is_tail:
|
|
553
|
+
if (${store_base} + LID_0 <
|
|
554
|
+
scan_types_per_int*(interval_end - unit_base))
|
|
555
|
+
%endif
|
|
556
|
+
{
|
|
557
|
+
index_type linear_index = ${store_base} + LID_0;
|
|
558
|
+
index_type linear_scan_data_idx =
|
|
559
|
+
linear_index / scan_types_per_int;
|
|
560
|
+
index_type remainder =
|
|
561
|
+
linear_index - linear_scan_data_idx * scan_types_per_int;
|
|
562
|
+
|
|
563
|
+
__local int *src = (__local int *) &(
|
|
564
|
+
ldata
|
|
565
|
+
[linear_scan_data_idx % K]
|
|
566
|
+
[linear_scan_data_idx / K].value);
|
|
567
|
+
|
|
568
|
+
dest[linear_index] = src[remainder];
|
|
569
|
+
}
|
|
570
|
+
%endfor
|
|
571
|
+
}
|
|
572
|
+
%else:
|
|
573
|
+
for (index_type k = 0; k < K; k++)
|
|
574
|
+
{
|
|
575
|
+
const index_type offset = k*WG_SIZE + LID_0;
|
|
576
|
+
|
|
577
|
+
%if is_tail:
|
|
578
|
+
if (unit_base + offset < interval_end)
|
|
579
|
+
%endif
|
|
580
|
+
{
|
|
581
|
+
pycl_printf(("write: %d\n", unit_base + offset));
|
|
582
|
+
partial_scan_buffer[unit_base + offset] =
|
|
583
|
+
ldata[offset % K][offset / K].value;
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
%endif
|
|
587
|
+
|
|
588
|
+
pycl_printf(("after write\n"));
|
|
589
|
+
|
|
590
|
+
// }}}
|
|
591
|
+
|
|
592
|
+
local_barrier();
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
% endfor
|
|
596
|
+
|
|
597
|
+
// write interval sum
|
|
598
|
+
%if is_first_level:
|
|
599
|
+
if (LID_0 == 0)
|
|
600
|
+
{
|
|
601
|
+
interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
|
|
602
|
+
%if is_segmented:
|
|
603
|
+
g_first_segment_start_in_interval[GID_0] =
|
|
604
|
+
first_segment_start_in_interval;
|
|
605
|
+
%endif
|
|
606
|
+
}
|
|
607
|
+
%endif
|
|
608
|
+
}
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
# }}}
|
|
612
|
+
|
|
613
|
+
# {{{ update
|
|
614
|
+
|
|
615
|
+
UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
|
|
616
|
+
|
|
617
|
+
KERNEL
|
|
618
|
+
REQD_WG_SIZE(WG_SIZE, 1, 1)
|
|
619
|
+
void ${name_prefix}_final_update(
|
|
620
|
+
${argument_signature},
|
|
621
|
+
const index_type N,
|
|
622
|
+
const index_type interval_size,
|
|
623
|
+
GLOBAL_MEM scan_type *restrict interval_results,
|
|
624
|
+
GLOBAL_MEM scan_type *restrict partial_scan_buffer
|
|
625
|
+
%if is_segmented:
|
|
626
|
+
, GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
|
|
627
|
+
%endif
|
|
628
|
+
%if is_segmented and use_lookbehind_update:
|
|
629
|
+
, GLOBAL_MEM char *restrict g_segment_start_flags
|
|
630
|
+
%endif
|
|
631
|
+
)
|
|
632
|
+
{
|
|
633
|
+
${arg_offset_adjustment}
|
|
634
|
+
|
|
635
|
+
%if use_lookbehind_update:
|
|
636
|
+
LOCAL_MEM scan_type ldata[WG_SIZE];
|
|
637
|
+
%endif
|
|
638
|
+
%if is_segmented and use_lookbehind_update:
|
|
639
|
+
LOCAL_MEM char l_segment_start_flags[WG_SIZE];
|
|
640
|
+
%endif
|
|
641
|
+
|
|
642
|
+
const index_type interval_begin = interval_size * GID_0;
|
|
643
|
+
const index_type interval_end = min(interval_begin + interval_size, N);
|
|
644
|
+
|
|
645
|
+
// carry from last interval
|
|
646
|
+
scan_type carry = ${neutral};
|
|
647
|
+
if (GID_0 != 0)
|
|
648
|
+
carry = interval_results[GID_0 - 1];
|
|
649
|
+
|
|
650
|
+
%if is_segmented:
|
|
651
|
+
const index_type first_seg_start_in_interval =
|
|
652
|
+
g_first_segment_start_in_interval[GID_0];
|
|
653
|
+
%endif
|
|
654
|
+
|
|
655
|
+
%if not is_segmented and 'last_item' in output_statement:
|
|
656
|
+
scan_type last_item = interval_results[GDIM_0-1];
|
|
657
|
+
%endif
|
|
658
|
+
|
|
659
|
+
%if not use_lookbehind_update:
|
|
660
|
+
// {{{ no look-behind ('prev_item' not in output_statement -> simpler)
|
|
661
|
+
|
|
662
|
+
index_type update_i = interval_begin+LID_0;
|
|
663
|
+
|
|
664
|
+
%if is_segmented:
|
|
665
|
+
index_type seg_end = min(first_seg_start_in_interval, interval_end);
|
|
666
|
+
%endif
|
|
667
|
+
|
|
668
|
+
for(; update_i < interval_end; update_i += WG_SIZE)
|
|
669
|
+
{
|
|
670
|
+
scan_type partial_val = partial_scan_buffer[update_i];
|
|
671
|
+
scan_type item = SCAN_EXPR(carry, partial_val,
|
|
672
|
+
%if is_segmented:
|
|
673
|
+
(update_i >= seg_end)
|
|
674
|
+
%else:
|
|
675
|
+
false
|
|
676
|
+
%endif
|
|
677
|
+
);
|
|
678
|
+
index_type i = update_i;
|
|
679
|
+
|
|
680
|
+
{ ${output_statement}; }
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
// }}}
|
|
684
|
+
%else:
|
|
685
|
+
// {{{ allow look-behind ('prev_item' in output_statement -> complicated)
|
|
686
|
+
|
|
687
|
+
// We are not allowed to branch across barriers at a granularity smaller
|
|
688
|
+
// than the whole workgroup. Therefore, the for loop is group-global,
|
|
689
|
+
// and there are lots of local ifs.
|
|
690
|
+
|
|
691
|
+
index_type group_base = interval_begin;
|
|
692
|
+
scan_type prev_item = carry; // (A)
|
|
693
|
+
|
|
694
|
+
for(; group_base < interval_end; group_base += WG_SIZE)
|
|
695
|
+
{
|
|
696
|
+
index_type update_i = group_base+LID_0;
|
|
697
|
+
|
|
698
|
+
// load a work group's worth of data
|
|
699
|
+
if (update_i < interval_end)
|
|
700
|
+
{
|
|
701
|
+
scan_type tmp = partial_scan_buffer[update_i];
|
|
702
|
+
|
|
703
|
+
tmp = SCAN_EXPR(carry, tmp,
|
|
704
|
+
%if is_segmented:
|
|
705
|
+
(update_i >= first_seg_start_in_interval)
|
|
706
|
+
%else:
|
|
707
|
+
false
|
|
708
|
+
%endif
|
|
709
|
+
);
|
|
710
|
+
|
|
711
|
+
ldata[LID_0] = tmp;
|
|
712
|
+
|
|
713
|
+
%if is_segmented:
|
|
714
|
+
l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
|
|
715
|
+
%endif
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
local_barrier();
|
|
719
|
+
|
|
720
|
+
// find prev_item
|
|
721
|
+
if (LID_0 != 0)
|
|
722
|
+
prev_item = ldata[LID_0 - 1];
|
|
723
|
+
/*
|
|
724
|
+
else
|
|
725
|
+
prev_item = carry (see (A)) OR last tail (see (B));
|
|
726
|
+
*/
|
|
727
|
+
|
|
728
|
+
if (update_i < interval_end)
|
|
729
|
+
{
|
|
730
|
+
%if is_segmented:
|
|
731
|
+
if (l_segment_start_flags[LID_0])
|
|
732
|
+
prev_item = ${neutral};
|
|
733
|
+
%endif
|
|
734
|
+
|
|
735
|
+
scan_type item = ldata[LID_0];
|
|
736
|
+
index_type i = update_i;
|
|
737
|
+
{ ${output_statement}; }
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
if (LID_0 == 0)
|
|
741
|
+
prev_item = ldata[WG_SIZE - 1]; // (B)
|
|
742
|
+
|
|
743
|
+
local_barrier();
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
// }}}
|
|
747
|
+
%endif
|
|
748
|
+
}
|
|
749
|
+
"""
|
|
750
|
+
|
|
751
|
+
# }}}
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
# {{{ driver
|
|
755
|
+
|
|
756
|
+
# {{{ helpers
|
|
757
|
+
|
|
758
|
+
def _round_down_to_power_of_2(val: int) -> int:
|
|
759
|
+
result = 2**bitlog2(val)
|
|
760
|
+
if result > val:
|
|
761
|
+
result >>= 1
|
|
762
|
+
|
|
763
|
+
assert result <= val
|
|
764
|
+
return result
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
_PREFIX_WORDS = set("""
|
|
768
|
+
ldata partial_scan_buffer global scan_offset
|
|
769
|
+
segment_start_in_k_group carry
|
|
770
|
+
g_first_segment_start_in_interval IS_SEG_START tmp Z
|
|
771
|
+
val l_first_segment_start_in_subtree unit_size
|
|
772
|
+
index_type interval_begin interval_size offset_end K
|
|
773
|
+
SCAN_EXPR do_update WG_SIZE
|
|
774
|
+
first_segment_start_in_k_group scan_type
|
|
775
|
+
segment_start_in_subtree offset interval_results interval_end
|
|
776
|
+
first_segment_start_in_subtree unit_base
|
|
777
|
+
first_segment_start_in_interval k INPUT_EXPR
|
|
778
|
+
prev_group_sum prev pv value partial_val pgs
|
|
779
|
+
is_seg_start update_i scan_item_at_i seq_i read_i
|
|
780
|
+
l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
|
|
781
|
+
first_seg_start_in_interval g_segment_start_flags
|
|
782
|
+
group_base seg_end my_val DEBUG ARGS
|
|
783
|
+
ints_to_store ints_per_wg scan_types_per_int linear_index
|
|
784
|
+
linear_scan_data_idx dest src store_base wrapped_scan_type
|
|
785
|
+
dummy scan_tmp tmp_aux
|
|
786
|
+
|
|
787
|
+
LID_2 LID_1 LID_0
|
|
788
|
+
LDIM_0 LDIM_1 LDIM_2
|
|
789
|
+
GDIM_0 GDIM_1 GDIM_2
|
|
790
|
+
GID_0 GID_1 GID_2
|
|
791
|
+
""".split())
|
|
792
|
+
|
|
793
|
+
_IGNORED_WORDS = set("""
|
|
794
|
+
4 8 32
|
|
795
|
+
|
|
796
|
+
typedef for endfor if void while endwhile endfor endif else const printf
|
|
797
|
+
None return bool n char true false ifdef pycl_printf str range assert
|
|
798
|
+
np iinfo max itemsize __packed__ struct restrict ptrdiff_t
|
|
799
|
+
|
|
800
|
+
set iteritems len setdefault
|
|
801
|
+
|
|
802
|
+
GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
|
|
803
|
+
local_barrier
|
|
804
|
+
CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
|
|
805
|
+
pragma __attribute__ __global __kernel __local
|
|
806
|
+
get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
|
|
807
|
+
get_num_groups barrier get_group_id
|
|
808
|
+
CL_VERSION_1_1 __OPENCL_C_VERSION__ 120
|
|
809
|
+
|
|
810
|
+
_final_update _debug_scan kernel_name
|
|
811
|
+
|
|
812
|
+
positions all padded integer its previous write based writes 0
|
|
813
|
+
has local worth scan_expr to read cannot not X items False bank
|
|
814
|
+
four beginning follows applicable item min each indices works side
|
|
815
|
+
scanning right summed relative used id out index avoid current state
|
|
816
|
+
boundary True across be This reads groups along Otherwise undetermined
|
|
817
|
+
store of times prior s update first regardless Each number because
|
|
818
|
+
array unit from segment conflicts two parallel 2 empty define direction
|
|
819
|
+
CL padding work tree bounds values and adds
|
|
820
|
+
scan is allowed thus it an as enable at in occur sequentially end no
|
|
821
|
+
storage data 1 largest may representable uses entry Y meaningful
|
|
822
|
+
computations interval At the left dimension know d
|
|
823
|
+
A load B group perform shift tail see last OR
|
|
824
|
+
this add fetched into are directly need
|
|
825
|
+
gets them stenciled that undefined
|
|
826
|
+
there up any ones or name only relevant populated
|
|
827
|
+
even wide we Prepare int seg Note re below place take variable must
|
|
828
|
+
intra Therefore find code assumption
|
|
829
|
+
branch workgroup complicated granularity phase remainder than simpler
|
|
830
|
+
We smaller look ifs lots self behind allow barriers whole loop
|
|
831
|
+
after already Observe achieve contiguous stores hard go with by math
|
|
832
|
+
size won t way divisible bit so Avoid declare adding single type
|
|
833
|
+
|
|
834
|
+
is_tail is_first_level input_expr argument_signature preamble
|
|
835
|
+
double_support neutral output_statement
|
|
836
|
+
k_group_size name_prefix is_segmented index_dtype scan_dtype
|
|
837
|
+
wg_size is_segment_start_expr fetch_expr_offsets
|
|
838
|
+
arg_ctypes ife_offsets input_fetch_exprs def
|
|
839
|
+
ife_offset arg_name local_fetch_expr_args update_body
|
|
840
|
+
update_loop_lookbehind update_loop_plain update_loop
|
|
841
|
+
use_lookbehind_update store_segment_start_flags
|
|
842
|
+
update_loop first_seg scan_dtype dtype_to_ctype
|
|
843
|
+
is_gpu use_bank_conflict_avoidance
|
|
844
|
+
|
|
845
|
+
a b prev_item i last_item prev_value
|
|
846
|
+
N NO_SEG_BOUNDARY across_seg_boundary
|
|
847
|
+
|
|
848
|
+
arg_offset_adjustment
|
|
849
|
+
""".split())
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
def _make_template(s: str):
|
|
853
|
+
import re
|
|
854
|
+
leftovers = set()
|
|
855
|
+
|
|
856
|
+
def replace_id(match: re.Match) -> str:
|
|
857
|
+
# avoid name clashes with user code by adding 'psc_' prefix to
|
|
858
|
+
# identifiers.
|
|
859
|
+
|
|
860
|
+
word = match.group(1)
|
|
861
|
+
if word in _IGNORED_WORDS:
|
|
862
|
+
return word
|
|
863
|
+
elif word in _PREFIX_WORDS:
|
|
864
|
+
return f"psc_{word}"
|
|
865
|
+
else:
|
|
866
|
+
leftovers.add(word)
|
|
867
|
+
return word
|
|
868
|
+
|
|
869
|
+
s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
|
|
870
|
+
if leftovers:
|
|
871
|
+
from warnings import warn
|
|
872
|
+
warn("Leftover words in identifier prefixing: " + " ".join(leftovers),
|
|
873
|
+
stacklevel=3)
|
|
874
|
+
|
|
875
|
+
return mako.template.Template(s, strict_undefined=True)
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
@dataclass(frozen=True)
|
|
879
|
+
class _GeneratedScanKernelInfo:
|
|
880
|
+
scan_src: str
|
|
881
|
+
kernel_name: str
|
|
882
|
+
scalar_arg_dtypes: list[np.dtype | None]
|
|
883
|
+
wg_size: int
|
|
884
|
+
k_group_size: int
|
|
885
|
+
|
|
886
|
+
def build(self, context: cl.Context, options: Any) -> _BuiltScanKernelInfo:
|
|
887
|
+
program = cl.Program(context, self.scan_src).build(options)
|
|
888
|
+
kernel = getattr(program, self.kernel_name)
|
|
889
|
+
kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
|
|
890
|
+
return _BuiltScanKernelInfo(
|
|
891
|
+
kernel=kernel,
|
|
892
|
+
wg_size=self.wg_size,
|
|
893
|
+
k_group_size=self.k_group_size)
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
@dataclass(frozen=True)
|
|
897
|
+
class _BuiltScanKernelInfo:
|
|
898
|
+
kernel: cl.Kernel
|
|
899
|
+
wg_size: int
|
|
900
|
+
k_group_size: int
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
@dataclass(frozen=True)
|
|
904
|
+
class _GeneratedFinalUpdateKernelInfo:
|
|
905
|
+
source: str
|
|
906
|
+
kernel_name: str
|
|
907
|
+
scalar_arg_dtypes: Sequence[np.dtype | None]
|
|
908
|
+
update_wg_size: int
|
|
909
|
+
|
|
910
|
+
def build(self,
|
|
911
|
+
context: cl.Context,
|
|
912
|
+
options: Any) -> _BuiltFinalUpdateKernelInfo:
|
|
913
|
+
program = cl.Program(context, self.source).build(options)
|
|
914
|
+
kernel = getattr(program, self.kernel_name)
|
|
915
|
+
kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
|
|
916
|
+
return _BuiltFinalUpdateKernelInfo(kernel, self.update_wg_size)
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
@dataclass(frozen=True)
|
|
920
|
+
class _BuiltFinalUpdateKernelInfo:
|
|
921
|
+
kernel: cl.Kernel
|
|
922
|
+
update_wg_size: int
|
|
923
|
+
|
|
924
|
+
# }}}
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
class ScanPerformanceWarning(UserWarning):
|
|
928
|
+
pass
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
class GenericScanKernelBase(ABC):
|
|
932
|
+
# {{{ constructor, argument processing
|
|
933
|
+
|
|
934
|
+
def __init__(
|
|
935
|
+
self,
|
|
936
|
+
ctx: cl.Context,
|
|
937
|
+
dtype: Any,
|
|
938
|
+
arguments: str | list[DtypedArgument],
|
|
939
|
+
input_expr: str,
|
|
940
|
+
scan_expr: str,
|
|
941
|
+
neutral: str | None,
|
|
942
|
+
output_statement: str,
|
|
943
|
+
is_segment_start_expr: str | None = None,
|
|
944
|
+
input_fetch_exprs: list[tuple[str, str, int]] | None = None,
|
|
945
|
+
index_dtype: Any = None,
|
|
946
|
+
name_prefix: str = "scan",
|
|
947
|
+
options: Any = None,
|
|
948
|
+
preamble: str = "",
|
|
949
|
+
devices: Sequence[cl.Device] | None = None) -> None:
|
|
950
|
+
"""
|
|
951
|
+
:arg ctx: a :class:`pyopencl.Context` within which the code
|
|
952
|
+
for this scan kernel will be generated.
|
|
953
|
+
:arg dtype: the :class:`numpy.dtype` with which the scan will
|
|
954
|
+
be performed. May be a structured type if that type was registered
|
|
955
|
+
through :func:`pyopencl.tools.get_or_register_dtype`.
|
|
956
|
+
:arg arguments: A string of comma-separated C argument declarations.
|
|
957
|
+
If *arguments* is specified, then *input_expr* must also be
|
|
958
|
+
specified. All types used here must be known to PyOpenCL.
|
|
959
|
+
(see :func:`pyopencl.tools.get_or_register_dtype`).
|
|
960
|
+
:arg scan_expr: The associative, binary operation carrying out the scan,
|
|
961
|
+
represented as a C string. Its two arguments are available as ``a``
|
|
962
|
+
and ``b`` when it is evaluated. ``b`` is guaranteed to be the
|
|
963
|
+
'element being updated', and ``a`` is the increment. Thus,
|
|
964
|
+
if some data is supposed to just propagate along without being
|
|
965
|
+
modified by the scan, it should live in ``b``.
|
|
966
|
+
|
|
967
|
+
This expression may call functions given in the *preamble*.
|
|
968
|
+
|
|
969
|
+
Another value available to this expression is ``across_seg_boundary``,
|
|
970
|
+
a C `bool` indicating whether this scan update is crossing a
|
|
971
|
+
segment boundary, as defined by ``is_segment_start_expr``.
|
|
972
|
+
The scan routine does not implement segmentation
|
|
973
|
+
semantics on its own. It relies on ``scan_expr`` to do this.
|
|
974
|
+
This value is available (but always ``false``) even for a
|
|
975
|
+
non-segmented scan.
|
|
976
|
+
|
|
977
|
+
.. note::
|
|
978
|
+
|
|
979
|
+
In early pre-releases of the segmented scan,
|
|
980
|
+
segmentation semantics were implemented *without*
|
|
981
|
+
relying on ``scan_expr``.
|
|
982
|
+
|
|
983
|
+
:arg input_expr: A C expression, encoded as a string, resulting
|
|
984
|
+
in the values to which the scan is applied. This may be used
|
|
985
|
+
to apply a mapping to values stored in *arguments* before being
|
|
986
|
+
scanned. The result of this expression must match *dtype*.
|
|
987
|
+
The index intended to be mapped is available as ``i`` in this
|
|
988
|
+
expression. This expression may also use the variables defined
|
|
989
|
+
by *input_fetch_expr*.
|
|
990
|
+
|
|
991
|
+
This expression may also call functions given in the *preamble*.
|
|
992
|
+
:arg output_statement: a C statement that writes
|
|
993
|
+
the output of the scan. It has access to the scan result as ``item``,
|
|
994
|
+
the preceding scan result item as ``prev_item``, and the current index
|
|
995
|
+
as ``i``. ``prev_item`` in a segmented scan will be the neutral element
|
|
996
|
+
at a segment boundary, not the immediately preceding item.
|
|
997
|
+
|
|
998
|
+
Using *prev_item* in output statement has a small run-time cost.
|
|
999
|
+
``prev_item`` enables the construction of an exclusive scan.
|
|
1000
|
+
|
|
1001
|
+
For non-segmented scans, *output_statement* may also reference
|
|
1002
|
+
``last_item``, which evaluates to the scan result of the last
|
|
1003
|
+
array entry.
|
|
1004
|
+
:arg is_segment_start_expr: A C expression, encoded as a string,
|
|
1005
|
+
resulting in a C ``bool`` value that determines whether a new
|
|
1006
|
+
scan segments starts at index *i*. If given, makes the scan a
|
|
1007
|
+
segmented scan. Has access to the current index ``i``, the result
|
|
1008
|
+
of *input_expr* as ``a``, and in addition may use *arguments* and
|
|
1009
|
+
*input_fetch_expr* variables just like *input_expr*.
|
|
1010
|
+
|
|
1011
|
+
If it returns true, then previous sums will not spill over into the
|
|
1012
|
+
item with index *i* or subsequent items.
|
|
1013
|
+
:arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*.
|
|
1014
|
+
An entry here has the effect of doing the equivalent of the following
|
|
1015
|
+
before input_expr::
|
|
1016
|
+
|
|
1017
|
+
ARG_NAME_TYPE NAME = ARG_NAME[i+OFFSET];
|
|
1018
|
+
|
|
1019
|
+
``OFFSET`` is allowed to be 0 or -1, and ``ARG_NAME_TYPE`` is the type
|
|
1020
|
+
of ``ARG_NAME``.
|
|
1021
|
+
:arg preamble: |preamble|
|
|
1022
|
+
|
|
1023
|
+
The first array in the argument list determines the size of the index
|
|
1024
|
+
space over which the scan is carried out, and thus the values over
|
|
1025
|
+
which the index *i* occurring in a number of code fragments in
|
|
1026
|
+
arguments above will vary.
|
|
1027
|
+
|
|
1028
|
+
All code fragments further have access to N, the number of elements
|
|
1029
|
+
being processed in the scan.
|
|
1030
|
+
"""
|
|
1031
|
+
|
|
1032
|
+
if index_dtype is None:
|
|
1033
|
+
index_dtype = np.dtype(np.int32)
|
|
1034
|
+
|
|
1035
|
+
if input_fetch_exprs is None:
|
|
1036
|
+
input_fetch_exprs = []
|
|
1037
|
+
|
|
1038
|
+
self.context: cl.Context = ctx
|
|
1039
|
+
self.dtype: np.dtype[Any]
|
|
1040
|
+
dtype = self.dtype = np.dtype(dtype)
|
|
1041
|
+
|
|
1042
|
+
if neutral is None:
|
|
1043
|
+
from warnings import warn
|
|
1044
|
+
warn("not specifying 'neutral' is deprecated and will lead to "
|
|
1045
|
+
"wrong results if your scan is not in-place or your "
|
|
1046
|
+
"'output_statement' does something otherwise non-trivial",
|
|
1047
|
+
stacklevel=2)
|
|
1048
|
+
|
|
1049
|
+
if dtype.itemsize % 4 != 0:
|
|
1050
|
+
raise TypeError("scan value type must have size divisible by 4 bytes")
|
|
1051
|
+
|
|
1052
|
+
self.index_dtype: np.dtype[np.integer] = np.dtype(index_dtype)
|
|
1053
|
+
if np.iinfo(self.index_dtype).min >= 0:
|
|
1054
|
+
raise TypeError("index_dtype must be signed")
|
|
1055
|
+
|
|
1056
|
+
if devices is None:
|
|
1057
|
+
devices = ctx.devices
|
|
1058
|
+
self.devices: Sequence[cl.Device] = devices
|
|
1059
|
+
self.options = options
|
|
1060
|
+
|
|
1061
|
+
from pyopencl.tools import parse_arg_list
|
|
1062
|
+
self.parsed_args: Sequence[DtypedArgument] = parse_arg_list(arguments)
|
|
1063
|
+
from pyopencl.tools import VectorArg
|
|
1064
|
+
self.first_array_idx: int = next(
|
|
1065
|
+
i for i, arg in enumerate(self.parsed_args)
|
|
1066
|
+
if isinstance(arg, VectorArg))
|
|
1067
|
+
|
|
1068
|
+
self.input_expr: str = input_expr
|
|
1069
|
+
|
|
1070
|
+
self.is_segment_start_expr: str | None = is_segment_start_expr
|
|
1071
|
+
self.is_segmented: bool = is_segment_start_expr is not None
|
|
1072
|
+
if is_segment_start_expr is not None:
|
|
1073
|
+
is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
|
|
1074
|
+
|
|
1075
|
+
self.output_statement: str = output_statement
|
|
1076
|
+
|
|
1077
|
+
for _name, _arg_name, ife_offset in input_fetch_exprs:
|
|
1078
|
+
if ife_offset not in [0, -1]:
|
|
1079
|
+
raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
|
|
1080
|
+
self.input_fetch_exprs: Sequence[tuple[str, str, int]] = input_fetch_exprs
|
|
1081
|
+
|
|
1082
|
+
arg_dtypes = {}
|
|
1083
|
+
arg_ctypes = {}
|
|
1084
|
+
for arg in self.parsed_args:
|
|
1085
|
+
arg_dtypes[arg.name] = arg.dtype
|
|
1086
|
+
arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
|
|
1087
|
+
|
|
1088
|
+
self.name_prefix: str = name_prefix
|
|
1089
|
+
|
|
1090
|
+
# {{{ set up shared code dict
|
|
1091
|
+
|
|
1092
|
+
from pyopencl.characterize import has_double_support
|
|
1093
|
+
|
|
1094
|
+
self.code_variables = {
|
|
1095
|
+
"np": np,
|
|
1096
|
+
"dtype_to_ctype": dtype_to_ctype,
|
|
1097
|
+
"preamble": preamble,
|
|
1098
|
+
"name_prefix": name_prefix,
|
|
1099
|
+
"index_dtype": self.index_dtype,
|
|
1100
|
+
"scan_dtype": dtype,
|
|
1101
|
+
"is_segmented": self.is_segmented,
|
|
1102
|
+
"arg_dtypes": arg_dtypes,
|
|
1103
|
+
"arg_ctypes": arg_ctypes,
|
|
1104
|
+
"scan_expr": _process_code_for_macro(scan_expr),
|
|
1105
|
+
"neutral": _process_code_for_macro(neutral),
|
|
1106
|
+
"is_gpu": bool(self.devices[0].type & cl.device_type.GPU),
|
|
1107
|
+
"double_support": all(
|
|
1108
|
+
has_double_support(dev) for dev in devices),
|
|
1109
|
+
}
|
|
1110
|
+
|
|
1111
|
+
index_typename = dtype_to_ctype(self.index_dtype)
|
|
1112
|
+
scan_typename = dtype_to_ctype(dtype)
|
|
1113
|
+
|
|
1114
|
+
# This key is meant to uniquely identify the non-device parameters for
|
|
1115
|
+
# the scan kernel.
|
|
1116
|
+
self.kernel_key = (
|
|
1117
|
+
self.dtype,
|
|
1118
|
+
tuple(arg.declarator() for arg in self.parsed_args),
|
|
1119
|
+
self.input_expr,
|
|
1120
|
+
scan_expr,
|
|
1121
|
+
neutral,
|
|
1122
|
+
output_statement,
|
|
1123
|
+
is_segment_start_expr,
|
|
1124
|
+
tuple(input_fetch_exprs),
|
|
1125
|
+
index_dtype,
|
|
1126
|
+
name_prefix,
|
|
1127
|
+
preamble,
|
|
1128
|
+
# These depend on dtype_to_ctype(), so their value is independent of
|
|
1129
|
+
# the other variables.
|
|
1130
|
+
index_typename,
|
|
1131
|
+
scan_typename,
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
# }}}
|
|
1135
|
+
|
|
1136
|
+
self.use_lookbehind_update: bool = "prev_item" in self.output_statement
|
|
1137
|
+
self.store_segment_start_flags: bool = (
|
|
1138
|
+
self.is_segmented and self.use_lookbehind_update)
|
|
1139
|
+
|
|
1140
|
+
self.finish_setup()
|
|
1141
|
+
|
|
1142
|
+
# }}}
|
|
1143
|
+
|
|
1144
|
+
@abstractmethod
|
|
1145
|
+
def finish_setup(self) -> None:
|
|
1146
|
+
pass
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
if not cl._PYOPENCL_NO_CACHE:
|
|
1150
|
+
generic_scan_kernel_cache: WriteOncePersistentDict[Any,
|
|
1151
|
+
tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
|
|
1152
|
+
_GeneratedFinalUpdateKernelInfo]] = \
|
|
1153
|
+
WriteOncePersistentDict(
|
|
1154
|
+
"pyopencl-generated-scan-kernel-cache-v1",
|
|
1155
|
+
key_builder=_NumpyTypesKeyBuilder(),
|
|
1156
|
+
in_mem_cache_size=0,
|
|
1157
|
+
safe_sync=False)
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
class GenericScanKernel(GenericScanKernelBase):
|
|
1161
|
+
"""Generates and executes code that performs prefix sums ("scans") on
|
|
1162
|
+
arbitrary types, with many possible tweaks.
|
|
1163
|
+
|
|
1164
|
+
Usage example::
|
|
1165
|
+
|
|
1166
|
+
from pyopencl.scan import GenericScanKernel
|
|
1167
|
+
knl = GenericScanKernel(
|
|
1168
|
+
context, np.int32,
|
|
1169
|
+
arguments="__global int *ary",
|
|
1170
|
+
input_expr="ary[i]",
|
|
1171
|
+
scan_expr="a+b", neutral="0",
|
|
1172
|
+
output_statement="ary[i+1] = item;")
|
|
1173
|
+
|
|
1174
|
+
a = cl.array.arange(queue, 10000, dtype=np.int32)
|
|
1175
|
+
knl(a, queue=queue)
|
|
1176
|
+
|
|
1177
|
+
.. automethod:: __init__
|
|
1178
|
+
.. automethod:: __call__
|
|
1179
|
+
"""
|
|
1180
|
+
|
|
1181
|
+
def finish_setup(self) -> None:
|
|
1182
|
+
# Before generating the kernel, see if it's cached.
|
|
1183
|
+
from pyopencl.cache import get_device_cache_id
|
|
1184
|
+
devices_key = tuple(get_device_cache_id(device)
|
|
1185
|
+
for device in self.devices)
|
|
1186
|
+
|
|
1187
|
+
cache_key = (self.kernel_key, devices_key)
|
|
1188
|
+
from_cache = False
|
|
1189
|
+
|
|
1190
|
+
if not cl._PYOPENCL_NO_CACHE:
|
|
1191
|
+
try:
|
|
1192
|
+
result = generic_scan_kernel_cache[cache_key]
|
|
1193
|
+
from_cache = True
|
|
1194
|
+
logger.debug(
|
|
1195
|
+
"cache hit for generated scan kernel '%s'", self.name_prefix)
|
|
1196
|
+
(
|
|
1197
|
+
self.first_level_scan_gen_info,
|
|
1198
|
+
self.second_level_scan_gen_info,
|
|
1199
|
+
self.final_update_gen_info) = result
|
|
1200
|
+
except KeyError:
|
|
1201
|
+
pass
|
|
1202
|
+
|
|
1203
|
+
if not from_cache:
|
|
1204
|
+
logger.debug(
|
|
1205
|
+
"cache miss for generated scan kernel '%s'", self.name_prefix)
|
|
1206
|
+
self._finish_setup_impl()
|
|
1207
|
+
|
|
1208
|
+
result = (self.first_level_scan_gen_info,
|
|
1209
|
+
self.second_level_scan_gen_info,
|
|
1210
|
+
self.final_update_gen_info)
|
|
1211
|
+
|
|
1212
|
+
if not cl._PYOPENCL_NO_CACHE:
|
|
1213
|
+
generic_scan_kernel_cache.store_if_not_present(cache_key, result)
|
|
1214
|
+
|
|
1215
|
+
# Build the kernels.
|
|
1216
|
+
self.first_level_scan_info = self.first_level_scan_gen_info.build(
|
|
1217
|
+
self.context, self.options)
|
|
1218
|
+
del self.first_level_scan_gen_info
|
|
1219
|
+
|
|
1220
|
+
self.second_level_scan_info = self.second_level_scan_gen_info.build(
|
|
1221
|
+
self.context, self.options)
|
|
1222
|
+
del self.second_level_scan_gen_info
|
|
1223
|
+
|
|
1224
|
+
self.final_update_info = self.final_update_gen_info.build(
|
|
1225
|
+
self.context, self.options)
|
|
1226
|
+
del self.final_update_gen_info
|
|
1227
|
+
|
|
1228
|
+
def _finish_setup_impl(self) -> None:
|
|
1229
|
+
# {{{ find usable workgroup/k-group size, build first-level scan
|
|
1230
|
+
|
|
1231
|
+
trip_count = 0
|
|
1232
|
+
|
|
1233
|
+
avail_local_mem = min(
|
|
1234
|
+
dev.local_mem_size
|
|
1235
|
+
for dev in self.devices)
|
|
1236
|
+
|
|
1237
|
+
if "CUDA" in self.devices[0].platform.name:
|
|
1238
|
+
# not sure where these go, but roughly this much seems unavailable.
|
|
1239
|
+
avail_local_mem -= 0x400
|
|
1240
|
+
|
|
1241
|
+
is_cpu = bool(self.devices[0].type & cl.device_type.CPU)
|
|
1242
|
+
is_gpu = bool(self.devices[0].type & cl.device_type.GPU)
|
|
1243
|
+
|
|
1244
|
+
if is_cpu:
|
|
1245
|
+
# (about the widest vector a CPU can support, also taking
|
|
1246
|
+
# into account that CPUs don't hide latency by large work groups
|
|
1247
|
+
max_scan_wg_size = 16
|
|
1248
|
+
wg_size_multiples = 4
|
|
1249
|
+
else:
|
|
1250
|
+
max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
|
|
1251
|
+
wg_size_multiples = 64
|
|
1252
|
+
|
|
1253
|
+
# Intel beignet fails "Out of shared local memory" in test_scan int64
|
|
1254
|
+
# and asserts in test_sort with this enabled:
|
|
1255
|
+
# https://github.com/inducer/pyopencl/pull/238
|
|
1256
|
+
# A beignet bug report (outside of pyopencl) suggests packed structs
|
|
1257
|
+
# (which this is) can even give wrong results:
|
|
1258
|
+
# https://bugs.freedesktop.org/show_bug.cgi?id=98717
|
|
1259
|
+
# TODO: does this also affect Intel Compute Runtime?
|
|
1260
|
+
use_bank_conflict_avoidance = (
|
|
1261
|
+
self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0
|
|
1262
|
+
and is_gpu
|
|
1263
|
+
and "beignet" not in self.devices[0].platform.version.lower())
|
|
1264
|
+
|
|
1265
|
+
# k_group_size should be a power of two because of in-kernel
|
|
1266
|
+
# division by that number.
|
|
1267
|
+
|
|
1268
|
+
solutions: list[tuple[int, int, int]] = []
|
|
1269
|
+
for k_exp in range(0, 9):
|
|
1270
|
+
for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
|
|
1271
|
+
wg_size_multiples):
|
|
1272
|
+
|
|
1273
|
+
k_group_size = 2**k_exp
|
|
1274
|
+
lmem_use = self.get_local_mem_use(wg_size, k_group_size,
|
|
1275
|
+
use_bank_conflict_avoidance)
|
|
1276
|
+
if lmem_use <= avail_local_mem:
|
|
1277
|
+
solutions.append((wg_size*k_group_size, k_group_size, wg_size))
|
|
1278
|
+
|
|
1279
|
+
if is_gpu:
|
|
1280
|
+
for wg_size_floor in [256, 192, 128]:
|
|
1281
|
+
have_sol_above_floor = any(wg_size >= wg_size_floor
|
|
1282
|
+
for _, _, wg_size in solutions)
|
|
1283
|
+
|
|
1284
|
+
if have_sol_above_floor:
|
|
1285
|
+
# delete all solutions not meeting the wg size floor
|
|
1286
|
+
solutions = [(total, try_k_group_size, try_wg_size)
|
|
1287
|
+
for total, try_k_group_size, try_wg_size in solutions
|
|
1288
|
+
if try_wg_size >= wg_size_floor]
|
|
1289
|
+
break
|
|
1290
|
+
|
|
1291
|
+
_, k_group_size, max_scan_wg_size = max(solutions)
|
|
1292
|
+
|
|
1293
|
+
while True:
|
|
1294
|
+
candidate_scan_gen_info = self.generate_scan_kernel(
|
|
1295
|
+
max_scan_wg_size, self.parsed_args,
|
|
1296
|
+
_process_code_for_macro(self.input_expr),
|
|
1297
|
+
self.is_segment_start_expr,
|
|
1298
|
+
input_fetch_exprs=self.input_fetch_exprs,
|
|
1299
|
+
is_first_level=True,
|
|
1300
|
+
store_segment_start_flags=self.store_segment_start_flags,
|
|
1301
|
+
k_group_size=k_group_size,
|
|
1302
|
+
use_bank_conflict_avoidance=use_bank_conflict_avoidance)
|
|
1303
|
+
|
|
1304
|
+
candidate_scan_info = candidate_scan_gen_info.build(
|
|
1305
|
+
self.context, self.options)
|
|
1306
|
+
|
|
1307
|
+
# Will this device actually let us execute this kernel
|
|
1308
|
+
# at the desired work group size? Building it is the
|
|
1309
|
+
# only way to find out.
|
|
1310
|
+
kernel_max_wg_size = min(
|
|
1311
|
+
candidate_scan_info.kernel.get_work_group_info(
|
|
1312
|
+
cl.kernel_work_group_info.WORK_GROUP_SIZE,
|
|
1313
|
+
dev)
|
|
1314
|
+
for dev in self.devices)
|
|
1315
|
+
|
|
1316
|
+
if candidate_scan_info.wg_size <= kernel_max_wg_size:
|
|
1317
|
+
break
|
|
1318
|
+
else:
|
|
1319
|
+
max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
|
|
1320
|
+
|
|
1321
|
+
trip_count += 1
|
|
1322
|
+
assert trip_count <= 20
|
|
1323
|
+
|
|
1324
|
+
self.first_level_scan_gen_info = candidate_scan_gen_info
|
|
1325
|
+
assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
|
|
1326
|
+
== candidate_scan_info.wg_size)
|
|
1327
|
+
|
|
1328
|
+
# }}}
|
|
1329
|
+
|
|
1330
|
+
# {{{ build second-level scan
|
|
1331
|
+
|
|
1332
|
+
from pyopencl.tools import VectorArg
|
|
1333
|
+
second_level_arguments = [
|
|
1334
|
+
*self.parsed_args,
|
|
1335
|
+
VectorArg(self.dtype, "interval_sums"),
|
|
1336
|
+
]
|
|
1337
|
+
|
|
1338
|
+
second_level_build_kwargs: dict[str, str | None] = {}
|
|
1339
|
+
if self.is_segmented:
|
|
1340
|
+
second_level_arguments.append(
|
|
1341
|
+
VectorArg(self.index_dtype,
|
|
1342
|
+
"g_first_segment_start_in_interval_input"))
|
|
1343
|
+
|
|
1344
|
+
# is_segment_start_expr answers the question "should previous sums
|
|
1345
|
+
# spill over into this item". And since
|
|
1346
|
+
# g_first_segment_start_in_interval_input answers the question if a
|
|
1347
|
+
# segment boundary was found in an interval of data, then if not,
|
|
1348
|
+
# it's ok to spill over.
|
|
1349
|
+
second_level_build_kwargs["is_segment_start_expr"] = \
|
|
1350
|
+
"g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
|
|
1351
|
+
else:
|
|
1352
|
+
second_level_build_kwargs["is_segment_start_expr"] = None
|
|
1353
|
+
|
|
1354
|
+
self.second_level_scan_gen_info = self.generate_scan_kernel(
|
|
1355
|
+
max_scan_wg_size,
|
|
1356
|
+
arguments=second_level_arguments,
|
|
1357
|
+
input_expr="interval_sums[i]",
|
|
1358
|
+
input_fetch_exprs=[],
|
|
1359
|
+
is_first_level=False,
|
|
1360
|
+
store_segment_start_flags=False,
|
|
1361
|
+
k_group_size=k_group_size,
|
|
1362
|
+
use_bank_conflict_avoidance=use_bank_conflict_avoidance,
|
|
1363
|
+
**second_level_build_kwargs)
|
|
1364
|
+
|
|
1365
|
+
# }}}
|
|
1366
|
+
|
|
1367
|
+
# {{{ generate final update kernel
|
|
1368
|
+
|
|
1369
|
+
update_wg_size = min(max_scan_wg_size, 256)
|
|
1370
|
+
|
|
1371
|
+
final_update_tpl = _make_template(UPDATE_SOURCE)
|
|
1372
|
+
final_update_src = str(final_update_tpl.render(
|
|
1373
|
+
wg_size=update_wg_size,
|
|
1374
|
+
output_statement=self.output_statement,
|
|
1375
|
+
arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
|
|
1376
|
+
argument_signature=", ".join(
|
|
1377
|
+
arg.declarator() for arg in self.parsed_args),
|
|
1378
|
+
is_segment_start_expr=self.is_segment_start_expr,
|
|
1379
|
+
input_expr=_process_code_for_macro(self.input_expr),
|
|
1380
|
+
use_lookbehind_update=self.use_lookbehind_update,
|
|
1381
|
+
**self.code_variables))
|
|
1382
|
+
|
|
1383
|
+
update_scalar_arg_dtypes = [
|
|
1384
|
+
*get_arg_list_scalar_arg_dtypes(self.parsed_args),
|
|
1385
|
+
self.index_dtype, self.index_dtype, None, None]
|
|
1386
|
+
|
|
1387
|
+
if self.is_segmented:
|
|
1388
|
+
# g_first_segment_start_in_interval
|
|
1389
|
+
update_scalar_arg_dtypes.append(None)
|
|
1390
|
+
if self.store_segment_start_flags:
|
|
1391
|
+
update_scalar_arg_dtypes.append(None) # g_segment_start_flags
|
|
1392
|
+
|
|
1393
|
+
self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
|
|
1394
|
+
final_update_src,
|
|
1395
|
+
self.name_prefix + "_final_update",
|
|
1396
|
+
update_scalar_arg_dtypes,
|
|
1397
|
+
update_wg_size)
|
|
1398
|
+
|
|
1399
|
+
# }}}
|
|
1400
|
+
|
|
1401
|
+
# {{{ scan kernel build/properties
|
|
1402
|
+
|
|
1403
|
+
def get_local_mem_use(
|
|
1404
|
+
self, k_group_size: int, wg_size: int,
|
|
1405
|
+
use_bank_conflict_avoidance: bool) -> int:
|
|
1406
|
+
arg_dtypes = {}
|
|
1407
|
+
for arg in self.parsed_args:
|
|
1408
|
+
arg_dtypes[arg.name] = arg.dtype
|
|
1409
|
+
|
|
1410
|
+
fetch_expr_offsets: dict[str, set[int]] = {}
|
|
1411
|
+
for _name, arg_name, ife_offset in self.input_fetch_exprs:
|
|
1412
|
+
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
|
|
1413
|
+
|
|
1414
|
+
itemsize = self.dtype.itemsize
|
|
1415
|
+
if use_bank_conflict_avoidance:
|
|
1416
|
+
itemsize += 4
|
|
1417
|
+
|
|
1418
|
+
return (
|
|
1419
|
+
# ldata
|
|
1420
|
+
itemsize*(k_group_size+1)*(wg_size+1)
|
|
1421
|
+
|
|
1422
|
+
# l_segment_start_flags
|
|
1423
|
+
+ k_group_size*wg_size
|
|
1424
|
+
|
|
1425
|
+
# l_first_segment_start_in_subtree
|
|
1426
|
+
+ self.index_dtype.itemsize*wg_size
|
|
1427
|
+
|
|
1428
|
+
+ k_group_size*wg_size*sum(
|
|
1429
|
+
arg_dtypes[arg_name].itemsize
|
|
1430
|
+
for arg_name, ife_offsets in list(fetch_expr_offsets.items())
|
|
1431
|
+
if -1 in ife_offsets or len(ife_offsets) > 1))
|
|
1432
|
+
|
|
1433
|
+
def generate_scan_kernel(
|
|
1434
|
+
self,
|
|
1435
|
+
max_wg_size: int,
|
|
1436
|
+
arguments: Sequence[DtypedArgument],
|
|
1437
|
+
input_expr: str,
|
|
1438
|
+
is_segment_start_expr: str | None,
|
|
1439
|
+
input_fetch_exprs: Sequence[tuple[str, str, int]],
|
|
1440
|
+
is_first_level: bool,
|
|
1441
|
+
store_segment_start_flags: bool,
|
|
1442
|
+
k_group_size: int,
|
|
1443
|
+
use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
|
|
1444
|
+
scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
|
|
1445
|
+
|
|
1446
|
+
# Empirically found on Nv hardware: no need to be bigger than this size
|
|
1447
|
+
wg_size = _round_down_to_power_of_2(
|
|
1448
|
+
min(max_wg_size, 256))
|
|
1449
|
+
|
|
1450
|
+
kernel_name = cast("str", self.code_variables["name_prefix"])
|
|
1451
|
+
if is_first_level:
|
|
1452
|
+
kernel_name += "_lev1"
|
|
1453
|
+
else:
|
|
1454
|
+
kernel_name += "_lev2"
|
|
1455
|
+
|
|
1456
|
+
scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
|
|
1457
|
+
scan_src = str(scan_tpl.render(
|
|
1458
|
+
wg_size=wg_size,
|
|
1459
|
+
input_expr=input_expr,
|
|
1460
|
+
k_group_size=k_group_size,
|
|
1461
|
+
arg_offset_adjustment=get_arg_offset_adjuster_code(arguments),
|
|
1462
|
+
argument_signature=", ".join(arg.declarator() for arg in arguments),
|
|
1463
|
+
is_segment_start_expr=is_segment_start_expr,
|
|
1464
|
+
input_fetch_exprs=input_fetch_exprs,
|
|
1465
|
+
is_first_level=is_first_level,
|
|
1466
|
+
store_segment_start_flags=store_segment_start_flags,
|
|
1467
|
+
use_bank_conflict_avoidance=use_bank_conflict_avoidance,
|
|
1468
|
+
kernel_name=kernel_name,
|
|
1469
|
+
**self.code_variables))
|
|
1470
|
+
|
|
1471
|
+
scalar_arg_dtypes.extend(
|
|
1472
|
+
(None, self.index_dtype, self.index_dtype))
|
|
1473
|
+
if is_first_level:
|
|
1474
|
+
scalar_arg_dtypes.append(None) # interval_results
|
|
1475
|
+
if self.is_segmented and is_first_level:
|
|
1476
|
+
scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval
|
|
1477
|
+
if store_segment_start_flags:
|
|
1478
|
+
scalar_arg_dtypes.append(None) # g_segment_start_flags
|
|
1479
|
+
|
|
1480
|
+
return _GeneratedScanKernelInfo(
|
|
1481
|
+
scan_src=scan_src,
|
|
1482
|
+
kernel_name=kernel_name,
|
|
1483
|
+
scalar_arg_dtypes=scalar_arg_dtypes,
|
|
1484
|
+
wg_size=wg_size,
|
|
1485
|
+
k_group_size=k_group_size)
|
|
1486
|
+
|
|
1487
|
+
# }}}
|
|
1488
|
+
|
|
1489
|
+
def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
|
|
1490
|
+
"""
|
|
1491
|
+
|std-enqueue-blurb|
|
|
1492
|
+
|
|
1493
|
+
.. note::
|
|
1494
|
+
|
|
1495
|
+
The returned :class:`pyopencl.Event` corresponds only to part of the
|
|
1496
|
+
execution of the scan. It is not suitable for profiling.
|
|
1497
|
+
|
|
1498
|
+
:arg queue: queue on which to execute the scan. If not given, the
|
|
1499
|
+
queue of the first :class:`pyopencl.array.Array` in *args* is used
|
|
1500
|
+
:arg allocator: an allocator for the temporary arrays and results. If
|
|
1501
|
+
not given, the allocator of the first :class:`pyopencl.array.Array`
|
|
1502
|
+
in *args* is used.
|
|
1503
|
+
:arg size: specify the length of the scan to be carried out. If not
|
|
1504
|
+
given, this length is inferred from the first argument
|
|
1505
|
+
:arg wait_for: a :class:`list` of events to wait for.
|
|
1506
|
+
"""
|
|
1507
|
+
|
|
1508
|
+
# {{{ argument processing
|
|
1509
|
+
|
|
1510
|
+
allocator = kwargs.get("allocator")
|
|
1511
|
+
queue = kwargs.get("queue")
|
|
1512
|
+
n = kwargs.get("size")
|
|
1513
|
+
wait_for = kwargs.get("wait_for")
|
|
1514
|
+
|
|
1515
|
+
if wait_for is None:
|
|
1516
|
+
wait_for = []
|
|
1517
|
+
else:
|
|
1518
|
+
wait_for = list(wait_for)
|
|
1519
|
+
|
|
1520
|
+
if len(args) != len(self.parsed_args):
|
|
1521
|
+
raise TypeError(
|
|
1522
|
+
f"expected {len(self.parsed_args)} arguments, got {len(args)}")
|
|
1523
|
+
|
|
1524
|
+
first_array = args[self.first_array_idx]
|
|
1525
|
+
allocator = allocator or first_array.allocator
|
|
1526
|
+
queue = queue or first_array.queue
|
|
1527
|
+
|
|
1528
|
+
if n is None:
|
|
1529
|
+
n, = first_array.shape
|
|
1530
|
+
|
|
1531
|
+
if n == 0:
|
|
1532
|
+
# We're done here. (But pretend to return an event.)
|
|
1533
|
+
return cl.enqueue_marker(queue, wait_for=wait_for)
|
|
1534
|
+
|
|
1535
|
+
data_args = []
|
|
1536
|
+
for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
|
|
1537
|
+
from pyopencl.tools import VectorArg
|
|
1538
|
+
if isinstance(arg_descr, VectorArg):
|
|
1539
|
+
data_args.append(arg_val.base_data)
|
|
1540
|
+
if arg_descr.with_offset:
|
|
1541
|
+
data_args.append(arg_val.offset)
|
|
1542
|
+
wait_for.extend(arg_val.events)
|
|
1543
|
+
else:
|
|
1544
|
+
data_args.append(arg_val)
|
|
1545
|
+
|
|
1546
|
+
# }}}
|
|
1547
|
+
|
|
1548
|
+
l1_info = self.first_level_scan_info
|
|
1549
|
+
l2_info = self.second_level_scan_info
|
|
1550
|
+
|
|
1551
|
+
# see CL source above for terminology
|
|
1552
|
+
unit_size = l1_info.wg_size * l1_info.k_group_size
|
|
1553
|
+
max_intervals = 3*max(dev.max_compute_units for dev in self.devices)
|
|
1554
|
+
|
|
1555
|
+
from pytools import uniform_interval_splitting
|
|
1556
|
+
interval_size, num_intervals = uniform_interval_splitting(
|
|
1557
|
+
n, unit_size, max_intervals)
|
|
1558
|
+
|
|
1559
|
+
# {{{ allocate some buffers
|
|
1560
|
+
|
|
1561
|
+
interval_results = cl_array.empty(queue,
|
|
1562
|
+
num_intervals, dtype=self.dtype,
|
|
1563
|
+
allocator=allocator)
|
|
1564
|
+
|
|
1565
|
+
partial_scan_buffer = cl_array.empty(
|
|
1566
|
+
queue, n, dtype=self.dtype,
|
|
1567
|
+
allocator=allocator)
|
|
1568
|
+
|
|
1569
|
+
if self.store_segment_start_flags:
|
|
1570
|
+
segment_start_flags = cl_array.empty(
|
|
1571
|
+
queue, n, dtype=np.bool_,
|
|
1572
|
+
allocator=allocator)
|
|
1573
|
+
|
|
1574
|
+
# }}}
|
|
1575
|
+
|
|
1576
|
+
# {{{ first level scan of interval (one interval per block)
|
|
1577
|
+
|
|
1578
|
+
scan1_args = [
|
|
1579
|
+
*data_args,
|
|
1580
|
+
partial_scan_buffer.data, n, interval_size, interval_results.data,
|
|
1581
|
+
]
|
|
1582
|
+
|
|
1583
|
+
if self.is_segmented:
|
|
1584
|
+
first_segment_start_in_interval = cl_array.empty(queue,
|
|
1585
|
+
num_intervals, dtype=self.index_dtype,
|
|
1586
|
+
allocator=allocator)
|
|
1587
|
+
scan1_args.append(first_segment_start_in_interval.data)
|
|
1588
|
+
|
|
1589
|
+
if self.store_segment_start_flags:
|
|
1590
|
+
scan1_args.append(segment_start_flags.data)
|
|
1591
|
+
|
|
1592
|
+
l1_evt = l1_info.kernel(
|
|
1593
|
+
queue, (num_intervals,), (l1_info.wg_size,),
|
|
1594
|
+
*scan1_args, g_times_l=True, wait_for=wait_for)
|
|
1595
|
+
|
|
1596
|
+
# }}}
|
|
1597
|
+
|
|
1598
|
+
# {{{ second level scan of per-interval results
|
|
1599
|
+
|
|
1600
|
+
# can scan at most one interval
|
|
1601
|
+
assert interval_size >= num_intervals
|
|
1602
|
+
|
|
1603
|
+
scan2_args = [
|
|
1604
|
+
*data_args,
|
|
1605
|
+
interval_results.data, # interval_sums
|
|
1606
|
+
]
|
|
1607
|
+
|
|
1608
|
+
if self.is_segmented:
|
|
1609
|
+
scan2_args.append(first_segment_start_in_interval.data)
|
|
1610
|
+
scan2_args = [
|
|
1611
|
+
*scan2_args,
|
|
1612
|
+
interval_results.data, # partial_scan_buffer
|
|
1613
|
+
num_intervals, interval_size]
|
|
1614
|
+
|
|
1615
|
+
l2_evt = l2_info.kernel(
|
|
1616
|
+
queue, (1,), (l1_info.wg_size,),
|
|
1617
|
+
*scan2_args, g_times_l=True, wait_for=[l1_evt])
|
|
1618
|
+
|
|
1619
|
+
# }}}
|
|
1620
|
+
|
|
1621
|
+
# {{{ update intervals with result of interval scan
|
|
1622
|
+
|
|
1623
|
+
upd_args = [
|
|
1624
|
+
*data_args,
|
|
1625
|
+
n, interval_size, interval_results.data, partial_scan_buffer.data]
|
|
1626
|
+
if self.is_segmented:
|
|
1627
|
+
upd_args.append(first_segment_start_in_interval.data)
|
|
1628
|
+
if self.store_segment_start_flags:
|
|
1629
|
+
upd_args.append(segment_start_flags.data)
|
|
1630
|
+
|
|
1631
|
+
return self.final_update_info.kernel(
|
|
1632
|
+
queue, (num_intervals,),
|
|
1633
|
+
(self.final_update_info.update_wg_size,),
|
|
1634
|
+
*upd_args, g_times_l=True, wait_for=[l2_evt])
|
|
1635
|
+
|
|
1636
|
+
# }}}
|
|
1637
|
+
|
|
1638
|
+
# }}}
|
|
1639
|
+
|
|
1640
|
+
|
|
1641
|
+
# {{{ debug kernel
|
|
1642
|
+
|
|
1643
|
+
DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + r"""//CL//
|
|
1644
|
+
|
|
1645
|
+
KERNEL
|
|
1646
|
+
REQD_WG_SIZE(1, 1, 1)
|
|
1647
|
+
void ${name_prefix}_debug_scan(
|
|
1648
|
+
__global scan_type *scan_tmp,
|
|
1649
|
+
${argument_signature},
|
|
1650
|
+
const index_type N)
|
|
1651
|
+
{
|
|
1652
|
+
scan_type current = ${neutral};
|
|
1653
|
+
scan_type prev;
|
|
1654
|
+
|
|
1655
|
+
${arg_offset_adjustment}
|
|
1656
|
+
|
|
1657
|
+
for (index_type i = 0; i < N; ++i)
|
|
1658
|
+
{
|
|
1659
|
+
%for name, arg_name, ife_offset in input_fetch_exprs:
|
|
1660
|
+
${arg_ctypes[arg_name]} ${name};
|
|
1661
|
+
%if ife_offset < 0:
|
|
1662
|
+
if (i+${ife_offset} >= 0)
|
|
1663
|
+
${name} = ${arg_name}[i+${ife_offset}];
|
|
1664
|
+
%else:
|
|
1665
|
+
${name} = ${arg_name}[i];
|
|
1666
|
+
%endif
|
|
1667
|
+
%endfor
|
|
1668
|
+
|
|
1669
|
+
scan_type my_val = INPUT_EXPR(i);
|
|
1670
|
+
|
|
1671
|
+
prev = current;
|
|
1672
|
+
%if is_segmented:
|
|
1673
|
+
bool is_seg_start = IS_SEG_START(i, my_val);
|
|
1674
|
+
%endif
|
|
1675
|
+
|
|
1676
|
+
current = SCAN_EXPR(prev, my_val,
|
|
1677
|
+
%if is_segmented:
|
|
1678
|
+
is_seg_start
|
|
1679
|
+
%else:
|
|
1680
|
+
false
|
|
1681
|
+
%endif
|
|
1682
|
+
);
|
|
1683
|
+
scan_tmp[i] = current;
|
|
1684
|
+
}
|
|
1685
|
+
|
|
1686
|
+
scan_type last_item = scan_tmp[N-1];
|
|
1687
|
+
|
|
1688
|
+
for (index_type i = 0; i < N; ++i)
|
|
1689
|
+
{
|
|
1690
|
+
scan_type item = scan_tmp[i];
|
|
1691
|
+
scan_type prev_item;
|
|
1692
|
+
if (i)
|
|
1693
|
+
prev_item = scan_tmp[i-1];
|
|
1694
|
+
else
|
|
1695
|
+
prev_item = ${neutral};
|
|
1696
|
+
|
|
1697
|
+
{
|
|
1698
|
+
${output_statement};
|
|
1699
|
+
}
|
|
1700
|
+
}
|
|
1701
|
+
}
|
|
1702
|
+
"""
|
|
1703
|
+
|
|
1704
|
+
|
|
1705
|
+
class GenericDebugScanKernel(GenericScanKernelBase):
|
|
1706
|
+
"""
|
|
1707
|
+
Performs the same function and has the same interface as
|
|
1708
|
+
:class:`GenericScanKernel`, but uses a dead-simple, sequential scan. Works
|
|
1709
|
+
best on CPU platforms, and helps isolate bugs in scans by removing the
|
|
1710
|
+
potential for issues originating in parallel execution.
|
|
1711
|
+
|
|
1712
|
+
.. automethod:: __call__
|
|
1713
|
+
"""
|
|
1714
|
+
|
|
1715
|
+
def finish_setup(self) -> None:
|
|
1716
|
+
scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
|
|
1717
|
+
scan_src = str(scan_tpl.render(
|
|
1718
|
+
output_statement=self.output_statement,
|
|
1719
|
+
arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
|
|
1720
|
+
argument_signature=", ".join(
|
|
1721
|
+
arg.declarator() for arg in self.parsed_args),
|
|
1722
|
+
is_segment_start_expr=self.is_segment_start_expr,
|
|
1723
|
+
input_expr=_process_code_for_macro(self.input_expr),
|
|
1724
|
+
input_fetch_exprs=self.input_fetch_exprs,
|
|
1725
|
+
wg_size=1,
|
|
1726
|
+
**self.code_variables))
|
|
1727
|
+
|
|
1728
|
+
scan_prg = cl.Program(self.context, scan_src).build(self.options)
|
|
1729
|
+
self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
|
|
1730
|
+
scalar_arg_dtypes = [
|
|
1731
|
+
None,
|
|
1732
|
+
*get_arg_list_scalar_arg_dtypes(self.parsed_args),
|
|
1733
|
+
self.index_dtype,
|
|
1734
|
+
]
|
|
1735
|
+
self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
|
|
1736
|
+
|
|
1737
|
+
def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
|
|
1738
|
+
"""See :meth:`GenericScanKernel.__call__`."""
|
|
1739
|
+
|
|
1740
|
+
# {{{ argument processing
|
|
1741
|
+
|
|
1742
|
+
allocator = kwargs.get("allocator")
|
|
1743
|
+
queue = kwargs.get("queue")
|
|
1744
|
+
n = kwargs.get("size")
|
|
1745
|
+
wait_for = kwargs.get("wait_for")
|
|
1746
|
+
|
|
1747
|
+
if wait_for is None:
|
|
1748
|
+
wait_for = []
|
|
1749
|
+
else:
|
|
1750
|
+
# We'll be modifying it below.
|
|
1751
|
+
wait_for = list(wait_for)
|
|
1752
|
+
|
|
1753
|
+
if len(args) != len(self.parsed_args):
|
|
1754
|
+
raise TypeError(
|
|
1755
|
+
f"expected {len(self.parsed_args)} arguments, got {len(args)}")
|
|
1756
|
+
|
|
1757
|
+
first_array = args[self.first_array_idx]
|
|
1758
|
+
allocator = allocator or first_array.allocator
|
|
1759
|
+
queue = queue or first_array.queue
|
|
1760
|
+
|
|
1761
|
+
if n is None:
|
|
1762
|
+
n, = first_array.shape
|
|
1763
|
+
|
|
1764
|
+
scan_tmp = cl_array.empty(queue,
|
|
1765
|
+
n, dtype=self.dtype,
|
|
1766
|
+
allocator=allocator)
|
|
1767
|
+
|
|
1768
|
+
data_args = [scan_tmp.data]
|
|
1769
|
+
from pyopencl.tools import VectorArg
|
|
1770
|
+
for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
|
|
1771
|
+
if isinstance(arg_descr, VectorArg):
|
|
1772
|
+
data_args.append(arg_val.base_data)
|
|
1773
|
+
if arg_descr.with_offset:
|
|
1774
|
+
data_args.append(arg_val.offset)
|
|
1775
|
+
wait_for.extend(arg_val.events)
|
|
1776
|
+
else:
|
|
1777
|
+
data_args.append(arg_val)
|
|
1778
|
+
|
|
1779
|
+
# }}}
|
|
1780
|
+
|
|
1781
|
+
return self.kernel(queue, (1,), (1,), *([*data_args, n]), wait_for=wait_for)
|
|
1782
|
+
|
|
1783
|
+
# }}}
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
# {{{ compatibility interface
|
|
1787
|
+
|
|
1788
|
+
class _LegacyScanKernelBase(GenericScanKernel):
|
|
1789
|
+
def __init__(self, ctx, dtype,
|
|
1790
|
+
scan_expr, neutral=None,
|
|
1791
|
+
name_prefix="scan", options=None, preamble="", devices=None):
|
|
1792
|
+
scan_ctype = dtype_to_ctype(dtype)
|
|
1793
|
+
GenericScanKernel.__init__(self,
|
|
1794
|
+
ctx, dtype,
|
|
1795
|
+
arguments="__global {} *input_ary, __global {} *output_ary".format(
|
|
1796
|
+
scan_ctype, scan_ctype),
|
|
1797
|
+
input_expr="input_ary[i]",
|
|
1798
|
+
scan_expr=scan_expr,
|
|
1799
|
+
neutral=neutral,
|
|
1800
|
+
output_statement=self.ary_output_statement,
|
|
1801
|
+
options=options, preamble=preamble, devices=devices)
|
|
1802
|
+
|
|
1803
|
+
@property
|
|
1804
|
+
def ary_output_statement(self):
|
|
1805
|
+
raise NotImplementedError
|
|
1806
|
+
|
|
1807
|
+
def __call__(self, input_ary, output_ary=None, allocator=None, queue=None):
|
|
1808
|
+
allocator = allocator or input_ary.allocator
|
|
1809
|
+
queue = queue or input_ary.queue or output_ary.queue
|
|
1810
|
+
|
|
1811
|
+
if output_ary is None:
|
|
1812
|
+
output_ary = input_ary
|
|
1813
|
+
|
|
1814
|
+
if isinstance(output_ary, (str, str)) and output_ary == "new":
|
|
1815
|
+
output_ary = cl_array.empty_like(input_ary, allocator=allocator)
|
|
1816
|
+
|
|
1817
|
+
if input_ary.shape != output_ary.shape:
|
|
1818
|
+
raise ValueError("input and output must have the same shape")
|
|
1819
|
+
|
|
1820
|
+
if not input_ary.flags.forc:
|
|
1821
|
+
raise RuntimeError("ScanKernel cannot "
|
|
1822
|
+
"deal with non-contiguous arrays")
|
|
1823
|
+
|
|
1824
|
+
n, = input_ary.shape
|
|
1825
|
+
|
|
1826
|
+
if not n:
|
|
1827
|
+
return output_ary
|
|
1828
|
+
|
|
1829
|
+
GenericScanKernel.__call__(self,
|
|
1830
|
+
input_ary, output_ary, allocator=allocator, queue=queue)
|
|
1831
|
+
|
|
1832
|
+
return output_ary
|
|
1833
|
+
|
|
1834
|
+
|
|
1835
|
+
class InclusiveScanKernel(_LegacyScanKernelBase):
|
|
1836
|
+
ary_output_statement = "output_ary[i] = item;"
|
|
1837
|
+
|
|
1838
|
+
|
|
1839
|
+
class ExclusiveScanKernel(_LegacyScanKernelBase):
|
|
1840
|
+
ary_output_statement = "output_ary[i] = prev_item;"
|
|
1841
|
+
|
|
1842
|
+
# }}}
|
|
1843
|
+
|
|
1844
|
+
|
|
1845
|
+
# {{{ template
|
|
1846
|
+
|
|
1847
|
+
class ScanTemplate(KernelTemplateBase):
|
|
1848
|
+
def __init__(
|
|
1849
|
+
self,
|
|
1850
|
+
arguments: str | list[DtypedArgument],
|
|
1851
|
+
input_expr: str,
|
|
1852
|
+
scan_expr: str,
|
|
1853
|
+
neutral: str | None,
|
|
1854
|
+
output_statement: str,
|
|
1855
|
+
is_segment_start_expr: str | None = None,
|
|
1856
|
+
input_fetch_exprs: list[tuple[str, str, int]] | None = None,
|
|
1857
|
+
name_prefix: str = "scan",
|
|
1858
|
+
preamble: str = "",
|
|
1859
|
+
template_processor: Any = None) -> None:
|
|
1860
|
+
super().__init__(template_processor=template_processor)
|
|
1861
|
+
|
|
1862
|
+
if input_fetch_exprs is None:
|
|
1863
|
+
input_fetch_exprs = []
|
|
1864
|
+
|
|
1865
|
+
self.arguments = arguments
|
|
1866
|
+
self.input_expr = input_expr
|
|
1867
|
+
self.scan_expr = scan_expr
|
|
1868
|
+
self.neutral = neutral
|
|
1869
|
+
self.output_statement = output_statement
|
|
1870
|
+
self.is_segment_start_expr = is_segment_start_expr
|
|
1871
|
+
self.input_fetch_exprs = input_fetch_exprs
|
|
1872
|
+
self.name_prefix = name_prefix
|
|
1873
|
+
self.preamble = preamble
|
|
1874
|
+
|
|
1875
|
+
def build_inner(self, context, type_aliases=(), var_values=(),
|
|
1876
|
+
more_preamble="", more_arguments=(), declare_types=(),
|
|
1877
|
+
options=None, devices=None, scan_cls=GenericScanKernel):
|
|
1878
|
+
renderer = self.get_renderer(type_aliases, var_values, context, options)
|
|
1879
|
+
|
|
1880
|
+
arg_list = renderer.render_argument_list(self.arguments, more_arguments)
|
|
1881
|
+
|
|
1882
|
+
type_decl_preamble = renderer.get_type_decl_preamble(
|
|
1883
|
+
context.devices[0], declare_types, arg_list)
|
|
1884
|
+
|
|
1885
|
+
return scan_cls(context, renderer.type_aliases["scan_t"],
|
|
1886
|
+
renderer.render_argument_list(self.arguments, more_arguments),
|
|
1887
|
+
renderer(self.input_expr), renderer(self.scan_expr),
|
|
1888
|
+
renderer(self.neutral), renderer(self.output_statement),
|
|
1889
|
+
is_segment_start_expr=renderer(self.is_segment_start_expr),
|
|
1890
|
+
input_fetch_exprs=self.input_fetch_exprs,
|
|
1891
|
+
index_dtype=renderer.type_aliases.get("index_t", np.int32),
|
|
1892
|
+
name_prefix=renderer(self.name_prefix), options=options,
|
|
1893
|
+
preamble=(
|
|
1894
|
+
type_decl_preamble
|
|
1895
|
+
+ "\n"
|
|
1896
|
+
+ renderer(self.preamble + "\n" + more_preamble)),
|
|
1897
|
+
devices=devices)
|
|
1898
|
+
|
|
1899
|
+
# }}}
|
|
1900
|
+
|
|
1901
|
+
|
|
1902
|
+
# {{{ 'canned' scan kernels
|
|
1903
|
+
|
|
1904
|
+
@context_dependent_memoize
|
|
1905
|
+
def get_cumsum_kernel(context, input_dtype, output_dtype):
|
|
1906
|
+
from pyopencl.tools import VectorArg
|
|
1907
|
+
return GenericScanKernel(
|
|
1908
|
+
context, output_dtype,
|
|
1909
|
+
arguments=[
|
|
1910
|
+
VectorArg(input_dtype, "input"),
|
|
1911
|
+
VectorArg(output_dtype, "output"),
|
|
1912
|
+
],
|
|
1913
|
+
input_expr="input[i]",
|
|
1914
|
+
scan_expr="a+b", neutral="0",
|
|
1915
|
+
output_statement="""
|
|
1916
|
+
output[i] = item;
|
|
1917
|
+
""")
|
|
1918
|
+
|
|
1919
|
+
# }}}
|
|
1920
|
+
|
|
1921
|
+
# vim: filetype=pyopencl:fdm=marker
|