pyomp 0.5.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. numba/openmp/__init__.py +106 -0
  2. numba/openmp/_version.py +34 -0
  3. numba/openmp/analysis.py +251 -0
  4. numba/openmp/compiler.py +402 -0
  5. numba/openmp/config.py +27 -0
  6. numba/openmp/decorators.py +27 -0
  7. numba/openmp/exceptions.py +26 -0
  8. numba/openmp/ir_utils.py +4 -0
  9. numba/openmp/libs/openmp/lib/libgomp.1.dylib +0 -0
  10. numba/openmp/libs/openmp/lib/libgomp.dylib +0 -0
  11. numba/openmp/libs/openmp/lib/libiomp5.dylib +0 -0
  12. numba/openmp/libs/openmp/lib/libomp.dylib +0 -0
  13. numba/openmp/libs/openmp/patches/14.0.6/0001-BACKPORT-Fix-for-CUDA-OpenMP-RTL.patch +39 -0
  14. numba/openmp/libs/openmp/patches/14.0.6/0002-Fix-missing-includes.patch +12 -0
  15. numba/openmp/libs/openmp/patches/14.0.6/0003-Link-static-LLVM-libs.patch +13 -0
  16. numba/openmp/libs/openmp/patches/15.0.7/0001-Fix-missing-includes.patch +14 -0
  17. numba/openmp/libs/openmp/patches/15.0.7/0002-Link-LLVM-statically.patch +101 -0
  18. numba/openmp/libs/openmp/patches/15.0.7/0003-Disable-opaque-pointers-DeviceRTL-bitcode.patch +12 -0
  19. numba/openmp/libs/openmp/patches/16.0.6/0001-Load-plugins-from-install-directory.patch +53 -0
  20. numba/openmp/libs/openmp/patches/16.0.6/0002-Link-LLVM-statically.patch +218 -0
  21. numba/openmp/libs/openmp/patches/20.1.8/0001-Enable-standalone-build.patch +13 -0
  22. numba/openmp/libs/openmp/patches/20.1.8/0002-Link-statically-LLVM.patch +24 -0
  23. numba/openmp/libs/openmp/patches/20.1.8/0003-Do-not-build-liboffload.patch +12 -0
  24. numba/openmp/libs/pass/CGIntrinsicsOpenMP.cpp +2939 -0
  25. numba/openmp/libs/pass/CGIntrinsicsOpenMP.h +606 -0
  26. numba/openmp/libs/pass/CMakeLists.txt +57 -0
  27. numba/openmp/libs/pass/DebugOpenMP.cpp +17 -0
  28. numba/openmp/libs/pass/DebugOpenMP.h +28 -0
  29. numba/openmp/libs/pass/IntrinsicsOpenMP.cpp +837 -0
  30. numba/openmp/libs/pass/IntrinsicsOpenMP.h +13 -0
  31. numba/openmp/libs/pass/IntrinsicsOpenMP_CAPI.h +23 -0
  32. numba/openmp/libs/pass/libIntrinsicsOpenMP.dylib +0 -0
  33. numba/openmp/link_utils.py +126 -0
  34. numba/openmp/llvm_pass.py +48 -0
  35. numba/openmp/llvmlite_extensions.py +75 -0
  36. numba/openmp/omp_context.py +242 -0
  37. numba/openmp/omp_grammar.py +696 -0
  38. numba/openmp/omp_ir.py +2105 -0
  39. numba/openmp/omp_lower.py +3125 -0
  40. numba/openmp/omp_runtime.py +107 -0
  41. numba/openmp/overloads.py +53 -0
  42. numba/openmp/parser.py +6 -0
  43. numba/openmp/tags.py +532 -0
  44. numba/openmp/tests/test_openmp.py +5056 -0
  45. pyomp-0.5.0.dist-info/METADATA +193 -0
  46. pyomp-0.5.0.dist-info/RECORD +52 -0
  47. pyomp-0.5.0.dist-info/WHEEL +6 -0
  48. pyomp-0.5.0.dist-info/licenses/LICENSE +25 -0
  49. pyomp-0.5.0.dist-info/licenses/LICENSE-OPENMP.txt +361 -0
  50. pyomp-0.5.0.dist-info/top_level.txt +3 -0
  51. pyomp.dylibs/libc++.1.0.dylib +0 -0
  52. pyomp.dylibs/libzstd.1.5.7.dylib +0 -0
