pyomp 0.5.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. numba/openmp/__init__.py +106 -0
  2. numba/openmp/_version.py +34 -0
  3. numba/openmp/analysis.py +251 -0
  4. numba/openmp/compiler.py +402 -0
  5. numba/openmp/config.py +27 -0
  6. numba/openmp/decorators.py +27 -0
  7. numba/openmp/exceptions.py +26 -0
  8. numba/openmp/ir_utils.py +4 -0
  9. numba/openmp/libs/openmp/lib/libgomp.1.dylib +0 -0
  10. numba/openmp/libs/openmp/lib/libgomp.dylib +0 -0
  11. numba/openmp/libs/openmp/lib/libiomp5.dylib +0 -0
  12. numba/openmp/libs/openmp/lib/libomp.dylib +0 -0
  13. numba/openmp/libs/openmp/patches/14.0.6/0001-BACKPORT-Fix-for-CUDA-OpenMP-RTL.patch +39 -0
  14. numba/openmp/libs/openmp/patches/14.0.6/0002-Fix-missing-includes.patch +12 -0
  15. numba/openmp/libs/openmp/patches/14.0.6/0003-Link-static-LLVM-libs.patch +13 -0
  16. numba/openmp/libs/openmp/patches/15.0.7/0001-Fix-missing-includes.patch +14 -0
  17. numba/openmp/libs/openmp/patches/15.0.7/0002-Link-LLVM-statically.patch +101 -0
  18. numba/openmp/libs/openmp/patches/15.0.7/0003-Disable-opaque-pointers-DeviceRTL-bitcode.patch +12 -0
  19. numba/openmp/libs/openmp/patches/16.0.6/0001-Load-plugins-from-install-directory.patch +53 -0
  20. numba/openmp/libs/openmp/patches/16.0.6/0002-Link-LLVM-statically.patch +218 -0
  21. numba/openmp/libs/openmp/patches/20.1.8/0001-Enable-standalone-build.patch +13 -0
  22. numba/openmp/libs/openmp/patches/20.1.8/0002-Link-statically-LLVM.patch +24 -0
  23. numba/openmp/libs/openmp/patches/20.1.8/0003-Do-not-build-liboffload.patch +12 -0
  24. numba/openmp/libs/pass/CGIntrinsicsOpenMP.cpp +2939 -0
  25. numba/openmp/libs/pass/CGIntrinsicsOpenMP.h +606 -0
  26. numba/openmp/libs/pass/CMakeLists.txt +57 -0
  27. numba/openmp/libs/pass/DebugOpenMP.cpp +17 -0
  28. numba/openmp/libs/pass/DebugOpenMP.h +28 -0
  29. numba/openmp/libs/pass/IntrinsicsOpenMP.cpp +837 -0
  30. numba/openmp/libs/pass/IntrinsicsOpenMP.h +13 -0
  31. numba/openmp/libs/pass/IntrinsicsOpenMP_CAPI.h +23 -0
  32. numba/openmp/libs/pass/libIntrinsicsOpenMP.dylib +0 -0
  33. numba/openmp/link_utils.py +126 -0
  34. numba/openmp/llvm_pass.py +48 -0
  35. numba/openmp/llvmlite_extensions.py +75 -0
  36. numba/openmp/omp_context.py +242 -0
  37. numba/openmp/omp_grammar.py +696 -0
  38. numba/openmp/omp_ir.py +2105 -0
  39. numba/openmp/omp_lower.py +3125 -0
  40. numba/openmp/omp_runtime.py +107 -0
  41. numba/openmp/overloads.py +53 -0
  42. numba/openmp/parser.py +6 -0
  43. numba/openmp/tags.py +532 -0
  44. numba/openmp/tests/test_openmp.py +5056 -0
  45. pyomp-0.5.0.dist-info/METADATA +193 -0
  46. pyomp-0.5.0.dist-info/RECORD +52 -0
  47. pyomp-0.5.0.dist-info/WHEEL +6 -0
  48. pyomp-0.5.0.dist-info/licenses/LICENSE +25 -0
  49. pyomp-0.5.0.dist-info/licenses/LICENSE-OPENMP.txt +361 -0
  50. pyomp-0.5.0.dist-info/top_level.txt +3 -0
  51. pyomp.dylibs/libc++.1.0.dylib +0 -0
  52. pyomp.dylibs/libzstd.1.5.7.dylib +0 -0
