pyopencl 2025.2.5__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

Files changed (47) hide show
  1. pyopencl/.libs/libOpenCL-83a5a7fd.so.1.0.0 +0 -0
  2. pyopencl/__init__.py +1995 -0
  3. pyopencl/_cl.cpython-310-x86_64-linux-gnu.so +0 -0
  4. pyopencl/_cl.pyi +2006 -0
  5. pyopencl/_cluda.py +57 -0
  6. pyopencl/_monkeypatch.py +1069 -0
  7. pyopencl/_mymako.py +17 -0
  8. pyopencl/algorithm.py +1454 -0
  9. pyopencl/array.py +3441 -0
  10. pyopencl/bitonic_sort.py +245 -0
  11. pyopencl/bitonic_sort_templates.py +597 -0
  12. pyopencl/cache.py +535 -0
  13. pyopencl/capture_call.py +200 -0
  14. pyopencl/characterize/__init__.py +463 -0
  15. pyopencl/characterize/performance.py +240 -0
  16. pyopencl/cl/pyopencl-airy.cl +324 -0
  17. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  18. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  19. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  20. pyopencl/cl/pyopencl-complex.h +303 -0
  21. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  22. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  23. pyopencl/cl/pyopencl-random123/array.h +325 -0
  24. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  25. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  26. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  27. pyopencl/clmath.py +282 -0
  28. pyopencl/clrandom.py +412 -0
  29. pyopencl/cltypes.py +202 -0
  30. pyopencl/compyte/.gitignore +21 -0
  31. pyopencl/compyte/__init__.py +0 -0
  32. pyopencl/compyte/array.py +241 -0
  33. pyopencl/compyte/dtypes.py +316 -0
  34. pyopencl/compyte/pyproject.toml +52 -0
  35. pyopencl/elementwise.py +1178 -0
  36. pyopencl/invoker.py +417 -0
  37. pyopencl/ipython_ext.py +70 -0
  38. pyopencl/py.typed +0 -0
  39. pyopencl/reduction.py +815 -0
  40. pyopencl/scan.py +1916 -0
  41. pyopencl/tools.py +1565 -0
  42. pyopencl/typing.py +61 -0
  43. pyopencl/version.py +11 -0
  44. pyopencl-2025.2.5.dist-info/METADATA +109 -0
  45. pyopencl-2025.2.5.dist-info/RECORD +47 -0
  46. pyopencl-2025.2.5.dist-info/WHEEL +6 -0
  47. pyopencl-2025.2.5.dist-info/licenses/LICENSE +104 -0
