pyopencl 2025.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyopencl might be problematic. Click here for more details.

Files changed (42) hide show
  1. pyopencl/__init__.py +2410 -0
  2. pyopencl/_cl.cpython-313-darwin.so +0 -0
  3. pyopencl/_cluda.py +54 -0
  4. pyopencl/_mymako.py +14 -0
  5. pyopencl/algorithm.py +1449 -0
  6. pyopencl/array.py +3362 -0
  7. pyopencl/bitonic_sort.py +242 -0
  8. pyopencl/bitonic_sort_templates.py +594 -0
  9. pyopencl/cache.py +535 -0
  10. pyopencl/capture_call.py +177 -0
  11. pyopencl/characterize/__init__.py +456 -0
  12. pyopencl/characterize/performance.py +237 -0
  13. pyopencl/cl/pyopencl-airy.cl +324 -0
  14. pyopencl/cl/pyopencl-bessel-j-complex.cl +238 -0
  15. pyopencl/cl/pyopencl-bessel-j.cl +1084 -0
  16. pyopencl/cl/pyopencl-bessel-y.cl +435 -0
  17. pyopencl/cl/pyopencl-complex.h +303 -0
  18. pyopencl/cl/pyopencl-eval-tbl.cl +120 -0
  19. pyopencl/cl/pyopencl-hankel-complex.cl +444 -0
  20. pyopencl/cl/pyopencl-random123/array.h +325 -0
  21. pyopencl/cl/pyopencl-random123/openclfeatures.h +93 -0
  22. pyopencl/cl/pyopencl-random123/philox.cl +486 -0
  23. pyopencl/cl/pyopencl-random123/threefry.cl +864 -0
  24. pyopencl/clmath.py +280 -0
  25. pyopencl/clrandom.py +409 -0
  26. pyopencl/cltypes.py +137 -0
  27. pyopencl/compyte/.gitignore +21 -0
  28. pyopencl/compyte/__init__.py +0 -0
  29. pyopencl/compyte/array.py +214 -0
  30. pyopencl/compyte/dtypes.py +290 -0
  31. pyopencl/compyte/pyproject.toml +54 -0
  32. pyopencl/elementwise.py +1171 -0
  33. pyopencl/invoker.py +421 -0
  34. pyopencl/ipython_ext.py +68 -0
  35. pyopencl/reduction.py +786 -0
  36. pyopencl/scan.py +1915 -0
  37. pyopencl/tools.py +1527 -0
  38. pyopencl/version.py +9 -0
  39. pyopencl-2025.1.dist-info/METADATA +108 -0
  40. pyopencl-2025.1.dist-info/RECORD +42 -0
  41. pyopencl-2025.1.dist-info/WHEEL +5 -0
  42. pyopencl-2025.1.dist-info/licenses/LICENSE +282 -0
