pyopencl 2024.2__cp312-cp312-macosx_10_14_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

Files changed (122) hide show
  1. pyopencl/__init__.py +2393 -0
  2. pyopencl/_cl.cpython-312-darwin.so +0 -0
  3. pyopencl/_cluda.py +54 -0
  4. pyopencl/_mymako.py +14 -0
  5. pyopencl/algorithm.py +1444 -0
  6. pyopencl/array.py +3427 -0
  7. pyopencl/bitonic_sort.py +238 -0
  8. pyopencl/bitonic_sort_templates.py +594 -0
  9. pyopencl/cache.py +534 -0
  10. pyopencl/capture_call.py +176 -0
  11. pyopencl/characterize/__init__.py +433 -0
  12. pyopencl/characterize/performance.py +237 -0
  13. pyopencl/cl/pyopencl-airy.cl +324 -0
  14. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  15. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  16. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  17. pyopencl/cl/pyopencl-complex.h +303 -0
  18. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  19. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  20. pyopencl/cl/pyopencl-random123/array.h +325 -0
  21. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  22. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  23. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  24. pyopencl/clmath.py +280 -0
  25. pyopencl/clrandom.py +408 -0
  26. pyopencl/cltypes.py +137 -0
  27. pyopencl/compyte/__init__.py +0 -0
  28. pyopencl/compyte/array.py +214 -0
  29. pyopencl/compyte/dtypes.py +290 -0
  30. pyopencl/compyte/ndarray/__init__.py +0 -0
  31. pyopencl/compyte/ndarray/gen_elemwise.py +1907 -0
  32. pyopencl/compyte/ndarray/gen_reduction.py +1511 -0
  33. pyopencl/compyte/ndarray/setup_opencl.py +101 -0
  34. pyopencl/compyte/ndarray/test_gpu_elemwise.py +411 -0
  35. pyopencl/compyte/ndarray/test_gpu_ndarray.py +487 -0
  36. pyopencl/elementwise.py +1164 -0
  37. pyopencl/invoker.py +418 -0
  38. pyopencl/ipython_ext.py +68 -0
  39. pyopencl/reduction.py +780 -0
  40. pyopencl/scan.py +1898 -0
  41. pyopencl/tools.py +1513 -0
  42. pyopencl/version.py +3 -0
  43. pyopencl-2024.2.data/data/CITATION.cff +74 -0
  44. pyopencl-2024.2.data/data/LICENSE +282 -0
  45. pyopencl-2024.2.data/data/Makefile.in +21 -0
  46. pyopencl-2024.2.data/data/README.rst +70 -0
  47. pyopencl-2024.2.data/data/README_SETUP.txt +34 -0
  48. pyopencl-2024.2.data/data/aksetup_helper.py +1013 -0
  49. pyopencl-2024.2.data/data/configure.py +6 -0
  50. pyopencl-2024.2.data/data/contrib/cldis.py +91 -0
  51. pyopencl-2024.2.data/data/contrib/fortran-to-opencl/README +29 -0
  52. pyopencl-2024.2.data/data/contrib/fortran-to-opencl/translate.py +1441 -0
  53. pyopencl-2024.2.data/data/contrib/pyopencl.vim +84 -0
  54. pyopencl-2024.2.data/data/doc/Makefile +23 -0
  55. pyopencl-2024.2.data/data/doc/algorithm.rst +214 -0
  56. pyopencl-2024.2.data/data/doc/array.rst +305 -0
  57. pyopencl-2024.2.data/data/doc/conf.py +26 -0
  58. pyopencl-2024.2.data/data/doc/howto.rst +105 -0
  59. pyopencl-2024.2.data/data/doc/index.rst +137 -0
  60. pyopencl-2024.2.data/data/doc/make_constants.py +561 -0
  61. pyopencl-2024.2.data/data/doc/misc.rst +885 -0
  62. pyopencl-2024.2.data/data/doc/runtime.rst +51 -0
  63. pyopencl-2024.2.data/data/doc/runtime_const.rst +30 -0
  64. pyopencl-2024.2.data/data/doc/runtime_gl.rst +78 -0
  65. pyopencl-2024.2.data/data/doc/runtime_memory.rst +527 -0
  66. pyopencl-2024.2.data/data/doc/runtime_platform.rst +184 -0
  67. pyopencl-2024.2.data/data/doc/runtime_program.rst +364 -0
  68. pyopencl-2024.2.data/data/doc/runtime_queue.rst +182 -0
  69. pyopencl-2024.2.data/data/doc/subst.rst +36 -0
  70. pyopencl-2024.2.data/data/doc/tools.rst +4 -0
  71. pyopencl-2024.2.data/data/doc/types.rst +42 -0
  72. pyopencl-2024.2.data/data/examples/black-hole-accretion.py +2227 -0
  73. pyopencl-2024.2.data/data/examples/demo-struct-reduce.py +75 -0
  74. pyopencl-2024.2.data/data/examples/demo.py +39 -0
  75. pyopencl-2024.2.data/data/examples/demo_array.py +32 -0
  76. pyopencl-2024.2.data/data/examples/demo_array_svm.py +37 -0
  77. pyopencl-2024.2.data/data/examples/demo_elementwise.py +34 -0
  78. pyopencl-2024.2.data/data/examples/demo_elementwise_complex.py +53 -0
  79. pyopencl-2024.2.data/data/examples/demo_mandelbrot.py +183 -0
  80. pyopencl-2024.2.data/data/examples/demo_meta_codepy.py +56 -0
  81. pyopencl-2024.2.data/data/examples/demo_meta_template.py +55 -0
  82. pyopencl-2024.2.data/data/examples/dump-performance.py +38 -0
  83. pyopencl-2024.2.data/data/examples/dump-properties.py +86 -0
  84. pyopencl-2024.2.data/data/examples/gl_interop_demo.py +84 -0
  85. pyopencl-2024.2.data/data/examples/gl_particle_animation.py +218 -0
  86. pyopencl-2024.2.data/data/examples/ipython-demo.ipynb +203 -0
  87. pyopencl-2024.2.data/data/examples/median-filter.py +99 -0
  88. pyopencl-2024.2.data/data/examples/n-body.py +1070 -0
  89. pyopencl-2024.2.data/data/examples/narray.py +37 -0
  90. pyopencl-2024.2.data/data/examples/noisyImage.jpg +0 -0
  91. pyopencl-2024.2.data/data/examples/pi-monte-carlo.py +1166 -0
  92. pyopencl-2024.2.data/data/examples/svm.py +82 -0
  93. pyopencl-2024.2.data/data/examples/transpose.py +229 -0
  94. pyopencl-2024.2.data/data/pytest.ini +3 -0
  95. pyopencl-2024.2.data/data/src/bitlog.cpp +51 -0
  96. pyopencl-2024.2.data/data/src/bitlog.hpp +83 -0
  97. pyopencl-2024.2.data/data/src/clinfo_ext.h +134 -0
  98. pyopencl-2024.2.data/data/src/mempool.hpp +444 -0
  99. pyopencl-2024.2.data/data/src/pyopencl_ext.h +77 -0
  100. pyopencl-2024.2.data/data/src/tools.hpp +90 -0
  101. pyopencl-2024.2.data/data/src/wrap_cl.cpp +61 -0
  102. pyopencl-2024.2.data/data/src/wrap_cl.hpp +5853 -0
  103. pyopencl-2024.2.data/data/src/wrap_cl_part_1.cpp +369 -0
  104. pyopencl-2024.2.data/data/src/wrap_cl_part_2.cpp +702 -0
  105. pyopencl-2024.2.data/data/src/wrap_constants.cpp +1274 -0
  106. pyopencl-2024.2.data/data/src/wrap_helpers.hpp +213 -0
  107. pyopencl-2024.2.data/data/src/wrap_mempool.cpp +731 -0
  108. pyopencl-2024.2.data/data/test/add-vectors-32.spv +0 -0
  109. pyopencl-2024.2.data/data/test/add-vectors-64.spv +0 -0
  110. pyopencl-2024.2.data/data/test/empty-header.h +1 -0
  111. pyopencl-2024.2.data/data/test/test_algorithm.py +1180 -0
  112. pyopencl-2024.2.data/data/test/test_array.py +2392 -0
  113. pyopencl-2024.2.data/data/test/test_arrays_in_structs.py +100 -0
  114. pyopencl-2024.2.data/data/test/test_clmath.py +529 -0
  115. pyopencl-2024.2.data/data/test/test_clrandom.py +75 -0
  116. pyopencl-2024.2.data/data/test/test_enqueue_copy.py +271 -0
  117. pyopencl-2024.2.data/data/test/test_wrapper.py +1554 -0
  118. pyopencl-2024.2.dist-info/LICENSE +282 -0
  119. pyopencl-2024.2.dist-info/METADATA +105 -0
  120. pyopencl-2024.2.dist-info/RECORD +122 -0
  121. pyopencl-2024.2.dist-info/WHEEL +5 -0
  122. pyopencl-2024.2.dist-info/top_level.txt +1 -0