@@ -0,0 +1,3125 @@
1
+ from lark import Transformer
2
+ from lark.exceptions import VisitError
3
+
4
+ from numba.core import ir, types, typed_passes
5
+ from numba.core.analysis import (
6
+ compute_cfg_from_blocks,
7
+ compute_use_defs,
8
+ compute_live_map,
9
+ )
10
+ from numba.core.ir_utils import (
11
+ dump_blocks,
12
+ get_call_table,
13
+ visit_vars,
14
+ build_definitions,
15
+ )
16
+ import copy
17
+ import operator
18
+ import sys
19
+ import os
20
+
21
+ from .config import DEBUG_OPENMP
22
+ from .parser import openmp_parser
23
+ from .analysis import (
24
+ remove_ssa,
25
+ user_defined_var,
26
+ is_dsa,
27
+ is_private,
28
+ is_internal_var,
29
+ has_user_defined_var,
30
+ get_user_defined_var,
31
+ get_enclosing_region,
32
+ add_enclosing_region,
33
+ filter_nested_loops,
34
+ remove_privatized,
35
+ get_var_from_enclosing,
36
+ remove_indirections,
37
+ add_tags_to_enclosing,
38
+ get_blocks_between_start_end,
39
+ get_itercount,
40
+ )
41
+ from .exceptions import (
42
+ UnspecifiedVarInDefaultNone,
43
+ ParallelForExtraCode,
44
+ ParallelForWrongLoopCount,
45
+ ParallelForInvalidCollapseCount,
46
+ MultipleNumThreadsClauses,
47
+ NonconstantOpenmpSpecification,
48
+ NonStringOpenmpSpecification,
49
+ )
50
+ from .ir_utils import dump_block
51
+ from .tags import openmp_tag, NameSlice, get_tags_of_type
52
+ from .omp_ir import (
53
+ openmp_region_start,
54
+ openmp_region_end,
55
+ default_shared_val,
56
+ )
57
+
58
+
59
+ class OpenmpVisitor(Transformer):
60
+ target_num = 0
61
+
62
+ def __init__(self, func_ir, blocks, blk_start, blk_end, body_blocks, loc):
63
+ self.func_ir = func_ir
64
+ self.blocks = blocks
65
+ self.blk_start = blk_start
66
+ self.blk_end = blk_end
67
+ self.body_blocks = body_blocks
68
+ self.loc = loc
69
+ super(OpenmpVisitor, self).__init__()
70
+
71
+ # --------- Non-parser functions --------------------
72
+
73
+ def remove_explicit_from_one(
74
+ self, varset, vars_in_explicit_clauses, clauses, scope, loc
75
+ ):
76
+ """Go through a set of variables and see if their non-SSA form is in an explicitly
77
+ provided data clause. If so, remove it from the set and add a clause so that the
78
+ SSA form gets the same data clause.
79
+ """
80
+ if DEBUG_OPENMP >= 1:
81
+ print(
82
+ "remove_explicit start:",
83
+ sorted(varset),
84
+ sorted(vars_in_explicit_clauses),
85
+ )
86
+ diff = set()
87
+ # For each variable in the set.
88
+ for v in sorted(varset):
89
+ # Get the non-SSA form.
90
+ flat = remove_ssa(v, scope, loc)
91
+ # Skip non-SSA introduced variables (i.e., Python vars).
92
+ if flat == v:
93
+ continue
94
+ if DEBUG_OPENMP >= 1:
95
+ print("remove_explicit:", v, flat, flat in vars_in_explicit_clauses)
96
+ # If we have the non-SSA form in an explicit data clause.
97
+ if flat in vars_in_explicit_clauses:
98
+ # We will remove it from the set.
99
+ diff.add(v)
100
+ # Copy the non-SSA variables data clause.
101
+ ccopy = copy.copy(vars_in_explicit_clauses[flat])
102
+ # Change the name in the clause to the SSA form.
103
+ ccopy.arg = ir.Var(scope, v, loc)
104
+ # Add to the clause set.
105
+ clauses.append(ccopy)
106
+ # Remove the vars from the set that we added a clause for.
107
+ varset.difference_update(diff)
108
+ if DEBUG_OPENMP >= 1:
109
+ print("remove_explicit end:", sorted(varset))
110
+
111
+ def remove_explicit_from_io_vars(
112
+ self,
113
+ inputs_to_region,
114
+ def_but_live_out,
115
+ private_to_region,
116
+ vars_in_explicit_clauses,
117
+ clauses,
118
+ non_user_explicits,
119
+ scope,
120
+ loc,
121
+ ):
122
+ """Remove vars in explicit data clauses from the auto-determined vars.
123
+ Then call remove_explicit_from_one to take SSA variants out of the auto-determined sets
124
+ and to create clauses so that SSA versions get the same clause as the explicit Python non-SSA var.
125
+ """
126
+ inputs_to_region.difference_update(vars_in_explicit_clauses.keys())
127
+ def_but_live_out.difference_update(vars_in_explicit_clauses.keys())
128
+ private_to_region.difference_update(vars_in_explicit_clauses.keys())
129
+ inputs_to_region.difference_update(non_user_explicits.keys())
130
+ def_but_live_out.difference_update(non_user_explicits.keys())
131
+ private_to_region.difference_update(non_user_explicits.keys())
132
+ self.remove_explicit_from_one(
133
+ inputs_to_region, vars_in_explicit_clauses, clauses, scope, loc
134
+ )
135
+ self.remove_explicit_from_one(
136
+ def_but_live_out, vars_in_explicit_clauses, clauses, scope, loc
137
+ )
138
+ self.remove_explicit_from_one(
139
+ private_to_region, vars_in_explicit_clauses, clauses, scope, loc
140
+ )
141
+
142
+ def find_io_vars(self, selected_blocks):
143
+ cfg = compute_cfg_from_blocks(self.blocks)
144
+ usedefs = compute_use_defs(self.blocks)
145
+ if DEBUG_OPENMP >= 1:
146
+ print("usedefs:", usedefs)
147
+ live_map = compute_live_map(cfg, self.blocks, usedefs.usemap, usedefs.defmap)
148
+ # Assumes enter_with is first statement in block.
149
+ inputs_to_region = live_map[self.blk_start]
150
+ if DEBUG_OPENMP >= 1:
151
+ print("live_map:", live_map)
152
+ print("inputs_to_region:", sorted(inputs_to_region), type(inputs_to_region))
153
+ print("selected blocks:", sorted(selected_blocks))
154
+ all_uses = set()
155
+ all_defs = set()
156
+ for label in selected_blocks:
157
+ all_uses = all_uses.union(usedefs.usemap[label])
158
+ all_defs = all_defs.union(usedefs.defmap[label])
159
+ # Filter out those vars live to the region but not used within it.
160
+ inputs_to_region = inputs_to_region.intersection(all_uses)
161
+ def_but_live_out = all_defs.difference(inputs_to_region).intersection(
162
+ live_map[self.blk_end]
163
+ )
164
+ private_to_region = all_defs.difference(inputs_to_region).difference(
165
+ live_map[self.blk_end]
166
+ )
167
+
168
+ if DEBUG_OPENMP >= 1:
169
+ print("all_uses:", sorted(all_uses))
170
+ print("inputs_to_region:", sorted(inputs_to_region))
171
+ print("private_to_region:", sorted(private_to_region))
172
+ print("def_but_live_out:", sorted(def_but_live_out))
173
+ return inputs_to_region, def_but_live_out, private_to_region, live_map
174
+
175
+ def get_explicit_vars(self, clauses):
176
+ user_vars = {}
177
+ non_user_vars = {}
178
+ privates = []
179
+ for c in clauses:
180
+ if DEBUG_OPENMP >= 1:
181
+ print("get_explicit_vars:", c, type(c))
182
+ if isinstance(c, openmp_tag):
183
+ if DEBUG_OPENMP >= 1:
184
+ print("arg:", c.arg, type(c.arg))
185
+ if isinstance(c.arg, list):
186
+ carglist = c.arg
187
+ else:
188
+ carglist = [c.arg]
189
+ # carglist = c.arg if isinstance(c.arg, list) else [c.arg]
190
+ for carg in carglist:
191
+ if DEBUG_OPENMP >= 1:
192
+ print(
193
+ "carg:",
194
+ carg,
195
+ type(carg),
196
+ user_defined_var(carg),
197
+ is_dsa(c.name),
198
+ )
199
+ # Extract the var name from the NameSlice.
200
+ if isinstance(carg, NameSlice):
201
+ carg = carg.name
202
+ if isinstance(carg, str) and is_dsa(c.name):
203
+ if user_defined_var(carg):
204
+ user_vars[carg] = c
205
+ if is_private(c.name):
206
+ privates.append(carg)
207
+ else:
208
+ non_user_vars[carg] = c
209
+ return user_vars, privates, non_user_vars
210
+
211
+ def filter_unused_vars(self, clauses, used_vars):
212
+ new_clauses = []
213
+ for c in clauses:
214
+ if DEBUG_OPENMP >= 1:
215
+ print("filter_unused_vars:", c, type(c))
216
+ if isinstance(c, openmp_tag):
217
+ if DEBUG_OPENMP >= 1:
218
+ print("arg:", c.arg, type(c.arg))
219
+ assert not isinstance(c.arg, list)
220
+ if DEBUG_OPENMP >= 1:
221
+ print(
222
+ "c.arg:",
223
+ c.arg,
224
+ type(c.arg),
225
+ user_defined_var(c.arg),
226
+ is_dsa(c.name),
227
+ )
228
+
229
+ if (
230
+ isinstance(c.arg, str)
231
+ and user_defined_var(c.arg)
232
+ and is_dsa(c.name)
233
+ ):
234
+ if c.arg in used_vars:
235
+ new_clauses.append(c)
236
+ else:
237
+ new_clauses.append(c)
238
+ return new_clauses
239
+
240
+ def get_clause_privates(self, clauses, def_but_live_out, scope, loc):
241
+ # Get all the private clauses from the whole set of clauses.
242
+ private_clauses_vars = [
243
+ remove_privatized(x.arg)
244
+ for x in clauses
245
+ if x.name in ["QUAL.OMP.PRIVATE", "QUAL.OMP.FIRSTPRIVATE"]
246
+ ]
247
+ # private_clauses_vars = [remove_privatized(x.arg) for x in clauses if x.name in ["QUAL.OMP.PRIVATE", "QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.LASTPRIVATE"]]
248
+ ret = {}
249
+ # Get a mapping of vars in private clauses to the SSA version of variable exiting the region.
250
+ for lo in def_but_live_out:
251
+ without_ssa = remove_ssa(lo, scope, loc)
252
+ if without_ssa in private_clauses_vars:
253
+ ret[without_ssa] = lo
254
+ return ret
255
+
256
+ def make_implicit_explicit(
257
+ self,
258
+ scope,
259
+ vars_in_explicit,
260
+ explicit_clauses,
261
+ gen_shared,
262
+ inputs_to_region,
263
+ def_but_live_out,
264
+ private_to_region,
265
+ for_task=False,
266
+ ):
267
+ if for_task is None:
268
+ for_task = []
269
+ if gen_shared:
270
+ for var_name in sorted(inputs_to_region):
271
+ if (
272
+ for_task != False
273
+ and get_var_from_enclosing(for_task, var_name) != "QUAL.OMP.SHARED"
274
+ ):
275
+ explicit_clauses.append(
276
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", var_name)
277
+ )
278
+ else:
279
+ explicit_clauses.append(openmp_tag("QUAL.OMP.SHARED", var_name))
280
+ vars_in_explicit[var_name] = explicit_clauses[-1]
281
+
282
+ for var_name in sorted(def_but_live_out):
283
+ if (
284
+ for_task != False
285
+ and get_var_from_enclosing(for_task, var_name) != "QUAL.OMP.SHARED"
286
+ ):
287
+ explicit_clauses.append(
288
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", var_name)
289
+ )
290
+ else:
291
+ explicit_clauses.append(openmp_tag("QUAL.OMP.SHARED", var_name))
292
+ vars_in_explicit[var_name] = explicit_clauses[-1]
293
+
294
+ # What to do below for task regions?
295
+ for var_name in sorted(private_to_region):
296
+ temp_var = ir.Var(scope, var_name, self.loc)
297
+ if not is_internal_var(temp_var):
298
+ explicit_clauses.append(openmp_tag("QUAL.OMP.PRIVATE", var_name))
299
+ vars_in_explicit[var_name] = explicit_clauses[-1]
300
+
301
+ for var_name in sorted(private_to_region):
302
+ temp_var = ir.Var(scope, var_name, self.loc)
303
+ if is_internal_var(temp_var):
304
+ explicit_clauses.append(openmp_tag("QUAL.OMP.PRIVATE", var_name))
305
+ vars_in_explicit[var_name] = explicit_clauses[-1]
306
+
307
+ def make_implicit_explicit_target(
308
+ self,
309
+ scope,
310
+ vars_in_explicit,
311
+ explicit_clauses,
312
+ gen_shared,
313
+ inputs_to_region,
314
+ def_but_live_out,
315
+ private_to_region,
316
+ ):
317
+ # unversioned_privates = set() # we get rid of SSA on the first openmp region so no SSA forms should be here
318
+ if gen_shared:
319
+ for var_name in sorted(inputs_to_region):
320
+ explicit_clauses.append(
321
+ openmp_tag(
322
+ "QUAL.OMP.TARGET.IMPLICIT"
323
+ if user_defined_var(var_name)
324
+ else "QUAL.OMP.PRIVATE",
325
+ var_name,
326
+ )
327
+ )
328
+ vars_in_explicit[var_name] = explicit_clauses[-1]
329
+ for var_name in sorted(def_but_live_out):
330
+ explicit_clauses.append(
331
+ openmp_tag(
332
+ "QUAL.OMP.TARGET.IMPLICIT"
333
+ if user_defined_var(var_name)
334
+ else "QUAL.OMP.PRIVATE",
335
+ var_name,
336
+ )
337
+ )
338
+ vars_in_explicit[var_name] = explicit_clauses[-1]
339
+ for var_name in sorted(private_to_region):
340
+ temp_var = ir.Var(scope, var_name, self.loc)
341
+ if not is_internal_var(temp_var):
342
+ explicit_clauses.append(openmp_tag("QUAL.OMP.PRIVATE", var_name))
343
+ # explicit_clauses.append(openmp_tag("QUAL.OMP.TARGET.IMPLICIT" if user_defined_var(var_name) else "QUAL.OMP.PRIVATE", var_name))
344
+ vars_in_explicit[var_name] = explicit_clauses[-1]
345
+
346
+ for var_name in sorted(private_to_region):
347
+ temp_var = ir.Var(scope, var_name, self.loc)
348
+ if is_internal_var(temp_var):
349
+ explicit_clauses.append(
350
+ openmp_tag(
351
+ "QUAL.OMP.TARGET.IMPLICIT"
352
+ if user_defined_var(var_name)
353
+ else "QUAL.OMP.PRIVATE",
354
+ var_name,
355
+ )
356
+ )
357
+ vars_in_explicit[var_name] = explicit_clauses[-1]
358
+
359
+ def add_explicits_to_start(
360
+ self,
361
+ scope,
362
+ vars_in_explicit,
363
+ explicit_clauses,
364
+ gen_shared,
365
+ start_tags,
366
+ keep_alive,
367
+ ):
368
+ start_tags.extend(explicit_clauses)
369
+ return []
370
+ # tags_for_enclosing = []
371
+ # for var in vars_in_explicit:
372
+ # if not is_private(vars_in_explicit[var].name):
373
+ # print("EVAR_COPY FOR", var)
374
+ # evar = ir.Var(scope, var, self.loc)
375
+ # evar_copy = scope.redefine("evar_copy_aets", self.loc)
376
+ # keep_alive.append(ir.Assign(evar, evar_copy, self.loc))
377
+ # #keep_alive.append(ir.Assign(evar, evar, self.loc))
378
+ # tags_for_enclosing.append(openmp_tag("QUAL.OMP.PRIVATE", evar_copy))
379
+ # return tags_for_enclosing
380
+
381
+ def flatten(self, all_clauses, start_block):
382
+ if DEBUG_OPENMP >= 1:
383
+ print("flatten", id(start_block))
384
+ incoming_clauses = [remove_indirections(x) for x in all_clauses]
385
+ clauses = []
386
+ default_shared = True
387
+ for clause in incoming_clauses:
388
+ if DEBUG_OPENMP >= 1:
389
+ print("clause:", clause, type(clause))
390
+ if isinstance(clause, openmp_tag):
391
+ clauses.append(clause)
392
+ elif isinstance(clause, list):
393
+ clauses.extend(remove_indirections(clause))
394
+ elif clause == "nowait":
395
+ clauses.append(openmp_tag("QUAL.OMP.NOWAIT"))
396
+ elif isinstance(clause, default_shared_val):
397
+ default_shared = clause.val
398
+ if DEBUG_OPENMP >= 1:
399
+ print("got new default_shared:", clause.val)
400
+ else:
401
+ if DEBUG_OPENMP >= 1:
402
+ print(
403
+ "Unknown clause type in incoming_clauses", clause, type(clause)
404
+ )
405
+ assert 0
406
+
407
+ if hasattr(start_block, "openmp_replace_vardict"):
408
+ for clause in clauses:
409
+ # print("flatten out clause:", clause, clause.arg, type(clause.arg))
410
+ for vardict in start_block.openmp_replace_vardict:
411
+ if clause.arg in vardict:
412
+ # print("clause.arg in vardict:", clause.arg, type(clause.arg), vardict[clause.arg], type(vardict[clause.arg]))
413
+ clause.arg = vardict[clause.arg].name
414
+
415
+ return clauses, default_shared
416
+
417
+ def add_replacement(self, blocks, replace_vardict):
418
+ for b in blocks.values():
419
+ if not hasattr(b, "openmp_replace_vardict"):
420
+ b.openmp_replace_vardict = []
421
+ b.openmp_replace_vardict.append(replace_vardict)
422
+
423
+ def make_consts_unliteral_for_privates(self, privates, blocks):
424
+ for blk in blocks.values():
425
+ for stmt in blk.body:
426
+ if (
427
+ isinstance(stmt, ir.Assign)
428
+ and isinstance(stmt.value, ir.Const)
429
+ and stmt.target.name in privates
430
+ ):
431
+ stmt.value.use_literal_type = False
432
+
433
+ def fix_empty_header(self, block, label):
434
+ if len(block.body) == 1:
435
+ assert isinstance(block.body[0], ir.Jump)
436
+ return self.blocks[block.body[0].target], block.body[0].target
437
+ return block, label
438
+
439
+ def prepare_for_directive(
440
+ self,
441
+ clauses,
442
+ vars_in_explicit_clauses,
443
+ before_start,
444
+ after_start,
445
+ start_tags,
446
+ end_tags,
447
+ scope,
448
+ ):
449
+ start_tags = clauses
450
+ call_table, _ = get_call_table(self.blocks)
451
+ cfg = compute_cfg_from_blocks(self.blocks)
452
+ usedefs = compute_use_defs(self.blocks)
453
+ live_map = compute_live_map(cfg, self.blocks, usedefs.usemap, usedefs.defmap)
454
+
455
+ def get_loops_in_region(all_loops):
456
+ loops = {}
457
+ for k, v in all_loops.items():
458
+ if v.header >= self.blk_start and v.header <= self.blk_end:
459
+ loops[k] = v
460
+ return loops
461
+
462
+ all_loops = cfg.loops()
463
+ if DEBUG_OPENMP >= 1:
464
+ print("all_loops:", all_loops)
465
+ print("live_map:", live_map)
466
+ print("body_blocks:", self.body_blocks)
467
+
468
+ loops = get_loops_in_region(all_loops)
469
+ # Find the outer-most loop in this OpenMP region.
470
+ loops = list(filter_nested_loops(cfg, loops))
471
+
472
+ if DEBUG_OPENMP >= 1:
473
+ print("loops:", loops)
474
+ if len(loops) != 1:
475
+ raise ParallelForWrongLoopCount(
476
+ f"OpenMP parallel for regions must contain exactly one range based loop. The parallel for at line {self.loc} contains {len(loops)} loops."
477
+ )
478
+
479
+ collapse_tags = get_tags_of_type(clauses, "QUAL.OMP.COLLAPSE")
480
+ new_stmts_for_iterspace = []
481
+ collapse_iterspace_block = set()
482
+ iterspace_vars = []
483
+ if len(collapse_tags) > 0:
484
+ # Limit all_loops to just loops within the openmp region.
485
+ all_loops = get_loops_in_region(all_loops)
486
+ # In case of multiple collapse tags, use the last one.
487
+ collapse_tag = collapse_tags[-1]
488
+ # Remove collapse tags from clauses so they don't go to LLVM pass.
489
+ clauses[:] = [x for x in clauses if x not in collapse_tags]
490
+ # Add top level loop to loop_order list.
491
+ loop_order = list(filter_nested_loops(cfg, all_loops))
492
+ if len(loop_order) != 1:
493
+ raise ParallelForWrongLoopCount(
494
+ f"OpenMP parallel for region must have only one top-level loop at line {self.loc}."
495
+ )
496
+ # Determine how many nested loops we need to process.
497
+ collapse_value = collapse_tag.arg - 1
498
+ # Make sure initial collapse value was >= 2.
499
+ if collapse_value <= 0:
500
+ raise ParallelForInvalidCollapseCount(
501
+ f"OpenMP parallel for regions with collapse clauses must be greather than or equal to 2 at line {self.loc}."
502
+ )
503
+
504
+ # Delete top-level loop from all_loops.
505
+ del all_loops[loop_order[-1].header]
506
+ # For remaining nested loops...
507
+ for _ in range(collapse_value):
508
+ # Get the next most top-level loop.
509
+ loops = list(filter_nested_loops(cfg, all_loops))
510
+ # Make sure there is only one.
511
+ if len(loops) != 1:
512
+ raise ParallelForWrongLoopCount(
513
+ f"OpenMP parallel for collapse regions must be perfectly nested for the parallel for at line {self.loc}."
514
+ )
515
+ # Add this loop to the loops to process in order.
516
+ loop_order.append(loops[0])
517
+ # Delete this loop from all_loops.
518
+ del all_loops[loop_order[-1].header]
519
+
520
+ if DEBUG_OPENMP >= 2:
521
+ print("loop_order:", loop_order)
522
+ stmts_to_retain = []
523
+ loop_bounds = []
524
+ for loop in loop_order:
525
+ loop_entry = list(loop.entries)[0]
526
+ loop_exit = list(loop.exits)[0]
527
+ loop_header = loop.header
528
+ loop_entry_block = self.blocks[loop_entry]
529
+ loop_header_block, _ = self.fix_empty_header(
530
+ self.blocks[loop_header], loop_header
531
+ )
532
+
533
+ # Copy all stmts from the loop entry block up to the ir.Global
534
+ # for range.
535
+ call_offset = None
536
+ for entry_block_index, stmt in enumerate(loop_entry_block.body):
537
+ found_range = False
538
+ if (
539
+ isinstance(stmt, ir.Assign)
540
+ and isinstance(stmt.value, ir.Global)
541
+ and stmt.value.name == "range"
542
+ ):
543
+ found_range = True
544
+ range_target = stmt.target
545
+ found_call = False
546
+ for call_index in range(
547
+ entry_block_index + 1, len(loop_entry_block.body)
548
+ ):
549
+ call_stmt = loop_entry_block.body[call_index]
550
+ if (
551
+ isinstance(call_stmt, ir.Assign)
552
+ and isinstance(call_stmt.value, ir.Expr)
553
+ and call_stmt.value.op == "call"
554
+ and call_stmt.value.func == range_target
555
+ ):
556
+ found_call = True
557
+ # Remove stmts that were retained.
558
+ loop_entry_block.body = loop_entry_block.body[
559
+ entry_block_index:
560
+ ]
561
+ call_offset = call_index - entry_block_index
562
+ break
563
+ assert found_call
564
+ break
565
+ stmts_to_retain.append(stmt)
566
+ assert found_range
567
+ for header_block_index, stmt in enumerate(loop_header_block.body):
568
+ if (
569
+ isinstance(stmt, ir.Assign)
570
+ and isinstance(stmt.value, ir.Expr)
571
+ and stmt.value.op == "iternext"
572
+ ):
573
+ iternext_inst = loop_header_block.body[header_block_index]
574
+ pair_first_inst = loop_header_block.body[header_block_index + 1]
575
+ pair_second_inst = loop_header_block.body[
576
+ header_block_index + 2
577
+ ]
578
+
579
+ assert (
580
+ isinstance(iternext_inst, ir.Assign)
581
+ and isinstance(iternext_inst.value, ir.Expr)
582
+ and iternext_inst.value.op == "iternext"
583
+ )
584
+ assert (
585
+ isinstance(pair_first_inst, ir.Assign)
586
+ and isinstance(pair_first_inst.value, ir.Expr)
587
+ and pair_first_inst.value.op == "pair_first"
588
+ )
589
+ assert (
590
+ isinstance(pair_second_inst, ir.Assign)
591
+ and isinstance(pair_second_inst.value, ir.Expr)
592
+ and pair_second_inst.value.op == "pair_second"
593
+ )
594
+ stmts_to_retain.extend(
595
+ loop_header_block.body[header_block_index + 3 : -1]
596
+ )
597
+ loop_index = pair_first_inst.target
598
+ break
599
+ stmts_to_retain.append(stmt)
600
+ loop_bounds.append((call_stmt.value.args[0], loop_index))
601
+ if DEBUG_OPENMP >= 1:
602
+ print("collapse 1")
603
+ dump_blocks(self.blocks)
604
+ # For all the loops except the last...
605
+ for loop in loop_order[:-1]:
606
+ # Change the unneeded headers to just jump to the next block.
607
+ loop_header = loop.header
608
+ loop_header_block, real_loop_header = self.fix_empty_header(
609
+ self.blocks[loop_header], loop_header
610
+ )
611
+ collapse_iterspace_block.add(real_loop_header)
612
+ loop_header_block.body[-1] = ir.Jump(
613
+ loop_header_block.body[-1].truebr, loop_header_block.body[-1].loc
614
+ )
615
+ last_eliminated_loop_header_block = loop_header_block
616
+ self.body_blocks = [
617
+ x for x in self.body_blocks if x not in loop.entries
618
+ ]
619
+ self.body_blocks.remove(loop.header)
620
+ if DEBUG_OPENMP >= 1:
621
+ print("loop order:", loop_order)
622
+ print("loop bounds:", loop_bounds)
623
+ print("collapse 2")
624
+ dump_blocks(self.blocks)
625
+ last_loop = loop_order[-1]
626
+ last_loop_entry = list(last_loop.entries)[0]
627
+ last_loop_exit = list(last_loop.exits)[0]
628
+ last_loop_header = last_loop.header
629
+ last_loop_entry_block = self.blocks[last_loop_entry]
630
+ last_loop_header_block, _ = self.fix_empty_header(
631
+ self.blocks[last_loop_header], loop_header
632
+ )
633
+ last_loop_first_body_block = last_loop_header_block.body[-1].truebr
634
+ self.blocks[last_loop_first_body_block].body = (
635
+ stmts_to_retain + self.blocks[last_loop_first_body_block].body
636
+ )
637
+ last_loop_header_block.body[-1].falsebr = list(loop_order[0].exits)[0]
638
+ new_var_scope = last_loop_entry_block.body[0].target.scope
639
+
640
+ # -------- Add vars to remember cumulative product of iteration space sizes.
641
+ new_iterspace_var = new_var_scope.redefine("new_iterspace0", self.loc)
642
+ start_tags.append(
643
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", new_iterspace_var.name)
644
+ )
645
+ iterspace_vars.append(new_iterspace_var)
646
+ new_stmts_for_iterspace.append(
647
+ ir.Assign(loop_bounds[0][0], new_iterspace_var, self.loc)
648
+ )
649
+ for lb_num, loop_bound in enumerate(loop_bounds[1:]):
650
+ mul_op = ir.Expr.binop(
651
+ operator.mul, new_iterspace_var, loop_bound[0], self.loc
652
+ )
653
+ new_iterspace_var = new_var_scope.redefine(
654
+ "new_iterspace" + str(lb_num + 1), self.loc
655
+ )
656
+ start_tags.append(
657
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", new_iterspace_var.name)
658
+ )
659
+ iterspace_vars.append(new_iterspace_var)
660
+ new_stmts_for_iterspace.append(
661
+ ir.Assign(mul_op, new_iterspace_var, self.loc)
662
+ )
663
+ # Change iteration space of innermost loop to the product of all the
664
+ # loops' iteration spaces.
665
+ last_loop_entry_block.body[call_offset].value.args[0] = new_iterspace_var
666
+
667
+ last_eliminated_loop_header_block.body = (
668
+ new_stmts_for_iterspace + last_eliminated_loop_header_block.body
669
+ )
670
+
671
+ deconstruct_indices = []
672
+ new_deconstruct_var = new_var_scope.redefine("deconstruct", self.loc)
673
+ deconstruct_indices.append(
674
+ ir.Assign(loop_bounds[-1][1], new_deconstruct_var, self.loc)
675
+ )
676
+ for deconstruct_index in range(len(loop_bounds) - 1):
677
+ cur_iterspace_var = iterspace_vars[
678
+ len(loop_bounds) - 2 - deconstruct_index
679
+ ]
680
+ cur_loop_bound = loop_bounds[deconstruct_index][1]
681
+ # if DEBUG_OPENMP >= 1:
682
+ # print("deconstructing", cur_iterspace_var)
683
+ # deconstruct_indices.append(ir.Print([new_deconstruct_var, cur_iterspace_var], None, self.loc))
684
+ deconstruct_div = ir.Expr.binop(
685
+ operator.floordiv, new_deconstruct_var, cur_iterspace_var, self.loc
686
+ )
687
+ new_deconstruct_var_loop = new_var_scope.redefine(
688
+ "deconstruct" + str(deconstruct_index), self.loc
689
+ )
690
+ deconstruct_indices.append(
691
+ ir.Assign(deconstruct_div, cur_loop_bound, self.loc)
692
+ )
693
+ # if DEBUG_OPENMP >= 1:
694
+ # deconstruct_indices.append(ir.Print([cur_loop_bound], None, self.loc))
695
+ new_deconstruct_var_mul = new_var_scope.redefine(
696
+ "deconstruct_mul" + str(deconstruct_index), self.loc
697
+ )
698
+ deconstruct_indices.append(
699
+ ir.Assign(
700
+ ir.Expr.binop(
701
+ operator.mul, cur_loop_bound, cur_iterspace_var, self.loc
702
+ ),
703
+ new_deconstruct_var_mul,
704
+ self.loc,
705
+ )
706
+ )
707
+ # if DEBUG_OPENMP >= 1:
708
+ # deconstruct_indices.append(ir.Print([new_deconstruct_var_mul], None, self.loc))
709
+ deconstruct_indices.append(
710
+ ir.Assign(
711
+ ir.Expr.binop(
712
+ operator.sub,
713
+ new_deconstruct_var,
714
+ new_deconstruct_var_mul,
715
+ self.loc,
716
+ ),
717
+ new_deconstruct_var_loop,
718
+ self.loc,
719
+ )
720
+ )
721
+ # if DEBUG_OPENMP >= 1:
722
+ # deconstruct_indices.append(ir.Print([new_deconstruct_var_loop], None, self.loc))
723
+ new_deconstruct_var = new_deconstruct_var_loop
724
+ deconstruct_indices.append(
725
+ ir.Assign(new_deconstruct_var, loop_bounds[-1][1], self.loc)
726
+ )
727
+
728
+ self.blocks[last_loop_first_body_block].body = (
729
+ deconstruct_indices + self.blocks[last_loop_first_body_block].body
730
+ )
731
+
732
+ if DEBUG_OPENMP >= 1:
733
+ print("collapse 3", self.blk_start, self.blk_end)
734
+ dump_blocks(self.blocks)
735
+
736
+ cfg = compute_cfg_from_blocks(self.blocks)
737
+ live_map = compute_live_map(
738
+ cfg, self.blocks, usedefs.usemap, usedefs.defmap
739
+ )
740
+ all_loops = cfg.loops()
741
+ loops = get_loops_in_region(all_loops)
742
+ loops = list(filter_nested_loops(cfg, loops))
743
+ if DEBUG_OPENMP >= 2:
744
+ print("loops after collapse:", loops)
745
+ if DEBUG_OPENMP >= 1:
746
+ print("blocks after collapse", self.blk_start, self.blk_end)
747
+ dump_blocks(self.blocks)
748
+
749
+ def _get_loop_kind(func_var, call_table):
750
+ if func_var not in call_table:
751
+ return False
752
+ call = call_table[func_var]
753
+ if len(call) == 0:
754
+ return False
755
+
756
+ return call[0]
757
+
758
+ loop = loops[0]
759
+ entry = list(loop.entries)[0]
760
+ header = loop.header
761
+ exit = list(loop.exits)[0]
762
+
763
+ loop_blocks_for_io = loop.entries.union(loop.body)
764
+ loop_blocks_for_io_minus_entry = loop_blocks_for_io - {entry}
765
+ non_loop_blocks = set(self.body_blocks)
766
+ non_loop_blocks.difference_update(loop_blocks_for_io)
767
+ non_loop_blocks.difference_update(collapse_iterspace_block)
768
+ # non_loop_blocks.difference_update({exit})
769
+
770
+ if DEBUG_OPENMP >= 1:
771
+ print("non_loop_blocks:", non_loop_blocks)
772
+ print("entry:", entry)
773
+ print("header:", header)
774
+ print("exit:", exit)
775
+ print("body_blocks:", self.body_blocks)
776
+ print("loop:", loop)
777
+
778
+ # Find the first statement after any iterspace calculation ones for collapse.
779
+ first_stmt = self.blocks[entry].body[0]
780
+ # first_stmt = self.blocks[entry].body[len(new_stmts_for_iterspace)]
781
+ if (
782
+ not isinstance(first_stmt, ir.Assign)
783
+ or not isinstance(first_stmt.value, ir.Global)
784
+ or first_stmt.value.name != "range"
785
+ ):
786
+ raise ParallelForExtraCode(
787
+ f"Extra code near line {self.loc} is not allowed before or after the loop in an OpenMP parallel for region."
788
+ )
789
+
790
+ for non_loop_block in non_loop_blocks:
791
+ nlb = self.blocks[non_loop_block]
792
+ if isinstance(nlb.body[0], ir.Jump):
793
+ # Non-loop empty blocks are fine.
794
+ continue
795
+ if (
796
+ isinstance(nlb.body[-1], ir.Jump)
797
+ and nlb.body[-1].target == self.blk_end
798
+ ):
799
+ # Loop through all statements in block that jumps to the end of the region.
800
+ # If those are all assignments where the LHS is dead then they are safe.
801
+ for nlb_stmt in nlb.body[:-1]:
802
+ if isinstance(nlb_stmt, ir.PopBlock):
803
+ continue
804
+
805
+ break
806
+ # if not isinstance(nlb_stmt, ir.Assign):
807
+ # break # Non-assignment is not known to be safe...will fallthrough to raise exception.
808
+ # if nlb_stmt.target.name in live_end:
809
+ # break # Non-dead variables in assignment is not safe...will fallthrough to raise exception.
810
+ else:
811
+ continue
812
+ raise ParallelForExtraCode(
813
+ f"Extra code near line {self.loc} is not allowed before or after the loop in an OpenMP parallel for region."
814
+ )
815
+
816
+ if DEBUG_OPENMP >= 1:
817
+ print("loop_blocks_for_io:", loop_blocks_for_io, entry, exit)
818
+ print("non_loop_blocks:", non_loop_blocks)
819
+ print("header:", header)
820
+
821
+ entry_block = self.blocks[entry]
822
+ assert isinstance(entry_block.body[-1], ir.Jump)
823
+ assert entry_block.body[-1].target == header
824
+ exit_block = self.blocks[exit]
825
+ header_block = self.blocks[header]
826
+ extra_block = (
827
+ None if len(header_block.body) > 1 else header_block.body[-1].target
828
+ )
829
+
830
+ latch_block_num = max(self.blocks.keys()) + 1
831
+
832
+ # We have to reformat the Numba style of loop to the form that the LLVM
833
+ # openmp pass supports.
834
+ header_preds = [x[0] for x in cfg.predecessors(header)]
835
+ entry_preds = list(set(header_preds).difference(loop.body))
836
+ back_blocks = list(set(header_preds).intersection(loop.body))
837
+ if DEBUG_OPENMP >= 1:
838
+ print("header_preds:", header_preds)
839
+ print("entry_preds:", entry_preds)
840
+ print("back_blocks:", back_blocks)
841
+ assert len(entry_preds) == 1
842
+ entry_pred_label = entry_preds[0]
843
+ entry_pred = self.blocks[entry_pred_label]
844
+ if extra_block is not None:
845
+ header_block = self.blocks[extra_block]
846
+ header = extra_block
847
+ header_branch = header_block.body[-1]
848
+ post_header = {header_branch.truebr, header_branch.falsebr}
849
+ post_header.remove(exit)
850
+ if DEBUG_OPENMP >= 1:
851
+ print("post_header:", post_header)
852
+ post_header = self.blocks[list(post_header)[0]]
853
+ if DEBUG_OPENMP >= 1:
854
+ print("post_header:", post_header)
855
+
856
+ for inst_num, inst in enumerate(entry_block.body):
857
+ if (
858
+ isinstance(inst, ir.Assign)
859
+ and isinstance(inst.value, ir.Expr)
860
+ and inst.value.op == "call"
861
+ ):
862
+ loop_kind = _get_loop_kind(inst.value.func.name, call_table)
863
+ if DEBUG_OPENMP >= 1:
864
+ print("loop_kind:", loop_kind)
865
+ if loop_kind and loop_kind is range:
866
+ range_inst = inst
867
+ range_args = inst.value.args
868
+ if DEBUG_OPENMP >= 1:
869
+ print("found one", loop_kind, inst, range_args)
870
+
871
+ # ----------------------------------------------
872
+ # Find getiter instruction for this range.
873
+ for entry_inst in entry_block.body[inst_num + 1 :]:
874
+ if (
875
+ isinstance(entry_inst, ir.Assign)
876
+ and isinstance(entry_inst.value, ir.Expr)
877
+ and entry_inst.value.op == "getiter"
878
+ and entry_inst.value.value == range_inst.target
879
+ ):
880
+ getiter_inst = entry_inst
881
+ break
882
+ assert getiter_inst
883
+ if DEBUG_OPENMP >= 1:
884
+ print("getiter_inst:", getiter_inst)
885
+ # ----------------------------------------------
886
+
887
+ assert len(header_block.body) > 3
888
+ if DEBUG_OPENMP >= 1:
889
+ print("header block before removing Numba range vars:")
890
+ dump_block(header, header_block)
891
+
892
+ for ii in range(len(header_block.body)):
893
+ ii_inst = header_block.body[ii]
894
+ if (
895
+ isinstance(ii_inst, ir.Assign)
896
+ and isinstance(ii_inst.value, ir.Expr)
897
+ and ii_inst.value.op == "iternext"
898
+ ):
899
+ iter_num = ii
900
+ break
901
+
902
+ iternext_inst = header_block.body[iter_num]
903
+ pair_first_inst = header_block.body[iter_num + 1]
904
+ pair_second_inst = header_block.body[iter_num + 2]
905
+
906
+ assert (
907
+ isinstance(iternext_inst, ir.Assign)
908
+ and isinstance(iternext_inst.value, ir.Expr)
909
+ and iternext_inst.value.op == "iternext"
910
+ )
911
+ assert (
912
+ isinstance(pair_first_inst, ir.Assign)
913
+ and isinstance(pair_first_inst.value, ir.Expr)
914
+ and pair_first_inst.value.op == "pair_first"
915
+ )
916
+ assert (
917
+ isinstance(pair_second_inst, ir.Assign)
918
+ and isinstance(pair_second_inst.value, ir.Expr)
919
+ and pair_second_inst.value.op == "pair_second"
920
+ )
921
+ # Remove those nodes from the IR.
922
+ header_block.body = (
923
+ header_block.body[:iter_num] + header_block.body[iter_num + 3 :]
924
+ )
925
+ if DEBUG_OPENMP >= 1:
926
+ print("header block after removing Numba range vars:")
927
+ dump_block(header, header_block)
928
+
929
+ loop_index = pair_first_inst.target
930
+ if DEBUG_OPENMP >= 1:
931
+ print("loop_index:", loop_index, type(loop_index))
932
+ # The loop_index from Numba's perspective is not what it is from the
933
+ # programmer's perspective. The OpenMP loop index is always private so
934
+ # we need to start from Numba's loop index (e.g., $48for_iter.3) and
935
+ # trace assignments from that through the header block and then find
936
+ # the first such assignment in the first loop block that the header
937
+ # branches to.
938
+ latest_index = loop_index
939
+ for hinst in header_block.body:
940
+ if isinstance(hinst, ir.Assign) and isinstance(
941
+ hinst.value, ir.Var
942
+ ):
943
+ if hinst.value.name == latest_index.name:
944
+ latest_index = hinst.target
945
+ for phinst in post_header.body:
946
+ if isinstance(phinst, ir.Assign) and isinstance(
947
+ phinst.value, ir.Var
948
+ ):
949
+ if phinst.value.name == latest_index.name:
950
+ latest_index = phinst.target
951
+ break
952
+ if DEBUG_OPENMP >= 1:
953
+ print("latest_index:", latest_index, type(latest_index))
954
+
955
+ if latest_index.name not in vars_in_explicit_clauses:
956
+ new_index_clause = openmp_tag(
957
+ "QUAL.OMP.PRIVATE",
958
+ ir.Var(loop_index.scope, latest_index.name, inst.loc),
959
+ )
960
+ clauses.append(new_index_clause)
961
+ vars_in_explicit_clauses[latest_index.name] = new_index_clause
962
+ else:
963
+ if (
964
+ vars_in_explicit_clauses[latest_index.name].name
965
+ != "QUAL.OMP.PRIVATE"
966
+ ):
967
+ pass
968
+ # throw error? FIX ME
969
+
970
+ if DEBUG_OPENMP >= 1:
971
+ for clause in clauses:
972
+ print("post-latest_index clauses:", clause)
973
+
974
+ start = 0
975
+ step = 1
976
+ size_var = range_args[0]
977
+ if len(range_args) == 2:
978
+ start = range_args[0]
979
+ size_var = range_args[1]
980
+ if len(range_args) == 3:
981
+ start = range_args[0]
982
+ size_var = range_args[1]
983
+ try:
984
+ step = self.func_ir.get_definition(range_args[2])
985
+ # Only use get_definition to get a const if
986
+ # available. Otherwise use the variable.
987
+ if not isinstance(step, (int, ir.Const)):
988
+ step = range_args[2]
989
+ except KeyError:
990
+ # If there is more than one definition possible for the
991
+ # step variable then just use the variable and don't try
992
+ # to convert to a const.
993
+ step = range_args[2]
994
+ if isinstance(step, ir.Const):
995
+ step = step.value
996
+
997
+ if DEBUG_OPENMP >= 1:
998
+ print("size_var:", size_var, type(size_var))
999
+
1000
+ omp_lb_var = loop_index.scope.redefine("$omp_lb", inst.loc)
1001
+ before_start.append(
1002
+ ir.Assign(ir.Const(0, inst.loc), omp_lb_var, inst.loc)
1003
+ )
1004
+
1005
+ omp_iv_var = loop_index.scope.redefine("$omp_iv", inst.loc)
1006
+ # before_start.append(ir.Assign(omp_lb_var, omp_iv_var, inst.loc))
1007
+ # Don't use omp_lb here because that makes a live-in to the region that
1008
+ # becomes a parameter to an outlined target region.
1009
+ after_start.append(
1010
+ ir.Assign(ir.Const(0, inst.loc), omp_iv_var, inst.loc)
1011
+ )
1012
+ # after_start.append(ir.Assign(omp_lb_var, omp_iv_var, inst.loc))
1013
+
1014
+ types_mod_var = loop_index.scope.redefine(
1015
+ "$numba_types_mod", inst.loc
1016
+ )
1017
+ types_mod = ir.Global("types", types, inst.loc)
1018
+ types_mod_assign = ir.Assign(types_mod, types_mod_var, inst.loc)
1019
+ before_start.append(types_mod_assign)
1020
+
1021
+ int64_var = loop_index.scope.redefine("$int64_var", inst.loc)
1022
+ int64_getattr = ir.Expr.getattr(types_mod_var, "int64", inst.loc)
1023
+ int64_assign = ir.Assign(int64_getattr, int64_var, inst.loc)
1024
+ before_start.append(int64_assign)
1025
+
1026
+ get_itercount_var = loop_index.scope.redefine(
1027
+ "$get_itercount", inst.loc
1028
+ )
1029
+ get_itercount_global = ir.Global(
1030
+ "get_itercount", get_itercount, inst.loc
1031
+ )
1032
+ get_itercount_assign = ir.Assign(
1033
+ get_itercount_global, get_itercount_var, inst.loc
1034
+ )
1035
+ before_start.append(get_itercount_assign)
1036
+
1037
+ itercount_var = loop_index.scope.redefine("$itercount", inst.loc)
1038
+ itercount_expr = ir.Expr.call(
1039
+ get_itercount_var, [getiter_inst.target], (), inst.loc
1040
+ )
1041
+ # itercount_expr = ir.Expr.itercount(getiter_inst.target, inst.loc)
1042
+ before_start.append(
1043
+ ir.Assign(itercount_expr, itercount_var, inst.loc)
1044
+ )
1045
+
1046
+ omp_ub_var = loop_index.scope.redefine("$omp_ub", inst.loc)
1047
+ omp_ub_expr = ir.Expr.call(int64_var, [itercount_var], (), inst.loc)
1048
+ before_start.append(ir.Assign(omp_ub_expr, omp_ub_var, inst.loc))
1049
+
1050
+ const1_var = loop_index.scope.redefine("$const1", inst.loc)
1051
+ start_tags.append(openmp_tag("QUAL.OMP.PRIVATE", const1_var))
1052
+ const1_assign = ir.Assign(
1053
+ ir.Const(1, inst.loc), const1_var, inst.loc
1054
+ )
1055
+ before_start.append(const1_assign)
1056
+ count_add_1 = ir.Expr.binop(
1057
+ operator.sub, omp_ub_var, const1_var, inst.loc
1058
+ )
1059
+ before_start.append(ir.Assign(count_add_1, omp_ub_var, inst.loc))
1060
+
1061
+ # before_start.append(ir.Print([omp_ub_var], None, inst.loc))
1062
+
1063
+ omp_start_var = loop_index.scope.redefine("$omp_start", inst.loc)
1064
+ if start == 0:
1065
+ start = ir.Const(start, inst.loc)
1066
+ before_start.append(ir.Assign(start, omp_start_var, inst.loc))
1067
+
1068
+ # ---------- Create latch block -------------------------------
1069
+ latch_iv = omp_iv_var
1070
+
1071
+ latch_block = ir.Block(scope, inst.loc)
1072
+ const1_latch_var = loop_index.scope.redefine(
1073
+ "$const1_latch", inst.loc
1074
+ )
1075
+ start_tags.append(openmp_tag("QUAL.OMP.PRIVATE", const1_latch_var))
1076
+ const1_assign = ir.Assign(
1077
+ ir.Const(1, inst.loc), const1_latch_var, inst.loc
1078
+ )
1079
+ latch_block.body.append(const1_assign)
1080
+ latch_assign = ir.Assign(
1081
+ ir.Expr.binop(
1082
+ operator.add, omp_iv_var, const1_latch_var, inst.loc
1083
+ ),
1084
+ latch_iv,
1085
+ inst.loc,
1086
+ )
1087
+ latch_block.body.append(latch_assign)
1088
+ latch_block.body.append(ir.Jump(header, inst.loc))
1089
+
1090
+ self.blocks[latch_block_num] = latch_block
1091
+ for bb in back_blocks:
1092
+ if False:
1093
+ str_var = scope.redefine("$str_var", inst.loc)
1094
+ str_const = ir.Const("mid start:", inst.loc)
1095
+ str_assign = ir.Assign(str_const, str_var, inst.loc)
1096
+ str_print = ir.Print([str_var, size_var], None, inst.loc)
1097
+ # before_start.append(str_assign)
1098
+ # before_start.append(str_print)
1099
+ self.blocks[bb].body = self.blocks[bb].body[:-1] + [
1100
+ str_assign,
1101
+ str_print,
1102
+ ir.Jump(latch_block_num, inst.loc),
1103
+ ]
1104
+ else:
1105
+ self.blocks[bb].body[-1] = ir.Jump(
1106
+ latch_block_num, inst.loc
1107
+ )
1108
+ # -------------------------------------------------------------
1109
+
1110
+ # ---------- Header Manipulation ------------------------------
1111
+ step_var = loop_index.scope.redefine("$step_var", inst.loc)
1112
+ detect_step_assign = ir.Assign(
1113
+ ir.Const(0, inst.loc), step_var, inst.loc
1114
+ )
1115
+ after_start.append(detect_step_assign)
1116
+
1117
+ if isinstance(step, int):
1118
+ step_assign = ir.Assign(
1119
+ ir.Const(step, inst.loc), step_var, inst.loc
1120
+ )
1121
+ elif isinstance(step, ir.Var):
1122
+ step_assign = ir.Assign(step, step_var, inst.loc)
1123
+ start_tags.append(
1124
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", step.name)
1125
+ )
1126
+ else:
1127
+ print("Unsupported step:", step, type(step))
1128
+ raise NotImplementedError(
1129
+ f"Unknown step type that isn't a constant or variable but {type(step)} instead."
1130
+ )
1131
+ scale_var = loop_index.scope.redefine("$scale", inst.loc)
1132
+ fake_iternext = ir.Assign(
1133
+ ir.Const(0, inst.loc), iternext_inst.target, inst.loc
1134
+ )
1135
+ fake_second = ir.Assign(
1136
+ ir.Const(0, inst.loc), pair_second_inst.target, inst.loc
1137
+ )
1138
+ scale_assign = ir.Assign(
1139
+ ir.Expr.binop(operator.mul, step_var, omp_iv_var, inst.loc),
1140
+ scale_var,
1141
+ inst.loc,
1142
+ )
1143
+ unnormalize_iv = ir.Assign(
1144
+ ir.Expr.binop(operator.add, omp_start_var, scale_var, inst.loc),
1145
+ loop_index,
1146
+ inst.loc,
1147
+ )
1148
+ cmp_var = loop_index.scope.redefine("$cmp", inst.loc)
1149
+ iv_lte_ub = ir.Assign(
1150
+ ir.Expr.binop(operator.le, omp_iv_var, omp_ub_var, inst.loc),
1151
+ cmp_var,
1152
+ inst.loc,
1153
+ )
1154
+ old_branch = header_block.body[-1]
1155
+ new_branch = ir.Branch(
1156
+ cmp_var, old_branch.truebr, old_branch.falsebr, old_branch.loc
1157
+ )
1158
+ body_label = old_branch.truebr
1159
+ first_body_block = self.blocks[body_label]
1160
+ new_end = [iv_lte_ub, new_branch]
1161
+ # Turn this on to add printing to help debug at runtime.
1162
+ if False:
1163
+ str_var = loop_index.scope.redefine("$str_var", inst.loc)
1164
+ str_const = ir.Const("header1:", inst.loc)
1165
+ str_assign = ir.Assign(str_const, str_var, inst.loc)
1166
+ new_end.append(str_assign)
1167
+ str_print = ir.Print(
1168
+ [str_var, omp_start_var, omp_iv_var], None, inst.loc
1169
+ )
1170
+ new_end.append(str_print)
1171
+
1172
+ # Prepend original contents of header into the first body block minus the comparison
1173
+ first_body_block.body = (
1174
+ [
1175
+ fake_iternext,
1176
+ fake_second,
1177
+ step_assign,
1178
+ scale_assign,
1179
+ unnormalize_iv,
1180
+ ]
1181
+ + header_block.body[:-1]
1182
+ + first_body_block.body
1183
+ )
1184
+
1185
+ header_block.body = new_end
1186
+ # header_block.body = [fake_iternext, fake_second, unnormalize_iv] + header_block.body[:-1] + new_end
1187
+
1188
+ # -------------------------------------------------------------
1189
+
1190
+ # const_start_var = loop_index.scope.redefine("$const_start", inst.loc)
1191
+ # before_start.append(ir.Assign(ir.Const(0, inst.loc), const_start_var, inst.loc))
1192
+ # start_tags.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", const_start_var.name))
1193
+ start_tags.append(
1194
+ openmp_tag("QUAL.OMP.NORMALIZED.IV", omp_iv_var.name)
1195
+ )
1196
+ start_tags.append(
1197
+ openmp_tag("QUAL.OMP.NORMALIZED.START", omp_start_var.name)
1198
+ )
1199
+ start_tags.append(
1200
+ openmp_tag("QUAL.OMP.NORMALIZED.LB", omp_lb_var.name)
1201
+ )
1202
+ start_tags.append(
1203
+ openmp_tag("QUAL.OMP.NORMALIZED.UB", omp_ub_var.name)
1204
+ )
1205
+ start_tags.append(openmp_tag("QUAL.OMP.PRIVATE", omp_iv_var.name))
1206
+ start_tags.append(
1207
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", omp_start_var.name)
1208
+ )
1209
+ start_tags.append(
1210
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", omp_lb_var.name)
1211
+ )
1212
+ start_tags.append(
1213
+ openmp_tag("QUAL.OMP.FIRSTPRIVATE", omp_ub_var.name)
1214
+ )
1215
+ tags_for_enclosing = [
1216
+ cmp_var.name,
1217
+ omp_lb_var.name,
1218
+ omp_start_var.name,
1219
+ omp_iv_var.name,
1220
+ types_mod_var.name,
1221
+ int64_var.name,
1222
+ itercount_var.name,
1223
+ omp_ub_var.name,
1224
+ const1_var.name,
1225
+ const1_latch_var.name,
1226
+ get_itercount_var.name,
1227
+ ] + [x.name for x in iterspace_vars]
1228
+ tags_for_enclosing = [
1229
+ openmp_tag("QUAL.OMP.PRIVATE", x) for x in tags_for_enclosing
1230
+ ]
1231
+ # Don't blindly copy code here...this isn't doing what the other spots are doing with privatization.
1232
+ add_tags_to_enclosing(
1233
+ self.func_ir, self.blk_start, tags_for_enclosing
1234
+ )
1235
+ # start_tags.append(openmp_tag("QUAL.OMP.NORMALIZED.IV", loop_index.name))
1236
+ # start_tags.append(openmp_tag("QUAL.OMP.NORMALIZED.UB", size_var.name))
1237
+ return (
1238
+ True,
1239
+ loop_blocks_for_io,
1240
+ loop_blocks_for_io_minus_entry,
1241
+ entry_pred,
1242
+ exit_block,
1243
+ inst,
1244
+ size_var,
1245
+ step_var,
1246
+ latest_index,
1247
+ loop_index,
1248
+ )
1249
+
1250
+ return False, None, None, None, None, None, None, None, None, None
1251
+
1252
+ def some_for_directive(
1253
+ self, args, main_start_tag, main_end_tag, first_clause, gen_shared
1254
+ ):
1255
+ if DEBUG_OPENMP >= 1:
1256
+ print("some_for_directive", self.body_blocks)
1257
+ start_tags = [openmp_tag(main_start_tag)]
1258
+ end_tags = [openmp_tag(main_end_tag)]
1259
+ clauses = self.some_data_clause_directive(
1260
+ args, start_tags, end_tags, first_clause, has_loop=True
1261
+ )
1262
+
1263
+ if "PARALLEL" in main_start_tag:
1264
+ # ---- Back propagate THREAD_LIMIT to enclosed target region. ----
1265
+ self.parallel_back_prop(clauses)
1266
+
1267
+ if len(list(filter(lambda x: x.name == "QUAL.OMP.NUM_THREADS", clauses))) > 1:
1268
+ raise MultipleNumThreadsClauses(
1269
+ f"Multiple num_threads clauses near line {self.loc} is not allowed in an OpenMP parallel region."
1270
+ )
1271
+
1272
+ # --------- Parser functions ------------------------
1273
+
1274
+ def barrier_directive(self, args):
1275
+ sblk = self.blocks[self.blk_start]
1276
+
1277
+ if DEBUG_OPENMP >= 1:
1278
+ print("visit barrier_directive", args, type(args))
1279
+ or_start = openmp_region_start([openmp_tag("DIR.OMP.BARRIER")], 0, self.loc)
1280
+ or_start.requires_combined_acquire_release()
1281
+ or_end = openmp_region_end(
1282
+ or_start, [openmp_tag("DIR.OMP.END.BARRIER")], self.loc
1283
+ )
1284
+ sblk.body = [or_start] + [or_end] + sblk.body[:]
1285
+
1286
+ def taskwait_directive(self, args):
1287
+ sblk = self.blocks[self.blk_start]
1288
+
1289
+ if DEBUG_OPENMP >= 1:
1290
+ print("visit taskwait_directive", args, type(args))
1291
+ or_start = openmp_region_start([openmp_tag("DIR.OMP.TASKWAIT")], 0, self.loc)
1292
+ or_start.requires_combined_acquire_release()
1293
+ or_end = openmp_region_end(
1294
+ or_start, [openmp_tag("DIR.OMP.END.TASKWAIT")], self.loc
1295
+ )
1296
+ sblk.body = [or_start] + [or_end] + sblk.body[:]
1297
+
1298
+ def taskyield_directive(self, args):
1299
+ raise NotImplementedError("Taskyield currently unsupported.")
1300
+
1301
+ # Don't need a rule for BARRIER.
1302
+ # Don't need a rule for TASKWAIT.
1303
+ # Don't need a rule for TASKYIELD.
1304
+
1305
+ def taskgroup_directive(self, args):
1306
+ raise NotImplementedError("Taskgroup currently unsupported.")
1307
+
1308
+ # Don't need a rule for taskgroup_construct.
1309
+ # Don't need a rule for TASKGROUP.
1310
+
1311
+ # Don't need a rule for openmp_construct.
1312
+
1313
+ # def teams_distribute_parallel_for_simd_clause(self, args):
1314
+ # raise NotImplementedError("""Simd clause for target teams
1315
+ # distribute parallel loop currently unsupported.""")
1316
+ # if DEBUG_OPENMP >= 1:
1317
+ # print("visit device_clause", args, type(args))
1318
+
1319
+ # Don't need a rule for for_simd_construct.
1320
+
1321
+ def for_simd_directive(self, args):
1322
+ raise NotImplementedError("For simd currently unsupported.")
1323
+
1324
+ def for_simd_clause(self, args):
1325
+ if DEBUG_OPENMP >= 1:
1326
+ print("visit for_simd_clause", args, type(args), args[0])
1327
+ return args[0]
1328
+
1329
+ def schedule_clause(self, args):
1330
+ if DEBUG_OPENMP >= 1:
1331
+ print("visit schedule_clause", args, type(args), args[0])
1332
+ return args[0]
1333
+
1334
+ def dist_schedule_clause(self, args):
1335
+ if DEBUG_OPENMP >= 1:
1336
+ print("visit dist_schedule_clause", args, type(args), args[0])
1337
+ return args[0]
1338
+
1339
+ # Don't need a rule for parallel_for_simd_construct.
1340
+
1341
+ def parallel_for_simd_directive(self, args):
1342
+ raise NotImplementedError("Parallel for simd currently unsupported.")
1343
+
1344
+ def parallel_for_simd_clause(self, args):
1345
+ if DEBUG_OPENMP >= 1:
1346
+ print("visit parallel_for_simd_clause", args, type(args), args[0])
1347
+ return args[0]
1348
+
1349
+ # Don't need a rule for target_data_construct.
1350
+
1351
+ def target_data_directive(self, args):
1352
+ sblk = self.blocks[self.blk_start]
1353
+ eblk = self.blocks[self.blk_end]
1354
+
1355
+ if DEBUG_OPENMP >= 1:
1356
+ print("visit target_data_directive", args, type(args))
1357
+
1358
+ before_start = []
1359
+ after_start = []
1360
+
1361
+ clauses, default_shared = self.flatten(args[2:], sblk)
1362
+
1363
+ if DEBUG_OPENMP >= 1:
1364
+ for clause in clauses:
1365
+ print("final clause:", clause)
1366
+
1367
+ inputs_to_region, def_but_live_out, private_to_region, live_map = (
1368
+ self.find_io_vars(self.body_blocks)
1369
+ )
1370
+ used_in_region = inputs_to_region | def_but_live_out | private_to_region
1371
+ clauses = self.filter_unused_vars(clauses, used_in_region)
1372
+
1373
+ start_tags = [openmp_tag("DIR.OMP.TARGET.DATA")] + clauses
1374
+ end_tags = [openmp_tag("DIR.OMP.END.TARGET.DATA")]
1375
+
1376
+ or_start = openmp_region_start(start_tags, 0, self.loc)
1377
+ or_end = openmp_region_end(or_start, end_tags, self.loc)
1378
+ sblk.body = before_start + [or_start] + after_start + sblk.body[:]
1379
+ eblk.body = [or_end] + eblk.body[:]
1380
+
1381
+ add_enclosing_region(self.func_ir, self.body_blocks, or_start)
1382
+
1383
+ # Don't need a rule for DATA.
1384
+
1385
+ def target_data_clause(self, args):
1386
+ if DEBUG_OPENMP >= 1:
1387
+ print("visit target_data_clause", args, type(args), args[0])
1388
+ (val,) = args
1389
+ if isinstance(val, openmp_tag):
1390
+ return [val]
1391
+ elif isinstance(val, list):
1392
+ return val
1393
+ elif val == "nowait":
1394
+ return openmp_tag("QUAL.OMP.NOWAIT")
1395
+ else:
1396
+ return val
1397
+
1398
+ def target_enter_data_clause(self, args):
1399
+ if DEBUG_OPENMP >= 1:
1400
+ print("visit target_enter_data_clause", args, type(args), args[0])
1401
+ (val,) = args
1402
+ if isinstance(val, openmp_tag):
1403
+ return [val]
1404
+ elif isinstance(val, list):
1405
+ return val
1406
+ elif val == "nowait":
1407
+ return openmp_tag("QUAL.OMP.NOWAIT")
1408
+ else:
1409
+ return val
1410
+
1411
+ def target_exit_data_clause(self, args):
1412
+ if DEBUG_OPENMP >= 1:
1413
+ print("visit target_exit_data_clause", args, type(args), args[0])
1414
+ (val,) = args
1415
+ if isinstance(val, openmp_tag):
1416
+ return [val]
1417
+ elif isinstance(val, list):
1418
+ return val
1419
+ elif val == "nowait":
1420
+ return openmp_tag("QUAL.OMP.NOWAIT")
1421
+ else:
1422
+ return val
1423
+
1424
+ def device_clause(self, args):
1425
+ if DEBUG_OPENMP >= 1:
1426
+ print("visit device_clause", args, type(args))
1427
+ return [openmp_tag("QUAL.OMP.DEVICE", args[0])]
1428
+
1429
+ def map_clause(self, args):
1430
+ if DEBUG_OPENMP >= 1:
1431
+ print("visit map_clause", args, type(args), args[0])
1432
+ if args[0] in ["to", "from", "alloc", "tofrom"]:
1433
+ map_type = args[0].upper()
1434
+ var_list = args[1]
1435
+ assert len(args) == 2
1436
+ else:
1437
+ # TODO: is this default right?
1438
+ map_type = "TOFROM"
1439
+ var_list = args[1]
1440
+ ret = []
1441
+ for var in var_list:
1442
+ ret.append(openmp_tag("QUAL.OMP.MAP." + map_type, var))
1443
+ return ret
1444
+
1445
+ def map_enter_clause(self, args):
1446
+ if DEBUG_OPENMP >= 1:
1447
+ print("visit map_enter_clause", args, type(args), args[0])
1448
+ assert args[0] in ["to", "alloc"]
1449
+ map_type = args[0].upper()
1450
+ var_list = args[1]
1451
+ assert len(args) == 2
1452
+ ret = []
1453
+ for var in var_list:
1454
+ ret.append(openmp_tag("QUAL.OMP.MAP." + map_type, var))
1455
+ return ret
1456
+
1457
+ def map_exit_clause(self, args):
1458
+ if DEBUG_OPENMP >= 1:
1459
+ print("visit map_exit_clause", args, type(args), args[0])
1460
+ assert args[0] in ["from", "release", "delete"]
1461
+ map_type = args[0].upper()
1462
+ var_list = args[1]
1463
+ assert len(args) == 2
1464
+ ret = []
1465
+ for var in var_list:
1466
+ ret.append(openmp_tag("QUAL.OMP.MAP." + map_type, var))
1467
+ return ret
1468
+
1469
+ def depend_with_modifier_clause(self, args):
1470
+ if DEBUG_OPENMP >= 1:
1471
+ print("visit depend_with_modifier_clause", args, type(args), args[0])
1472
+ dep_type = args[1].upper()
1473
+ var_list = args[2]
1474
+ assert len(args) == 3
1475
+ ret = []
1476
+ for var in var_list:
1477
+ ret.append(openmp_tag("QUAL.OMP.DEPEND." + dep_type, var))
1478
+ return ret
1479
+
1480
+ def map_type(self, args):
1481
+ if DEBUG_OPENMP >= 1:
1482
+ print("visit map_type", args, type(args), args[0])
1483
+ return str(args[0])
1484
+
1485
+ def map_enter_type(self, args):
1486
+ if DEBUG_OPENMP >= 1:
1487
+ print("visit map_enter_type", args, type(args), args[0])
1488
+ return str(args[0])
1489
+
1490
+ def map_exit_type(self, args):
1491
+ if DEBUG_OPENMP >= 1:
1492
+ print("visit map_exit_type", args, type(args), args[0])
1493
+ return str(args[0])
1494
+
1495
+ def update_motion_type(self, args):
1496
+ if DEBUG_OPENMP >= 1:
1497
+ print("visit update_motion_type", args, type(args), args[0])
1498
+ return str(args[0])
1499
+
1500
+ # Don't need a rule for TO.
1501
+ # Don't need a rule for FROM.
1502
+ # Don't need a rule for ALLOC.
1503
+ # Don't need a rule for TOFROM.
1504
+ # Don't need a rule for parallel_sections_construct.
1505
+
1506
+ def parallel_sections_directive(self, args):
1507
+ raise NotImplementedError("Parallel sections currently unsupported.")
1508
+
1509
+ def parallel_sections_clause(self, args):
1510
+ if DEBUG_OPENMP >= 1:
1511
+ print("visit parallel_sections_clause", args, type(args), args[0])
1512
+ return args[0]
1513
+
1514
+ # Don't need a rule for sections_construct.
1515
+
1516
+ def sections_directive(self, args):
1517
+ raise NotImplementedError("Sections directive currently unsupported.")
1518
+
1519
+ # Don't need a rule for SECTIONS.
1520
+
1521
+ def sections_clause(self, args):
1522
+ if DEBUG_OPENMP >= 1:
1523
+ print("visit sections_clause", args, type(args), args[0])
1524
+ return args[0]
1525
+
1526
+ # Don't need a rule for section_construct.
1527
+
1528
+ def section_directive(self, args):
1529
+ raise NotImplementedError("Section directive currently unsupported.")
1530
+
1531
+ # Don't need a rule for SECTION.
1532
+ # Don't need a rule for atomic_construct.
1533
+
1534
+ def atomic_directive(self, args):
1535
+ raise NotImplementedError("Atomic currently unsupported.")
1536
+
1537
+ # Don't need a rule for ATOMIC.
1538
+
1539
+ def atomic_clause(self, args):
1540
+ if DEBUG_OPENMP >= 1:
1541
+ print("visit atomic_clause", args, type(args), args[0])
1542
+ return args[0]
1543
+
1544
+ # Don't need a rule for READ.
1545
+ # Don't need a rule for WRITE.
1546
+ # Don't need a rule for UPDATE.
1547
+ # Don't need a rule for CAPTURE.
1548
+ # Don't need a rule for seq_cst_clause.
1549
+ # Don't need a rule for critical_construct.
1550
+
1551
+ def critical_directive(self, args):
1552
+ sblk = self.blocks[self.blk_start]
1553
+ eblk = self.blocks[self.blk_end]
1554
+ scope = sblk.scope
1555
+
1556
+ if DEBUG_OPENMP >= 1:
1557
+ print("visit critical_directive", args, type(args))
1558
+ or_start = openmp_region_start([openmp_tag("DIR.OMP.CRITICAL")], 0, self.loc)
1559
+ or_start.requires_acquire_release()
1560
+ or_end = openmp_region_end(
1561
+ or_start, [openmp_tag("DIR.OMP.END.CRITICAL")], self.loc
1562
+ )
1563
+
1564
+ inputs_to_region, def_but_live_out, private_to_region, live_map = (
1565
+ self.find_io_vars(self.body_blocks)
1566
+ )
1567
+ inputs_to_region = {remove_ssa(x, scope, self.loc): x for x in inputs_to_region}
1568
+ def_but_live_out = {remove_ssa(x, scope, self.loc): x for x in def_but_live_out}
1569
+ common_keys = inputs_to_region.keys() & def_but_live_out.keys()
1570
+ in_def_live_out = {
1571
+ inputs_to_region[k]: def_but_live_out[k] for k in common_keys
1572
+ }
1573
+ if DEBUG_OPENMP >= 1:
1574
+ print("inputs_to_region:", sorted(inputs_to_region))
1575
+ print("def_but_live_out:", sorted(def_but_live_out))
1576
+ print("in_def_live_out:", sorted(in_def_live_out))
1577
+
1578
+ reset = []
1579
+ for k, v in in_def_live_out.items():
1580
+ reset.append(
1581
+ ir.Assign(
1582
+ ir.Var(scope, v, self.loc), ir.Var(scope, k, self.loc), self.loc
1583
+ )
1584
+ )
1585
+
1586
+ sblk.body = [or_start] + sblk.body[:]
1587
+ eblk.body = reset + [or_end] + eblk.body[:]
1588
+
1589
+ # Don't need a rule for CRITICAL.
1590
+ # Don't need a rule for target_construct.
1591
+ # Don't need a rule for target_teams_distribute_parallel_for_simd_construct.
1592
+
1593
+ def teams_back_prop(self, clauses):
1594
+ enclosing_regions = get_enclosing_region(self.func_ir, self.blk_start)
1595
+ if DEBUG_OPENMP >= 1:
1596
+ print("teams enclosing_regions:", enclosing_regions)
1597
+ if not enclosing_regions:
1598
+ return
1599
+
1600
+ for enclosing_region in enclosing_regions[::-1]:
1601
+ if not self.get_directive_match(enclosing_region.tags, "DIR.OMP.TARGET"):
1602
+ continue
1603
+
1604
+ nt_tag = self.get_clauses_by_name(
1605
+ enclosing_region.tags, "QUAL.OMP.NUM_TEAMS"
1606
+ )
1607
+ assert len(nt_tag) == 1
1608
+ cur_num_team_clauses = self.get_clauses_by_name(
1609
+ clauses, "QUAL.OMP.NUM_TEAMS", remove_from_orig=True
1610
+ )
1611
+ if len(cur_num_team_clauses) >= 1:
1612
+ nt_tag[-1].arg = cur_num_team_clauses[-1].arg
1613
+ else:
1614
+ nt_tag[-1].arg = 0
1615
+
1616
+ nt_tag = self.get_clauses_by_name(
1617
+ enclosing_region.tags, "QUAL.OMP.THREAD_LIMIT"
1618
+ )
1619
+ assert len(nt_tag) == 1
1620
+ cur_num_team_clauses = self.get_clauses_by_name(
1621
+ clauses, "QUAL.OMP.THREAD_LIMIT", remove_from_orig=True
1622
+ )
1623
+ if len(cur_num_team_clauses) >= 1:
1624
+ nt_tag[-1].arg = cur_num_team_clauses[-1].arg
1625
+ else:
1626
+ nt_tag[-1].arg = 0
1627
+
1628
+ return
1629
+
1630
+ def check_distribute_nesting(self, dir_tag):
1631
+ if "DISTRIBUTE" in dir_tag and "TEAMS" not in dir_tag:
1632
+ enclosing_regions = get_enclosing_region(self.func_ir, self.blk_start)
1633
+ if (
1634
+ len(enclosing_regions) < 1
1635
+ or "TEAMS" not in enclosing_regions[-1].tags[0].name
1636
+ ):
1637
+ raise NotImplementedError(
1638
+ "DISTRIBUTE must be nested under or combined with TEAMS."
1639
+ )
1640
+
1641
+ def teams_directive(self, args):
1642
+ if DEBUG_OPENMP >= 1:
1643
+ print(
1644
+ "visit teams_directive", args, type(args), self.blk_start, self.blk_end
1645
+ )
1646
+ start_tags = [openmp_tag("DIR.OMP.TEAMS")]
1647
+ end_tags = [openmp_tag("DIR.OMP.END.TEAMS")]
1648
+ clauses = self.some_data_clause_directive(args, start_tags, end_tags, 1)
1649
+
1650
+ self.teams_back_prop(clauses)
1651
+
1652
+ def target_directive(self, args):
1653
+ if sys.platform.startswith("darwin"):
1654
+ print("ERROR: OpenMP target offloading is unavailable on Darwin")
1655
+ sys.exit(-1)
1656
+ self.some_target_directive(args, "TARGET", 1)
1657
+
1658
+ def target_teams_directive(self, args):
1659
+ self.some_target_directive(args, "TARGET.TEAMS", 2)
1660
+
1661
+ def target_teams_distribute_directive(self, args):
1662
+ self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE", 3, has_loop=True)
1663
+
1664
+ def target_loop_directive(self, args):
1665
+ self.some_target_directive(
1666
+ args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 2, has_loop=True
1667
+ )
1668
+
1669
+ def target_teams_loop_directive(self, args):
1670
+ self.some_target_directive(
1671
+ args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 3, has_loop=True
1672
+ )
1673
+
1674
+ def target_teams_distribute_parallel_for_directive(self, args):
1675
+ self.some_target_directive(
1676
+ args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 5, has_loop=True
1677
+ )
1678
+
1679
+ def target_teams_distribute_parallel_for_simd_directive(self, args):
1680
+ # Intentionally dropping "SIMD" from string as that typically isn't implemented on GPU.
1681
+ self.some_target_directive(
1682
+ args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP", 6, has_loop=True
1683
+ )
1684
+
1685
+ def get_clauses_by_name(self, clauses, names, remove_from_orig=False):
1686
+ if not isinstance(names, list):
1687
+ names = [names]
1688
+
1689
+ ret = list(filter(lambda x: x.name in names, clauses))
1690
+ if remove_from_orig:
1691
+ clauses[:] = list(filter(lambda x: x.name not in names, clauses))
1692
+ return ret
1693
+
1694
+ def get_clauses_by_start(self, clauses, names, remove_from_orig=False):
1695
+ if not isinstance(names, list):
1696
+ names = [names]
1697
+ ret = list(
1698
+ filter(lambda x: any([x.name.startswith(y) for y in names]), clauses)
1699
+ )
1700
+ if remove_from_orig:
1701
+ clauses[:] = list(
1702
+ filter(
1703
+ lambda x: any([not x.name.startswith(y) for y in names]), clauses
1704
+ )
1705
+ )
1706
+ return ret
1707
+
1708
+ def get_clauses_if_contains(self, clauses, names, remove_from_orig=False):
1709
+ if not isinstance(names, list):
1710
+ names = [names]
1711
+ ret = list(filter(lambda x: any([y in x.name for y in names]), clauses))
1712
+ if remove_from_orig:
1713
+ clauses[:] = list(
1714
+ filter(lambda x: any([y not in x.name for y in names]), clauses)
1715
+ )
1716
+ return ret
1717
+
1718
+ def get_directive_if_contains(self, tags, name):
1719
+ dir = [x for x in tags if x.name.startswith("DIR")]
1720
+ assert len(dir) == 1, "Expected one directive tag"
1721
+ ret = [x for x in dir if name in x.name]
1722
+ return ret
1723
+
1724
+ def get_directive_match(self, tags, name):
1725
+ dir = [x for x in tags if x.name.startswith("DIR")]
1726
+ assert len(dir) == 1, "Expected one directive tag"
1727
+ ret = [x for x in dir if name == x.name]
1728
+ return ret
1729
+
1730
+ def target_enter_data_directive(self, args):
1731
+ sblk = self.blocks[self.blk_start]
1732
+
1733
+ if DEBUG_OPENMP >= 1:
1734
+ print("visit target_enter_data_directive", args, type(args))
1735
+
1736
+ clauses, _ = self.flatten(args[3:], sblk)
1737
+ or_start = openmp_region_start(
1738
+ [openmp_tag("DIR.OMP.TARGET.ENTER.DATA")] + clauses, 0, self.loc
1739
+ )
1740
+ or_end = openmp_region_end(
1741
+ or_start, [openmp_tag("DIR.OMP.END.TARGET.ENTER.DATA")], self.loc
1742
+ )
1743
+ sblk.body = [or_start] + [or_end] + sblk.body[:]
1744
+
1745
+ def target_exit_data_directive(self, args):
1746
+ sblk = self.blocks[self.blk_start]
1747
+
1748
+ if DEBUG_OPENMP >= 1:
1749
+ print("visit target_exit_data_directive", args, type(args))
1750
+
1751
+ clauses, _ = self.flatten(args[3:], sblk)
1752
+ or_start = openmp_region_start(
1753
+ [openmp_tag("DIR.OMP.TARGET.EXIT.DATA")] + clauses, 0, self.loc
1754
+ )
1755
+ or_end = openmp_region_end(
1756
+ or_start, [openmp_tag("DIR.OMP.END.TARGET.EXIT.DATA")], self.loc
1757
+ )
1758
+ sblk.body = [or_start] + [or_end] + sblk.body[:]
1759
+
1760
+ def teams_distribute_parallel_for_simd_directive(self, args):
1761
+ self.some_distribute_directive(
1762
+ args, "TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD", 5, has_loop=True
1763
+ )
1764
+
1765
+ def teams_distribute_parallel_for_directive(self, args):
1766
+ self.some_distribute_directive(
1767
+ args, "TEAMS.DISTRIBUTE.PARALLEL.LOOP", 4, has_loop=True
1768
+ )
1769
+
1770
+ def teams_distribute_directive(self, args):
1771
+ self.some_distribute_directive(args, "TEAMS.DISTRIBUTE", 2, has_loop=True)
1772
+
1773
+ def teams_distribute_simd_directive(self, args):
1774
+ self.some_distribute_directive(args, "TEAMS.DISTRIBUTE.SIMD", 3, has_loop=True)
1775
+
1776
+ def teams_loop_directive(self, args):
1777
+ self.some_distribute_directive(
1778
+ args, "TEAMS.DISTRIBUTE.PARALLEL.LOOP", 2, has_loop=True
1779
+ )
1780
+
1781
+ def loop_directive(self, args):
1782
+ # TODO Add error checking that a clause that the parser accepts if we find that
1783
+ # loop can even take clauses, which we're not sure that it can.
1784
+ enclosing_regions = get_enclosing_region(self.func_ir, self.blk_start)
1785
+ if not enclosing_regions or len(enclosing_regions) < 1:
1786
+ self.some_for_directive(
1787
+ args, "DIR.OMP.PARALLEL.LOOP", "DIR.OMP.END.PARALLEL.LOOP", 1, True
1788
+ )
1789
+ else:
1790
+ if "DISTRIBUTE" in enclosing_regions[-1].tags[0].name:
1791
+ self.some_distribute_directive(args, "PARALLEL.LOOP", 1, has_loop=True)
1792
+ elif "TEAMS" in enclosing_regions[-1].tags[0].name:
1793
+ self.some_distribute_directive(
1794
+ args, "DISTRIBUTE.PARALLEL.LOOP", 1, has_loop=True
1795
+ )
1796
+ else:
1797
+ if "TARGET" in enclosing_regions[-1].tags[0].name:
1798
+ self.some_distribute_directive(
1799
+ args, "TEAMS.DISTRIBUTE.PARALLEL.LOOP", 1, has_loop=True
1800
+ )
1801
+ else:
1802
+ self.some_for_directive(
1803
+ args,
1804
+ "DIR.OMP.PARALLEL.LOOP",
1805
+ "DIR.OMP.END.PARALLEL.LOOP",
1806
+ 1,
1807
+ True,
1808
+ )
1809
+
1810
+ def distribute_directive(self, args):
1811
+ self.some_distribute_directive(args, "DISTRIBUTE", 1, has_loop=True)
1812
+
1813
+ def distribute_simd_directive(self, args):
1814
+ self.some_distribute_directive(args, "DISTRIBUTE.SIMD", 2, has_loop=True)
1815
+
1816
+ def distribute_parallel_for_directive(self, args):
1817
+ self.some_distribute_directive(
1818
+ args, "DISTRIBUTE.PARALLEL.LOOP", 3, has_loop=True
1819
+ )
1820
+
1821
+ def distribute_parallel_for_simd_directive(self, args):
1822
+ self.some_distribute_directive(
1823
+ args, "DISTRIBUTE.PARALLEL.LOOP.SIMD", 4, has_loop=True
1824
+ )
1825
+
1826
+ def some_distribute_directive(self, args, dir_tag, lexer_count, has_loop=False):
1827
+ if DEBUG_OPENMP >= 1:
1828
+ print(
1829
+ "visit some_distribute_directive",
1830
+ args,
1831
+ type(args),
1832
+ self.blk_start,
1833
+ self.blk_end,
1834
+ )
1835
+
1836
+ self.check_distribute_nesting(dir_tag)
1837
+
1838
+ target_num = OpenmpVisitor.target_num
1839
+ OpenmpVisitor.target_num += 1
1840
+
1841
+ dir_start_tag = "DIR.OMP." + dir_tag
1842
+ dir_end_tag = "DIR.OMP.END." + dir_tag
1843
+ start_tags = [openmp_tag(dir_start_tag, target_num)]
1844
+ end_tags = [openmp_tag(dir_end_tag, target_num)]
1845
+
1846
+ sblk = self.blocks[self.blk_start]
1847
+ clauses, _ = self.flatten(args[lexer_count:], sblk)
1848
+
1849
+ if "TEAMS" in dir_tag:
1850
+ # NUM_TEAMS, THREAD_LIMIT are not in clauses, set them to 0 to
1851
+ # use runtime defaults in teams, thread launching.
1852
+ if len(self.get_clauses_by_name(clauses, "QUAL.OMP.NUM_TEAMS")) == 0:
1853
+ start_tags.append(openmp_tag("QUAL.OMP.NUM_TEAMS", 0))
1854
+ if len(self.get_clauses_by_name(clauses, "QUAL.OMP.THREAD_LIMIT")) == 0:
1855
+ start_tags.append(openmp_tag("QUAL.OMP.THREAD_LIMIT", 0))
1856
+ self.teams_back_prop(clauses)
1857
+ elif "PARALLEL" in dir_tag:
1858
+ self.parallel_back_prop(clauses)
1859
+
1860
+ if DEBUG_OPENMP >= 1:
1861
+ for clause in clauses:
1862
+ print("target clause:", clause)
1863
+
1864
+ self.some_data_clause_directive(
1865
+ clauses, start_tags, end_tags, 0, has_loop=has_loop, for_target=False
1866
+ )
1867
+
1868
+ def some_target_directive(self, args, dir_tag, lexer_count, has_loop=False):
1869
+ if DEBUG_OPENMP >= 1:
1870
+ print(
1871
+ "visit some_target_directive",
1872
+ args,
1873
+ type(args),
1874
+ self.blk_start,
1875
+ self.blk_end,
1876
+ )
1877
+
1878
+ self.check_distribute_nesting(dir_tag)
1879
+
1880
+ target_num = OpenmpVisitor.target_num
1881
+ OpenmpVisitor.target_num += 1
1882
+
1883
+ dir_start_tag = "DIR.OMP." + dir_tag
1884
+ dir_end_tag = "DIR.OMP.END." + dir_tag
1885
+ start_tags = [openmp_tag(dir_start_tag, target_num)]
1886
+ end_tags = [openmp_tag(dir_end_tag, target_num)]
1887
+
1888
+ sblk = self.blocks[self.blk_start]
1889
+ clauses, _ = self.flatten(args[lexer_count:], sblk)
1890
+
1891
+ if "TEAMS" in dir_tag:
1892
+ # When NUM_TEAMS, THREAD_LIMIT are not in clauses, set them to 0 to
1893
+ # use runtime defaults in teams, thread launching, otherwise use
1894
+ # existing clauses.
1895
+ clause_num_teams = self.get_clauses_by_name(clauses, "QUAL.OMP.NUM_TEAMS")
1896
+ if not clause_num_teams:
1897
+ start_tags.append(openmp_tag("QUAL.OMP.NUM_TEAMS", 0))
1898
+
1899
+ # Use the THREAD_LIMIT clause value if it exists, regardless of a
1900
+ # combined PARALLEL (see
1901
+ # https://www.openmp.org/spec-html/5.0/openmpse15.html) since
1902
+ # THREAD_LIMIT takes precedence. If clause does not exist, set to 0
1903
+ # or to NUM_THREADS of the combined PARALLEL (if this exists).
1904
+ clause_thread_limit = self.get_clauses_by_name(
1905
+ clauses, "QUAL.OMP.THREAD_LIMIT"
1906
+ )
1907
+ if not clause_thread_limit:
1908
+ thread_limit = 0
1909
+ if "PARALLEL" in dir_tag:
1910
+ clause_num_threads = self.get_clauses_by_name(
1911
+ clauses, "QUAL.OMP.NUM_THREADS"
1912
+ )
1913
+ if clause_num_threads:
1914
+ assert len(clause_num_threads) == 1, (
1915
+ "Expected single NUM_THREADS clause"
1916
+ )
1917
+ thread_limit = clause_num_threads[0].arg
1918
+ start_tags.append(openmp_tag("QUAL.OMP.THREAD_LIMIT", thread_limit))
1919
+ elif "PARALLEL" in dir_tag:
1920
+ # PARALLEL in the directive (without TEAMS), set THREAD_LIMIT to NUM_THREADS clause
1921
+ # (if NUM_THREADS exists), or 0 (if NUM_THREADS does not exist)
1922
+ num_threads = 0
1923
+ clause_num_threads = self.get_clauses_by_name(
1924
+ clauses, "QUAL.OMP.NUM_THREADS"
1925
+ )
1926
+ if clause_num_threads:
1927
+ assert len(clause_num_threads) == 1, (
1928
+ "Expected single NUM_THREADS clause"
1929
+ )
1930
+ num_threads = clause_num_threads[0].arg
1931
+
1932
+ # Replace existing THREAD_LIMIT clause.
1933
+ clause_thread_limit = self.get_clauses_by_name(
1934
+ clauses, "QUAL.OMP.THREAD_LIMIT", remove_from_orig=True
1935
+ )
1936
+ clauses.append(openmp_tag("QUAL.OMP.THREAD_LIMIT", num_threads))
1937
+ else:
1938
+ # Neither TEAMS or PARALLEL in directive, set teams, threads to 1.
1939
+ start_tags.append(openmp_tag("QUAL.OMP.NUM_TEAMS", 1))
1940
+ # Set thread limit to 0 to use runtime default.
1941
+ start_tags.append(openmp_tag("QUAL.OMP.THREAD_LIMIT", 0))
1942
+
1943
+ if DEBUG_OPENMP >= 1:
1944
+ for clause in clauses:
1945
+ print("target clause:", clause)
1946
+
1947
+ self.some_data_clause_directive(
1948
+ clauses, start_tags, end_tags, 0, has_loop=has_loop, for_target=True
1949
+ )
1950
+ # self.some_data_clause_directive(args, start_tags, end_tags, lexer_count, has_loop=has_loop)
1951
+
1952
+ def add_to_returns(self, stmts):
1953
+ for blk in self.blocks.values():
1954
+ if isinstance(blk.body[-1], ir.Return):
1955
+ blk.body = blk.body[:-1] + stmts + [blk.body[-1]]
1956
+
1957
+ def add_block_in_order(self, new_block, insert_after_block):
1958
+ """Insert a new block after the specified block while maintaining topological order"""
1959
+ new_blocks = {}
1960
+ # Copy blocks up to and including insert_after_block
1961
+ for label, block in self.blocks.items():
1962
+ new_blocks[label] = block
1963
+ if label == insert_after_block:
1964
+ # Insert new block right after
1965
+ # We add a fractional to make sure the block is sorted right
1966
+ # after the insert_after_block and before its successor.
1967
+ # TODO: Avoid this fractional addition.
1968
+ new_block_num = label + 0.1
1969
+ new_blocks[new_block_num] = new_block
1970
+ # Copy remaining blocks
1971
+ for label, block in self.blocks.items():
1972
+ if label > insert_after_block:
1973
+ new_blocks[label] = block
1974
+ # new_blocks = flatten_labels(new_blocks)
1975
+ self.blocks.clear()
1976
+ self.blocks.update(new_blocks)
1977
+ return new_block_num
1978
+
1979
+ def some_data_clause_directive(
1980
+ self,
1981
+ args,
1982
+ start_tags,
1983
+ end_tags,
1984
+ lexer_count,
1985
+ has_loop=False,
1986
+ for_target=False,
1987
+ for_task=False,
1988
+ ):
1989
+ if DEBUG_OPENMP >= 1:
1990
+ print(
1991
+ "visit some_data_clause_directive",
1992
+ args,
1993
+ type(args),
1994
+ self.blk_start,
1995
+ self.blk_end,
1996
+ )
1997
+ assert not (for_target and for_task)
1998
+
1999
+ sblk = self.blocks[self.blk_start]
2000
+ eblk = self.blocks[self.blk_end]
2001
+ scope = sblk.scope
2002
+
2003
+ if DEBUG_OPENMP >= 1:
2004
+ for clause in args[lexer_count:]:
2005
+ print("pre clause:", clause)
2006
+ clauses, default_shared = self.flatten(args[lexer_count:], sblk)
2007
+ if DEBUG_OPENMP >= 1:
2008
+ for clause in clauses:
2009
+ print("final clause:", clause)
2010
+
2011
+ before_start = []
2012
+ after_start = []
2013
+ for_before_start = []
2014
+ for_after_start = []
2015
+
2016
+ # Get a dict mapping variables explicitly mentioned in the data clauses above to their openmp_tag.
2017
+ vars_in_explicit_clauses, explicit_privates, non_user_explicits = (
2018
+ self.get_explicit_vars(clauses)
2019
+ )
2020
+ if DEBUG_OPENMP >= 1:
2021
+ print(
2022
+ "vars_in_explicit_clauses:",
2023
+ sorted(vars_in_explicit_clauses),
2024
+ type(vars_in_explicit_clauses),
2025
+ )
2026
+ for v in clauses:
2027
+ print("vars_in_explicit clauses first:", v)
2028
+
2029
+ if has_loop:
2030
+ prepare_out = self.prepare_for_directive(
2031
+ clauses,
2032
+ vars_in_explicit_clauses,
2033
+ for_before_start,
2034
+ for_after_start,
2035
+ start_tags,
2036
+ end_tags,
2037
+ scope,
2038
+ )
2039
+ vars_in_explicit_clauses, explicit_privates, non_user_explicits = (
2040
+ self.get_explicit_vars(clauses)
2041
+ )
2042
+ (
2043
+ found_loop,
2044
+ blocks_for_io,
2045
+ blocks_in_region,
2046
+ entry_pred,
2047
+ exit_block,
2048
+ inst,
2049
+ size_var,
2050
+ step_var,
2051
+ latest_index,
2052
+ loop_index,
2053
+ ) = prepare_out
2054
+ assert found_loop
2055
+ else:
2056
+ blocks_for_io = self.body_blocks
2057
+ blocks_in_region = get_blocks_between_start_end(
2058
+ self.blocks, self.blk_start, self.blk_end
2059
+ )
2060
+ entry_pred = sblk
2061
+ exit_block = eblk
2062
+
2063
+ # Do an analysis to get variable use information coming into and out of the region.
2064
+ inputs_to_region, def_but_live_out, private_to_region, live_map = (
2065
+ self.find_io_vars(blocks_for_io)
2066
+ )
2067
+ live_out_copy = copy.copy(def_but_live_out)
2068
+
2069
+ if DEBUG_OPENMP >= 1:
2070
+ print("inputs_to_region:", sorted(inputs_to_region))
2071
+ print("def_but_live_out:", sorted(def_but_live_out))
2072
+ print("private_to_region:", sorted(private_to_region))
2073
+ for v in clauses:
2074
+ print("clause after find_io_vars:", v)
2075
+
2076
+ # Remove variables the user explicitly added to a clause from the auto-determined variables.
2077
+ # This will also treat SSA forms of vars the same as their explicit Python var clauses.
2078
+ self.remove_explicit_from_io_vars(
2079
+ inputs_to_region,
2080
+ def_but_live_out,
2081
+ private_to_region,
2082
+ vars_in_explicit_clauses,
2083
+ clauses,
2084
+ non_user_explicits,
2085
+ scope,
2086
+ self.loc,
2087
+ )
2088
+
2089
+ if DEBUG_OPENMP >= 1:
2090
+ for v in clauses:
2091
+ print("clause after remove_explicit_from_io_vars:", v)
2092
+
2093
+ if DEBUG_OPENMP >= 1:
2094
+ for k, v in vars_in_explicit_clauses.items():
2095
+ print("vars_in_explicit before:", k, v)
2096
+ for v in clauses:
2097
+ print("vars_in_explicit clauses before:", v)
2098
+ for k, v in non_user_explicits.items():
2099
+ print("non_user_explicits before:", k, v)
2100
+
2101
+ if DEBUG_OPENMP >= 1:
2102
+ print("inputs_to_region after remove_explicit:", sorted(inputs_to_region))
2103
+ print("def_but_live_out after remove_explicit:", sorted(def_but_live_out))
2104
+ print("private_to_region after remove_explicit:", sorted(private_to_region))
2105
+
2106
+ if not default_shared and (
2107
+ has_user_defined_var(inputs_to_region)
2108
+ or has_user_defined_var(def_but_live_out)
2109
+ or has_user_defined_var(private_to_region)
2110
+ ):
2111
+ user_defined_inputs = get_user_defined_var(inputs_to_region)
2112
+ user_defined_def_live = get_user_defined_var(def_but_live_out)
2113
+ user_defined_private = get_user_defined_var(private_to_region)
2114
+ if DEBUG_OPENMP >= 1:
2115
+ print("inputs users:", sorted(user_defined_inputs))
2116
+ print("def users:", sorted(user_defined_def_live))
2117
+ print("private users:", sorted(user_defined_private))
2118
+ raise UnspecifiedVarInDefaultNone(
2119
+ "Variables with no data env clause in OpenMP region: "
2120
+ + str(
2121
+ user_defined_inputs.union(user_defined_def_live).union(
2122
+ user_defined_private
2123
+ )
2124
+ )
2125
+ )
2126
+
2127
+ if for_target:
2128
+ self.make_implicit_explicit_target(
2129
+ scope,
2130
+ vars_in_explicit_clauses,
2131
+ clauses,
2132
+ True,
2133
+ inputs_to_region,
2134
+ def_but_live_out,
2135
+ private_to_region,
2136
+ )
2137
+ elif for_task:
2138
+ self.make_implicit_explicit(
2139
+ scope,
2140
+ vars_in_explicit_clauses,
2141
+ clauses,
2142
+ True,
2143
+ inputs_to_region,
2144
+ def_but_live_out,
2145
+ private_to_region,
2146
+ for_task=get_enclosing_region(self.func_ir, self.blk_start),
2147
+ )
2148
+ else:
2149
+ self.make_implicit_explicit(
2150
+ scope,
2151
+ vars_in_explicit_clauses,
2152
+ clauses,
2153
+ True,
2154
+ inputs_to_region,
2155
+ def_but_live_out,
2156
+ private_to_region,
2157
+ )
2158
+ if DEBUG_OPENMP >= 1:
2159
+ for k, v in vars_in_explicit_clauses.items():
2160
+ print("vars_in_explicit after:", k, v)
2161
+ for v in clauses:
2162
+ print("vars_in_explicit clauses after:", v)
2163
+ vars_in_explicit_clauses, explicit_privates, non_user_explicits = (
2164
+ self.get_explicit_vars(clauses)
2165
+ )
2166
+ if DEBUG_OPENMP >= 1:
2167
+ print("post get_explicit_vars:", explicit_privates)
2168
+ for k, v in vars_in_explicit_clauses.items():
2169
+ print("vars_in_explicit post:", k, v)
2170
+ if DEBUG_OPENMP >= 1:
2171
+ print("blocks_in_region:", blocks_in_region)
2172
+
2173
+ self.make_consts_unliteral_for_privates(explicit_privates, self.blocks)
2174
+
2175
+ # Returns a dict of private clause variables and their potentially SSA form at the end of the region.
2176
+ clause_privates = self.get_clause_privates(
2177
+ clauses, live_out_copy, scope, self.loc
2178
+ )
2179
+
2180
+ if DEBUG_OPENMP >= 1:
2181
+ print("clause_privates:", sorted(clause_privates), type(clause_privates))
2182
+ print("inputs_to_region:", sorted(inputs_to_region))
2183
+ print("def_but_live_out:", sorted(def_but_live_out))
2184
+ print("live_out_copy:", sorted(live_out_copy))
2185
+ print("private_to_region:", sorted(private_to_region))
2186
+
2187
+ keep_alive = []
2188
+ tags_for_enclosing = self.add_explicits_to_start(
2189
+ scope, vars_in_explicit_clauses, clauses, True, start_tags, keep_alive
2190
+ )
2191
+ add_tags_to_enclosing(self.func_ir, self.blk_start, tags_for_enclosing)
2192
+
2193
+ # or_start = openmp_region_start([openmp_tag("DIR.OMP.TARGET", target_num)] + clauses, 0, self.loc)
2194
+ # or_end = openmp_region_end(or_start, [openmp_tag("DIR.OMP.END.TARGET", target_num)], self.loc)
2195
+ # new_header_block_num = max(self.blocks.keys()) + 1
2196
+
2197
+ firstprivate_dead_after = list(
2198
+ filter(
2199
+ lambda x: x.name == "QUAL.OMP.FIRSTPRIVATE"
2200
+ and x.arg not in live_map[self.blk_end],
2201
+ start_tags,
2202
+ )
2203
+ )
2204
+
2205
+ or_start = openmp_region_start(
2206
+ start_tags, 0, self.loc, firstprivate_dead_after=firstprivate_dead_after
2207
+ )
2208
+ or_end = openmp_region_end(or_start, end_tags, self.loc)
2209
+
2210
+ if DEBUG_OPENMP >= 1:
2211
+ for x in keep_alive:
2212
+ print("keep_alive:", x)
2213
+ for x in firstprivate_dead_after:
2214
+ print("firstprivate_dead_after:", x)
2215
+
2216
+ # Adding the openmp tags in topo order to avoid problems with code
2217
+ # generation and with_lifting legalization.
2218
+ # TODO: we should remove the requirement to process in topo order. There
2219
+ # is state depending on topo order processing.
2220
+ if has_loop:
2221
+ new_header_block = ir.Block(scope, self.loc)
2222
+ new_header_block.body = (
2223
+ [or_start] + after_start + for_after_start + [entry_pred.body[-1]]
2224
+ )
2225
+ new_block_num = self.add_block_in_order(new_header_block, self.blk_start)
2226
+ entry_pred.body = (
2227
+ entry_pred.body[:-1]
2228
+ + before_start
2229
+ + for_before_start
2230
+ + [ir.Jump(new_block_num, self.loc)]
2231
+ )
2232
+
2233
+ if for_task:
2234
+ exit_block.body = [or_end] + exit_block.body
2235
+ self.add_to_returns(keep_alive)
2236
+ else:
2237
+ exit_block.body = [or_end] + keep_alive + exit_block.body
2238
+ else:
2239
+ new_header_block = ir.Block(scope, self.loc)
2240
+ new_header_block.body = [or_start] + after_start + sblk.body[:]
2241
+ new_header_block_num = self.add_block_in_order(
2242
+ new_header_block, self.blk_start
2243
+ )
2244
+ sblk.body = before_start + [ir.Jump(new_header_block_num, self.loc)]
2245
+
2246
+ # NOTE: or_start could also be inlined for correct codegen as
2247
+ # follows. Favoring the add_block_in_order method for consistency.
2248
+ # sblk.body = before_start + [or_start] + after_start + sblk.body[:]
2249
+
2250
+ if for_task:
2251
+ eblk.body = [or_end] + eblk.body[:]
2252
+ self.add_to_returns(keep_alive)
2253
+ else:
2254
+ eblk.body = [or_end] + keep_alive + eblk.body[:]
2255
+
2256
+ add_enclosing_region(self.func_ir, self.body_blocks, or_start)
2257
+ return clauses
2258
+
2259
+ def target_clause(self, args):
2260
+ if DEBUG_OPENMP >= 1:
2261
+ print("visit target_clause", args, type(args), args[0])
2262
+ if isinstance(args[0], list):
2263
+ print(args[0][0])
2264
+ (val,) = args
2265
+ if isinstance(val, openmp_tag):
2266
+ return [val]
2267
+ elif isinstance(val, list):
2268
+ return val
2269
+ elif val == "nowait":
2270
+ return openmp_tag("QUAL.OMP.NOWAIT")
2271
+ else:
2272
+ return val
2273
+ # return args[0]
2274
+
2275
+ def target_teams_clause(self, args):
2276
+ if DEBUG_OPENMP >= 1:
2277
+ print("visit target_teams_clause", args, type(args), args[0])
2278
+ if isinstance(args[0], list):
2279
+ print(args[0][0])
2280
+ return args[0]
2281
+
2282
+ def target_teams_distribute_parallel_for_simd_clause(self, args):
2283
+ if DEBUG_OPENMP >= 1:
2284
+ print(
2285
+ "visit target_teams_distribute_parallel_for_simd_clause",
2286
+ args,
2287
+ type(args),
2288
+ args[0],
2289
+ )
2290
+ if isinstance(args[0], list):
2291
+ print(args[0][0])
2292
+ return args[0]
2293
+
2294
+ def teams_distribute_parallel_for_simd_clause(self, args):
2295
+ if DEBUG_OPENMP >= 1:
2296
+ print(
2297
+ "visit teams_distribute_parallel_for_simd_clause",
2298
+ args,
2299
+ type(args),
2300
+ args[0],
2301
+ )
2302
+ if isinstance(args[0], list):
2303
+ print(args[0][0])
2304
+ return args[0]
2305
+
2306
+ def teams_distribute_parallel_for_clause(self, args):
2307
+ if DEBUG_OPENMP >= 1:
2308
+ print(
2309
+ "visit teams_distribute_parallel_for_clause", args, type(args), args[0]
2310
+ )
2311
+ if isinstance(args[0], list):
2312
+ print(args[0][0])
2313
+ return args[0]
2314
+
2315
+ def distribute_clause(self, args):
2316
+ if DEBUG_OPENMP >= 1:
2317
+ print("visit distribute_clause", args, type(args), args[0])
2318
+ if isinstance(args[0], list):
2319
+ print(args[0][0])
2320
+ return args[0]
2321
+
2322
+ def teams_distribute_clause(self, args):
2323
+ if DEBUG_OPENMP >= 1:
2324
+ print("visit teams_distribute_clause", args, type(args), args[0])
2325
+ if isinstance(args[0], list):
2326
+ print(args[0][0])
2327
+ return args[0]
2328
+
2329
+ def teams_distribute_simd_clause(self, args):
2330
+ if DEBUG_OPENMP >= 1:
2331
+ print("visit teams_distribute_simd_clause", args, type(args), args[0])
2332
+ if isinstance(args[0], list):
2333
+ print(args[0][0])
2334
+ return args[0]
2335
+
2336
+ def distribute_parallel_for_clause(self, args):
2337
+ if DEBUG_OPENMP >= 1:
2338
+ print("visit distribute_parallel_for_clause", args, type(args), args[0])
2339
+ if isinstance(args[0], list):
2340
+ print(args[0][0])
2341
+ return args[0]
2342
+
2343
+ def target_teams_distribute_clause(self, args):
2344
+ if DEBUG_OPENMP >= 1:
2345
+ print("visit target_teams_distribute_clause", args, type(args), args[0])
2346
+ if isinstance(args[0], list):
2347
+ print(args[0][0])
2348
+ return args[0]
2349
+
2350
+ def target_teams_distribute_parallel_for_clause(self, args):
2351
+ if DEBUG_OPENMP >= 1:
2352
+ print(
2353
+ "visit target_teams_distribute_parallel_for_clause",
2354
+ args,
2355
+ type(args),
2356
+ args[0],
2357
+ )
2358
+ if isinstance(args[0], list):
2359
+ print(args[0][0])
2360
+ return args[0]
2361
+
2362
+ # Don't need a rule for target_update_construct.
2363
+
2364
+ def target_update_directive(self, args):
2365
+ sblk = self.blocks[self.blk_start]
2366
+
2367
+ if DEBUG_OPENMP >= 1:
2368
+ print("visit target_update_directive", args, type(args))
2369
+ clauses, _ = self.flatten(args[2:], sblk)
2370
+ or_start = openmp_region_start(
2371
+ [openmp_tag("DIR.OMP.TARGET.UPDATE")] + clauses, 0, self.loc
2372
+ )
2373
+ or_end = openmp_region_end(
2374
+ or_start, [openmp_tag("DIR.OMP.END.TARGET.UPDATE")], self.loc
2375
+ )
2376
+ sblk.body = [or_start] + [or_end] + sblk.body[:]
2377
+
2378
+ def target_update_clause(self, args):
2379
+ if DEBUG_OPENMP >= 1:
2380
+ print("visit target_update_clause", args, type(args), args[0])
2381
+ # return args[0]
2382
+ (val,) = args
2383
+ if isinstance(val, openmp_tag):
2384
+ return [val]
2385
+ elif isinstance(val, list):
2386
+ return val
2387
+ else:
2388
+ return val
2389
+
2390
+ def motion_clause(self, args):
2391
+ if DEBUG_OPENMP >= 1:
2392
+ print("visit motion_clause", args, type(args))
2393
+ assert args[0] in ["to", "from"]
2394
+ map_type = args[0].upper()
2395
+ var_list = args[1]
2396
+ assert len(args) == 2
2397
+ ret = []
2398
+ for var in var_list:
2399
+ ret.append(openmp_tag("QUAL.OMP.MAP." + map_type, var))
2400
+ return ret
2401
+
2402
+ def variable_array_section_list(self, args):
2403
+ if DEBUG_OPENMP >= 1:
2404
+ print("visit variable_array_section_list", args, type(args))
2405
+ if len(args) == 1:
2406
+ return args
2407
+ else:
2408
+ args[0].append(args[1])
2409
+ return args[0]
2410
+
2411
+ """
2412
+ def array_section(self, args):
2413
+ if DEBUG_OPENMP >= 1:
2414
+ print("visit array_section", args, type(args))
2415
+ return args
2416
+
2417
+ def array_section_subscript(self, args):
2418
+ if DEBUG_OPENMP >= 1:
2419
+ print("visit array_section_subscript", args, type(args))
2420
+ return args
2421
+ """
2422
+
2423
+ # Don't need a rule for TARGET.
2424
+ # Don't need a rule for single_construct.
2425
+
2426
+ def single_directive(self, args):
2427
+ sblk = self.blocks[self.blk_start]
2428
+ eblk = self.blocks[self.blk_end]
2429
+
2430
+ if DEBUG_OPENMP >= 1:
2431
+ print("visit single_directive", args, type(args))
2432
+ or_start = openmp_region_start([openmp_tag("DIR.OMP.SINGLE")], 0, self.loc)
2433
+ or_start.requires_acquire_release()
2434
+ or_end = openmp_region_end(
2435
+ or_start, [openmp_tag("DIR.OMP.END.SINGLE")], self.loc
2436
+ )
2437
+ sblk.body = [or_start] + sblk.body[:]
2438
+ eblk.body = [or_end] + eblk.body[:]
2439
+
2440
+ def single_clause(self, args):
2441
+ if DEBUG_OPENMP >= 1:
2442
+ print("visit single_clause", args, type(args), args[0])
2443
+ return args[0]
2444
+
2445
+ # Don't need a rule for unique_single_clause.
2446
+ # def NOWAIT(self, args):
2447
+ # return "nowait"
2448
+ # Don't need a rule for NOWAIT.
2449
+ # Don't need a rule for master_construct.
2450
+
2451
+ def master_directive(self, args):
2452
+ raise NotImplementedError("Master directive currently unsupported.")
2453
+
2454
+ # Don't need a rule for simd_construct.
2455
+
2456
+ def simd_directive(self, args):
2457
+ raise NotImplementedError("Simd directive currently unsupported.")
2458
+
2459
+ # Don't need a rule for SIMD.
2460
+
2461
+ def simd_clause(self, args):
2462
+ if DEBUG_OPENMP >= 1:
2463
+ print("visit simd_clause", args, type(args), args[0])
2464
+ return args[0]
2465
+
2466
+ def aligned_clause(self, args):
2467
+ raise NotImplementedError("Aligned clause currently unsupported.")
2468
+ if DEBUG_OPENMP >= 1:
2469
+ print("visit aligned_clause", args, type(args))
2470
+
2471
+ # Don't need a rule for declare_simd_construct.
2472
+
2473
+ def declare_simd_directive_seq(self, args):
2474
+ if DEBUG_OPENMP >= 1:
2475
+ print("visit declare_simd_directive_seq", args, type(args), args[0])
2476
+ return args[0]
2477
+
2478
+ def declare_simd_directive(self, args):
2479
+ raise NotImplementedError("Declare simd directive currently unsupported.")
2480
+
2481
+ def declare_simd_clause(self, args):
2482
+ raise NotImplementedError("Declare simd clauses currently unsupported.")
2483
+ if DEBUG_OPENMP >= 1:
2484
+ print("visit declare_simd_clause", args, type(args))
2485
+
2486
+ # Don't need a rule for ALIGNED.
2487
+
2488
+ def inbranch_clause(self, args):
2489
+ if DEBUG_OPENMP >= 1:
2490
+ print("visit inbranch_clause", args, type(args), args[0])
2491
+ return args[0]
2492
+
2493
+ # Don't need a rule for INBRANCH.
2494
+ # Don't need a rule for NOTINBRANCH.
2495
+
2496
+ def uniform_clause(self, args):
2497
+ raise NotImplementedError("Uniform clause currently unsupported.")
2498
+ if DEBUG_OPENMP >= 1:
2499
+ print("visit uniform_clause", args, type(args))
2500
+
2501
+ # Don't need a rule for UNIFORM.
2502
+
2503
+ def collapse_clause(self, args):
2504
+ if DEBUG_OPENMP >= 1:
2505
+ print("visit collapse_clause", args, type(args))
2506
+ return openmp_tag("QUAL.OMP.COLLAPSE", args[1])
2507
+
2508
+ # Don't need a rule for COLLAPSE.
2509
+ # Don't need a rule for task_construct.
2510
+ # Don't need a rule for TASK.
2511
+
2512
+ def task_directive(self, args):
2513
+ if DEBUG_OPENMP >= 1:
2514
+ print("visit task_directive", args, type(args))
2515
+
2516
+ start_tags = [openmp_tag("DIR.OMP.TASK")]
2517
+ end_tags = [openmp_tag("DIR.OMP.END.TASK")]
2518
+ self.some_data_clause_directive(args, start_tags, end_tags, 1, for_task=True)
2519
+
2520
+ def task_clause(self, args):
2521
+ if DEBUG_OPENMP >= 1:
2522
+ print("visit task_clause", args, type(args), args[0])
2523
+ return args[0]
2524
+
2525
+ def unique_task_clause(self, args):
2526
+ raise NotImplementedError("Task-related clauses currently unsupported.")
2527
+ if DEBUG_OPENMP >= 1:
2528
+ print("visit unique_task_clause", args, type(args))
2529
+
2530
+ # Don't need a rule for DEPEND.
2531
+ # Don't need a rule for FINAL.
2532
+ # Don't need a rule for UNTIED.
2533
+ # Don't need a rule for MERGEABLE.
2534
+
2535
+ def dependence_type(self, args):
2536
+ if DEBUG_OPENMP >= 1:
2537
+ print("visit dependence_type", args, type(args), args[0])
2538
+ return args[0]
2539
+
2540
+ # Don't need a rule for IN.
2541
+ # Don't need a rule for OUT.
2542
+ # Don't need a rule for INOUT.
2543
+
2544
+ def data_default_clause(self, args):
2545
+ if DEBUG_OPENMP >= 1:
2546
+ print("visit data_default_clause", args, type(args), args[0])
2547
+ return args[0]
2548
+
2549
+ def data_sharing_clause(self, args):
2550
+ if DEBUG_OPENMP >= 1:
2551
+ print("visit data_sharing_clause", args, type(args), args[0])
2552
+ return args[0]
2553
+
2554
+ def data_clause(self, args):
2555
+ if DEBUG_OPENMP >= 1:
2556
+ print("visit data_clause", args, type(args), args[0])
2557
+ return args[0]
2558
+
2559
+ def private_clause(self, args):
2560
+ if DEBUG_OPENMP >= 1:
2561
+ print("visit private_clause", args, type(args), args[0])
2562
+ (_, var_list) = args
2563
+ ret = []
2564
+ for var in var_list:
2565
+ ret.append(openmp_tag("QUAL.OMP.PRIVATE", var))
2566
+ return ret
2567
+
2568
+ # Don't need a rule for PRIVATE.
2569
+
2570
+ def copyprivate_clause(self, args):
2571
+ if DEBUG_OPENMP >= 1:
2572
+ print("visit copyprivate_clause", args, type(args), args[0])
2573
+ (_, var_list) = args
2574
+ ret = []
2575
+ for var in var_list:
2576
+ ret.append(openmp_tag("QUAL.OMP.COPYPRIVATE", var))
2577
+ return ret
2578
+
2579
+ # Don't need a rule for COPYPRIVATE.
2580
+
2581
+ def firstprivate_clause(self, args):
2582
+ if DEBUG_OPENMP >= 1:
2583
+ print("visit firstprivate_clause", args, type(args), args[0])
2584
+ (_, var_list) = args
2585
+ ret = []
2586
+ for var in var_list:
2587
+ ret.append(openmp_tag("QUAL.OMP.FIRSTPRIVATE", var))
2588
+ return ret
2589
+
2590
+ # Don't need a rule for FIRSTPRIVATE.
2591
+
2592
+ def lastprivate_clause(self, args):
2593
+ if DEBUG_OPENMP >= 1:
2594
+ print("visit lastprivate_clause", args, type(args), args[0])
2595
+ (_, var_list) = args
2596
+ ret = []
2597
+ for var in var_list:
2598
+ ret.append(openmp_tag("QUAL.OMP.LASTPRIVATE", var))
2599
+ return ret
2600
+
2601
+ # Don't need a rule for LASTPRIVATE.
2602
+
2603
+ def shared_clause(self, args):
2604
+ if DEBUG_OPENMP >= 1:
2605
+ print("visit shared_clause", args, type(args), args[0])
2606
+ (_, var_list) = args
2607
+ ret = []
2608
+ for var in var_list:
2609
+ ret.append(openmp_tag("QUAL.OMP.SHARED", var))
2610
+ return ret
2611
+
2612
+ # Don't need a rule for SHARED.
2613
+
2614
+ def copyin_clause(self, args):
2615
+ if DEBUG_OPENMP >= 1:
2616
+ print("visit copyin_clause", args, type(args), args[0])
2617
+ (_, var_list) = args
2618
+ ret = []
2619
+ for var in var_list:
2620
+ ret.append(openmp_tag("QUAL.OMP.COPYIN", var))
2621
+ return ret
2622
+
2623
+ # Don't need a rule for COPYIN.
2624
+ # Don't need a rule for REDUCTION.
2625
+
2626
+ def reduction_clause(self, args):
2627
+ if DEBUG_OPENMP >= 1:
2628
+ print("visit reduction_clause", args, type(args), args[0])
2629
+
2630
+ (_, red_op, red_list) = args
2631
+ ret = []
2632
+ for shared in red_list:
2633
+ ret.append(openmp_tag("QUAL.OMP.REDUCTION." + red_op, shared))
2634
+ return ret
2635
+
2636
+ def default_shared_clause(self, args):
2637
+ if DEBUG_OPENMP >= 1:
2638
+ print("visit default_shared_clause", args, type(args))
2639
+ return default_shared_val(True)
2640
+
2641
+ def default_none_clause(self, args):
2642
+ if DEBUG_OPENMP >= 1:
2643
+ print("visit default_none", args, type(args))
2644
+ return default_shared_val(False)
2645
+
2646
+ def const_num_or_var(self, args):
2647
+ if DEBUG_OPENMP >= 1:
2648
+ print("visit const_num_or_var", args, type(args))
2649
+ return args[0]
2650
+
2651
+ # Don't need a rule for parallel_construct.
2652
+
2653
+ def parallel_back_prop(self, clauses):
2654
+ enclosing_regions = get_enclosing_region(self.func_ir, self.blk_start)
2655
+ if DEBUG_OPENMP >= 1:
2656
+ print("parallel enclosing_regions:", enclosing_regions)
2657
+ if not enclosing_regions:
2658
+ return
2659
+
2660
+ for enclosing_region in enclosing_regions[::-1]:
2661
+ # If there is TEAMS in the enclosing region then THREAD_LIMIT is
2662
+ # already set, do nothing.
2663
+ if self.get_directive_if_contains(enclosing_region.tags, "TEAMS"):
2664
+ return
2665
+ if not self.get_directive_if_contains(enclosing_region.tags, "TARGET"):
2666
+ continue
2667
+
2668
+ # Set to 0 means "don't care", use implementation specific number of threads.
2669
+ num_threads = 0
2670
+ num_threads_clause = self.get_clauses_by_name(
2671
+ clauses, "QUAL.OMP.NUM_THREADS"
2672
+ )
2673
+ if num_threads_clause:
2674
+ assert len(num_threads_clause) == 1, (
2675
+ "Expected num_threads clause defined once"
2676
+ )
2677
+ num_threads = num_threads_clause[0].arg
2678
+ nt_tag = self.get_clauses_by_name(
2679
+ enclosing_region.tags, "QUAL.OMP.THREAD_LIMIT"
2680
+ )
2681
+ assert len(nt_tag) > 0
2682
+
2683
+ # If THREAD_LIMIT is less than requested NUM_THREADS or 1,
2684
+ # increase it. This is still valid if THREAD_LIMIT is 0, since this
2685
+ # means there was a parallel region before that did not specify
2686
+ # NUM_THREADS so we can set to the concrete value of the sibling
2687
+ # parallel region with the max value of NUM_THREADS.
2688
+ if num_threads > nt_tag[-1].arg or nt_tag[-1].arg == 1:
2689
+ nt_tag[-1].arg = num_threads
2690
+ return
2691
+
2692
+ def parallel_directive(self, args):
2693
+ if DEBUG_OPENMP >= 1:
2694
+ print("visit parallel_directive", args, type(args))
2695
+
2696
+ start_tags = [openmp_tag("DIR.OMP.PARALLEL")]
2697
+ end_tags = [openmp_tag("DIR.OMP.END.PARALLEL")]
2698
+ clauses = self.some_data_clause_directive(args, start_tags, end_tags, 1)
2699
+
2700
+ if len(list(filter(lambda x: x.name == "QUAL.OMP.NUM_THREADS", clauses))) > 1:
2701
+ raise MultipleNumThreadsClauses(
2702
+ f"Multiple num_threads clauses near line {self.loc} is not allowed in an OpenMP parallel region."
2703
+ )
2704
+
2705
+ if DEBUG_OPENMP >= 1:
2706
+ for clause in clauses:
2707
+ print("final clause:", clause)
2708
+
2709
+ # ---- Back propagate THREAD_LIMIT to enclosed target region. ----
2710
+ self.parallel_back_prop(clauses)
2711
+
2712
+ def parallel_clause(self, args):
2713
+ (val,) = args
2714
+ if DEBUG_OPENMP >= 1:
2715
+ print("visit parallel_clause", args, type(args), args[0])
2716
+ return val
2717
+
2718
+ def unique_parallel_clause(self, args):
2719
+ (val,) = args
2720
+ if DEBUG_OPENMP >= 1:
2721
+ print("visit unique_parallel_clause", args, type(args), args[0])
2722
+ assert isinstance(val, openmp_tag)
2723
+ return val
2724
+
2725
+ def teams_clause(self, args):
2726
+ (val,) = args
2727
+ if DEBUG_OPENMP >= 1:
2728
+ print("visit teams_clause", args, type(args), args[0])
2729
+ return val
2730
+
2731
+ def num_teams_clause(self, args):
2732
+ (_, num_teams) = args
2733
+ if DEBUG_OPENMP >= 1:
2734
+ print("visit num_teams_clause", args, type(args))
2735
+
2736
+ return openmp_tag("QUAL.OMP.NUM_TEAMS", num_teams, load=True)
2737
+
2738
+ def thread_limit_clause(self, args):
2739
+ (_, thread_limit) = args
2740
+ if DEBUG_OPENMP >= 1:
2741
+ print("visit thread_limit_clause", args, type(args))
2742
+
2743
+ return openmp_tag("QUAL.OMP.THREAD_LIMIT", thread_limit, load=True)
2744
+
2745
+ def if_clause(self, args):
2746
+ (_, if_val) = args
2747
+ if DEBUG_OPENMP >= 1:
2748
+ print("visit if_clause", args, type(args))
2749
+
2750
+ return openmp_tag("QUAL.OMP.IF", if_val, load=True)
2751
+
2752
+ # Don't need a rule for IF.
2753
+
2754
+ def num_threads_clause(self, args):
2755
+ (_, num_threads) = args
2756
+ if DEBUG_OPENMP >= 1:
2757
+ print("visit num_threads_clause", args, type(args))
2758
+
2759
+ return openmp_tag("QUAL.OMP.NUM_THREADS", num_threads, load=True)
2760
+
2761
+ # Don't need a rule for NUM_THREADS.
2762
+ # Don't need a rule for PARALLEL.
2763
+ # Don't need a rule for FOR.
2764
+ # Don't need a rule for parallel_for_construct.
2765
+
2766
+ def parallel_for_directive(self, args):
2767
+ return self.some_for_directive(
2768
+ args, "DIR.OMP.PARALLEL.LOOP", "DIR.OMP.END.PARALLEL.LOOP", 2, True
2769
+ )
2770
+
2771
+ def parallel_for_clause(self, args):
2772
+ if DEBUG_OPENMP >= 1:
2773
+ print("visit parallel_for_clause", args, type(args), args[0])
2774
+ return args[0]
2775
+
2776
+ # Don't need a rule for for_construct.
2777
+
2778
+ def for_directive(self, args):
2779
+ return self.some_for_directive(
2780
+ args, "DIR.OMP.LOOP", "DIR.OMP.END.LOOP", 1, False
2781
+ )
2782
+
2783
+ def for_clause(self, args):
2784
+ (val,) = args
2785
+ if DEBUG_OPENMP >= 1:
2786
+ print("visit for_clause", args, type(args))
2787
+ if isinstance(val, openmp_tag):
2788
+ return [val]
2789
+ elif isinstance(val, list):
2790
+ return val
2791
+ elif val == "nowait":
2792
+ return openmp_tag("QUAL.OMP.NOWAIT")
2793
+
2794
+ def unique_for_clause(self, args):
2795
+ (val,) = args
2796
+ if DEBUG_OPENMP >= 1:
2797
+ print("visit unique_for_clause", args, type(args))
2798
+ if isinstance(val, openmp_tag):
2799
+ return val
2800
+ elif val == "ordered":
2801
+ return openmp_tag("QUAL.OMP.ORDERED", 0)
2802
+
2803
+ # Don't need a rule for LINEAR.
2804
+
2805
+ def linear_clause(self, args):
2806
+ if DEBUG_OPENMP >= 1:
2807
+ print("visit linear_clause", args, type(args), args[0])
2808
+ return args[0]
2809
+
2810
+ """
2811
+ Linear_expr not in grammar
2812
+ def linear_expr(self, args):
2813
+ (_, var, step) = args
2814
+ if DEBUG_OPENMP >= 1:
2815
+ print("visit linear_expr", args, type(args))
2816
+ return openmp_tag("QUAL.OMP.LINEAR", [var, step])
2817
+ """
2818
+
2819
+ """
2820
+ def ORDERED(self, args):
2821
+ if DEBUG_OPENMP >= 1:
2822
+ print("visit ordered", args, type(args))
2823
+ return "ordered"
2824
+ """
2825
+
2826
+ def sched_no_expr(self, args):
2827
+ (_, kind) = args
2828
+ if DEBUG_OPENMP >= 1:
2829
+ print("visit sched_no_expr", args, type(args))
2830
+ if kind == "static":
2831
+ return openmp_tag("QUAL.OMP.SCHEDULE.STATIC", 0)
2832
+ elif kind == "dynamic":
2833
+ return openmp_tag("QUAL.OMP.SCHEDULE.DYNAMIC", 0)
2834
+ elif kind == "guided":
2835
+ return openmp_tag("QUAL.OMP.SCHEDULE.GUIDED", 0)
2836
+ elif kind == "runtime":
2837
+ return openmp_tag("QUAL.OMP.SCHEDULE.RUNTIME", 0)
2838
+
2839
+ def sched_expr(self, args):
2840
+ (_, kind, num_or_var) = args
2841
+ if DEBUG_OPENMP >= 1:
2842
+ print("visit sched_expr", args, type(args), num_or_var, type(num_or_var))
2843
+ if kind == "static":
2844
+ return openmp_tag("QUAL.OMP.SCHEDULE.STATIC", num_or_var, load=True)
2845
+ elif kind == "dynamic":
2846
+ return openmp_tag("QUAL.OMP.SCHEDULE.DYNAMIC", num_or_var, load=True)
2847
+ elif kind == "guided":
2848
+ return openmp_tag("QUAL.OMP.SCHEDULE.GUIDED", num_or_var, load=True)
2849
+ elif kind == "runtime":
2850
+ return openmp_tag("QUAL.OMP.SCHEDULE.RUNTIME", num_or_var, load=True)
2851
+
2852
+ def SCHEDULE(self, args):
2853
+ if DEBUG_OPENMP >= 1:
2854
+ print("visit SCHEDULE", args, type(args))
2855
+ return "schedule"
2856
+
2857
+ def schedule_kind(self, args):
2858
+ (kind,) = args
2859
+ if DEBUG_OPENMP >= 1:
2860
+ print("visit schedule_kind", args, type(args))
2861
+ return kind
2862
+
2863
+ # Don't need a rule for STATIC.
2864
+ # Don't need a rule for DYNAMIC.
2865
+ # Don't need a rule for GUIDED.
2866
+ # Don't need a rule for RUNTIME.
2867
+
2868
+ """
2869
+ def STATIC(self, args):
2870
+ if DEBUG_OPENMP >= 1:
2871
+ print("visit STATIC", args, type(args))
2872
+ return "static"
2873
+
2874
+ def DYNAMIC(self, args):
2875
+ if DEBUG_OPENMP >= 1:
2876
+ print("visit DYNAMIC", args, type(args))
2877
+ return "dynamic"
2878
+
2879
+ def GUIDED(self, args):
2880
+ if DEBUG_OPENMP >= 1:
2881
+ print("visit GUIDED", args, type(args))
2882
+ return "guided"
2883
+
2884
+ def RUNTIME(self, args):
2885
+ if DEBUG_OPENMP >= 1:
2886
+ print("visit RUNTIME", args, type(args))
2887
+ return "runtime"
2888
+ """
2889
+
2890
+ def COLON(self, args):
2891
+ if DEBUG_OPENMP >= 1:
2892
+ print("visit COLON", args, type(args))
2893
+ return ":"
2894
+
2895
+ def oslice(self, args):
2896
+ if DEBUG_OPENMP >= 1:
2897
+ print("visit oslice", args, type(args))
2898
+ start = None
2899
+ end = None
2900
+ if args[0] != ":":
2901
+ start = args[0]
2902
+ args = args[2:]
2903
+ else:
2904
+ args = args[1:]
2905
+
2906
+ if len(args) > 0:
2907
+ end = args[0]
2908
+ return slice(start, end)
2909
+
2910
+ def slice_list(self, args):
2911
+ if DEBUG_OPENMP >= 1:
2912
+ print("visit slice_list", args, type(args))
2913
+ if len(args) == 1:
2914
+ return args
2915
+ else:
2916
+ args[0].append(args[1])
2917
+ return args[0]
2918
+
2919
+ def name_slice(self, args):
2920
+ if DEBUG_OPENMP >= 1:
2921
+ print("visit name_slice", args, type(args))
2922
+ if len(args) == 1 or args[1] is None:
2923
+ return args[0]
2924
+ else:
2925
+ return NameSlice(args[0], args[1:])
2926
+
2927
+ def var_list(self, args):
2928
+ if DEBUG_OPENMP >= 1:
2929
+ print("visit var_list", args, type(args))
2930
+ if len(args) == 1:
2931
+ return args
2932
+ else:
2933
+ args[0].append(args[1])
2934
+ return args[0]
2935
+
2936
+ def number_list(self, args):
2937
+ if DEBUG_OPENMP >= 1:
2938
+ print("visit number_list", args, type(args))
2939
+ if len(args) == 1:
2940
+ return args
2941
+ else:
2942
+ args[0].append(args[1])
2943
+ return args[0]
2944
+
2945
+ def ompx_attribute(self, args):
2946
+ if DEBUG_OPENMP >= 1:
2947
+ print("visit ompx_attribute", args, type(args), args[0])
2948
+ (_, attr, number_list) = args
2949
+ return openmp_tag("QUAL.OMP.OMPX_ATTRIBUTE", (attr, number_list))
2950
+
2951
+ def PLUS(self, args):
2952
+ if DEBUG_OPENMP >= 1:
2953
+ print("visit PLUS", args, type(args))
2954
+ return "+"
2955
+
2956
+ def MINUS(self, args):
2957
+ if DEBUG_OPENMP >= 1:
2958
+ print("visit MINUS", args, type(args))
2959
+ return "-"
2960
+
2961
+ def STAR(self, args):
2962
+ if DEBUG_OPENMP >= 1:
2963
+ print("visit STAR", args, type(args))
2964
+ return "*"
2965
+
2966
+ def reduction_operator(self, args):
2967
+ arg = args[0]
2968
+ if DEBUG_OPENMP >= 1:
2969
+ print("visit reduction_operator", args, type(args), arg, type(arg))
2970
+ if arg == "+":
2971
+ return "ADD"
2972
+ elif arg == "-":
2973
+ return "SUB"
2974
+ elif arg == "*":
2975
+ return "MUL"
2976
+ assert 0
2977
+
2978
+ def threadprivate_directive(self, args):
2979
+ raise NotImplementedError("Threadprivate currently unsupported.")
2980
+
2981
+ def cancellation_point_directive(self, args):
2982
+ raise NotImplementedError("""Explicit cancellation points
2983
+ currently unsupported.""")
2984
+
2985
+ def construct_type_clause(self, args):
2986
+ if DEBUG_OPENMP >= 1:
2987
+ print("visit construct_type_clause", args, type(args), args[0])
2988
+ return args[0]
2989
+
2990
+ def cancel_directive(self, args):
2991
+ raise NotImplementedError("Cancel directive currently unsupported.")
2992
+
2993
+ # Don't need a rule for ORDERED.
2994
+
2995
+ def flush_directive(self, args):
2996
+ raise NotImplementedError("Flush directive currently unsupported.")
2997
+
2998
+ def region_phrase(self, args):
2999
+ raise NotImplementedError("No implementation for region phrase.")
3000
+
3001
+ def PYTHON_NAME(self, args):
3002
+ if DEBUG_OPENMP >= 1:
3003
+ print("visit PYTHON_NAME", args, type(args), str(args))
3004
+ return str(args)
3005
+
3006
+ def NUMBER(self, args):
3007
+ if DEBUG_OPENMP >= 1:
3008
+ print("visit NUMBER", args, type(args), str(args))
3009
+ return int(args)
3010
+
3011
+
3012
+ # This Transformer visitor class just finds the referenced python names
3013
+ # and puts them in a list of VarName. The default visitor function
3014
+ # looks for list of VarNames in the args to that tree node and then
3015
+ # concatenates them all together. The final return value is a list of
3016
+ # VarName that are variables used in the openmp clauses.
3017
+
3018
+
3019
+ class VarName(str):
3020
+ pass
3021
+
3022
+
3023
+ class OnlyClauseVar(VarName):
3024
+ pass
3025
+
3026
+
3027
+ class VarCollector(Transformer):
3028
+ def __init__(self):
3029
+ super(VarCollector, self).__init__()
3030
+
3031
+ def PYTHON_NAME(self, args):
3032
+ return [VarName(args)]
3033
+
3034
+ def const_num_or_var(self, args):
3035
+ return args[0]
3036
+
3037
+ def num_threads_clause(self, args):
3038
+ (_, num_threads) = args
3039
+ if isinstance(num_threads, list):
3040
+ assert len(num_threads) == 1
3041
+ return [OnlyClauseVar(num_threads[0])]
3042
+ else:
3043
+ return None
3044
+
3045
+ def __default__(self, data, children, meta):
3046
+ ret = []
3047
+ for c in children:
3048
+ if isinstance(c, list) and len(c) > 0:
3049
+ if isinstance(c[0], OnlyClauseVar):
3050
+ ret.extend(c)
3051
+ return ret
3052
+
3053
+
3054
+ def replace_ssa_var_callback(var, vardict):
3055
+ assert isinstance(var, ir.Var)
3056
+ while var.unversioned_name in vardict.keys():
3057
+ assert vardict[var.unversioned_name].name != var.unversioned_name
3058
+ new_var = vardict[var.unversioned_name]
3059
+ var = ir.Var(new_var.scope, new_var.name, new_var.loc)
3060
+ return var
3061
+
3062
+
3063
+ def replace_ssa_vars(blocks, vardict):
3064
+ """replace variables (ir.Var to ir.Var) from dictionary (name -> ir.Var)"""
3065
+ # remove identity values to avoid infinite loop
3066
+ new_vardict = {}
3067
+ for n, r in vardict.items():
3068
+ if n != r.name:
3069
+ new_vardict[n] = r
3070
+ visit_vars(blocks, replace_ssa_var_callback, new_vardict)
3071
+
3072
+
3073
+ def remove_ssa_callback(var, unused):
3074
+ assert isinstance(var, ir.Var)
3075
+ new_var = ir.Var(var.scope, var.unversioned_name, var.loc)
3076
+ return new_var
3077
+
3078
+
3079
+ def remove_ssa_from_func_ir(func_ir):
3080
+ typed_passes.PreLowerStripPhis()._strip_phi_nodes(func_ir)
3081
+ visit_vars(func_ir.blocks, remove_ssa_callback, None)
3082
+ func_ir._definitions = build_definitions(func_ir.blocks)
3083
+
3084
+
3085
+ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra):
3086
+ """Given the starting and ending block of the with-context,
3087
+ replaces the head block with a new block that has the starting
3088
+ openmp ir nodes in it and adds the ending openmp ir nodes to
3089
+ the end block.
3090
+ """
3091
+ sblk = blocks[blk_start]
3092
+ loc = sblk.loc
3093
+ sblk.body = sblk.body[1:]
3094
+
3095
+ args = extra["args"]
3096
+ arg = args[0]
3097
+ # If OpenMP argument is not a constant or not a string then raise exception
3098
+ if not isinstance(arg, (ir.Const, ir.FreeVar)):
3099
+ raise NonconstantOpenmpSpecification(
3100
+ f"Non-constant OpenMP specification at line {arg.loc}"
3101
+ )
3102
+ if not isinstance(arg.value, str):
3103
+ raise NonStringOpenmpSpecification(
3104
+ f"Non-string OpenMP specification at line {arg.loc}"
3105
+ )
3106
+
3107
+ if DEBUG_OPENMP >= 1:
3108
+ print("args:", args, type(args))
3109
+ print("arg:", arg, type(arg), arg.value, type(arg.value))
3110
+ parse_res = openmp_parser.parse(arg.value)
3111
+ if DEBUG_OPENMP >= 1:
3112
+ print(parse_res.pretty())
3113
+ visitor = OpenmpVisitor(func_ir, blocks, blk_start, blk_end, body_blocks, loc)
3114
+ try:
3115
+ visitor.transform(parse_res)
3116
+ except VisitError as e:
3117
+ raise e.__context__
3118
+ except Exception:
3119
+ print("generic transform exception")
3120
+ exc_type, exc_obj, exc_tb = sys.exc_info()
3121
+ fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
3122
+ print(exc_type, fname, exc_tb.tb_lineno)
3123
+ print("Internal error for OpenMp pragma '{}'".format(arg.value))
3124
+ sys.exit(-2)
3125
+ assert blocks is visitor.blocks