pyopencl/scan.py ADDED
@@ -0,0 +1,1915 @@
1
+ """Scan primitive."""
2
+
3
+
4
+ __copyright__ = """
5
+ Copyright 2011-2012 Andreas Kloeckner
6
+ Copyright 2008-2011 NVIDIA Corporation
7
+ """
8
+
9
+ __license__ = """
10
+ Licensed under the Apache License, Version 2.0 (the "License");
11
+ you may not use this file except in compliance with the License.
12
+ You may obtain a copy of the License at
13
+
14
+ https://www.apache.org/licenses/LICENSE-2.0
15
+
16
+ Unless required by applicable law or agreed to in writing, software
17
+ distributed under the License is distributed on an "AS IS" BASIS,
18
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ See the License for the specific language governing permissions and
20
+ limitations under the License.
21
+
22
+ Derived from code within the Thrust project, https://github.com/NVIDIA/thrust
23
+ """
24
+
25
+ import logging
26
+ from abc import ABC, abstractmethod
27
+ from dataclasses import dataclass
28
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
29
+
30
+ import numpy as np
31
+
32
+ from pytools.persistent_dict import WriteOncePersistentDict
33
+
34
+ import pyopencl as cl
35
+ import pyopencl._mymako as mako
36
+ import pyopencl.array
37
+ from pyopencl._cluda import CLUDA_PREAMBLE
38
+ from pyopencl.tools import (
39
+ DtypedArgument,
40
+ KernelTemplateBase,
41
+ _NumpyTypesKeyBuilder,
42
+ _process_code_for_macro,
43
+ bitlog2,
44
+ context_dependent_memoize,
45
+ dtype_to_ctype,
46
+ get_arg_list_scalar_arg_dtypes,
47
+ get_arg_offset_adjuster_code,
48
+ )
49
+
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+
54
+ # {{{ preamble
55
+
56
+ SHARED_PREAMBLE = CLUDA_PREAMBLE + """//CL//
57
+ #define WG_SIZE ${wg_size}
58
+
59
+ #define SCAN_EXPR(a, b, across_seg_boundary) ${scan_expr}
60
+ #define INPUT_EXPR(i) (${input_expr})
61
+ %if is_segmented:
62
+ #define IS_SEG_START(i, a) (${is_segment_start_expr})
63
+ %endif
64
+
65
+ ${preamble}
66
+
67
+ typedef ${dtype_to_ctype(scan_dtype)} scan_type;
68
+ typedef ${dtype_to_ctype(index_dtype)} index_type;
69
+
70
+ // NO_SEG_BOUNDARY is the largest representable integer in index_type.
71
+ // This assumption is used in code below.
72
+ #define NO_SEG_BOUNDARY ${str(np.iinfo(index_dtype).max)}
73
+ """
74
+
75
+ # }}}
76
+
77
+ # {{{ main scan code
78
+
79
+ # Algorithm: Each work group is responsible for one contiguous
80
+ # 'interval'. There are just enough intervals to fill all compute
81
+ # units. Intervals are split into 'units'. A unit is what gets
82
+ # worked on in parallel by one work group.
83
+ #
84
+ # in index space:
85
+ # interval > unit > local-parallel > k-group
86
+ #
87
+ # (Note that there is also a transpose in here: The data is read
88
+ # with local ids along linear index order.)
89
+ #
90
+ # Each unit has two axes--the local-id axis and the k axis.
91
+ #
92
+ # unit 0:
93
+ # | | | | | | | | | | ----> lid
94
+ # | | | | | | | | | |
95
+ # | | | | | | | | | |
96
+ # | | | | | | | | | |
97
+ # | | | | | | | | | |
98
+ #
99
+ # |
100
+ # v k (fastest-moving in linear index)
101
+ #
102
+ # unit 1:
103
+ # | | | | | | | | | | ----> lid
104
+ # | | | | | | | | | |
105
+ # | | | | | | | | | |
106
+ # | | | | | | | | | |
107
+ # | | | | | | | | | |
108
+ #
109
+ # |
110
+ # v k (fastest-moving in linear index)
111
+ #
112
+ # ...
113
+ #
114
+ # At a device-global level, this is a three-phase algorithm, in
115
+ # which first each interval does its local scan, then a scan
116
+ # across intervals exchanges data globally, and the final update
117
+ # adds the exchanged sums to each interval.
118
+ #
119
+ # Exclusive scan is realized by allowing look-behind (access to the
120
+ # preceding item) in the final update, by means of a local shift.
121
+ #
122
+ # NOTE: All segment_start_in_X indices are relative to the start
123
+ # of the array.
124
+
125
+ SCAN_INTERVALS_SOURCE = SHARED_PREAMBLE + r"""//CL//
126
+
127
+ #define K ${k_group_size}
128
+
129
+ // #define DEBUG
130
+ #ifdef DEBUG
131
+ #define pycl_printf(ARGS) printf ARGS
132
+ #else
133
+ #define pycl_printf(ARGS) /* */
134
+ #endif
135
+
136
+ KERNEL
137
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
138
+ void ${kernel_name}(
139
+ ${argument_signature},
140
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer,
141
+ const index_type N,
142
+ const index_type interval_size
143
+ %if is_first_level:
144
+ , GLOBAL_MEM scan_type *restrict interval_results
145
+ %endif
146
+ %if is_segmented and is_first_level:
147
+ // NO_SEG_BOUNDARY if no segment boundary in interval.
148
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
149
+ %endif
150
+ %if store_segment_start_flags:
151
+ , GLOBAL_MEM char *restrict g_segment_start_flags
152
+ %endif
153
+ )
154
+ {
155
+ ${arg_offset_adjustment}
156
+
157
+ // index K in first dimension used for carry storage
158
+ %if use_bank_conflict_avoidance:
159
+ // Avoid bank conflicts by adding a single 32-bit value to the size of
160
+ // the scan type.
161
+ struct __attribute__ ((__packed__)) wrapped_scan_type
162
+ {
163
+ scan_type value;
164
+ int dummy;
165
+ };
166
+ %else:
167
+ struct wrapped_scan_type
168
+ {
169
+ scan_type value;
170
+ };
171
+ %endif
172
+ // padded in WG_SIZE to avoid bank conflicts
173
+ LOCAL_MEM struct wrapped_scan_type ldata[K + 1][WG_SIZE + 1];
174
+
175
+ %if is_segmented:
176
+ LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
177
+ LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
178
+
179
+ // only relevant/populated for local id 0
180
+ index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
181
+
182
+ index_type first_segment_start_in_k_group, first_segment_start_in_subtree;
183
+ %endif
184
+
185
+ // {{{ declare local data for input_fetch_exprs if any of them are stenciled
186
+
187
+ <%
188
+ fetch_expr_offsets = {}
189
+ for name, arg_name, ife_offset in input_fetch_exprs:
190
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
191
+
192
+ local_fetch_expr_args = set(
193
+ arg_name
194
+ for arg_name, ife_offsets in fetch_expr_offsets.items()
195
+ if -1 in ife_offsets or len(ife_offsets) > 1)
196
+ %>
197
+
198
+ %for arg_name in local_fetch_expr_args:
199
+ LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
200
+ %endfor
201
+
202
+ // }}}
203
+
204
+ const index_type interval_begin = interval_size * GID_0;
205
+ const index_type interval_end = min(interval_begin + interval_size, N);
206
+
207
+ const index_type unit_size = K * WG_SIZE;
208
+
209
+ index_type unit_base = interval_begin;
210
+
211
+ %for is_tail in [False, True]:
212
+
213
+ %if not is_tail:
214
+ for(; unit_base + unit_size <= interval_end; unit_base += unit_size)
215
+ %else:
216
+ if (unit_base < interval_end)
217
+ %endif
218
+
219
+ {
220
+
221
+ // {{{ carry out input_fetch_exprs
222
+ // (if there are ones that need to be fetched into local)
223
+
224
+ %if local_fetch_expr_args:
225
+ for(index_type k = 0; k < K; k++)
226
+ {
227
+ const index_type offset = k*WG_SIZE + LID_0;
228
+ const index_type read_i = unit_base + offset;
229
+
230
+ %for arg_name in local_fetch_expr_args:
231
+ %if is_tail:
232
+ if (read_i < interval_end)
233
+ %endif
234
+ {
235
+ l_${arg_name}[offset] = ${arg_name}[read_i];
236
+ }
237
+ %endfor
238
+ }
239
+
240
+ local_barrier();
241
+ %endif
242
+
243
+ pycl_printf(("after input_fetch_exprs\n"));
244
+
245
+ // }}}
246
+
247
+ // {{{ read a unit's worth of data from global
248
+
249
+ for(index_type k = 0; k < K; k++)
250
+ {
251
+ const index_type offset = k*WG_SIZE + LID_0;
252
+ const index_type read_i = unit_base + offset;
253
+
254
+ %if is_tail:
255
+ if (read_i < interval_end)
256
+ %endif
257
+ {
258
+ %for name, arg_name, ife_offset in input_fetch_exprs:
259
+ ${arg_ctypes[arg_name]} ${name};
260
+
261
+ %if arg_name in local_fetch_expr_args:
262
+ if (offset + ${ife_offset} >= 0)
263
+ ${name} = l_${arg_name}[offset + ${ife_offset}];
264
+ else if (read_i + ${ife_offset} >= 0)
265
+ ${name} = ${arg_name}[read_i + ${ife_offset}];
266
+ /*
267
+ else
268
+ if out of bounds, name is left undefined */
269
+
270
+ %else:
271
+ // ${arg_name} gets fetched directly from global
272
+ ${name} = ${arg_name}[read_i];
273
+
274
+ %endif
275
+ %endfor
276
+
277
+ scan_type scan_value = INPUT_EXPR(read_i);
278
+
279
+ const index_type o_mod_k = offset % K;
280
+ const index_type o_div_k = offset / K;
281
+ ldata[o_mod_k][o_div_k].value = scan_value;
282
+
283
+ %if is_segmented:
284
+ bool is_seg_start = IS_SEG_START(read_i, scan_value);
285
+ l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start;
286
+ %endif
287
+ %if store_segment_start_flags:
288
+ g_segment_start_flags[read_i] = is_seg_start;
289
+ %endif
290
+ }
291
+ }
292
+
293
+ pycl_printf(("after read from global\n"));
294
+
295
+ // }}}
296
+
297
+ // {{{ carry in from previous unit, if applicable
298
+
299
+ %if is_segmented:
300
+ local_barrier();
301
+
302
+ first_segment_start_in_k_group = NO_SEG_BOUNDARY;
303
+ if (l_segment_start_flags[0][LID_0])
304
+ first_segment_start_in_k_group = unit_base + K*LID_0;
305
+ %endif
306
+
307
+ if (LID_0 == 0 && unit_base != interval_begin)
308
+ {
309
+ scan_type tmp = ldata[K][WG_SIZE - 1].value;
310
+ scan_type tmp_aux = ldata[0][0].value;
311
+
312
+ ldata[0][0].value = SCAN_EXPR(
313
+ tmp, tmp_aux,
314
+ %if is_segmented:
315
+ (l_segment_start_flags[0][0])
316
+ %else:
317
+ false
318
+ %endif
319
+ );
320
+ }
321
+
322
+ pycl_printf(("after carry-in\n"));
323
+
324
+ // }}}
325
+
326
+ local_barrier();
327
+
328
+ // {{{ scan along k (sequentially in each work item)
329
+
330
+ scan_type sum = ldata[0][LID_0].value;
331
+
332
+ %if is_tail:
333
+ const index_type offset_end = interval_end - unit_base;
334
+ %endif
335
+
336
+ for (index_type k = 1; k < K; k++)
337
+ {
338
+ %if is_tail:
339
+ if ((index_type) (K * LID_0 + k) < offset_end)
340
+ %endif
341
+ {
342
+ scan_type tmp = ldata[k][LID_0].value;
343
+
344
+ %if is_segmented:
345
+ index_type seq_i = unit_base + K*LID_0 + k;
346
+
347
+ if (l_segment_start_flags[k][LID_0])
348
+ {
349
+ first_segment_start_in_k_group = min(
350
+ first_segment_start_in_k_group,
351
+ seq_i);
352
+ }
353
+ %endif
354
+
355
+ sum = SCAN_EXPR(sum, tmp,
356
+ %if is_segmented:
357
+ (l_segment_start_flags[k][LID_0])
358
+ %else:
359
+ false
360
+ %endif
361
+ );
362
+
363
+ ldata[k][LID_0].value = sum;
364
+ }
365
+ }
366
+
367
+ pycl_printf(("after scan along k\n"));
368
+
369
+ // }}}
370
+
371
+ // store carry in out-of-bounds (padding) array entry (index K) in
372
+ // the K direction
373
+ ldata[K][LID_0].value = sum;
374
+
375
+ %if is_segmented:
376
+ l_first_segment_start_in_subtree[LID_0] =
377
+ first_segment_start_in_k_group;
378
+ %endif
379
+
380
+ local_barrier();
381
+
382
+ // {{{ tree-based local parallel scan
383
+
384
+ // This tree-based scan works as follows:
385
+ // - Each work item adds the previous item to its current state
386
+ // - barrier
387
+ // - Each work item adds in the item from two positions to the left
388
+ // - barrier
389
+ // - Each work item adds in the item from four positions to the left
390
+ // ...
391
+ // At the end, each item has summed all prior items.
392
+
393
+ // across k groups, along local id
394
+ // (uses out-of-bounds k=K array entry for storage)
395
+
396
+ scan_type val = ldata[K][LID_0].value;
397
+
398
+ <% scan_offset = 1 %>
399
+
400
+ % while scan_offset <= wg_size:
401
+ // {{{ reads from local allowed, writes to local not allowed
402
+
403
+ if (LID_0 >= ${scan_offset})
404
+ {
405
+ scan_type tmp = ldata[K][LID_0 - ${scan_offset}].value;
406
+ % if is_tail:
407
+ if (K*LID_0 < offset_end)
408
+ % endif
409
+ {
410
+ val = SCAN_EXPR(tmp, val,
411
+ %if is_segmented:
412
+ (l_first_segment_start_in_subtree[LID_0]
413
+ != NO_SEG_BOUNDARY)
414
+ %else:
415
+ false
416
+ %endif
417
+ );
418
+ }
419
+
420
+ %if is_segmented:
421
+ // Prepare for l_first_segment_start_in_subtree, below.
422
+
423
+ // Note that this update must take place *even* if we're
424
+ // out of bounds.
425
+
426
+ first_segment_start_in_subtree = min(
427
+ l_first_segment_start_in_subtree[LID_0],
428
+ l_first_segment_start_in_subtree
429
+ [LID_0 - ${scan_offset}]);
430
+ %endif
431
+ }
432
+ %if is_segmented:
433
+ else
434
+ {
435
+ first_segment_start_in_subtree =
436
+ l_first_segment_start_in_subtree[LID_0];
437
+ }
438
+ %endif
439
+
440
+ // }}}
441
+
442
+ local_barrier();
443
+
444
+ // {{{ writes to local allowed, reads from local not allowed
445
+
446
+ ldata[K][LID_0].value = val;
447
+ %if is_segmented:
448
+ l_first_segment_start_in_subtree[LID_0] =
449
+ first_segment_start_in_subtree;
450
+ %endif
451
+
452
+ // }}}
453
+
454
+ local_barrier();
455
+
456
+ %if 0:
457
+ if (LID_0 == 0)
458
+ {
459
+ printf("${scan_offset}: ");
460
+ for (int i = 0; i < WG_SIZE; ++i)
461
+ {
462
+ if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY)
463
+ printf("- ");
464
+ else
465
+ printf("%d ", l_first_segment_start_in_subtree[i]);
466
+ }
467
+ printf("\n");
468
+ }
469
+ %endif
470
+
471
+ <% scan_offset *= 2 %>
472
+ % endwhile
473
+
474
+ pycl_printf(("after tree scan\n"));
475
+
476
+ // }}}
477
+
478
+ // {{{ update local values
479
+
480
+ if (LID_0 > 0)
481
+ {
482
+ sum = ldata[K][LID_0 - 1].value;
483
+
484
+ for(index_type k = 0; k < K; k++)
485
+ {
486
+ %if is_tail:
487
+ if (K * LID_0 + k < offset_end)
488
+ %endif
489
+ {
490
+ scan_type tmp = ldata[k][LID_0].value;
491
+ ldata[k][LID_0].value = SCAN_EXPR(sum, tmp,
492
+ %if is_segmented:
493
+ (unit_base + K * LID_0 + k
494
+ >= first_segment_start_in_k_group)
495
+ %else:
496
+ false
497
+ %endif
498
+ );
499
+ }
500
+ }
501
+ }
502
+
503
+ %if is_segmented:
504
+ if (LID_0 == 0)
505
+ {
506
+ // update interval-wide first-seg variable from current unit
507
+ first_segment_start_in_interval = min(
508
+ first_segment_start_in_interval,
509
+ l_first_segment_start_in_subtree[WG_SIZE-1]);
510
+ }
511
+ %endif
512
+
513
+ pycl_printf(("after local update\n"));
514
+
515
+ // }}}
516
+
517
+ local_barrier();
518
+
519
+ // {{{ write data
520
+
521
+ %if is_gpu:
522
+ {
523
+ // work hard with index math to achieve contiguous 32-bit stores
524
+ __global int *dest =
525
+ (__global int *) (partial_scan_buffer + unit_base);
526
+
527
+ <%
528
+
529
+ assert scan_dtype.itemsize % 4 == 0
530
+
531
+ ints_per_wg = wg_size
532
+ ints_to_store = scan_dtype.itemsize*wg_size*k_group_size // 4
533
+
534
+ %>
535
+
536
+ const index_type scan_types_per_int = ${scan_dtype.itemsize//4};
537
+
538
+ %for store_base in range(0, ints_to_store, ints_per_wg):
539
+ <%
540
+
541
+ # Observe that ints_to_store is divisible by the work group
542
+ # size already, so we won't go out of bounds that way.
543
+ assert store_base + ints_per_wg <= ints_to_store
544
+
545
+ %>
546
+
547
+ %if is_tail:
548
+ if (${store_base} + LID_0 <
549
+ scan_types_per_int*(interval_end - unit_base))
550
+ %endif
551
+ {
552
+ index_type linear_index = ${store_base} + LID_0;
553
+ index_type linear_scan_data_idx =
554
+ linear_index / scan_types_per_int;
555
+ index_type remainder =
556
+ linear_index - linear_scan_data_idx * scan_types_per_int;
557
+
558
+ __local int *src = (__local int *) &(
559
+ ldata
560
+ [linear_scan_data_idx % K]
561
+ [linear_scan_data_idx / K].value);
562
+
563
+ dest[linear_index] = src[remainder];
564
+ }
565
+ %endfor
566
+ }
567
+ %else:
568
+ for (index_type k = 0; k < K; k++)
569
+ {
570
+ const index_type offset = k*WG_SIZE + LID_0;
571
+
572
+ %if is_tail:
573
+ if (unit_base + offset < interval_end)
574
+ %endif
575
+ {
576
+ pycl_printf(("write: %d\n", unit_base + offset));
577
+ partial_scan_buffer[unit_base + offset] =
578
+ ldata[offset % K][offset / K].value;
579
+ }
580
+ }
581
+ %endif
582
+
583
+ pycl_printf(("after write\n"));
584
+
585
+ // }}}
586
+
587
+ local_barrier();
588
+ }
589
+
590
+ % endfor
591
+
592
+ // write interval sum
593
+ %if is_first_level:
594
+ if (LID_0 == 0)
595
+ {
596
+ interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
597
+ %if is_segmented:
598
+ g_first_segment_start_in_interval[GID_0] =
599
+ first_segment_start_in_interval;
600
+ %endif
601
+ }
602
+ %endif
603
+ }
604
+ """
605
+
606
+ # }}}
607
+
608
+ # {{{ update
609
+
610
+ UPDATE_SOURCE = SHARED_PREAMBLE + r"""//CL//
611
+
612
+ KERNEL
613
+ REQD_WG_SIZE(WG_SIZE, 1, 1)
614
+ void ${name_prefix}_final_update(
615
+ ${argument_signature},
616
+ const index_type N,
617
+ const index_type interval_size,
618
+ GLOBAL_MEM scan_type *restrict interval_results,
619
+ GLOBAL_MEM scan_type *restrict partial_scan_buffer
620
+ %if is_segmented:
621
+ , GLOBAL_MEM index_type *restrict g_first_segment_start_in_interval
622
+ %endif
623
+ %if is_segmented and use_lookbehind_update:
624
+ , GLOBAL_MEM char *restrict g_segment_start_flags
625
+ %endif
626
+ )
627
+ {
628
+ ${arg_offset_adjustment}
629
+
630
+ %if use_lookbehind_update:
631
+ LOCAL_MEM scan_type ldata[WG_SIZE];
632
+ %endif
633
+ %if is_segmented and use_lookbehind_update:
634
+ LOCAL_MEM char l_segment_start_flags[WG_SIZE];
635
+ %endif
636
+
637
+ const index_type interval_begin = interval_size * GID_0;
638
+ const index_type interval_end = min(interval_begin + interval_size, N);
639
+
640
+ // carry from last interval
641
+ scan_type carry = ${neutral};
642
+ if (GID_0 != 0)
643
+ carry = interval_results[GID_0 - 1];
644
+
645
+ %if is_segmented:
646
+ const index_type first_seg_start_in_interval =
647
+ g_first_segment_start_in_interval[GID_0];
648
+ %endif
649
+
650
+ %if not is_segmented and 'last_item' in output_statement:
651
+ scan_type last_item = interval_results[GDIM_0-1];
652
+ %endif
653
+
654
+ %if not use_lookbehind_update:
655
+ // {{{ no look-behind ('prev_item' not in output_statement -> simpler)
656
+
657
+ index_type update_i = interval_begin+LID_0;
658
+
659
+ %if is_segmented:
660
+ index_type seg_end = min(first_seg_start_in_interval, interval_end);
661
+ %endif
662
+
663
+ for(; update_i < interval_end; update_i += WG_SIZE)
664
+ {
665
+ scan_type partial_val = partial_scan_buffer[update_i];
666
+ scan_type item = SCAN_EXPR(carry, partial_val,
667
+ %if is_segmented:
668
+ (update_i >= seg_end)
669
+ %else:
670
+ false
671
+ %endif
672
+ );
673
+ index_type i = update_i;
674
+
675
+ { ${output_statement}; }
676
+ }
677
+
678
+ // }}}
679
+ %else:
680
+ // {{{ allow look-behind ('prev_item' in output_statement -> complicated)
681
+
682
+ // We are not allowed to branch across barriers at a granularity smaller
683
+ // than the whole workgroup. Therefore, the for loop is group-global,
684
+ // and there are lots of local ifs.
685
+
686
+ index_type group_base = interval_begin;
687
+ scan_type prev_item = carry; // (A)
688
+
689
+ for(; group_base < interval_end; group_base += WG_SIZE)
690
+ {
691
+ index_type update_i = group_base+LID_0;
692
+
693
+ // load a work group's worth of data
694
+ if (update_i < interval_end)
695
+ {
696
+ scan_type tmp = partial_scan_buffer[update_i];
697
+
698
+ tmp = SCAN_EXPR(carry, tmp,
699
+ %if is_segmented:
700
+ (update_i >= first_seg_start_in_interval)
701
+ %else:
702
+ false
703
+ %endif
704
+ );
705
+
706
+ ldata[LID_0] = tmp;
707
+
708
+ %if is_segmented:
709
+ l_segment_start_flags[LID_0] = g_segment_start_flags[update_i];
710
+ %endif
711
+ }
712
+
713
+ local_barrier();
714
+
715
+ // find prev_item
716
+ if (LID_0 != 0)
717
+ prev_item = ldata[LID_0 - 1];
718
+ /*
719
+ else
720
+ prev_item = carry (see (A)) OR last tail (see (B));
721
+ */
722
+
723
+ if (update_i < interval_end)
724
+ {
725
+ %if is_segmented:
726
+ if (l_segment_start_flags[LID_0])
727
+ prev_item = ${neutral};
728
+ %endif
729
+
730
+ scan_type item = ldata[LID_0];
731
+ index_type i = update_i;
732
+ { ${output_statement}; }
733
+ }
734
+
735
+ if (LID_0 == 0)
736
+ prev_item = ldata[WG_SIZE - 1]; // (B)
737
+
738
+ local_barrier();
739
+ }
740
+
741
+ // }}}
742
+ %endif
743
+ }
744
+ """
745
+
746
+ # }}}
747
+
748
+
749
+ # {{{ driver
750
+
751
+ # {{{ helpers
752
+
753
+ def _round_down_to_power_of_2(val: int) -> int:
754
+ result = 2**bitlog2(val)
755
+ if result > val:
756
+ result >>= 1
757
+
758
+ assert result <= val
759
+ return result
760
+
761
+
762
+ _PREFIX_WORDS = set("""
763
+ ldata partial_scan_buffer global scan_offset
764
+ segment_start_in_k_group carry
765
+ g_first_segment_start_in_interval IS_SEG_START tmp Z
766
+ val l_first_segment_start_in_subtree unit_size
767
+ index_type interval_begin interval_size offset_end K
768
+ SCAN_EXPR do_update WG_SIZE
769
+ first_segment_start_in_k_group scan_type
770
+ segment_start_in_subtree offset interval_results interval_end
771
+ first_segment_start_in_subtree unit_base
772
+ first_segment_start_in_interval k INPUT_EXPR
773
+ prev_group_sum prev pv value partial_val pgs
774
+ is_seg_start update_i scan_item_at_i seq_i read_i
775
+ l_ o_mod_k o_div_k l_segment_start_flags scan_value sum
776
+ first_seg_start_in_interval g_segment_start_flags
777
+ group_base seg_end my_val DEBUG ARGS
778
+ ints_to_store ints_per_wg scan_types_per_int linear_index
779
+ linear_scan_data_idx dest src store_base wrapped_scan_type
780
+ dummy scan_tmp tmp_aux
781
+
782
+ LID_2 LID_1 LID_0
783
+ LDIM_0 LDIM_1 LDIM_2
784
+ GDIM_0 GDIM_1 GDIM_2
785
+ GID_0 GID_1 GID_2
786
+ """.split())
787
+
788
+ _IGNORED_WORDS = set("""
789
+ 4 8 32
790
+
791
+ typedef for endfor if void while endwhile endfor endif else const printf
792
+ None return bool n char true false ifdef pycl_printf str range assert
793
+ np iinfo max itemsize __packed__ struct restrict ptrdiff_t
794
+
795
+ set iteritems len setdefault
796
+
797
+ GLOBAL_MEM LOCAL_MEM_ARG WITHIN_KERNEL LOCAL_MEM KERNEL REQD_WG_SIZE
798
+ local_barrier
799
+ CLK_LOCAL_MEM_FENCE OPENCL EXTENSION
800
+ pragma __attribute__ __global __kernel __local
801
+ get_local_size get_local_id cl_khr_fp64 reqd_work_group_size
802
+ get_num_groups barrier get_group_id
803
+ CL_VERSION_1_1 __OPENCL_C_VERSION__ 120
804
+
805
+ _final_update _debug_scan kernel_name
806
+
807
+ positions all padded integer its previous write based writes 0
808
+ has local worth scan_expr to read cannot not X items False bank
809
+ four beginning follows applicable item min each indices works side
810
+ scanning right summed relative used id out index avoid current state
811
+ boundary True across be This reads groups along Otherwise undetermined
812
+ store of times prior s update first regardless Each number because
813
+ array unit from segment conflicts two parallel 2 empty define direction
814
+ CL padding work tree bounds values and adds
815
+ scan is allowed thus it an as enable at in occur sequentially end no
816
+ storage data 1 largest may representable uses entry Y meaningful
817
+ computations interval At the left dimension know d
818
+ A load B group perform shift tail see last OR
819
+ this add fetched into are directly need
820
+ gets them stenciled that undefined
821
+ there up any ones or name only relevant populated
822
+ even wide we Prepare int seg Note re below place take variable must
823
+ intra Therefore find code assumption
824
+ branch workgroup complicated granularity phase remainder than simpler
825
+ We smaller look ifs lots self behind allow barriers whole loop
826
+ after already Observe achieve contiguous stores hard go with by math
827
+ size won t way divisible bit so Avoid declare adding single type
828
+
829
+ is_tail is_first_level input_expr argument_signature preamble
830
+ double_support neutral output_statement
831
+ k_group_size name_prefix is_segmented index_dtype scan_dtype
832
+ wg_size is_segment_start_expr fetch_expr_offsets
833
+ arg_ctypes ife_offsets input_fetch_exprs def
834
+ ife_offset arg_name local_fetch_expr_args update_body
835
+ update_loop_lookbehind update_loop_plain update_loop
836
+ use_lookbehind_update store_segment_start_flags
837
+ update_loop first_seg scan_dtype dtype_to_ctype
838
+ is_gpu use_bank_conflict_avoidance
839
+
840
+ a b prev_item i last_item prev_value
841
+ N NO_SEG_BOUNDARY across_seg_boundary
842
+
843
+ arg_offset_adjustment
844
+ """.split())
845
+
846
+
847
+ def _make_template(s: str):
848
+ import re
849
+ leftovers = set()
850
+
851
+ def replace_id(match: "re.Match") -> str:
852
+ # avoid name clashes with user code by adding 'psc_' prefix to
853
+ # identifiers.
854
+
855
+ word = match.group(1)
856
+ if word in _IGNORED_WORDS:
857
+ return word
858
+ elif word in _PREFIX_WORDS:
859
+ return f"psc_{word}"
860
+ else:
861
+ leftovers.add(word)
862
+ return word
863
+
864
+ s = re.sub(r"\b([a-zA-Z0-9_]+)\b", replace_id, s)
865
+ if leftovers:
866
+ from warnings import warn
867
+ warn("Leftover words in identifier prefixing: " + " ".join(leftovers),
868
+ stacklevel=3)
869
+
870
+ return mako.template.Template(s, strict_undefined=True) # type: ignore
871
+
872
+
873
+ @dataclass(frozen=True)
874
+ class _GeneratedScanKernelInfo:
875
+ scan_src: str
876
+ kernel_name: str
877
+ scalar_arg_dtypes: List[Optional[np.dtype]]
878
+ wg_size: int
879
+ k_group_size: int
880
+
881
+ def build(self, context: cl.Context, options: Any) -> "_BuiltScanKernelInfo":
882
+ program = cl.Program(context, self.scan_src).build(options)
883
+ kernel = getattr(program, self.kernel_name)
884
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
885
+ return _BuiltScanKernelInfo(
886
+ kernel=kernel,
887
+ wg_size=self.wg_size,
888
+ k_group_size=self.k_group_size)
889
+
890
+
891
+ @dataclass(frozen=True)
892
+ class _BuiltScanKernelInfo:
893
+ kernel: cl.Kernel
894
+ wg_size: int
895
+ k_group_size: int
896
+
897
+
898
+ @dataclass(frozen=True)
899
+ class _GeneratedFinalUpdateKernelInfo:
900
+ source: str
901
+ kernel_name: str
902
+ scalar_arg_dtypes: List[Optional[np.dtype]]
903
+ update_wg_size: int
904
+
905
+ def build(self,
906
+ context: cl.Context,
907
+ options: Any) -> "_BuiltFinalUpdateKernelInfo":
908
+ program = cl.Program(context, self.source).build(options)
909
+ kernel = getattr(program, self.kernel_name)
910
+ kernel.set_scalar_arg_dtypes(self.scalar_arg_dtypes)
911
+ return _BuiltFinalUpdateKernelInfo(kernel, self.update_wg_size)
912
+
913
+
914
+ @dataclass(frozen=True)
915
+ class _BuiltFinalUpdateKernelInfo:
916
+ kernel: cl.Kernel
917
+ update_wg_size: int
918
+
919
+ # }}}
920
+
921
+
922
+ class ScanPerformanceWarning(UserWarning):
923
+ pass
924
+
925
+
926
+ class GenericScanKernelBase(ABC):
927
+ # {{{ constructor, argument processing
928
+
929
+ def __init__(
930
+ self,
931
+ ctx: cl.Context,
932
+ dtype: Any,
933
+ arguments: Union[str, List[DtypedArgument]],
934
+ input_expr: str,
935
+ scan_expr: str,
936
+ neutral: Optional[str],
937
+ output_statement: str,
938
+ is_segment_start_expr: Optional[str] = None,
939
+ input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
940
+ index_dtype: Any = None,
941
+ name_prefix: str = "scan",
942
+ options: Any = None,
943
+ preamble: str = "",
944
+ devices: Optional[cl.Device] = None) -> None:
945
+ """
946
+ :arg ctx: a :class:`pyopencl.Context` within which the code
947
+ for this scan kernel will be generated.
948
+ :arg dtype: the :class:`numpy.dtype` with which the scan will
949
+ be performed. May be a structured type if that type was registered
950
+ through :func:`pyopencl.tools.get_or_register_dtype`.
951
+ :arg arguments: A string of comma-separated C argument declarations.
952
+ If *arguments* is specified, then *input_expr* must also be
953
+ specified. All types used here must be known to PyOpenCL.
954
+ (see :func:`pyopencl.tools.get_or_register_dtype`).
955
+ :arg scan_expr: The associative, binary operation carrying out the scan,
956
+ represented as a C string. Its two arguments are available as ``a``
957
+ and ``b`` when it is evaluated. ``b`` is guaranteed to be the
958
+ 'element being updated', and ``a`` is the increment. Thus,
959
+ if some data is supposed to just propagate along without being
960
+ modified by the scan, it should live in ``b``.
961
+
962
+ This expression may call functions given in the *preamble*.
963
+
964
+ Another value available to this expression is ``across_seg_boundary``,
965
+ a C `bool` indicating whether this scan update is crossing a
966
+ segment boundary, as defined by ``is_segment_start_expr``.
967
+ The scan routine does not implement segmentation
968
+ semantics on its own. It relies on ``scan_expr`` to do this.
969
+ This value is available (but always ``false``) even for a
970
+ non-segmented scan.
971
+
972
+ .. note::
973
+
974
+ In early pre-releases of the segmented scan,
975
+ segmentation semantics were implemented *without*
976
+ relying on ``scan_expr``.
977
+
978
+ :arg input_expr: A C expression, encoded as a string, resulting
979
+ in the values to which the scan is applied. This may be used
980
+ to apply a mapping to values stored in *arguments* before being
981
+ scanned. The result of this expression must match *dtype*.
982
+ The index intended to be mapped is available as ``i`` in this
983
+ expression. This expression may also use the variables defined
984
+ by *input_fetch_expr*.
985
+
986
+ This expression may also call functions given in the *preamble*.
987
+ :arg output_statement: a C statement that writes
988
+ the output of the scan. It has access to the scan result as ``item``,
989
+ the preceding scan result item as ``prev_item``, and the current index
990
+ as ``i``. ``prev_item`` in a segmented scan will be the neutral element
991
+ at a segment boundary, not the immediately preceding item.
992
+
993
+ Using *prev_item* in output statement has a small run-time cost.
994
+ ``prev_item`` enables the construction of an exclusive scan.
995
+
996
+ For non-segmented scans, *output_statement* may also reference
997
+ ``last_item``, which evaluates to the scan result of the last
998
+ array entry.
999
+ :arg is_segment_start_expr: A C expression, encoded as a string,
1000
+ resulting in a C ``bool`` value that determines whether a new
1001
+ scan segments starts at index *i*. If given, makes the scan a
1002
+ segmented scan. Has access to the current index ``i``, the result
1003
+ of *input_expr* as ``a``, and in addition may use *arguments* and
1004
+ *input_fetch_expr* variables just like *input_expr*.
1005
+
1006
+ If it returns true, then previous sums will not spill over into the
1007
+ item with index *i* or subsequent items.
1008
+ :arg input_fetch_exprs: a list of tuples *(NAME, ARG_NAME, OFFSET)*.
1009
+ An entry here has the effect of doing the equivalent of the following
1010
+ before input_expr::
1011
+
1012
+ ARG_NAME_TYPE NAME = ARG_NAME[i+OFFSET];
1013
+
1014
+ ``OFFSET`` is allowed to be 0 or -1, and ``ARG_NAME_TYPE`` is the type
1015
+ of ``ARG_NAME``.
1016
+ :arg preamble: |preamble|
1017
+
1018
+ The first array in the argument list determines the size of the index
1019
+ space over which the scan is carried out, and thus the values over
1020
+ which the index *i* occurring in a number of code fragments in
1021
+ arguments above will vary.
1022
+
1023
+ All code fragments further have access to N, the number of elements
1024
+ being processed in the scan.
1025
+ """
1026
+
1027
+ if index_dtype is None:
1028
+ index_dtype = np.dtype(np.int32)
1029
+
1030
+ if input_fetch_exprs is None:
1031
+ input_fetch_exprs = []
1032
+
1033
+ self.context = ctx
1034
+ dtype = self.dtype = np.dtype(dtype)
1035
+
1036
+ if neutral is None:
1037
+ from warnings import warn
1038
+ warn("not specifying 'neutral' is deprecated and will lead to "
1039
+ "wrong results if your scan is not in-place or your "
1040
+ "'output_statement' does something otherwise non-trivial",
1041
+ stacklevel=2)
1042
+
1043
+ if dtype.itemsize % 4 != 0:
1044
+ raise TypeError("scan value type must have size divisible by 4 bytes")
1045
+
1046
+ self.index_dtype = np.dtype(index_dtype)
1047
+ if np.iinfo(self.index_dtype).min >= 0:
1048
+ raise TypeError("index_dtype must be signed")
1049
+
1050
+ if devices is None:
1051
+ devices = ctx.devices
1052
+ self.devices = devices
1053
+ self.options = options
1054
+
1055
+ from pyopencl.tools import parse_arg_list
1056
+ self.parsed_args = parse_arg_list(arguments)
1057
+ from pyopencl.tools import VectorArg
1058
+ self.first_array_idx = next(
1059
+ i for i, arg in enumerate(self.parsed_args)
1060
+ if isinstance(arg, VectorArg))
1061
+
1062
+ self.input_expr = input_expr
1063
+
1064
+ self.is_segment_start_expr = is_segment_start_expr
1065
+ self.is_segmented = is_segment_start_expr is not None
1066
+ if self.is_segmented:
1067
+ is_segment_start_expr = _process_code_for_macro(is_segment_start_expr)
1068
+
1069
+ self.output_statement = output_statement
1070
+
1071
+ for _name, _arg_name, ife_offset in input_fetch_exprs:
1072
+ if ife_offset not in [0, -1]:
1073
+ raise RuntimeError("input_fetch_expr offsets must either be 0 or -1")
1074
+ self.input_fetch_exprs = input_fetch_exprs
1075
+
1076
+ arg_dtypes = {}
1077
+ arg_ctypes = {}
1078
+ for arg in self.parsed_args:
1079
+ arg_dtypes[arg.name] = arg.dtype
1080
+ arg_ctypes[arg.name] = dtype_to_ctype(arg.dtype)
1081
+
1082
+ self.name_prefix = name_prefix
1083
+
1084
+ # {{{ set up shared code dict
1085
+
1086
+ from pyopencl.characterize import has_double_support
1087
+
1088
+ self.code_variables = {
1089
+ "np": np,
1090
+ "dtype_to_ctype": dtype_to_ctype,
1091
+ "preamble": preamble,
1092
+ "name_prefix": name_prefix,
1093
+ "index_dtype": self.index_dtype,
1094
+ "scan_dtype": dtype,
1095
+ "is_segmented": self.is_segmented,
1096
+ "arg_dtypes": arg_dtypes,
1097
+ "arg_ctypes": arg_ctypes,
1098
+ "scan_expr": _process_code_for_macro(scan_expr),
1099
+ "neutral": _process_code_for_macro(neutral),
1100
+ "is_gpu": bool(self.devices[0].type & cl.device_type.GPU),
1101
+ "double_support": all(
1102
+ has_double_support(dev) for dev in devices),
1103
+ }
1104
+
1105
+ index_typename = dtype_to_ctype(self.index_dtype)
1106
+ scan_typename = dtype_to_ctype(dtype)
1107
+
1108
+ # This key is meant to uniquely identify the non-device parameters for
1109
+ # the scan kernel.
1110
+ self.kernel_key = (
1111
+ self.dtype,
1112
+ tuple(arg.declarator() for arg in self.parsed_args),
1113
+ self.input_expr,
1114
+ scan_expr,
1115
+ neutral,
1116
+ output_statement,
1117
+ is_segment_start_expr,
1118
+ tuple(input_fetch_exprs),
1119
+ index_dtype,
1120
+ name_prefix,
1121
+ preamble,
1122
+ # These depend on dtype_to_ctype(), so their value is independent of
1123
+ # the other variables.
1124
+ index_typename,
1125
+ scan_typename,
1126
+ )
1127
+
1128
+ # }}}
1129
+
1130
+ self.use_lookbehind_update = "prev_item" in self.output_statement
1131
+ self.store_segment_start_flags = (
1132
+ self.is_segmented and self.use_lookbehind_update)
1133
+
1134
+ self.finish_setup()
1135
+
1136
+ # }}}
1137
+
1138
+ @abstractmethod
1139
+ def finish_setup(self) -> None:
1140
+ pass
1141
+
1142
+
1143
+ if not cl._PYOPENCL_NO_CACHE:
1144
+ generic_scan_kernel_cache: WriteOncePersistentDict[Any,
1145
+ Tuple[_GeneratedScanKernelInfo, _GeneratedScanKernelInfo,
1146
+ _GeneratedFinalUpdateKernelInfo]] = \
1147
+ WriteOncePersistentDict(
1148
+ "pyopencl-generated-scan-kernel-cache-v1",
1149
+ key_builder=_NumpyTypesKeyBuilder(),
1150
+ in_mem_cache_size=0,
1151
+ safe_sync=False)
1152
+
1153
+
1154
+ class GenericScanKernel(GenericScanKernelBase):
1155
+ """Generates and executes code that performs prefix sums ("scans") on
1156
+ arbitrary types, with many possible tweaks.
1157
+
1158
+ Usage example::
1159
+
1160
+ from pyopencl.scan import GenericScanKernel
1161
+ knl = GenericScanKernel(
1162
+ context, np.int32,
1163
+ arguments="__global int *ary",
1164
+ input_expr="ary[i]",
1165
+ scan_expr="a+b", neutral="0",
1166
+ output_statement="ary[i+1] = item;")
1167
+
1168
+ a = cl.array.arange(queue, 10000, dtype=np.int32)
1169
+ knl(a, queue=queue)
1170
+
1171
+ .. automethod:: __init__
1172
+ .. automethod:: __call__
1173
+ """
1174
+
1175
+ def finish_setup(self) -> None:
1176
+ # Before generating the kernel, see if it's cached.
1177
+ from pyopencl.cache import get_device_cache_id
1178
+ devices_key = tuple(get_device_cache_id(device)
1179
+ for device in self.devices)
1180
+
1181
+ cache_key = (self.kernel_key, devices_key)
1182
+ from_cache = False
1183
+
1184
+ if not cl._PYOPENCL_NO_CACHE:
1185
+ try:
1186
+ result = generic_scan_kernel_cache[cache_key]
1187
+ from_cache = True
1188
+ logger.debug(
1189
+ "cache hit for generated scan kernel '%s'", self.name_prefix)
1190
+ (
1191
+ self.first_level_scan_gen_info,
1192
+ self.second_level_scan_gen_info,
1193
+ self.final_update_gen_info) = result
1194
+ except KeyError:
1195
+ pass
1196
+
1197
+ if not from_cache:
1198
+ logger.debug(
1199
+ "cache miss for generated scan kernel '%s'", self.name_prefix)
1200
+ self._finish_setup_impl()
1201
+
1202
+ result = (self.first_level_scan_gen_info,
1203
+ self.second_level_scan_gen_info,
1204
+ self.final_update_gen_info)
1205
+
1206
+ if not cl._PYOPENCL_NO_CACHE:
1207
+ generic_scan_kernel_cache.store_if_not_present(cache_key, result)
1208
+
1209
+ # Build the kernels.
1210
+ self.first_level_scan_info = self.first_level_scan_gen_info.build(
1211
+ self.context, self.options)
1212
+ del self.first_level_scan_gen_info
1213
+
1214
+ self.second_level_scan_info = self.second_level_scan_gen_info.build(
1215
+ self.context, self.options)
1216
+ del self.second_level_scan_gen_info
1217
+
1218
+ self.final_update_info = self.final_update_gen_info.build(
1219
+ self.context, self.options)
1220
+ del self.final_update_gen_info
1221
+
1222
+ def _finish_setup_impl(self) -> None:
1223
+ # {{{ find usable workgroup/k-group size, build first-level scan
1224
+
1225
+ trip_count = 0
1226
+
1227
+ avail_local_mem = min(
1228
+ dev.local_mem_size
1229
+ for dev in self.devices)
1230
+
1231
+ if "CUDA" in self.devices[0].platform.name:
1232
+ # not sure where these go, but roughly this much seems unavailable.
1233
+ avail_local_mem -= 0x400
1234
+
1235
+ is_cpu = self.devices[0].type & cl.device_type.CPU
1236
+ is_gpu = self.devices[0].type & cl.device_type.GPU
1237
+
1238
+ if is_cpu:
1239
+ # (about the widest vector a CPU can support, also taking
1240
+ # into account that CPUs don't hide latency by large work groups
1241
+ max_scan_wg_size = 16
1242
+ wg_size_multiples = 4
1243
+ else:
1244
+ max_scan_wg_size = min(dev.max_work_group_size for dev in self.devices)
1245
+ wg_size_multiples = 64
1246
+
1247
+ # Intel beignet fails "Out of shared local memory" in test_scan int64
1248
+ # and asserts in test_sort with this enabled:
1249
+ # https://github.com/inducer/pyopencl/pull/238
1250
+ # A beignet bug report (outside of pyopencl) suggests packed structs
1251
+ # (which this is) can even give wrong results:
1252
+ # https://bugs.freedesktop.org/show_bug.cgi?id=98717
1253
+ # TODO: does this also affect Intel Compute Runtime?
1254
+ use_bank_conflict_avoidance = (
1255
+ self.dtype.itemsize > 4 and self.dtype.itemsize % 8 == 0
1256
+ and is_gpu
1257
+ and "beignet" not in self.devices[0].platform.version.lower())
1258
+
1259
+ # k_group_size should be a power of two because of in-kernel
1260
+ # division by that number.
1261
+
1262
+ solutions = []
1263
+ for k_exp in range(0, 9):
1264
+ for wg_size in range(wg_size_multiples, max_scan_wg_size+1,
1265
+ wg_size_multiples):
1266
+
1267
+ k_group_size = 2**k_exp
1268
+ lmem_use = self.get_local_mem_use(wg_size, k_group_size,
1269
+ use_bank_conflict_avoidance)
1270
+ if lmem_use <= avail_local_mem:
1271
+ solutions.append((wg_size*k_group_size, k_group_size, wg_size))
1272
+
1273
+ if is_gpu:
1274
+ for wg_size_floor in [256, 192, 128]:
1275
+ have_sol_above_floor = any(wg_size >= wg_size_floor
1276
+ for _, _, wg_size in solutions)
1277
+
1278
+ if have_sol_above_floor:
1279
+ # delete all solutions not meeting the wg size floor
1280
+ solutions = [(total, try_k_group_size, try_wg_size)
1281
+ for total, try_k_group_size, try_wg_size in solutions
1282
+ if try_wg_size >= wg_size_floor]
1283
+ break
1284
+
1285
+ _, k_group_size, max_scan_wg_size = max(solutions)
1286
+
1287
+ while True:
1288
+ candidate_scan_gen_info = self.generate_scan_kernel(
1289
+ max_scan_wg_size, self.parsed_args,
1290
+ _process_code_for_macro(self.input_expr),
1291
+ self.is_segment_start_expr,
1292
+ input_fetch_exprs=self.input_fetch_exprs,
1293
+ is_first_level=True,
1294
+ store_segment_start_flags=self.store_segment_start_flags,
1295
+ k_group_size=k_group_size,
1296
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance)
1297
+
1298
+ candidate_scan_info = candidate_scan_gen_info.build(
1299
+ self.context, self.options)
1300
+
1301
+ # Will this device actually let us execute this kernel
1302
+ # at the desired work group size? Building it is the
1303
+ # only way to find out.
1304
+ kernel_max_wg_size = min(
1305
+ candidate_scan_info.kernel.get_work_group_info(
1306
+ cl.kernel_work_group_info.WORK_GROUP_SIZE,
1307
+ dev)
1308
+ for dev in self.devices)
1309
+
1310
+ if candidate_scan_info.wg_size <= kernel_max_wg_size:
1311
+ break
1312
+ else:
1313
+ max_scan_wg_size = min(kernel_max_wg_size, max_scan_wg_size)
1314
+
1315
+ trip_count += 1
1316
+ assert trip_count <= 20
1317
+
1318
+ self.first_level_scan_gen_info = candidate_scan_gen_info
1319
+ assert (_round_down_to_power_of_2(candidate_scan_info.wg_size)
1320
+ == candidate_scan_info.wg_size)
1321
+
1322
+ # }}}
1323
+
1324
+ # {{{ build second-level scan
1325
+
1326
+ from pyopencl.tools import VectorArg
1327
+ second_level_arguments = [
1328
+ *self.parsed_args,
1329
+ VectorArg(self.dtype, "interval_sums"),
1330
+ ]
1331
+
1332
+ second_level_build_kwargs: Dict[str, Optional[str]] = {}
1333
+ if self.is_segmented:
1334
+ second_level_arguments.append(
1335
+ VectorArg(self.index_dtype,
1336
+ "g_first_segment_start_in_interval_input"))
1337
+
1338
+ # is_segment_start_expr answers the question "should previous sums
1339
+ # spill over into this item". And since
1340
+ # g_first_segment_start_in_interval_input answers the question if a
1341
+ # segment boundary was found in an interval of data, then if not,
1342
+ # it's ok to spill over.
1343
+ second_level_build_kwargs["is_segment_start_expr"] = \
1344
+ "g_first_segment_start_in_interval_input[i] != NO_SEG_BOUNDARY"
1345
+ else:
1346
+ second_level_build_kwargs["is_segment_start_expr"] = None
1347
+
1348
+ self.second_level_scan_gen_info = self.generate_scan_kernel(
1349
+ max_scan_wg_size,
1350
+ arguments=second_level_arguments,
1351
+ input_expr="interval_sums[i]",
1352
+ input_fetch_exprs=[],
1353
+ is_first_level=False,
1354
+ store_segment_start_flags=False,
1355
+ k_group_size=k_group_size,
1356
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1357
+ **second_level_build_kwargs)
1358
+
1359
+ # }}}
1360
+
1361
+ # {{{ generate final update kernel
1362
+
1363
+ update_wg_size = min(max_scan_wg_size, 256)
1364
+
1365
+ final_update_tpl = _make_template(UPDATE_SOURCE)
1366
+ final_update_src = str(final_update_tpl.render(
1367
+ wg_size=update_wg_size,
1368
+ output_statement=self.output_statement,
1369
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1370
+ argument_signature=", ".join(
1371
+ arg.declarator() for arg in self.parsed_args),
1372
+ is_segment_start_expr=self.is_segment_start_expr,
1373
+ input_expr=_process_code_for_macro(self.input_expr),
1374
+ use_lookbehind_update=self.use_lookbehind_update,
1375
+ **self.code_variables))
1376
+
1377
+ update_scalar_arg_dtypes = [
1378
+ *get_arg_list_scalar_arg_dtypes(self.parsed_args),
1379
+ self.index_dtype, self.index_dtype, None, None]
1380
+
1381
+ if self.is_segmented:
1382
+ # g_first_segment_start_in_interval
1383
+ update_scalar_arg_dtypes.append(None)
1384
+ if self.store_segment_start_flags:
1385
+ update_scalar_arg_dtypes.append(None) # g_segment_start_flags
1386
+
1387
+ self.final_update_gen_info = _GeneratedFinalUpdateKernelInfo(
1388
+ final_update_src,
1389
+ self.name_prefix + "_final_update",
1390
+ update_scalar_arg_dtypes,
1391
+ update_wg_size)
1392
+
1393
+ # }}}
1394
+
1395
+ # {{{ scan kernel build/properties
1396
+
1397
+ def get_local_mem_use(
1398
+ self, k_group_size: int, wg_size: int,
1399
+ use_bank_conflict_avoidance: bool) -> int:
1400
+ arg_dtypes = {}
1401
+ for arg in self.parsed_args:
1402
+ arg_dtypes[arg.name] = arg.dtype
1403
+
1404
+ fetch_expr_offsets: Dict[str, Set] = {}
1405
+ for _name, arg_name, ife_offset in self.input_fetch_exprs:
1406
+ fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
1407
+
1408
+ itemsize = self.dtype.itemsize
1409
+ if use_bank_conflict_avoidance:
1410
+ itemsize += 4
1411
+
1412
+ return (
1413
+ # ldata
1414
+ itemsize*(k_group_size+1)*(wg_size+1)
1415
+
1416
+ # l_segment_start_flags
1417
+ + k_group_size*wg_size
1418
+
1419
+ # l_first_segment_start_in_subtree
1420
+ + self.index_dtype.itemsize*wg_size
1421
+
1422
+ + k_group_size*wg_size*sum(
1423
+ arg_dtypes[arg_name].itemsize
1424
+ for arg_name, ife_offsets in list(fetch_expr_offsets.items())
1425
+ if -1 in ife_offsets or len(ife_offsets) > 1))
1426
+
1427
+ def generate_scan_kernel(
1428
+ self,
1429
+ max_wg_size: int,
1430
+ arguments: List[DtypedArgument],
1431
+ input_expr: str,
1432
+ is_segment_start_expr: Optional[str],
1433
+ input_fetch_exprs: List[Tuple[str, str, int]],
1434
+ is_first_level: bool,
1435
+ store_segment_start_flags: bool,
1436
+ k_group_size: int,
1437
+ use_bank_conflict_avoidance: bool) -> _GeneratedScanKernelInfo:
1438
+ scalar_arg_dtypes = get_arg_list_scalar_arg_dtypes(arguments)
1439
+
1440
+ # Empirically found on Nv hardware: no need to be bigger than this size
1441
+ wg_size = _round_down_to_power_of_2(
1442
+ min(max_wg_size, 256))
1443
+
1444
+ kernel_name = self.code_variables["name_prefix"]
1445
+ if is_first_level:
1446
+ kernel_name += "_lev1"
1447
+ else:
1448
+ kernel_name += "_lev2"
1449
+
1450
+ scan_tpl = _make_template(SCAN_INTERVALS_SOURCE)
1451
+ scan_src = str(scan_tpl.render(
1452
+ wg_size=wg_size,
1453
+ input_expr=input_expr,
1454
+ k_group_size=k_group_size,
1455
+ arg_offset_adjustment=get_arg_offset_adjuster_code(arguments),
1456
+ argument_signature=", ".join(arg.declarator() for arg in arguments),
1457
+ is_segment_start_expr=is_segment_start_expr,
1458
+ input_fetch_exprs=input_fetch_exprs,
1459
+ is_first_level=is_first_level,
1460
+ store_segment_start_flags=store_segment_start_flags,
1461
+ use_bank_conflict_avoidance=use_bank_conflict_avoidance,
1462
+ kernel_name=kernel_name,
1463
+ **self.code_variables))
1464
+
1465
+ scalar_arg_dtypes.extend(
1466
+ (None, self.index_dtype, self.index_dtype))
1467
+ if is_first_level:
1468
+ scalar_arg_dtypes.append(None) # interval_results
1469
+ if self.is_segmented and is_first_level:
1470
+ scalar_arg_dtypes.append(None) # g_first_segment_start_in_interval
1471
+ if store_segment_start_flags:
1472
+ scalar_arg_dtypes.append(None) # g_segment_start_flags
1473
+
1474
+ return _GeneratedScanKernelInfo(
1475
+ scan_src=scan_src,
1476
+ kernel_name=kernel_name,
1477
+ scalar_arg_dtypes=scalar_arg_dtypes,
1478
+ wg_size=wg_size,
1479
+ k_group_size=k_group_size)
1480
+
1481
+ # }}}
1482
+
1483
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1484
+ """
1485
+ |std-enqueue-blurb|
1486
+
1487
+ .. note::
1488
+
1489
+ The returned :class:`pyopencl.Event` corresponds only to part of the
1490
+ execution of the scan. It is not suitable for profiling.
1491
+
1492
+ :arg queue: queue on which to execute the scan. If not given, the
1493
+ queue of the first :class:`pyopencl.array.Array` in *args* is used
1494
+ :arg allocator: an allocator for the temporary arrays and results. If
1495
+ not given, the allocator of the first :class:`pyopencl.array.Array`
1496
+ in *args* is used.
1497
+ :arg size: specify the length of the scan to be carried out. If not
1498
+ given, this length is inferred from the first argument
1499
+ :arg wait_for: a :class:`list` of events to wait for.
1500
+ """
1501
+
1502
+ # {{{ argument processing
1503
+
1504
+ allocator = kwargs.get("allocator")
1505
+ queue = kwargs.get("queue")
1506
+ n = kwargs.get("size")
1507
+ wait_for = kwargs.get("wait_for")
1508
+
1509
+ if wait_for is None:
1510
+ wait_for = []
1511
+ else:
1512
+ wait_for = list(wait_for)
1513
+
1514
+ if len(args) != len(self.parsed_args):
1515
+ raise TypeError(
1516
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1517
+
1518
+ first_array = args[self.first_array_idx]
1519
+ allocator = allocator or first_array.allocator
1520
+ queue = queue or first_array.queue
1521
+
1522
+ if n is None:
1523
+ n, = first_array.shape
1524
+
1525
+ if n == 0:
1526
+ # We're done here. (But pretend to return an event.)
1527
+ return cl.enqueue_marker(queue, wait_for=wait_for)
1528
+
1529
+ data_args = []
1530
+ for arg_descr, arg_val in zip(self.parsed_args, args):
1531
+ from pyopencl.tools import VectorArg
1532
+ if isinstance(arg_descr, VectorArg):
1533
+ data_args.append(arg_val.base_data)
1534
+ if arg_descr.with_offset:
1535
+ data_args.append(arg_val.offset)
1536
+ wait_for.extend(arg_val.events)
1537
+ else:
1538
+ data_args.append(arg_val)
1539
+
1540
+ # }}}
1541
+
1542
+ l1_info = self.first_level_scan_info
1543
+ l2_info = self.second_level_scan_info
1544
+
1545
+ # see CL source above for terminology
1546
+ unit_size = l1_info.wg_size * l1_info.k_group_size
1547
+ max_intervals = 3*max(dev.max_compute_units for dev in self.devices)
1548
+
1549
+ from pytools import uniform_interval_splitting
1550
+ interval_size, num_intervals = uniform_interval_splitting(
1551
+ n, unit_size, max_intervals)
1552
+
1553
+ # {{{ allocate some buffers
1554
+
1555
+ interval_results = cl.array.empty(queue,
1556
+ num_intervals, dtype=self.dtype,
1557
+ allocator=allocator)
1558
+
1559
+ partial_scan_buffer = cl.array.empty(
1560
+ queue, n, dtype=self.dtype,
1561
+ allocator=allocator)
1562
+
1563
+ if self.store_segment_start_flags:
1564
+ segment_start_flags = cl.array.empty(
1565
+ queue, n, dtype=np.bool_,
1566
+ allocator=allocator)
1567
+
1568
+ # }}}
1569
+
1570
+ # {{{ first level scan of interval (one interval per block)
1571
+
1572
+ scan1_args = [
1573
+ *data_args,
1574
+ partial_scan_buffer.data, n, interval_size, interval_results.data,
1575
+ ]
1576
+
1577
+ if self.is_segmented:
1578
+ first_segment_start_in_interval = cl.array.empty(queue,
1579
+ num_intervals, dtype=self.index_dtype,
1580
+ allocator=allocator)
1581
+ scan1_args.append(first_segment_start_in_interval.data)
1582
+
1583
+ if self.store_segment_start_flags:
1584
+ scan1_args.append(segment_start_flags.data)
1585
+
1586
+ l1_evt = l1_info.kernel(
1587
+ queue, (num_intervals,), (l1_info.wg_size,),
1588
+ *scan1_args, g_times_l=True, wait_for=wait_for)
1589
+
1590
+ # }}}
1591
+
1592
+ # {{{ second level scan of per-interval results
1593
+
1594
+ # can scan at most one interval
1595
+ assert interval_size >= num_intervals
1596
+
1597
+ scan2_args = [
1598
+ *data_args,
1599
+ interval_results.data, # interval_sums
1600
+ ]
1601
+
1602
+ if self.is_segmented:
1603
+ scan2_args.append(first_segment_start_in_interval.data)
1604
+ scan2_args = [
1605
+ *scan2_args,
1606
+ interval_results.data, # partial_scan_buffer
1607
+ num_intervals, interval_size]
1608
+
1609
+ l2_evt = l2_info.kernel(
1610
+ queue, (1,), (l1_info.wg_size,),
1611
+ *scan2_args, g_times_l=True, wait_for=[l1_evt])
1612
+
1613
+ # }}}
1614
+
1615
+ # {{{ update intervals with result of interval scan
1616
+
1617
+ upd_args = [
1618
+ *data_args,
1619
+ n, interval_size, interval_results.data, partial_scan_buffer.data]
1620
+ if self.is_segmented:
1621
+ upd_args.append(first_segment_start_in_interval.data)
1622
+ if self.store_segment_start_flags:
1623
+ upd_args.append(segment_start_flags.data)
1624
+
1625
+ return self.final_update_info.kernel(
1626
+ queue, (num_intervals,),
1627
+ (self.final_update_info.update_wg_size,),
1628
+ *upd_args, g_times_l=True, wait_for=[l2_evt])
1629
+
1630
+ # }}}
1631
+
1632
+ # }}}
1633
+
1634
+
1635
+ # {{{ debug kernel
1636
+
1637
+ DEBUG_SCAN_TEMPLATE = SHARED_PREAMBLE + r"""//CL//
1638
+
1639
+ KERNEL
1640
+ REQD_WG_SIZE(1, 1, 1)
1641
+ void ${name_prefix}_debug_scan(
1642
+ __global scan_type *scan_tmp,
1643
+ ${argument_signature},
1644
+ const index_type N)
1645
+ {
1646
+ scan_type current = ${neutral};
1647
+ scan_type prev;
1648
+
1649
+ ${arg_offset_adjustment}
1650
+
1651
+ for (index_type i = 0; i < N; ++i)
1652
+ {
1653
+ %for name, arg_name, ife_offset in input_fetch_exprs:
1654
+ ${arg_ctypes[arg_name]} ${name};
1655
+ %if ife_offset < 0:
1656
+ if (i+${ife_offset} >= 0)
1657
+ ${name} = ${arg_name}[i+${ife_offset}];
1658
+ %else:
1659
+ ${name} = ${arg_name}[i];
1660
+ %endif
1661
+ %endfor
1662
+
1663
+ scan_type my_val = INPUT_EXPR(i);
1664
+
1665
+ prev = current;
1666
+ %if is_segmented:
1667
+ bool is_seg_start = IS_SEG_START(i, my_val);
1668
+ %endif
1669
+
1670
+ current = SCAN_EXPR(prev, my_val,
1671
+ %if is_segmented:
1672
+ is_seg_start
1673
+ %else:
1674
+ false
1675
+ %endif
1676
+ );
1677
+ scan_tmp[i] = current;
1678
+ }
1679
+
1680
+ scan_type last_item = scan_tmp[N-1];
1681
+
1682
+ for (index_type i = 0; i < N; ++i)
1683
+ {
1684
+ scan_type item = scan_tmp[i];
1685
+ scan_type prev_item;
1686
+ if (i)
1687
+ prev_item = scan_tmp[i-1];
1688
+ else
1689
+ prev_item = ${neutral};
1690
+
1691
+ {
1692
+ ${output_statement};
1693
+ }
1694
+ }
1695
+ }
1696
+ """
1697
+
1698
+
1699
+ class GenericDebugScanKernel(GenericScanKernelBase):
1700
+ """
1701
+ Performs the same function and has the same interface as
1702
+ :class:`GenericScanKernel`, but uses a dead-simple, sequential scan. Works
1703
+ best on CPU platforms, and helps isolate bugs in scans by removing the
1704
+ potential for issues originating in parallel execution.
1705
+
1706
+ .. automethod:: __call__
1707
+ """
1708
+
1709
+ def finish_setup(self) -> None:
1710
+ scan_tpl = _make_template(DEBUG_SCAN_TEMPLATE)
1711
+ scan_src = str(scan_tpl.render(
1712
+ output_statement=self.output_statement,
1713
+ arg_offset_adjustment=get_arg_offset_adjuster_code(self.parsed_args),
1714
+ argument_signature=", ".join(
1715
+ arg.declarator() for arg in self.parsed_args),
1716
+ is_segment_start_expr=self.is_segment_start_expr,
1717
+ input_expr=_process_code_for_macro(self.input_expr),
1718
+ input_fetch_exprs=self.input_fetch_exprs,
1719
+ wg_size=1,
1720
+ **self.code_variables))
1721
+
1722
+ scan_prg = cl.Program(self.context, scan_src).build(self.options)
1723
+ self.kernel = getattr(scan_prg, f"{self.name_prefix}_debug_scan")
1724
+ scalar_arg_dtypes = [
1725
+ None,
1726
+ *get_arg_list_scalar_arg_dtypes(self.parsed_args),
1727
+ self.index_dtype,
1728
+ ]
1729
+ self.kernel.set_scalar_arg_dtypes(scalar_arg_dtypes)
1730
+
1731
+ def __call__(self, *args: Any, **kwargs: Any) -> cl.Event:
1732
+ """See :meth:`GenericScanKernel.__call__`."""
1733
+
1734
+ # {{{ argument processing
1735
+
1736
+ allocator = kwargs.get("allocator")
1737
+ queue = kwargs.get("queue")
1738
+ n = kwargs.get("size")
1739
+ wait_for = kwargs.get("wait_for")
1740
+
1741
+ if wait_for is None:
1742
+ wait_for = []
1743
+ else:
1744
+ # We'll be modifying it below.
1745
+ wait_for = list(wait_for)
1746
+
1747
+ if len(args) != len(self.parsed_args):
1748
+ raise TypeError(
1749
+ f"expected {len(self.parsed_args)} arguments, got {len(args)}")
1750
+
1751
+ first_array = args[self.first_array_idx]
1752
+ allocator = allocator or first_array.allocator
1753
+ queue = queue or first_array.queue
1754
+
1755
+ if n is None:
1756
+ n, = first_array.shape
1757
+
1758
+ scan_tmp = cl.array.empty(queue,
1759
+ n, dtype=self.dtype,
1760
+ allocator=allocator)
1761
+
1762
+ data_args = [scan_tmp.data]
1763
+ from pyopencl.tools import VectorArg
1764
+ for arg_descr, arg_val in zip(self.parsed_args, args):
1765
+ if isinstance(arg_descr, VectorArg):
1766
+ data_args.append(arg_val.base_data)
1767
+ if arg_descr.with_offset:
1768
+ data_args.append(arg_val.offset)
1769
+ wait_for.extend(arg_val.events)
1770
+ else:
1771
+ data_args.append(arg_val)
1772
+
1773
+ # }}}
1774
+
1775
+ return self.kernel(queue, (1,), (1,), *([*data_args, n]), wait_for=wait_for)
1776
+
1777
+ # }}}
1778
+
1779
+
1780
+ # {{{ compatibility interface
1781
+
1782
+ class _LegacyScanKernelBase(GenericScanKernel):
1783
+ def __init__(self, ctx, dtype,
1784
+ scan_expr, neutral=None,
1785
+ name_prefix="scan", options=None, preamble="", devices=None):
1786
+ scan_ctype = dtype_to_ctype(dtype)
1787
+ GenericScanKernel.__init__(self,
1788
+ ctx, dtype,
1789
+ arguments="__global {} *input_ary, __global {} *output_ary".format(
1790
+ scan_ctype, scan_ctype),
1791
+ input_expr="input_ary[i]",
1792
+ scan_expr=scan_expr,
1793
+ neutral=neutral,
1794
+ output_statement=self.ary_output_statement,
1795
+ options=options, preamble=preamble, devices=devices)
1796
+
1797
+ @property
1798
+ def ary_output_statement(self):
1799
+ raise NotImplementedError
1800
+
1801
+ def __call__(self, input_ary, output_ary=None, allocator=None, queue=None):
1802
+ allocator = allocator or input_ary.allocator
1803
+ queue = queue or input_ary.queue or output_ary.queue
1804
+
1805
+ if output_ary is None:
1806
+ output_ary = input_ary
1807
+
1808
+ if isinstance(output_ary, (str, str)) and output_ary == "new":
1809
+ output_ary = cl.array.empty_like(input_ary, allocator=allocator)
1810
+
1811
+ if input_ary.shape != output_ary.shape:
1812
+ raise ValueError("input and output must have the same shape")
1813
+
1814
+ if not input_ary.flags.forc:
1815
+ raise RuntimeError("ScanKernel cannot "
1816
+ "deal with non-contiguous arrays")
1817
+
1818
+ n, = input_ary.shape
1819
+
1820
+ if not n:
1821
+ return output_ary
1822
+
1823
+ GenericScanKernel.__call__(self,
1824
+ input_ary, output_ary, allocator=allocator, queue=queue)
1825
+
1826
+ return output_ary
1827
+
1828
+
1829
+ class InclusiveScanKernel(_LegacyScanKernelBase):
1830
+ ary_output_statement = "output_ary[i] = item;"
1831
+
1832
+
1833
+ class ExclusiveScanKernel(_LegacyScanKernelBase):
1834
+ ary_output_statement = "output_ary[i] = prev_item;"
1835
+
1836
+ # }}}
1837
+
1838
+
1839
+ # {{{ template
1840
+
1841
+ class ScanTemplate(KernelTemplateBase):
1842
+ def __init__(
1843
+ self,
1844
+ arguments: Union[str, List[DtypedArgument]],
1845
+ input_expr: str,
1846
+ scan_expr: str,
1847
+ neutral: Optional[str],
1848
+ output_statement: str,
1849
+ is_segment_start_expr: Optional[str] = None,
1850
+ input_fetch_exprs: Optional[List[Tuple[str, str, int]]] = None,
1851
+ name_prefix: str = "scan",
1852
+ preamble: str = "",
1853
+ template_processor: Any = None) -> None:
1854
+ super().__init__(template_processor=template_processor)
1855
+
1856
+ if input_fetch_exprs is None:
1857
+ input_fetch_exprs = []
1858
+
1859
+ self.arguments = arguments
1860
+ self.input_expr = input_expr
1861
+ self.scan_expr = scan_expr
1862
+ self.neutral = neutral
1863
+ self.output_statement = output_statement
1864
+ self.is_segment_start_expr = is_segment_start_expr
1865
+ self.input_fetch_exprs = input_fetch_exprs
1866
+ self.name_prefix = name_prefix
1867
+ self.preamble = preamble
1868
+
1869
+ def build_inner(self, context, type_aliases=(), var_values=(),
1870
+ more_preamble="", more_arguments=(), declare_types=(),
1871
+ options=None, devices=None, scan_cls=GenericScanKernel):
1872
+ renderer = self.get_renderer(type_aliases, var_values, context, options)
1873
+
1874
+ arg_list = renderer.render_argument_list(self.arguments, more_arguments)
1875
+
1876
+ type_decl_preamble = renderer.get_type_decl_preamble(
1877
+ context.devices[0], declare_types, arg_list)
1878
+
1879
+ return scan_cls(context, renderer.type_aliases["scan_t"],
1880
+ renderer.render_argument_list(self.arguments, more_arguments),
1881
+ renderer(self.input_expr), renderer(self.scan_expr),
1882
+ renderer(self.neutral), renderer(self.output_statement),
1883
+ is_segment_start_expr=renderer(self.is_segment_start_expr),
1884
+ input_fetch_exprs=self.input_fetch_exprs,
1885
+ index_dtype=renderer.type_aliases.get("index_t", np.int32),
1886
+ name_prefix=renderer(self.name_prefix), options=options,
1887
+ preamble=(
1888
+ type_decl_preamble
1889
+ + "\n"
1890
+ + renderer(self.preamble + "\n" + more_preamble)),
1891
+ devices=devices)
1892
+
1893
+ # }}}
1894
+
1895
+
1896
+ # {{{ 'canned' scan kernels
1897
+
1898
+ @context_dependent_memoize
1899
+ def get_cumsum_kernel(context, input_dtype, output_dtype):
1900
+ from pyopencl.tools import VectorArg
1901
+ return GenericScanKernel(
1902
+ context, output_dtype,
1903
+ arguments=[
1904
+ VectorArg(input_dtype, "input"),
1905
+ VectorArg(output_dtype, "output"),
1906
+ ],
1907
+ input_expr="input[i]",
1908
+ scan_expr="a+b", neutral="0",
1909
+ output_statement="""
1910
+ output[i] = item;
1911
+ """)
1912
+
1913
+ # }}}
1914
+
1915
+ # vim: filetype=pyopencl:fdm=marker