pyopencl/scan.py ADDED
@@ -0,0 +1,1898 @@
1
+ """Scan primitive."""
2
+
3
+
4
+ __copyright__ = """
5
+ Copyright 2011-2012 Andreas Kloeckner
6
+ Copyright 2008-2011 NVIDIA Corporation
7
+ """
8
+
9
+ __license__ = """
10
+ Licensed under the Apache License, Version 2.0 (the "License");
11
+ you may not use this file except in compliance with the License.
12
+ You may obtain a copy of the License at
13
+
14
+ https://www.apache.org/licenses/LICENSE-2.0
15
+
16
+ Unless required by applicable law or agreed to in writing, software
17
+ distributed under the License is distributed on an "AS IS" BASIS,
18
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ See the License for the specific language governing permissions and
20
+ limitations under the License.
21
+
22
+ Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
23
+ """
24
+
25
+ import logging
26
+ from abc import ABC, abstractmethod
27
+ from dataclasses import dataclass
28
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
29
+
30
+ import numpy as np
31
+ from pytools.persistent_dict import WriteOncePersistentDict
32
+
33
+ import pyopencl as cl
34
+ import pyopencl._mymako as mako
35
+ import pyopencl.array
36
+ from pyopencl._cluda import CLUDA_PREAMBLE
37
+ from pyopencl.tools import (
38
+ DtypedArgument, KernelTemplateBase, _NumpyTypesKeyBuilder,
39
+ _process_code_for_macro, bitlog2, context_dependent_memoize, dtype_to_ctype,
40
+ get_arg_list_scalar_arg_dtypes, get_arg_offset_adjuster_code)
41
+
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # {{{ preamble
47
+
48
+ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
49
+ #define WG_SIZE ${wg_size}
50
+
51
+ #define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr}
52
+ #define INPUT_EXPR(i) (${input_expr})
53
+ %if is_segmented:
54
+ #define IS_SEG_START(i, a) (${is_segment_start_expr})
55
+ %endif
56
+
57
+ ${preamble}
58
+
59
+ typedef ${dtype_to_ctype(scan_dtype)} scan_type;
60
+ typedef ${dtype_to_ctype(index_dtype)} index_type;
61
+
62
+ // NO_SEG_BOUNDARY is the largest representable integer in index_type.
63
+ // This assumption is used in code below.
64
+ #define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
65
+ """
66
+
67
+ # }}}
68
+
69
+ # {{{ main scan code
70
+
71
+ # Algorithm: Each work group is responsible for one contiguous
72
+ # 'interval'. There are just enough intervals to fill all compute
73
+ # units. Intervals are split into 'units'. A unit is what gets
74
+ # worked on in parallel by one work group.
75
+ #
76
+ # in index space:
77
+ # interval > unit > local-parallel > k-group
78
+ #
79
+ # (Note that there is also a transpose in here: The data is read
80
+ # with local ids along linear index order.)
81
+ #
82
+ # Each unit has two axes--the local-id axis and the k axis.
83
+ #
84
+ # unit 0:
85
+ # | | | | | | | | | | ----> lid
86
+ # | | | | | | | | | |
87
+ # | | | | | | | | | |
88
+ # | | | | | | | | | |
89
+ # | | | | | | | | | |
90
+ #
91
+ # |
92
+ # v k (fastest-moving in linear index)
93
+ #
94
+ # unit 1:
95
+ # | | | | | | | | | | ----> lid
96
+ # | | | | | | | | | |
97
+ # | | | | | | | | | |
98
+ # | | | | | | | | | |
99
+ # | | | | | | | | | |
100
+ #
101
+ # |
102
+ # v k (fastest-moving in linear index)
103
+ #
104
+ # ...
105
+ #
106
+ # At a device-global level, this is a three-phase algorithm, in
107
+ # which first each interval does its local scan, then a scan
108
+ # across intervals exchanges data globally, and the final update
109
+ # adds the exchanged sums to each interval.
110
+ #
111
+ # Exclusive scan is realized by allowing look-behind (access to the
112
+ # preceding item) in the final update, by means of a local shift.
113
+ #
114
+ # NOTE: All segment_start_in_X indices are relative to the start
115
+ # of the array.
116
+
117
+ SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL//
118
+
119
+ #define K ${k_group_size}
120
+
121
+ // #define DEBUG
122
+ #ifdef DEBUG
123
+ #define pycl_printf(ARGS) printf ARGS
124
+ #else
125
+ #define pycl_printf(ARGS) /* */
126
+ #endif
127
+
128
+ KERNEL
129
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
130
+ void ${kernel_name}(
131
+ ${argument_signature},
132
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer,
133
+ const index_type N,
134
+ const index_type interval_size
135
+ %if is_first_level:
136
+ , GLOBAL_MEM scan_type *restrict interval_results
137
+ %endif
138
+ %if is_segmented and is_first_level:
139
+ // NO_SEG_BOUNDARY if no segment boundary in interval.
140
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
141
+ %endif
142
+ %if store_segment_start_flags:
143
+ , GLOBAL_MEM char *restrict g_segment_start_flags
144
+ %endif
145
+ )
146
+ {
147
+ ${arg_offset_adjustment}
148
+
149
+ // index K in first dimension used for carry storage
150
+ %if use_bank_conflict_avoidance:
151
+ // Avoid bank conflicts by adding a single 32-bit value to the size of
152
+ // the scan type.
153
+ struct __attribute__ ((__packed__)) wrapped_scan_type
154
+ {
155
+ scan_type value;
156
+ int dummy;
157
+ };
158
+ %else:
159
+ struct wrapped_scan_type
160
+ {
161
+ scan_type value;
162
+ };
163
+ %endif
164
+ // padded in WG_SIZE to avoid bank conflicts
165
+ LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1];
166
+
167
+ %if is_segmented:
168
+ LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
169
+ LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
170
+
171
+ // only relevant/populated for local id 0
172
+ index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
173
+
174
+ index_type first_segment_start_in_k_group, first_segment_start_in_subtree;
175
+ %endif
176
+
177
+ // {{{ declare local data for input_fetch_exprs if any of them are stenciled
178
+
179
+ <%
180
+ fetch_expr_offsets = {}
181
+ for name, arg_name, ife_offset in input_fetch_exprs:
182
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
183
+
184
+ local_fetch_expr_args = set(
185
+ arg_name
186
+ for arg_name, ife_offsets in fetch_expr_offsets.items()
187
+ if -1 in ife_offsets or len(ife_offsets) > 1)
188
+ %>
189
+
190
+ %for arg_name in local_fetch_expr_args:
191
+ LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
192
+ %endfor
193
+
194
+ // }}}
195
+
196
+ const index_type interval_begin = interval_size * GID_0;
197
+ const index_type interval_end = min(interval_begin + interval_size, N);
198
+
199
+ const index_type unit_size = K * WG_SIZE;
200
+
201
+ index_type unit_base = interval_begin;
202
+
203
+ %for is_tail in [False, True]:
204
+
205
+ %if not is_tail:
206
+ for(; unit_base + unit_size <= interval_end; unit_base += unit_size)
207
+ %else:
208
+ if (unit_base < interval_end)
209
+ %endif
210
+
211
+ {
212
+
213
+ // {{{ carry out input_fetch_exprs
214
+ // (if there are ones that need to be fetched into local)
215
+
216
+ %if local_fetch_expr_args:
217
+ for(index_type k = 0; k < K; k++)
218
+ {
219
+ const index_type offset = k*WG_SIZE + LID_0;
220
+ const index_type read_i = unit_base + offset;
221
+
222
+ %for arg_name in local_fetch_expr_args:
223
+ %if is_tail:
224
+ if (read_i < interval_end)
225
+ %endif
226
+ {
227
+ l_${arg_name}[offset] = ${arg_name}[read_i];
228
+ }
229
+ %endfor
230
+ }
231
+
232
+ local_barrier();
233
+ %endif
234
+
235
+ pycl_printf(("after input_fetch_exprs\n"));
236
+
237
+ // }}}
238
+
239
+ // {{{ read a unit's worth of data from global
240
+
241
+ for(index_type k = 0; k < K; k++)
242
+ {
243
+ const index_type offset = k*WG_SIZE + LID_0;
244
+ const index_type read_i = unit_base + offset;
245
+
246
+ %if is_tail:
247
+ if (read_i < interval_end)
248
+ %endif
249
+ {
250
+ %for name, arg_name, ife_offset in input_fetch_exprs:
251
+ ${arg_ctypes[arg_name]} ${name};
252
+
253
+ %if arg_name in local_fetch_expr_args:
254
+ if (offset + ${ife_offset} >= 0)
255
+ ${name} = l_${arg_name}[offset + ${ife_offset}];
256
+ else if (read_i + ${ife_offset} >= 0)
257
+ ${name} = ${arg_name}[read_i + ${ife_offset}];
258
+ /*
259
+ else
260
+ if out of bounds, name is left undefined */
261
+
262
+ %else:
263
+ // ${arg_name} gets fetched directly from global
264
+ ${name} = ${arg_name}[read_i];
265
+
266
+ %endif
267
+ %endfor
268
+
269
+ scan_type scan_value = INPUT_EXPR(read_i);
270
+
271
+ const index_type o_mod_k = offset % K;
272
+ const index_type o_div_k = offset / K;
273
+ ldata[o_mod_k][o_div_k].value = scan_value;
274
+
275
+ %if is_segmented:
276
+ bool is_seg_start = IS_SEG_START(read_i, scan_value);
277
+ l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start;
278
+ %endif
279
+ %if store_segment_start_flags:
280
+ g_segment_start_flags[read_i] = is_seg_start;
281
+ %endif
282
+ }
283
+ }
284
+
285
+ pycl_printf(("after read from global\n"));
286
+
287
+ // }}}
288
+
289
+ // {{{ carry in from previous unit, if applicable
290
+
291
+ %if is_segmented:
292
+ local_barrier();
293
+
294
+ first_segment_start_in_k_group = NO_SEG_BOUNDARY;
295
+ if (l_segment_start_flags[0][LID_0])
296
+ first_segment_start_in_k_group = unit_base + K*LID_0;
297
+ %endif
298
+
299
+ if (LID_0 == 0 && unit_base != interval_begin)
300
+ {
301
+ scan_type tmp = ldata[K][WG_SIZE - 1].value;
302
+ scan_type tmp_aux = ldata[0][0].value;
303
+
304
+ ldata[0][0].value = SCAN_EXPR(
305
+ tmp, tmp_aux,
306
+ %if is_segmented:
307
+ (l_segment_start_flags[0][0])
308
+ %else:
309
+ false
310
+ %endif
311
+ );
312
+ }
313
+
314
+ pycl_printf(("after carry-in\n"));
315
+
316
+ // }}}
317
+
318
+ local_barrier();
319
+
320
+ // {{{ scan along k (sequentially in each work item)
321
+
322
+ scan_type sum = ldata[0][LID_0].value;
323
+
324
+ %if is_tail:
325
+ const index_type offset_end = interval_end - unit_base;
326
+ %endif
327
+
328
+ for (index_type k = 1; k < K; k++)
329
+ {
330
+ %if is_tail:
331
+ if ((index_type) (K * LID_0 + k) < offset_end)
332
+ %endif
333
+ {
334
+ scan_type tmp = ldata[k][LID_0].value;
335
+
336
+ %if is_segmented:
337
+ index_type seq_i = unit_base + K*LID_0 + k;
338
+
339
+ if (l_segment_start_flags[k][LID_0])
340
+ {
341
+ first_segment_start_in_k_group = min(
342
+ first_segment_start_in_k_group,
343
+ seq_i);
344
+ }
345
+ %endif
346
+
347
+ sum = SCAN_EXPR(sum, tmp,
348
+ %if is_segmented:
349
+ (l_segment_start_flags[k][LID_0])
350
+ %else:
351
+ false
352
+ %endif
353
+ );
354
+
355
+ ldata[k][LID_0].value = sum;
356
+ }
357
+ }
358
+
359
+ pycl_printf(("after scan along k\n"));
360
+
361
+ // }}}
362
+
363
+ // store carry in out-of-bounds (padding) array entry (index K) in
364
+ // the K direction
365
+ ldata[K][LID_0].value = sum;
366
+
367
+ %if is_segmented:
368
+ l_first_segment_start_in_subtree[LID_0] =
369
+ first_segment_start_in_k_group;
370
+ %endif
371
+
372
+ local_barrier();
373
+
374
+ // {{{ tree-based local parallel scan
375
+
376
+ // This tree-based scan works as follows:
377
+ // - Each work item adds the previous item to its current state
378
+ // - barrier
379
+ // - Each work item adds in the item from two positions to the left
380
+ // - barrier
381
+ // - Each work item adds in the item from four positions to the left
382
+ // ...
383
+ // At the end, each item has summed all prior items.
384
+
385
+ // across k groups, along local id
386
+ // (uses out-of-bounds k=K array entry for storage)
387
+
388
+ scan_type val = ldata[K][LID_0].value;
389
+
390
+ <% scan_offset = 1 %>
391
+
392
+ % while scan_offset <= wg_size:
393
+ // {{{ reads from local allowed, writes to local not allowed
394
+
395
+ if (LID_0 >= ${scan_offset})
396
+ {
397
+ scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value;
398
+ % if is_tail:
399
+ if (K*LID_0 < offset_end)
400
+ % endif
401
+ {
402
+ val = SCAN_EXPR(tmp, val,
403
+ %if is_segmented:
404
+ (l_first_segment_start_in_subtree[LID_0]
405
+ != NO_SEG_BOUNDARY)
406
+ %else:
407
+ false
408
+ %endif
409
+ );
410
+ }
411
+
412
+ %if is_segmented:
413
+ // Prepare for l_first_segment_start_in_subtree, below.
414
+
415
+ // Note that this update must take place *even* if we're
416
+ // out of bounds.
417
+
418
+ first_segment_start_in_subtree = min(
419
+ l_first_segment_start_in_subtree[LID_0],
420
+ l_first_segment_start_in_subtree
421
+ [LID_0 - ${scan_offset}]);
422
+ %endif
423
+ }
424
+ %if is_segmented:
425
+ else
426
+ {
427
+ first_segment_start_in_subtree =
428
+ l_first_segment_start_in_subtree[LID_0];
429
+ }
430
+ %endif
431
+
432
+ // }}}
433
+
434
+ local_barrier();
435
+
436
+ // {{{ writes to local allowed, reads from local not allowed
437
+
438
+ ldata[K][LID_0].value = val;
439
+ %if is_segmented:
440
+ l_first_segment_start_in_subtree[LID_0] =
441
+ first_segment_start_in_subtree;
442
+ %endif
443
+
444
+ // }}}
445
+
446
+ local_barrier();
447
+
448
+ %if 0:
449
+ if (LID_0 == 0)
450
+ {
451
+ printf("${scan_offset}: ");
452
+ for (int i = 0; i < WG_SIZE; ++i)
453
+ {
454
+ if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY)
455
+ printf("- ");
456
+ else
457
+ printf("%d ", l_first_segment_start_in_subtree[i]);
458
+ }
459
+ printf("\n");
460
+ }
461
+ %endif
462
+
463
+ <% scan_offset *= 2 %>
464
+ % endwhile
465
+
466
+ pycl_printf(("after tree scan\n"));
467
+
468
+ // }}}
469
+
470
+ // {{{ update local values
471
+
472
+ if (LID_0 > 0)
473
+ {
474
+ sum = ldata[K][LID_0 - 1].value;
475
+
476
+ for(index_type k = 0; k < K; k++)
477
+ {
478
+ %if is_tail:
479
+ if (K * LID_0 + k < offset_end)
480
+ %endif
481
+ {
482
+ scan_type tmp = ldata[k][LID_0].value;
483
+ ldata[k][LID_0].value = SCAN_EXPR(sum, tmp,
484
+ %if is_segmented:
485
+ (unit_base + K * LID_0 + k
486
+ >= first_segment_start_in_k_group)
487
+ %else:
488
+ false
489
+ %endif
490
+ );
491
+ }
492
+ }
493
+ }
494
+
495
+ %if is_segmented:
496
+ if (LID_0 == 0)
497
+ {
498
+ // update interval-wide first-seg variable from current unit
499
+ first_segment_start_in_interval = min(
500
+ first_segment_start_in_interval,
501
+ l_first_segment_start_in_subtree[WG_SIZE-1]);
502
+ }
503
+ %endif
504
+
505
+ pycl_printf(("after local update\n"));
506
+
507
+ // }}}
508
+
509
+ local_barrier();
510
+
511
+ // {{{ write data
512
+
513
+ %if is_gpu:
514
+ {
515
+ // work hard with index math to achieve contiguous 32-bit stores
516
+ __global int *dest =
517
+ (__global int *) (partial_scan_buffer + unit_base);
518
+
519
+ <%
520
+
521
+ assert scan_dtype.itemsize % 4 == 0
522
+
523
+ ints_per_wg = wg_size
524
+ ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
525
+
526
+ %>
527
+
528
+ const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
529
+
530
+ %for store_base in range(0, ints_to_store, ints_per_wg):
531
+ <%
532
+
533
+ # Observe that ints_to_store is divisible by the work group
534
+ # size already, so we won't go out of bounds that way.
535
+ assert store_base + ints_per_wg <= ints_to_store
536
+
537
+ %>
538
+
539
+ %if is_tail:
540
+ if (${store_base} + LID_0 <
541
+ scan_types_per_int*(interval_end - unit_base))
542
+ %endif
543
+ {
544
+ index_type linear_index = ${store_base} + LID_0;
545
+ index_type linear_scan_data_idx =
546
+ linear_index / scan_types_per_int;
547
+ index_type remainder =
548
+ linear_index - linear_scan_data_idx * scan_types_per_int;
549
+
550
+ __local int *src = (__local int *) &(
551
+ ldata
552
+ [linear_scan_data_idx % K]
553
+ [linear_scan_data_idx / K].value);
554
+
555
+ dest[linear_index] = src[remainder];
556
+ }
557
+ %endfor
558
+ }
559
+ %else:
560
+ for (index_type k = 0; k < K; k++)
561
+ {
562
+ const index_type offset = k*WG_SIZE + LID_0;
563
+
564
+ %if is_tail:
565
+ if (unit_base + offset < interval_end)
566
+ %endif
567
+ {
568
+ pycl_printf(("write: %d\n", unit_base + offset));
569
+ partial_scan_buffer[unit_base + offset] =
570
+ ldata[offset % K][offset / K].value;
571
+ }
572
+ }
573
+ %endif
574
+
575
+ pycl_printf(("after write\n"));
576
+
577
+ // }}}
578
+
579
+ local_barrier();
580
+ }
581
+
582
+ % endfor
583
+
584
+ // write interval sum
585
+ %if is_first_level:
586
+ if (LID_0 == 0)
587
+ {
588
+ interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
589
+ %if is_segmented:
590
+ g_first_segment_start_in_interval[GID_0] =
591
+ first_segment_start_in_interval;
592
+ %endif
593
+ }
594
+ %endif
595
+ }
596
+ """
597
+
598
+ # }}}
599
+
600
+ # {{{ update
601
+
602
+ UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
603
+
604
+ KERNEL
605
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
606
+ void ${name_prefix}_final_update(
607
+ ${argument_signature},
608
+ const index_type N,
609
+ const index_type interval_size,
610
+ GLOBAL_MEM scan_type *restrict interval_results,
611
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer
612
+ %if is_segmented:
613
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
614
+ %endif
615
+ %if is_segmented and use_lookbehind_update:
616
+ , GLOBAL_MEM char *restrict g_segment_start_flags
617
+ %endif
618
+ )
619
+ {
620
+ ${arg_offset_adjustment}
621
+
622
+ %if use_lookbehind_update:
623
+ LOCAL_MEM scan_type ldata[WG_SIZE];
624
+ %endif
625
+ %if is_segmented and use_lookbehind_update:
626
+ LOCAL_MEM char l_segment_start_flags[WG_SIZE];
627
+ %endif
628
+
629
+ const index_type interval_begin = interval_size * GID_0;
630
+ const index_type interval_end = min(interval_begin + interval_size, N);
631
+
632
+ // carry from last interval
633
+ scan_type carry = ${neutral};
634
+ if (GID_0 != 0)
635
+ carry = interval_results[GID_0 - 1];
636
+
637
+ %if is_segmented:
638
+ const index_type first_seg_start_in_interval =
639
+ g_first_segment_start_in_interval[GID_0];
640
+ %endif
641
+
642
+ %if not is_segmented and 'last_item' in output_statement:
643
+ scan_type last_item = interval_results[GDIM_0-1];
644
+ %endif
645
+
646
+ %if not use_lookbehind_update:
647
+ // {{{ no look-behind ('prev_item' not in output_statement -> simpler)
648
+
649
+ index_type update_i = interval_begin+LID_0;
650
+
651
+ %if is_segmented:
652
+ index_type seg_end = min(first_seg_start_in_interval, interval_end);
653
+ %endif
654
+
655
+ for(; update_i < interval_end; update_i += WG_SIZE)
656
+ {
657
+ scan_type partial_val = partial_scan_buffer[update_i];
658
+ scan_type item = SCAN_EXPR(carry, partial_val,
659
+ %if is_segmented:
660
+ (update_i >= seg_end)
661
+ %else:
662
+ false
663
+ %endif
664
+ );
665
+ index_type i = update_i;
666
+
667
+ { ${output_statement}; }
668
+ }
669
+
670
+ // }}}
671
+ %else:
672
+ // {{{ allow look-behind ('prev_item' in output_statement -> complicated)
673
+
674
+ // We are not allowed to branch across barriers at a granularity smaller
675
+ // than the whole workgroup. Therefore, the for loop is group-global,
676
+ // and there are lots of local ifs.
677
+
678
+ index_type group_base = interval_begin;
679
+ scan_type prev_item = carry; // (A)
680
+
681
+ for(; group_base < interval_end; group_base += WG_SIZE)
682
+ {
683
+ index_type update_i = group_base+LID_0;
684
+
685
+ // load a work group's worth of data
686
+ if (update_i < interval_end)
687
+ {
688
+ scan_type tmp = partial_scan_buffer[update_i];
689
+
690
+ tmp = SCAN_EXPR(carry, tmp,
691
+ %if is_segmented:
692
+ (update_i >= first_seg_start_in_interval)
693
+ %else:
694
+ false
695
+ %endif
696
+ );
697
+
698
+ ldata[LID_0] = tmp;
699
+
700
+ %if is_segmented:
701
+ l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
702
+ %endif
703
+ }
704
+
705
+ local_barrier();
706
+
707
+ // find prev_item
708
+ if (LID_0 != 0)
709
+ prev_item = ldata[LID_0 - 1];
710
+ /*
711
+ else
712
+ prev_item = carry (see (A)) OR last tail (see (B));
713
+ */
714
+
715
+ if (update_i < interval_end)
716
+ {
717
+ %if is_segmented:
718
+ if (l_segment_start_flags[LID_0])
719
+ prev_item = ${neutral};
720
+ %endif
721
+
722
+ scan_type item = ldata[LID_0];
723
+ index_type i = update_i;
724
+ { ${output_statement}; }
725
+ }
726
+
727
+ if (LID_0 == 0)
728
+ prev_item = ldata[WG_SIZE - 1]; // (B)
729
+
730
+ local_barrier();
731
+ }
732
+
733
+ // }}}
734
+ %endif
735
+ }
736
+ """
737
+
738
+ # }}}
739
+
740
+
741
+ # {{{ driver
742
+
743
+ # {{{ helpers
744
+
745
+ def _round_down_to_power_of_2(val: int) -> int:
746
+ result = 2**bitlog2(val)
747
+ if result > val:
748
+ result >>= 1
749
+
750
+ assert result <= val
751
+ return result
752
+
753
+
754
+ _PREFIX_WORDS = set("""
755
+ ldata partial_scan_buffer global scan_offset
756
+ segment_start_in_k_group carry
757
+ g_first_segment_start_in_interval IS_SEG_START tmp Z
758
+ val l_first_segment_start_in_subtree unit_size
759
+ index_type interval_begin interval_size offset_end K
760
+ SCAN_EXPR do_update WG_SIZE
761
+ first_segment_start_in_k_group scan_type
762
+ segment_start_in_subtree offset interval_results interval_end
763
+ first_segment_start_in_subtree unit_base
764
+ first_segment_start_in_interval k INPUT_EXPR
765
+ prev_group_sum prev pv value partial_val pgs
766
+ is_seg_start update_i scan_item_at_i seq_i read_i
767
+ l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
768
+ first_seg_start_in_interval g_segment_start_flags
769
+ group_base seg_end my_val DEBUG ARGS
770
+ ints_to_store ints_per_wg scan_types_per_int linear_index
771
+ linear_scan_data_idx dest src store_base wrapped_scan_type
772
+ dummy scan_tmp tmp_aux
773
+
774
+ LID_2 LID_1 LID_0
775
+ LDIM_0 LDIM_1 LDIM_2
776
+ GDIM_0 GDIM_1 GDIM_2
777
+ GID_0 GID_1 GID_2
778
+ """.split())
779
+
780
+ _IGNORED_WORDS = set("""
781
+ 4 8 32
782
+
783
+ typedef for endfor if void while endwhile endfor endif else const printf
784
+ None return bool n char true false ifdef pycl_printf str range assert
785
+ np iinfo max itemsize __packed__ struct restrict ptrdiff_t
786
+
787
+ set iteritems len setdefault
788
+
789
+ GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
790
+ local_barrier
791
+ CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
792
+ pragma __attribute__ __global __kernel __local
793
+ get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
794
+ get_num_groups barrier get_group_id
795
+ CL_VERSION_1_1 __OPENCL_C_VERSION__ 120
796
+
797
+ _final_update _debug_scan kernel_name
798
+
799
+ positions all padded integer its previous write based writes 0
800
+ has local worth scan_expr to read cannot not X items False bank
801
+ four beginning follows applicable item min each indices works side
802
+ scanning right summed relative used id out index avoid current state
803
+ boundary True across be This reads groups along Otherwise undetermined
804
+ store of times prior s update first regardless Each number because
805
+ array unit from segment conflicts two parallel 2 empty define direction
806
+ CL padding work tree bounds values and adds
807
+ scan is allowed thus it an as enable at in occur sequentially end no
808
+ storage data 1 largest may representable uses entry Y meaningful
809
+ computations interval At the left dimension know d
810
+ A load B group perform shift tail see last OR
811
+ this add fetched into are directly need
812
+ gets them stenciled that undefined
813
+ there up any ones or name only relevant populated
814
+ even wide we Prepare int seg Note re below place take variable must
815
+ intra Therefore find code assumption
816
+ branch workgroup complicated granularity phase remainder than simpler
817
+ We smaller look ifs lots self behind allow barriers whole loop
818
+ after already Observe achieve contiguous stores hard go with by math
819
+ size won t way divisible bit so Avoid declare adding single type
820
+
821
+ is_tail is_first_level input_expr argument_signature preamble
822
+ double_support neutral output_statement
823
+ k_group_size name_prefix is_segmented index_dtype scan_dtype
824
+ wg_size is_segment_start_expr fetch_expr_offsets
825
+ arg_ctypes ife_offsets input_fetch_exprs def
826
+ ife_offset arg_name local_fetch_expr_args update_body
827
+ update_loop_lookbehind update_loop_plain update_loop
828
+ use_lookbehind_update store_segment_start_flags
829
+ update_loop first_seg scan_dtype dtype_to_ctype
830
+ is_gpu use_bank_conflict_avoidance
831
+
832
+ a b prev_item i last_item prev_value
833
+ N NO_SEG_BOUNDARY across_seg_boundary
834
+
835
+ arg_offset_adjustment
836
+ """.split())
837
+
838
+
839
+ def _make_template(s: str):
840
+ import re
841
+ leftovers = set()
842
+
843
+ def replace_id(match: "re.Match") -> str:
844
+ # avoid name clashes with user code by adding 'psc_' prefix to
845
+ # identifiers.
846
+
847
+ word = match.group(1)
848
+ if word in _IGNORED_WORDS:
849
+ return word
850
+ elif word in _PREFIX_WORDS:
851
+ return f"psc_{word}"
852
+ else:
853
+ leftovers.add(word)
854
+ return word
855
+
856
+ s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
857
+ if leftovers:
858
+ from warnings import warn
859
+ warn("Leftover words in identifier prefixing: " + " ".join(leftovers),
860
+ stacklevel=3)
861
+
862
+ return mako.template.Template(s, strict_undefined=True) # type: ignore
863
+
864
+
865
+ @dataclass(frozen=True)
866
+ class _GeneratedScanKernelInfo:
867
+ scan_src: str
868
+ kernel_name: str
869
+ scalar_arg_dtypes: List[Optional[np.dtype]]
870
+ wg_size: int
871
+ k_group_size: int
872
+
873
+ def build(self, context: cl.Context, options: Any) -> "_BuiltScanKernelInfo":
874
+ program = cl.Program(context, self.scan_src).build(options)
875
+ kernel = getattr(program, self.kernel_name)
876
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
877
+ return _BuiltScanKernelInfo(
878
+ kernel=kernel,
879
+ wg_size=self.wg_size,
880
+ k_group_size=self.k_group_size)
881
+
882
+
883
+ @dataclass(frozen=True)
884
+ class _BuiltScanKernelInfo:
885
+ kernel: cl.Kernel
886
+ wg_size: int
887
+ k_group_size: int
888
+
889
+
890
+ @dataclass(frozen=True)
891
+ class _GeneratedFinalUpdateKernelInfo:
892
+ source: str
893
+ kernel_name: str
894
+ scalar_arg_dtypes: List[Optional[np.dtype]]
895
+ update_wg_size: int
896
+
897
+ def build(self,
898
+ context: cl.Context,
899
+ options: Any) -> "_BuiltFinalUpdateKernelInfo":
900
+ program = cl.Program(context, self.source).build(options)
901
+ kernel = getattr(program, self.kernel_name)
902
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
903
+ return _BuiltFinalUpdateKernelInfo(kernel, self.update_wg_size)
904
+
905
+
906
+ @dataclass(frozen=True)
907
+ class _BuiltFinalUpdateKernelInfo:
908
+ kernel: cl.Kernel
909
+ update_wg_size: int
910
+
911
+ # }}}
912
+
913
+
914
+ class ScanPerformanceWarning(UserWarning):
915
+ pass
916
+
917
+
918
+ class GenericScanKernelBase(ABC):
919
+ # {{{ constructor, argument processing
920
+
921
+ def __init__(
922
+ self,
923
+ ctx: cl.Context,
924
+ dtype: Any,
925
+ arguments: Union[str, List[DtypedArgument]],
926
+ input_expr: str,
927
+ scan_expr: str,
928
+ neutral: Optional[str],
929
+ output_statement: str,
930
+ is_segment_start_expr: Optional[str] = None,
931
+ input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
932
+ index_dtype: Any = None,
933
+ name_prefix: str = "scan",
934
+ options: Any = None,
935
+ preamble: str = "",
936
+ devices: Optional[cl.Device] = None) -> None:
937
+ """
938
+ :arg ctx: a :class:`pyopencl.Context` within which the code
939
+ for this scan kernel will be generated.
940
+ :arg dtype: the :class:`numpy.dtype` with which the scan will
941
+ be performed. May be a structured type if that type was registered
942
+ through :func:`pyopencl.tools.get_or_register_dtype`.
943
+ :arg arguments: A string of comma-separated C argument declarations.
944
+ If *arguments* is specified, then *input_expr* must also be
945
+ specified. All types used here must be known to PyOpenCL.
946
+ (see :func:`pyopencl.tools.get_or_register_dtype`).
947
+ :arg scan_expr: The associative, binary operation carrying out the scan,
948
+ represented as a C string. Its two arguments are available as ``a``
949
+ and ``b`` when it is evaluated. ``b`` is guaranteed to be the
950
+ 'element being updated', and ``a`` is the increment. Thus,
951
+ if some data is supposed to just propagate along without being
952
+ modified by the scan, it should live in ``b``.
953
+
954
+ This expression may call functions given in the *preamble*.
955
+
956
+ Another value available to this expression is ``across_seg_boundary``,
957
+ a C `bool` indicating whether this scan update is crossing a
958
+ segment boundary, as defined by ``is_segment_start_expr``.
959
+ The scan routine does not implement segmentation
960
+ semantics on its own. It relies on ``scan_expr`` to do this.
961
+ This value is available (but always ``false``) even for a
962
+ non-segmented scan.
963
+
964
+ .. note::
965
+
966
+ In early pre-releases of the segmented scan,
967
+ segmentation semantics were implemented *without*
968
+ relying on ``scan_expr``.
969
+
970
+ :arg input_expr: A C expression, encoded as a string, resulting
971
+ in the values to which the scan is applied. This may be used
972
+ to apply a mapping to values stored in *arguments* before being
973
+ scanned. The result of this expression must match *dtype*.
974
+ The index intended to be mapped is available as ``i`` in this
975
+ expression. This expression may also use the variables defined
976
+ by *input_fetch_expr*.
977
+
978
+ This expression may also call functions given in the *preamble*.
979
+ :arg output_statement: a C statement that writes
980
+ the output of the scan. It has access to the scan result as ``item``,
981
+ the preceding scan result item as ``prev_item``, and the current index
982
+ as ``i``. ``prev_item`` in a segmented scan will be the neutral element
983
+ at a segment boundary, not the immediately preceding item.
984
+
985
+ Using *prev_item* in output statement has a small run-time cost.
986
+ ``prev_item`` enables the construction of an exclusive scan.
987
+
988
+ For non-segmented scans, *output_statement* may also reference
989
+ ``last_item``, which evaluates to the scan result of the last
990
+ array entry.
991
+ :arg is_segment_start_expr: A C expression, encoded as a string,
992
+ resulting in a C ``bool`` value that determines whether a new
993
+ scan segments starts at index *i*. If given, makes the scan a
994
+ segmented scan. Has access to the current index ``i``, the result
995
+ of *input_expr* as ``a``, and in addition may use *arguments* and
996
+ *input_fetch_expr* variables just like *input_expr*.
997
+
998
+ If it returns true, then previous sums will not spill over into the
999
+ item with index *i* or subsequent items.
1000
+ :arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*.
1001
+ An entry here has the effect of doing the equivalent of the following
1002
+ before input_expr::
1003
+
1004
+ ARG_NAME_TYPE NAME = ARG_NAME[i+OFFSET];
1005
+
1006
+ ``OFFSET`` is allowed to be 0 or -1, and ``ARG_NAME_TYPE`` is the type
1007
+ of ``ARG_NAME``.
1008
+ :arg preamble: |preamble|
1009
+
1010
+ The first array in the argument list determines the size of the index
1011
+ space over which the scan is carried out, and thus the values over
1012
+ which the index *i* occurring in a number of code fragments in
1013
+ arguments above will vary.
1014
+
1015
+ All code fragments further have access to N, the number of elements
1016
+ being processed in the scan.
1017
+ """
1018
+
1019
+ if index_dtype is None:
1020
+ index_dtype = np.dtype(np.int32)
1021
+
1022
+ if input_fetch_exprs is None:
1023
+ input_fetch_exprs = []
1024
+
1025
+ self.context = ctx
1026
+ dtype = self.dtype = np.dtype(dtype)
1027
+
1028
+ if neutral is None:
1029
+ from warnings import warn
1030
+ warn("not specifying 'neutral' is deprecated and will lead to "
1031
+ "wrong results if your scan is not in-place or your "
1032
+ "'output_statement' does something otherwise non-trivial",
1033
+ stacklevel=2)
1034
+
1035
+ if dtype.itemsize % 4 != 0:
1036
+ raise TypeError("scan value type must have size divisible by 4 bytes")
1037
+
1038
+ self.index_dtype = np.dtype(index_dtype)
1039
+ if np.iinfo(self.index_dtype).min >= 0:
1040
+ raise TypeError("index_dtype must be signed")
1041
+
1042
+ if devices is None:
1043
+ devices = ctx.devices
1044
+ self.devices = devices
1045
+ self.options = options
1046
+
1047
+ from pyopencl.tools import parse_arg_list
1048
+ self.parsed_args = parse_arg_list(arguments)
1049
+ from pyopencl.tools import VectorArg
1050
+ self.first_array_idx = [
1051
+ i for i, arg in enumerate(self.parsed_args)
1052
+ if isinstance(arg, VectorArg)][0]
1053
+
1054
+ self.input_expr = input_expr
1055
+
1056
+ self.is_segment_start_expr = is_segment_start_expr
1057
+ self.is_segmented = is_segment_start_expr is not None
1058
+ if self.is_segmented:
1059
+ is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
1060
+
1061
+ self.output_statement = output_statement
1062
+
1063
+ for _name, _arg_name, ife_offset in input_fetch_exprs:
1064
+ if ife_offset not in [0, -1]:
1065
+ raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
1066
+ self.input_fetch_exprs = input_fetch_exprs
1067
+
1068
+ arg_dtypes = {}
1069
+ arg_ctypes = {}
1070
+ for arg in self.parsed_args:
1071
+ arg_dtypes[arg.name] = arg.dtype
1072
+ arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
1073
+
1074
+ self.name_prefix = name_prefix
1075
+
1076
+ # {{{ set up shared code dict
1077
+
1078
+ from pyopencl.characterize import has_double_support
1079
+
1080
+ self.code_variables = {
1081
+ "np": np,
1082
+ "dtype_to_ctype": dtype_to_ctype,
1083
+ "preamble": preamble,
1084
+ "name_prefix": name_prefix,
1085
+ "index_dtype": self.index_dtype,
1086
+ "scan_dtype": dtype,
1087
+ "is_segmented": self.is_segmented,
1088
+ "arg_dtypes": arg_dtypes,
1089
+ "arg_ctypes": arg_ctypes,
1090
+ "scan_expr": _process_code_for_macro(scan_expr),
1091
+ "neutral": _process_code_for_macro(neutral),
1092
+ "is_gpu": bool(self.devices[0].type & cl.device_type.GPU),
1093
+ "double_support": all(
1094
+ has_double_support(dev) for dev in devices),
1095
+ }
1096
+
1097
+ index_typename = dtype_to_ctype(self.index_dtype)
1098
+ scan_typename = dtype_to_ctype(dtype)
1099
+
1100
+ # This key is meant to uniquely identify the non-device parameters for
1101
+ # the scan kernel.
1102
+ self.kernel_key = (
1103
+ self.dtype,
1104
+ tuple(arg.declarator() for arg in self.parsed_args),
1105
+ self.input_expr,
1106
+ scan_expr,
1107
+ neutral,
1108
+ output_statement,
1109
+ is_segment_start_expr,
1110
+ tuple(input_fetch_exprs),
1111
+ index_dtype,
1112
+ name_prefix,
1113
+ preamble,
1114
+ # These depend on dtype_to_ctype(), so their value is independent of
1115
+ # the other variables.
1116
+ index_typename,
1117
+ scan_typename,
1118
+ )
1119
+
1120
+ # }}}
1121
+
1122
+ self.use_lookbehind_update = "prev_item" in self.output_statement
1123
+ self.store_segment_start_flags = (
1124
+ self.is_segmented and self.use_lookbehind_update)
1125
+
1126
+ self.finish_setup()
1127
+
1128
+ # }}}
1129
+
1130
+ @abstractmethod
1131
+ def finish_setup(self) -> None:
1132
+ pass
1133
+
1134
+
1135
+ if not cl._PYOPENCL_NO_CACHE:
1136
+ generic_scan_kernel_cache: WriteOncePersistentDict[Any,
1137
+ Tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
1138
+ _GeneratedFinalUpdateKernelInfo]] = \
1139
+ WriteOncePersistentDict(
1140
+ "pyopencl-generated-scan-kernel-cache-v1",
1141
+ key_builder=_NumpyTypesKeyBuilder(),
1142
+ in_mem_cache_size=0)
1143
+
1144
+
1145
+ class GenericScanKernel(GenericScanKernelBase):
1146
+ """Generates and executes code that performs prefix sums ("scans") on
1147
+ arbitrary types, with many possible tweaks.
1148
+
1149
+ Usage example::
1150
+
1151
+ from pyopencl.scan import GenericScanKernel
1152
+ knl = GenericScanKernel(
1153
+ context, np.int32,
1154
+ arguments="__global int *ary",
1155
+ input_expr="ary[i]",
1156
+ scan_expr="a+b", neutral="0",
1157
+ output_statement="ary[i+1] = item;")
1158
+
1159
+ a = cl.array.arange(queue, 10000, dtype=np.int32)
1160
+ knl(a, queue=queue)
1161
+
1162
+ .. automethod:: __init__
1163
+ .. automethod:: __call__
1164
+ """
1165
+
1166
+ def finish_setup(self) -> None:
1167
+ # Before generating the kernel, see if it's cached.
1168
+ from pyopencl.cache import get_device_cache_id
1169
+ devices_key = tuple(get_device_cache_id(device)
1170
+ for device in self.devices)
1171
+
1172
+ cache_key = (self.kernel_key, devices_key)
1173
+ from_cache = False
1174
+
1175
+ if not cl._PYOPENCL_NO_CACHE:
1176
+ try:
1177
+ result = generic_scan_kernel_cache[cache_key]
1178
+ from_cache = True
1179
+ logger.debug(
1180
+ "cache hit for generated scan kernel '%s'" % self.name_prefix)
1181
+ (
1182
+ self.first_level_scan_gen_info,
1183
+ self.second_level_scan_gen_info,
1184
+ self.final_update_gen_info) = result
1185
+ except KeyError:
1186
+ pass
1187
+
1188
+ if not from_cache:
1189
+ logger.debug(
1190
+ "cache miss for generated scan kernel '%s'" % self.name_prefix)
1191
+ self._finish_setup_impl()
1192
+
1193
+ result = (self.first_level_scan_gen_info,
1194
+ self.second_level_scan_gen_info,
1195
+ self.final_update_gen_info)
1196
+
1197
+ if not cl._PYOPENCL_NO_CACHE:
1198
+ generic_scan_kernel_cache.store_if_not_present(cache_key, result)
1199
+
1200
+ # Build the kernels.
1201
+ self.first_level_scan_info = self.first_level_scan_gen_info.build(
1202
+ self.context, self.options)
1203
+ del self.first_level_scan_gen_info
1204
+
1205
+ self.second_level_scan_info = self.second_level_scan_gen_info.build(
1206
+ self.context, self.options)
1207
+ del self.second_level_scan_gen_info
1208
+
1209
+ self.final_update_info = self.final_update_gen_info.build(
1210
+ self.context, self.options)
1211
+ del self.final_update_gen_info
1212
+
1213
+ def _finish_setup_impl(self) -> None:
1214
+ # {{{ find usable workgroup/k-group size, build first-level scan
1215
+
1216
+ trip_count = 0
1217
+
1218
+ avail_local_mem = min(
1219
+ dev.local_mem_size
1220
+ for dev in self.devices)
1221
+
1222
+ if "CUDA" in self.devices[0].platform.name:
1223
+ # not sure where these go, but roughly this much seems unavailable.
1224
+ avail_local_mem -= 0x400
1225
+
1226
+ is_cpu = self.devices[0].type & cl.device_type.CPU
1227
+ is_gpu = self.devices[0].type & cl.device_type.GPU
1228
+
1229
+ if is_cpu:
1230
+ # (about the widest vector a CPU can support, also taking
1231
+ # into account that CPUs don't hide latency by large work groups
1232
+ max_scan_wg_size = 16
1233
+ wg_size_multiples = 4
1234
+ else:
1235
+ max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
1236
+ wg_size_multiples = 64
1237
+
1238
+ # Intel beignet fails "Out of shared local memory" in test_scan int64
1239
+ # and asserts in test_sort with this enabled:
1240
+ # https://github.com/inducer/pyopencl/pull/238
1241
+ # A beignet bug report (outside of pyopencl) suggests packed structs
1242
+ # (which this is) can even give wrong results:
1243
+ # https://bugs.freedesktop.org/show_bug.cgi?id=98717
1244
+ # TODO: does this also affect Intel Compute Runtime?
1245
+ use_bank_conflict_avoidance = (
1246
+ self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0
1247
+ and is_gpu
1248
+ and "beignet" not in self.devices[0].platform.version.lower())
1249
+
1250
+ # k_group_size should be a power of two because of in-kernel
1251
+ # division by that number.
1252
+
1253
+ solutions = []
1254
+ for k_exp in range(0, 9):
1255
+ for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
1256
+ wg_size_multiples):
1257
+
1258
+ k_group_size = 2**k_exp
1259
+ lmem_use = self.get_local_mem_use(wg_size, k_group_size,
1260
+ use_bank_conflict_avoidance)
1261
+ if lmem_use <= avail_local_mem:
1262
+ solutions.append((wg_size*k_group_size, k_group_size, wg_size))
1263
+
1264
+ if is_gpu:
1265
+ for wg_size_floor in [256, 192, 128]:
1266
+ have_sol_above_floor = any(wg_size >= wg_size_floor
1267
+ for _, _, wg_size in solutions)
1268
+
1269
+ if have_sol_above_floor:
1270
+ # delete all solutions not meeting the wg size floor
1271
+ solutions = [(total, try_k_group_size, try_wg_size)
1272
+ for total, try_k_group_size, try_wg_size in solutions
1273
+ if try_wg_size >= wg_size_floor]
1274
+ break
1275
+
1276
+ _, k_group_size, max_scan_wg_size = max(solutions)
1277
+
1278
+ while True:
1279
+ candidate_scan_gen_info = self.generate_scan_kernel(
1280
+ max_scan_wg_size, self.parsed_args,
1281
+ _process_code_for_macro(self.input_expr),
1282
+ self.is_segment_start_expr,
1283
+ input_fetch_exprs=self.input_fetch_exprs,
1284
+ is_first_level=True,
1285
+ store_segment_start_flags=self.store_segment_start_flags,
1286
+ k_group_size=k_group_size,
1287
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance)
1288
+
1289
+ candidate_scan_info = candidate_scan_gen_info.build(
1290
+ self.context, self.options)
1291
+
1292
+ # Will this device actually let us execute this kernel
1293
+ # at the desired work group size? Building it is the
1294
+ # only way to find out.
1295
+ kernel_max_wg_size = min(
1296
+ candidate_scan_info.kernel.get_work_group_info(
1297
+ cl.kernel_work_group_info.WORK_GROUP_SIZE,
1298
+ dev)
1299
+ for dev in self.devices)
1300
+
1301
+ if candidate_scan_info.wg_size <= kernel_max_wg_size:
1302
+ break
1303
+ else:
1304
+ max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
1305
+
1306
+ trip_count += 1
1307
+ assert trip_count <= 20
1308
+
1309
+ self.first_level_scan_gen_info = candidate_scan_gen_info
1310
+ assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
1311
+ == candidate_scan_info.wg_size)
1312
+
1313
+ # }}}
1314
+
1315
+ # {{{ build second-level scan
1316
+
1317
+ from pyopencl.tools import VectorArg
1318
+ second_level_arguments = self.parsed_args + [
1319
+ VectorArg(self.dtype, "interval_sums")]
1320
+
1321
+ second_level_build_kwargs: Dict[str, Optional[str]] = {}
1322
+ if self.is_segmented:
1323
+ second_level_arguments.append(
1324
+ VectorArg(self.index_dtype,
1325
+ "g_first_segment_start_in_interval_input"))
1326
+
1327
+ # is_segment_start_expr answers the question "should previous sums
1328
+ # spill over into this item". And since
1329
+ # g_first_segment_start_in_interval_input answers the question if a
1330
+ # segment boundary was found in an interval of data, then if not,
1331
+ # it's ok to spill over.
1332
+ second_level_build_kwargs["is_segment_start_expr"] = \
1333
+ "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
1334
+ else:
1335
+ second_level_build_kwargs["is_segment_start_expr"] = None
1336
+
1337
+ self.second_level_scan_gen_info = self.generate_scan_kernel(
1338
+ max_scan_wg_size,
1339
+ arguments=second_level_arguments,
1340
+ input_expr="interval_sums[i]",
1341
+ input_fetch_exprs=[],
1342
+ is_first_level=False,
1343
+ store_segment_start_flags=False,
1344
+ k_group_size=k_group_size,
1345
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1346
+ **second_level_build_kwargs)
1347
+
1348
+ # }}}
1349
+
1350
+ # {{{ generate final update kernel
1351
+
1352
+ update_wg_size = min(max_scan_wg_size, 256)
1353
+
1354
+ final_update_tpl = _make_template(UPDATE_SOURCE)
1355
+ final_update_src = str(final_update_tpl.render(
1356
+ wg_size=update_wg_size,
1357
+ output_statement=self.output_statement,
1358
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1359
+ argument_signature=", ".join(
1360
+ arg.declarator() for arg in self.parsed_args),
1361
+ is_segment_start_expr=self.is_segment_start_expr,
1362
+ input_expr=_process_code_for_macro(self.input_expr),
1363
+ use_lookbehind_update=self.use_lookbehind_update,
1364
+ **self.code_variables))
1365
+
1366
+ update_scalar_arg_dtypes = (
1367
+ get_arg_list_scalar_arg_dtypes(self.parsed_args)
1368
+ + [self.index_dtype, self.index_dtype, None, None])
1369
+ if self.is_segmented:
1370
+ # g_first_segment_start_in_interval
1371
+ update_scalar_arg_dtypes.append(None)
1372
+ if self.store_segment_start_flags:
1373
+ update_scalar_arg_dtypes.append(None) # g_segment_start_flags
1374
+
1375
+ self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
1376
+ final_update_src,
1377
+ self.name_prefix + "_final_update",
1378
+ update_scalar_arg_dtypes,
1379
+ update_wg_size)
1380
+
1381
+ # }}}
1382
+
1383
+ # {{{ scan kernel build/properties
1384
+
1385
+ def get_local_mem_use(
1386
+ self, k_group_size: int, wg_size: int,
1387
+ use_bank_conflict_avoidance: bool) -> int:
1388
+ arg_dtypes = {}
1389
+ for arg in self.parsed_args:
1390
+ arg_dtypes[arg.name] = arg.dtype
1391
+
1392
+ fetch_expr_offsets: Dict[str, Set] = {}
1393
+ for _name, arg_name, ife_offset in self.input_fetch_exprs:
1394
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
1395
+
1396
+ itemsize = self.dtype.itemsize
1397
+ if use_bank_conflict_avoidance:
1398
+ itemsize += 4
1399
+
1400
+ return (
1401
+ # ldata
1402
+ itemsize*(k_group_size+1)*(wg_size+1)
1403
+
1404
+ # l_segment_start_flags
1405
+ + k_group_size*wg_size
1406
+
1407
+ # l_first_segment_start_in_subtree
1408
+ + self.index_dtype.itemsize*wg_size
1409
+
1410
+ + k_group_size*wg_size*sum(
1411
+ arg_dtypes[arg_name].itemsize
1412
+ for arg_name, ife_offsets in list(fetch_expr_offsets.items())
1413
+ if -1 in ife_offsets or len(ife_offsets) > 1))
1414
+
1415
+ def generate_scan_kernel(
1416
+ self,
1417
+ max_wg_size: int,
1418
+ arguments: List[DtypedArgument],
1419
+ input_expr: str,
1420
+ is_segment_start_expr: Optional[str],
1421
+ input_fetch_exprs: List[Tuple[str, str, int]],
1422
+ is_first_level: bool,
1423
+ store_segment_start_flags: bool,
1424
+ k_group_size: int,
1425
+ use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
1426
+ scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
1427
+
1428
+ # Empirically found on Nv hardware: no need to be bigger than this size
1429
+ wg_size = _round_down_to_power_of_2(
1430
+ min(max_wg_size, 256))
1431
+
1432
+ kernel_name = self.code_variables["name_prefix"]
1433
+ if is_first_level:
1434
+ kernel_name += "_lev1"
1435
+ else:
1436
+ kernel_name += "_lev2"
1437
+
1438
+ scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
1439
+ scan_src = str(scan_tpl.render(
1440
+ wg_size=wg_size,
1441
+ input_expr=input_expr,
1442
+ k_group_size=k_group_size,
1443
+ arg_offset_adjustment=get_arg_offset_adjuster_code(arguments),
1444
+ argument_signature=", ".join(arg.declarator() for arg in arguments),
1445
+ is_segment_start_expr=is_segment_start_expr,
1446
+ input_fetch_exprs=input_fetch_exprs,
1447
+ is_first_level=is_first_level,
1448
+ store_segment_start_flags=store_segment_start_flags,
1449
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1450
+ kernel_name=kernel_name,
1451
+ **self.code_variables))
1452
+
1453
+ scalar_arg_dtypes.extend(
1454
+ (None, self.index_dtype, self.index_dtype))
1455
+ if is_first_level:
1456
+ scalar_arg_dtypes.append(None) # interval_results
1457
+ if self.is_segmented and is_first_level:
1458
+ scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval
1459
+ if store_segment_start_flags:
1460
+ scalar_arg_dtypes.append(None) # g_segment_start_flags
1461
+
1462
+ return _GeneratedScanKernelInfo(
1463
+ scan_src=scan_src,
1464
+ kernel_name=kernel_name,
1465
+ scalar_arg_dtypes=scalar_arg_dtypes,
1466
+ wg_size=wg_size,
1467
+ k_group_size=k_group_size)
1468
+
1469
+ # }}}
1470
+
1471
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1472
+ """
1473
+ |std-enqueue-blurb|
1474
+
1475
+ .. note::
1476
+
1477
+ The returned :class:`pyopencl.Event` corresponds only to part of the
1478
+ execution of the scan. It is not suitable for profiling.
1479
+
1480
+ :arg queue: queue on which to execute the scan. If not given, the
1481
+ queue of the first :class:`pyopencl.array.Array` in *args* is used
1482
+ :arg allocator: an allocator for the temporary arrays and results. If
1483
+ not given, the allocator of the first :class:`pyopencl.array.Array`
1484
+ in *args* is used.
1485
+ :arg size: specify the length of the scan to be carried out. If not
1486
+ given, this length is inferred from the first argument
1487
+ :arg wait_for: a :class:`list` of events to wait for.
1488
+ """
1489
+
1490
+ # {{{ argument processing
1491
+
1492
+ allocator = kwargs.get("allocator")
1493
+ queue = kwargs.get("queue")
1494
+ n = kwargs.get("size")
1495
+ wait_for = kwargs.get("wait_for")
1496
+
1497
+ if wait_for is None:
1498
+ wait_for = []
1499
+ else:
1500
+ wait_for = list(wait_for)
1501
+
1502
+ if len(args) != len(self.parsed_args):
1503
+ raise TypeError(
1504
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1505
+
1506
+ first_array = args[self.first_array_idx]
1507
+ allocator = allocator or first_array.allocator
1508
+ queue = queue or first_array.queue
1509
+
1510
+ if n is None:
1511
+ n, = first_array.shape
1512
+
1513
+ if n == 0:
1514
+ # We're done here. (But pretend to return an event.)
1515
+ return cl.enqueue_marker(queue, wait_for=wait_for)
1516
+
1517
+ data_args = []
1518
+ for arg_descr, arg_val in zip(self.parsed_args, args):
1519
+ from pyopencl.tools import VectorArg
1520
+ if isinstance(arg_descr, VectorArg):
1521
+ data_args.append(arg_val.base_data)
1522
+ if arg_descr.with_offset:
1523
+ data_args.append(arg_val.offset)
1524
+ wait_for.extend(arg_val.events)
1525
+ else:
1526
+ data_args.append(arg_val)
1527
+
1528
+ # }}}
1529
+
1530
+ l1_info = self.first_level_scan_info
1531
+ l2_info = self.second_level_scan_info
1532
+
1533
+ # see CL source above for terminology
1534
+ unit_size = l1_info.wg_size * l1_info.k_group_size
1535
+ max_intervals = 3*max(dev.max_compute_units for dev in self.devices)
1536
+
1537
+ from pytools import uniform_interval_splitting
1538
+ interval_size, num_intervals = uniform_interval_splitting(
1539
+ n, unit_size, max_intervals)
1540
+
1541
+ # {{{ allocate some buffers
1542
+
1543
+ interval_results = cl.array.empty(queue,
1544
+ num_intervals, dtype=self.dtype,
1545
+ allocator=allocator)
1546
+
1547
+ partial_scan_buffer = cl.array.empty(
1548
+ queue, n, dtype=self.dtype,
1549
+ allocator=allocator)
1550
+
1551
+ if self.store_segment_start_flags:
1552
+ segment_start_flags = cl.array.empty(
1553
+ queue, n, dtype=np.bool_,
1554
+ allocator=allocator)
1555
+
1556
+ # }}}
1557
+
1558
+ # {{{ first level scan of interval (one interval per block)
1559
+
1560
+ scan1_args = data_args + [
1561
+ partial_scan_buffer.data, n, interval_size, interval_results.data,
1562
+ ]
1563
+
1564
+ if self.is_segmented:
1565
+ first_segment_start_in_interval = cl.array.empty(queue,
1566
+ num_intervals, dtype=self.index_dtype,
1567
+ allocator=allocator)
1568
+ scan1_args.append(first_segment_start_in_interval.data)
1569
+
1570
+ if self.store_segment_start_flags:
1571
+ scan1_args.append(segment_start_flags.data)
1572
+
1573
+ l1_evt = l1_info.kernel(
1574
+ queue, (num_intervals,), (l1_info.wg_size,),
1575
+ *scan1_args, g_times_l=True, wait_for=wait_for)
1576
+
1577
+ # }}}
1578
+
1579
+ # {{{ second level scan of per-interval results
1580
+
1581
+ # can scan at most one interval
1582
+ assert interval_size >= num_intervals
1583
+
1584
+ scan2_args = data_args + [
1585
+ interval_results.data, # interval_sums
1586
+ ]
1587
+ if self.is_segmented:
1588
+ scan2_args.append(first_segment_start_in_interval.data)
1589
+ scan2_args = scan2_args + [
1590
+ interval_results.data, # partial_scan_buffer
1591
+ num_intervals, interval_size]
1592
+
1593
+ l2_evt = l2_info.kernel(
1594
+ queue, (1,), (l1_info.wg_size,),
1595
+ *scan2_args, g_times_l=True, wait_for=[l1_evt])
1596
+
1597
+ # }}}
1598
+
1599
+ # {{{ update intervals with result of interval scan
1600
+
1601
+ upd_args = data_args + [
1602
+ n, interval_size, interval_results.data, partial_scan_buffer.data]
1603
+ if self.is_segmented:
1604
+ upd_args.append(first_segment_start_in_interval.data)
1605
+ if self.store_segment_start_flags:
1606
+ upd_args.append(segment_start_flags.data)
1607
+
1608
+ return self.final_update_info.kernel(
1609
+ queue, (num_intervals,),
1610
+ (self.final_update_info.update_wg_size,),
1611
+ *upd_args, g_times_l=True, wait_for=[l2_evt])
1612
+
1613
+ # }}}
1614
+
1615
+ # }}}
1616
+
1617
+
1618
+ # {{{ debug kernel
1619
+
1620
+ DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + r"""//CL//
1621
+
1622
+ KERNEL
1623
+ REQD_WG_SIZE(1, 1, 1)
1624
+ void ${name_prefix}_debug_scan(
1625
+ __global scan_type *scan_tmp,
1626
+ ${argument_signature},
1627
+ const index_type N)
1628
+ {
1629
+ scan_type current = ${neutral};
1630
+ scan_type prev;
1631
+
1632
+ ${arg_offset_adjustment}
1633
+
1634
+ for (index_type i = 0; i < N; ++i)
1635
+ {
1636
+ %for name, arg_name, ife_offset in input_fetch_exprs:
1637
+ ${arg_ctypes[arg_name]} ${name};
1638
+ %if ife_offset < 0:
1639
+ if (i+${ife_offset} >= 0)
1640
+ ${name} = ${arg_name}[i+${ife_offset}];
1641
+ %else:
1642
+ ${name} = ${arg_name}[i];
1643
+ %endif
1644
+ %endfor
1645
+
1646
+ scan_type my_val = INPUT_EXPR(i);
1647
+
1648
+ prev = current;
1649
+ %if is_segmented:
1650
+ bool is_seg_start = IS_SEG_START(i, my_val);
1651
+ %endif
1652
+
1653
+ current = SCAN_EXPR(prev, my_val,
1654
+ %if is_segmented:
1655
+ is_seg_start
1656
+ %else:
1657
+ false
1658
+ %endif
1659
+ );
1660
+ scan_tmp[i] = current;
1661
+ }
1662
+
1663
+ scan_type last_item = scan_tmp[N-1];
1664
+
1665
+ for (index_type i = 0; i < N; ++i)
1666
+ {
1667
+ scan_type item = scan_tmp[i];
1668
+ scan_type prev_item;
1669
+ if (i)
1670
+ prev_item = scan_tmp[i-1];
1671
+ else
1672
+ prev_item = ${neutral};
1673
+
1674
+ {
1675
+ ${output_statement};
1676
+ }
1677
+ }
1678
+ }
1679
+ """
1680
+
1681
+
1682
+ class GenericDebugScanKernel(GenericScanKernelBase):
1683
+ """
1684
+ Performs the same function and has the same interface as
1685
+ :class:`GenericScanKernel`, but uses a dead-simple, sequential scan. Works
1686
+ best on CPU platforms, and helps isolate bugs in scans by removing the
1687
+ potential for issues originating in parallel execution.
1688
+
1689
+ .. automethod:: __call__
1690
+ """
1691
+
1692
+ def finish_setup(self) -> None:
1693
+ scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
1694
+ scan_src = str(scan_tpl.render(
1695
+ output_statement=self.output_statement,
1696
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1697
+ argument_signature=", ".join(
1698
+ arg.declarator() for arg in self.parsed_args),
1699
+ is_segment_start_expr=self.is_segment_start_expr,
1700
+ input_expr=_process_code_for_macro(self.input_expr),
1701
+ input_fetch_exprs=self.input_fetch_exprs,
1702
+ wg_size=1,
1703
+ **self.code_variables))
1704
+
1705
+ scan_prg = cl.Program(self.context, scan_src).build(self.options)
1706
+ self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
1707
+ scalar_arg_dtypes = (
1708
+ [None]
1709
+ + get_arg_list_scalar_arg_dtypes(self.parsed_args)
1710
+ + [self.index_dtype])
1711
+ self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
1712
+
1713
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1714
+ """See :meth:`GenericScanKernel.__call__`."""
1715
+
1716
+ # {{{ argument processing
1717
+
1718
+ allocator = kwargs.get("allocator")
1719
+ queue = kwargs.get("queue")
1720
+ n = kwargs.get("size")
1721
+ wait_for = kwargs.get("wait_for")
1722
+
1723
+ if wait_for is None:
1724
+ wait_for = []
1725
+ else:
1726
+ # We'll be modifying it below.
1727
+ wait_for = list(wait_for)
1728
+
1729
+ if len(args) != len(self.parsed_args):
1730
+ raise TypeError(
1731
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1732
+
1733
+ first_array = args[self.first_array_idx]
1734
+ allocator = allocator or first_array.allocator
1735
+ queue = queue or first_array.queue
1736
+
1737
+ if n is None:
1738
+ n, = first_array.shape
1739
+
1740
+ scan_tmp = cl.array.empty(queue,
1741
+ n, dtype=self.dtype,
1742
+ allocator=allocator)
1743
+
1744
+ data_args = [scan_tmp.data]
1745
+ from pyopencl.tools import VectorArg
1746
+ for arg_descr, arg_val in zip(self.parsed_args, args):
1747
+ if isinstance(arg_descr, VectorArg):
1748
+ data_args.append(arg_val.base_data)
1749
+ if arg_descr.with_offset:
1750
+ data_args.append(arg_val.offset)
1751
+ wait_for.extend(arg_val.events)
1752
+ else:
1753
+ data_args.append(arg_val)
1754
+
1755
+ # }}}
1756
+
1757
+ return self.kernel(queue, (1,), (1,),
1758
+ *(data_args + [n]), wait_for=wait_for)
1759
+
1760
+ # }}}
1761
+
1762
+
1763
+ # {{{ compatibility interface
1764
+
1765
+ class _LegacyScanKernelBase(GenericScanKernel):
1766
+ def __init__(self, ctx, dtype,
1767
+ scan_expr, neutral=None,
1768
+ name_prefix="scan", options=None, preamble="", devices=None):
1769
+ scan_ctype = dtype_to_ctype(dtype)
1770
+ GenericScanKernel.__init__(self,
1771
+ ctx, dtype,
1772
+ arguments="__global {} *input_ary, __global {} *output_ary".format(
1773
+ scan_ctype, scan_ctype),
1774
+ input_expr="input_ary[i]",
1775
+ scan_expr=scan_expr,
1776
+ neutral=neutral,
1777
+ output_statement=self.ary_output_statement,
1778
+ options=options, preamble=preamble, devices=devices)
1779
+
1780
+ @property
1781
+ def ary_output_statement(self):
1782
+ raise NotImplementedError
1783
+
1784
+ def __call__(self, input_ary, output_ary=None, allocator=None, queue=None):
1785
+ allocator = allocator or input_ary.allocator
1786
+ queue = queue or input_ary.queue or output_ary.queue
1787
+
1788
+ if output_ary is None:
1789
+ output_ary = input_ary
1790
+
1791
+ if isinstance(output_ary, (str, str)) and output_ary == "new":
1792
+ output_ary = cl.array.empty_like(input_ary, allocator=allocator)
1793
+
1794
+ if input_ary.shape != output_ary.shape:
1795
+ raise ValueError("input and output must have the same shape")
1796
+
1797
+ if not input_ary.flags.forc:
1798
+ raise RuntimeError("ScanKernel cannot "
1799
+ "deal with non-contiguous arrays")
1800
+
1801
+ n, = input_ary.shape
1802
+
1803
+ if not n:
1804
+ return output_ary
1805
+
1806
+ GenericScanKernel.__call__(self,
1807
+ input_ary, output_ary, allocator=allocator, queue=queue)
1808
+
1809
+ return output_ary
1810
+
1811
+
1812
+ class InclusiveScanKernel(_LegacyScanKernelBase):
1813
+ ary_output_statement = "output_ary[i] = item;"
1814
+
1815
+
1816
+ class ExclusiveScanKernel(_LegacyScanKernelBase):
1817
+ ary_output_statement = "output_ary[i] = prev_item;"
1818
+
1819
+ # }}}
1820
+
1821
+
1822
+ # {{{ template
1823
+
1824
+ class ScanTemplate(KernelTemplateBase):
1825
+ def __init__(
1826
+ self,
1827
+ arguments: Union[str, List[DtypedArgument]],
1828
+ input_expr: str,
1829
+ scan_expr: str,
1830
+ neutral: Optional[str],
1831
+ output_statement: str,
1832
+ is_segment_start_expr: Optional[str] = None,
1833
+ input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
1834
+ name_prefix: str = "scan",
1835
+ preamble: str = "",
1836
+ template_processor: Any = None) -> None:
1837
+ super().__init__(template_processor=template_processor)
1838
+
1839
+ if input_fetch_exprs is None:
1840
+ input_fetch_exprs = []
1841
+
1842
+ self.arguments = arguments
1843
+ self.input_expr = input_expr
1844
+ self.scan_expr = scan_expr
1845
+ self.neutral = neutral
1846
+ self.output_statement = output_statement
1847
+ self.is_segment_start_expr = is_segment_start_expr
1848
+ self.input_fetch_exprs = input_fetch_exprs
1849
+ self.name_prefix = name_prefix
1850
+ self.preamble = preamble
1851
+
1852
+ def build_inner(self, context, type_aliases=(), var_values=(),
1853
+ more_preamble="", more_arguments=(), declare_types=(),
1854
+ options=None, devices=None, scan_cls=GenericScanKernel):
1855
+ renderer = self.get_renderer(type_aliases, var_values, context, options)
1856
+
1857
+ arg_list = renderer.render_argument_list(self.arguments, more_arguments)
1858
+
1859
+ type_decl_preamble = renderer.get_type_decl_preamble(
1860
+ context.devices[0], declare_types, arg_list)
1861
+
1862
+ return scan_cls(context, renderer.type_aliases["scan_t"],
1863
+ renderer.render_argument_list(self.arguments, more_arguments),
1864
+ renderer(self.input_expr), renderer(self.scan_expr),
1865
+ renderer(self.neutral), renderer(self.output_statement),
1866
+ is_segment_start_expr=renderer(self.is_segment_start_expr),
1867
+ input_fetch_exprs=self.input_fetch_exprs,
1868
+ index_dtype=renderer.type_aliases.get("index_t", np.int32),
1869
+ name_prefix=renderer(self.name_prefix), options=options,
1870
+ preamble=(
1871
+ type_decl_preamble
1872
+ + "\n"
1873
+ + renderer(self.preamble + "\n" + more_preamble)),
1874
+ devices=devices)
1875
+
1876
+ # }}}
1877
+
1878
+
1879
+ # {{{ 'canned' scan kernels
1880
+
1881
+ @context_dependent_memoize
1882
+ def get_cumsum_kernel(context, input_dtype, output_dtype):
1883
+ from pyopencl.tools import VectorArg
1884
+ return GenericScanKernel(
1885
+ context, output_dtype,
1886
+ arguments=[
1887
+ VectorArg(input_dtype, "input"),
1888
+ VectorArg(output_dtype, "output"),
1889
+ ],
1890
+ input_expr="input[i]",
1891
+ scan_expr="a+b", neutral="0",
1892
+ output_statement="""
1893
+ output[i] = item;
1894
+ """)
1895
+
1896
+ # }}}
1897
+
1898
+ # vim: filetype=pyopencl:fdm=marker