numba/openmp/omp_ir.py ADDED
@@ -0,0 +1,2105 @@
1
+ from numba.core import (
2
+ ir,
3
+ types,
4
+ cgutils,
5
+ typing,
6
+ transforms,
7
+ bytecode,
8
+ compiler,
9
+ typeinfer,
10
+ )
11
+ from numba.core.ir_utils import (
12
+ dprint_func_ir,
13
+ find_topo_order,
14
+ mk_unique_var,
15
+ apply_copy_propagate_extensions,
16
+ visit_vars_extensions,
17
+ visit_vars_inner,
18
+ )
19
+ from numba import cuda as numba_cuda
20
+ from numba.cuda import descriptor as cuda_descriptor, compiler as cuda_compiler
21
+ from numba.core.types.functions import Dispatcher
22
+ from numba.core.analysis import ir_extension_usedefs, _use_defs_result
23
+ import numba
24
+ import llvmlite.ir as lir
25
+ import llvmlite.binding as ll
26
+ import sys
27
+ import os
28
+ import copy
29
+ import tempfile
30
+ import subprocess
31
+ import operator
32
+ import numpy as np
33
+ from pathlib import Path
34
+ import types as python_types
35
+
36
+ from .analysis import (
37
+ is_dsa,
38
+ typemap_lookup,
39
+ is_target_tag,
40
+ is_target_arg,
41
+ in_openmp_region,
42
+ get_blocks_between_start_end,
43
+ get_name_var_table,
44
+ is_pointer_target_arg,
45
+ )
46
+ from .tags import (
47
+ openmp_tag_list_to_str,
48
+ list_vars_from_tags,
49
+ get_tags_of_type,
50
+ StringLiteral,
51
+ openmp_tag,
52
+ NameSlice,
53
+ )
54
+ from .llvmlite_extensions import TokenType, CallInstrWithOperandBundle
55
+ from .config import (
56
+ libpath,
57
+ DEBUG_OPENMP,
58
+ DEBUG_OPENMP_LLVM_PASS,
59
+ OPENMP_DEVICE_TOOLCHAIN,
60
+ )
61
+ from .link_utils import link_shared_library
62
+ from .llvm_pass import run_intrinsics_openmp_pass
63
+ from .compiler import (
64
+ OnlyLower,
65
+ OnlyLowerCUDA,
66
+ OpenmpCPUTargetContext,
67
+ OpenmpCUDATargetContext,
68
+ CustomAOTCPUCodeLibrary,
69
+ CustomCPUCodeLibrary,
70
+ CustomContext,
71
+ )
72
+
73
+ unique = 0
74
+
75
+
76
+ def get_unique():
77
+ global unique
78
+ ret = unique
79
+ unique += 1
80
+ return ret
81
+
82
+
83
+ def openmp_region_alloca(obj, alloca_instr, typ):
84
+ obj.alloca(alloca_instr, typ)
85
+
86
+
87
+ def push_alloca_callback(lowerer, callback, data, builder):
88
+ # cgutils.push_alloca_callbacks(callback, data)
89
+ if not hasattr(builder, "_lowerer_push_alloca_callbacks"):
90
+ builder._lowerer_push_alloca_callbacks = 0
91
+ builder._lowerer_push_alloca_callbacks += 1
92
+
93
+
94
+ def pop_alloca_callback(lowerer, builder):
95
+ # cgutils.pop_alloca_callbacks()
96
+ builder._lowerer_push_alloca_callbacks -= 1
97
+
98
+
99
+ def get_dotted_type(x, typemap, lowerer):
100
+ xsplit = x.split("*")
101
+ cur_typ = typemap_lookup(typemap, xsplit[0])
102
+ # print("xsplit:", xsplit, cur_typ, type(cur_typ))
103
+ for field in xsplit[1:]:
104
+ dm = lowerer.context.data_model_manager.lookup(cur_typ)
105
+ findex = dm._fields.index(field)
106
+ cur_typ = dm._members[findex]
107
+ # print("dm:", dm, type(dm), dm._members, type(dm._members), dm._fields, type(dm._fields), findex, cur_typ, type(cur_typ))
108
+ return cur_typ
109
+
110
+
111
+ class OpenMPCUDACodegen:
112
+ def __init__(self):
113
+ import numba.cuda.api as cudaapi
114
+ import numba.cuda.cudadrv.libs as cudalibs
115
+ from numba.cuda.codegen import CUDA_TRIPLE
116
+ from numba.cuda.cudadrv import driver, enums
117
+
118
+ # The OpenMP target runtime prefers the blocking sync flag, so we set it
119
+ # here before creating the CUDA context.
120
+ driver.driver.cuDevicePrimaryCtxSetFlags(0, enums.CU_CTX_SCHED_BLOCKING_SYNC)
121
+ self.cc = cudaapi.get_current_device().compute_capability
122
+ self.sm = "sm_" + str(self.cc[0]) + str(self.cc[1])
123
+
124
+ # Read the libdevice bitcode for the architecture to link with the module.
125
+ self.libdevice_path = cudalibs.get_libdevice()
126
+ with open(self.libdevice_path, "rb") as f:
127
+ self.libdevice_mod = ll.parse_bitcode(f.read())
128
+
129
+ # Read the OpenMP device RTL for the architecture to link with the module.
130
+ self.libomptarget_arch = libpath / "openmp" / "lib" / "libomptarget-nvptx.bc"
131
+ try:
132
+ with open(self.libomptarget_arch, "rb") as f:
133
+ self.libomptarget_mod = ll.parse_bitcode(f.read())
134
+ except FileNotFoundError:
135
+ raise RuntimeError(
136
+ f"Device RTL for architecture {self.sm} not found. Check compute capability with LLVM version {'.'.join(map(str, ll.llvm_version_info))}."
137
+ )
138
+
139
+ # Initialize asm printers to codegen ptx.
140
+ ll.initialize_all_targets()
141
+ ll.initialize_all_asmprinters()
142
+ target = ll.Target.from_triple(CUDA_TRIPLE)
143
+ # We pick opt=2 as a reasonable optimization level for codegen.
144
+ self.tm = target.create_target_machine(cpu=self.sm, opt=2)
145
+
146
+ def _get_target_image(self, mod, filename_prefix, ompx_attrs, use_toolchain=False):
147
+ from numba.cuda.cudadrv import driver
148
+ from numba.core.llvm_bindings import create_pass_builder
149
+
150
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
151
+ with open(filename_prefix + ".ll", "w") as f:
152
+ f.write(str(mod))
153
+
154
+ # Lower openmp intrinsics.
155
+ mod = run_intrinsics_openmp_pass(mod)
156
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
157
+ with open(filename_prefix + "-intr.ll", "w") as f:
158
+ f.write(str(mod))
159
+
160
+ def _internalize():
161
+ # Internalize non-kernel function definitions.
162
+ for func in mod.functions:
163
+ if func.is_declaration:
164
+ continue
165
+ if func.linkage != ll.Linkage.external:
166
+ continue
167
+ if "__omp_offload_numba" in func.name:
168
+ continue
169
+ func.linkage = "internal"
170
+
171
+ # Link first libdevice and optimize aggressively with opt=2 as a
172
+ # reasonable optimization default.
173
+ mod.link_in(self.libdevice_mod, preserve=True)
174
+ # Internalize non-kernel function definitions.
175
+ _internalize()
176
+ # Run passes for optimization, including target-specific passes.
177
+ # Run function passes.
178
+ with create_pass_builder(
179
+ self.tm, opt=2, slp_vectorize=True, loop_vectorize=True
180
+ ) as pb:
181
+ pm = pb.getFunctionPassManager()
182
+ for func in mod.functions:
183
+ pm.run(func, pb)
184
+
185
+ # Run module passes.
186
+ with create_pass_builder(
187
+ self.tm, opt=2, slp_vectorize=True, loop_vectorize=True
188
+ ) as pb:
189
+ pm = pb.getModulePassManager()
190
+ pm.run(mod, pb)
191
+
192
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
193
+ mod.verify()
194
+ with open(filename_prefix + "-intr-dev.ll", "w") as f:
195
+ f.write(str(mod))
196
+
197
+ # Link in OpenMP device RTL and optimize lightly, with opt=1 to avoid
198
+ # aggressive optimization can break openmp execution synchronization for
199
+ # target regions.
200
+ mod.link_in(self.libomptarget_mod, preserve=True)
201
+ # Internalize non-kernel function definitions.
202
+ _internalize()
203
+ # Run module passes.
204
+ with create_pass_builder(
205
+ self.tm, opt=1, slp_vectorize=True, loop_vectorize=True
206
+ ) as pb:
207
+ pm = pb.getModulePassManager()
208
+ pm.run(mod, pb)
209
+
210
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
211
+ mod.verify()
212
+ with open(filename_prefix + "-intr-dev-rtl.ll", "w") as f:
213
+ f.write(str(mod))
214
+
215
+ # Generate ptx assemlby.
216
+ ptx = self.tm.emit_assembly(mod)
217
+ if use_toolchain:
218
+ # ptxas normally does file I/O; prefer piping PTX to stdin to avoid
219
+ # writing the .s file unless debug is enabled.
220
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
221
+ with open(filename_prefix + "-intr-dev-rtl.s", "w") as f:
222
+ f.write(ptx)
223
+
224
+ # Invoke ptxas reading PTX from stdin ('-') and writing output to
225
+ # a temporary file so we can capture the object in-memory without
226
+ # leaving it in the working directory.
227
+ with tempfile.NamedTemporaryFile(suffix=".o", delete=False) as tmpf:
228
+ outname = tmpf.name
229
+ try:
230
+ subprocess.run(
231
+ [
232
+ "ptxas",
233
+ "-m64",
234
+ "--gpu-name",
235
+ self.sm,
236
+ "-",
237
+ "-o",
238
+ outname,
239
+ ],
240
+ input=ptx.encode(),
241
+ check=True,
242
+ )
243
+
244
+ with open(outname, "rb") as f:
245
+ cubin = f.read()
246
+
247
+ # If debug is enabled, also write a named copy for inspection.
248
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
249
+ with open(
250
+ filename_prefix + "-intr-dev-rtl.o",
251
+ "wb",
252
+ ) as f:
253
+ f.write(cubin)
254
+ finally:
255
+ try:
256
+ os.remove(outname)
257
+ except OSError:
258
+ pass
259
+ else:
260
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
261
+ with open(
262
+ filename_prefix + "-intr-dev-rtl.s",
263
+ "w",
264
+ ) as f:
265
+ f.write(ptx)
266
+
267
+ linker_kwargs = {}
268
+ for x in ompx_attrs:
269
+ linker_kwargs[x.arg[0]] = (
270
+ tuple(x.arg[1]) if len(x.arg[1]) > 1 else x.arg[1][0]
271
+ )
272
+ # NOTE: DO NOT set cc, since the linker will always
273
+ # compile for the existing GPU context and it is
274
+ # incompatible with the launch_bounds ompx_attribute.
275
+ linker = driver.Linker.new(**linker_kwargs)
276
+ linker.add_ptx(ptx.encode())
277
+ cubin = linker.complete()
278
+
279
+ if DEBUG_OPENMP_LLVM_PASS >= 1:
280
+ with open(
281
+ filename_prefix + "-intr-dev-rtl.o",
282
+ "wb",
283
+ ) as f:
284
+ f.write(cubin)
285
+
286
+ return cubin
287
+
288
+ def get_target_image(self, cres, ompx_attrs):
289
+ filename_prefix = cres.library.name
290
+ allmods = cres.library.modules
291
+ linked_mod = ll.parse_assembly(str(allmods[0]))
292
+ for mod in allmods[1:]:
293
+ linked_mod.link_in(ll.parse_assembly(str(mod)))
294
+ if OPENMP_DEVICE_TOOLCHAIN >= 1:
295
+ return self._get_target_image(
296
+ linked_mod, filename_prefix, ompx_attrs, use_toolchain=True
297
+ )
298
+ else:
299
+ return self._get_target_image(linked_mod, filename_prefix, ompx_attrs)
300
+
301
+
302
+ _omp_cuda_codegen = None
303
+
304
+
305
+ # Accessor for the singleton OpenMPCUDACodegen instance. Initializes the
306
+ # instance on first use to ensure a single CUDA context and codegen setup
307
+ # per process.
308
+ def get_omp_cuda_codegen():
309
+ global _omp_cuda_codegen
310
+ if _omp_cuda_codegen is None:
311
+ _omp_cuda_codegen = OpenMPCUDACodegen()
312
+ return _omp_cuda_codegen
313
+
314
+
315
+ def copy_one(x, calltypes):
316
+ if DEBUG_OPENMP >= 2:
317
+ print("copy_one:", x, type(x))
318
+ if isinstance(x, ir.Loc):
319
+ return copy.copy(x)
320
+ elif isinstance(x, ir.Expr):
321
+ if x in calltypes:
322
+ ctyp = calltypes[x]
323
+ else:
324
+ ctyp = None
325
+ ret = ir.Expr(
326
+ copy_one(x.op, calltypes),
327
+ copy_one(x.loc, calltypes),
328
+ **copy_one(x._kws, calltypes),
329
+ )
330
+ if ctyp and ret not in calltypes:
331
+ calltypes[ret] = ctyp
332
+ return ret
333
+ elif isinstance(x, dict):
334
+ return {k: copy_one(v, calltypes) for k, v in x.items()}
335
+ elif isinstance(x, list):
336
+ return [copy_one(v, calltypes) for v in x]
337
+ elif isinstance(x, tuple):
338
+ return tuple([copy_one(v, calltypes) for v in x])
339
+ elif isinstance(x, ir.Const):
340
+ return ir.Const(
341
+ copy_one(x.value, calltypes), copy_one(x.loc, calltypes), x.use_literal_type
342
+ )
343
+ elif isinstance(
344
+ x,
345
+ (
346
+ int,
347
+ float,
348
+ str,
349
+ ir.Global,
350
+ python_types.BuiltinFunctionType,
351
+ ir.UndefinedType,
352
+ type(None),
353
+ types.functions.ExternalFunction,
354
+ ),
355
+ ):
356
+ return x
357
+ elif isinstance(x, ir.Var):
358
+ return ir.Var(x.scope, copy_one(x.name, calltypes), copy_one(x.loc, calltypes))
359
+ elif isinstance(x, ir.Del):
360
+ return ir.Del(copy_one(x.value, calltypes), copy_one(x.loc, calltypes))
361
+ elif isinstance(x, ir.Jump):
362
+ return ir.Jump(copy_one(x.target, calltypes), copy_one(x.loc, calltypes))
363
+ elif isinstance(x, ir.Return):
364
+ return ir.Return(copy_one(x.value, calltypes), copy_one(x.loc, calltypes))
365
+ elif isinstance(x, ir.Branch):
366
+ return ir.Branch(
367
+ copy_one(x.cond, calltypes),
368
+ copy_one(x.truebr, calltypes),
369
+ copy_one(x.falsebr, calltypes),
370
+ copy_one(x.loc, calltypes),
371
+ )
372
+ elif isinstance(x, ir.Print):
373
+ ctyp = calltypes[x]
374
+ ret = copy.copy(x)
375
+ calltypes[ret] = ctyp
376
+ return ret
377
+ elif isinstance(x, ir.Assign):
378
+ return ir.Assign(
379
+ copy_one(x.value, calltypes),
380
+ copy_one(x.target, calltypes),
381
+ copy_one(x.loc, calltypes),
382
+ )
383
+ elif isinstance(x, ir.Arg):
384
+ return ir.Arg(
385
+ copy_one(x.name, calltypes),
386
+ copy_one(x.index, calltypes),
387
+ copy_one(x.loc, calltypes),
388
+ )
389
+ elif isinstance(x, ir.SetItem):
390
+ ctyp = calltypes[x]
391
+ ret = ir.SetItem(
392
+ copy_one(x.target, calltypes),
393
+ copy_one(x.index, calltypes),
394
+ copy_one(x.value, calltypes),
395
+ copy_one(x.loc, calltypes),
396
+ )
397
+ calltypes[ret] = ctyp
398
+ return ret
399
+ elif isinstance(x, ir.StaticSetItem):
400
+ ctyp = calltypes[x]
401
+ ret = ir.StaticSetItem(
402
+ copy_one(x.target, calltypes),
403
+ copy_one(x.index, calltypes),
404
+ copy_one(x.index_var, calltypes),
405
+ copy_one(x.value, calltypes),
406
+ copy_one(x.loc, calltypes),
407
+ )
408
+ calltypes[ret] = ctyp
409
+ return ret
410
+ elif isinstance(x, ir.FreeVar):
411
+ return ir.FreeVar(
412
+ copy_one(x.index, calltypes),
413
+ copy_one(x.name, calltypes),
414
+ copy_one(x.value, calltypes),
415
+ copy_one(x.loc, calltypes),
416
+ )
417
+ elif isinstance(x, slice):
418
+ return slice(
419
+ copy_one(x.start, calltypes),
420
+ copy_one(x.stop, calltypes),
421
+ copy_one(x.step, calltypes),
422
+ )
423
+ elif isinstance(x, ir.PopBlock):
424
+ return ir.PopBlock(copy_one(x.loc, calltypes))
425
+ elif isinstance(x, ir.SetAttr):
426
+ ctyp = calltypes[x]
427
+ ret = ir.SetAttr(
428
+ copy_one(x.target, calltypes),
429
+ copy_one(x.attr, calltypes),
430
+ copy_one(x.value, calltypes),
431
+ copy_one(x.loc, calltypes),
432
+ )
433
+ calltypes[ret] = ctyp
434
+ return ret
435
+ elif isinstance(x, ir.DelAttr):
436
+ return ir.DelAttr(
437
+ copy_one(x.target, calltypes),
438
+ copy_one(x.attr, calltypes),
439
+ copy_one(x.loc, calltypes),
440
+ )
441
+ elif isinstance(x, types.Type):
442
+ return x # Don't copy types.
443
+ print("Failed to handle the following type when copying target IR.", type(x), x)
444
+ assert False
445
+
446
+
447
+ def copy_ir(input_ir, calltypes, depth=1):
448
+ assert depth >= 0 and depth <= 1
449
+
450
+ # This is a depth 0 copy.
451
+ cur_ir = input_ir.copy()
452
+ if depth == 1:
453
+ for blk in cur_ir.blocks.values():
454
+ for i in range(len(blk.body)):
455
+ if not isinstance(
456
+ blk.body[i], (openmp_region_start, openmp_region_end)
457
+ ):
458
+ blk.body[i] = copy_one(blk.body[i], calltypes)
459
+
460
+ return cur_ir
461
+
462
+
463
+ def replace_np_empty_with_cuda_shared(
464
+ outlined_ir, typemap, calltypes, prefix, typingctx
465
+ ):
466
+ if DEBUG_OPENMP >= 2:
467
+ print("starting replace_np_empty_with_cuda_shared")
468
+ outlined_ir = outlined_ir.blocks
469
+ converted_arrays = []
470
+ consts = {}
471
+ topo_order = find_topo_order(outlined_ir)
472
+ mode = 0 # 0 = non-target region, 1 = target region, 2 = teams region, 3 = teams parallel region
473
+ # For each block in topological order...
474
+ for label in topo_order:
475
+ block = outlined_ir[label]
476
+ new_block_body = []
477
+ blen = len(block.body)
478
+ index = 0
479
+ # For each statement in the block.
480
+ while index < blen:
481
+ stmt = block.body[index]
482
+ # Adjust mode based on the start of an openmp region.
483
+ if isinstance(stmt, openmp_region_start):
484
+ if "TARGET" in stmt.tags[0].name:
485
+ assert mode == 0
486
+ mode = 1
487
+ if "TEAMS" in stmt.tags[0].name and mode == 1:
488
+ mode = 2
489
+ if "PARALLEL" in stmt.tags[0].name and mode == 2:
490
+ mode = 3
491
+ new_block_body.append(stmt)
492
+ # Adjust mode based on the end of an openmp region.
493
+ elif isinstance(stmt, openmp_region_end):
494
+ if mode == 3 and "PARALLEL" in stmt.tags[0].name:
495
+ mode = 2
496
+ if mode == 2 and "TEAMS" in stmt.tags[0].name:
497
+ mode = 1
498
+ if mode == 1 and "TARGET" in stmt.tags[0].name:
499
+ mode = 0
500
+ new_block_body.append(stmt)
501
+ # Fix calltype for the np.empty call to have literal as first
502
+ # arg and include explicit dtype.
503
+ elif (
504
+ isinstance(stmt, ir.Assign)
505
+ and isinstance(stmt.value, ir.Expr)
506
+ and stmt.value.op == "call"
507
+ and stmt.value.func in converted_arrays
508
+ ):
509
+ size = consts[stmt.value.args[0].name]
510
+ # The 1D case where the dimension size is directly a const.
511
+ if isinstance(size, ir.Const):
512
+ size = size.value
513
+ signature = calltypes[stmt.value]
514
+ signature_args = (
515
+ types.scalars.IntegerLiteral(size),
516
+ types.functions.NumberClass(signature.return_type.dtype),
517
+ )
518
+ del calltypes[stmt.value]
519
+ calltypes[stmt.value] = typing.templates.Signature(
520
+ signature.return_type, signature_args, signature.recvr
521
+ )
522
+ # The 2D+ case where the dimension sizes are in a tuple.
523
+ elif isinstance(size, ir.Expr):
524
+ signature = calltypes[stmt.value]
525
+ signature_args = (
526
+ types.Tuple(
527
+ [
528
+ types.scalars.IntegerLiteral(consts[x.name].value)
529
+ for x in size.items
530
+ ]
531
+ ),
532
+ types.functions.NumberClass(signature.return_type.dtype),
533
+ )
534
+ del calltypes[stmt.value]
535
+ calltypes[stmt.value] = typing.templates.Signature(
536
+ signature.return_type, signature_args, signature.recvr
537
+ )
538
+
539
+ # These lines will force the function to be in the data structures that lowering uses.
540
+ afnty = typemap[stmt.value.func.name]
541
+ afnty.get_call_type(typingctx, signature_args, {})
542
+ if len(stmt.value.args) == 1:
543
+ dtype_to_use = signature.return_type.dtype
544
+ # If dtype in kwargs then remove it.
545
+ if len(stmt.value.kws) > 0:
546
+ for kwarg in stmt.value.kws:
547
+ if kwarg[0] == "dtype":
548
+ stmt.value.kws = list(
549
+ filter(lambda x: x[0] != "dtype", stmt.value.kws)
550
+ )
551
+ break
552
+ new_block_body.append(
553
+ ir.Assign(
554
+ ir.Global("np", np, stmt.loc),
555
+ ir.Var(
556
+ stmt.target.scope, mk_unique_var(".np_global"), stmt.loc
557
+ ),
558
+ stmt.loc,
559
+ )
560
+ )
561
+ typemap[new_block_body[-1].target.name] = types.Module(np)
562
+ new_block_body.append(
563
+ ir.Assign(
564
+ ir.Expr.getattr(
565
+ new_block_body[-1].target, str(dtype_to_use), stmt.loc
566
+ ),
567
+ ir.Var(
568
+ stmt.target.scope, mk_unique_var(".np_dtype"), stmt.loc
569
+ ),
570
+ stmt.loc,
571
+ )
572
+ )
573
+ typemap[new_block_body[-1].target.name] = (
574
+ types.functions.NumberClass(signature.return_type.dtype)
575
+ )
576
+ stmt.value.args.append(new_block_body[-1].target)
577
+ else:
578
+ raise NotImplementedError(
579
+ "np.empty having more than shape and dtype arguments not yet supported."
580
+ )
581
+ new_block_body.append(stmt)
582
+ # Keep track of variables assigned from consts or from build_tuples make up exclusively of
583
+ # variables assigned from consts.
584
+ elif isinstance(stmt, ir.Assign) and (
585
+ isinstance(stmt.value, ir.Const)
586
+ or (
587
+ isinstance(stmt.value, ir.Expr)
588
+ and stmt.value.op == "build_tuple"
589
+ and all([x.name in consts for x in stmt.value.items])
590
+ )
591
+ ):
592
+ consts[stmt.target.name] = stmt.value
593
+ new_block_body.append(stmt)
594
+ # If we see a global for the numpy module.
595
+ elif (
596
+ isinstance(stmt, ir.Assign)
597
+ and isinstance(stmt.value, ir.Global)
598
+ and isinstance(stmt.value.value, python_types.ModuleType)
599
+ and stmt.value.value.__name__ == "numpy"
600
+ ):
601
+ lhs = stmt.target
602
+ index += 1
603
+ next_stmt = block.body[index]
604
+ # And the next statement is a getattr for the name "empty" on the numpy module
605
+ # and we are in a target region.
606
+ if (
607
+ isinstance(next_stmt, ir.Assign)
608
+ and isinstance(next_stmt.value, ir.Expr)
609
+ and next_stmt.value.value == lhs
610
+ and next_stmt.value.op == "getattr"
611
+ and next_stmt.value.attr == "empty"
612
+ and mode > 0
613
+ ):
614
+ # Remember that we are converting this np.empty into a CUDA call.
615
+ converted_arrays.append(next_stmt.target)
616
+
617
+ # Create numba.cuda module variable.
618
+ new_block_body.append(
619
+ ir.Assign(
620
+ ir.Global("numba", numba, lhs.loc),
621
+ ir.Var(
622
+ lhs.scope, mk_unique_var(".cuda_shared_global"), lhs.loc
623
+ ),
624
+ lhs.loc,
625
+ )
626
+ )
627
+ typemap[new_block_body[-1].target.name] = types.Module(numba)
628
+ new_block_body.append(
629
+ ir.Assign(
630
+ ir.Expr.getattr(new_block_body[-1].target, "cuda", lhs.loc),
631
+ ir.Var(
632
+ lhs.scope,
633
+ mk_unique_var(".cuda_shared_getattr"),
634
+ lhs.loc,
635
+ ),
636
+ lhs.loc,
637
+ )
638
+ )
639
+ typemap[new_block_body[-1].target.name] = types.Module(numba.cuda)
640
+
641
+ if mode == 1:
642
+ raise NotImplementedError(
643
+ "np.empty used in non-teams or parallel target region"
644
+ )
645
+ pass
646
+ elif mode == 2:
647
+ # Create numba.cuda.shared module variable.
648
+ new_block_body.append(
649
+ ir.Assign(
650
+ ir.Expr.getattr(
651
+ new_block_body[-1].target, "shared", lhs.loc
652
+ ),
653
+ ir.Var(
654
+ lhs.scope,
655
+ mk_unique_var(".cuda_shared_getattr"),
656
+ lhs.loc,
657
+ ),
658
+ lhs.loc,
659
+ )
660
+ )
661
+ typemap[new_block_body[-1].target.name] = types.Module(
662
+ numba.cuda.stubs.shared
663
+ )
664
+ elif mode == 3:
665
+ # Create numba.cuda.local module variable.
666
+ new_block_body.append(
667
+ ir.Assign(
668
+ ir.Expr.getattr(
669
+ new_block_body[-1].target, "local", lhs.loc
670
+ ),
671
+ ir.Var(
672
+ lhs.scope,
673
+ mk_unique_var(".cuda_local_getattr"),
674
+ lhs.loc,
675
+ ),
676
+ lhs.loc,
677
+ )
678
+ )
679
+ typemap[new_block_body[-1].target.name] = types.Module(
680
+ numba.cuda.stubs.local
681
+ )
682
+
683
+ # Change the typemap for the original function variable for np.empty.
684
+ afnty = typingctx.resolve_getattr(
685
+ typemap[new_block_body[-1].target.name], "array"
686
+ )
687
+ del typemap[next_stmt.target.name]
688
+ typemap[next_stmt.target.name] = afnty
689
+ # Change the variable that previously was assigned np.empty to now be one of
690
+ # the CUDA array allocators.
691
+ new_block_body.append(
692
+ ir.Assign(
693
+ ir.Expr.getattr(
694
+ new_block_body[-1].target, "array", lhs.loc
695
+ ),
696
+ next_stmt.target,
697
+ lhs.loc,
698
+ )
699
+ )
700
+ else:
701
+ new_block_body.append(stmt)
702
+ new_block_body.append(next_stmt)
703
+ else:
704
+ new_block_body.append(stmt)
705
+ index += 1
706
+ block.body = new_block_body
707
+
708
+
709
+ def remove_dels(blocks):
710
+ """remove ir.Del nodes"""
711
+ for block in blocks.values():
712
+ new_body = []
713
+ for stmt in block.body:
714
+ if not isinstance(stmt, ir.Del):
715
+ new_body.append(stmt)
716
+ block.body = new_body
717
+ return
718
+
719
+
720
+ def find_target_start_end(func_ir, target_num):
721
+ start_block = None
722
+ end_block = None
723
+
724
+ for label, block in func_ir.blocks.items():
725
+ if isinstance(block.body[0], openmp_region_start):
726
+ block_target_num = block.body[0].has_target()
727
+ if target_num == block_target_num:
728
+ start_block = label
729
+ if start_block is not None and end_block is not None:
730
+ return start_block, end_block
731
+ elif isinstance(block.body[0], openmp_region_end):
732
+ block_target_num = block.body[0].start_region.has_target()
733
+ if target_num == block_target_num:
734
+ end_block = label
735
+ if start_block is not None and end_block is not None:
736
+ return start_block, end_block
737
+
738
+ dprint_func_ir(func_ir, "find_target_start_end")
739
+ print("target_num:", target_num)
740
+ assert False
741
+
742
+
743
+ class openmp_region_start(ir.Stmt):
744
+ def __init__(self, tags, region_number, loc, firstprivate_dead_after=None):
745
+ if DEBUG_OPENMP >= 2:
746
+ print("region ids openmp_region_start::__init__", id(self))
747
+ self.tags = tags
748
+ self.region_number = region_number
749
+ self.loc = loc
750
+ self.omp_region_var = None
751
+ self.omp_metadata = None
752
+ self.tag_vars = set()
753
+ self.normal_iv = None
754
+ self.target_copy = False
755
+ self.firstprivate_dead_after = (
756
+ [] if firstprivate_dead_after is None else firstprivate_dead_after
757
+ )
758
+ for tag in self.tags:
759
+ if isinstance(tag.arg, ir.Var):
760
+ self.tag_vars.add(tag.arg.name)
761
+ elif isinstance(tag.arg, str):
762
+ self.tag_vars.add(tag.arg)
763
+ elif isinstance(tag.arg, NameSlice):
764
+ self.tag_vars.add(tag.arg.name)
765
+
766
+ if tag.name == "QUAL.OMP.NORMALIZED.IV":
767
+ self.normal_iv = tag.arg
768
+ if DEBUG_OPENMP >= 1:
769
+ print("tags:", self.tags)
770
+ print("tag_vars:", sorted(self.tag_vars))
771
+ self.acq_res = False
772
+ self.acq_rel = False
773
+ self.alloca_queue = []
774
+ self.end_region = None
775
+
776
+ def __getstate__(self):
777
+ state = self.__dict__.copy()
778
+ return state
779
+
780
+ def __setstate__(self, state):
781
+ self.__dict__.update(state)
782
+
783
+ def replace_var_names(self, namedict):
784
+ for i in range(len(self.tags)):
785
+ if isinstance(self.tags[i].arg, ir.Var):
786
+ if self.tags[i].arg.name in namedict:
787
+ var = self.tags[i].arg
788
+ self.tags[i].arg = ir.Var(var.scope, namedict[var.name], var.log)
789
+ elif isinstance(self.tags[i].arg, str):
790
+ if "*" in self.tags[i].arg:
791
+ xsplit = self.tags[i].arg.split("*")
792
+ assert len(xsplit) == 2
793
+ if xsplit[0] in namedict:
794
+ self.tags[i].arg = namedict[xsplit[0]] + "*" + xsplit[1]
795
+ else:
796
+ if self.tags[i].arg in namedict:
797
+ self.tags[i].arg = namedict[self.tags[i].arg]
798
+
799
+ def add_tag(self, tag):
800
+ tag_arg_str = None
801
+ if isinstance(tag.arg, ir.Var):
802
+ tag_arg_str = tag.arg.name
803
+ elif isinstance(tag.arg, str):
804
+ tag_arg_str = tag.arg
805
+ elif isinstance(tag.arg, lir.instructions.AllocaInstr):
806
+ tag_arg_str = tag.arg._get_name()
807
+ else:
808
+ assert False
809
+ if isinstance(tag_arg_str, str):
810
+ self.tag_vars.add(tag_arg_str)
811
+ self.tags.append(tag)
812
+
813
+ def get_var_dsa(self, var):
814
+ assert isinstance(var, str)
815
+ for tag in self.tags:
816
+ if is_dsa(tag.name) and tag.var_in(var):
817
+ return tag.name
818
+ return None
819
+
820
+ def requires_acquire_release(self):
821
+ pass
822
+ # self.acq_res = True
823
+
824
+ def requires_combined_acquire_release(self):
825
+ pass
826
+ # self.acq_rel = True
827
+
828
+ def has_target(self):
829
+ for t in self.tags:
830
+ if is_target_tag(t.name):
831
+ return t.arg
832
+ return None
833
+
834
+ def list_vars(self):
835
+ return list_vars_from_tags(self.tags)
836
+
837
+ def update_tags(self):
838
+ with self.builder.goto_block(self.block):
839
+ cur_instr = -1
840
+
841
+ while True:
842
+ last_instr = self.builder.block.instructions[cur_instr]
843
+ if (
844
+ isinstance(last_instr, lir.instructions.CallInstr)
845
+ and last_instr.tags is not None
846
+ and len(last_instr.tags) > 0
847
+ ):
848
+ break
849
+ cur_instr -= 1
850
+
851
+ last_instr.tags = openmp_tag_list_to_str(self.tags, self.lowerer, False)
852
+ if DEBUG_OPENMP >= 1:
853
+ print("last_tags:", last_instr.tags, type(last_instr.tags))
854
+
855
+ def alloca(self, alloca_instr, typ):
856
+ # We can't process these right away since the processing required can
857
+ # lead to infinite recursion. So, we just accumulate them in a queue
858
+ # and then process them later at the end_region marker so that the
859
+ # variables are guaranteed to exist in their full form so that when we
860
+ # process them then they won't lead to infinite recursion.
861
+ self.alloca_queue.append((alloca_instr, typ))
862
+
863
+ def process_alloca_queue(self):
864
+ # This should be old code...making sure with the assertion.
865
+ assert len(self.alloca_queue) == 0
866
+ has_update = False
867
+ for alloca_instr, typ in self.alloca_queue:
868
+ has_update = self.process_one_alloca(alloca_instr, typ) or has_update
869
+ if has_update:
870
+ self.update_tags()
871
+ self.alloca_queue = []
872
+
873
+ def post_lowering_process_alloca_queue(self, enter_directive):
874
+ has_update = False
875
+ if DEBUG_OPENMP >= 1:
876
+ print("starting post_lowering_process_alloca_queue")
877
+ for alloca_instr, typ in self.alloca_queue:
878
+ has_update = self.process_one_alloca(alloca_instr, typ) or has_update
879
+ if has_update:
880
+ if DEBUG_OPENMP >= 1:
881
+ print(
882
+ "post_lowering_process_alloca_queue has update:",
883
+ enter_directive.tags,
884
+ )
885
+ enter_directive.tags = openmp_tag_list_to_str(
886
+ self.tags, self.lowerer, False
887
+ )
888
+ # LLVM IR is doing some string caching and the following line is necessary to
889
+ # reset that caching so that the original tag text can be overwritten above.
890
+ enter_directive._clear_string_cache()
891
+ if DEBUG_OPENMP >= 1:
892
+ print(
893
+ "post_lowering_process_alloca_queue updated tags:",
894
+ enter_directive.tags,
895
+ )
896
+ self.alloca_queue = []
897
+
898
+ def process_one_alloca(self, alloca_instr, typ):
899
+ avar = alloca_instr.name
900
+ if DEBUG_OPENMP >= 1:
901
+ print(
902
+ "openmp_region_start process_one_alloca:",
903
+ id(self),
904
+ alloca_instr,
905
+ avar,
906
+ typ,
907
+ type(alloca_instr),
908
+ self.tag_vars,
909
+ )
910
+
911
+ has_update = False
912
+ if (
913
+ self.normal_iv is not None
914
+ and avar != self.normal_iv
915
+ and avar.startswith(self.normal_iv)
916
+ ):
917
+ for i in range(len(self.tags)):
918
+ if DEBUG_OPENMP >= 1:
919
+ print("Replacing normalized iv with", avar)
920
+ self.tags[i].arg = avar
921
+ has_update = True
922
+ break
923
+
924
+ if not self.needs_implicit_vars():
925
+ return has_update
926
+ if avar not in self.tag_vars:
927
+ if DEBUG_OPENMP >= 1:
928
+ print(
929
+ f"LLVM variable {avar} didn't previously exist in the list of vars so adding as private."
930
+ )
931
+ self.add_tag(
932
+ openmp_tag("QUAL.OMP.PRIVATE", alloca_instr)
933
+ ) # is FIRSTPRIVATE right here?
934
+ has_update = True
935
+ return has_update
936
+
937
+ def needs_implicit_vars(self):
938
+ first_tag = self.tags[0]
939
+ if (
940
+ first_tag.name == "DIR.OMP.PARALLEL"
941
+ or first_tag.name == "DIR.OMP.PARALLEL.LOOP"
942
+ or first_tag.name == "DIR.OMP.TASK"
943
+ ):
944
+ return True
945
+ return False
946
+
947
+ def update_context(self, context, builder):
948
+ cctyp = type(context.call_conv)
949
+ # print("start update_context id(context)", id(context), "id(const.call_conv)", id(context.call_conv), "cctyp", cctyp, "id(cctyp)", id(cctyp))
950
+
951
+ if (
952
+ not hasattr(cctyp, "pyomp_patch_installed")
953
+ or not cctyp.pyomp_patch_installed
954
+ ):
955
+ cctyp.pyomp_patch_installed = True
956
+ # print("update_context", "id(cctyp.return_user_exec)", id(cctyp.return_user_exc), "id(context)", id(context))
957
+ setattr(cctyp, "orig_return_user_exc", cctyp.return_user_exc)
958
+
959
+ def pyomp_return_user_exc(self, builder, *args, **kwargs):
960
+ # print("pyomp_return_user_exc")
961
+ # Handle exceptions in OpenMP regions by emitting a trap and an
962
+ # unreachable terminator.
963
+ if in_openmp_region(builder):
964
+ fnty = lir.types.FunctionType(lir.types.VoidType(), [])
965
+ fn = builder.module.declare_intrinsic("llvm.trap", (), fnty)
966
+ builder.call(fn, [])
967
+ builder.unreachable()
968
+ return
969
+ self.orig_return_user_exc(builder, *args, **kwargs)
970
+
971
+ setattr(cctyp, "return_user_exc", pyomp_return_user_exc)
972
+ # print("after", id(pyomp_return_user_exc), id(cctyp.return_user_exc))
973
+
974
+ setattr(
975
+ cctyp, "orig_return_status_propagate", cctyp.return_status_propagate
976
+ )
977
+
978
+ def pyomp_return_status_propagate(self, builder, *args, **kwargs):
979
+ if in_openmp_region(builder):
980
+ return
981
+ self.orig_return_status_propagate(builder, *args, **kwargs)
982
+
983
+ setattr(cctyp, "return_status_propagate", pyomp_return_status_propagate)
984
+
985
+ cemtyp = type(context.error_model)
986
+ # print("start update_context id(context)", id(context), "id(const.error_model)", id(context.error_model), "cemtyp", cemtyp, "id(cemtyp)", id(cemtyp))
987
+
988
+ if (
989
+ not hasattr(cemtyp, "pyomp_patch_installed")
990
+ or not cemtyp.pyomp_patch_installed
991
+ ):
992
+ cemtyp.pyomp_patch_installed = True
993
+ # print("update_context", "id(cemtyp.return_user_exec)", id(cemtyp.fp_zero_division), "id(context)", id(context))
994
+ setattr(cemtyp, "orig_fp_zero_division", cemtyp.fp_zero_division)
995
+
996
+ def pyomp_fp_zero_division(self, builder, *args, **kwargs):
997
+ # print("pyomp_fp_zero_division")
998
+ if in_openmp_region(builder):
999
+ return False
1000
+ return self.orig_fp_zero_division(builder, *args, **kwargs)
1001
+
1002
+ setattr(cemtyp, "fp_zero_division", pyomp_fp_zero_division)
1003
+ # print("after", id(pyomp_fp_zero_division), id(cemtyp.fp_zero_division))
1004
+
1005
+ pyapi = context.get_python_api(builder)
1006
+ ptyp = type(pyapi)
1007
+
1008
+ if not hasattr(ptyp, "pyomp_patch_installed") or not ptyp.pyomp_patch_installed:
1009
+ ptyp.pyomp_patch_installed = True
1010
+ # print("update_context", "id(ptyp.emit_environment_sentry)", id(ptyp.emit_environment_sentry), "id(context)", id(context))
1011
+ setattr(ptyp, "orig_emit_environment_sentry", ptyp.emit_environment_sentry)
1012
+
1013
+ def pyomp_emit_environment_sentry(self, *args, **kwargs):
1014
+ builder = self.builder
1015
+ # print("pyomp_emit_environment_sentry")
1016
+ if in_openmp_region(builder):
1017
+ return False
1018
+ return self.orig_emit_environment_sentry(*args, **kwargs)
1019
+
1020
+ setattr(ptyp, "emit_environment_sentry", pyomp_emit_environment_sentry)
1021
+ # print("after", id(pyomp_emit_environment_sentry), id(ptyp.emit_environment_sentry))
1022
+
1023
+ def fix_dispatchers(self, typemap, typingctx, cuda_target):
1024
+ fixup_dict = {}
1025
+ for k, v in typemap.items():
1026
+ if isinstance(v, Dispatcher) and not isinstance(
1027
+ v, numba_cuda.types.CUDADispatcher
1028
+ ):
1029
+ # targetoptions = v.targetoptions.copy()
1030
+ # targetoptions['device'] = True
1031
+ # targetoptions['debug'] = targetoptions.get('debug', False)
1032
+ # targetoptions['opt'] = targetoptions.get('opt', True)
1033
+ vdispatcher = v.dispatcher
1034
+ vdispatcher.targetoptions.pop("nopython", None)
1035
+ vdispatcher.targetoptions.pop("boundscheck", None)
1036
+ disp = typingctx.resolve_value_type(vdispatcher)
1037
+ fixup_dict[k] = disp
1038
+ for sig in vdispatcher.overloads.keys():
1039
+ disp.dispatcher.compile_device(sig, cuda_target=cuda_target)
1040
+
1041
+ for k, v in fixup_dict.items():
1042
+ del typemap[k]
1043
+ typemap[k] = v
1044
+
1045
+ def lower(self, lowerer):
1046
+ targetctx = lowerer.context
1047
+ typemap = lowerer.fndesc.typemap
1048
+ calltypes = lowerer.fndesc.calltypes
1049
+ context = lowerer.context
1050
+ builder = lowerer.builder
1051
+ mod = builder.module
1052
+ library = lowerer.library
1053
+ library.openmp = True
1054
+ self.block = builder.block
1055
+ self.builder = builder
1056
+ self.lowerer = lowerer
1057
+ self.update_context(context, builder)
1058
+ if DEBUG_OPENMP >= 1:
1059
+ print(
1060
+ "region ids lower:block",
1061
+ id(self),
1062
+ self,
1063
+ id(self.block),
1064
+ self.block,
1065
+ type(self.block),
1066
+ self.tags,
1067
+ len(self.tags),
1068
+ "builder_id:",
1069
+ id(self.builder),
1070
+ "block_id:",
1071
+ id(self.block),
1072
+ )
1073
+ for k, v in lowerer.func_ir.blocks.items():
1074
+ print("block post copy:", k, id(v), id(v.body))
1075
+
1076
+ # Convert implicit tags to explicit form now that we have typing info.
1077
+ for i in range(len(self.tags)):
1078
+ cur_tag = self.tags[i]
1079
+ if cur_tag.name == "QUAL.OMP.TARGET.IMPLICIT":
1080
+ if isinstance(
1081
+ typemap_lookup(typemap, cur_tag.arg), types.npytypes.Array
1082
+ ):
1083
+ cur_tag.name = "QUAL.OMP.MAP.TOFROM"
1084
+ else:
1085
+ cur_tag.name = "QUAL.OMP.FIRSTPRIVATE"
1086
+
1087
+ if DEBUG_OPENMP >= 1:
1088
+ for otag in self.tags:
1089
+ print("otag:", otag, type(otag.arg))
1090
+
1091
+ # Remove LLVM vars that might have been added if this is an OpenMP
1092
+ # region inside a target region.
1093
+ count_alloca_instr = len(
1094
+ list(
1095
+ filter(
1096
+ lambda x: isinstance(x.arg, lir.instructions.AllocaInstr), self.tags
1097
+ )
1098
+ )
1099
+ )
1100
+ assert count_alloca_instr == 0
1101
+ # self.tags = list(filter(lambda x: not isinstance(x.arg, lir.instructions.AllocaInstr), self.tags))
1102
+ if DEBUG_OPENMP >= 1:
1103
+ print("after LLVM tag filter", self.tags, len(self.tags))
1104
+ for otag in self.tags:
1105
+ print("otag:", otag, type(otag.arg))
1106
+
1107
+ host_side_target_tags = []
1108
+ target_num = self.has_target()
1109
+
1110
+ def add_struct_tags(self, var_table):
1111
+ extras_before = []
1112
+ struct_tags = []
1113
+ for i in range(len(self.tags)):
1114
+ cur_tag = self.tags[i]
1115
+ if cur_tag.name in [
1116
+ "QUAL.OMP.MAP.TOFROM",
1117
+ "QUAL.OMP.MAP.TO",
1118
+ "QUAL.OMP.MAP.FROM",
1119
+ "QUAL.OMP.MAP.ALLOC",
1120
+ ]:
1121
+ cur_tag_var = cur_tag.arg
1122
+ if isinstance(cur_tag_var, NameSlice):
1123
+ cur_tag_var = cur_tag_var.name
1124
+ assert isinstance(cur_tag_var, str)
1125
+ cur_tag_typ = typemap_lookup(typemap, cur_tag_var)
1126
+ if isinstance(cur_tag_typ, types.npytypes.Array):
1127
+ stride_typ = lowerer.context.get_value_type(
1128
+ types.intp
1129
+ ) # lir.Type.int(64)
1130
+ stride_abi_size = context.get_abi_sizeof(stride_typ)
1131
+ array_var = var_table[cur_tag_var]
1132
+ if DEBUG_OPENMP >= 1:
1133
+ print(
1134
+ "Found array mapped:",
1135
+ cur_tag.name,
1136
+ cur_tag.arg,
1137
+ cur_tag_typ,
1138
+ type(cur_tag_typ),
1139
+ stride_typ,
1140
+ type(stride_typ),
1141
+ stride_abi_size,
1142
+ array_var,
1143
+ type(array_var),
1144
+ )
1145
+ uniqueness = get_unique()
1146
+ if isinstance(cur_tag.arg, NameSlice):
1147
+ the_slice = cur_tag.arg.the_slice[0][0]
1148
+ assert the_slice.step is None
1149
+ if isinstance(the_slice.start, int):
1150
+ start_index_var = ir.Var(
1151
+ None,
1152
+ f"{cur_tag_var}_start_index_var{target_num}{uniqueness}",
1153
+ array_var.loc,
1154
+ )
1155
+ start_assign = ir.Assign(
1156
+ ir.Const(the_slice.start, array_var.loc),
1157
+ start_index_var,
1158
+ array_var.loc,
1159
+ )
1160
+
1161
+ typemap[start_index_var.name] = types.int64
1162
+ lowerer.lower_inst(start_assign)
1163
+ extras_before.append(start_assign)
1164
+ lowerer._alloca_var(
1165
+ start_index_var.name, typemap[start_index_var.name]
1166
+ )
1167
+ lowerer.loadvar(start_index_var.name)
1168
+ else:
1169
+ start_index_var = the_slice.start
1170
+ assert isinstance(start_index_var, str)
1171
+ start_index_var = ir.Var(
1172
+ None, start_index_var, array_var.loc
1173
+ )
1174
+ if isinstance(the_slice.stop, int):
1175
+ end_index_var = ir.Var(
1176
+ None,
1177
+ f"{cur_tag_var}_end_index_var{target_num}{uniqueness}",
1178
+ array_var.loc,
1179
+ )
1180
+ end_assign = ir.Assign(
1181
+ ir.Const(the_slice.stop, array_var.loc),
1182
+ end_index_var,
1183
+ array_var.loc,
1184
+ )
1185
+ typemap[end_index_var.name] = types.int64
1186
+ lowerer.lower_inst(end_assign)
1187
+ extras_before.append(end_assign)
1188
+ lowerer._alloca_var(
1189
+ end_index_var.name, typemap[end_index_var.name]
1190
+ )
1191
+ lowerer.loadvar(end_index_var.name)
1192
+ else:
1193
+ end_index_var = the_slice.stop
1194
+ assert isinstance(end_index_var, str)
1195
+ end_index_var = ir.Var(
1196
+ None, end_index_var, array_var.loc
1197
+ )
1198
+
1199
+ num_elements_var = ir.Var(
1200
+ None,
1201
+ f"{cur_tag_var}_num_elements_var{target_num}{uniqueness}",
1202
+ array_var.loc,
1203
+ )
1204
+ size_binop = ir.Expr.binop(
1205
+ operator.sub,
1206
+ end_index_var,
1207
+ start_index_var,
1208
+ array_var.loc,
1209
+ )
1210
+ size_assign = ir.Assign(
1211
+ size_binop, num_elements_var, array_var.loc
1212
+ )
1213
+ calltypes[size_binop] = typing.signature(
1214
+ types.int64, types.int64, types.int64
1215
+ )
1216
+ else:
1217
+ start_index_var = 0
1218
+ num_elements_var = ir.Var(
1219
+ None,
1220
+ f"{cur_tag_var}_num_elements_var{target_num}{uniqueness}",
1221
+ array_var.loc,
1222
+ )
1223
+ size_getattr = ir.Expr.getattr(
1224
+ array_var, "size", array_var.loc
1225
+ )
1226
+ size_assign = ir.Assign(
1227
+ size_getattr, num_elements_var, array_var.loc
1228
+ )
1229
+
1230
+ typemap[num_elements_var.name] = types.int64
1231
+ lowerer.lower_inst(size_assign)
1232
+ extras_before.append(size_assign)
1233
+ lowerer._alloca_var(
1234
+ num_elements_var.name, typemap[num_elements_var.name]
1235
+ )
1236
+
1237
+ # see core/datamodel/models.py
1238
+ lowerer.loadvar(num_elements_var.name) # alloca the var
1239
+
1240
+ # see core/datamodel/models.py
1241
+ if isinstance(start_index_var, ir.Var):
1242
+ lowerer.loadvar(start_index_var.name) # alloca the var
1243
+ if isinstance(num_elements_var, ir.Var):
1244
+ lowerer.loadvar(num_elements_var.name) # alloca the var
1245
+ struct_tags.append(
1246
+ openmp_tag(
1247
+ cur_tag.name + ".STRUCT",
1248
+ cur_tag_var + "*data",
1249
+ non_arg=True,
1250
+ omp_slice=(start_index_var, num_elements_var),
1251
+ )
1252
+ )
1253
+ struct_tags.append(
1254
+ openmp_tag(
1255
+ "QUAL.OMP.MAP.TO.STRUCT",
1256
+ cur_tag_var + "*shape",
1257
+ non_arg=True,
1258
+ omp_slice=(0, 1),
1259
+ )
1260
+ )
1261
+ struct_tags.append(
1262
+ openmp_tag(
1263
+ "QUAL.OMP.MAP.TO.STRUCT",
1264
+ cur_tag_var + "*strides",
1265
+ non_arg=True,
1266
+ omp_slice=(0, 1),
1267
+ )
1268
+ )
1269
+ # Peel off NameSlice, it served its purpose and is not
1270
+ # needed by the rest of compilation.
1271
+ if isinstance(cur_tag.arg, NameSlice):
1272
+ cur_tag.arg = cur_tag.arg.name
1273
+
1274
+ return struct_tags, extras_before
1275
+
1276
+ if self.tags[0].name in [
1277
+ "DIR.OMP.TARGET.DATA",
1278
+ "DIR.OMP.TARGET.ENTER.DATA",
1279
+ "DIR.OMP.TARGET.EXIT.DATA",
1280
+ "DIR.OMP.TARGET.UPDATE",
1281
+ ]:
1282
+ var_table = get_name_var_table(lowerer.func_ir.blocks)
1283
+ struct_tags, extras_before = add_struct_tags(self, var_table)
1284
+ self.tags.extend(struct_tags)
1285
+ for extra in extras_before:
1286
+ lowerer.lower_inst(extra)
1287
+
1288
+ elif target_num is not None and not self.target_copy:
1289
+ var_table = get_name_var_table(lowerer.func_ir.blocks)
1290
+
1291
+ ompx_attrs = list(
1292
+ filter(lambda x: x.name == "QUAL.OMP.OMPX_ATTRIBUTE", self.tags)
1293
+ )
1294
+ self.tags = list(
1295
+ filter(lambda x: x.name != "QUAL.OMP.OMPX_ATTRIBUTE", self.tags)
1296
+ )
1297
+ selected_device = 0
1298
+ device_tags = get_tags_of_type(self.tags, "QUAL.OMP.DEVICE")
1299
+ if len(device_tags) > 0:
1300
+ device_tag = device_tags[-1]
1301
+ if isinstance(device_tag.arg, int):
1302
+ selected_device = device_tag.arg
1303
+ else:
1304
+ assert False
1305
+ if DEBUG_OPENMP >= 1:
1306
+ print("new selected device:", selected_device)
1307
+
1308
+ struct_tags, extras_before = add_struct_tags(self, var_table)
1309
+ self.tags.extend(struct_tags)
1310
+ if DEBUG_OPENMP >= 1:
1311
+ for otag in self.tags:
1312
+ print("tag in target:", otag, type(otag.arg))
1313
+
1314
+ from numba.core.compiler import Flags
1315
+
1316
+ if DEBUG_OPENMP >= 1:
1317
+ print("openmp start region lower has target", type(lowerer.func_ir))
1318
+ # Make a copy of the host IR being lowered.
1319
+ dprint_func_ir(lowerer.func_ir, "original func_ir")
1320
+ func_ir = copy_ir(lowerer.func_ir, calltypes)
1321
+ dprint_func_ir(func_ir, "copied func_ir")
1322
+ if DEBUG_OPENMP >= 1:
1323
+ for k, v in lowerer.func_ir.blocks.items():
1324
+ print(
1325
+ "region ids block post copy:",
1326
+ k,
1327
+ id(v),
1328
+ id(func_ir.blocks[k]),
1329
+ id(v.body),
1330
+ id(func_ir.blocks[k].body),
1331
+ )
1332
+
1333
+ remove_dels(func_ir.blocks)
1334
+
1335
+ dprint_func_ir(func_ir, "func_ir after remove_dels")
1336
+
1337
+ def fixup_openmp_pairs(blocks):
1338
+ """The Numba IR nodes for the start and end of an OpenMP region
1339
+ contain references to each other. When a target region is
1340
+ outlined that contains these pairs of IR nodes then if we
1341
+ simply shallow copy them then they'll point to their original
1342
+ matching pair in the original IR. In this function, we go
1343
+ through and find what should be matching pairs in the
1344
+ outlined (target) IR and make those copies point to each
1345
+ other.
1346
+ """
1347
+ start_dict = {}
1348
+ end_dict = {}
1349
+
1350
+ # Go through the blocks in the original IR and create a mapping
1351
+ # between the id of the start nodes with their block label and
1352
+ # position in the block. Likewise, do the same for end nodes.
1353
+ for label, block in func_ir.blocks.items():
1354
+ for bindex, bstmt in enumerate(block.body):
1355
+ if isinstance(bstmt, openmp_region_start):
1356
+ if DEBUG_OPENMP >= 2:
1357
+ print("region ids found region start", id(bstmt))
1358
+ start_dict[id(bstmt)] = (label, bindex)
1359
+ elif isinstance(bstmt, openmp_region_end):
1360
+ if DEBUG_OPENMP >= 2:
1361
+ print(
1362
+ "region ids found region end",
1363
+ id(bstmt.start_region),
1364
+ id(bstmt),
1365
+ )
1366
+ end_dict[id(bstmt.start_region)] = (label, bindex)
1367
+ assert len(start_dict) == len(end_dict)
1368
+
1369
+ # For each start node that we found above, create a copy in the target IR
1370
+ # and fixup the references of the copies to point at each other.
1371
+ for start_id, blockindex in start_dict.items():
1372
+ start_block, sbindex = blockindex
1373
+
1374
+ end_block_index = end_dict[start_id]
1375
+ end_block, ebindex = end_block_index
1376
+
1377
+ if DEBUG_OPENMP >= 2:
1378
+ start_pre_copy = blocks[start_block].body[sbindex]
1379
+ end_pre_copy = blocks[end_block].body[ebindex]
1380
+
1381
+ # Create copy of the OpenMP start and end nodes in the target outlined IR.
1382
+ blocks[start_block].body[sbindex] = copy.copy(
1383
+ blocks[start_block].body[sbindex]
1384
+ )
1385
+ blocks[end_block].body[ebindex] = copy.copy(
1386
+ blocks[end_block].body[ebindex]
1387
+ )
1388
+ # Reset some fields in the start OpenMP region because the target IR
1389
+ # has not been lowered yet.
1390
+ start_region = blocks[start_block].body[sbindex]
1391
+ start_region.builder = None
1392
+ start_region.block = None
1393
+ start_region.lowerer = None
1394
+ start_region.target_copy = True
1395
+ start_region.tags = copy.deepcopy(start_region.tags)
1396
+ # Remove unnecessary num_teams, thread_limit tags when
1397
+ # emitting a target directive within a kernel to avoid
1398
+ # extraneous arguments in the kernel function.
1399
+ if start_region.has_target() == target_num:
1400
+ start_region.tags.append(openmp_tag("OMP.DEVICE"))
1401
+ end_region = blocks[end_block].body[ebindex]
1402
+ # assert(start_region.omp_region_var is None)
1403
+ assert len(start_region.alloca_queue) == 0
1404
+ # Make start and end copies point at each other.
1405
+ end_region.start_region = start_region
1406
+ start_region.end_region = end_region
1407
+ if DEBUG_OPENMP >= 2:
1408
+ print(
1409
+ f"region ids fixup start: {id(start_pre_copy)}->{id(start_region)} end: {id(end_pre_copy)}->{id(end_region)}"
1410
+ )
1411
+
1412
+ fixup_openmp_pairs(func_ir.blocks)
1413
+ state = compiler.StateDict()
1414
+ fndesc = lowerer.fndesc
1415
+ state.typemap = fndesc.typemap
1416
+ state.calltypes = fndesc.calltypes
1417
+ state.argtypes = fndesc.argtypes
1418
+ state.return_type = fndesc.restype
1419
+ if DEBUG_OPENMP >= 1:
1420
+ print("context:", context, type(context))
1421
+ print("targetctx:", targetctx, type(targetctx))
1422
+ print("state:", state, dir(state))
1423
+ print("fndesc:", fndesc, type(fndesc))
1424
+ print("func_ir type:", type(func_ir))
1425
+ dprint_func_ir(func_ir, "target func_ir")
1426
+
1427
+ # Find the start and end IR blocks for this offloaded region.
1428
+ start_block, end_block = find_target_start_end(func_ir, target_num)
1429
+ end_target_node = func_ir.blocks[end_block].body[0]
1430
+
1431
+ if DEBUG_OPENMP >= 1:
1432
+ print("start_block:", start_block)
1433
+ print("end_block:", end_block)
1434
+
1435
+ blocks_in_region = get_blocks_between_start_end(
1436
+ func_ir.blocks, start_block, end_block
1437
+ )
1438
+ if DEBUG_OPENMP >= 1:
1439
+ print("lower blocks_in_region:", blocks_in_region)
1440
+
1441
+ # Find the variables that cross the boundary between the target
1442
+ # region and the non-target host-side code.
1443
+ ins, outs = transforms.find_region_inout_vars(
1444
+ blocks=func_ir.blocks,
1445
+ livemap=func_ir.variable_lifetime.livemap,
1446
+ callfrom=start_block,
1447
+ returnto=end_block,
1448
+ body_block_ids=blocks_in_region,
1449
+ )
1450
+
1451
+ def add_mapped_to_ins(ins, tags):
1452
+ for tag in tags:
1453
+ if tag.arg in ins:
1454
+ continue
1455
+
1456
+ if tag.name in ["QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.MAP.FROM"]:
1457
+ ins.append(tag.arg)
1458
+
1459
+ add_mapped_to_ins(ins, self.tags)
1460
+
1461
+ normalized_ivs = get_tags_of_type(self.tags, "QUAL.OMP.NORMALIZED.IV")
1462
+ if DEBUG_OPENMP >= 1:
1463
+ print("ivs ins", normalized_ivs, ins, outs)
1464
+ for niv in normalized_ivs:
1465
+ if DEBUG_OPENMP >= 1:
1466
+ print("Removing normalized iv from ins", niv.arg)
1467
+ if niv.arg in ins:
1468
+ ins.remove(niv.arg)
1469
+ # Get the types of the variables live-in to the target region.
1470
+ target_args_unordered = ins + list(set(outs) - set(ins))
1471
+ if DEBUG_OPENMP >= 1:
1472
+ print("ins:", ins, type(ins))
1473
+ print("outs:", outs, type(outs))
1474
+ # print("args:", state.args)
1475
+ print("rettype:", state.return_type, type(state.return_type))
1476
+ print("target_args_unordered:", target_args_unordered)
1477
+ # Re-use Numba loop lifting code to extract the target region as
1478
+ # its own function.
1479
+ region_info = transforms._loop_lift_info(
1480
+ loop=None,
1481
+ inputs=ins,
1482
+ # outputs=outs,
1483
+ outputs=(),
1484
+ callfrom=start_block,
1485
+ returnto=end_block,
1486
+ )
1487
+
1488
+ region_blocks = dict((k, func_ir.blocks[k]) for k in blocks_in_region)
1489
+
1490
+ if DEBUG_OPENMP >= 1:
1491
+ print("region_info:", region_info)
1492
+ transforms._loop_lift_prepare_loop_func(region_info, region_blocks)
1493
+ # exit_block_label = max(region_blocks.keys())
1494
+ # region_blocks[exit_block_label].body = []
1495
+ # exit_scope = region_blocks[exit_block_label].scope
1496
+ # tmp = exit_scope.make_temp(loc=func_ir.loc)
1497
+ # region_blocks[exit_block_label].append(ir.Assign(value=ir.Const(0, func_ir.loc), target=tmp, loc=func_ir.loc))
1498
+ # region_blocks[exit_block_label].append(ir.Return(value=tmp, loc=func_ir.loc))
1499
+
1500
+ target_args = []
1501
+ outline_arg_typs = []
1502
+ # outline_arg_typs = [None] * len(target_args_unordered)
1503
+ for tag in self.tags:
1504
+ if DEBUG_OPENMP >= 1:
1505
+ print(1, "target_arg?", tag, tag.non_arg, is_target_arg(tag.name))
1506
+ if (
1507
+ tag.arg in target_args_unordered
1508
+ and not tag.non_arg
1509
+ and is_target_arg(tag.name)
1510
+ ):
1511
+ target_args.append(tag.arg)
1512
+ # target_arg_index = target_args.index(tag.arg)
1513
+ atyp = get_dotted_type(tag.arg, typemap, lowerer)
1514
+ if is_pointer_target_arg(tag.name, atyp):
1515
+ outline_arg_typs.append(types.CPointer(atyp))
1516
+ if DEBUG_OPENMP >= 1:
1517
+ print(1, "found cpointer target_arg", tag, atyp, id(atyp))
1518
+ else:
1519
+ # outline_arg_typs[target_arg_index] = atyp
1520
+ outline_arg_typs.append(atyp)
1521
+ if DEBUG_OPENMP >= 1:
1522
+ print(1, "found target_arg", tag, atyp, id(atyp))
1523
+
1524
+ if DEBUG_OPENMP >= 1:
1525
+ print("target_args:", target_args)
1526
+ print("target_args_unordered:", target_args_unordered)
1527
+ print("outline_arg_typs:", outline_arg_typs)
1528
+ print("extras_before:", extras_before, start_block)
1529
+ for eb in extras_before:
1530
+ print(eb)
1531
+
1532
+ # NOTE: workaround for python 3.10 lowering in numba that may
1533
+ # include a branch converging variable $cp. Remove it to avoid the
1534
+ # assert since the openmp region must be single-entry, single-exit.
1535
+ if sys.version_info >= (3, 10) and sys.version_info < (3, 11):
1536
+ assert len(target_args) == len(
1537
+ [x for x in target_args_unordered if x != "$cp"]
1538
+ )
1539
+ else:
1540
+ assert len(target_args) == len(target_args_unordered)
1541
+ assert len(target_args) == len(outline_arg_typs)
1542
+
1543
+ # Create the outlined IR from the blocks in the region, making the
1544
+ # variables crossing into the regions argument.
1545
+ outlined_ir = func_ir.derive(
1546
+ blocks=region_blocks,
1547
+ arg_names=tuple(target_args),
1548
+ arg_count=len(target_args),
1549
+ force_non_generator=True,
1550
+ )
1551
+ outlined_ir.blocks[start_block].body = (
1552
+ extras_before + outlined_ir.blocks[start_block].body
1553
+ )
1554
+ for stmt in outlined_ir.blocks[min(outlined_ir.blocks.keys())].body:
1555
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg):
1556
+ stmt.value.index = target_args.index(stmt.value.name)
1557
+
1558
+ def prepend_device_to_func_name(outlined_ir):
1559
+ # Change the name of the outlined function to prepend the
1560
+ # word "device" to the function name.
1561
+ fparts = outlined_ir.func_id.func_qualname.split(".")
1562
+ fparts[-1] = "device" + str(target_num) + fparts[-1]
1563
+ outlined_ir.func_id.func_qualname = ".".join(fparts)
1564
+ outlined_ir.func_id.func_name = fparts[-1]
1565
+ uid = next(bytecode.FunctionIdentity._unique_ids)
1566
+ outlined_ir.func_id.unique_name = "{}${}".format(
1567
+ outlined_ir.func_id.func_qualname, uid
1568
+ )
1569
+
1570
+ prepend_device_to_func_name(outlined_ir)
1571
+ device_func_name = outlined_ir.func_id.func_qualname
1572
+ if DEBUG_OPENMP >= 1:
1573
+ print(
1574
+ "outlined_ir:",
1575
+ type(outlined_ir),
1576
+ type(outlined_ir.func_id),
1577
+ outlined_ir.arg_names,
1578
+ device_func_name,
1579
+ )
1580
+ dprint_func_ir(outlined_ir, "outlined_ir")
1581
+
1582
+ # Create a copy of the state and the typemap inside of it so that changes
1583
+ # for compiling the outlined IR don't effect the original compilation state
1584
+ # of the host.
1585
+ state_copy = copy.copy(state)
1586
+ state_copy.typemap = copy.copy(typemap)
1587
+
1588
+ entry_block_num = min(outlined_ir.blocks.keys())
1589
+ entry_block = outlined_ir.blocks[entry_block_num]
1590
+ if DEBUG_OPENMP >= 1:
1591
+ print("entry_block:", entry_block)
1592
+ for x in entry_block.body:
1593
+ print(x)
1594
+ rev_arg_assigns = []
1595
+ # Add entries in the copied typemap for the arguments to the outlined IR.
1596
+ for idx, zipvar in enumerate(zip(target_args, outline_arg_typs)):
1597
+ var_in, vtyp = zipvar
1598
+ arg_name = "arg." + var_in
1599
+ state_copy.typemap.pop(arg_name, None)
1600
+ state_copy.typemap[arg_name] = vtyp
1601
+
1602
+ last_block = outlined_ir.blocks[end_block]
1603
+ last_block.body = (
1604
+ [end_target_node]
1605
+ + last_block.body[:-1]
1606
+ + rev_arg_assigns
1607
+ + last_block.body[-1:]
1608
+ )
1609
+
1610
+ assert isinstance(last_block.body[-1], ir.Return)
1611
+ # Add typemap entry for the empty tuple return type.
1612
+ state_copy.typemap[last_block.body[-1].value.name] = types.none
1613
+ # end test
1614
+
1615
+ if DEBUG_OPENMP >= 1:
1616
+ print("selected_device:", selected_device)
1617
+
1618
+ if selected_device == 1:
1619
+ flags = Flags()
1620
+ flags.enable_ssa = False
1621
+ device_lowerer_pipeline = OnlyLower
1622
+
1623
+ subtarget = OpenmpCPUTargetContext(
1624
+ device_func_name, targetctx.typing_context
1625
+ )
1626
+ # Copy everything (like registries) from cpu context into the new OpenMPCPUTargetContext subtarget
1627
+ # except call_conv which is the whole point of that class so that the minimal call convention is used.
1628
+ subtarget.__dict__.update(
1629
+ {
1630
+ k: targetctx.__dict__[k]
1631
+ for k in targetctx.__dict__.keys() - {"call_conv"}
1632
+ }
1633
+ )
1634
+ # subtarget.install_registry(imputils.builtin_registry)
1635
+ # Turn off the Numba runtime (incref and decref mostly) for the target compilation.
1636
+ subtarget.enable_nrt = False
1637
+ typingctx_outlined = targetctx.typing_context
1638
+
1639
+ import numba.core.codegen as codegen
1640
+
1641
+ subtarget._internal_codegen = codegen.AOTCPUCodegen(
1642
+ mod.name + f"$device{selected_device}"
1643
+ )
1644
+ subtarget._internal_codegen._library_class = CustomAOTCPUCodeLibrary
1645
+ subtarget._internal_codegen._engine.set_object_cache(None, None)
1646
+ device_target = subtarget
1647
+ elif selected_device == 0:
1648
+ from numba.core import target_extension
1649
+
1650
+ orig_target = getattr(
1651
+ target_extension._active_context,
1652
+ "target",
1653
+ target_extension._active_context_default,
1654
+ )
1655
+ target_extension._active_context.target = "cuda"
1656
+
1657
+ flags = cuda_compiler.CUDAFlags()
1658
+
1659
+ typingctx_outlined = cuda_descriptor.cuda_target.typing_context
1660
+ device_target = OpenmpCUDATargetContext(
1661
+ device_func_name, typingctx_outlined
1662
+ )
1663
+ device_target.fndesc = fndesc
1664
+ # device_target = cuda_descriptor.cuda_target.target_context
1665
+
1666
+ device_lowerer_pipeline = OnlyLowerCUDA
1667
+ openmp_cuda_target = numba_cuda.descriptor.CUDATarget("openmp_cuda")
1668
+ openmp_cuda_target._typingctx = typingctx_outlined
1669
+ openmp_cuda_target._targetctx = device_target
1670
+ self.fix_dispatchers(
1671
+ state_copy.typemap, typingctx_outlined, openmp_cuda_target
1672
+ )
1673
+
1674
+ typingctx_outlined.refresh()
1675
+ device_target.refresh()
1676
+ dprint_func_ir(outlined_ir, "outlined_ir before replace np.empty")
1677
+ replace_np_empty_with_cuda_shared(
1678
+ outlined_ir,
1679
+ state_copy.typemap,
1680
+ calltypes,
1681
+ device_func_name,
1682
+ typingctx_outlined,
1683
+ )
1684
+ dprint_func_ir(outlined_ir, "outlined_ir after replace np.empty")
1685
+ else:
1686
+ raise NotImplementedError("Unsupported OpenMP device number")
1687
+
1688
+ device_target.state_copy = state_copy
1689
+ # Do not compile (generate native code), just lower (to LLVM)
1690
+ flags.no_compile = True
1691
+ flags.no_cpython_wrapper = True
1692
+ flags.no_cfunc_wrapper = True
1693
+ # What to do here?
1694
+ flags.forceinline = True
1695
+ # Propagate fastmath flag on the outer function to the inner outlined compile.
1696
+ # TODO: find a good way to handle fastmath. Clang has
1697
+ # fp-contractions on by default for GPU code.
1698
+ # flags.fastmath = True#state_copy.flags.fastmath
1699
+ flags.release_gil = True
1700
+ flags.inline = "always"
1701
+ # Create a pipeline that only lowers the outlined target code. No need to
1702
+ # compile because it has already gone through those passes.
1703
+ if DEBUG_OPENMP >= 1:
1704
+ print(
1705
+ "outlined_ir:",
1706
+ outlined_ir,
1707
+ type(outlined_ir),
1708
+ outlined_ir.arg_names,
1709
+ )
1710
+ dprint_func_ir(outlined_ir, "outlined_ir")
1711
+ dprint_func_ir(func_ir, "target after outline func_ir")
1712
+ dprint_func_ir(lowerer.func_ir, "original func_ir")
1713
+ print("state_copy.typemap:", state_copy.typemap)
1714
+ print("region ids before compile_ir")
1715
+ print(
1716
+ "==================================================================================="
1717
+ )
1718
+ print(
1719
+ "==================================================================================="
1720
+ )
1721
+ print(
1722
+ "==================================================================================="
1723
+ )
1724
+ print(
1725
+ "==================================================================================="
1726
+ )
1727
+ print(
1728
+ "==================================================================================="
1729
+ )
1730
+ print(
1731
+ "==================================================================================="
1732
+ )
1733
+ print(
1734
+ "==================================================================================="
1735
+ )
1736
+
1737
+ cres = compiler.compile_ir(
1738
+ typingctx_outlined,
1739
+ device_target,
1740
+ outlined_ir,
1741
+ outline_arg_typs,
1742
+ types.none,
1743
+ flags,
1744
+ {},
1745
+ pipeline_class=device_lowerer_pipeline,
1746
+ is_lifted_loop=False,
1747
+ ) # tried this as True since code derived from loop lifting code but it goes through the pipeline twice and messes things up
1748
+
1749
+ if DEBUG_OPENMP >= 2:
1750
+ print("cres:", type(cres))
1751
+ print("fndesc:", cres.fndesc, cres.fndesc.mangled_name)
1752
+ print("metadata:", cres.metadata)
1753
+ cres_library = cres.library
1754
+ if DEBUG_OPENMP >= 2:
1755
+ print("cres_library:", type(cres_library))
1756
+ sys.stdout.flush()
1757
+ cres_library._ensure_finalized()
1758
+ if DEBUG_OPENMP >= 2:
1759
+ print("ensure_finalized:")
1760
+ sys.stdout.flush()
1761
+
1762
+ if DEBUG_OPENMP >= 1:
1763
+ print("region ids compile_ir")
1764
+ print(
1765
+ "==================================================================================="
1766
+ )
1767
+ print(
1768
+ "==================================================================================="
1769
+ )
1770
+ print(
1771
+ "==================================================================================="
1772
+ )
1773
+ print(
1774
+ "==================================================================================="
1775
+ )
1776
+ print(
1777
+ "==================================================================================="
1778
+ )
1779
+ print(
1780
+ "==================================================================================="
1781
+ )
1782
+ print(
1783
+ "==================================================================================="
1784
+ )
1785
+
1786
+ for k, v in lowerer.func_ir.blocks.items():
1787
+ print(
1788
+ "block post copy:",
1789
+ k,
1790
+ id(v),
1791
+ id(func_ir.blocks[k]),
1792
+ id(v.body),
1793
+ id(func_ir.blocks[k].body),
1794
+ )
1795
+
1796
+ # TODO: move device pipelines in numba proper.
1797
+ if selected_device == 1:
1798
+ if DEBUG_OPENMP >= 1:
1799
+ with open(cres_library.name + ".ll", "w") as f:
1800
+ f.write(cres_library.get_llvm_str())
1801
+
1802
+ fd_o, filename_o = tempfile.mkstemp(".o")
1803
+ os.close(fd_o)
1804
+ filename_so = Path(filename_o).with_suffix(".so")
1805
+
1806
+ target_elf = cres_library.emit_native_object()
1807
+ with open(filename_o, "wb") as f:
1808
+ f.write(target_elf)
1809
+
1810
+ # Create shared library as required by the libomptarget host
1811
+ # plugin.
1812
+
1813
+ link_shared_library(obj_path=filename_o, out_path=filename_so)
1814
+
1815
+ with open(filename_so, "rb") as f:
1816
+ target_elf = f.read()
1817
+ if DEBUG_OPENMP >= 1:
1818
+ print("filename_o", filename_o, "filename_so", filename_so)
1819
+
1820
+ # Remove the temporary files.
1821
+ os.remove(filename_o)
1822
+ os.remove(filename_so)
1823
+
1824
+ if DEBUG_OPENMP >= 1:
1825
+ print("target_elf:", type(target_elf), len(target_elf))
1826
+ sys.stdout.flush()
1827
+ elif selected_device == 0:
1828
+ target_extension._active_context.target = orig_target
1829
+ omp_cuda_cg = get_omp_cuda_codegen()
1830
+ target_elf = omp_cuda_cg.get_target_image(cres, ompx_attrs)
1831
+ else:
1832
+ raise NotImplementedError("Unsupported OpenMP device number")
1833
+
1834
+ # if cuda then run ptxas on the cres and pass that
1835
+
1836
+ # bytes_array_typ = lir.ArrayType(cgutils.voidptr_t, len(target_elf))
1837
+ # bytes_array_typ = lir.ArrayType(cgutils.int8_t, len(target_elf))
1838
+ # dev_image = cgutils.add_global_variable(mod, bytes_array_typ, ".omp_offloading.device_image")
1839
+ # dev_image.initializer = lir.Constant.array(cgutils.int8_t, target_elf)
1840
+ # dev_image.initializer = lir.Constant.array(cgutils.int8_t, target_elf)
1841
+ add_target_globals_in_numba = int(
1842
+ os.environ.get("NUMBA_OPENMP_ADD_TARGET_GLOBALS", 0)
1843
+ )
1844
+ if add_target_globals_in_numba != 0:
1845
+ elftext = cgutils.make_bytearray(target_elf)
1846
+ dev_image = targetctx.insert_unique_const(
1847
+ mod, ".omp_offloading.device_image", elftext
1848
+ )
1849
+ mangled_name = cgutils.make_bytearray(
1850
+ cres.fndesc.mangled_name.encode("utf-8") + b"\x00"
1851
+ )
1852
+ mangled_var = targetctx.insert_unique_const(
1853
+ mod, ".omp_offloading.entry_name", mangled_name
1854
+ )
1855
+
1856
+ llvmused_typ = lir.ArrayType(cgutils.voidptr_t, 2)
1857
+ llvmused_gv = cgutils.add_global_variable(
1858
+ mod, llvmused_typ, "llvm.used"
1859
+ )
1860
+ llvmused_syms = [
1861
+ lir.Constant.bitcast(dev_image, cgutils.voidptr_t),
1862
+ lir.Constant.bitcast(mangled_var, cgutils.voidptr_t),
1863
+ ]
1864
+ llvmused_gv.initializer = lir.Constant.array(
1865
+ cgutils.voidptr_t, llvmused_syms
1866
+ )
1867
+ llvmused_gv.linkage = "appending"
1868
+ else:
1869
+ host_side_target_tags.append(
1870
+ openmp_tag(
1871
+ "QUAL.OMP.TARGET.DEV_FUNC",
1872
+ StringLiteral(cres.fndesc.mangled_name.encode("utf-8")),
1873
+ )
1874
+ )
1875
+ host_side_target_tags.append(
1876
+ openmp_tag("QUAL.OMP.TARGET.ELF", StringLiteral(target_elf))
1877
+ )
1878
+
1879
+ if DEBUG_OPENMP >= 1:
1880
+ dprint_func_ir(func_ir, "target after outline compiled func_ir")
1881
+
1882
+ llvm_token_t = TokenType()
1883
+ fnty = lir.FunctionType(llvm_token_t, [])
1884
+ tags_to_include = self.tags + host_side_target_tags
1885
+ # tags_to_include = list(filter(lambda x: x.name != "DIR.OMP.TARGET", tags_to_include))
1886
+ self.filtered_tag_length = len(tags_to_include)
1887
+ if DEBUG_OPENMP >= 1:
1888
+ print("filtered_tag_length:", self.filtered_tag_length)
1889
+
1890
+ if len(tags_to_include) > 0:
1891
+ if DEBUG_OPENMP >= 1:
1892
+ print("push_alloca_callbacks")
1893
+
1894
+ push_alloca_callback(lowerer, openmp_region_alloca, self, builder)
1895
+ tag_str = openmp_tag_list_to_str(tags_to_include, lowerer, True)
1896
+ pre_fn = builder.module.declare_intrinsic(
1897
+ "llvm.directive.region.entry", (), fnty
1898
+ )
1899
+ assert self.omp_region_var is None
1900
+ self.omp_region_var = builder.call(pre_fn, [], tail=False)
1901
+ self.omp_region_var.__class__ = CallInstrWithOperandBundle
1902
+ self.omp_region_var.set_tags(tag_str)
1903
+ # This is used by the post-lowering pass over LLVM to add LLVM alloca
1904
+ # vars to the Numba IR openmp node and then when the exit of the region
1905
+ # is detected then the tags in the enter directive are updated.
1906
+ self.omp_region_var.save_orig_numba_openmp = self
1907
+ if DEBUG_OPENMP >= 2:
1908
+ print("setting omp_region_var", self.omp_region_var._get_name())
1909
+ if self.acq_res:
1910
+ builder.fence("acquire")
1911
+ if self.acq_rel:
1912
+ builder.fence("acq_rel")
1913
+
1914
+ for otag in self.tags: # should be tags_to_include?
1915
+ otag.post_entry(lowerer)
1916
+
1917
+ if DEBUG_OPENMP >= 1:
1918
+ sys.stdout.flush()
1919
+
1920
+ def __str__(self):
1921
+ return (
1922
+ "openmp_region_start "
1923
+ + ", ".join([str(x) for x in self.tags])
1924
+ + " target="
1925
+ + str(self.target_copy)
1926
+ )
1927
+
1928
+
1929
+ class openmp_region_end(ir.Stmt):
1930
+ def __init__(self, start_region, tags, loc):
1931
+ if DEBUG_OPENMP >= 1:
1932
+ print("region ids openmp_region_end::__init__", id(self), id(start_region))
1933
+ self.start_region = start_region
1934
+ self.tags = tags
1935
+ self.loc = loc
1936
+ self.start_region.end_region = self
1937
+
1938
+ def __new__(cls, *args, **kwargs):
1939
+ instance = super(openmp_region_end, cls).__new__(cls)
1940
+ # print("openmp_region_end::__new__", id(instance))
1941
+ return instance
1942
+
1943
+ def list_vars(self):
1944
+ return list_vars_from_tags(self.tags)
1945
+
1946
+ def lower(self, lowerer):
1947
+ builder = lowerer.builder
1948
+
1949
+ if DEBUG_OPENMP >= 2:
1950
+ print("openmp_region_end::lower", id(self), id(self.start_region))
1951
+ sys.stdout.flush()
1952
+
1953
+ if self.start_region.acq_res:
1954
+ builder.fence("release")
1955
+
1956
+ if DEBUG_OPENMP >= 1:
1957
+ print("pop_alloca_callbacks")
1958
+
1959
+ if DEBUG_OPENMP >= 2:
1960
+ print("start_region tag length:", self.start_region.filtered_tag_length)
1961
+
1962
+ if self.start_region.filtered_tag_length > 0:
1963
+ llvm_token_t = TokenType()
1964
+ fnty = lir.FunctionType(lir.VoidType(), [llvm_token_t])
1965
+ # The callback is only needed if llvm.directive.region.entry was added
1966
+ # which only happens if tag length > 0.
1967
+ pop_alloca_callback(lowerer, builder)
1968
+
1969
+ # Process the accumulated allocas in the start region.
1970
+ self.start_region.process_alloca_queue()
1971
+
1972
+ assert self.start_region.omp_region_var is not None
1973
+ if DEBUG_OPENMP >= 2:
1974
+ print(
1975
+ "before adding exit", self.start_region.omp_region_var._get_name()
1976
+ )
1977
+
1978
+ for fp in filter(
1979
+ lambda x: x.name == "QUAL.OMP.FIRSTPRIVATE", self.start_region.tags
1980
+ ):
1981
+ new_del = ir.Del(fp.arg, self.loc)
1982
+ lowerer.lower_inst(new_del)
1983
+
1984
+ pre_fn = builder.module.declare_intrinsic(
1985
+ "llvm.directive.region.exit", (), fnty
1986
+ )
1987
+ or_end_call = builder.call(
1988
+ pre_fn, [self.start_region.omp_region_var], tail=True
1989
+ )
1990
+ or_end_call.__class__ = CallInstrWithOperandBundle
1991
+ or_end_call.set_tags(openmp_tag_list_to_str(self.tags, lowerer, True))
1992
+
1993
+ if DEBUG_OPENMP >= 1:
1994
+ print(
1995
+ "OpenMP end lowering firstprivate_dead_after len:",
1996
+ len(self.start_region.firstprivate_dead_after),
1997
+ )
1998
+
1999
+ for fp in self.start_region.firstprivate_dead_after:
2000
+ new_del = ir.Del(fp.arg, self.loc)
2001
+ lowerer.lower_inst(new_del)
2002
+
2003
+ def __str__(self):
2004
+ return "openmp_region_end " + ", ".join([str(x) for x in self.tags])
2005
+
2006
+ def has_target(self):
2007
+ for t in self.tags:
2008
+ if is_target_tag(t.name):
2009
+ return t.arg
2010
+ return None
2011
+
2012
+
2013
+ # Callback for ir_extension_usedefs
2014
+ def openmp_region_start_defs(region, use_set=None, def_set=None):
2015
+ assert isinstance(region, openmp_region_start)
2016
+ if use_set is None:
2017
+ use_set = set()
2018
+ if def_set is None:
2019
+ def_set = set()
2020
+ for tag in region.tags:
2021
+ tag.add_to_usedef_set(use_set, def_set, start=True)
2022
+ return _use_defs_result(usemap=use_set, defmap=def_set)
2023
+
2024
+
2025
+ def openmp_region_end_defs(region, use_set=None, def_set=None):
2026
+ assert isinstance(region, openmp_region_end)
2027
+ if use_set is None:
2028
+ use_set = set()
2029
+ if def_set is None:
2030
+ def_set = set()
2031
+ # We refer to the clauses from the corresponding start of the region.
2032
+ start_region = region.start_region
2033
+ for tag in start_region.tags:
2034
+ tag.add_to_usedef_set(use_set, def_set, start=False)
2035
+ return _use_defs_result(usemap=use_set, defmap=def_set)
2036
+
2037
+
2038
+ # Extend usedef analysis to support openmp_region_start/end nodes.
2039
+ ir_extension_usedefs[openmp_region_start] = openmp_region_start_defs
2040
+ ir_extension_usedefs[openmp_region_end] = openmp_region_end_defs
2041
+
2042
+
2043
+ def openmp_region_start_infer(prs, typeinferer):
2044
+ pass
2045
+
2046
+
2047
+ def openmp_region_end_infer(pre, typeinferer):
2048
+ pass
2049
+
2050
+
2051
+ typeinfer.typeinfer_extensions[openmp_region_start] = openmp_region_start_infer
2052
+ typeinfer.typeinfer_extensions[openmp_region_end] = openmp_region_end_infer
2053
+
2054
+
2055
+ class default_shared_val:
2056
+ def __init__(self, val):
2057
+ self.val = val
2058
+
2059
+
2060
+ def _lower_openmp_region_start(lowerer, prs):
2061
+ # TODO: if we set it always in numba_fixups we can remove from here
2062
+ if isinstance(lowerer.context, OpenmpCPUTargetContext) or isinstance(
2063
+ lowerer.context, OpenmpCUDATargetContext
2064
+ ):
2065
+ pass
2066
+ else:
2067
+ lowerer.library.__class__ = CustomCPUCodeLibrary
2068
+ lowerer.context.__class__ = CustomContext
2069
+ prs.lower(lowerer)
2070
+
2071
+
2072
+ def _lower_openmp_region_end(lowerer, pre):
2073
+ # TODO: if we set it always in numba_fixups we can remove from here
2074
+ if isinstance(lowerer.context, OpenmpCPUTargetContext) or isinstance(
2075
+ lowerer.context, OpenmpCUDATargetContext
2076
+ ):
2077
+ pass
2078
+ else:
2079
+ lowerer.library.__class__ = CustomCPUCodeLibrary
2080
+ lowerer.context.__class__ = CustomContext
2081
+ pre.lower(lowerer)
2082
+
2083
+
2084
+ def apply_copies_openmp_region(
2085
+ region, var_dict, name_var_table, typemap, calltypes, save_copies
2086
+ ):
2087
+ for i in range(len(region.tags)):
2088
+ region.tags[i].replace_vars_inner(var_dict)
2089
+
2090
+
2091
+ apply_copy_propagate_extensions[openmp_region_start] = apply_copies_openmp_region
2092
+ apply_copy_propagate_extensions[openmp_region_end] = apply_copies_openmp_region
2093
+
2094
+
2095
+ def visit_vars_openmp_region(region, callback, cbdata):
2096
+ for i in range(len(region.tags)):
2097
+ if DEBUG_OPENMP >= 1:
2098
+ print("visit_vars before", region.tags[i], type(region.tags[i].arg))
2099
+ region.tags[i].arg = visit_vars_inner(region.tags[i].arg, callback, cbdata)
2100
+ if DEBUG_OPENMP >= 1:
2101
+ print("visit_vars after", region.tags[i])
2102
+
2103
+
2104
+ visit_vars_extensions[openmp_region_start] = visit_vars_openmp_region
2105
+ visit_vars_extensions[openmp_region_end] = visit_vars_openmp_region