pyopencl 2026.1.1__cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pyopencl/.libs/libOpenCL-34a55fe4.so.1.0.0 +0 -0
  2. pyopencl/__init__.py +1995 -0
  3. pyopencl/_cl.cpython-314t-aarch64-linux-gnu.so +0 -0
  4. pyopencl/_cl.pyi +2009 -0
  5. pyopencl/_cluda.py +57 -0
  6. pyopencl/_monkeypatch.py +1104 -0
  7. pyopencl/_mymako.py +17 -0
  8. pyopencl/algorithm.py +1454 -0
  9. pyopencl/array.py +3530 -0
  10. pyopencl/bitonic_sort.py +245 -0
  11. pyopencl/bitonic_sort_templates.py +597 -0
  12. pyopencl/cache.py +553 -0
  13. pyopencl/capture_call.py +200 -0
  14. pyopencl/characterize/__init__.py +461 -0
  15. pyopencl/characterize/performance.py +240 -0
  16. pyopencl/cl/pyopencl-airy.cl +324 -0
  17. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  18. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  19. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  20. pyopencl/cl/pyopencl-complex.h +303 -0
  21. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  22. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  23. pyopencl/cl/pyopencl-random123/array.h +325 -0
  24. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  25. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  26. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  27. pyopencl/clmath.py +281 -0
  28. pyopencl/clrandom.py +412 -0
  29. pyopencl/cltypes.py +217 -0
  30. pyopencl/compyte/.gitignore +21 -0
  31. pyopencl/compyte/__init__.py +0 -0
  32. pyopencl/compyte/array.py +211 -0
  33. pyopencl/compyte/dtypes.py +314 -0
  34. pyopencl/compyte/pyproject.toml +49 -0
  35. pyopencl/elementwise.py +1288 -0
  36. pyopencl/invoker.py +417 -0
  37. pyopencl/ipython_ext.py +70 -0
  38. pyopencl/py.typed +0 -0
  39. pyopencl/reduction.py +829 -0
  40. pyopencl/scan.py +1921 -0
  41. pyopencl/tools.py +1680 -0
  42. pyopencl/typing.py +61 -0
  43. pyopencl/version.py +11 -0
  44. pyopencl-2026.1.1.dist-info/METADATA +108 -0
  45. pyopencl-2026.1.1.dist-info/RECORD +47 -0
  46. pyopencl-2026.1.1.dist-info/WHEEL +6 -0
  47. pyopencl-2026.1.1.dist-info/licenses/LICENSE +104 -0