pyopencl/scan.py ADDED
@@ -0,0 +1,1916 @@
1
+ """Scan primitive."""
2
+ from __future__ import annotations
3
+
4
+
5
+ __copyright__ = """
6
+ Copyright 2011-2012 Andreas Kloeckner
7
+ Copyright 2008-2011 NVIDIA Corporation
8
+ """
9
+
10
+ __license__ = """
11
+ Licensed under the Apache License, Version 2.0 (the "License");
12
+ you may not use this file except in compliance with the License.
13
+ You may obtain a copy of the License at
14
+
15
+ https://www.apache.org/licenses/LICENSE-2.0
16
+
17
+ Unless required by applicable law or agreed to in writing, software
18
+ distributed under the License is distributed on an "AS IS" BASIS,
19
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ See the License for the specific language governing permissions and
21
+ limitations under the License.
22
+
23
+ Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
24
+ """
25
+
26
+ import logging
27
+ from abc import ABC, abstractmethod
28
+ from dataclasses import dataclass
29
+ from typing import Any
30
+
31
+ import numpy as np
32
+
33
+ from pytools.persistent_dict import WriteOncePersistentDict
34
+
35
+ import pyopencl as cl
36
+ import pyopencl._mymako as mako
37
+ import pyopencl.array as cl_array
38
+ from pyopencl._cluda import CLUDA_PREAMBLE
39
+ from pyopencl.tools import (
40
+ DtypedArgument,
41
+ KernelTemplateBase,
42
+ _NumpyTypesKeyBuilder,
43
+ _process_code_for_macro,
44
+ bitlog2,
45
+ context_dependent_memoize,
46
+ dtype_to_ctype,
47
+ get_arg_list_scalar_arg_dtypes,
48
+ get_arg_offset_adjuster_code,
49
+ )
50
+
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ # {{{ preamble
56
+
57
+ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
58
+ #define WG_SIZE ${wg_size}
59
+
60
+ #define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr}
61
+ #define INPUT_EXPR(i) (${input_expr})
62
+ %if is_segmented:
63
+ #define IS_SEG_START(i, a) (${is_segment_start_expr})
64
+ %endif
65
+
66
+ ${preamble}
67
+
68
+ typedef ${dtype_to_ctype(scan_dtype)} scan_type;
69
+ typedef ${dtype_to_ctype(index_dtype)} index_type;
70
+
71
+ // NO_SEG_BOUNDARY is the largest representable integer in index_type.
72
+ // This assumption is used in code below.
73
+ #define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
74
+ """
75
+
76
+ # }}}
77
+
78
+ # {{{ main scan code
79
+
80
+ # Algorithm: Each work group is responsible for one contiguous
81
+ # 'interval'. There are just enough intervals to fill all compute
82
+ # units. Intervals are split into 'units'. A unit is what gets
83
+ # worked on in parallel by one work group.
84
+ #
85
+ # in index space:
86
+ # interval > unit > local-parallel > k-group
87
+ #
88
+ # (Note that there is also a transpose in here: The data is read
89
+ # with local ids along linear index order.)
90
+ #
91
+ # Each unit has two axes--the local-id axis and the k axis.
92
+ #
93
+ # unit 0:
94
+ # | | | | | | | | | | ----> lid
95
+ # | | | | | | | | | |
96
+ # | | | | | | | | | |
97
+ # | | | | | | | | | |
98
+ # | | | | | | | | | |
99
+ #
100
+ # |
101
+ # v k (fastest-moving in linear index)
102
+ #
103
+ # unit 1:
104
+ # | | | | | | | | | | ----> lid
105
+ # | | | | | | | | | |
106
+ # | | | | | | | | | |
107
+ # | | | | | | | | | |
108
+ # | | | | | | | | | |
109
+ #
110
+ # |
111
+ # v k (fastest-moving in linear index)
112
+ #
113
+ # ...
114
+ #
115
+ # At a device-global level, this is a three-phase algorithm, in
116
+ # which first each interval does its local scan, then a scan
117
+ # across intervals exchanges data globally, and the final update
118
+ # adds the exchanged sums to each interval.
119
+ #
120
+ # Exclusive scan is realized by allowing look-behind (access to the
121
+ # preceding item) in the final update, by means of a local shift.
122
+ #
123
+ # NOTE: All segment_start_in_X indices are relative to the start
124
+ # of the array.
125
+
126
+ SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL//
127
+
128
+ #define K ${k_group_size}
129
+
130
+ // #define DEBUG
131
+ #ifdef DEBUG
132
+ #define pycl_printf(ARGS) printf ARGS
133
+ #else
134
+ #define pycl_printf(ARGS) /* */
135
+ #endif
136
+
137
+ KERNEL
138
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
139
+ void ${kernel_name}(
140
+ ${argument_signature},
141
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer,
142
+ const index_type N,
143
+ const index_type interval_size
144
+ %if is_first_level:
145
+ , GLOBAL_MEM scan_type *restrict interval_results
146
+ %endif
147
+ %if is_segmented and is_first_level:
148
+ // NO_SEG_BOUNDARY if no segment boundary in interval.
149
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
150
+ %endif
151
+ %if store_segment_start_flags:
152
+ , GLOBAL_MEM char *restrict g_segment_start_flags
153
+ %endif
154
+ )
155
+ {
156
+ ${arg_offset_adjustment}
157
+
158
+ // index K in first dimension used for carry storage
159
+ %if use_bank_conflict_avoidance:
160
+ // Avoid bank conflicts by adding a single 32-bit value to the size of
161
+ // the scan type.
162
+ struct __attribute__ ((__packed__)) wrapped_scan_type
163
+ {
164
+ scan_type value;
165
+ int dummy;
166
+ };
167
+ %else:
168
+ struct wrapped_scan_type
169
+ {
170
+ scan_type value;
171
+ };
172
+ %endif
173
+ // padded in WG_SIZE to avoid bank conflicts
174
+ LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1];
175
+
176
+ %if is_segmented:
177
+ LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
178
+ LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
179
+
180
+ // only relevant/populated for local id 0
181
+ index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
182
+
183
+ index_type first_segment_start_in_k_group, first_segment_start_in_subtree;
184
+ %endif
185
+
186
+ // {{{ declare local data for input_fetch_exprs if any of them are stenciled
187
+
188
+ <%
189
+ fetch_expr_offsets = {}
190
+ for name, arg_name, ife_offset in input_fetch_exprs:
191
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
192
+
193
+ local_fetch_expr_args = set(
194
+ arg_name
195
+ for arg_name, ife_offsets in fetch_expr_offsets.items()
196
+ if -1 in ife_offsets or len(ife_offsets) > 1)
197
+ %>
198
+
199
+ %for arg_name in local_fetch_expr_args:
200
+ LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
201
+ %endfor
202
+
203
+ // }}}
204
+
205
+ const index_type interval_begin = interval_size * GID_0;
206
+ const index_type interval_end = min(interval_begin + interval_size, N);
207
+
208
+ const index_type unit_size = K * WG_SIZE;
209
+
210
+ index_type unit_base = interval_begin;
211
+
212
+ %for is_tail in [False, True]:
213
+
214
+ %if not is_tail:
215
+ for(; unit_base + unit_size <= interval_end; unit_base += unit_size)
216
+ %else:
217
+ if (unit_base < interval_end)
218
+ %endif
219
+
220
+ {
221
+
222
+ // {{{ carry out input_fetch_exprs
223
+ // (if there are ones that need to be fetched into local)
224
+
225
+ %if local_fetch_expr_args:
226
+ for(index_type k = 0; k < K; k++)
227
+ {
228
+ const index_type offset = k*WG_SIZE + LID_0;
229
+ const index_type read_i = unit_base + offset;
230
+
231
+ %for arg_name in local_fetch_expr_args:
232
+ %if is_tail:
233
+ if (read_i < interval_end)
234
+ %endif
235
+ {
236
+ l_${arg_name}[offset] = ${arg_name}[read_i];
237
+ }
238
+ %endfor
239
+ }
240
+
241
+ local_barrier();
242
+ %endif
243
+
244
+ pycl_printf(("after input_fetch_exprs\n"));
245
+
246
+ // }}}
247
+
248
+ // {{{ read a unit's worth of data from global
249
+
250
+ for(index_type k = 0; k < K; k++)
251
+ {
252
+ const index_type offset = k*WG_SIZE + LID_0;
253
+ const index_type read_i = unit_base + offset;
254
+
255
+ %if is_tail:
256
+ if (read_i < interval_end)
257
+ %endif
258
+ {
259
+ %for name, arg_name, ife_offset in input_fetch_exprs:
260
+ ${arg_ctypes[arg_name]} ${name};
261
+
262
+ %if arg_name in local_fetch_expr_args:
263
+ if (offset + ${ife_offset} >= 0)
264
+ ${name} = l_${arg_name}[offset + ${ife_offset}];
265
+ else if (read_i + ${ife_offset} >= 0)
266
+ ${name} = ${arg_name}[read_i + ${ife_offset}];
267
+ /*
268
+ else
269
+ if out of bounds, name is left undefined */
270
+
271
+ %else:
272
+ // ${arg_name} gets fetched directly from global
273
+ ${name} = ${arg_name}[read_i];
274
+
275
+ %endif
276
+ %endfor
277
+
278
+ scan_type scan_value = INPUT_EXPR(read_i);
279
+
280
+ const index_type o_mod_k = offset % K;
281
+ const index_type o_div_k = offset / K;
282
+ ldata[o_mod_k][o_div_k].value = scan_value;
283
+
284
+ %if is_segmented:
285
+ bool is_seg_start = IS_SEG_START(read_i, scan_value);
286
+ l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start;
287
+ %endif
288
+ %if store_segment_start_flags:
289
+ g_segment_start_flags[read_i] = is_seg_start;
290
+ %endif
291
+ }
292
+ }
293
+
294
+ pycl_printf(("after read from global\n"));
295
+
296
+ // }}}
297
+
298
+ // {{{ carry in from previous unit, if applicable
299
+
300
+ %if is_segmented:
301
+ local_barrier();
302
+
303
+ first_segment_start_in_k_group = NO_SEG_BOUNDARY;
304
+ if (l_segment_start_flags[0][LID_0])
305
+ first_segment_start_in_k_group = unit_base + K*LID_0;
306
+ %endif
307
+
308
+ if (LID_0 == 0 && unit_base != interval_begin)
309
+ {
310
+ scan_type tmp = ldata[K][WG_SIZE - 1].value;
311
+ scan_type tmp_aux = ldata[0][0].value;
312
+
313
+ ldata[0][0].value = SCAN_EXPR(
314
+ tmp, tmp_aux,
315
+ %if is_segmented:
316
+ (l_segment_start_flags[0][0])
317
+ %else:
318
+ false
319
+ %endif
320
+ );
321
+ }
322
+
323
+ pycl_printf(("after carry-in\n"));
324
+
325
+ // }}}
326
+
327
+ local_barrier();
328
+
329
+ // {{{ scan along k (sequentially in each work item)
330
+
331
+ scan_type sum = ldata[0][LID_0].value;
332
+
333
+ %if is_tail:
334
+ const index_type offset_end = interval_end - unit_base;
335
+ %endif
336
+
337
+ for (index_type k = 1; k < K; k++)
338
+ {
339
+ %if is_tail:
340
+ if ((index_type) (K * LID_0 + k) < offset_end)
341
+ %endif
342
+ {
343
+ scan_type tmp = ldata[k][LID_0].value;
344
+
345
+ %if is_segmented:
346
+ index_type seq_i = unit_base + K*LID_0 + k;
347
+
348
+ if (l_segment_start_flags[k][LID_0])
349
+ {
350
+ first_segment_start_in_k_group = min(
351
+ first_segment_start_in_k_group,
352
+ seq_i);
353
+ }
354
+ %endif
355
+
356
+ sum = SCAN_EXPR(sum, tmp,
357
+ %if is_segmented:
358
+ (l_segment_start_flags[k][LID_0])
359
+ %else:
360
+ false
361
+ %endif
362
+ );
363
+
364
+ ldata[k][LID_0].value = sum;
365
+ }
366
+ }
367
+
368
+ pycl_printf(("after scan along k\n"));
369
+
370
+ // }}}
371
+
372
+ // store carry in out-of-bounds (padding) array entry (index K) in
373
+ // the K direction
374
+ ldata[K][LID_0].value = sum;
375
+
376
+ %if is_segmented:
377
+ l_first_segment_start_in_subtree[LID_0] =
378
+ first_segment_start_in_k_group;
379
+ %endif
380
+
381
+ local_barrier();
382
+
383
+ // {{{ tree-based local parallel scan
384
+
385
+ // This tree-based scan works as follows:
386
+ // - Each work item adds the previous item to its current state
387
+ // - barrier
388
+ // - Each work item adds in the item from two positions to the left
389
+ // - barrier
390
+ // - Each work item adds in the item from four positions to the left
391
+ // ...
392
+ // At the end, each item has summed all prior items.
393
+
394
+ // across k groups, along local id
395
+ // (uses out-of-bounds k=K array entry for storage)
396
+
397
+ scan_type val = ldata[K][LID_0].value;
398
+
399
+ <% scan_offset = 1 %>
400
+
401
+ % while scan_offset <= wg_size:
402
+ // {{{ reads from local allowed, writes to local not allowed
403
+
404
+ if (LID_0 >= ${scan_offset})
405
+ {
406
+ scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value;
407
+ % if is_tail:
408
+ if (K*LID_0 < offset_end)
409
+ % endif
410
+ {
411
+ val = SCAN_EXPR(tmp, val,
412
+ %if is_segmented:
413
+ (l_first_segment_start_in_subtree[LID_0]
414
+ != NO_SEG_BOUNDARY)
415
+ %else:
416
+ false
417
+ %endif
418
+ );
419
+ }
420
+
421
+ %if is_segmented:
422
+ // Prepare for l_first_segment_start_in_subtree, below.
423
+
424
+ // Note that this update must take place *even* if we're
425
+ // out of bounds.
426
+
427
+ first_segment_start_in_subtree = min(
428
+ l_first_segment_start_in_subtree[LID_0],
429
+ l_first_segment_start_in_subtree
430
+ [LID_0 - ${scan_offset}]);
431
+ %endif
432
+ }
433
+ %if is_segmented:
434
+ else
435
+ {
436
+ first_segment_start_in_subtree =
437
+ l_first_segment_start_in_subtree[LID_0];
438
+ }
439
+ %endif
440
+
441
+ // }}}
442
+
443
+ local_barrier();
444
+
445
+ // {{{ writes to local allowed, reads from local not allowed
446
+
447
+ ldata[K][LID_0].value = val;
448
+ %if is_segmented:
449
+ l_first_segment_start_in_subtree[LID_0] =
450
+ first_segment_start_in_subtree;
451
+ %endif
452
+
453
+ // }}}
454
+
455
+ local_barrier();
456
+
457
+ %if 0:
458
+ if (LID_0 == 0)
459
+ {
460
+ printf("${scan_offset}: ");
461
+ for (int i = 0; i < WG_SIZE; ++i)
462
+ {
463
+ if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY)
464
+ printf("- ");
465
+ else
466
+ printf("%d ", l_first_segment_start_in_subtree[i]);
467
+ }
468
+ printf("\n");
469
+ }
470
+ %endif
471
+
472
+ <% scan_offset *= 2 %>
473
+ % endwhile
474
+
475
+ pycl_printf(("after tree scan\n"));
476
+
477
+ // }}}
478
+
479
+ // {{{ update local values
480
+
481
+ if (LID_0 > 0)
482
+ {
483
+ sum = ldata[K][LID_0 - 1].value;
484
+
485
+ for(index_type k = 0; k < K; k++)
486
+ {
487
+ %if is_tail:
488
+ if (K * LID_0 + k < offset_end)
489
+ %endif
490
+ {
491
+ scan_type tmp = ldata[k][LID_0].value;
492
+ ldata[k][LID_0].value = SCAN_EXPR(sum, tmp,
493
+ %if is_segmented:
494
+ (unit_base + K * LID_0 + k
495
+ >= first_segment_start_in_k_group)
496
+ %else:
497
+ false
498
+ %endif
499
+ );
500
+ }
501
+ }
502
+ }
503
+
504
+ %if is_segmented:
505
+ if (LID_0 == 0)
506
+ {
507
+ // update interval-wide first-seg variable from current unit
508
+ first_segment_start_in_interval = min(
509
+ first_segment_start_in_interval,
510
+ l_first_segment_start_in_subtree[WG_SIZE-1]);
511
+ }
512
+ %endif
513
+
514
+ pycl_printf(("after local update\n"));
515
+
516
+ // }}}
517
+
518
+ local_barrier();
519
+
520
+ // {{{ write data
521
+
522
+ %if is_gpu:
523
+ {
524
+ // work hard with index math to achieve contiguous 32-bit stores
525
+ __global int *dest =
526
+ (__global int *) (partial_scan_buffer + unit_base);
527
+
528
+ <%
529
+
530
+ assert scan_dtype.itemsize % 4 == 0
531
+
532
+ ints_per_wg = wg_size
533
+ ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
534
+
535
+ %>
536
+
537
+ const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
538
+
539
+ %for store_base in range(0, ints_to_store, ints_per_wg):
540
+ <%
541
+
542
+ # Observe that ints_to_store is divisible by the work group
543
+ # size already, so we won't go out of bounds that way.
544
+ assert store_base + ints_per_wg <= ints_to_store
545
+
546
+ %>
547
+
548
+ %if is_tail:
549
+ if (${store_base} + LID_0 <
550
+ scan_types_per_int*(interval_end - unit_base))
551
+ %endif
552
+ {
553
+ index_type linear_index = ${store_base} + LID_0;
554
+ index_type linear_scan_data_idx =
555
+ linear_index / scan_types_per_int;
556
+ index_type remainder =
557
+ linear_index - linear_scan_data_idx * scan_types_per_int;
558
+
559
+ __local int *src = (__local int *) &(
560
+ ldata
561
+ [linear_scan_data_idx % K]
562
+ [linear_scan_data_idx / K].value);
563
+
564
+ dest[linear_index] = src[remainder];
565
+ }
566
+ %endfor
567
+ }
568
+ %else:
569
+ for (index_type k = 0; k < K; k++)
570
+ {
571
+ const index_type offset = k*WG_SIZE + LID_0;
572
+
573
+ %if is_tail:
574
+ if (unit_base + offset < interval_end)
575
+ %endif
576
+ {
577
+ pycl_printf(("write: %d\n", unit_base + offset));
578
+ partial_scan_buffer[unit_base + offset] =
579
+ ldata[offset % K][offset / K].value;
580
+ }
581
+ }
582
+ %endif
583
+
584
+ pycl_printf(("after write\n"));
585
+
586
+ // }}}
587
+
588
+ local_barrier();
589
+ }
590
+
591
+ % endfor
592
+
593
+ // write interval sum
594
+ %if is_first_level:
595
+ if (LID_0 == 0)
596
+ {
597
+ interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
598
+ %if is_segmented:
599
+ g_first_segment_start_in_interval[GID_0] =
600
+ first_segment_start_in_interval;
601
+ %endif
602
+ }
603
+ %endif
604
+ }
605
+ """
606
+
607
+ # }}}
608
+
609
+ # {{{ update
610
+
611
+ UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
612
+
613
+ KERNEL
614
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
615
+ void ${name_prefix}_final_update(
616
+ ${argument_signature},
617
+ const index_type N,
618
+ const index_type interval_size,
619
+ GLOBAL_MEM scan_type *restrict interval_results,
620
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer
621
+ %if is_segmented:
622
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
623
+ %endif
624
+ %if is_segmented and use_lookbehind_update:
625
+ , GLOBAL_MEM char *restrict g_segment_start_flags
626
+ %endif
627
+ )
628
+ {
629
+ ${arg_offset_adjustment}
630
+
631
+ %if use_lookbehind_update:
632
+ LOCAL_MEM scan_type ldata[WG_SIZE];
633
+ %endif
634
+ %if is_segmented and use_lookbehind_update:
635
+ LOCAL_MEM char l_segment_start_flags[WG_SIZE];
636
+ %endif
637
+
638
+ const index_type interval_begin = interval_size * GID_0;
639
+ const index_type interval_end = min(interval_begin + interval_size, N);
640
+
641
+ // carry from last interval
642
+ scan_type carry = ${neutral};
643
+ if (GID_0 != 0)
644
+ carry = interval_results[GID_0 - 1];
645
+
646
+ %if is_segmented:
647
+ const index_type first_seg_start_in_interval =
648
+ g_first_segment_start_in_interval[GID_0];
649
+ %endif
650
+
651
+ %if not is_segmented and 'last_item' in output_statement:
652
+ scan_type last_item = interval_results[GDIM_0-1];
653
+ %endif
654
+
655
+ %if not use_lookbehind_update:
656
+ // {{{ no look-behind ('prev_item' not in output_statement -> simpler)
657
+
658
+ index_type update_i = interval_begin+LID_0;
659
+
660
+ %if is_segmented:
661
+ index_type seg_end = min(first_seg_start_in_interval, interval_end);
662
+ %endif
663
+
664
+ for(; update_i < interval_end; update_i += WG_SIZE)
665
+ {
666
+ scan_type partial_val = partial_scan_buffer[update_i];
667
+ scan_type item = SCAN_EXPR(carry, partial_val,
668
+ %if is_segmented:
669
+ (update_i >= seg_end)
670
+ %else:
671
+ false
672
+ %endif
673
+ );
674
+ index_type i = update_i;
675
+
676
+ { ${output_statement}; }
677
+ }
678
+
679
+ // }}}
680
+ %else:
681
+ // {{{ allow look-behind ('prev_item' in output_statement -> complicated)
682
+
683
+ // We are not allowed to branch across barriers at a granularity smaller
684
+ // than the whole workgroup. Therefore, the for loop is group-global,
685
+ // and there are lots of local ifs.
686
+
687
+ index_type group_base = interval_begin;
688
+ scan_type prev_item = carry; // (A)
689
+
690
+ for(; group_base < interval_end; group_base += WG_SIZE)
691
+ {
692
+ index_type update_i = group_base+LID_0;
693
+
694
+ // load a work group's worth of data
695
+ if (update_i < interval_end)
696
+ {
697
+ scan_type tmp = partial_scan_buffer[update_i];
698
+
699
+ tmp = SCAN_EXPR(carry, tmp,
700
+ %if is_segmented:
701
+ (update_i >= first_seg_start_in_interval)
702
+ %else:
703
+ false
704
+ %endif
705
+ );
706
+
707
+ ldata[LID_0] = tmp;
708
+
709
+ %if is_segmented:
710
+ l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
711
+ %endif
712
+ }
713
+
714
+ local_barrier();
715
+
716
+ // find prev_item
717
+ if (LID_0 != 0)
718
+ prev_item = ldata[LID_0 - 1];
719
+ /*
720
+ else
721
+ prev_item = carry (see (A)) OR last tail (see (B));
722
+ */
723
+
724
+ if (update_i < interval_end)
725
+ {
726
+ %if is_segmented:
727
+ if (l_segment_start_flags[LID_0])
728
+ prev_item = ${neutral};
729
+ %endif
730
+
731
+ scan_type item = ldata[LID_0];
732
+ index_type i = update_i;
733
+ { ${output_statement}; }
734
+ }
735
+
736
+ if (LID_0 == 0)
737
+ prev_item = ldata[WG_SIZE - 1]; // (B)
738
+
739
+ local_barrier();
740
+ }
741
+
742
+ // }}}
743
+ %endif
744
+ }
745
+ """
746
+
747
+ # }}}
748
+
749
+
750
+ # {{{ driver
751
+
752
+ # {{{ helpers
753
+
754
+ def _round_down_to_power_of_2(val: int) -> int:
755
+ result = 2**bitlog2(val)
756
+ if result > val:
757
+ result >>= 1
758
+
759
+ assert result <= val
760
+ return result
761
+
762
+
763
+ _PREFIX_WORDS = set("""
764
+ ldata partial_scan_buffer global scan_offset
765
+ segment_start_in_k_group carry
766
+ g_first_segment_start_in_interval IS_SEG_START tmp Z
767
+ val l_first_segment_start_in_subtree unit_size
768
+ index_type interval_begin interval_size offset_end K
769
+ SCAN_EXPR do_update WG_SIZE
770
+ first_segment_start_in_k_group scan_type
771
+ segment_start_in_subtree offset interval_results interval_end
772
+ first_segment_start_in_subtree unit_base
773
+ first_segment_start_in_interval k INPUT_EXPR
774
+ prev_group_sum prev pv value partial_val pgs
775
+ is_seg_start update_i scan_item_at_i seq_i read_i
776
+ l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
777
+ first_seg_start_in_interval g_segment_start_flags
778
+ group_base seg_end my_val DEBUG ARGS
779
+ ints_to_store ints_per_wg scan_types_per_int linear_index
780
+ linear_scan_data_idx dest src store_base wrapped_scan_type
781
+ dummy scan_tmp tmp_aux
782
+
783
+ LID_2 LID_1 LID_0
784
+ LDIM_0 LDIM_1 LDIM_2
785
+ GDIM_0 GDIM_1 GDIM_2
786
+ GID_0 GID_1 GID_2
787
+ """.split())
788
+
789
+ _IGNORED_WORDS = set("""
790
+ 4 8 32
791
+
792
+ typedef for endfor if void while endwhile endfor endif else const printf
793
+ None return bool n char true false ifdef pycl_printf str range assert
794
+ np iinfo max itemsize __packed__ struct restrict ptrdiff_t
795
+
796
+ set iteritems len setdefault
797
+
798
+ GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
799
+ local_barrier
800
+ CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
801
+ pragma __attribute__ __global __kernel __local
802
+ get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
803
+ get_num_groups barrier get_group_id
804
+ CL_VERSION_1_1 __OPENCL_C_VERSION__ 120
805
+
806
+ _final_update _debug_scan kernel_name
807
+
808
+ positions all padded integer its previous write based writes 0
809
+ has local worth scan_expr to read cannot not X items False bank
810
+ four beginning follows applicable item min each indices works side
811
+ scanning right summed relative used id out index avoid current state
812
+ boundary True across be This reads groups along Otherwise undetermined
813
+ store of times prior s update first regardless Each number because
814
+ array unit from segment conflicts two parallel 2 empty define direction
815
+ CL padding work tree bounds values and adds
816
+ scan is allowed thus it an as enable at in occur sequentially end no
817
+ storage data 1 largest may representable uses entry Y meaningful
818
+ computations interval At the left dimension know d
819
+ A load B group perform shift tail see last OR
820
+ this add fetched into are directly need
821
+ gets them stenciled that undefined
822
+ there up any ones or name only relevant populated
823
+ even wide we Prepare int seg Note re below place take variable must
824
+ intra Therefore find code assumption
825
+ branch workgroup complicated granularity phase remainder than simpler
826
+ We smaller look ifs lots self behind allow barriers whole loop
827
+ after already Observe achieve contiguous stores hard go with by math
828
+ size won t way divisible bit so Avoid declare adding single type
829
+
830
+ is_tail is_first_level input_expr argument_signature preamble
831
+ double_support neutral output_statement
832
+ k_group_size name_prefix is_segmented index_dtype scan_dtype
833
+ wg_size is_segment_start_expr fetch_expr_offsets
834
+ arg_ctypes ife_offsets input_fetch_exprs def
835
+ ife_offset arg_name local_fetch_expr_args update_body
836
+ update_loop_lookbehind update_loop_plain update_loop
837
+ use_lookbehind_update store_segment_start_flags
838
+ update_loop first_seg scan_dtype dtype_to_ctype
839
+ is_gpu use_bank_conflict_avoidance
840
+
841
+ a b prev_item i last_item prev_value
842
+ N NO_SEG_BOUNDARY across_seg_boundary
843
+
844
+ arg_offset_adjustment
845
+ """.split())
846
+
847
+
848
+ def _make_template(s: str):
849
+ import re
850
+ leftovers = set()
851
+
852
+ def replace_id(match: re.Match) -> str:
853
+ # avoid name clashes with user code by adding 'psc_' prefix to
854
+ # identifiers.
855
+
856
+ word = match.group(1)
857
+ if word in _IGNORED_WORDS:
858
+ return word
859
+ elif word in _PREFIX_WORDS:
860
+ return f"psc_{word}"
861
+ else:
862
+ leftovers.add(word)
863
+ return word
864
+
865
+ s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
866
+ if leftovers:
867
+ from warnings import warn
868
+ warn("Leftover words in identifier prefixing: " + " ".join(leftovers),
869
+ stacklevel=3)
870
+
871
+ return mako.template.Template(s, strict_undefined=True) # type: ignore
872
+
873
+
874
+ @dataclass(frozen=True)
875
+ class _GeneratedScanKernelInfo:
876
+ scan_src: str
877
+ kernel_name: str
878
+ scalar_arg_dtypes: list[np.dtype | None]
879
+ wg_size: int
880
+ k_group_size: int
881
+
882
+ def build(self, context: cl.Context, options: Any) -> _BuiltScanKernelInfo:
883
+ program = cl.Program(context, self.scan_src).build(options)
884
+ kernel = getattr(program, self.kernel_name)
885
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
886
+ return _BuiltScanKernelInfo(
887
+ kernel=kernel,
888
+ wg_size=self.wg_size,
889
+ k_group_size=self.k_group_size)
890
+
891
+
892
+ @dataclass(frozen=True)
893
+ class _BuiltScanKernelInfo:
894
+ kernel: cl.Kernel
895
+ wg_size: int
896
+ k_group_size: int
897
+
898
+
899
+ @dataclass(frozen=True)
900
+ class _GeneratedFinalUpdateKernelInfo:
901
+ source: str
902
+ kernel_name: str
903
+ scalar_arg_dtypes: list[np.dtype | None]
904
+ update_wg_size: int
905
+
906
+ def build(self,
907
+ context: cl.Context,
908
+ options: Any) -> _BuiltFinalUpdateKernelInfo:
909
+ program = cl.Program(context, self.source).build(options)
910
+ kernel = getattr(program, self.kernel_name)
911
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
912
+ return _BuiltFinalUpdateKernelInfo(kernel, self.update_wg_size)
913
+
914
+
915
+ @dataclass(frozen=True)
916
+ class _BuiltFinalUpdateKernelInfo:
917
+ kernel: cl.Kernel
918
+ update_wg_size: int
919
+
920
+ # }}}
921
+
922
+
923
+ class ScanPerformanceWarning(UserWarning):
924
+ pass
925
+
926
+
927
+ class GenericScanKernelBase(ABC):
928
+ # {{{ constructor, argument processing
929
+
930
+ def __init__(
931
+ self,
932
+ ctx: cl.Context,
933
+ dtype: Any,
934
+ arguments: str | list[DtypedArgument],
935
+ input_expr: str,
936
+ scan_expr: str,
937
+ neutral: str | None,
938
+ output_statement: str,
939
+ is_segment_start_expr: str | None = None,
940
+ input_fetch_exprs: list[tuple[str, str, int]] | None = None,
941
+ index_dtype: Any = None,
942
+ name_prefix: str = "scan",
943
+ options: Any = None,
944
+ preamble: str = "",
945
+ devices: cl.Device | None = None) -> None:
946
+ """
947
+ :arg ctx: a :class:`pyopencl.Context` within which the code
948
+ for this scan kernel will be generated.
949
+ :arg dtype: the :class:`numpy.dtype` with which the scan will
950
+ be performed. May be a structured type if that type was registered
951
+ through :func:`pyopencl.tools.get_or_register_dtype`.
952
+ :arg arguments: A string of comma-separated C argument declarations.
953
+ If *arguments* is specified, then *input_expr* must also be
954
+ specified. All types used here must be known to PyOpenCL.
955
+ (see :func:`pyopencl.tools.get_or_register_dtype`).
956
+ :arg scan_expr: The associative, binary operation carrying out the scan,
957
+ represented as a C string. Its two arguments are available as ``a``
958
+ and ``b`` when it is evaluated. ``b`` is guaranteed to be the
959
+ 'element being updated', and ``a`` is the increment. Thus,
960
+ if some data is supposed to just propagate along without being
961
+ modified by the scan, it should live in ``b``.
962
+
963
+ This expression may call functions given in the *preamble*.
964
+
965
+ Another value available to this expression is ``across_seg_boundary``,
966
+ a C `bool` indicating whether this scan update is crossing a
967
+ segment boundary, as defined by ``is_segment_start_expr``.
968
+ The scan routine does not implement segmentation
969
+ semantics on its own. It relies on ``scan_expr`` to do this.
970
+ This value is available (but always ``false``) even for a
971
+ non-segmented scan.
972
+
973
+ .. note::
974
+
975
+ In early pre-releases of the segmented scan,
976
+ segmentation semantics were implemented *without*
977
+ relying on ``scan_expr``.
978
+
979
+ :arg input_expr: A C expression, encoded as a string, resulting
980
+ in the values to which the scan is applied. This may be used
981
+ to apply a mapping to values stored in *arguments* before being
982
+ scanned. The result of this expression must match *dtype*.
983
+ The index intended to be mapped is available as ``i`` in this
984
+ expression. This expression may also use the variables defined
985
+ by *input_fetch_expr*.
986
+
987
+ This expression may also call functions given in the *preamble*.
988
+ :arg output_statement: a C statement that writes
989
+ the output of the scan. It has access to the scan result as ``item``,
990
+ the preceding scan result item as ``prev_item``, and the current index
991
+ as ``i``. ``prev_item`` in a segmented scan will be the neutral element
992
+ at a segment boundary, not the immediately preceding item.
993
+
994
+ Using *prev_item* in output statement has a small run-time cost.
995
+ ``prev_item`` enables the construction of an exclusive scan.
996
+
997
+ For non-segmented scans, *output_statement* may also reference
998
+ ``last_item``, which evaluates to the scan result of the last
999
+ array entry.
1000
+ :arg is_segment_start_expr: A C expression, encoded as a string,
1001
+ resulting in a C ``bool`` value that determines whether a new
1002
+ scan segments starts at index *i*. If given, makes the scan a
1003
+ segmented scan. Has access to the current index ``i``, the result
1004
+ of *input_expr* as ``a``, and in addition may use *arguments* and
1005
+ *input_fetch_expr* variables just like *input_expr*.
1006
+
1007
+ If it returns true, then previous sums will not spill over into the
1008
+ item with index *i* or subsequent items.
1009
+ :arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*.
1010
+ An entry here has the effect of doing the equivalent of the following
1011
+ before input_expr::
1012
+
1013
+ ARG_NAME_TYPE NAME = ARG_NAME[i+OFFSET];
1014
+
1015
+ ``OFFSET`` is allowed to be 0 or -1, and ``ARG_NAME_TYPE`` is the type
1016
+ of ``ARG_NAME``.
1017
+ :arg preamble: |preamble|
1018
+
1019
+ The first array in the argument list determines the size of the index
1020
+ space over which the scan is carried out, and thus the values over
1021
+ which the index *i* occurring in a number of code fragments in
1022
+ arguments above will vary.
1023
+
1024
+ All code fragments further have access to N, the number of elements
1025
+ being processed in the scan.
1026
+ """
1027
+
1028
+ if index_dtype is None:
1029
+ index_dtype = np.dtype(np.int32)
1030
+
1031
+ if input_fetch_exprs is None:
1032
+ input_fetch_exprs = []
1033
+
1034
+ self.context = ctx
1035
+ dtype = self.dtype = np.dtype(dtype)
1036
+
1037
+ if neutral is None:
1038
+ from warnings import warn
1039
+ warn("not specifying 'neutral' is deprecated and will lead to "
1040
+ "wrong results if your scan is not in-place or your "
1041
+ "'output_statement' does something otherwise non-trivial",
1042
+ stacklevel=2)
1043
+
1044
+ if dtype.itemsize % 4 != 0:
1045
+ raise TypeError("scan value type must have size divisible by 4 bytes")
1046
+
1047
+ self.index_dtype = np.dtype(index_dtype)
1048
+ if np.iinfo(self.index_dtype).min >= 0:
1049
+ raise TypeError("index_dtype must be signed")
1050
+
1051
+ if devices is None:
1052
+ devices = ctx.devices
1053
+ self.devices = devices
1054
+ self.options = options
1055
+
1056
+ from pyopencl.tools import parse_arg_list
1057
+ self.parsed_args = parse_arg_list(arguments)
1058
+ from pyopencl.tools import VectorArg
1059
+ self.first_array_idx = next(
1060
+ i for i, arg in enumerate(self.parsed_args)
1061
+ if isinstance(arg, VectorArg))
1062
+
1063
+ self.input_expr = input_expr
1064
+
1065
+ self.is_segment_start_expr = is_segment_start_expr
1066
+ self.is_segmented = is_segment_start_expr is not None
1067
+ if self.is_segmented:
1068
+ is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
1069
+
1070
+ self.output_statement = output_statement
1071
+
1072
+ for _name, _arg_name, ife_offset in input_fetch_exprs:
1073
+ if ife_offset not in [0, -1]:
1074
+ raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
1075
+ self.input_fetch_exprs = input_fetch_exprs
1076
+
1077
+ arg_dtypes = {}
1078
+ arg_ctypes = {}
1079
+ for arg in self.parsed_args:
1080
+ arg_dtypes[arg.name] = arg.dtype
1081
+ arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
1082
+
1083
+ self.name_prefix = name_prefix
1084
+
1085
+ # {{{ set up shared code dict
1086
+
1087
+ from pyopencl.characterize import has_double_support
1088
+
1089
+ self.code_variables = {
1090
+ "np": np,
1091
+ "dtype_to_ctype": dtype_to_ctype,
1092
+ "preamble": preamble,
1093
+ "name_prefix": name_prefix,
1094
+ "index_dtype": self.index_dtype,
1095
+ "scan_dtype": dtype,
1096
+ "is_segmented": self.is_segmented,
1097
+ "arg_dtypes": arg_dtypes,
1098
+ "arg_ctypes": arg_ctypes,
1099
+ "scan_expr": _process_code_for_macro(scan_expr),
1100
+ "neutral": _process_code_for_macro(neutral),
1101
+ "is_gpu": bool(self.devices[0].type & cl.device_type.GPU),
1102
+ "double_support": all(
1103
+ has_double_support(dev) for dev in devices),
1104
+ }
1105
+
1106
+ index_typename = dtype_to_ctype(self.index_dtype)
1107
+ scan_typename = dtype_to_ctype(dtype)
1108
+
1109
+ # This key is meant to uniquely identify the non-device parameters for
1110
+ # the scan kernel.
1111
+ self.kernel_key = (
1112
+ self.dtype,
1113
+ tuple(arg.declarator() for arg in self.parsed_args),
1114
+ self.input_expr,
1115
+ scan_expr,
1116
+ neutral,
1117
+ output_statement,
1118
+ is_segment_start_expr,
1119
+ tuple(input_fetch_exprs),
1120
+ index_dtype,
1121
+ name_prefix,
1122
+ preamble,
1123
+ # These depend on dtype_to_ctype(), so their value is independent of
1124
+ # the other variables.
1125
+ index_typename,
1126
+ scan_typename,
1127
+ )
1128
+
1129
+ # }}}
1130
+
1131
+ self.use_lookbehind_update = "prev_item" in self.output_statement
1132
+ self.store_segment_start_flags = (
1133
+ self.is_segmented and self.use_lookbehind_update)
1134
+
1135
+ self.finish_setup()
1136
+
1137
+ # }}}
1138
+
1139
+ @abstractmethod
1140
+ def finish_setup(self) -> None:
1141
+ pass
1142
+
1143
+
1144
+ if not cl._PYOPENCL_NO_CACHE:
1145
+ generic_scan_kernel_cache: WriteOncePersistentDict[Any,
1146
+ tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
1147
+ _GeneratedFinalUpdateKernelInfo]] = \
1148
+ WriteOncePersistentDict(
1149
+ "pyopencl-generated-scan-kernel-cache-v1",
1150
+ key_builder=_NumpyTypesKeyBuilder(),
1151
+ in_mem_cache_size=0,
1152
+ safe_sync=False)
1153
+
1154
+
1155
+ class GenericScanKernel(GenericScanKernelBase):
1156
+ """Generates and executes code that performs prefix sums ("scans") on
1157
+ arbitrary types, with many possible tweaks.
1158
+
1159
+ Usage example::
1160
+
1161
+ from pyopencl.scan import GenericScanKernel
1162
+ knl = GenericScanKernel(
1163
+ context, np.int32,
1164
+ arguments="__global int *ary",
1165
+ input_expr="ary[i]",
1166
+ scan_expr="a+b", neutral="0",
1167
+ output_statement="ary[i+1] = item;")
1168
+
1169
+ a = cl.array.arange(queue, 10000, dtype=np.int32)
1170
+ knl(a, queue=queue)
1171
+
1172
+ .. automethod:: __init__
1173
+ .. automethod:: __call__
1174
+ """
1175
+
1176
+ def finish_setup(self) -> None:
1177
+ # Before generating the kernel, see if it's cached.
1178
+ from pyopencl.cache import get_device_cache_id
1179
+ devices_key = tuple(get_device_cache_id(device)
1180
+ for device in self.devices)
1181
+
1182
+ cache_key = (self.kernel_key, devices_key)
1183
+ from_cache = False
1184
+
1185
+ if not cl._PYOPENCL_NO_CACHE:
1186
+ try:
1187
+ result = generic_scan_kernel_cache[cache_key]
1188
+ from_cache = True
1189
+ logger.debug(
1190
+ "cache hit for generated scan kernel '%s'", self.name_prefix)
1191
+ (
1192
+ self.first_level_scan_gen_info,
1193
+ self.second_level_scan_gen_info,
1194
+ self.final_update_gen_info) = result
1195
+ except KeyError:
1196
+ pass
1197
+
1198
+ if not from_cache:
1199
+ logger.debug(
1200
+ "cache miss for generated scan kernel '%s'", self.name_prefix)
1201
+ self._finish_setup_impl()
1202
+
1203
+ result = (self.first_level_scan_gen_info,
1204
+ self.second_level_scan_gen_info,
1205
+ self.final_update_gen_info)
1206
+
1207
+ if not cl._PYOPENCL_NO_CACHE:
1208
+ generic_scan_kernel_cache.store_if_not_present(cache_key, result)
1209
+
1210
+ # Build the kernels.
1211
+ self.first_level_scan_info = self.first_level_scan_gen_info.build(
1212
+ self.context, self.options)
1213
+ del self.first_level_scan_gen_info
1214
+
1215
+ self.second_level_scan_info = self.second_level_scan_gen_info.build(
1216
+ self.context, self.options)
1217
+ del self.second_level_scan_gen_info
1218
+
1219
+ self.final_update_info = self.final_update_gen_info.build(
1220
+ self.context, self.options)
1221
+ del self.final_update_gen_info
1222
+
1223
+ def _finish_setup_impl(self) -> None:
1224
+ # {{{ find usable workgroup/k-group size, build first-level scan
1225
+
1226
+ trip_count = 0
1227
+
1228
+ avail_local_mem = min(
1229
+ dev.local_mem_size
1230
+ for dev in self.devices)
1231
+
1232
+ if "CUDA" in self.devices[0].platform.name:
1233
+ # not sure where these go, but roughly this much seems unavailable.
1234
+ avail_local_mem -= 0x400
1235
+
1236
+ is_cpu = self.devices[0].type & cl.device_type.CPU
1237
+ is_gpu = self.devices[0].type & cl.device_type.GPU
1238
+
1239
+ if is_cpu:
1240
+ # (about the widest vector a CPU can support, also taking
1241
+ # into account that CPUs don't hide latency by large work groups
1242
+ max_scan_wg_size = 16
1243
+ wg_size_multiples = 4
1244
+ else:
1245
+ max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
1246
+ wg_size_multiples = 64
1247
+
1248
+ # Intel beignet fails "Out of shared local memory" in test_scan int64
1249
+ # and asserts in test_sort with this enabled:
1250
+ # https://github.com/inducer/pyopencl/pull/238
1251
+ # A beignet bug report (outside of pyopencl) suggests packed structs
1252
+ # (which this is) can even give wrong results:
1253
+ # https://bugs.freedesktop.org/show_bug.cgi?id=98717
1254
+ # TODO: does this also affect Intel Compute Runtime?
1255
+ use_bank_conflict_avoidance = (
1256
+ self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0
1257
+ and is_gpu
1258
+ and "beignet" not in self.devices[0].platform.version.lower())
1259
+
1260
+ # k_group_size should be a power of two because of in-kernel
1261
+ # division by that number.
1262
+
1263
+ solutions = []
1264
+ for k_exp in range(0, 9):
1265
+ for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
1266
+ wg_size_multiples):
1267
+
1268
+ k_group_size = 2**k_exp
1269
+ lmem_use = self.get_local_mem_use(wg_size, k_group_size,
1270
+ use_bank_conflict_avoidance)
1271
+ if lmem_use <= avail_local_mem:
1272
+ solutions.append((wg_size*k_group_size, k_group_size, wg_size))
1273
+
1274
+ if is_gpu:
1275
+ for wg_size_floor in [256, 192, 128]:
1276
+ have_sol_above_floor = any(wg_size >= wg_size_floor
1277
+ for _, _, wg_size in solutions)
1278
+
1279
+ if have_sol_above_floor:
1280
+ # delete all solutions not meeting the wg size floor
1281
+ solutions = [(total, try_k_group_size, try_wg_size)
1282
+ for total, try_k_group_size, try_wg_size in solutions
1283
+ if try_wg_size >= wg_size_floor]
1284
+ break
1285
+
1286
+ _, k_group_size, max_scan_wg_size = max(solutions)
1287
+
1288
+ while True:
1289
+ candidate_scan_gen_info = self.generate_scan_kernel(
1290
+ max_scan_wg_size, self.parsed_args,
1291
+ _process_code_for_macro(self.input_expr),
1292
+ self.is_segment_start_expr,
1293
+ input_fetch_exprs=self.input_fetch_exprs,
1294
+ is_first_level=True,
1295
+ store_segment_start_flags=self.store_segment_start_flags,
1296
+ k_group_size=k_group_size,
1297
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance)
1298
+
1299
+ candidate_scan_info = candidate_scan_gen_info.build(
1300
+ self.context, self.options)
1301
+
1302
+ # Will this device actually let us execute this kernel
1303
+ # at the desired work group size? Building it is the
1304
+ # only way to find out.
1305
+ kernel_max_wg_size = min(
1306
+ candidate_scan_info.kernel.get_work_group_info(
1307
+ cl.kernel_work_group_info.WORK_GROUP_SIZE,
1308
+ dev)
1309
+ for dev in self.devices)
1310
+
1311
+ if candidate_scan_info.wg_size <= kernel_max_wg_size:
1312
+ break
1313
+ else:
1314
+ max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
1315
+
1316
+ trip_count += 1
1317
+ assert trip_count <= 20
1318
+
1319
+ self.first_level_scan_gen_info = candidate_scan_gen_info
1320
+ assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
1321
+ == candidate_scan_info.wg_size)
1322
+
1323
+ # }}}
1324
+
1325
+ # {{{ build second-level scan
1326
+
1327
+ from pyopencl.tools import VectorArg
1328
+ second_level_arguments = [
1329
+ *self.parsed_args,
1330
+ VectorArg(self.dtype, "interval_sums"),
1331
+ ]
1332
+
1333
+ second_level_build_kwargs: dict[str, str | None] = {}
1334
+ if self.is_segmented:
1335
+ second_level_arguments.append(
1336
+ VectorArg(self.index_dtype,
1337
+ "g_first_segment_start_in_interval_input"))
1338
+
1339
+ # is_segment_start_expr answers the question "should previous sums
1340
+ # spill over into this item". And since
1341
+ # g_first_segment_start_in_interval_input answers the question if a
1342
+ # segment boundary was found in an interval of data, then if not,
1343
+ # it's ok to spill over.
1344
+ second_level_build_kwargs["is_segment_start_expr"] = \
1345
+ "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
1346
+ else:
1347
+ second_level_build_kwargs["is_segment_start_expr"] = None
1348
+
1349
+ self.second_level_scan_gen_info = self.generate_scan_kernel(
1350
+ max_scan_wg_size,
1351
+ arguments=second_level_arguments,
1352
+ input_expr="interval_sums[i]",
1353
+ input_fetch_exprs=[],
1354
+ is_first_level=False,
1355
+ store_segment_start_flags=False,
1356
+ k_group_size=k_group_size,
1357
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1358
+ **second_level_build_kwargs)
1359
+
1360
+ # }}}
1361
+
1362
+ # {{{ generate final update kernel
1363
+
1364
+ update_wg_size = min(max_scan_wg_size, 256)
1365
+
1366
+ final_update_tpl = _make_template(UPDATE_SOURCE)
1367
+ final_update_src = str(final_update_tpl.render(
1368
+ wg_size=update_wg_size,
1369
+ output_statement=self.output_statement,
1370
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1371
+ argument_signature=", ".join(
1372
+ arg.declarator() for arg in self.parsed_args),
1373
+ is_segment_start_expr=self.is_segment_start_expr,
1374
+ input_expr=_process_code_for_macro(self.input_expr),
1375
+ use_lookbehind_update=self.use_lookbehind_update,
1376
+ **self.code_variables))
1377
+
1378
+ update_scalar_arg_dtypes = [
1379
+ *get_arg_list_scalar_arg_dtypes(self.parsed_args),
1380
+ self.index_dtype, self.index_dtype, None, None]
1381
+
1382
+ if self.is_segmented:
1383
+ # g_first_segment_start_in_interval
1384
+ update_scalar_arg_dtypes.append(None)
1385
+ if self.store_segment_start_flags:
1386
+ update_scalar_arg_dtypes.append(None) # g_segment_start_flags
1387
+
1388
+ self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
1389
+ final_update_src,
1390
+ self.name_prefix + "_final_update",
1391
+ update_scalar_arg_dtypes,
1392
+ update_wg_size)
1393
+
1394
+ # }}}
1395
+
1396
+ # {{{ scan kernel build/properties
1397
+
1398
+ def get_local_mem_use(
1399
+ self, k_group_size: int, wg_size: int,
1400
+ use_bank_conflict_avoidance: bool) -> int:
1401
+ arg_dtypes = {}
1402
+ for arg in self.parsed_args:
1403
+ arg_dtypes[arg.name] = arg.dtype
1404
+
1405
+ fetch_expr_offsets: dict[str, set] = {}
1406
+ for _name, arg_name, ife_offset in self.input_fetch_exprs:
1407
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
1408
+
1409
+ itemsize = self.dtype.itemsize
1410
+ if use_bank_conflict_avoidance:
1411
+ itemsize += 4
1412
+
1413
+ return (
1414
+ # ldata
1415
+ itemsize*(k_group_size+1)*(wg_size+1)
1416
+
1417
+ # l_segment_start_flags
1418
+ + k_group_size*wg_size
1419
+
1420
+ # l_first_segment_start_in_subtree
1421
+ + self.index_dtype.itemsize*wg_size
1422
+
1423
+ + k_group_size*wg_size*sum(
1424
+ arg_dtypes[arg_name].itemsize
1425
+ for arg_name, ife_offsets in list(fetch_expr_offsets.items())
1426
+ if -1 in ife_offsets or len(ife_offsets) > 1))
1427
+
1428
+ def generate_scan_kernel(
1429
+ self,
1430
+ max_wg_size: int,
1431
+ arguments: list[DtypedArgument],
1432
+ input_expr: str,
1433
+ is_segment_start_expr: str | None,
1434
+ input_fetch_exprs: list[tuple[str, str, int]],
1435
+ is_first_level: bool,
1436
+ store_segment_start_flags: bool,
1437
+ k_group_size: int,
1438
+ use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
1439
+ scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
1440
+
1441
+ # Empirically found on Nv hardware: no need to be bigger than this size
1442
+ wg_size = _round_down_to_power_of_2(
1443
+ min(max_wg_size, 256))
1444
+
1445
+ kernel_name = self.code_variables["name_prefix"]
1446
+ if is_first_level:
1447
+ kernel_name += "_lev1"
1448
+ else:
1449
+ kernel_name += "_lev2"
1450
+
1451
+ scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
1452
+ scan_src = str(scan_tpl.render(
1453
+ wg_size=wg_size,
1454
+ input_expr=input_expr,
1455
+ k_group_size=k_group_size,
1456
+ arg_offset_adjustment=get_arg_offset_adjuster_code(arguments),
1457
+ argument_signature=", ".join(arg.declarator() for arg in arguments),
1458
+ is_segment_start_expr=is_segment_start_expr,
1459
+ input_fetch_exprs=input_fetch_exprs,
1460
+ is_first_level=is_first_level,
1461
+ store_segment_start_flags=store_segment_start_flags,
1462
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1463
+ kernel_name=kernel_name,
1464
+ **self.code_variables))
1465
+
1466
+ scalar_arg_dtypes.extend(
1467
+ (None, self.index_dtype, self.index_dtype))
1468
+ if is_first_level:
1469
+ scalar_arg_dtypes.append(None) # interval_results
1470
+ if self.is_segmented and is_first_level:
1471
+ scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval
1472
+ if store_segment_start_flags:
1473
+ scalar_arg_dtypes.append(None) # g_segment_start_flags
1474
+
1475
+ return _GeneratedScanKernelInfo(
1476
+ scan_src=scan_src,
1477
+ kernel_name=kernel_name,
1478
+ scalar_arg_dtypes=scalar_arg_dtypes,
1479
+ wg_size=wg_size,
1480
+ k_group_size=k_group_size)
1481
+
1482
+ # }}}
1483
+
1484
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1485
+ """
1486
+ |std-enqueue-blurb|
1487
+
1488
+ .. note::
1489
+
1490
+ The returned :class:`pyopencl.Event` corresponds only to part of the
1491
+ execution of the scan. It is not suitable for profiling.
1492
+
1493
+ :arg queue: queue on which to execute the scan. If not given, the
1494
+ queue of the first :class:`pyopencl.array.Array` in *args* is used
1495
+ :arg allocator: an allocator for the temporary arrays and results. If
1496
+ not given, the allocator of the first :class:`pyopencl.array.Array`
1497
+ in *args* is used.
1498
+ :arg size: specify the length of the scan to be carried out. If not
1499
+ given, this length is inferred from the first argument
1500
+ :arg wait_for: a :class:`list` of events to wait for.
1501
+ """
1502
+
1503
+ # {{{ argument processing
1504
+
1505
+ allocator = kwargs.get("allocator")
1506
+ queue = kwargs.get("queue")
1507
+ n = kwargs.get("size")
1508
+ wait_for = kwargs.get("wait_for")
1509
+
1510
+ if wait_for is None:
1511
+ wait_for = []
1512
+ else:
1513
+ wait_for = list(wait_for)
1514
+
1515
+ if len(args) != len(self.parsed_args):
1516
+ raise TypeError(
1517
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1518
+
1519
+ first_array = args[self.first_array_idx]
1520
+ allocator = allocator or first_array.allocator
1521
+ queue = queue or first_array.queue
1522
+
1523
+ if n is None:
1524
+ n, = first_array.shape
1525
+
1526
+ if n == 0:
1527
+ # We're done here. (But pretend to return an event.)
1528
+ return cl.enqueue_marker(queue, wait_for=wait_for)
1529
+
1530
+ data_args = []
1531
+ for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
1532
+ from pyopencl.tools import VectorArg
1533
+ if isinstance(arg_descr, VectorArg):
1534
+ data_args.append(arg_val.base_data)
1535
+ if arg_descr.with_offset:
1536
+ data_args.append(arg_val.offset)
1537
+ wait_for.extend(arg_val.events)
1538
+ else:
1539
+ data_args.append(arg_val)
1540
+
1541
+ # }}}
1542
+
1543
+ l1_info = self.first_level_scan_info
1544
+ l2_info = self.second_level_scan_info
1545
+
1546
+ # see CL source above for terminology
1547
+ unit_size = l1_info.wg_size * l1_info.k_group_size
1548
+ max_intervals = 3*max(dev.max_compute_units for dev in self.devices)
1549
+
1550
+ from pytools import uniform_interval_splitting
1551
+ interval_size, num_intervals = uniform_interval_splitting(
1552
+ n, unit_size, max_intervals)
1553
+
1554
+ # {{{ allocate some buffers
1555
+
1556
+ interval_results = cl_array.empty(queue,
1557
+ num_intervals, dtype=self.dtype,
1558
+ allocator=allocator)
1559
+
1560
+ partial_scan_buffer = cl_array.empty(
1561
+ queue, n, dtype=self.dtype,
1562
+ allocator=allocator)
1563
+
1564
+ if self.store_segment_start_flags:
1565
+ segment_start_flags = cl_array.empty(
1566
+ queue, n, dtype=np.bool_,
1567
+ allocator=allocator)
1568
+
1569
+ # }}}
1570
+
1571
+ # {{{ first level scan of interval (one interval per block)
1572
+
1573
+ scan1_args = [
1574
+ *data_args,
1575
+ partial_scan_buffer.data, n, interval_size, interval_results.data,
1576
+ ]
1577
+
1578
+ if self.is_segmented:
1579
+ first_segment_start_in_interval = cl_array.empty(queue,
1580
+ num_intervals, dtype=self.index_dtype,
1581
+ allocator=allocator)
1582
+ scan1_args.append(first_segment_start_in_interval.data)
1583
+
1584
+ if self.store_segment_start_flags:
1585
+ scan1_args.append(segment_start_flags.data)
1586
+
1587
+ l1_evt = l1_info.kernel(
1588
+ queue, (num_intervals,), (l1_info.wg_size,),
1589
+ *scan1_args, g_times_l=True, wait_for=wait_for)
1590
+
1591
+ # }}}
1592
+
1593
+ # {{{ second level scan of per-interval results
1594
+
1595
+ # can scan at most one interval
1596
+ assert interval_size >= num_intervals
1597
+
1598
+ scan2_args = [
1599
+ *data_args,
1600
+ interval_results.data, # interval_sums
1601
+ ]
1602
+
1603
+ if self.is_segmented:
1604
+ scan2_args.append(first_segment_start_in_interval.data)
1605
+ scan2_args = [
1606
+ *scan2_args,
1607
+ interval_results.data, # partial_scan_buffer
1608
+ num_intervals, interval_size]
1609
+
1610
+ l2_evt = l2_info.kernel(
1611
+ queue, (1,), (l1_info.wg_size,),
1612
+ *scan2_args, g_times_l=True, wait_for=[l1_evt])
1613
+
1614
+ # }}}
1615
+
1616
+ # {{{ update intervals with result of interval scan
1617
+
1618
+ upd_args = [
1619
+ *data_args,
1620
+ n, interval_size, interval_results.data, partial_scan_buffer.data]
1621
+ if self.is_segmented:
1622
+ upd_args.append(first_segment_start_in_interval.data)
1623
+ if self.store_segment_start_flags:
1624
+ upd_args.append(segment_start_flags.data)
1625
+
1626
+ return self.final_update_info.kernel(
1627
+ queue, (num_intervals,),
1628
+ (self.final_update_info.update_wg_size,),
1629
+ *upd_args, g_times_l=True, wait_for=[l2_evt])
1630
+
1631
+ # }}}
1632
+
1633
+ # }}}
1634
+
1635
+
1636
+ # {{{ debug kernel
1637
+
1638
+ DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + r"""//CL//
1639
+
1640
+ KERNEL
1641
+ REQD_WG_SIZE(1, 1, 1)
1642
+ void ${name_prefix}_debug_scan(
1643
+ __global scan_type *scan_tmp,
1644
+ ${argument_signature},
1645
+ const index_type N)
1646
+ {
1647
+ scan_type current = ${neutral};
1648
+ scan_type prev;
1649
+
1650
+ ${arg_offset_adjustment}
1651
+
1652
+ for (index_type i = 0; i < N; ++i)
1653
+ {
1654
+ %for name, arg_name, ife_offset in input_fetch_exprs:
1655
+ ${arg_ctypes[arg_name]} ${name};
1656
+ %if ife_offset < 0:
1657
+ if (i+${ife_offset} >= 0)
1658
+ ${name} = ${arg_name}[i+${ife_offset}];
1659
+ %else:
1660
+ ${name} = ${arg_name}[i];
1661
+ %endif
1662
+ %endfor
1663
+
1664
+ scan_type my_val = INPUT_EXPR(i);
1665
+
1666
+ prev = current;
1667
+ %if is_segmented:
1668
+ bool is_seg_start = IS_SEG_START(i, my_val);
1669
+ %endif
1670
+
1671
+ current = SCAN_EXPR(prev, my_val,
1672
+ %if is_segmented:
1673
+ is_seg_start
1674
+ %else:
1675
+ false
1676
+ %endif
1677
+ );
1678
+ scan_tmp[i] = current;
1679
+ }
1680
+
1681
+ scan_type last_item = scan_tmp[N-1];
1682
+
1683
+ for (index_type i = 0; i < N; ++i)
1684
+ {
1685
+ scan_type item = scan_tmp[i];
1686
+ scan_type prev_item;
1687
+ if (i)
1688
+ prev_item = scan_tmp[i-1];
1689
+ else
1690
+ prev_item = ${neutral};
1691
+
1692
+ {
1693
+ ${output_statement};
1694
+ }
1695
+ }
1696
+ }
1697
+ """
1698
+
1699
+
1700
+ class GenericDebugScanKernel(GenericScanKernelBase):
1701
+ """
1702
+ Performs the same function and has the same interface as
1703
+ :class:`GenericScanKernel`, but uses a dead-simple, sequential scan. Works
1704
+ best on CPU platforms, and helps isolate bugs in scans by removing the
1705
+ potential for issues originating in parallel execution.
1706
+
1707
+ .. automethod:: __call__
1708
+ """
1709
+
1710
+ def finish_setup(self) -> None:
1711
+ scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
1712
+ scan_src = str(scan_tpl.render(
1713
+ output_statement=self.output_statement,
1714
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1715
+ argument_signature=", ".join(
1716
+ arg.declarator() for arg in self.parsed_args),
1717
+ is_segment_start_expr=self.is_segment_start_expr,
1718
+ input_expr=_process_code_for_macro(self.input_expr),
1719
+ input_fetch_exprs=self.input_fetch_exprs,
1720
+ wg_size=1,
1721
+ **self.code_variables))
1722
+
1723
+ scan_prg = cl.Program(self.context, scan_src).build(self.options)
1724
+ self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
1725
+ scalar_arg_dtypes = [
1726
+ None,
1727
+ *get_arg_list_scalar_arg_dtypes(self.parsed_args),
1728
+ self.index_dtype,
1729
+ ]
1730
+ self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
1731
+
1732
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1733
+ """See :meth:`GenericScanKernel.__call__`."""
1734
+
1735
+ # {{{ argument processing
1736
+
1737
+ allocator = kwargs.get("allocator")
1738
+ queue = kwargs.get("queue")
1739
+ n = kwargs.get("size")
1740
+ wait_for = kwargs.get("wait_for")
1741
+
1742
+ if wait_for is None:
1743
+ wait_for = []
1744
+ else:
1745
+ # We'll be modifying it below.
1746
+ wait_for = list(wait_for)
1747
+
1748
+ if len(args) != len(self.parsed_args):
1749
+ raise TypeError(
1750
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1751
+
1752
+ first_array = args[self.first_array_idx]
1753
+ allocator = allocator or first_array.allocator
1754
+ queue = queue or first_array.queue
1755
+
1756
+ if n is None:
1757
+ n, = first_array.shape
1758
+
1759
+ scan_tmp = cl_array.empty(queue,
1760
+ n, dtype=self.dtype,
1761
+ allocator=allocator)
1762
+
1763
+ data_args = [scan_tmp.data]
1764
+ from pyopencl.tools import VectorArg
1765
+ for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
1766
+ if isinstance(arg_descr, VectorArg):
1767
+ data_args.append(arg_val.base_data)
1768
+ if arg_descr.with_offset:
1769
+ data_args.append(arg_val.offset)
1770
+ wait_for.extend(arg_val.events)
1771
+ else:
1772
+ data_args.append(arg_val)
1773
+
1774
+ # }}}
1775
+
1776
+ return self.kernel(queue, (1,), (1,), *([*data_args, n]), wait_for=wait_for)
1777
+
1778
+ # }}}
1779
+
1780
+
1781
+ # {{{ compatibility interface
1782
+
1783
+ class _LegacyScanKernelBase(GenericScanKernel):
1784
+ def __init__(self, ctx, dtype,
1785
+ scan_expr, neutral=None,
1786
+ name_prefix="scan", options=None, preamble="", devices=None):
1787
+ scan_ctype = dtype_to_ctype(dtype)
1788
+ GenericScanKernel.__init__(self,
1789
+ ctx, dtype,
1790
+ arguments="__global {} *input_ary, __global {} *output_ary".format(
1791
+ scan_ctype, scan_ctype),
1792
+ input_expr="input_ary[i]",
1793
+ scan_expr=scan_expr,
1794
+ neutral=neutral,
1795
+ output_statement=self.ary_output_statement,
1796
+ options=options, preamble=preamble, devices=devices)
1797
+
1798
+ @property
1799
+ def ary_output_statement(self):
1800
+ raise NotImplementedError
1801
+
1802
+ def __call__(self, input_ary, output_ary=None, allocator=None, queue=None):
1803
+ allocator = allocator or input_ary.allocator
1804
+ queue = queue or input_ary.queue or output_ary.queue
1805
+
1806
+ if output_ary is None:
1807
+ output_ary = input_ary
1808
+
1809
+ if isinstance(output_ary, (str, str)) and output_ary == "new":
1810
+ output_ary = cl_array.empty_like(input_ary, allocator=allocator)
1811
+
1812
+ if input_ary.shape != output_ary.shape:
1813
+ raise ValueError("input and output must have the same shape")
1814
+
1815
+ if not input_ary.flags.forc:
1816
+ raise RuntimeError("ScanKernel cannot "
1817
+ "deal with non-contiguous arrays")
1818
+
1819
+ n, = input_ary.shape
1820
+
1821
+ if not n:
1822
+ return output_ary
1823
+
1824
+ GenericScanKernel.__call__(self,
1825
+ input_ary, output_ary, allocator=allocator, queue=queue)
1826
+
1827
+ return output_ary
1828
+
1829
+
1830
+ class InclusiveScanKernel(_LegacyScanKernelBase):
1831
+ ary_output_statement = "output_ary[i] = item;"
1832
+
1833
+
1834
+ class ExclusiveScanKernel(_LegacyScanKernelBase):
1835
+ ary_output_statement = "output_ary[i] = prev_item;"
1836
+
1837
+ # }}}
1838
+
1839
+
1840
+ # {{{ template
1841
+
1842
+ class ScanTemplate(KernelTemplateBase):
1843
+ def __init__(
1844
+ self,
1845
+ arguments: str | list[DtypedArgument],
1846
+ input_expr: str,
1847
+ scan_expr: str,
1848
+ neutral: str | None,
1849
+ output_statement: str,
1850
+ is_segment_start_expr: str | None = None,
1851
+ input_fetch_exprs: list[tuple[str, str, int]] | None = None,
1852
+ name_prefix: str = "scan",
1853
+ preamble: str = "",
1854
+ template_processor: Any = None) -> None:
1855
+ super().__init__(template_processor=template_processor)
1856
+
1857
+ if input_fetch_exprs is None:
1858
+ input_fetch_exprs = []
1859
+
1860
+ self.arguments = arguments
1861
+ self.input_expr = input_expr
1862
+ self.scan_expr = scan_expr
1863
+ self.neutral = neutral
1864
+ self.output_statement = output_statement
1865
+ self.is_segment_start_expr = is_segment_start_expr
1866
+ self.input_fetch_exprs = input_fetch_exprs
1867
+ self.name_prefix = name_prefix
1868
+ self.preamble = preamble
1869
+
1870
+ def build_inner(self, context, type_aliases=(), var_values=(),
1871
+ more_preamble="", more_arguments=(), declare_types=(),
1872
+ options=None, devices=None, scan_cls=GenericScanKernel):
1873
+ renderer = self.get_renderer(type_aliases, var_values, context, options)
1874
+
1875
+ arg_list = renderer.render_argument_list(self.arguments, more_arguments)
1876
+
1877
+ type_decl_preamble = renderer.get_type_decl_preamble(
1878
+ context.devices[0], declare_types, arg_list)
1879
+
1880
+ return scan_cls(context, renderer.type_aliases["scan_t"],
1881
+ renderer.render_argument_list(self.arguments, more_arguments),
1882
+ renderer(self.input_expr), renderer(self.scan_expr),
1883
+ renderer(self.neutral), renderer(self.output_statement),
1884
+ is_segment_start_expr=renderer(self.is_segment_start_expr),
1885
+ input_fetch_exprs=self.input_fetch_exprs,
1886
+ index_dtype=renderer.type_aliases.get("index_t", np.int32),
1887
+ name_prefix=renderer(self.name_prefix), options=options,
1888
+ preamble=(
1889
+ type_decl_preamble
1890
+ + "\n"
1891
+ + renderer(self.preamble + "\n" + more_preamble)),
1892
+ devices=devices)
1893
+
1894
+ # }}}
1895
+
1896
+
1897
+ # {{{ 'canned' scan kernels
1898
+
1899
+ @context_dependent_memoize
1900
+ def get_cumsum_kernel(context, input_dtype, output_dtype):
1901
+ from pyopencl.tools import VectorArg
1902
+ return GenericScanKernel(
1903
+ context, output_dtype,
1904
+ arguments=[
1905
+ VectorArg(input_dtype, "input"),
1906
+ VectorArg(output_dtype, "output"),
1907
+ ],
1908
+ input_expr="input[i]",
1909
+ scan_expr="a+b", neutral="0",
1910
+ output_statement="""
1911
+ output[i] = item;
1912
+ """)
1913
+
1914
+ # }}}
1915
+
1916
+ # vim: filetype=pyopencl:fdm=marker