pyopencl/scan.py ADDED
@@ -0,0 +1,1921 @@
1
+ """Scan primitive."""
2
+ from __future__ import annotations
3
+
4
+
5
+ __copyright__ = """
6
+ Copyright 2011-2012 Andreas Kloeckner
7
+ Copyright 2008-2011 NVIDIA Corporation
8
+ """
9
+
10
+ __license__ = """
11
+ Licensed under the Apache License, Version 2.0 (the "License");
12
+ you may not use this file except in compliance with the License.
13
+ You may obtain a copy of the License at
14
+
15
+ https://www.apache.org/licenses/LICENSE-2.0
16
+
17
+ Unless required by applicable law or agreed to in writing, software
18
+ distributed under the License is distributed on an "AS IS" BASIS,
19
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ See the License for the specific language governing permissions and
21
+ limitations under the License.
22
+
23
+ Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
24
+ """
25
+
26
+ import logging
27
+ from abc import ABC, abstractmethod
28
+ from dataclasses import dataclass
29
+ from typing import TYPE_CHECKING, Any, cast
30
+
31
+ import numpy as np
32
+
33
+ from pytools.persistent_dict import WriteOncePersistentDict
34
+
35
+ import pyopencl as cl
36
+ import pyopencl._mymako as mako
37
+ import pyopencl.array as cl_array
38
+ from pyopencl._cluda import CLUDA_PREAMBLE
39
+ from pyopencl.tools import (
40
+ DtypedArgument,
41
+ KernelTemplateBase,
42
+ _NumpyTypesKeyBuilder,
43
+ _process_code_for_macro,
44
+ bitlog2,
45
+ context_dependent_memoize,
46
+ dtype_to_ctype,
47
+ get_arg_list_scalar_arg_dtypes,
48
+ get_arg_offset_adjuster_code,
49
+ )
50
+
51
+
52
+ if TYPE_CHECKING:
53
+ from collections.abc import Sequence
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ # {{{ preamble
60
+
61
+ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
62
+ #define WG_SIZE ${wg_size}
63
+
64
+ #define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr}
65
+ #define INPUT_EXPR(i) (${input_expr})
66
+ %if is_segmented:
67
+ #define IS_SEG_START(i, a) (${is_segment_start_expr})
68
+ %endif
69
+
70
+ ${preamble}
71
+
72
+ typedef ${dtype_to_ctype(scan_dtype)} scan_type;
73
+ typedef ${dtype_to_ctype(index_dtype)} index_type;
74
+
75
+ // NO_SEG_BOUNDARY is the largest representable integer in index_type.
76
+ // This assumption is used in code below.
77
+ #define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
78
+ """
79
+
80
+ # }}}
81
+
82
+ # {{{ main scan code
83
+
84
+ # Algorithm: Each work group is responsible for one contiguous
85
+ # 'interval'. There are just enough intervals to fill all compute
86
+ # units. Intervals are split into 'units'. A unit is what gets
87
+ # worked on in parallel by one work group.
88
+ #
89
+ # in index space:
90
+ # interval > unit > local-parallel > k-group
91
+ #
92
+ # (Note that there is also a transpose in here: The data is read
93
+ # with local ids along linear index order.)
94
+ #
95
+ # Each unit has two axes--the local-id axis and the k axis.
96
+ #
97
+ # unit 0:
98
+ # | | | | | | | | | | ----> lid
99
+ # | | | | | | | | | |
100
+ # | | | | | | | | | |
101
+ # | | | | | | | | | |
102
+ # | | | | | | | | | |
103
+ #
104
+ # |
105
+ # v k (fastest-moving in linear index)
106
+ #
107
+ # unit 1:
108
+ # | | | | | | | | | | ----> lid
109
+ # | | | | | | | | | |
110
+ # | | | | | | | | | |
111
+ # | | | | | | | | | |
112
+ # | | | | | | | | | |
113
+ #
114
+ # |
115
+ # v k (fastest-moving in linear index)
116
+ #
117
+ # ...
118
+ #
119
+ # At a device-global level, this is a three-phase algorithm, in
120
+ # which first each interval does its local scan, then a scan
121
+ # across intervals exchanges data globally, and the final update
122
+ # adds the exchanged sums to each interval.
123
+ #
124
+ # Exclusive scan is realized by allowing look-behind (access to the
125
+ # preceding item) in the final update, by means of a local shift.
126
+ #
127
+ # NOTE: All segment_start_in_X indices are relative to the start
128
+ # of the array.
129
+
130
+ SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL//
131
+
132
+ #define K ${k_group_size}
133
+
134
+ // #define DEBUG
135
+ #ifdef DEBUG
136
+ #define pycl_printf(ARGS) printf ARGS
137
+ #else
138
+ #define pycl_printf(ARGS) /* */
139
+ #endif
140
+
141
+ KERNEL
142
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
143
+ void ${kernel_name}(
144
+ ${argument_signature},
145
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer,
146
+ const index_type N,
147
+ const index_type interval_size
148
+ %if is_first_level:
149
+ , GLOBAL_MEM scan_type *restrict interval_results
150
+ %endif
151
+ %if is_segmented and is_first_level:
152
+ // NO_SEG_BOUNDARY if no segment boundary in interval.
153
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
154
+ %endif
155
+ %if store_segment_start_flags:
156
+ , GLOBAL_MEM char *restrict g_segment_start_flags
157
+ %endif
158
+ )
159
+ {
160
+ ${arg_offset_adjustment}
161
+
162
+ // index K in first dimension used for carry storage
163
+ %if use_bank_conflict_avoidance:
164
+ // Avoid bank conflicts by adding a single 32-bit value to the size of
165
+ // the scan type.
166
+ struct __attribute__ ((__packed__)) wrapped_scan_type
167
+ {
168
+ scan_type value;
169
+ int dummy;
170
+ };
171
+ %else:
172
+ struct wrapped_scan_type
173
+ {
174
+ scan_type value;
175
+ };
176
+ %endif
177
+ // padded in WG_SIZE to avoid bank conflicts
178
+ LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1];
179
+
180
+ %if is_segmented:
181
+ LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
182
+ LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
183
+
184
+ // only relevant/populated for local id 0
185
+ index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
186
+
187
+ index_type first_segment_start_in_k_group, first_segment_start_in_subtree;
188
+ %endif
189
+
190
+ // {{{ declare local data for input_fetch_exprs if any of them are stenciled
191
+
192
+ <%
193
+ fetch_expr_offsets = {}
194
+ for name, arg_name, ife_offset in input_fetch_exprs:
195
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
196
+
197
+ local_fetch_expr_args = set(
198
+ arg_name
199
+ for arg_name, ife_offsets in fetch_expr_offsets.items()
200
+ if -1 in ife_offsets or len(ife_offsets) > 1)
201
+ %>
202
+
203
+ %for arg_name in local_fetch_expr_args:
204
+ LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
205
+ %endfor
206
+
207
+ // }}}
208
+
209
+ const index_type interval_begin = interval_size * GID_0;
210
+ const index_type interval_end = min(interval_begin + interval_size, N);
211
+
212
+ const index_type unit_size = K * WG_SIZE;
213
+
214
+ index_type unit_base = interval_begin;
215
+
216
+ %for is_tail in [False, True]:
217
+
218
+ %if not is_tail:
219
+ for(; unit_base + unit_size <= interval_end; unit_base += unit_size)
220
+ %else:
221
+ if (unit_base < interval_end)
222
+ %endif
223
+
224
+ {
225
+
226
+ // {{{ carry out input_fetch_exprs
227
+ // (if there are ones that need to be fetched into local)
228
+
229
+ %if local_fetch_expr_args:
230
+ for(index_type k = 0; k < K; k++)
231
+ {
232
+ const index_type offset = k*WG_SIZE + LID_0;
233
+ const index_type read_i = unit_base + offset;
234
+
235
+ %for arg_name in local_fetch_expr_args:
236
+ %if is_tail:
237
+ if (read_i < interval_end)
238
+ %endif
239
+ {
240
+ l_${arg_name}[offset] = ${arg_name}[read_i];
241
+ }
242
+ %endfor
243
+ }
244
+
245
+ local_barrier();
246
+ %endif
247
+
248
+ pycl_printf(("after input_fetch_exprs\n"));
249
+
250
+ // }}}
251
+
252
+ // {{{ read a unit's worth of data from global
253
+
254
+ for(index_type k = 0; k < K; k++)
255
+ {
256
+ const index_type offset = k*WG_SIZE + LID_0;
257
+ const index_type read_i = unit_base + offset;
258
+
259
+ %if is_tail:
260
+ if (read_i < interval_end)
261
+ %endif
262
+ {
263
+ %for name, arg_name, ife_offset in input_fetch_exprs:
264
+ ${arg_ctypes[arg_name]} ${name};
265
+
266
+ %if arg_name in local_fetch_expr_args:
267
+ if (offset + ${ife_offset} >= 0)
268
+ ${name} = l_${arg_name}[offset + ${ife_offset}];
269
+ else if (read_i + ${ife_offset} >= 0)
270
+ ${name} = ${arg_name}[read_i + ${ife_offset}];
271
+ /*
272
+ else
273
+ if out of bounds, name is left undefined */
274
+
275
+ %else:
276
+ // ${arg_name} gets fetched directly from global
277
+ ${name} = ${arg_name}[read_i];
278
+
279
+ %endif
280
+ %endfor
281
+
282
+ scan_type scan_value = INPUT_EXPR(read_i);
283
+
284
+ const index_type o_mod_k = offset % K;
285
+ const index_type o_div_k = offset / K;
286
+ ldata[o_mod_k][o_div_k].value = scan_value;
287
+
288
+ %if is_segmented:
289
+ bool is_seg_start = IS_SEG_START(read_i, scan_value);
290
+ l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start;
291
+ %endif
292
+ %if store_segment_start_flags:
293
+ g_segment_start_flags[read_i] = is_seg_start;
294
+ %endif
295
+ }
296
+ }
297
+
298
+ pycl_printf(("after read from global\n"));
299
+
300
+ // }}}
301
+
302
+ // {{{ carry in from previous unit, if applicable
303
+
304
+ %if is_segmented:
305
+ local_barrier();
306
+
307
+ first_segment_start_in_k_group = NO_SEG_BOUNDARY;
308
+ if (l_segment_start_flags[0][LID_0])
309
+ first_segment_start_in_k_group = unit_base + K*LID_0;
310
+ %endif
311
+
312
+ if (LID_0 == 0 && unit_base != interval_begin)
313
+ {
314
+ scan_type tmp = ldata[K][WG_SIZE - 1].value;
315
+ scan_type tmp_aux = ldata[0][0].value;
316
+
317
+ ldata[0][0].value = SCAN_EXPR(
318
+ tmp, tmp_aux,
319
+ %if is_segmented:
320
+ (l_segment_start_flags[0][0])
321
+ %else:
322
+ false
323
+ %endif
324
+ );
325
+ }
326
+
327
+ pycl_printf(("after carry-in\n"));
328
+
329
+ // }}}
330
+
331
+ local_barrier();
332
+
333
+ // {{{ scan along k (sequentially in each work item)
334
+
335
+ scan_type sum = ldata[0][LID_0].value;
336
+
337
+ %if is_tail:
338
+ const index_type offset_end = interval_end - unit_base;
339
+ %endif
340
+
341
+ for (index_type k = 1; k < K; k++)
342
+ {
343
+ %if is_tail:
344
+ if ((index_type) (K * LID_0 + k) < offset_end)
345
+ %endif
346
+ {
347
+ scan_type tmp = ldata[k][LID_0].value;
348
+
349
+ %if is_segmented:
350
+ index_type seq_i = unit_base + K*LID_0 + k;
351
+
352
+ if (l_segment_start_flags[k][LID_0])
353
+ {
354
+ first_segment_start_in_k_group = min(
355
+ first_segment_start_in_k_group,
356
+ seq_i);
357
+ }
358
+ %endif
359
+
360
+ sum = SCAN_EXPR(sum, tmp,
361
+ %if is_segmented:
362
+ (l_segment_start_flags[k][LID_0])
363
+ %else:
364
+ false
365
+ %endif
366
+ );
367
+
368
+ ldata[k][LID_0].value = sum;
369
+ }
370
+ }
371
+
372
+ pycl_printf(("after scan along k\n"));
373
+
374
+ // }}}
375
+
376
+ // store carry in out-of-bounds (padding) array entry (index K) in
377
+ // the K direction
378
+ ldata[K][LID_0].value = sum;
379
+
380
+ %if is_segmented:
381
+ l_first_segment_start_in_subtree[LID_0] =
382
+ first_segment_start_in_k_group;
383
+ %endif
384
+
385
+ local_barrier();
386
+
387
+ // {{{ tree-based local parallel scan
388
+
389
+ // This tree-based scan works as follows:
390
+ // - Each work item adds the previous item to its current state
391
+ // - barrier
392
+ // - Each work item adds in the item from two positions to the left
393
+ // - barrier
394
+ // - Each work item adds in the item from four positions to the left
395
+ // ...
396
+ // At the end, each item has summed all prior items.
397
+
398
+ // across k groups, along local id
399
+ // (uses out-of-bounds k=K array entry for storage)
400
+
401
+ scan_type val = ldata[K][LID_0].value;
402
+
403
+ <% scan_offset = 1 %>
404
+
405
+ % while scan_offset <= wg_size:
406
+ // {{{ reads from local allowed, writes to local not allowed
407
+
408
+ if (LID_0 >= ${scan_offset})
409
+ {
410
+ scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value;
411
+ % if is_tail:
412
+ if (K*LID_0 < offset_end)
413
+ % endif
414
+ {
415
+ val = SCAN_EXPR(tmp, val,
416
+ %if is_segmented:
417
+ (l_first_segment_start_in_subtree[LID_0]
418
+ != NO_SEG_BOUNDARY)
419
+ %else:
420
+ false
421
+ %endif
422
+ );
423
+ }
424
+
425
+ %if is_segmented:
426
+ // Prepare for l_first_segment_start_in_subtree, below.
427
+
428
+ // Note that this update must take place *even* if we're
429
+ // out of bounds.
430
+
431
+ first_segment_start_in_subtree = min(
432
+ l_first_segment_start_in_subtree[LID_0],
433
+ l_first_segment_start_in_subtree
434
+ [LID_0 - ${scan_offset}]);
435
+ %endif
436
+ }
437
+ %if is_segmented:
438
+ else
439
+ {
440
+ first_segment_start_in_subtree =
441
+ l_first_segment_start_in_subtree[LID_0];
442
+ }
443
+ %endif
444
+
445
+ // }}}
446
+
447
+ local_barrier();
448
+
449
+ // {{{ writes to local allowed, reads from local not allowed
450
+
451
+ ldata[K][LID_0].value = val;
452
+ %if is_segmented:
453
+ l_first_segment_start_in_subtree[LID_0] =
454
+ first_segment_start_in_subtree;
455
+ %endif
456
+
457
+ // }}}
458
+
459
+ local_barrier();
460
+
461
+ %if 0:
462
+ if (LID_0 == 0)
463
+ {
464
+ printf("${scan_offset}: ");
465
+ for (int i = 0; i < WG_SIZE; ++i)
466
+ {
467
+ if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY)
468
+ printf("- ");
469
+ else
470
+ printf("%d ", l_first_segment_start_in_subtree[i]);
471
+ }
472
+ printf("\n");
473
+ }
474
+ %endif
475
+
476
+ <% scan_offset *= 2 %>
477
+ % endwhile
478
+
479
+ pycl_printf(("after tree scan\n"));
480
+
481
+ // }}}
482
+
483
+ // {{{ update local values
484
+
485
+ if (LID_0 > 0)
486
+ {
487
+ sum = ldata[K][LID_0 - 1].value;
488
+
489
+ for(index_type k = 0; k < K; k++)
490
+ {
491
+ %if is_tail:
492
+ if (K * LID_0 + k < offset_end)
493
+ %endif
494
+ {
495
+ scan_type tmp = ldata[k][LID_0].value;
496
+ ldata[k][LID_0].value = SCAN_EXPR(sum, tmp,
497
+ %if is_segmented:
498
+ (unit_base + K * LID_0 + k
499
+ >= first_segment_start_in_k_group)
500
+ %else:
501
+ false
502
+ %endif
503
+ );
504
+ }
505
+ }
506
+ }
507
+
508
+ %if is_segmented:
509
+ if (LID_0 == 0)
510
+ {
511
+ // update interval-wide first-seg variable from current unit
512
+ first_segment_start_in_interval = min(
513
+ first_segment_start_in_interval,
514
+ l_first_segment_start_in_subtree[WG_SIZE-1]);
515
+ }
516
+ %endif
517
+
518
+ pycl_printf(("after local update\n"));
519
+
520
+ // }}}
521
+
522
+ local_barrier();
523
+
524
+ // {{{ write data
525
+
526
+ %if is_gpu:
527
+ {
528
+ // work hard with index math to achieve contiguous 32-bit stores
529
+ __global int *dest =
530
+ (__global int *) (partial_scan_buffer + unit_base);
531
+
532
+ <%
533
+
534
+ assert scan_dtype.itemsize % 4 == 0
535
+
536
+ ints_per_wg = wg_size
537
+ ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
538
+
539
+ %>
540
+
541
+ const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
542
+
543
+ %for store_base in range(0, ints_to_store, ints_per_wg):
544
+ <%
545
+
546
+ # Observe that ints_to_store is divisible by the work group
547
+ # size already, so we won't go out of bounds that way.
548
+ assert store_base + ints_per_wg <= ints_to_store
549
+
550
+ %>
551
+
552
+ %if is_tail:
553
+ if (${store_base} + LID_0 <
554
+ scan_types_per_int*(interval_end - unit_base))
555
+ %endif
556
+ {
557
+ index_type linear_index = ${store_base} + LID_0;
558
+ index_type linear_scan_data_idx =
559
+ linear_index / scan_types_per_int;
560
+ index_type remainder =
561
+ linear_index - linear_scan_data_idx * scan_types_per_int;
562
+
563
+ __local int *src = (__local int *) &(
564
+ ldata
565
+ [linear_scan_data_idx % K]
566
+ [linear_scan_data_idx / K].value);
567
+
568
+ dest[linear_index] = src[remainder];
569
+ }
570
+ %endfor
571
+ }
572
+ %else:
573
+ for (index_type k = 0; k < K; k++)
574
+ {
575
+ const index_type offset = k*WG_SIZE + LID_0;
576
+
577
+ %if is_tail:
578
+ if (unit_base + offset < interval_end)
579
+ %endif
580
+ {
581
+ pycl_printf(("write: %d\n", unit_base + offset));
582
+ partial_scan_buffer[unit_base + offset] =
583
+ ldata[offset % K][offset / K].value;
584
+ }
585
+ }
586
+ %endif
587
+
588
+ pycl_printf(("after write\n"));
589
+
590
+ // }}}
591
+
592
+ local_barrier();
593
+ }
594
+
595
+ % endfor
596
+
597
+ // write interval sum
598
+ %if is_first_level:
599
+ if (LID_0 == 0)
600
+ {
601
+ interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
602
+ %if is_segmented:
603
+ g_first_segment_start_in_interval[GID_0] =
604
+ first_segment_start_in_interval;
605
+ %endif
606
+ }
607
+ %endif
608
+ }
609
+ """
610
+
611
+ # }}}
612
+
613
+ # {{{ update
614
+
615
+ UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
616
+
617
+ KERNEL
618
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
619
+ void ${name_prefix}_final_update(
620
+ ${argument_signature},
621
+ const index_type N,
622
+ const index_type interval_size,
623
+ GLOBAL_MEM scan_type *restrict interval_results,
624
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer
625
+ %if is_segmented:
626
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
627
+ %endif
628
+ %if is_segmented and use_lookbehind_update:
629
+ , GLOBAL_MEM char *restrict g_segment_start_flags
630
+ %endif
631
+ )
632
+ {
633
+ ${arg_offset_adjustment}
634
+
635
+ %if use_lookbehind_update:
636
+ LOCAL_MEM scan_type ldata[WG_SIZE];
637
+ %endif
638
+ %if is_segmented and use_lookbehind_update:
639
+ LOCAL_MEM char l_segment_start_flags[WG_SIZE];
640
+ %endif
641
+
642
+ const index_type interval_begin = interval_size * GID_0;
643
+ const index_type interval_end = min(interval_begin + interval_size, N);
644
+
645
+ // carry from last interval
646
+ scan_type carry = ${neutral};
647
+ if (GID_0 != 0)
648
+ carry = interval_results[GID_0 - 1];
649
+
650
+ %if is_segmented:
651
+ const index_type first_seg_start_in_interval =
652
+ g_first_segment_start_in_interval[GID_0];
653
+ %endif
654
+
655
+ %if not is_segmented and 'last_item' in output_statement:
656
+ scan_type last_item = interval_results[GDIM_0-1];
657
+ %endif
658
+
659
+ %if not use_lookbehind_update:
660
+ // {{{ no look-behind ('prev_item' not in output_statement -> simpler)
661
+
662
+ index_type update_i = interval_begin+LID_0;
663
+
664
+ %if is_segmented:
665
+ index_type seg_end = min(first_seg_start_in_interval, interval_end);
666
+ %endif
667
+
668
+ for(; update_i < interval_end; update_i += WG_SIZE)
669
+ {
670
+ scan_type partial_val = partial_scan_buffer[update_i];
671
+ scan_type item = SCAN_EXPR(carry, partial_val,
672
+ %if is_segmented:
673
+ (update_i >= seg_end)
674
+ %else:
675
+ false
676
+ %endif
677
+ );
678
+ index_type i = update_i;
679
+
680
+ { ${output_statement}; }
681
+ }
682
+
683
+ // }}}
684
+ %else:
685
+ // {{{ allow look-behind ('prev_item' in output_statement -> complicated)
686
+
687
+ // We are not allowed to branch across barriers at a granularity smaller
688
+ // than the whole workgroup. Therefore, the for loop is group-global,
689
+ // and there are lots of local ifs.
690
+
691
+ index_type group_base = interval_begin;
692
+ scan_type prev_item = carry; // (A)
693
+
694
+ for(; group_base < interval_end; group_base += WG_SIZE)
695
+ {
696
+ index_type update_i = group_base+LID_0;
697
+
698
+ // load a work group's worth of data
699
+ if (update_i < interval_end)
700
+ {
701
+ scan_type tmp = partial_scan_buffer[update_i];
702
+
703
+ tmp = SCAN_EXPR(carry, tmp,
704
+ %if is_segmented:
705
+ (update_i >= first_seg_start_in_interval)
706
+ %else:
707
+ false
708
+ %endif
709
+ );
710
+
711
+ ldata[LID_0] = tmp;
712
+
713
+ %if is_segmented:
714
+ l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
715
+ %endif
716
+ }
717
+
718
+ local_barrier();
719
+
720
+ // find prev_item
721
+ if (LID_0 != 0)
722
+ prev_item = ldata[LID_0 - 1];
723
+ /*
724
+ else
725
+ prev_item = carry (see (A)) OR last tail (see (B));
726
+ */
727
+
728
+ if (update_i < interval_end)
729
+ {
730
+ %if is_segmented:
731
+ if (l_segment_start_flags[LID_0])
732
+ prev_item = ${neutral};
733
+ %endif
734
+
735
+ scan_type item = ldata[LID_0];
736
+ index_type i = update_i;
737
+ { ${output_statement}; }
738
+ }
739
+
740
+ if (LID_0 == 0)
741
+ prev_item = ldata[WG_SIZE - 1]; // (B)
742
+
743
+ local_barrier();
744
+ }
745
+
746
+ // }}}
747
+ %endif
748
+ }
749
+ """
750
+
751
+ # }}}
752
+
753
+
754
+ # {{{ driver
755
+
756
+ # {{{ helpers
757
+
758
+ def _round_down_to_power_of_2(val: int) -> int:
759
+ result = 2**bitlog2(val)
760
+ if result > val:
761
+ result >>= 1
762
+
763
+ assert result <= val
764
+ return result
765
+
766
+
767
+ _PREFIX_WORDS = set("""
768
+ ldata partial_scan_buffer global scan_offset
769
+ segment_start_in_k_group carry
770
+ g_first_segment_start_in_interval IS_SEG_START tmp Z
771
+ val l_first_segment_start_in_subtree unit_size
772
+ index_type interval_begin interval_size offset_end K
773
+ SCAN_EXPR do_update WG_SIZE
774
+ first_segment_start_in_k_group scan_type
775
+ segment_start_in_subtree offset interval_results interval_end
776
+ first_segment_start_in_subtree unit_base
777
+ first_segment_start_in_interval k INPUT_EXPR
778
+ prev_group_sum prev pv value partial_val pgs
779
+ is_seg_start update_i scan_item_at_i seq_i read_i
780
+ l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
781
+ first_seg_start_in_interval g_segment_start_flags
782
+ group_base seg_end my_val DEBUG ARGS
783
+ ints_to_store ints_per_wg scan_types_per_int linear_index
784
+ linear_scan_data_idx dest src store_base wrapped_scan_type
785
+ dummy scan_tmp tmp_aux
786
+
787
+ LID_2 LID_1 LID_0
788
+ LDIM_0 LDIM_1 LDIM_2
789
+ GDIM_0 GDIM_1 GDIM_2
790
+ GID_0 GID_1 GID_2
791
+ """.split())
792
+
793
+ _IGNORED_WORDS = set("""
794
+ 4 8 32
795
+
796
+ typedef for endfor if void while endwhile endfor endif else const printf
797
+ None return bool n char true false ifdef pycl_printf str range assert
798
+ np iinfo max itemsize __packed__ struct restrict ptrdiff_t
799
+
800
+ set iteritems len setdefault
801
+
802
+ GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
803
+ local_barrier
804
+ CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
805
+ pragma __attribute__ __global __kernel __local
806
+ get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
807
+ get_num_groups barrier get_group_id
808
+ CL_VERSION_1_1 __OPENCL_C_VERSION__ 120
809
+
810
+ _final_update _debug_scan kernel_name
811
+
812
+ positions all padded integer its previous write based writes 0
813
+ has local worth scan_expr to read cannot not X items False bank
814
+ four beginning follows applicable item min each indices works side
815
+ scanning right summed relative used id out index avoid current state
816
+ boundary True across be This reads groups along Otherwise undetermined
817
+ store of times prior s update first regardless Each number because
818
+ array unit from segment conflicts two parallel 2 empty define direction
819
+ CL padding work tree bounds values and adds
820
+ scan is allowed thus it an as enable at in occur sequentially end no
821
+ storage data 1 largest may representable uses entry Y meaningful
822
+ computations interval At the left dimension know d
823
+ A load B group perform shift tail see last OR
824
+ this add fetched into are directly need
825
+ gets them stenciled that undefined
826
+ there up any ones or name only relevant populated
827
+ even wide we Prepare int seg Note re below place take variable must
828
+ intra Therefore find code assumption
829
+ branch workgroup complicated granularity phase remainder than simpler
830
+ We smaller look ifs lots self behind allow barriers whole loop
831
+ after already Observe achieve contiguous stores hard go with by math
832
+ size won t way divisible bit so Avoid declare adding single type
833
+
834
+ is_tail is_first_level input_expr argument_signature preamble
835
+ double_support neutral output_statement
836
+ k_group_size name_prefix is_segmented index_dtype scan_dtype
837
+ wg_size is_segment_start_expr fetch_expr_offsets
838
+ arg_ctypes ife_offsets input_fetch_exprs def
839
+ ife_offset arg_name local_fetch_expr_args update_body
840
+ update_loop_lookbehind update_loop_plain update_loop
841
+ use_lookbehind_update store_segment_start_flags
842
+ update_loop first_seg scan_dtype dtype_to_ctype
843
+ is_gpu use_bank_conflict_avoidance
844
+
845
+ a b prev_item i last_item prev_value
846
+ N NO_SEG_BOUNDARY across_seg_boundary
847
+
848
+ arg_offset_adjustment
849
+ """.split())
850
+
851
+
852
+ def _make_template(s: str):
853
+ import re
854
+ leftovers = set()
855
+
856
+ def replace_id(match: re.Match) -> str:
857
+ # avoid name clashes with user code by adding 'psc_' prefix to
858
+ # identifiers.
859
+
860
+ word = match.group(1)
861
+ if word in _IGNORED_WORDS:
862
+ return word
863
+ elif word in _PREFIX_WORDS:
864
+ return f"psc_{word}"
865
+ else:
866
+ leftovers.add(word)
867
+ return word
868
+
869
+ s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
870
+ if leftovers:
871
+ from warnings import warn
872
+ warn("Leftover words in identifier prefixing: " + " ".join(leftovers),
873
+ stacklevel=3)
874
+
875
+ return mako.template.Template(s, strict_undefined=True)
876
+
877
+
878
+ @dataclass(frozen=True)
879
+ class _GeneratedScanKernelInfo:
880
+ scan_src: str
881
+ kernel_name: str
882
+ scalar_arg_dtypes: list[np.dtype | None]
883
+ wg_size: int
884
+ k_group_size: int
885
+
886
+ def build(self, context: cl.Context, options: Any) -> _BuiltScanKernelInfo:
887
+ program = cl.Program(context, self.scan_src).build(options)
888
+ kernel = getattr(program, self.kernel_name)
889
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
890
+ return _BuiltScanKernelInfo(
891
+ kernel=kernel,
892
+ wg_size=self.wg_size,
893
+ k_group_size=self.k_group_size)
894
+
895
+
896
+ @dataclass(frozen=True)
897
+ class _BuiltScanKernelInfo:
898
+ kernel: cl.Kernel
899
+ wg_size: int
900
+ k_group_size: int
901
+
902
+
903
+ @dataclass(frozen=True)
904
+ class _GeneratedFinalUpdateKernelInfo:
905
+ source: str
906
+ kernel_name: str
907
+ scalar_arg_dtypes: Sequence[np.dtype | None]
908
+ update_wg_size: int
909
+
910
+ def build(self,
911
+ context: cl.Context,
912
+ options: Any) -> _BuiltFinalUpdateKernelInfo:
913
+ program = cl.Program(context, self.source).build(options)
914
+ kernel = getattr(program, self.kernel_name)
915
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
916
+ return _BuiltFinalUpdateKernelInfo(kernel, self.update_wg_size)
917
+
918
+
919
+ @dataclass(frozen=True)
920
+ class _BuiltFinalUpdateKernelInfo:
921
+ kernel: cl.Kernel
922
+ update_wg_size: int
923
+
924
+ # }}}
925
+
926
+
927
+ class ScanPerformanceWarning(UserWarning):
928
+ pass
929
+
930
+
931
+ class GenericScanKernelBase(ABC):
932
+ # {{{ constructor, argument processing
933
+
934
+ def __init__(
935
+ self,
936
+ ctx: cl.Context,
937
+ dtype: Any,
938
+ arguments: str | list[DtypedArgument],
939
+ input_expr: str,
940
+ scan_expr: str,
941
+ neutral: str | None,
942
+ output_statement: str,
943
+ is_segment_start_expr: str | None = None,
944
+ input_fetch_exprs: list[tuple[str, str, int]] | None = None,
945
+ index_dtype: Any = None,
946
+ name_prefix: str = "scan",
947
+ options: Any = None,
948
+ preamble: str = "",
949
+ devices: Sequence[cl.Device] | None = None) -> None:
950
+ """
951
+ :arg ctx: a :class:`pyopencl.Context` within which the code
952
+ for this scan kernel will be generated.
953
+ :arg dtype: the :class:`numpy.dtype` with which the scan will
954
+ be performed. May be a structured type if that type was registered
955
+ through :func:`pyopencl.tools.get_or_register_dtype`.
956
+ :arg arguments: A string of comma-separated C argument declarations.
957
+ If *arguments* is specified, then *input_expr* must also be
958
+ specified. All types used here must be known to PyOpenCL.
959
+ (see :func:`pyopencl.tools.get_or_register_dtype`).
960
+ :arg scan_expr: The associative, binary operation carrying out the scan,
961
+ represented as a C string. Its two arguments are available as ``a``
962
+ and ``b`` when it is evaluated. ``b`` is guaranteed to be the
963
+ 'element being updated', and ``a`` is the increment. Thus,
964
+ if some data is supposed to just propagate along without being
965
+ modified by the scan, it should live in ``b``.
966
+
967
+ This expression may call functions given in the *preamble*.
968
+
969
+ Another value available to this expression is ``across_seg_boundary``,
970
+ a C `bool` indicating whether this scan update is crossing a
971
+ segment boundary, as defined by ``is_segment_start_expr``.
972
+ The scan routine does not implement segmentation
973
+ semantics on its own. It relies on ``scan_expr`` to do this.
974
+ This value is available (but always ``false``) even for a
975
+ non-segmented scan.
976
+
977
+ .. note::
978
+
979
+ In early pre-releases of the segmented scan,
980
+ segmentation semantics were implemented *without*
981
+ relying on ``scan_expr``.
982
+
983
+ :arg input_expr: A C expression, encoded as a string, resulting
984
+ in the values to which the scan is applied. This may be used
985
+ to apply a mapping to values stored in *arguments* before being
986
+ scanned. The result of this expression must match *dtype*.
987
+ The index intended to be mapped is available as ``i`` in this
988
+ expression. This expression may also use the variables defined
989
+ by *input_fetch_expr*.
990
+
991
+ This expression may also call functions given in the *preamble*.
992
+ :arg output_statement: a C statement that writes
993
+ the output of the scan. It has access to the scan result as ``item``,
994
+ the preceding scan result item as ``prev_item``, and the current index
995
+ as ``i``. ``prev_item`` in a segmented scan will be the neutral element
996
+ at a segment boundary, not the immediately preceding item.
997
+
998
+ Using *prev_item* in output statement has a small run-time cost.
999
+ ``prev_item`` enables the construction of an exclusive scan.
1000
+
1001
+ For non-segmented scans, *output_statement* may also reference
1002
+ ``last_item``, which evaluates to the scan result of the last
1003
+ array entry.
1004
+ :arg is_segment_start_expr: A C expression, encoded as a string,
1005
+ resulting in a C ``bool`` value that determines whether a new
1006
+ scan segments starts at index *i*. If given, makes the scan a
1007
+ segmented scan. Has access to the current index ``i``, the result
1008
+ of *input_expr* as ``a``, and in addition may use *arguments* and
1009
+ *input_fetch_expr* variables just like *input_expr*.
1010
+
1011
+ If it returns true, then previous sums will not spill over into the
1012
+ item with index *i* or subsequent items.
1013
+ :arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*.
1014
+ An entry here has the effect of doing the equivalent of the following
1015
+ before input_expr::
1016
+
1017
+ ARG_NAME_TYPE NAME = ARG_NAME[i+OFFSET];
1018
+
1019
+ ``OFFSET`` is allowed to be 0 or -1, and ``ARG_NAME_TYPE`` is the type
1020
+ of ``ARG_NAME``.
1021
+ :arg preamble: |preamble|
1022
+
1023
+ The first array in the argument list determines the size of the index
1024
+ space over which the scan is carried out, and thus the values over
1025
+ which the index *i* occurring in a number of code fragments in
1026
+ arguments above will vary.
1027
+
1028
+ All code fragments further have access to N, the number of elements
1029
+ being processed in the scan.
1030
+ """
1031
+
1032
+ if index_dtype is None:
1033
+ index_dtype = np.dtype(np.int32)
1034
+
1035
+ if input_fetch_exprs is None:
1036
+ input_fetch_exprs = []
1037
+
1038
+ self.context: cl.Context = ctx
1039
+ self.dtype: np.dtype[Any]
1040
+ dtype = self.dtype = np.dtype(dtype)
1041
+
1042
+ if neutral is None:
1043
+ from warnings import warn
1044
+ warn("not specifying 'neutral' is deprecated and will lead to "
1045
+ "wrong results if your scan is not in-place or your "
1046
+ "'output_statement' does something otherwise non-trivial",
1047
+ stacklevel=2)
1048
+
1049
+ if dtype.itemsize % 4 != 0:
1050
+ raise TypeError("scan value type must have size divisible by 4 bytes")
1051
+
1052
+ self.index_dtype: np.dtype[np.integer] = np.dtype(index_dtype)
1053
+ if np.iinfo(self.index_dtype).min >= 0:
1054
+ raise TypeError("index_dtype must be signed")
1055
+
1056
+ if devices is None:
1057
+ devices = ctx.devices
1058
+ self.devices: Sequence[cl.Device] = devices
1059
+ self.options = options
1060
+
1061
+ from pyopencl.tools import parse_arg_list
1062
+ self.parsed_args: Sequence[DtypedArgument] = parse_arg_list(arguments)
1063
+ from pyopencl.tools import VectorArg
1064
+ self.first_array_idx: int = next(
1065
+ i for i, arg in enumerate(self.parsed_args)
1066
+ if isinstance(arg, VectorArg))
1067
+
1068
+ self.input_expr: str = input_expr
1069
+
1070
+ self.is_segment_start_expr: str | None = is_segment_start_expr
1071
+ self.is_segmented: bool = is_segment_start_expr is not None
1072
+ if is_segment_start_expr is not None:
1073
+ is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
1074
+
1075
+ self.output_statement: str = output_statement
1076
+
1077
+ for _name, _arg_name, ife_offset in input_fetch_exprs:
1078
+ if ife_offset not in [0, -1]:
1079
+ raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
1080
+ self.input_fetch_exprs: Sequence[tuple[str, str, int]] = input_fetch_exprs
1081
+
1082
+ arg_dtypes = {}
1083
+ arg_ctypes = {}
1084
+ for arg in self.parsed_args:
1085
+ arg_dtypes[arg.name] = arg.dtype
1086
+ arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
1087
+
1088
+ self.name_prefix: str = name_prefix
1089
+
1090
+ # {{{ set up shared code dict
1091
+
1092
+ from pyopencl.characterize import has_double_support
1093
+
1094
+ self.code_variables = {
1095
+ "np": np,
1096
+ "dtype_to_ctype": dtype_to_ctype,
1097
+ "preamble": preamble,
1098
+ "name_prefix": name_prefix,
1099
+ "index_dtype": self.index_dtype,
1100
+ "scan_dtype": dtype,
1101
+ "is_segmented": self.is_segmented,
1102
+ "arg_dtypes": arg_dtypes,
1103
+ "arg_ctypes": arg_ctypes,
1104
+ "scan_expr": _process_code_for_macro(scan_expr),
1105
+ "neutral": _process_code_for_macro(neutral),
1106
+ "is_gpu": bool(self.devices[0].type & cl.device_type.GPU),
1107
+ "double_support": all(
1108
+ has_double_support(dev) for dev in devices),
1109
+ }
1110
+
1111
+ index_typename = dtype_to_ctype(self.index_dtype)
1112
+ scan_typename = dtype_to_ctype(dtype)
1113
+
1114
+ # This key is meant to uniquely identify the non-device parameters for
1115
+ # the scan kernel.
1116
+ self.kernel_key = (
1117
+ self.dtype,
1118
+ tuple(arg.declarator() for arg in self.parsed_args),
1119
+ self.input_expr,
1120
+ scan_expr,
1121
+ neutral,
1122
+ output_statement,
1123
+ is_segment_start_expr,
1124
+ tuple(input_fetch_exprs),
1125
+ index_dtype,
1126
+ name_prefix,
1127
+ preamble,
1128
+ # These depend on dtype_to_ctype(), so their value is independent of
1129
+ # the other variables.
1130
+ index_typename,
1131
+ scan_typename,
1132
+ )
1133
+
1134
+ # }}}
1135
+
1136
+ self.use_lookbehind_update: bool = "prev_item" in self.output_statement
1137
+ self.store_segment_start_flags: bool = (
1138
+ self.is_segmented and self.use_lookbehind_update)
1139
+
1140
+ self.finish_setup()
1141
+
1142
+ # }}}
1143
+
1144
+ @abstractmethod
1145
+ def finish_setup(self) -> None:
1146
+ pass
1147
+
1148
+
1149
+ if not cl._PYOPENCL_NO_CACHE:
1150
+ generic_scan_kernel_cache: WriteOncePersistentDict[Any,
1151
+ tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
1152
+ _GeneratedFinalUpdateKernelInfo]] = \
1153
+ WriteOncePersistentDict(
1154
+ "pyopencl-generated-scan-kernel-cache-v1",
1155
+ key_builder=_NumpyTypesKeyBuilder(),
1156
+ in_mem_cache_size=0,
1157
+ safe_sync=False)
1158
+
1159
+
1160
+ class GenericScanKernel(GenericScanKernelBase):
1161
+ """Generates and executes code that performs prefix sums ("scans") on
1162
+ arbitrary types, with many possible tweaks.
1163
+
1164
+ Usage example::
1165
+
1166
+ from pyopencl.scan import GenericScanKernel
1167
+ knl = GenericScanKernel(
1168
+ context, np.int32,
1169
+ arguments="__global int *ary",
1170
+ input_expr="ary[i]",
1171
+ scan_expr="a+b", neutral="0",
1172
+ output_statement="ary[i+1] = item;")
1173
+
1174
+ a = cl.array.arange(queue, 10000, dtype=np.int32)
1175
+ knl(a, queue=queue)
1176
+
1177
+ .. automethod:: __init__
1178
+ .. automethod:: __call__
1179
+ """
1180
+
1181
+ def finish_setup(self) -> None:
1182
+ # Before generating the kernel, see if it's cached.
1183
+ from pyopencl.cache import get_device_cache_id
1184
+ devices_key = tuple(get_device_cache_id(device)
1185
+ for device in self.devices)
1186
+
1187
+ cache_key = (self.kernel_key, devices_key)
1188
+ from_cache = False
1189
+
1190
+ if not cl._PYOPENCL_NO_CACHE:
1191
+ try:
1192
+ result = generic_scan_kernel_cache[cache_key]
1193
+ from_cache = True
1194
+ logger.debug(
1195
+ "cache hit for generated scan kernel '%s'", self.name_prefix)
1196
+ (
1197
+ self.first_level_scan_gen_info,
1198
+ self.second_level_scan_gen_info,
1199
+ self.final_update_gen_info) = result
1200
+ except KeyError:
1201
+ pass
1202
+
1203
+ if not from_cache:
1204
+ logger.debug(
1205
+ "cache miss for generated scan kernel '%s'", self.name_prefix)
1206
+ self._finish_setup_impl()
1207
+
1208
+ result = (self.first_level_scan_gen_info,
1209
+ self.second_level_scan_gen_info,
1210
+ self.final_update_gen_info)
1211
+
1212
+ if not cl._PYOPENCL_NO_CACHE:
1213
+ generic_scan_kernel_cache.store_if_not_present(cache_key, result)
1214
+
1215
+ # Build the kernels.
1216
+ self.first_level_scan_info = self.first_level_scan_gen_info.build(
1217
+ self.context, self.options)
1218
+ del self.first_level_scan_gen_info
1219
+
1220
+ self.second_level_scan_info = self.second_level_scan_gen_info.build(
1221
+ self.context, self.options)
1222
+ del self.second_level_scan_gen_info
1223
+
1224
+ self.final_update_info = self.final_update_gen_info.build(
1225
+ self.context, self.options)
1226
+ del self.final_update_gen_info
1227
+
1228
+ def _finish_setup_impl(self) -> None:
1229
+ # {{{ find usable workgroup/k-group size, build first-level scan
1230
+
1231
+ trip_count = 0
1232
+
1233
+ avail_local_mem = min(
1234
+ dev.local_mem_size
1235
+ for dev in self.devices)
1236
+
1237
+ if "CUDA" in self.devices[0].platform.name:
1238
+ # not sure where these go, but roughly this much seems unavailable.
1239
+ avail_local_mem -= 0x400
1240
+
1241
+ is_cpu = bool(self.devices[0].type & cl.device_type.CPU)
1242
+ is_gpu = bool(self.devices[0].type & cl.device_type.GPU)
1243
+
1244
+ if is_cpu:
1245
+ # (about the widest vector a CPU can support, also taking
1246
+ # into account that CPUs don't hide latency by large work groups
1247
+ max_scan_wg_size = 16
1248
+ wg_size_multiples = 4
1249
+ else:
1250
+ max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
1251
+ wg_size_multiples = 64
1252
+
1253
+ # Intel beignet fails "Out of shared local memory" in test_scan int64
1254
+ # and asserts in test_sort with this enabled:
1255
+ # https://github.com/inducer/pyopencl/pull/238
1256
+ # A beignet bug report (outside of pyopencl) suggests packed structs
1257
+ # (which this is) can even give wrong results:
1258
+ # https://bugs.freedesktop.org/show_bug.cgi?id=98717
1259
+ # TODO: does this also affect Intel Compute Runtime?
1260
+ use_bank_conflict_avoidance = (
1261
+ self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0
1262
+ and is_gpu
1263
+ and "beignet" not in self.devices[0].platform.version.lower())
1264
+
1265
+ # k_group_size should be a power of two because of in-kernel
1266
+ # division by that number.
1267
+
1268
+ solutions: list[tuple[int, int, int]] = []
1269
+ for k_exp in range(0, 9):
1270
+ for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
1271
+ wg_size_multiples):
1272
+
1273
+ k_group_size = 2**k_exp
1274
+ lmem_use = self.get_local_mem_use(wg_size, k_group_size,
1275
+ use_bank_conflict_avoidance)
1276
+ if lmem_use <= avail_local_mem:
1277
+ solutions.append((wg_size*k_group_size, k_group_size, wg_size))
1278
+
1279
+ if is_gpu:
1280
+ for wg_size_floor in [256, 192, 128]:
1281
+ have_sol_above_floor = any(wg_size >= wg_size_floor
1282
+ for _, _, wg_size in solutions)
1283
+
1284
+ if have_sol_above_floor:
1285
+ # delete all solutions not meeting the wg size floor
1286
+ solutions = [(total, try_k_group_size, try_wg_size)
1287
+ for total, try_k_group_size, try_wg_size in solutions
1288
+ if try_wg_size >= wg_size_floor]
1289
+ break
1290
+
1291
+ _, k_group_size, max_scan_wg_size = max(solutions)
1292
+
1293
+ while True:
1294
+ candidate_scan_gen_info = self.generate_scan_kernel(
1295
+ max_scan_wg_size, self.parsed_args,
1296
+ _process_code_for_macro(self.input_expr),
1297
+ self.is_segment_start_expr,
1298
+ input_fetch_exprs=self.input_fetch_exprs,
1299
+ is_first_level=True,
1300
+ store_segment_start_flags=self.store_segment_start_flags,
1301
+ k_group_size=k_group_size,
1302
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance)
1303
+
1304
+ candidate_scan_info = candidate_scan_gen_info.build(
1305
+ self.context, self.options)
1306
+
1307
+ # Will this device actually let us execute this kernel
1308
+ # at the desired work group size? Building it is the
1309
+ # only way to find out.
1310
+ kernel_max_wg_size = min(
1311
+ candidate_scan_info.kernel.get_work_group_info(
1312
+ cl.kernel_work_group_info.WORK_GROUP_SIZE,
1313
+ dev)
1314
+ for dev in self.devices)
1315
+
1316
+ if candidate_scan_info.wg_size <= kernel_max_wg_size:
1317
+ break
1318
+ else:
1319
+ max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
1320
+
1321
+ trip_count += 1
1322
+ assert trip_count <= 20
1323
+
1324
+ self.first_level_scan_gen_info = candidate_scan_gen_info
1325
+ assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
1326
+ == candidate_scan_info.wg_size)
1327
+
1328
+ # }}}
1329
+
1330
+ # {{{ build second-level scan
1331
+
1332
+ from pyopencl.tools import VectorArg
1333
+ second_level_arguments = [
1334
+ *self.parsed_args,
1335
+ VectorArg(self.dtype, "interval_sums"),
1336
+ ]
1337
+
1338
+ second_level_build_kwargs: dict[str, str | None] = {}
1339
+ if self.is_segmented:
1340
+ second_level_arguments.append(
1341
+ VectorArg(self.index_dtype,
1342
+ "g_first_segment_start_in_interval_input"))
1343
+
1344
+ # is_segment_start_expr answers the question "should previous sums
1345
+ # spill over into this item". And since
1346
+ # g_first_segment_start_in_interval_input answers the question if a
1347
+ # segment boundary was found in an interval of data, then if not,
1348
+ # it's ok to spill over.
1349
+ second_level_build_kwargs["is_segment_start_expr"] = \
1350
+ "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
1351
+ else:
1352
+ second_level_build_kwargs["is_segment_start_expr"] = None
1353
+
1354
+ self.second_level_scan_gen_info = self.generate_scan_kernel(
1355
+ max_scan_wg_size,
1356
+ arguments=second_level_arguments,
1357
+ input_expr="interval_sums[i]",
1358
+ input_fetch_exprs=[],
1359
+ is_first_level=False,
1360
+ store_segment_start_flags=False,
1361
+ k_group_size=k_group_size,
1362
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1363
+ **second_level_build_kwargs)
1364
+
1365
+ # }}}
1366
+
1367
+ # {{{ generate final update kernel
1368
+
1369
+ update_wg_size = min(max_scan_wg_size, 256)
1370
+
1371
+ final_update_tpl = _make_template(UPDATE_SOURCE)
1372
+ final_update_src = str(final_update_tpl.render(
1373
+ wg_size=update_wg_size,
1374
+ output_statement=self.output_statement,
1375
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1376
+ argument_signature=", ".join(
1377
+ arg.declarator() for arg in self.parsed_args),
1378
+ is_segment_start_expr=self.is_segment_start_expr,
1379
+ input_expr=_process_code_for_macro(self.input_expr),
1380
+ use_lookbehind_update=self.use_lookbehind_update,
1381
+ **self.code_variables))
1382
+
1383
+ update_scalar_arg_dtypes = [
1384
+ *get_arg_list_scalar_arg_dtypes(self.parsed_args),
1385
+ self.index_dtype, self.index_dtype, None, None]
1386
+
1387
+ if self.is_segmented:
1388
+ # g_first_segment_start_in_interval
1389
+ update_scalar_arg_dtypes.append(None)
1390
+ if self.store_segment_start_flags:
1391
+ update_scalar_arg_dtypes.append(None) # g_segment_start_flags
1392
+
1393
+ self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
1394
+ final_update_src,
1395
+ self.name_prefix + "_final_update",
1396
+ update_scalar_arg_dtypes,
1397
+ update_wg_size)
1398
+
1399
+ # }}}
1400
+
1401
+ # {{{ scan kernel build/properties
1402
+
1403
+ def get_local_mem_use(
1404
+ self, k_group_size: int, wg_size: int,
1405
+ use_bank_conflict_avoidance: bool) -> int:
1406
+ arg_dtypes = {}
1407
+ for arg in self.parsed_args:
1408
+ arg_dtypes[arg.name] = arg.dtype
1409
+
1410
+ fetch_expr_offsets: dict[str, set[int]] = {}
1411
+ for _name, arg_name, ife_offset in self.input_fetch_exprs:
1412
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
1413
+
1414
+ itemsize = self.dtype.itemsize
1415
+ if use_bank_conflict_avoidance:
1416
+ itemsize += 4
1417
+
1418
+ return (
1419
+ # ldata
1420
+ itemsize*(k_group_size+1)*(wg_size+1)
1421
+
1422
+ # l_segment_start_flags
1423
+ + k_group_size*wg_size
1424
+
1425
+ # l_first_segment_start_in_subtree
1426
+ + self.index_dtype.itemsize*wg_size
1427
+
1428
+ + k_group_size*wg_size*sum(
1429
+ arg_dtypes[arg_name].itemsize
1430
+ for arg_name, ife_offsets in list(fetch_expr_offsets.items())
1431
+ if -1 in ife_offsets or len(ife_offsets) > 1))
1432
+
1433
+ def generate_scan_kernel(
1434
+ self,
1435
+ max_wg_size: int,
1436
+ arguments: Sequence[DtypedArgument],
1437
+ input_expr: str,
1438
+ is_segment_start_expr: str | None,
1439
+ input_fetch_exprs: Sequence[tuple[str, str, int]],
1440
+ is_first_level: bool,
1441
+ store_segment_start_flags: bool,
1442
+ k_group_size: int,
1443
+ use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
1444
+ scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
1445
+
1446
+ # Empirically found on Nv hardware: no need to be bigger than this size
1447
+ wg_size = _round_down_to_power_of_2(
1448
+ min(max_wg_size, 256))
1449
+
1450
+ kernel_name = cast("str", self.code_variables["name_prefix"])
1451
+ if is_first_level:
1452
+ kernel_name += "_lev1"
1453
+ else:
1454
+ kernel_name += "_lev2"
1455
+
1456
+ scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
1457
+ scan_src = str(scan_tpl.render(
1458
+ wg_size=wg_size,
1459
+ input_expr=input_expr,
1460
+ k_group_size=k_group_size,
1461
+ arg_offset_adjustment=get_arg_offset_adjuster_code(arguments),
1462
+ argument_signature=", ".join(arg.declarator() for arg in arguments),
1463
+ is_segment_start_expr=is_segment_start_expr,
1464
+ input_fetch_exprs=input_fetch_exprs,
1465
+ is_first_level=is_first_level,
1466
+ store_segment_start_flags=store_segment_start_flags,
1467
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1468
+ kernel_name=kernel_name,
1469
+ **self.code_variables))
1470
+
1471
+ scalar_arg_dtypes.extend(
1472
+ (None, self.index_dtype, self.index_dtype))
1473
+ if is_first_level:
1474
+ scalar_arg_dtypes.append(None) # interval_results
1475
+ if self.is_segmented and is_first_level:
1476
+ scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval
1477
+ if store_segment_start_flags:
1478
+ scalar_arg_dtypes.append(None) # g_segment_start_flags
1479
+
1480
+ return _GeneratedScanKernelInfo(
1481
+ scan_src=scan_src,
1482
+ kernel_name=kernel_name,
1483
+ scalar_arg_dtypes=scalar_arg_dtypes,
1484
+ wg_size=wg_size,
1485
+ k_group_size=k_group_size)
1486
+
1487
+ # }}}
1488
+
1489
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1490
+ """
1491
+ |std-enqueue-blurb|
1492
+
1493
+ .. note::
1494
+
1495
+ The returned :class:`pyopencl.Event` corresponds only to part of the
1496
+ execution of the scan. It is not suitable for profiling.
1497
+
1498
+ :arg queue: queue on which to execute the scan. If not given, the
1499
+ queue of the first :class:`pyopencl.array.Array` in *args* is used
1500
+ :arg allocator: an allocator for the temporary arrays and results. If
1501
+ not given, the allocator of the first :class:`pyopencl.array.Array`
1502
+ in *args* is used.
1503
+ :arg size: specify the length of the scan to be carried out. If not
1504
+ given, this length is inferred from the first argument
1505
+ :arg wait_for: a :class:`list` of events to wait for.
1506
+ """
1507
+
1508
+ # {{{ argument processing
1509
+
1510
+ allocator = kwargs.get("allocator")
1511
+ queue = kwargs.get("queue")
1512
+ n = kwargs.get("size")
1513
+ wait_for = kwargs.get("wait_for")
1514
+
1515
+ if wait_for is None:
1516
+ wait_for = []
1517
+ else:
1518
+ wait_for = list(wait_for)
1519
+
1520
+ if len(args) != len(self.parsed_args):
1521
+ raise TypeError(
1522
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1523
+
1524
+ first_array = args[self.first_array_idx]
1525
+ allocator = allocator or first_array.allocator
1526
+ queue = queue or first_array.queue
1527
+
1528
+ if n is None:
1529
+ n, = first_array.shape
1530
+
1531
+ if n == 0:
1532
+ # We're done here. (But pretend to return an event.)
1533
+ return cl.enqueue_marker(queue, wait_for=wait_for)
1534
+
1535
+ data_args = []
1536
+ for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
1537
+ from pyopencl.tools import VectorArg
1538
+ if isinstance(arg_descr, VectorArg):
1539
+ data_args.append(arg_val.base_data)
1540
+ if arg_descr.with_offset:
1541
+ data_args.append(arg_val.offset)
1542
+ wait_for.extend(arg_val.events)
1543
+ else:
1544
+ data_args.append(arg_val)
1545
+
1546
+ # }}}
1547
+
1548
+ l1_info = self.first_level_scan_info
1549
+ l2_info = self.second_level_scan_info
1550
+
1551
+ # see CL source above for terminology
1552
+ unit_size = l1_info.wg_size * l1_info.k_group_size
1553
+ max_intervals = 3*max(dev.max_compute_units for dev in self.devices)
1554
+
1555
+ from pytools import uniform_interval_splitting
1556
+ interval_size, num_intervals = uniform_interval_splitting(
1557
+ n, unit_size, max_intervals)
1558
+
1559
+ # {{{ allocate some buffers
1560
+
1561
+ interval_results = cl_array.empty(queue,
1562
+ num_intervals, dtype=self.dtype,
1563
+ allocator=allocator)
1564
+
1565
+ partial_scan_buffer = cl_array.empty(
1566
+ queue, n, dtype=self.dtype,
1567
+ allocator=allocator)
1568
+
1569
+ if self.store_segment_start_flags:
1570
+ segment_start_flags = cl_array.empty(
1571
+ queue, n, dtype=np.bool_,
1572
+ allocator=allocator)
1573
+
1574
+ # }}}
1575
+
1576
+ # {{{ first level scan of interval (one interval per block)
1577
+
1578
+ scan1_args = [
1579
+ *data_args,
1580
+ partial_scan_buffer.data, n, interval_size, interval_results.data,
1581
+ ]
1582
+
1583
+ if self.is_segmented:
1584
+ first_segment_start_in_interval = cl_array.empty(queue,
1585
+ num_intervals, dtype=self.index_dtype,
1586
+ allocator=allocator)
1587
+ scan1_args.append(first_segment_start_in_interval.data)
1588
+
1589
+ if self.store_segment_start_flags:
1590
+ scan1_args.append(segment_start_flags.data)
1591
+
1592
+ l1_evt = l1_info.kernel(
1593
+ queue, (num_intervals,), (l1_info.wg_size,),
1594
+ *scan1_args, g_times_l=True, wait_for=wait_for)
1595
+
1596
+ # }}}
1597
+
1598
+ # {{{ second level scan of per-interval results
1599
+
1600
+ # can scan at most one interval
1601
+ assert interval_size >= num_intervals
1602
+
1603
+ scan2_args = [
1604
+ *data_args,
1605
+ interval_results.data, # interval_sums
1606
+ ]
1607
+
1608
+ if self.is_segmented:
1609
+ scan2_args.append(first_segment_start_in_interval.data)
1610
+ scan2_args = [
1611
+ *scan2_args,
1612
+ interval_results.data, # partial_scan_buffer
1613
+ num_intervals, interval_size]
1614
+
1615
+ l2_evt = l2_info.kernel(
1616
+ queue, (1,), (l1_info.wg_size,),
1617
+ *scan2_args, g_times_l=True, wait_for=[l1_evt])
1618
+
1619
+ # }}}
1620
+
1621
+ # {{{ update intervals with result of interval scan
1622
+
1623
+ upd_args = [
1624
+ *data_args,
1625
+ n, interval_size, interval_results.data, partial_scan_buffer.data]
1626
+ if self.is_segmented:
1627
+ upd_args.append(first_segment_start_in_interval.data)
1628
+ if self.store_segment_start_flags:
1629
+ upd_args.append(segment_start_flags.data)
1630
+
1631
+ return self.final_update_info.kernel(
1632
+ queue, (num_intervals,),
1633
+ (self.final_update_info.update_wg_size,),
1634
+ *upd_args, g_times_l=True, wait_for=[l2_evt])
1635
+
1636
+ # }}}
1637
+
1638
+ # }}}
1639
+
1640
+
1641
+ # {{{ debug kernel
1642
+
1643
+ DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + r"""//CL//
1644
+
1645
+ KERNEL
1646
+ REQD_WG_SIZE(1, 1, 1)
1647
+ void ${name_prefix}_debug_scan(
1648
+ __global scan_type *scan_tmp,
1649
+ ${argument_signature},
1650
+ const index_type N)
1651
+ {
1652
+ scan_type current = ${neutral};
1653
+ scan_type prev;
1654
+
1655
+ ${arg_offset_adjustment}
1656
+
1657
+ for (index_type i = 0; i < N; ++i)
1658
+ {
1659
+ %for name, arg_name, ife_offset in input_fetch_exprs:
1660
+ ${arg_ctypes[arg_name]} ${name};
1661
+ %if ife_offset < 0:
1662
+ if (i+${ife_offset} >= 0)
1663
+ ${name} = ${arg_name}[i+${ife_offset}];
1664
+ %else:
1665
+ ${name} = ${arg_name}[i];
1666
+ %endif
1667
+ %endfor
1668
+
1669
+ scan_type my_val = INPUT_EXPR(i);
1670
+
1671
+ prev = current;
1672
+ %if is_segmented:
1673
+ bool is_seg_start = IS_SEG_START(i, my_val);
1674
+ %endif
1675
+
1676
+ current = SCAN_EXPR(prev, my_val,
1677
+ %if is_segmented:
1678
+ is_seg_start
1679
+ %else:
1680
+ false
1681
+ %endif
1682
+ );
1683
+ scan_tmp[i] = current;
1684
+ }
1685
+
1686
+ scan_type last_item = scan_tmp[N-1];
1687
+
1688
+ for (index_type i = 0; i < N; ++i)
1689
+ {
1690
+ scan_type item = scan_tmp[i];
1691
+ scan_type prev_item;
1692
+ if (i)
1693
+ prev_item = scan_tmp[i-1];
1694
+ else
1695
+ prev_item = ${neutral};
1696
+
1697
+ {
1698
+ ${output_statement};
1699
+ }
1700
+ }
1701
+ }
1702
+ """
1703
+
1704
+
1705
+ class GenericDebugScanKernel(GenericScanKernelBase):
1706
+ """
1707
+ Performs the same function and has the same interface as
1708
+ :class:`GenericScanKernel`, but uses a dead-simple, sequential scan. Works
1709
+ best on CPU platforms, and helps isolate bugs in scans by removing the
1710
+ potential for issues originating in parallel execution.
1711
+
1712
+ .. automethod:: __call__
1713
+ """
1714
+
1715
+ def finish_setup(self) -> None:
1716
+ scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
1717
+ scan_src = str(scan_tpl.render(
1718
+ output_statement=self.output_statement,
1719
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1720
+ argument_signature=", ".join(
1721
+ arg.declarator() for arg in self.parsed_args),
1722
+ is_segment_start_expr=self.is_segment_start_expr,
1723
+ input_expr=_process_code_for_macro(self.input_expr),
1724
+ input_fetch_exprs=self.input_fetch_exprs,
1725
+ wg_size=1,
1726
+ **self.code_variables))
1727
+
1728
+ scan_prg = cl.Program(self.context, scan_src).build(self.options)
1729
+ self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
1730
+ scalar_arg_dtypes = [
1731
+ None,
1732
+ *get_arg_list_scalar_arg_dtypes(self.parsed_args),
1733
+ self.index_dtype,
1734
+ ]
1735
+ self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
1736
+
1737
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1738
+ """See :meth:`GenericScanKernel.__call__`."""
1739
+
1740
+ # {{{ argument processing
1741
+
1742
+ allocator = kwargs.get("allocator")
1743
+ queue = kwargs.get("queue")
1744
+ n = kwargs.get("size")
1745
+ wait_for = kwargs.get("wait_for")
1746
+
1747
+ if wait_for is None:
1748
+ wait_for = []
1749
+ else:
1750
+ # We'll be modifying it below.
1751
+ wait_for = list(wait_for)
1752
+
1753
+ if len(args) != len(self.parsed_args):
1754
+ raise TypeError(
1755
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1756
+
1757
+ first_array = args[self.first_array_idx]
1758
+ allocator = allocator or first_array.allocator
1759
+ queue = queue or first_array.queue
1760
+
1761
+ if n is None:
1762
+ n, = first_array.shape
1763
+
1764
+ scan_tmp = cl_array.empty(queue,
1765
+ n, dtype=self.dtype,
1766
+ allocator=allocator)
1767
+
1768
+ data_args = [scan_tmp.data]
1769
+ from pyopencl.tools import VectorArg
1770
+ for arg_descr, arg_val in zip(self.parsed_args, args, strict=True):
1771
+ if isinstance(arg_descr, VectorArg):
1772
+ data_args.append(arg_val.base_data)
1773
+ if arg_descr.with_offset:
1774
+ data_args.append(arg_val.offset)
1775
+ wait_for.extend(arg_val.events)
1776
+ else:
1777
+ data_args.append(arg_val)
1778
+
1779
+ # }}}
1780
+
1781
+ return self.kernel(queue, (1,), (1,), *([*data_args, n]), wait_for=wait_for)
1782
+
1783
+ # }}}
1784
+
1785
+
1786
+ # {{{ compatibility interface
1787
+
1788
+ class _LegacyScanKernelBase(GenericScanKernel):
1789
+ def __init__(self, ctx, dtype,
1790
+ scan_expr, neutral=None,
1791
+ name_prefix="scan", options=None, preamble="", devices=None):
1792
+ scan_ctype = dtype_to_ctype(dtype)
1793
+ GenericScanKernel.__init__(self,
1794
+ ctx, dtype,
1795
+ arguments="__global {} *input_ary, __global {} *output_ary".format(
1796
+ scan_ctype, scan_ctype),
1797
+ input_expr="input_ary[i]",
1798
+ scan_expr=scan_expr,
1799
+ neutral=neutral,
1800
+ output_statement=self.ary_output_statement,
1801
+ options=options, preamble=preamble, devices=devices)
1802
+
1803
+ @property
1804
+ def ary_output_statement(self):
1805
+ raise NotImplementedError
1806
+
1807
+ def __call__(self, input_ary, output_ary=None, allocator=None, queue=None):
1808
+ allocator = allocator or input_ary.allocator
1809
+ queue = queue or input_ary.queue or output_ary.queue
1810
+
1811
+ if output_ary is None:
1812
+ output_ary = input_ary
1813
+
1814
+ if isinstance(output_ary, (str, str)) and output_ary == "new":
1815
+ output_ary = cl_array.empty_like(input_ary, allocator=allocator)
1816
+
1817
+ if input_ary.shape != output_ary.shape:
1818
+ raise ValueError("input and output must have the same shape")
1819
+
1820
+ if not input_ary.flags.forc:
1821
+ raise RuntimeError("ScanKernel cannot "
1822
+ "deal with non-contiguous arrays")
1823
+
1824
+ n, = input_ary.shape
1825
+
1826
+ if not n:
1827
+ return output_ary
1828
+
1829
+ GenericScanKernel.__call__(self,
1830
+ input_ary, output_ary, allocator=allocator, queue=queue)
1831
+
1832
+ return output_ary
1833
+
1834
+
1835
+ class InclusiveScanKernel(_LegacyScanKernelBase):
1836
+ ary_output_statement = "output_ary[i] = item;"
1837
+
1838
+
1839
+ class ExclusiveScanKernel(_LegacyScanKernelBase):
1840
+ ary_output_statement = "output_ary[i] = prev_item;"
1841
+
1842
+ # }}}
1843
+
1844
+
1845
+ # {{{ template
1846
+
1847
+ class ScanTemplate(KernelTemplateBase):
1848
+ def __init__(
1849
+ self,
1850
+ arguments: str | list[DtypedArgument],
1851
+ input_expr: str,
1852
+ scan_expr: str,
1853
+ neutral: str | None,
1854
+ output_statement: str,
1855
+ is_segment_start_expr: str | None = None,
1856
+ input_fetch_exprs: list[tuple[str, str, int]] | None = None,
1857
+ name_prefix: str = "scan",
1858
+ preamble: str = "",
1859
+ template_processor: Any = None) -> None:
1860
+ super().__init__(template_processor=template_processor)
1861
+
1862
+ if input_fetch_exprs is None:
1863
+ input_fetch_exprs = []
1864
+
1865
+ self.arguments = arguments
1866
+ self.input_expr = input_expr
1867
+ self.scan_expr = scan_expr
1868
+ self.neutral = neutral
1869
+ self.output_statement = output_statement
1870
+ self.is_segment_start_expr = is_segment_start_expr
1871
+ self.input_fetch_exprs = input_fetch_exprs
1872
+ self.name_prefix = name_prefix
1873
+ self.preamble = preamble
1874
+
1875
+ def build_inner(self, context, type_aliases=(), var_values=(),
1876
+ more_preamble="", more_arguments=(), declare_types=(),
1877
+ options=None, devices=None, scan_cls=GenericScanKernel):
1878
+ renderer = self.get_renderer(type_aliases, var_values, context, options)
1879
+
1880
+ arg_list = renderer.render_argument_list(self.arguments, more_arguments)
1881
+
1882
+ type_decl_preamble = renderer.get_type_decl_preamble(
1883
+ context.devices[0], declare_types, arg_list)
1884
+
1885
+ return scan_cls(context, renderer.type_aliases["scan_t"],
1886
+ renderer.render_argument_list(self.arguments, more_arguments),
1887
+ renderer(self.input_expr), renderer(self.scan_expr),
1888
+ renderer(self.neutral), renderer(self.output_statement),
1889
+ is_segment_start_expr=renderer(self.is_segment_start_expr),
1890
+ input_fetch_exprs=self.input_fetch_exprs,
1891
+ index_dtype=renderer.type_aliases.get("index_t", np.int32),
1892
+ name_prefix=renderer(self.name_prefix), options=options,
1893
+ preamble=(
1894
+ type_decl_preamble
1895
+ + "\n"
1896
+ + renderer(self.preamble + "\n" + more_preamble)),
1897
+ devices=devices)
1898
+
1899
+ # }}}
1900
+
1901
+
1902
+ # {{{ 'canned' scan kernels
1903
+
1904
+ @context_dependent_memoize
1905
+ def get_cumsum_kernel(context, input_dtype, output_dtype):
1906
+ from pyopencl.tools import VectorArg
1907
+ return GenericScanKernel(
1908
+ context, output_dtype,
1909
+ arguments=[
1910
+ VectorArg(input_dtype, "input"),
1911
+ VectorArg(output_dtype, "output"),
1912
+ ],
1913
+ input_expr="input[i]",
1914
+ scan_expr="a+b", neutral="0",
1915
+ output_statement="""
1916
+ output[i] = item;
1917
+ """)
1918
+
1919
+ # }}}
1920
+
1921
+ # vim: filetype=pyopencl:fdm=marker