numba-cuda 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (88) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +2 -2
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +1 -1
  5. numba_cuda/numba/cuda/api.py +2 -7
  6. numba_cuda/numba/cuda/compiler.py +7 -4
  7. numba_cuda/numba/cuda/core/interpreter.py +3592 -0
  8. numba_cuda/numba/cuda/core/ir_utils.py +2645 -0
  9. numba_cuda/numba/cuda/core/sigutils.py +55 -0
  10. numba_cuda/numba/cuda/cuda_paths.py +9 -17
  11. numba_cuda/numba/cuda/cudadecl.py +1 -1
  12. numba_cuda/numba/cuda/cudadrv/driver.py +4 -19
  13. numba_cuda/numba/cuda/cudadrv/libs.py +1 -2
  14. numba_cuda/numba/cuda/cudadrv/nvrtc.py +44 -44
  15. numba_cuda/numba/cuda/cudadrv/nvvm.py +3 -18
  16. numba_cuda/numba/cuda/cudadrv/runtime.py +12 -1
  17. numba_cuda/numba/cuda/cudamath.py +1 -1
  18. numba_cuda/numba/cuda/decorators.py +4 -3
  19. numba_cuda/numba/cuda/deviceufunc.py +2 -1
  20. numba_cuda/numba/cuda/dispatcher.py +3 -2
  21. numba_cuda/numba/cuda/extending.py +1 -1
  22. numba_cuda/numba/cuda/itanium_mangler.py +211 -0
  23. numba_cuda/numba/cuda/libdevicedecl.py +1 -1
  24. numba_cuda/numba/cuda/libdevicefuncs.py +1 -1
  25. numba_cuda/numba/cuda/lowering.py +1 -1
  26. numba_cuda/numba/cuda/simulator/api.py +1 -1
  27. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +0 -7
  28. numba_cuda/numba/cuda/target.py +1 -2
  29. numba_cuda/numba/cuda/testing.py +4 -6
  30. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +80 -0
  31. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +1 -1
  32. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  33. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +1 -1
  34. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  35. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +1 -1
  36. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +1 -1
  37. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +1 -1
  38. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +4 -6
  39. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +0 -4
  40. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  41. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +1 -3
  42. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +1 -3
  43. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +146 -3
  44. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +1 -1
  45. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +0 -4
  46. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -1
  47. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  49. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +1 -284
  50. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +473 -0
  51. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +1 -1
  52. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  53. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -6
  54. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +1 -1
  55. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +1 -1
  56. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +295 -0
  57. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +1 -1
  58. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  59. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +1 -1
  60. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +5 -1
  61. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +1 -1
  62. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +1 -1
  63. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +1 -1
  64. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +1 -1
  65. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +1 -1
  66. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +1 -1
  67. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +1 -1
  68. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +1 -1
  69. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +1 -1
  70. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +1 -1
  71. numba_cuda/numba/cuda/tests/nocuda/test_import.py +1 -1
  72. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -2
  73. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +1 -1
  74. numba_cuda/numba/cuda/tests/support.py +752 -0
  75. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +3 -3
  76. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +4 -1
  77. numba_cuda/numba/cuda/typing/__init__.py +8 -0
  78. numba_cuda/numba/cuda/typing/templates.py +1453 -0
  79. numba_cuda/numba/cuda/vector_types.py +3 -3
  80. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/METADATA +21 -28
  81. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/RECORD +84 -79
  82. numba_cuda/numba/cuda/include/11/cuda_bf16.h +0 -3749
  83. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +0 -2683
  84. numba_cuda/numba/cuda/include/11/cuda_fp16.h +0 -3794
  85. numba_cuda/numba/cuda/include/11/cuda_fp16.hpp +0 -2614
  86. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/WHEEL +0 -0
  87. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/licenses/LICENSE +0 -0
  88. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2645 @@
1
+ #
2
+ # Copyright (c) 2017 Intel Corporation
3
+ # SPDX-License-Identifier: BSD-2-Clause
4
+ #
5
+
6
+ import numpy
7
+ import math
8
+
9
+ import types as pytypes
10
+ import collections
11
+ import warnings
12
+
13
+ import numba
14
+ from numba.core.extending import _Intrinsic
15
+ from numba.core import types, typing, ir, analysis, postproc, rewrites, config
16
+ from numba.core.typing.templates import signature
17
+ from numba.core.analysis import (
18
+ compute_live_map,
19
+ compute_use_defs,
20
+ compute_cfg_from_blocks,
21
+ )
22
+ from numba.core.errors import (
23
+ TypingError,
24
+ UnsupportedError,
25
+ NumbaPendingDeprecationWarning,
26
+ CompilerError,
27
+ )
28
+
29
+ import copy
30
+
31
+ _unique_var_count = 0
32
+
33
+
34
+ def mk_unique_var(prefix):
35
+ global _unique_var_count
36
+ var = prefix + "." + str(_unique_var_count)
37
+ _unique_var_count = _unique_var_count + 1
38
+ return var
39
+
40
+
41
+ class _MaxLabel:
42
+ def __init__(self, value=0):
43
+ self._value = value
44
+
45
+ def next(self):
46
+ self._value += 1
47
+ return self._value
48
+
49
+ def update(self, newval):
50
+ self._value = max(newval, self._value)
51
+
52
+
53
+ _the_max_label = _MaxLabel()
54
+ del _MaxLabel
55
+
56
+
57
+ def get_unused_var_name(prefix, var_table):
58
+ """Get a new var name with a given prefix and
59
+ make sure it is unused in the given variable table.
60
+ """
61
+ cur = 0
62
+ while True:
63
+ var = prefix + str(cur)
64
+ if var not in var_table:
65
+ return var
66
+ cur += 1
67
+
68
+
69
+ def next_label():
70
+ return _the_max_label.next()
71
+
72
+
73
+ def mk_alloc(
74
+ typingctx, typemap, calltypes, lhs, size_var, dtype, scope, loc, lhs_typ
75
+ ):
76
+ """generate an array allocation with np.empty() and return list of nodes.
77
+ size_var can be an int variable or tuple of int variables.
78
+ lhs_typ is the type of the array being allocated.
79
+ """
80
+ out = []
81
+ ndims = 1
82
+ size_typ = types.intp
83
+ if isinstance(size_var, tuple):
84
+ if len(size_var) == 1:
85
+ size_var = size_var[0]
86
+ size_var = convert_size_to_var(size_var, typemap, scope, loc, out)
87
+ else:
88
+ # tuple_var = build_tuple([size_var...])
89
+ ndims = len(size_var)
90
+ tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc)
91
+ if typemap:
92
+ typemap[tuple_var.name] = types.containers.UniTuple(
93
+ types.intp, ndims
94
+ )
95
+ # constant sizes need to be assigned to vars
96
+ new_sizes = [
97
+ convert_size_to_var(s, typemap, scope, loc, out)
98
+ for s in size_var
99
+ ]
100
+ tuple_call = ir.Expr.build_tuple(new_sizes, loc)
101
+ tuple_assign = ir.Assign(tuple_call, tuple_var, loc)
102
+ out.append(tuple_assign)
103
+ size_var = tuple_var
104
+ size_typ = types.containers.UniTuple(types.intp, ndims)
105
+ if hasattr(lhs_typ, "__allocate__"):
106
+ return lhs_typ.__allocate__(
107
+ typingctx,
108
+ typemap,
109
+ calltypes,
110
+ lhs,
111
+ size_var,
112
+ dtype,
113
+ scope,
114
+ loc,
115
+ lhs_typ,
116
+ size_typ,
117
+ out,
118
+ )
119
+ # g_np_var = Global(numpy)
120
+ g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
121
+ if typemap:
122
+ typemap[g_np_var.name] = types.misc.Module(numpy)
123
+ g_np = ir.Global("np", numpy, loc)
124
+ g_np_assign = ir.Assign(g_np, g_np_var, loc)
125
+ # attr call: empty_attr = getattr(g_np_var, empty)
126
+ empty_attr_call = ir.Expr.getattr(g_np_var, "empty", loc)
127
+ attr_var = ir.Var(scope, mk_unique_var("$empty_attr_attr"), loc)
128
+ if typemap:
129
+ typemap[attr_var.name] = get_np_ufunc_typ(numpy.empty)
130
+ attr_assign = ir.Assign(empty_attr_call, attr_var, loc)
131
+ # Assume str(dtype) returns a valid type
132
+ dtype_str = str(dtype)
133
+ # alloc call: lhs = empty_attr(size_var, typ_var)
134
+ typ_var = ir.Var(scope, mk_unique_var("$np_typ_var"), loc)
135
+ if typemap:
136
+ typemap[typ_var.name] = types.functions.NumberClass(dtype)
137
+ # If dtype is a datetime/timedelta with a unit,
138
+ # then it won't return a valid type and instead can be created
139
+ # with a string. i.e. "datetime64[ns]")
140
+ if (
141
+ isinstance(dtype, (types.NPDatetime, types.NPTimedelta))
142
+ and dtype.unit != ""
143
+ ):
144
+ typename_const = ir.Const(dtype_str, loc)
145
+ typ_var_assign = ir.Assign(typename_const, typ_var, loc)
146
+ else:
147
+ if dtype_str == "bool":
148
+ # empty doesn't like 'bool' sometimes (e.g. kmeans example)
149
+ dtype_str = "bool_"
150
+ np_typ_getattr = ir.Expr.getattr(g_np_var, dtype_str, loc)
151
+ typ_var_assign = ir.Assign(np_typ_getattr, typ_var, loc)
152
+ alloc_call = ir.Expr.call(attr_var, [size_var, typ_var], (), loc)
153
+
154
+ if calltypes:
155
+ cac = typemap[attr_var.name].get_call_type(
156
+ typingctx, [size_typ, types.functions.NumberClass(dtype)], {}
157
+ )
158
+ # By default, all calls to "empty" are typed as returning a standard
159
+ # NumPy ndarray. If we are allocating a ndarray subclass here then
160
+ # just change the return type to be that of the subclass.
161
+ cac._return_type = (
162
+ lhs_typ.copy(layout="C") if lhs_typ.layout == "F" else lhs_typ
163
+ )
164
+ calltypes[alloc_call] = cac
165
+ if lhs_typ.layout == "F":
166
+ empty_c_typ = lhs_typ.copy(layout="C")
167
+ empty_c_var = ir.Var(scope, mk_unique_var("$empty_c_var"), loc)
168
+ if typemap:
169
+ typemap[empty_c_var.name] = lhs_typ.copy(layout="C")
170
+ empty_c_assign = ir.Assign(alloc_call, empty_c_var, loc)
171
+
172
+ # attr call: asfortranarray = getattr(g_np_var, asfortranarray)
173
+ asfortranarray_attr_call = ir.Expr.getattr(
174
+ g_np_var, "asfortranarray", loc
175
+ )
176
+ afa_attr_var = ir.Var(
177
+ scope, mk_unique_var("$asfortran_array_attr"), loc
178
+ )
179
+ if typemap:
180
+ typemap[afa_attr_var.name] = get_np_ufunc_typ(numpy.asfortranarray)
181
+ afa_attr_assign = ir.Assign(asfortranarray_attr_call, afa_attr_var, loc)
182
+ # call asfortranarray
183
+ asfortranarray_call = ir.Expr.call(afa_attr_var, [empty_c_var], (), loc)
184
+ if calltypes:
185
+ calltypes[asfortranarray_call] = typemap[
186
+ afa_attr_var.name
187
+ ].get_call_type(typingctx, [empty_c_typ], {})
188
+
189
+ asfortranarray_assign = ir.Assign(asfortranarray_call, lhs, loc)
190
+
191
+ out.extend(
192
+ [
193
+ g_np_assign,
194
+ attr_assign,
195
+ typ_var_assign,
196
+ empty_c_assign,
197
+ afa_attr_assign,
198
+ asfortranarray_assign,
199
+ ]
200
+ )
201
+ else:
202
+ alloc_assign = ir.Assign(alloc_call, lhs, loc)
203
+ out.extend([g_np_assign, attr_assign, typ_var_assign, alloc_assign])
204
+
205
+ return out
206
+
207
+
208
+ def convert_size_to_var(size_var, typemap, scope, loc, nodes):
209
+ if isinstance(size_var, int):
210
+ new_size = ir.Var(scope, mk_unique_var("$alloc_size"), loc)
211
+ if typemap:
212
+ typemap[new_size.name] = types.intp
213
+ size_assign = ir.Assign(ir.Const(size_var, loc), new_size, loc)
214
+ nodes.append(size_assign)
215
+ return new_size
216
+ assert isinstance(size_var, ir.Var)
217
+ return size_var
218
+
219
+
220
+ def get_np_ufunc_typ(func):
221
+ """get type of the incoming function from builtin registry"""
222
+ for k, v in typing.npydecl.registry.globals:
223
+ if k == func:
224
+ return v
225
+ for k, v in typing.templates.builtin_registry.globals:
226
+ if k == func:
227
+ return v
228
+ raise RuntimeError("type for func ", func, " not found")
229
+
230
+
231
+ def mk_range_block(typemap, start, stop, step, calltypes, scope, loc):
232
+ """make a block that initializes loop range and iteration variables.
233
+ target label in jump needs to be set.
234
+ """
235
+ # g_range_var = Global(range)
236
+ g_range_var = ir.Var(scope, mk_unique_var("$range_g_var"), loc)
237
+ typemap[g_range_var.name] = get_global_func_typ(range)
238
+ g_range = ir.Global("range", range, loc)
239
+ g_range_assign = ir.Assign(g_range, g_range_var, loc)
240
+ arg_nodes, args = _mk_range_args(typemap, start, stop, step, scope, loc)
241
+ # range_call_var = call g_range_var(start, stop, step)
242
+ range_call = ir.Expr.call(g_range_var, args, (), loc)
243
+ calltypes[range_call] = typemap[g_range_var.name].get_call_type(
244
+ typing.Context(), [types.intp] * len(args), {}
245
+ )
246
+ # signature(types.range_state64_type, types.intp)
247
+ range_call_var = ir.Var(scope, mk_unique_var("$range_c_var"), loc)
248
+ typemap[range_call_var.name] = types.iterators.RangeType(types.intp)
249
+ range_call_assign = ir.Assign(range_call, range_call_var, loc)
250
+ # iter_var = getiter(range_call_var)
251
+ iter_call = ir.Expr.getiter(range_call_var, loc)
252
+ if config.USE_LEGACY_TYPE_SYSTEM:
253
+ calltype_sig = signature(
254
+ types.range_iter64_type, types.range_state64_type
255
+ )
256
+ else:
257
+ calltype_sig = signature(types.range_iter_type, types.range_state_type)
258
+ calltypes[iter_call] = calltype_sig
259
+ iter_var = ir.Var(scope, mk_unique_var("$iter_var"), loc)
260
+ typemap[iter_var.name] = types.iterators.RangeIteratorType(types.intp)
261
+ iter_call_assign = ir.Assign(iter_call, iter_var, loc)
262
+ # $phi = iter_var
263
+ phi_var = ir.Var(scope, mk_unique_var("$phi"), loc)
264
+ typemap[phi_var.name] = types.iterators.RangeIteratorType(types.intp)
265
+ phi_assign = ir.Assign(iter_var, phi_var, loc)
266
+ # jump to header
267
+ jump_header = ir.Jump(-1, loc)
268
+ range_block = ir.Block(scope, loc)
269
+ range_block.body = arg_nodes + [
270
+ g_range_assign,
271
+ range_call_assign,
272
+ iter_call_assign,
273
+ phi_assign,
274
+ jump_header,
275
+ ]
276
+ return range_block
277
+
278
+
279
+ def _mk_range_args(typemap, start, stop, step, scope, loc):
280
+ nodes = []
281
+ if isinstance(stop, ir.Var):
282
+ g_stop_var = stop
283
+ else:
284
+ assert isinstance(stop, int)
285
+ g_stop_var = ir.Var(scope, mk_unique_var("$range_stop"), loc)
286
+ if typemap:
287
+ typemap[g_stop_var.name] = types.intp
288
+ stop_assign = ir.Assign(ir.Const(stop, loc), g_stop_var, loc)
289
+ nodes.append(stop_assign)
290
+ if start == 0 and step == 1:
291
+ return nodes, [g_stop_var]
292
+
293
+ if isinstance(start, ir.Var):
294
+ g_start_var = start
295
+ else:
296
+ assert isinstance(start, int)
297
+ g_start_var = ir.Var(scope, mk_unique_var("$range_start"), loc)
298
+ if typemap:
299
+ typemap[g_start_var.name] = types.intp
300
+ start_assign = ir.Assign(ir.Const(start, loc), g_start_var, loc)
301
+ nodes.append(start_assign)
302
+ if step == 1:
303
+ return nodes, [g_start_var, g_stop_var]
304
+
305
+ if isinstance(step, ir.Var):
306
+ g_step_var = step
307
+ else:
308
+ assert isinstance(step, int)
309
+ g_step_var = ir.Var(scope, mk_unique_var("$range_step"), loc)
310
+ if typemap:
311
+ typemap[g_step_var.name] = types.intp
312
+ step_assign = ir.Assign(ir.Const(step, loc), g_step_var, loc)
313
+ nodes.append(step_assign)
314
+
315
+ return nodes, [g_start_var, g_stop_var, g_step_var]
316
+
317
+
318
+ def get_global_func_typ(func):
319
+ """get type variable for func() from builtin registry"""
320
+ for k, v in typing.templates.builtin_registry.globals:
321
+ if k == func:
322
+ return v
323
+ raise RuntimeError("func type not found {}".format(func))
324
+
325
+
326
+ def mk_loop_header(typemap, phi_var, calltypes, scope, loc):
327
+ """make a block that is a loop header updating iteration variables.
328
+ target labels in branch need to be set.
329
+ """
330
+ # iternext_var = iternext(phi_var)
331
+ iternext_var = ir.Var(scope, mk_unique_var("$iternext_var"), loc)
332
+ typemap[iternext_var.name] = types.containers.Pair(
333
+ types.intp, types.boolean
334
+ )
335
+ iternext_call = ir.Expr.iternext(phi_var, loc)
336
+ if config.USE_LEGACY_TYPE_SYSTEM:
337
+ range_iter_type = types.range_iter64_type
338
+ else:
339
+ range_iter_type = types.range_iter_type
340
+ calltypes[iternext_call] = signature(
341
+ types.containers.Pair(types.intp, types.boolean), range_iter_type
342
+ )
343
+ iternext_assign = ir.Assign(iternext_call, iternext_var, loc)
344
+ # pair_first_var = pair_first(iternext_var)
345
+ pair_first_var = ir.Var(scope, mk_unique_var("$pair_first_var"), loc)
346
+ typemap[pair_first_var.name] = types.intp
347
+ pair_first_call = ir.Expr.pair_first(iternext_var, loc)
348
+ pair_first_assign = ir.Assign(pair_first_call, pair_first_var, loc)
349
+ # pair_second_var = pair_second(iternext_var)
350
+ pair_second_var = ir.Var(scope, mk_unique_var("$pair_second_var"), loc)
351
+ typemap[pair_second_var.name] = types.boolean
352
+ pair_second_call = ir.Expr.pair_second(iternext_var, loc)
353
+ pair_second_assign = ir.Assign(pair_second_call, pair_second_var, loc)
354
+ # phi_b_var = pair_first_var
355
+ phi_b_var = ir.Var(scope, mk_unique_var("$phi"), loc)
356
+ typemap[phi_b_var.name] = types.intp
357
+ phi_b_assign = ir.Assign(pair_first_var, phi_b_var, loc)
358
+ # branch pair_second_var body_block out_block
359
+ branch = ir.Branch(pair_second_var, -1, -1, loc)
360
+ header_block = ir.Block(scope, loc)
361
+ header_block.body = [
362
+ iternext_assign,
363
+ pair_first_assign,
364
+ pair_second_assign,
365
+ phi_b_assign,
366
+ branch,
367
+ ]
368
+ return header_block
369
+
370
+
371
+ def legalize_names(varnames):
372
+ """returns a dictionary for conversion of variable names to legal
373
+ parameter names.
374
+ """
375
+ var_map = {}
376
+ for var in varnames:
377
+ new_name = var.replace("_", "__").replace("$", "_").replace(".", "_")
378
+ assert new_name not in var_map
379
+ var_map[var] = new_name
380
+ return var_map
381
+
382
+
383
+ def get_name_var_table(blocks):
384
+ """create a mapping from variable names to their ir.Var objects"""
385
+
386
+ def get_name_var_visit(var, namevar):
387
+ namevar[var.name] = var
388
+ return var
389
+
390
+ namevar = {}
391
+ visit_vars(blocks, get_name_var_visit, namevar)
392
+ return namevar
393
+
394
+
395
+ def replace_var_names(blocks, namedict):
396
+ """replace variables (ir.Var to ir.Var) from dictionary (name -> name)"""
397
+ # remove identity values to avoid infinite loop
398
+ new_namedict = {}
399
+ for l, r in namedict.items():
400
+ if l != r:
401
+ new_namedict[l] = r
402
+
403
+ def replace_name(var, namedict):
404
+ assert isinstance(var, ir.Var)
405
+ while var.name in namedict:
406
+ var = ir.Var(var.scope, namedict[var.name], var.loc)
407
+ return var
408
+
409
+ visit_vars(blocks, replace_name, new_namedict)
410
+
411
+
412
+ def replace_var_callback(var, vardict):
413
+ assert isinstance(var, ir.Var)
414
+ while var.name in vardict.keys():
415
+ assert vardict[var.name].name != var.name
416
+ new_var = vardict[var.name]
417
+ var = ir.Var(new_var.scope, new_var.name, new_var.loc)
418
+ return var
419
+
420
+
421
+ def replace_vars(blocks, vardict):
422
+ """replace variables (ir.Var to ir.Var) from dictionary (name -> ir.Var)"""
423
+ # remove identity values to avoid infinite loop
424
+ new_vardict = {}
425
+ for l, r in vardict.items():
426
+ if l != r.name:
427
+ new_vardict[l] = r
428
+ visit_vars(blocks, replace_var_callback, new_vardict)
429
+
430
+
431
+ def replace_vars_stmt(stmt, vardict):
432
+ visit_vars_stmt(stmt, replace_var_callback, vardict)
433
+
434
+
435
+ def replace_vars_inner(node, vardict):
436
+ return visit_vars_inner(node, replace_var_callback, vardict)
437
+
438
+
439
+ # other packages that define new nodes add calls to visit variables in them
440
+ # format: {type:function}
441
+ visit_vars_extensions = {}
442
+
443
+
444
+ def visit_vars(blocks, callback, cbdata):
445
+ """go over statements of block bodies and replace variable names with
446
+ dictionary.
447
+ """
448
+ for block in blocks.values():
449
+ for stmt in block.body:
450
+ visit_vars_stmt(stmt, callback, cbdata)
451
+ return
452
+
453
+
454
+ def visit_vars_stmt(stmt, callback, cbdata):
455
+ # let external calls handle stmt if type matches
456
+ for t, f in visit_vars_extensions.items():
457
+ if isinstance(stmt, t):
458
+ f(stmt, callback, cbdata)
459
+ return
460
+ if isinstance(stmt, ir.Assign):
461
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
462
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
463
+ elif isinstance(stmt, ir.Arg):
464
+ stmt.name = visit_vars_inner(stmt.name, callback, cbdata)
465
+ elif isinstance(stmt, ir.Return):
466
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
467
+ elif isinstance(stmt, ir.Raise):
468
+ stmt.exception = visit_vars_inner(stmt.exception, callback, cbdata)
469
+ elif isinstance(stmt, ir.Branch):
470
+ stmt.cond = visit_vars_inner(stmt.cond, callback, cbdata)
471
+ elif isinstance(stmt, ir.Jump):
472
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
473
+ elif isinstance(stmt, ir.Del):
474
+ # Because Del takes only a var name, we make up by
475
+ # constructing a temporary variable.
476
+ var = ir.Var(None, stmt.value, stmt.loc)
477
+ var = visit_vars_inner(var, callback, cbdata)
478
+ stmt.value = var.name
479
+ elif isinstance(stmt, ir.DelAttr):
480
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
481
+ stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata)
482
+ elif isinstance(stmt, ir.SetAttr):
483
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
484
+ stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata)
485
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
486
+ elif isinstance(stmt, ir.DelItem):
487
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
488
+ stmt.index = visit_vars_inner(stmt.index, callback, cbdata)
489
+ elif isinstance(stmt, ir.StaticSetItem):
490
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
491
+ stmt.index_var = visit_vars_inner(stmt.index_var, callback, cbdata)
492
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
493
+ elif isinstance(stmt, ir.SetItem):
494
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
495
+ stmt.index = visit_vars_inner(stmt.index, callback, cbdata)
496
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
497
+ elif isinstance(stmt, ir.Print):
498
+ stmt.args = [visit_vars_inner(x, callback, cbdata) for x in stmt.args]
499
+ else:
500
+ # TODO: raise NotImplementedError("no replacement for IR node: ", stmt)
501
+ pass
502
+ return
503
+
504
+
505
+ def visit_vars_inner(node, callback, cbdata):
506
+ if isinstance(node, ir.Var):
507
+ return callback(node, cbdata)
508
+ elif isinstance(node, list):
509
+ return [visit_vars_inner(n, callback, cbdata) for n in node]
510
+ elif isinstance(node, tuple):
511
+ return tuple([visit_vars_inner(n, callback, cbdata) for n in node])
512
+ elif isinstance(node, ir.Expr):
513
+ # if node.op in ['binop', 'inplace_binop']:
514
+ # lhs = node.lhs.name
515
+ # rhs = node.rhs.name
516
+ # node.lhs.name = callback, cbdata.get(lhs, lhs)
517
+ # node.rhs.name = callback, cbdata.get(rhs, rhs)
518
+ for arg in node._kws.keys():
519
+ node._kws[arg] = visit_vars_inner(node._kws[arg], callback, cbdata)
520
+ elif isinstance(node, ir.Yield):
521
+ node.value = visit_vars_inner(node.value, callback, cbdata)
522
+ return node
523
+
524
+
525
+ add_offset_to_labels_extensions = {}
526
+
527
+
528
+ def add_offset_to_labels(blocks, offset):
529
+ """add an offset to all block labels and jump/branch targets"""
530
+ new_blocks = {}
531
+ for l, b in blocks.items():
532
+ # some parfor last blocks might be empty
533
+ term = None
534
+ if b.body:
535
+ term = b.body[-1]
536
+ for inst in b.body:
537
+ for T, f in add_offset_to_labels_extensions.items():
538
+ if isinstance(inst, T):
539
+ f(inst, offset)
540
+ if isinstance(term, ir.Jump):
541
+ b.body[-1] = ir.Jump(term.target + offset, term.loc)
542
+ if isinstance(term, ir.Branch):
543
+ b.body[-1] = ir.Branch(
544
+ term.cond, term.truebr + offset, term.falsebr + offset, term.loc
545
+ )
546
+ new_blocks[l + offset] = b
547
+ return new_blocks
548
+
549
+
550
+ find_max_label_extensions = {}
551
+
552
+
553
+ def find_max_label(blocks):
554
+ max_label = 0
555
+ for l, b in blocks.items():
556
+ if b.body:
557
+ for inst in b.body:
558
+ for T, f in find_max_label_extensions.items():
559
+ if isinstance(inst, T):
560
+ f_max = f(inst)
561
+ if f_max > max_label:
562
+ max_label = f_max
563
+ if l > max_label:
564
+ max_label = l
565
+ return max_label
566
+
567
+
568
+ def flatten_labels(blocks):
569
+ """makes the labels in range(0, len(blocks)), useful to compare CFGs"""
570
+ # first bulk move the labels out of the rewrite range
571
+ blocks = add_offset_to_labels(blocks, find_max_label(blocks) + 1)
572
+ # order them in topo order because it's easier to read
573
+ new_blocks = {}
574
+ topo_order = find_topo_order(blocks)
575
+ l_map = dict()
576
+ idx = 0
577
+ for x in topo_order:
578
+ l_map[x] = idx
579
+ idx += 1
580
+
581
+ for t_node in topo_order:
582
+ b = blocks[t_node]
583
+ # some parfor last blocks might be empty
584
+ term = None
585
+ if b.body:
586
+ term = b.body[-1]
587
+ if isinstance(term, ir.Jump):
588
+ b.body[-1] = ir.Jump(l_map[term.target], term.loc)
589
+ if isinstance(term, ir.Branch):
590
+ b.body[-1] = ir.Branch(
591
+ term.cond, l_map[term.truebr], l_map[term.falsebr], term.loc
592
+ )
593
+ new_blocks[l_map[t_node]] = b
594
+ return new_blocks
595
+
596
+
597
+ def remove_dels(blocks):
598
+ """remove ir.Del nodes"""
599
+ for block in blocks.values():
600
+ new_body = []
601
+ for stmt in block.body:
602
+ if not isinstance(stmt, ir.Del):
603
+ new_body.append(stmt)
604
+ block.body = new_body
605
+ return
606
+
607
+
608
+ def remove_args(blocks):
609
+ """remove ir.Arg nodes"""
610
+ for block in blocks.values():
611
+ new_body = []
612
+ for stmt in block.body:
613
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg):
614
+ continue
615
+ new_body.append(stmt)
616
+ block.body = new_body
617
+ return
618
+
619
+
620
+ def dead_code_elimination(
621
+ func_ir, typemap=None, alias_map=None, arg_aliases=None
622
+ ):
623
+ """Performs dead code elimination and leaves the IR in a valid state on
624
+ exit
625
+ """
626
+ do_post_proc = False
627
+ while remove_dead(
628
+ func_ir.blocks,
629
+ func_ir.arg_names,
630
+ func_ir,
631
+ typemap,
632
+ alias_map,
633
+ arg_aliases,
634
+ ):
635
+ do_post_proc = True
636
+
637
+ if do_post_proc:
638
+ post_proc = postproc.PostProcessor(func_ir)
639
+ post_proc.run()
640
+
641
+
642
+ def remove_dead(
643
+ blocks, args, func_ir, typemap=None, alias_map=None, arg_aliases=None
644
+ ):
645
+ """dead code elimination using liveness and CFG info.
646
+ Returns True if something has been removed, or False if nothing is removed.
647
+ """
648
+ cfg = compute_cfg_from_blocks(blocks)
649
+ usedefs = compute_use_defs(blocks)
650
+ live_map = compute_live_map(cfg, blocks, usedefs.usemap, usedefs.defmap)
651
+ call_table, _ = get_call_table(blocks)
652
+ if alias_map is None or arg_aliases is None:
653
+ alias_map, arg_aliases = find_potential_aliases(
654
+ blocks, args, typemap, func_ir
655
+ )
656
+ if config.DEBUG_ARRAY_OPT >= 1:
657
+ print("args:", args)
658
+ print("alias map:", alias_map)
659
+ print("arg_aliases:", arg_aliases)
660
+ print("live_map:", live_map)
661
+ print("usemap:", usedefs.usemap)
662
+ print("defmap:", usedefs.defmap)
663
+ # keep set for easier search
664
+ alias_set = set(alias_map.keys())
665
+
666
+ removed = False
667
+ for label, block in blocks.items():
668
+ # find live variables at each statement to delete dead assignment
669
+ lives = {v.name for v in block.terminator.list_vars()}
670
+ if config.DEBUG_ARRAY_OPT >= 2:
671
+ print("remove_dead processing block", label, lives)
672
+ # find live variables at the end of block
673
+ for out_blk, _data in cfg.successors(label):
674
+ if config.DEBUG_ARRAY_OPT >= 2:
675
+ print("succ live_map", out_blk, live_map[out_blk])
676
+ lives |= live_map[out_blk]
677
+ removed |= remove_dead_block(
678
+ block,
679
+ lives,
680
+ call_table,
681
+ arg_aliases,
682
+ alias_map,
683
+ alias_set,
684
+ func_ir,
685
+ typemap,
686
+ )
687
+
688
+ return removed
689
+
690
+
691
+ # other packages that define new nodes add calls to remove dead code in them
692
+ # format: {type:function}
693
+ remove_dead_extensions = {}
694
+
695
+
696
+ def remove_dead_block(
697
+ block,
698
+ lives,
699
+ call_table,
700
+ arg_aliases,
701
+ alias_map,
702
+ alias_set,
703
+ func_ir,
704
+ typemap,
705
+ ):
706
+ """remove dead code using liveness info.
707
+ Mutable arguments (e.g. arrays) that are not definitely assigned are live
708
+ after return of function.
709
+ """
710
+ # TODO: find mutable args that are not definitely assigned instead of
711
+ # assuming all args are live after return
712
+ removed = False
713
+
714
+ # add statements in reverse order
715
+ new_body = [block.terminator]
716
+ # for each statement in reverse order, excluding terminator
717
+ for stmt in reversed(block.body[:-1]):
718
+ if config.DEBUG_ARRAY_OPT >= 2:
719
+ print("remove_dead_block", stmt)
720
+ # aliases of lives are also live
721
+ alias_lives = set()
722
+ init_alias_lives = lives & alias_set
723
+ for v in init_alias_lives:
724
+ alias_lives |= alias_map[v]
725
+ lives_n_aliases = lives | alias_lives | arg_aliases
726
+
727
+ # let external calls handle stmt if type matches
728
+ if type(stmt) in remove_dead_extensions:
729
+ f = remove_dead_extensions[type(stmt)]
730
+ stmt = f(
731
+ stmt,
732
+ lives,
733
+ lives_n_aliases,
734
+ arg_aliases,
735
+ alias_map,
736
+ func_ir,
737
+ typemap,
738
+ )
739
+ if stmt is None:
740
+ if config.DEBUG_ARRAY_OPT >= 2:
741
+ print("Statement was removed.")
742
+ removed = True
743
+ continue
744
+
745
+ # ignore assignments that their lhs is not live or lhs==rhs
746
+ if isinstance(stmt, ir.Assign):
747
+ lhs = stmt.target
748
+ rhs = stmt.value
749
+ if lhs.name not in lives and has_no_side_effect(
750
+ rhs, lives_n_aliases, call_table
751
+ ):
752
+ if config.DEBUG_ARRAY_OPT >= 2:
753
+ print("Statement was removed.")
754
+ removed = True
755
+ continue
756
+ if isinstance(rhs, ir.Var) and lhs.name == rhs.name:
757
+ if config.DEBUG_ARRAY_OPT >= 2:
758
+ print("Statement was removed.")
759
+ removed = True
760
+ continue
761
+ # TODO: remove other nodes like SetItem etc.
762
+
763
+ if isinstance(stmt, ir.Del):
764
+ if stmt.value not in lives:
765
+ if config.DEBUG_ARRAY_OPT >= 2:
766
+ print("Statement was removed.")
767
+ removed = True
768
+ continue
769
+
770
+ if isinstance(stmt, ir.SetItem):
771
+ name = stmt.target.name
772
+ if name not in lives_n_aliases:
773
+ if config.DEBUG_ARRAY_OPT >= 2:
774
+ print("Statement was removed.")
775
+ continue
776
+
777
+ if type(stmt) in analysis.ir_extension_usedefs:
778
+ def_func = analysis.ir_extension_usedefs[type(stmt)]
779
+ uses, defs = def_func(stmt)
780
+ lives -= defs
781
+ lives |= uses
782
+ else:
783
+ lives |= {v.name for v in stmt.list_vars()}
784
+ if isinstance(stmt, ir.Assign):
785
+ # make sure lhs is not used in rhs, e.g. a = g(a)
786
+ if isinstance(stmt.value, ir.Expr):
787
+ rhs_vars = {v.name for v in stmt.value.list_vars()}
788
+ if lhs.name not in rhs_vars:
789
+ lives.remove(lhs.name)
790
+ else:
791
+ lives.remove(lhs.name)
792
+
793
+ new_body.append(stmt)
794
+ new_body.reverse()
795
+ block.body = new_body
796
+ return removed
797
+
798
+
799
+ # list of functions
800
+ remove_call_handlers = []
801
+
802
+
803
+ def remove_dead_random_call(rhs, lives, call_list):
804
+ if len(call_list) == 3 and call_list[1:] == ["random", numpy]:
805
+ return call_list[0] not in {"seed", "shuffle"}
806
+ return False
807
+
808
+
809
+ remove_call_handlers.append(remove_dead_random_call)
810
+
811
+
812
+ def has_no_side_effect(rhs, lives, call_table):
813
+ """Returns True if this expression has no side effects that
814
+ would prevent re-ordering.
815
+ """
816
+ from numba.parfors import array_analysis, parfor
817
+ from numba.misc.special import prange
818
+
819
+ if isinstance(rhs, ir.Expr) and rhs.op == "call":
820
+ func_name = rhs.func.name
821
+ if func_name not in call_table or call_table[func_name] == []:
822
+ return False
823
+ call_list = call_table[func_name]
824
+ if (
825
+ call_list == ["empty", numpy]
826
+ or call_list == [slice]
827
+ or call_list == ["stencil", numba]
828
+ or call_list == ["log", numpy]
829
+ or call_list == ["dtype", numpy]
830
+ or call_list == [array_analysis.wrap_index]
831
+ or call_list == [prange]
832
+ or call_list == ["prange", numba]
833
+ or call_list == ["pndindex", numba]
834
+ or call_list == [parfor.internal_prange]
835
+ or call_list == ["ceil", math]
836
+ or call_list == [max]
837
+ or call_list == [int]
838
+ ):
839
+ return True
840
+ elif isinstance(call_list[0], _Intrinsic) and (
841
+ call_list[0]._name == "empty_inferred"
842
+ or call_list[0]._name == "unsafe_empty_inferred"
843
+ ):
844
+ return True
845
+ from numba.core.registry import CPUDispatcher
846
+ from numba.np.linalg import dot_3_mv_check_args
847
+
848
+ if isinstance(call_list[0], CPUDispatcher):
849
+ py_func = call_list[0].py_func
850
+ if py_func == dot_3_mv_check_args:
851
+ return True
852
+ for f in remove_call_handlers:
853
+ if f(rhs, lives, call_list):
854
+ return True
855
+ return False
856
+ if isinstance(rhs, ir.Expr) and rhs.op == "inplace_binop":
857
+ return rhs.lhs.name not in lives
858
+ if isinstance(rhs, ir.Yield):
859
+ return False
860
+ if isinstance(rhs, ir.Expr) and rhs.op == "pair_first":
861
+ # don't remove pair_first since prange looks for it
862
+ return False
863
+ return True
864
+
865
+
866
+ is_pure_extensions = []
867
+
868
+
869
+ def is_pure(rhs, lives, call_table):
870
+ """Returns True if every time this expression is evaluated it
871
+ returns the same result. This is not the case for things
872
+ like calls to numpy.random.
873
+ """
874
+ if isinstance(rhs, ir.Expr):
875
+ if rhs.op == "call":
876
+ func_name = rhs.func.name
877
+ if func_name not in call_table or call_table[func_name] == []:
878
+ return False
879
+ call_list = call_table[func_name]
880
+ if (
881
+ call_list == [slice]
882
+ or call_list == ["log", numpy]
883
+ or call_list == ["empty", numpy]
884
+ or call_list == ["ceil", math]
885
+ or call_list == [max]
886
+ or call_list == [int]
887
+ ):
888
+ return True
889
+ for f in is_pure_extensions:
890
+ if f(rhs, lives, call_list):
891
+ return True
892
+ return False
893
+ elif rhs.op == "getiter" or rhs.op == "iternext":
894
+ return False
895
+ if isinstance(rhs, ir.Yield):
896
+ return False
897
+ return True
898
+
899
+
900
+ def is_const_call(module_name, func_name):
901
+ # Returns True if there is no state in the given module changed by the given function.
902
+ if module_name == "numpy":
903
+ if func_name in ["empty"]:
904
+ return True
905
+ return False
906
+
907
+
908
+ alias_analysis_extensions = {}
909
+ alias_func_extensions = {}
910
+
911
+
912
+ def get_canonical_alias(v, alias_map):
913
+ if v not in alias_map:
914
+ return v
915
+
916
+ v_aliases = sorted(list(alias_map[v]))
917
+ return v_aliases[0]
918
+
919
+
920
+ def find_potential_aliases(
921
+ blocks, args, typemap, func_ir, alias_map=None, arg_aliases=None
922
+ ):
923
+ "find all array aliases and argument aliases to avoid remove as dead"
924
+ if alias_map is None:
925
+ alias_map = {}
926
+ if arg_aliases is None:
927
+ arg_aliases = set(a for a in args if not is_immutable_type(a, typemap))
928
+
929
+ # update definitions since they are not guaranteed to be up-to-date
930
+ # FIXME keep definitions up-to-date to avoid the need for rebuilding
931
+ func_ir._definitions = build_definitions(func_ir.blocks)
932
+ np_alias_funcs = ["ravel", "transpose", "reshape"]
933
+
934
+ for bl in blocks.values():
935
+ for instr in bl.body:
936
+ if type(instr) in alias_analysis_extensions:
937
+ f = alias_analysis_extensions[type(instr)]
938
+ f(instr, args, typemap, func_ir, alias_map, arg_aliases)
939
+ if isinstance(instr, ir.Assign):
940
+ expr = instr.value
941
+ lhs = instr.target.name
942
+ # only mutable types can alias
943
+ if is_immutable_type(lhs, typemap):
944
+ continue
945
+ if isinstance(expr, ir.Var) and lhs != expr.name:
946
+ _add_alias(lhs, expr.name, alias_map, arg_aliases)
947
+ # subarrays like A = B[0] for 2D B
948
+ if isinstance(expr, ir.Expr) and (
949
+ expr.op == "cast"
950
+ or expr.op in ["getitem", "static_getitem"]
951
+ ):
952
+ _add_alias(lhs, expr.value.name, alias_map, arg_aliases)
953
+ if isinstance(expr, ir.Expr) and expr.op == "inplace_binop":
954
+ _add_alias(lhs, expr.lhs.name, alias_map, arg_aliases)
955
+ # array attributes like A.T
956
+ if (
957
+ isinstance(expr, ir.Expr)
958
+ and expr.op == "getattr"
959
+ and expr.attr in ["T", "ctypes", "flat"]
960
+ ):
961
+ _add_alias(lhs, expr.value.name, alias_map, arg_aliases)
962
+ # a = b.c. a should alias b
963
+ if (
964
+ isinstance(expr, ir.Expr)
965
+ and expr.op == "getattr"
966
+ and expr.attr not in ["shape"]
967
+ and expr.value.name in arg_aliases
968
+ ):
969
+ _add_alias(lhs, expr.value.name, alias_map, arg_aliases)
970
+ # calls that can create aliases such as B = A.ravel()
971
+ if isinstance(expr, ir.Expr) and expr.op == "call":
972
+ fdef = guard(find_callname, func_ir, expr, typemap)
973
+ # TODO: sometimes gufunc backend creates duplicate code
974
+ # causing find_callname to fail. Example: test_argmax
975
+ # ignored here since those cases don't create aliases
976
+ # but should be fixed in general
977
+ if fdef is None:
978
+ continue
979
+ fname, fmod = fdef
980
+ if fdef in alias_func_extensions:
981
+ alias_func = alias_func_extensions[fdef]
982
+ alias_func(lhs, expr.args, alias_map, arg_aliases)
983
+ if fmod == "numpy" and fname in np_alias_funcs:
984
+ _add_alias(
985
+ lhs, expr.args[0].name, alias_map, arg_aliases
986
+ )
987
+ if isinstance(fmod, ir.Var) and fname in np_alias_funcs:
988
+ _add_alias(lhs, fmod.name, alias_map, arg_aliases)
989
+
990
+ # copy to avoid changing size during iteration
991
+ old_alias_map = copy.deepcopy(alias_map)
992
+ # combine all aliases transitively
993
+ for v in old_alias_map:
994
+ for w in old_alias_map[v]:
995
+ alias_map[v] |= alias_map[w]
996
+ for w in old_alias_map[v]:
997
+ alias_map[w] = alias_map[v]
998
+
999
+ return alias_map, arg_aliases
1000
+
1001
+
1002
+ def _add_alias(lhs, rhs, alias_map, arg_aliases):
1003
+ if rhs in arg_aliases:
1004
+ arg_aliases.add(lhs)
1005
+ else:
1006
+ if rhs not in alias_map:
1007
+ alias_map[rhs] = set()
1008
+ if lhs not in alias_map:
1009
+ alias_map[lhs] = set()
1010
+ alias_map[rhs].add(lhs)
1011
+ alias_map[lhs].add(rhs)
1012
+ return
1013
+
1014
+
1015
+ def is_immutable_type(var, typemap):
1016
+ # Conservatively, assume mutable if type not available
1017
+ if typemap is None or var not in typemap:
1018
+ return False
1019
+ typ = typemap[var]
1020
+ # TODO: add more immutable types
1021
+ if isinstance(
1022
+ typ,
1023
+ (
1024
+ types.Number,
1025
+ types.scalars._NPDatetimeBase,
1026
+ types.iterators.RangeType,
1027
+ ),
1028
+ ):
1029
+ return True
1030
+ if typ == types.string:
1031
+ return True
1032
+ # conservatively, assume mutable
1033
+ return False
1034
+
1035
+
1036
+ def copy_propagate(blocks, typemap):
1037
+ """compute copy propagation information for each block using fixed-point
1038
+ iteration on data flow equations:
1039
+ in_b = intersect(predec(B))
1040
+ out_b = gen_b | (in_b - kill_b)
1041
+ """
1042
+ cfg = compute_cfg_from_blocks(blocks)
1043
+ entry = cfg.entry_point()
1044
+
1045
+ # format: dict of block labels to copies as tuples
1046
+ # label -> (l,r)
1047
+ c_data = init_copy_propagate_data(blocks, entry, typemap)
1048
+ (gen_copies, all_copies, kill_copies, in_copies, out_copies) = c_data
1049
+
1050
+ old_point = None
1051
+ new_point = copy.deepcopy(out_copies)
1052
+ # comparison works since dictionary of built-in types
1053
+ while old_point != new_point:
1054
+ for label in blocks.keys():
1055
+ if label == entry:
1056
+ continue
1057
+ predecs = [i for i, _d in cfg.predecessors(label)]
1058
+ # in_b = intersect(predec(B))
1059
+ in_copies[label] = out_copies[predecs[0]].copy()
1060
+ for p in predecs:
1061
+ in_copies[label] &= out_copies[p]
1062
+
1063
+ # out_b = gen_b | (in_b - kill_b)
1064
+ out_copies[label] = gen_copies[label] | (
1065
+ in_copies[label] - kill_copies[label]
1066
+ )
1067
+ old_point = new_point
1068
+ new_point = copy.deepcopy(out_copies)
1069
+ if config.DEBUG_ARRAY_OPT >= 1:
1070
+ print("copy propagate out_copies:", out_copies)
1071
+ return in_copies, out_copies
1072
+
1073
+
1074
+ def init_copy_propagate_data(blocks, entry, typemap):
1075
+ """get initial condition of copy propagation data flow for each block."""
1076
+ # gen is all definite copies, extra_kill is additional ones that may hit
1077
+ # for example, parfors can have control flow so they may hit extra copies
1078
+ gen_copies, extra_kill = get_block_copies(blocks, typemap)
1079
+ # set of all program copies
1080
+ all_copies = set()
1081
+ for l, s in gen_copies.items():
1082
+ all_copies |= gen_copies[l]
1083
+ kill_copies = {}
1084
+ for label, gen_set in gen_copies.items():
1085
+ kill_copies[label] = set()
1086
+ for lhs, rhs in all_copies:
1087
+ if lhs in extra_kill[label] or rhs in extra_kill[label]:
1088
+ kill_copies[label].add((lhs, rhs))
1089
+ # a copy is killed if it is not in this block and lhs or rhs are
1090
+ # assigned in this block
1091
+ assigned = {lhs for lhs, rhs in gen_set}
1092
+ if (lhs, rhs) not in gen_set and (
1093
+ lhs in assigned or rhs in assigned
1094
+ ):
1095
+ kill_copies[label].add((lhs, rhs))
1096
+ # set initial values
1097
+ # all copies are in for all blocks except entry
1098
+ in_copies = {l: all_copies.copy() for l in blocks.keys()}
1099
+ in_copies[entry] = set()
1100
+ out_copies = {}
1101
+ for label in blocks.keys():
1102
+ # out_b = gen_b | (in_b - kill_b)
1103
+ out_copies[label] = gen_copies[label] | (
1104
+ in_copies[label] - kill_copies[label]
1105
+ )
1106
+ out_copies[entry] = gen_copies[entry]
1107
+ return (gen_copies, all_copies, kill_copies, in_copies, out_copies)
1108
+
1109
+
1110
+ # other packages that define new nodes add calls to get copies in them
1111
+ # format: {type:function}
1112
+ copy_propagate_extensions = {}
1113
+
1114
+
1115
+ def get_block_copies(blocks, typemap):
1116
+ """get copies generated and killed by each block"""
1117
+ block_copies = {}
1118
+ extra_kill = {}
1119
+ for label, block in blocks.items():
1120
+ assign_dict = {}
1121
+ extra_kill[label] = set()
1122
+ # assignments as dict to replace with latest value
1123
+ for stmt in block.body:
1124
+ for T, f in copy_propagate_extensions.items():
1125
+ if isinstance(stmt, T):
1126
+ gen_set, kill_set = f(stmt, typemap)
1127
+ for lhs, rhs in gen_set:
1128
+ assign_dict[lhs] = rhs
1129
+ # if a=b is in dict and b is killed, a is also killed
1130
+ new_assign_dict = {}
1131
+ for l, r in assign_dict.items():
1132
+ if l not in kill_set and r not in kill_set:
1133
+ new_assign_dict[l] = r
1134
+ if r in kill_set:
1135
+ extra_kill[label].add(l)
1136
+ assign_dict = new_assign_dict
1137
+ extra_kill[label] |= kill_set
1138
+ if isinstance(stmt, ir.Assign):
1139
+ lhs = stmt.target.name
1140
+ if isinstance(stmt.value, ir.Var):
1141
+ rhs = stmt.value.name
1142
+ # copy is valid only if same type (see
1143
+ # TestCFunc.test_locals)
1144
+ # Some transformations can produce assignments of the
1145
+ # form A = A. We don't put these mapping in the
1146
+ # copy propagation set because then you get cycles and
1147
+ # infinite loops in the replacement phase.
1148
+ if typemap[lhs] == typemap[rhs] and lhs != rhs:
1149
+ assign_dict[lhs] = rhs
1150
+ continue
1151
+ if (
1152
+ isinstance(stmt.value, ir.Expr)
1153
+ and stmt.value.op == "inplace_binop"
1154
+ ):
1155
+ in1_var = stmt.value.lhs.name
1156
+ in1_typ = typemap[in1_var]
1157
+ # inplace_binop assigns first operand if mutable
1158
+ if not (
1159
+ isinstance(in1_typ, types.Number)
1160
+ or in1_typ == types.string
1161
+ ):
1162
+ extra_kill[label].add(in1_var)
1163
+ # if a=b is in dict and b is killed, a is also killed
1164
+ new_assign_dict = {}
1165
+ for l, r in assign_dict.items():
1166
+ if l != in1_var and r != in1_var:
1167
+ new_assign_dict[l] = r
1168
+ if r == in1_var:
1169
+ extra_kill[label].add(l)
1170
+ assign_dict = new_assign_dict
1171
+ extra_kill[label].add(lhs)
1172
+ block_cps = set(assign_dict.items())
1173
+ block_copies[label] = block_cps
1174
+ return block_copies, extra_kill
1175
+
1176
+
1177
+ # other packages that define new nodes add calls to apply copy propagate in them
1178
+ # format: {type:function}
1179
+ apply_copy_propagate_extensions = {}
1180
+
1181
+
1182
+ def apply_copy_propagate(
1183
+ blocks, in_copies, name_var_table, typemap, calltypes, save_copies=None
1184
+ ):
1185
+ """apply copy propagation to IR: replace variables when copies available"""
1186
+ # save_copies keeps an approximation of the copies that were applied, so
1187
+ # that the variable names of removed user variables can be recovered to some
1188
+ # extent.
1189
+ if save_copies is None:
1190
+ save_copies = []
1191
+
1192
+ for label, block in blocks.items():
1193
+ var_dict = {l: name_var_table[r] for l, r in in_copies[label]}
1194
+ # assignments as dict to replace with latest value
1195
+ for stmt in block.body:
1196
+ if type(stmt) in apply_copy_propagate_extensions:
1197
+ f = apply_copy_propagate_extensions[type(stmt)]
1198
+ f(
1199
+ stmt,
1200
+ var_dict,
1201
+ name_var_table,
1202
+ typemap,
1203
+ calltypes,
1204
+ save_copies,
1205
+ )
1206
+ # only rhs of assignments should be replaced
1207
+ # e.g. if x=y is available, x in x=z shouldn't be replaced
1208
+ elif isinstance(stmt, ir.Assign):
1209
+ stmt.value = replace_vars_inner(stmt.value, var_dict)
1210
+ else:
1211
+ replace_vars_stmt(stmt, var_dict)
1212
+ fix_setitem_type(stmt, typemap, calltypes)
1213
+ for T, f in copy_propagate_extensions.items():
1214
+ if isinstance(stmt, T):
1215
+ gen_set, kill_set = f(stmt, typemap)
1216
+ for lhs, rhs in gen_set:
1217
+ if rhs in name_var_table:
1218
+ var_dict[lhs] = name_var_table[rhs]
1219
+ for l, r in var_dict.copy().items():
1220
+ if l in kill_set or r.name in kill_set:
1221
+ var_dict.pop(l)
1222
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var):
1223
+ lhs = stmt.target.name
1224
+ rhs = stmt.value.name
1225
+ # rhs could be replaced with lhs from previous copies
1226
+ if lhs != rhs:
1227
+ # copy is valid only if same type (see
1228
+ # TestCFunc.test_locals)
1229
+ if typemap[lhs] == typemap[rhs] and rhs in name_var_table:
1230
+ var_dict[lhs] = name_var_table[rhs]
1231
+ else:
1232
+ var_dict.pop(lhs, None)
1233
+ # a=b kills previous t=a
1234
+ lhs_kill = []
1235
+ for k, v in var_dict.items():
1236
+ if v.name == lhs:
1237
+ lhs_kill.append(k)
1238
+ for k in lhs_kill:
1239
+ var_dict.pop(k, None)
1240
+ if isinstance(stmt, ir.Assign) and not isinstance(
1241
+ stmt.value, ir.Var
1242
+ ):
1243
+ lhs = stmt.target.name
1244
+ var_dict.pop(lhs, None)
1245
+ # previous t=a is killed if a is killed
1246
+ lhs_kill = []
1247
+ for k, v in var_dict.items():
1248
+ if v.name == lhs:
1249
+ lhs_kill.append(k)
1250
+ for k in lhs_kill:
1251
+ var_dict.pop(k, None)
1252
+ save_copies.extend(var_dict.items())
1253
+
1254
+ return save_copies
1255
+
1256
+
1257
+ def fix_setitem_type(stmt, typemap, calltypes):
1258
+ """Copy propagation can replace setitem target variable, which can be array
1259
+ with 'A' layout. The replaced variable can be 'C' or 'F', so we update
1260
+ setitem call type reflect this (from matrix power test)
1261
+ """
1262
+ if not isinstance(stmt, (ir.SetItem, ir.StaticSetItem)):
1263
+ return
1264
+ t_typ = typemap[stmt.target.name]
1265
+ s_typ = calltypes[stmt].args[0]
1266
+ # test_optional t_typ can be Optional with array
1267
+ if not isinstance(s_typ, types.npytypes.Array) or not isinstance(
1268
+ t_typ, types.npytypes.Array
1269
+ ):
1270
+ return
1271
+ if s_typ.layout == "A" and t_typ.layout != "A":
1272
+ new_s_typ = s_typ.copy(layout=t_typ.layout)
1273
+ calltypes[stmt].args = (
1274
+ new_s_typ,
1275
+ calltypes[stmt].args[1],
1276
+ calltypes[stmt].args[2],
1277
+ )
1278
+ return
1279
+
1280
+
1281
+ def dprint_func_ir(func_ir, title, blocks=None):
1282
+ """Debug print function IR, with an optional blocks argument
1283
+ that may differ from the IR's original blocks.
1284
+ """
1285
+ if config.DEBUG_ARRAY_OPT >= 1:
1286
+ ir_blocks = func_ir.blocks
1287
+ func_ir.blocks = ir_blocks if blocks is None else blocks
1288
+ name = func_ir.func_id.func_qualname
1289
+ print(("IR %s: %s" % (title, name)).center(80, "-"))
1290
+ func_ir.dump()
1291
+ print("-" * 40)
1292
+ func_ir.blocks = ir_blocks
1293
+
1294
+
1295
+ def find_topo_order(blocks, cfg=None):
1296
+ """find topological order of blocks such that true branches are visited
1297
+ first (e.g. for_break test in test_dataflow). This is written as an iterative
1298
+ implementation of post order traversal to avoid recursion limit issues.
1299
+ """
1300
+ if cfg is None:
1301
+ cfg = compute_cfg_from_blocks(blocks)
1302
+
1303
+ post_order = []
1304
+ # Has the node already added its children?
1305
+ seen = set()
1306
+ # Has the node already been pushed to post order?
1307
+ visited = set()
1308
+ stack = [cfg.entry_point()]
1309
+
1310
+ while len(stack) > 0:
1311
+ node = stack[-1]
1312
+ if node not in visited and node not in seen:
1313
+ # We haven't added a node or its children.
1314
+ seen.add(node)
1315
+ succs = cfg._succs[node]
1316
+ last_inst = blocks[node].body[-1]
1317
+ if isinstance(last_inst, ir.Branch):
1318
+ succs = [last_inst.truebr, last_inst.falsebr]
1319
+ for dest in succs:
1320
+ if (node, dest) not in cfg._back_edges:
1321
+ if dest not in seen:
1322
+ stack.append(dest)
1323
+ else:
1324
+ # This node has already added its children. We either need
1325
+ # to visit the node or it has been added multiple times in
1326
+ # which case we should just skip the node.
1327
+ node = stack.pop()
1328
+ if node not in visited:
1329
+ post_order.append(node)
1330
+ visited.add(node)
1331
+ if node in seen:
1332
+ # Remove the node from seen if it exists to limit the memory
1333
+ # usage to 1 entry per node. Otherwise the memory requirement
1334
+ # can double the recursive version.
1335
+ seen.remove(node)
1336
+
1337
+ post_order.reverse()
1338
+ return post_order
1339
+
1340
+
1341
+ # other packages that define new nodes add calls to get call table
1342
+ # format: {type:function}
1343
+ call_table_extensions = {}
1344
+
1345
+
1346
+ def get_call_table(
1347
+ blocks, call_table=None, reverse_call_table=None, topological_ordering=True
1348
+ ):
1349
+ """returns a dictionary of call variables and their references."""
1350
+ # call_table example: c = np.zeros becomes c:["zeroes", np]
1351
+ # reverse_call_table example: c = np.zeros becomes np_var:c
1352
+ if call_table is None:
1353
+ call_table = {}
1354
+ if reverse_call_table is None:
1355
+ reverse_call_table = {}
1356
+
1357
+ if topological_ordering:
1358
+ order = find_topo_order(blocks)
1359
+ else:
1360
+ order = list(blocks.keys())
1361
+
1362
+ for label in reversed(order):
1363
+ for inst in reversed(blocks[label].body):
1364
+ if isinstance(inst, ir.Assign):
1365
+ lhs = inst.target.name
1366
+ rhs = inst.value
1367
+ if isinstance(rhs, ir.Expr) and rhs.op == "call":
1368
+ call_table[rhs.func.name] = []
1369
+ if isinstance(rhs, ir.Expr) and rhs.op == "getattr":
1370
+ if lhs in call_table:
1371
+ call_table[lhs].append(rhs.attr)
1372
+ reverse_call_table[rhs.value.name] = lhs
1373
+ if lhs in reverse_call_table:
1374
+ call_var = reverse_call_table[lhs]
1375
+ call_table[call_var].append(rhs.attr)
1376
+ reverse_call_table[rhs.value.name] = call_var
1377
+ if isinstance(rhs, ir.Global):
1378
+ if lhs in call_table:
1379
+ call_table[lhs].append(rhs.value)
1380
+ if lhs in reverse_call_table:
1381
+ call_var = reverse_call_table[lhs]
1382
+ call_table[call_var].append(rhs.value)
1383
+ if isinstance(rhs, ir.FreeVar):
1384
+ if lhs in call_table:
1385
+ call_table[lhs].append(rhs.value)
1386
+ if lhs in reverse_call_table:
1387
+ call_var = reverse_call_table[lhs]
1388
+ call_table[call_var].append(rhs.value)
1389
+ if isinstance(rhs, ir.Var):
1390
+ if lhs in call_table:
1391
+ call_table[lhs].append(rhs.name)
1392
+ reverse_call_table[rhs.name] = lhs
1393
+ if lhs in reverse_call_table:
1394
+ call_var = reverse_call_table[lhs]
1395
+ call_table[call_var].append(rhs.name)
1396
+ for T, f in call_table_extensions.items():
1397
+ if isinstance(inst, T):
1398
+ f(inst, call_table, reverse_call_table)
1399
+ return call_table, reverse_call_table
1400
+
1401
+
1402
+ # other packages that define new nodes add calls to get tuple table
1403
+ # format: {type:function}
1404
+ tuple_table_extensions = {}
1405
+
1406
+
1407
+ def get_tuple_table(blocks, tuple_table=None):
1408
+ """returns a dictionary of tuple variables and their values."""
1409
+ if tuple_table is None:
1410
+ tuple_table = {}
1411
+
1412
+ for block in blocks.values():
1413
+ for inst in block.body:
1414
+ if isinstance(inst, ir.Assign):
1415
+ lhs = inst.target.name
1416
+ rhs = inst.value
1417
+ if isinstance(rhs, ir.Expr) and rhs.op == "build_tuple":
1418
+ tuple_table[lhs] = rhs.items
1419
+ if isinstance(rhs, ir.Const) and isinstance(rhs.value, tuple):
1420
+ tuple_table[lhs] = rhs.value
1421
+ for T, f in tuple_table_extensions.items():
1422
+ if isinstance(inst, T):
1423
+ f(inst, tuple_table)
1424
+ return tuple_table
1425
+
1426
+
1427
+ def get_stmt_writes(stmt):
1428
+ writes = set()
1429
+ if isinstance(stmt, (ir.Assign, ir.SetItem, ir.StaticSetItem)):
1430
+ writes.add(stmt.target.name)
1431
+ return writes
1432
+
1433
+
1434
+ def rename_labels(blocks):
1435
+ """rename labels of function body blocks according to topological sort.
1436
+ The set of labels of these blocks will remain unchanged.
1437
+ """
1438
+ topo_order = find_topo_order(blocks)
1439
+
1440
+ # make a block with return last if available (just for readability)
1441
+ return_label = -1
1442
+ for l, b in blocks.items():
1443
+ if isinstance(b.body[-1], ir.Return):
1444
+ return_label = l
1445
+ # some cases like generators can have no return blocks
1446
+ if return_label != -1:
1447
+ topo_order.remove(return_label)
1448
+ topo_order.append(return_label)
1449
+
1450
+ label_map = {}
1451
+ all_labels = sorted(topo_order, reverse=True)
1452
+ for label in topo_order:
1453
+ label_map[label] = all_labels.pop()
1454
+ # update target labels in jumps/branches
1455
+ for b in blocks.values():
1456
+ term = b.terminator
1457
+ # create new IR nodes instead of mutating the existing one as copies of
1458
+ # the IR may also refer to the same nodes!
1459
+ if isinstance(term, ir.Jump):
1460
+ b.body[-1] = ir.Jump(label_map[term.target], term.loc)
1461
+ if isinstance(term, ir.Branch):
1462
+ b.body[-1] = ir.Branch(
1463
+ term.cond,
1464
+ label_map[term.truebr],
1465
+ label_map[term.falsebr],
1466
+ term.loc,
1467
+ )
1468
+
1469
+ # update blocks dictionary keys
1470
+ new_blocks = {}
1471
+ for k, b in blocks.items():
1472
+ new_label = label_map[k]
1473
+ new_blocks[new_label] = b
1474
+
1475
+ return new_blocks
1476
+
1477
+
1478
+ def simplify_CFG(blocks):
1479
+ """transform chains of blocks that have no loop into a single block"""
1480
+ # first, inline single-branch-block to its predecessors
1481
+ cfg = compute_cfg_from_blocks(blocks)
1482
+
1483
+ def find_single_branch(label):
1484
+ block = blocks[label]
1485
+ return len(block.body) == 1 and isinstance(block.body[0], ir.Branch)
1486
+
1487
+ single_branch_blocks = list(filter(find_single_branch, blocks.keys()))
1488
+ marked_for_del = set()
1489
+ for label in single_branch_blocks:
1490
+ inst = blocks[label].body[0]
1491
+ predecessors = cfg.predecessors(label)
1492
+ delete_block = True
1493
+ for p, q in predecessors:
1494
+ block = blocks[p]
1495
+ if isinstance(block.body[-1], ir.Jump):
1496
+ block.body[-1] = copy.copy(inst)
1497
+ else:
1498
+ delete_block = False
1499
+ if delete_block:
1500
+ marked_for_del.add(label)
1501
+ # Delete marked labels
1502
+ for label in marked_for_del:
1503
+ del blocks[label]
1504
+ merge_adjacent_blocks(blocks)
1505
+ return rename_labels(blocks)
1506
+
1507
+
1508
+ arr_math = [
1509
+ "min",
1510
+ "max",
1511
+ "sum",
1512
+ "prod",
1513
+ "mean",
1514
+ "var",
1515
+ "std",
1516
+ "cumsum",
1517
+ "cumprod",
1518
+ "argmax",
1519
+ "argmin",
1520
+ "argsort",
1521
+ "nonzero",
1522
+ "ravel",
1523
+ ]
1524
+
1525
+
1526
+ def canonicalize_array_math(func_ir, typemap, calltypes, typingctx):
1527
+ # save array arg to call
1528
+ # call_varname -> array
1529
+ blocks = func_ir.blocks
1530
+ saved_arr_arg = {}
1531
+ topo_order = find_topo_order(blocks)
1532
+ for label in topo_order:
1533
+ block = blocks[label]
1534
+ new_body = []
1535
+ for stmt in block.body:
1536
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr):
1537
+ lhs = stmt.target.name
1538
+ rhs = stmt.value
1539
+ # replace A.func with np.func, and save A in saved_arr_arg
1540
+ if (
1541
+ rhs.op == "getattr"
1542
+ and rhs.attr in arr_math
1543
+ and isinstance(
1544
+ typemap[rhs.value.name], types.npytypes.Array
1545
+ )
1546
+ ):
1547
+ rhs = stmt.value
1548
+ arr = rhs.value
1549
+ saved_arr_arg[lhs] = arr
1550
+ scope = arr.scope
1551
+ loc = arr.loc
1552
+ # g_np_var = Global(numpy)
1553
+ g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
1554
+ typemap[g_np_var.name] = types.misc.Module(numpy)
1555
+ g_np = ir.Global("np", numpy, loc)
1556
+ g_np_assign = ir.Assign(g_np, g_np_var, loc)
1557
+ rhs.value = g_np_var
1558
+ new_body.append(g_np_assign)
1559
+ func_ir._definitions[g_np_var.name] = [g_np]
1560
+ # update func var type
1561
+ func = getattr(numpy, rhs.attr)
1562
+ func_typ = get_np_ufunc_typ(func)
1563
+ typemap.pop(lhs)
1564
+ typemap[lhs] = func_typ
1565
+ if rhs.op == "call" and rhs.func.name in saved_arr_arg:
1566
+ # add array as first arg
1567
+ arr = saved_arr_arg[rhs.func.name]
1568
+ # update call type signature to include array arg
1569
+ old_sig = calltypes.pop(rhs)
1570
+ # argsort requires kws for typing so sig.args can't be used
1571
+ # reusing sig.args since some types become Const in sig
1572
+ argtyps = old_sig.args[: len(rhs.args)]
1573
+ kwtyps = {name: typemap[v.name] for name, v in rhs.kws}
1574
+ calltypes[rhs] = typemap[rhs.func.name].get_call_type(
1575
+ typingctx, [typemap[arr.name]] + list(argtyps), kwtyps
1576
+ )
1577
+ rhs.args = [arr] + rhs.args
1578
+
1579
+ new_body.append(stmt)
1580
+ block.body = new_body
1581
+ return
1582
+
1583
+
1584
+ # format: {type:function}
1585
+ array_accesses_extensions = {}
1586
+
1587
+
1588
+ def get_array_accesses(blocks, accesses=None):
1589
+ """returns a set of arrays accessed and their indices."""
1590
+ if accesses is None:
1591
+ accesses = set()
1592
+
1593
+ for block in blocks.values():
1594
+ for inst in block.body:
1595
+ if isinstance(inst, ir.SetItem):
1596
+ accesses.add((inst.target.name, inst.index.name))
1597
+ if isinstance(inst, ir.StaticSetItem):
1598
+ accesses.add((inst.target.name, inst.index_var.name))
1599
+ if isinstance(inst, ir.Assign):
1600
+ rhs = inst.value
1601
+ if isinstance(rhs, ir.Expr) and rhs.op == "getitem":
1602
+ accesses.add((rhs.value.name, rhs.index.name))
1603
+ if isinstance(rhs, ir.Expr) and rhs.op == "static_getitem":
1604
+ index = rhs.index
1605
+ # slice is unhashable, so just keep the variable
1606
+ if index is None or is_slice_index(index):
1607
+ index = rhs.index_var.name
1608
+ accesses.add((rhs.value.name, index))
1609
+ for T, f in array_accesses_extensions.items():
1610
+ if isinstance(inst, T):
1611
+ f(inst, accesses)
1612
+ return accesses
1613
+
1614
+
1615
+ def is_slice_index(index):
1616
+ """see if index is a slice index or has slice in it"""
1617
+ if isinstance(index, slice):
1618
+ return True
1619
+ if isinstance(index, tuple):
1620
+ for i in index:
1621
+ if isinstance(i, slice):
1622
+ return True
1623
+ return False
1624
+
1625
+
1626
+ def merge_adjacent_blocks(blocks):
1627
+ cfg = compute_cfg_from_blocks(blocks)
1628
+ # merge adjacent blocks
1629
+ removed = set()
1630
+ for label in list(blocks.keys()):
1631
+ if label in removed:
1632
+ continue
1633
+ block = blocks[label]
1634
+ succs = list(cfg.successors(label))
1635
+ while True:
1636
+ if len(succs) != 1:
1637
+ break
1638
+ next_label = succs[0][0]
1639
+ if next_label in removed:
1640
+ break
1641
+ preds = list(cfg.predecessors(next_label))
1642
+ succs = list(cfg.successors(next_label))
1643
+ if len(preds) != 1 or preds[0][0] != label:
1644
+ break
1645
+ next_block = blocks[next_label]
1646
+ # XXX: commented out since scope objects are not consistent
1647
+ # throughout the compiler. for example, pieces of code are compiled
1648
+ # and inlined on the fly without proper scope merge.
1649
+ # if block.scope != next_block.scope:
1650
+ # break
1651
+ # merge
1652
+ block.body.pop() # remove Jump
1653
+ block.body += next_block.body
1654
+ del blocks[next_label]
1655
+ removed.add(next_label)
1656
+ label = next_label
1657
+
1658
+
1659
+ def restore_copy_var_names(blocks, save_copies, typemap):
1660
+ """
1661
+ restores variable names of user variables after applying copy propagation
1662
+ """
1663
+ if not save_copies:
1664
+ return {}
1665
+
1666
+ rename_dict = {}
1667
+ var_rename_map = {}
1668
+ for a, b in save_copies:
1669
+ # a is string name, b is variable
1670
+ # if a is user variable and b is generated temporary and b is not
1671
+ # already renamed
1672
+ if (
1673
+ not a.startswith("$")
1674
+ and b.name.startswith("$")
1675
+ and b.name not in rename_dict
1676
+ ):
1677
+ new_name = mk_unique_var("${}".format(a))
1678
+ rename_dict[b.name] = new_name
1679
+ var_rename_map[new_name] = a
1680
+ typ = typemap.pop(b.name)
1681
+ typemap[new_name] = typ
1682
+
1683
+ replace_var_names(blocks, rename_dict)
1684
+ return var_rename_map
1685
+
1686
+
1687
+ def simplify(func_ir, typemap, calltypes, metadata):
1688
+ # get copies in to blocks and out from blocks
1689
+ in_cps, _ = copy_propagate(func_ir.blocks, typemap)
1690
+ # table mapping variable names to ir.Var objects to help replacement
1691
+ name_var_table = get_name_var_table(func_ir.blocks)
1692
+ save_copies = apply_copy_propagate(
1693
+ func_ir.blocks, in_cps, name_var_table, typemap, calltypes
1694
+ )
1695
+ var_rename_map = restore_copy_var_names(
1696
+ func_ir.blocks, save_copies, typemap
1697
+ )
1698
+ if "var_rename_map" not in metadata:
1699
+ metadata["var_rename_map"] = {}
1700
+ metadata["var_rename_map"].update(var_rename_map)
1701
+ # remove dead code to enable fusion
1702
+ if config.DEBUG_ARRAY_OPT >= 1:
1703
+ dprint_func_ir(func_ir, "after copy prop")
1704
+ remove_dead(func_ir.blocks, func_ir.arg_names, func_ir, typemap)
1705
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1706
+ if config.DEBUG_ARRAY_OPT >= 1:
1707
+ dprint_func_ir(func_ir, "after simplify")
1708
+
1709
+
1710
+ class GuardException(Exception):
1711
+ pass
1712
+
1713
+
1714
+ def require(cond):
1715
+ """
1716
+ Raise GuardException if the given condition is False.
1717
+ """
1718
+ if not cond:
1719
+ raise GuardException
1720
+
1721
+
1722
+ def guard(func, *args, **kwargs):
1723
+ """
1724
+ Run a function with given set of arguments, and guard against
1725
+ any GuardException raised by the function by returning None,
1726
+ or the expected return results if no such exception was raised.
1727
+ """
1728
+ try:
1729
+ return func(*args, **kwargs)
1730
+ except GuardException:
1731
+ return None
1732
+
1733
+
1734
+ def get_definition(func_ir, name, **kwargs):
1735
+ """
1736
+ Same as func_ir.get_definition(name), but raise GuardException if
1737
+ exception KeyError is caught.
1738
+ """
1739
+ try:
1740
+ return func_ir.get_definition(name, **kwargs)
1741
+ except KeyError:
1742
+ raise GuardException
1743
+
1744
+
1745
+ def build_definitions(blocks, definitions=None):
1746
+ """Build the definitions table of the given blocks by scanning
1747
+ through all blocks and instructions, useful when the definitions
1748
+ table is out-of-sync.
1749
+ Will return a new definition table if one is not passed.
1750
+ """
1751
+ if definitions is None:
1752
+ definitions = collections.defaultdict(list)
1753
+
1754
+ for block in blocks.values():
1755
+ for inst in block.body:
1756
+ if isinstance(inst, ir.Assign):
1757
+ name = inst.target.name
1758
+ definition = definitions.get(name, [])
1759
+ if definition == []:
1760
+ definitions[name] = definition
1761
+ definition.append(inst.value)
1762
+ if type(inst) in build_defs_extensions:
1763
+ f = build_defs_extensions[type(inst)]
1764
+ f(inst, definitions)
1765
+
1766
+ return definitions
1767
+
1768
+
1769
+ build_defs_extensions = {}
1770
+
1771
+
1772
+ def find_callname(
1773
+ func_ir, expr, typemap=None, definition_finder=get_definition
1774
+ ):
1775
+ """Try to find a call expression's function and module names and return
1776
+ them as strings for unbounded calls. If the call is a bounded call, return
1777
+ the self object instead of module name. Raise GuardException if failed.
1778
+
1779
+ Providing typemap can make the call matching more accurate in corner cases
1780
+ such as bounded call on an object which is inside another object.
1781
+ """
1782
+ require(isinstance(expr, ir.Expr) and expr.op == "call")
1783
+ callee = expr.func
1784
+ callee_def = definition_finder(func_ir, callee)
1785
+ attrs = []
1786
+ obj = None
1787
+ while True:
1788
+ if isinstance(callee_def, (ir.Global, ir.FreeVar)):
1789
+ # require(callee_def.value == numpy)
1790
+ # these checks support modules like numpy, numpy.random as well as
1791
+ # calls like len() and intrinsics like assertEquiv
1792
+ keys = ["name", "_name", "__name__"]
1793
+ value = None
1794
+ for key in keys:
1795
+ if hasattr(callee_def.value, key):
1796
+ value = getattr(callee_def.value, key)
1797
+ break
1798
+ if not value or not isinstance(value, str):
1799
+ raise GuardException
1800
+ attrs.append(value)
1801
+ def_val = callee_def.value
1802
+ # get the underlying definition of Intrinsic object to be able to
1803
+ # find the module effectively.
1804
+ # Otherwise, it will return numba.extending
1805
+ if isinstance(def_val, _Intrinsic):
1806
+ def_val = def_val._defn
1807
+ if hasattr(def_val, "__module__"):
1808
+ mod_name = def_val.__module__
1809
+ # The reason for first checking if the function is in NumPy's
1810
+ # top level name space by module is that some functions are
1811
+ # deprecated in NumPy but the functions' names are aliased with
1812
+ # other common names. This prevents deprecation warnings on
1813
+ # e.g. getattr(numpy, 'bool') were a bool the target.
1814
+ # For context see #6175, impacts NumPy>=1.20.
1815
+ mod_not_none = mod_name is not None
1816
+ numpy_toplevel = mod_not_none and (
1817
+ mod_name == "numpy" or mod_name.startswith("numpy.")
1818
+ )
1819
+ # it might be a numpy function imported directly
1820
+ if (
1821
+ numpy_toplevel
1822
+ and hasattr(numpy, value)
1823
+ and def_val == getattr(numpy, value)
1824
+ ):
1825
+ attrs += ["numpy"]
1826
+ # it might be a np.random function imported directly
1827
+ elif hasattr(numpy.random, value) and def_val == getattr(
1828
+ numpy.random, value
1829
+ ):
1830
+ attrs += ["random", "numpy"]
1831
+ elif mod_not_none:
1832
+ attrs.append(mod_name)
1833
+ else:
1834
+ class_name = def_val.__class__.__name__
1835
+ if class_name == "builtin_function_or_method":
1836
+ class_name = "builtin"
1837
+ if class_name != "module":
1838
+ attrs.append(class_name)
1839
+ break
1840
+ elif isinstance(callee_def, ir.Expr) and callee_def.op == "getattr":
1841
+ obj = callee_def.value
1842
+ attrs.append(callee_def.attr)
1843
+ if typemap and obj.name in typemap:
1844
+ typ = typemap[obj.name]
1845
+ if not isinstance(typ, types.Module):
1846
+ return attrs[0], obj
1847
+ callee_def = definition_finder(func_ir, obj)
1848
+ else:
1849
+ # obj.func calls where obj is not np array
1850
+ if obj is not None:
1851
+ return ".".join(reversed(attrs)), obj
1852
+ raise GuardException
1853
+ return attrs[0], ".".join(reversed(attrs[1:]))
1854
+
1855
+
1856
+ def find_build_sequence(func_ir, var):
1857
+ """Check if a variable is constructed via build_tuple or
1858
+ build_list or build_set, and return the sequence and the
1859
+ operator, or raise GuardException otherwise.
1860
+ Note: only build_tuple is immutable, so use with care.
1861
+ """
1862
+ require(isinstance(var, ir.Var))
1863
+ var_def = get_definition(func_ir, var)
1864
+ require(isinstance(var_def, ir.Expr))
1865
+ build_ops = ["build_tuple", "build_list", "build_set"]
1866
+ require(var_def.op in build_ops)
1867
+ return var_def.items, var_def.op
1868
+
1869
+
1870
+ def find_const(func_ir, var):
1871
+ """Check if a variable is defined as constant, and return
1872
+ the constant value, or raise GuardException otherwise.
1873
+ """
1874
+ require(isinstance(var, ir.Var))
1875
+ var_def = get_definition(func_ir, var)
1876
+ require(isinstance(var_def, (ir.Const, ir.Global, ir.FreeVar)))
1877
+ return var_def.value
1878
+
1879
+
1880
+ def compile_to_numba_ir(
1881
+ mk_func,
1882
+ glbls,
1883
+ typingctx=None,
1884
+ targetctx=None,
1885
+ arg_typs=None,
1886
+ typemap=None,
1887
+ calltypes=None,
1888
+ ):
1889
+ """
1890
+ Compile a function or a make_function node to Numba IR.
1891
+
1892
+ Rename variables and
1893
+ labels to avoid conflict if inlined somewhere else. Perform type inference
1894
+ if typingctx and other typing inputs are available and update typemap and
1895
+ calltypes.
1896
+ """
1897
+ from numba.core import typed_passes
1898
+
1899
+ # mk_func can be actual function or make_function node, or a njit function
1900
+ if hasattr(mk_func, "code"):
1901
+ code = mk_func.code
1902
+ elif hasattr(mk_func, "__code__"):
1903
+ code = mk_func.__code__
1904
+ else:
1905
+ raise NotImplementedError(
1906
+ "function type not recognized {}".format(mk_func)
1907
+ )
1908
+ f_ir = get_ir_of_code(glbls, code)
1909
+ remove_dels(f_ir.blocks)
1910
+
1911
+ # relabel by adding an offset
1912
+ f_ir.blocks = add_offset_to_labels(f_ir.blocks, _the_max_label.next())
1913
+ max_label = max(f_ir.blocks.keys())
1914
+ _the_max_label.update(max_label)
1915
+
1916
+ # rename all variables to avoid conflict
1917
+ var_table = get_name_var_table(f_ir.blocks)
1918
+ new_var_dict = {}
1919
+ for name, var in var_table.items():
1920
+ new_var_dict[name] = mk_unique_var(name)
1921
+ replace_var_names(f_ir.blocks, new_var_dict)
1922
+
1923
+ # perform type inference if typingctx is available and update type
1924
+ # data structures typemap and calltypes
1925
+ if typingctx:
1926
+ f_typemap, f_return_type, f_calltypes, _ = (
1927
+ typed_passes.type_inference_stage(
1928
+ typingctx, targetctx, f_ir, arg_typs, None
1929
+ )
1930
+ )
1931
+ # remove argument entries like arg.a from typemap
1932
+ arg_names = [vname for vname in f_typemap if vname.startswith("arg.")]
1933
+ for a in arg_names:
1934
+ f_typemap.pop(a)
1935
+ typemap.update(f_typemap)
1936
+ calltypes.update(f_calltypes)
1937
+ return f_ir
1938
+
1939
+
1940
+ def _create_function_from_code_obj(fcode, func_env, func_arg, func_clo, glbls):
1941
+ """
1942
+ Creates a function from a code object. Args:
1943
+ * fcode - the code object
1944
+ * func_env - string for the freevar placeholders
1945
+ * func_arg - string for the function args (e.g. "a, b, c, d=None")
1946
+ * func_clo - string for the closure args
1947
+ * glbls - the function globals
1948
+ """
1949
+ sanitized_co_name = fcode.co_name.replace("<", "_").replace(">", "_")
1950
+ func_text = (
1951
+ f"def closure():\n{func_env}\n"
1952
+ f"\tdef {sanitized_co_name}({func_arg}):\n"
1953
+ f"\t\treturn ({func_clo})\n"
1954
+ f"\treturn {sanitized_co_name}"
1955
+ )
1956
+ loc = {}
1957
+ exec(func_text, glbls, loc)
1958
+
1959
+ f = loc["closure"]()
1960
+ # replace the code body
1961
+ f.__code__ = fcode
1962
+ f.__name__ = fcode.co_name
1963
+ return f
1964
+
1965
+
1966
+ def get_ir_of_code(glbls, fcode):
1967
+ """
1968
+ Compile a code object to get its IR, ir.Del nodes are emitted
1969
+ """
1970
+ nfree = len(fcode.co_freevars)
1971
+ func_env = "\n".join(["\tc_%d = None" % i for i in range(nfree)])
1972
+ func_clo = ",".join(["c_%d" % i for i in range(nfree)])
1973
+ func_arg = ",".join(["x_%d" % i for i in range(fcode.co_argcount)])
1974
+
1975
+ f = _create_function_from_code_obj(
1976
+ fcode, func_env, func_arg, func_clo, glbls
1977
+ )
1978
+
1979
+ from numba.core import compiler
1980
+
1981
+ ir = compiler.run_frontend(f)
1982
+
1983
+ # we need to run the before inference rewrite pass to normalize the IR
1984
+ # XXX: check rewrite pass flag?
1985
+ # for example, Raise nodes need to become StaticRaise before type inference
1986
+ class DummyPipeline(object):
1987
+ def __init__(self, f_ir):
1988
+ self.state = compiler.StateDict()
1989
+ self.state.typingctx = None
1990
+ self.state.targetctx = None
1991
+ self.state.args = None
1992
+ self.state.func_ir = f_ir
1993
+ self.state.typemap = None
1994
+ self.state.return_type = None
1995
+ self.state.calltypes = None
1996
+
1997
+ state = DummyPipeline(ir).state
1998
+ rewrites.rewrite_registry.apply("before-inference", state)
1999
+ # call inline pass to handle cases like stencils and comprehensions
2000
+ swapped = {} # TODO: get this from diagnostics store
2001
+ import numba.core.inline_closurecall
2002
+
2003
+ inline_pass = numba.core.inline_closurecall.InlineClosureCallPass(
2004
+ ir, numba.core.cpu.ParallelOptions(False), swapped
2005
+ )
2006
+ inline_pass.run()
2007
+
2008
+ # TODO: DO NOT ADD MORE THINGS HERE!
2009
+ # If adding more things here is being contemplated, it really is time to
2010
+ # retire this function and work on getting the InlineWorker class from
2011
+ # numba.core.inline_closurecall into sufficient shape as a replacement.
2012
+ # The issue with `get_ir_of_code` is that it doesn't run a full compilation
2013
+ # pipeline and as a result various additional things keep needing to be
2014
+ # added to create valid IR.
2015
+
2016
+ # rebuild IR in SSA form
2017
+ from numba.core.untyped_passes import ReconstructSSA
2018
+ from numba.core.typed_passes import PreLowerStripPhis
2019
+
2020
+ reconstruct_ssa = ReconstructSSA()
2021
+ phistrip = PreLowerStripPhis()
2022
+ reconstruct_ssa.run_pass(state)
2023
+ phistrip.run_pass(state)
2024
+
2025
+ post_proc = postproc.PostProcessor(ir)
2026
+ post_proc.run(True)
2027
+ return ir
2028
+
2029
+
2030
+ def replace_arg_nodes(block, args):
2031
+ """
2032
+ Replace ir.Arg(...) with variables
2033
+ """
2034
+ for stmt in block.body:
2035
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg):
2036
+ idx = stmt.value.index
2037
+ assert idx < len(args)
2038
+ stmt.value = args[idx]
2039
+ return
2040
+
2041
+
2042
+ def replace_returns(blocks, target, return_label):
2043
+ """
2044
+ Return return statement by assigning directly to target, and a jump.
2045
+ """
2046
+ for block in blocks.values():
2047
+ # some blocks may be empty during transformations
2048
+ if not block.body:
2049
+ continue
2050
+ stmt = block.terminator
2051
+ if isinstance(stmt, ir.Return):
2052
+ block.body.pop() # remove return
2053
+ cast_stmt = block.body.pop()
2054
+ assert (
2055
+ isinstance(cast_stmt, ir.Assign)
2056
+ and isinstance(cast_stmt.value, ir.Expr)
2057
+ and cast_stmt.value.op == "cast"
2058
+ ), "invalid return cast"
2059
+ block.body.append(
2060
+ ir.Assign(cast_stmt.value.value, target, stmt.loc)
2061
+ )
2062
+ block.body.append(ir.Jump(return_label, stmt.loc))
2063
+
2064
+
2065
+ def gen_np_call(func_as_str, func, lhs, args, typingctx, typemap, calltypes):
2066
+ scope = args[0].scope
2067
+ loc = args[0].loc
2068
+
2069
+ # g_np_var = Global(numpy)
2070
+ g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
2071
+ typemap[g_np_var.name] = types.misc.Module(numpy)
2072
+ g_np = ir.Global("np", numpy, loc)
2073
+ g_np_assign = ir.Assign(g_np, g_np_var, loc)
2074
+ # attr call: <something>_attr = getattr(g_np_var, func_as_str)
2075
+ np_attr_call = ir.Expr.getattr(g_np_var, func_as_str, loc)
2076
+ attr_var = ir.Var(scope, mk_unique_var("$np_attr_attr"), loc)
2077
+ func_var_typ = get_np_ufunc_typ(func)
2078
+ typemap[attr_var.name] = func_var_typ
2079
+ attr_assign = ir.Assign(np_attr_call, attr_var, loc)
2080
+ # np call: lhs = np_attr(*args)
2081
+ np_call = ir.Expr.call(attr_var, args, (), loc)
2082
+ arg_types = [typemap[x.name] for x in args]
2083
+ func_typ = func_var_typ.get_call_type(typingctx, arg_types, {})
2084
+ calltypes[np_call] = func_typ
2085
+ np_assign = ir.Assign(np_call, lhs, loc)
2086
+ return [g_np_assign, attr_assign, np_assign]
2087
+
2088
+
2089
+ def dump_block(label, block):
2090
+ print(label, ":")
2091
+ for stmt in block.body:
2092
+ print(" ", stmt)
2093
+
2094
+
2095
+ def dump_blocks(blocks):
2096
+ for label, block in blocks.items():
2097
+ dump_block(label, block)
2098
+
2099
+
2100
+ def is_operator_or_getitem(expr):
2101
+ """true if expr is unary or binary operator or getitem"""
2102
+ return (
2103
+ isinstance(expr, ir.Expr)
2104
+ and getattr(expr, "op", False)
2105
+ and expr.op
2106
+ in ["unary", "binop", "inplace_binop", "getitem", "static_getitem"]
2107
+ )
2108
+
2109
+
2110
+ def is_get_setitem(stmt):
2111
+ """stmt is getitem assignment or setitem (and static cases)"""
2112
+ return is_getitem(stmt) or is_setitem(stmt)
2113
+
2114
+
2115
+ def is_getitem(stmt):
2116
+ """true if stmt is a getitem or static_getitem assignment"""
2117
+ return (
2118
+ isinstance(stmt, ir.Assign)
2119
+ and isinstance(stmt.value, ir.Expr)
2120
+ and stmt.value.op in ["getitem", "static_getitem"]
2121
+ )
2122
+
2123
+
2124
+ def is_setitem(stmt):
2125
+ """true if stmt is a SetItem or StaticSetItem node"""
2126
+ return isinstance(stmt, (ir.SetItem, ir.StaticSetItem))
2127
+
2128
+
2129
+ def index_var_of_get_setitem(stmt):
2130
+ """get index variable for getitem/setitem nodes (and static cases)"""
2131
+ if is_getitem(stmt):
2132
+ if stmt.value.op == "getitem":
2133
+ return stmt.value.index
2134
+ else:
2135
+ return stmt.value.index_var
2136
+
2137
+ if is_setitem(stmt):
2138
+ if isinstance(stmt, ir.SetItem):
2139
+ return stmt.index
2140
+ else:
2141
+ return stmt.index_var
2142
+
2143
+ return None
2144
+
2145
+
2146
+ def set_index_var_of_get_setitem(stmt, new_index):
2147
+ if is_getitem(stmt):
2148
+ if stmt.value.op == "getitem":
2149
+ stmt.value.index = new_index
2150
+ else:
2151
+ stmt.value.index_var = new_index
2152
+ elif is_setitem(stmt):
2153
+ if isinstance(stmt, ir.SetItem):
2154
+ stmt.index = new_index
2155
+ else:
2156
+ stmt.index_var = new_index
2157
+ else:
2158
+ raise ValueError(
2159
+ "getitem or setitem node expected but received {}".format(stmt)
2160
+ )
2161
+
2162
+
2163
+ def is_namedtuple_class(c):
2164
+ """check if c is a namedtuple class"""
2165
+ if not isinstance(c, type):
2166
+ return False
2167
+ # should have only tuple as superclass
2168
+ bases = c.__bases__
2169
+ if len(bases) != 1 or bases[0] is not tuple:
2170
+ return False
2171
+ # should have _make method
2172
+ if not hasattr(c, "_make"):
2173
+ return False
2174
+ # should have _fields that is all string
2175
+ fields = getattr(c, "_fields", None)
2176
+ if not isinstance(fields, tuple):
2177
+ return False
2178
+ return all(isinstance(f, str) for f in fields)
2179
+
2180
+
2181
+ def fill_block_with_call(newblock, callee, label_next, inputs, outputs):
2182
+ """Fill *newblock* to call *callee* with arguments listed in *inputs*.
2183
+ The returned values are unwrapped into variables in *outputs*.
2184
+ The block would then jump to *label_next*.
2185
+ """
2186
+ scope = newblock.scope
2187
+ loc = newblock.loc
2188
+
2189
+ fn = ir.Const(value=callee, loc=loc)
2190
+ fnvar = scope.make_temp(loc=loc)
2191
+ newblock.append(ir.Assign(target=fnvar, value=fn, loc=loc))
2192
+ # call
2193
+ args = [scope.get_exact(name) for name in inputs]
2194
+ callexpr = ir.Expr.call(func=fnvar, args=args, kws=(), loc=loc)
2195
+ callres = scope.make_temp(loc=loc)
2196
+ newblock.append(ir.Assign(target=callres, value=callexpr, loc=loc))
2197
+ # unpack return value
2198
+ for i, out in enumerate(outputs):
2199
+ target = scope.get_exact(out)
2200
+ getitem = ir.Expr.static_getitem(
2201
+ value=callres, index=i, index_var=None, loc=loc
2202
+ )
2203
+ newblock.append(ir.Assign(target=target, value=getitem, loc=loc))
2204
+ # jump to next block
2205
+ newblock.append(ir.Jump(target=label_next, loc=loc))
2206
+ return newblock
2207
+
2208
+
2209
+ def fill_callee_prologue(block, inputs, label_next):
2210
+ """
2211
+ Fill a new block *block* that unwraps arguments using names in *inputs* and
2212
+ then jumps to *label_next*.
2213
+
2214
+ Expected to use with *fill_block_with_call()*
2215
+ """
2216
+ scope = block.scope
2217
+ loc = block.loc
2218
+ # load args
2219
+ args = [ir.Arg(name=k, index=i, loc=loc) for i, k in enumerate(inputs)]
2220
+ for aname, aval in zip(inputs, args):
2221
+ tmp = ir.Var(scope=scope, name=aname, loc=loc)
2222
+ block.append(ir.Assign(target=tmp, value=aval, loc=loc))
2223
+ # jump to loop entry
2224
+ block.append(ir.Jump(target=label_next, loc=loc))
2225
+ return block
2226
+
2227
+
2228
+ def fill_callee_epilogue(block, outputs):
2229
+ """
2230
+ Fill a new block *block* to prepare the return values.
2231
+ This block is the last block of the function.
2232
+
2233
+ Expected to use with *fill_block_with_call()*
2234
+ """
2235
+ scope = block.scope
2236
+ loc = block.loc
2237
+ # prepare tuples to return
2238
+ vals = [scope.get_exact(name=name) for name in outputs]
2239
+ tupexpr = ir.Expr.build_tuple(items=vals, loc=loc)
2240
+ tup = scope.make_temp(loc=loc)
2241
+ block.append(ir.Assign(target=tup, value=tupexpr, loc=loc))
2242
+ # return
2243
+ block.append(ir.Return(value=tup, loc=loc))
2244
+ return block
2245
+
2246
+
2247
+ def find_outer_value(func_ir, var):
2248
+ """Check if a variable is a global value, and return the value,
2249
+ or raise GuardException otherwise.
2250
+ """
2251
+ dfn = get_definition(func_ir, var)
2252
+ if isinstance(dfn, (ir.Global, ir.FreeVar)):
2253
+ return dfn.value
2254
+
2255
+ if isinstance(dfn, ir.Expr) and dfn.op == "getattr":
2256
+ prev_val = find_outer_value(func_ir, dfn.value)
2257
+ try:
2258
+ val = getattr(prev_val, dfn.attr)
2259
+ return val
2260
+ except AttributeError:
2261
+ raise GuardException
2262
+
2263
+ raise GuardException
2264
+
2265
+
2266
+ def raise_on_unsupported_feature(func_ir, typemap):
2267
+ """
2268
+ Helper function to walk IR and raise if it finds op codes
2269
+ that are unsupported. Could be extended to cover IR sequences
2270
+ as well as op codes. Intended use is to call it as a pipeline
2271
+ stage just prior to lowering to prevent LoweringErrors for known
2272
+ unsupported features.
2273
+ """
2274
+ gdb_calls = [] # accumulate calls to gdb/gdb_init
2275
+
2276
+ # issue 2195: check for excessively large tuples
2277
+ for arg_name in func_ir.arg_names:
2278
+ if (
2279
+ arg_name in typemap
2280
+ and isinstance(typemap[arg_name], types.containers.UniTuple)
2281
+ and typemap[arg_name].count > 1000
2282
+ ):
2283
+ # Raise an exception when len(tuple) > 1000. The choice of this number (1000)
2284
+ # was entirely arbitrary
2285
+ msg = (
2286
+ "Tuple '{}' length must be smaller than 1000.\n"
2287
+ "Large tuples lead to the generation of a prohibitively large "
2288
+ "LLVM IR which causes excessive memory pressure "
2289
+ "and large compile times.\n"
2290
+ "As an alternative, the use of a 'list' is recommended in "
2291
+ "place of a 'tuple' as lists do not suffer from this problem.".format(
2292
+ arg_name
2293
+ )
2294
+ )
2295
+ raise UnsupportedError(msg, func_ir.loc)
2296
+
2297
+ for blk in func_ir.blocks.values():
2298
+ for stmt in blk.find_insts(ir.Assign):
2299
+ # This raises on finding `make_function`
2300
+ if isinstance(stmt.value, ir.Expr):
2301
+ if stmt.value.op == "make_function":
2302
+ val = stmt.value
2303
+
2304
+ # See if the construct name can be refined
2305
+ code = getattr(val, "code", None)
2306
+ if code is not None:
2307
+ # check if this is a closure, the co_name will
2308
+ # be the captured function name which is not
2309
+ # useful so be explicit
2310
+ if getattr(val, "closure", None) is not None:
2311
+ use = "<creating a function from a closure>"
2312
+ expr = ""
2313
+ else:
2314
+ use = code.co_name
2315
+ expr = "(%s) " % use
2316
+ else:
2317
+ use = "<could not ascertain use case>"
2318
+ expr = ""
2319
+
2320
+ msg = (
2321
+ "Numba encountered the use of a language "
2322
+ "feature it does not support in this context: "
2323
+ "%s (op code: make_function not supported). If "
2324
+ "the feature is explicitly supported it is "
2325
+ "likely that the result of the expression %s"
2326
+ "is being used in an unsupported manner."
2327
+ ) % (use, expr)
2328
+ raise UnsupportedError(msg, stmt.value.loc)
2329
+
2330
+ # this checks for gdb initialization calls, only one is permitted
2331
+ if isinstance(stmt.value, (ir.Global, ir.FreeVar)):
2332
+ val = stmt.value
2333
+ val = getattr(val, "value", None)
2334
+ if val is None:
2335
+ continue
2336
+
2337
+ # check global function
2338
+ found = False
2339
+ if isinstance(val, pytypes.FunctionType):
2340
+ found = val in {numba.gdb, numba.gdb_init}
2341
+ if not found: # freevar bind to intrinsic
2342
+ found = getattr(val, "_name", "") == "gdb_internal"
2343
+ if found:
2344
+ gdb_calls.append(stmt.loc) # report last seen location
2345
+
2346
+ # this checks that np.<type> was called if view is called
2347
+ if isinstance(stmt.value, ir.Expr):
2348
+ if stmt.value.op == "getattr" and stmt.value.attr == "view":
2349
+ var = stmt.value.value.name
2350
+ if isinstance(typemap[var], types.Array):
2351
+ continue
2352
+ df = func_ir.get_definition(var)
2353
+ cn = guard(find_callname, func_ir, df)
2354
+ if cn and cn[1] == "numpy":
2355
+ ty = getattr(numpy, cn[0])
2356
+ if numpy.issubdtype(
2357
+ ty, numpy.integer
2358
+ ) or numpy.issubdtype(ty, numpy.floating):
2359
+ continue
2360
+
2361
+ vardescr = (
2362
+ "" if var.startswith("$") else "'{}' ".format(var)
2363
+ )
2364
+ raise TypingError(
2365
+ "'view' can only be called on NumPy dtypes, "
2366
+ "try wrapping the variable {}with 'np.<dtype>()'".format(
2367
+ vardescr
2368
+ ),
2369
+ loc=stmt.loc,
2370
+ )
2371
+
2372
+ # checks for globals that are also reflected
2373
+ if isinstance(stmt.value, ir.Global):
2374
+ ty = typemap[stmt.target.name]
2375
+ msg = (
2376
+ "The use of a %s type, assigned to variable '%s' in "
2377
+ "globals, is not supported as globals are considered "
2378
+ "compile-time constants and there is no known way to "
2379
+ "compile a %s type as a constant."
2380
+ )
2381
+ if getattr(ty, "reflected", False) or isinstance(
2382
+ ty, (types.DictType, types.ListType)
2383
+ ):
2384
+ raise TypingError(
2385
+ msg % (ty, stmt.value.name, ty), loc=stmt.loc
2386
+ )
2387
+
2388
+ # checks for generator expressions (yield in use when func_ir has
2389
+ # not been identified as a generator).
2390
+ if isinstance(stmt.value, ir.Yield) and not func_ir.is_generator:
2391
+ msg = "The use of generator expressions is unsupported."
2392
+ raise UnsupportedError(msg, loc=stmt.loc)
2393
+
2394
+ # There is more than one call to function gdb/gdb_init
2395
+ if len(gdb_calls) > 1:
2396
+ msg = (
2397
+ "Calling either numba.gdb() or numba.gdb_init() more than once "
2398
+ "in a function is unsupported (strange things happen!), use "
2399
+ "numba.gdb_breakpoint() to create additional breakpoints "
2400
+ "instead.\n\nRelevant documentation is available here:\n"
2401
+ "https://numba.readthedocs.io/en/stable/user/troubleshoot.html"
2402
+ "#using-numba-s-direct-gdb-bindings-in-nopython-mode\n\n"
2403
+ "Conflicting calls found at:\n %s"
2404
+ )
2405
+ buf = "\n".join([x.strformat() for x in gdb_calls])
2406
+ raise UnsupportedError(msg % buf)
2407
+
2408
+
2409
+ def warn_deprecated(func_ir, typemap):
2410
+ # first pass, just walk the type map
2411
+ for name, ty in typemap.items():
2412
+ # the Type Metaclass has a reflected member
2413
+ if ty.reflected:
2414
+ # if its an arg, report function call
2415
+ if name.startswith("arg."):
2416
+ loc = func_ir.loc
2417
+ arg = name.split(".")[1]
2418
+ fname = func_ir.func_id.func_qualname
2419
+ tyname = "list" if isinstance(ty, types.List) else "set"
2420
+ url = (
2421
+ "https://numba.readthedocs.io/en/stable/reference/"
2422
+ "deprecation.html#deprecation-of-reflection-for-list-and"
2423
+ "-set-types"
2424
+ )
2425
+ msg = (
2426
+ "\nEncountered the use of a type that is scheduled for "
2427
+ "deprecation: type 'reflected %s' found for argument "
2428
+ "'%s' of function '%s'.\n\nFor more information visit "
2429
+ "%s" % (tyname, arg, fname, url)
2430
+ )
2431
+ warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))
2432
+
2433
+
2434
+ def resolve_func_from_module(func_ir, node):
2435
+ """
2436
+ This returns the python function that is being getattr'd from a module in
2437
+ some IR, it resolves import chains/submodules recursively. Should it not be
2438
+ possible to find the python function being called None will be returned.
2439
+
2440
+ func_ir - the FunctionIR object
2441
+ node - the IR node from which to start resolving (should be a `getattr`).
2442
+ """
2443
+ getattr_chain = []
2444
+
2445
+ def resolve_mod(mod):
2446
+ if getattr(mod, "op", False) == "getattr":
2447
+ getattr_chain.insert(0, mod.attr)
2448
+ try:
2449
+ mod = func_ir.get_definition(mod.value)
2450
+ except KeyError: # multiple definitions
2451
+ return None
2452
+ return resolve_mod(mod)
2453
+ elif isinstance(mod, (ir.Global, ir.FreeVar)):
2454
+ if isinstance(mod.value, pytypes.ModuleType):
2455
+ return mod
2456
+ return None
2457
+
2458
+ mod = resolve_mod(node)
2459
+ if mod is not None:
2460
+ defn = mod.value
2461
+ for x in getattr_chain:
2462
+ defn = getattr(defn, x, False)
2463
+ if not defn:
2464
+ break
2465
+ else:
2466
+ return defn
2467
+ else:
2468
+ return None
2469
+
2470
+
2471
+ def enforce_no_dels(func_ir):
2472
+ """
2473
+ Enforce there being no ir.Del nodes in the IR.
2474
+ """
2475
+ for blk in func_ir.blocks.values():
2476
+ dels = [x for x in blk.find_insts(ir.Del)]
2477
+ if dels:
2478
+ msg = "Illegal IR, del found at: %s" % dels[0]
2479
+ raise CompilerError(msg, loc=dels[0].loc)
2480
+
2481
+
2482
+ def enforce_no_phis(func_ir):
2483
+ """
2484
+ Enforce there being no ir.Expr.phi nodes in the IR.
2485
+ """
2486
+ for blk in func_ir.blocks.values():
2487
+ phis = [x for x in blk.find_exprs(op="phi")]
2488
+ if phis:
2489
+ msg = "Illegal IR, phi found at: %s" % phis[0]
2490
+ raise CompilerError(msg, loc=phis[0].loc)
2491
+
2492
+
2493
+ def legalize_single_scope(blocks):
2494
+ """Check the given mapping of ir.Block for containing a single scope."""
2495
+ return len({blk.scope for blk in blocks.values()}) == 1
2496
+
2497
+
2498
+ def check_and_legalize_ir(func_ir, flags: "numba.core.compiler.Flags"):
2499
+ """
2500
+ This checks that the IR presented is legal
2501
+ """
2502
+ enforce_no_phis(func_ir)
2503
+ enforce_no_dels(func_ir)
2504
+ # postprocess and emit ir.Dels
2505
+ post_proc = postproc.PostProcessor(func_ir)
2506
+ post_proc.run(True, extend_lifetimes=flags.dbg_extend_lifetimes)
2507
+
2508
+
2509
+ def convert_code_obj_to_function(code_obj, caller_ir):
2510
+ """
2511
+ Converts a code object from a `make_function.code` attr in the IR into a
2512
+ python function, caller_ir is the FunctionIR of the caller and is used for
2513
+ the resolution of freevars.
2514
+ """
2515
+ fcode = code_obj.code
2516
+ nfree = len(fcode.co_freevars)
2517
+
2518
+ # try and resolve freevars if they are consts in the caller's IR
2519
+ # these can be baked into the new function
2520
+ freevars = []
2521
+ for x in fcode.co_freevars:
2522
+ # not using guard here to differentiate between multiple definition and
2523
+ # non-const variable
2524
+ try:
2525
+ freevar_def = caller_ir.get_definition(x)
2526
+ except KeyError:
2527
+ msg = (
2528
+ "Cannot capture a constant value for variable '%s' as there "
2529
+ "are multiple definitions present." % x
2530
+ )
2531
+ raise TypingError(msg, loc=code_obj.loc)
2532
+ if isinstance(freevar_def, ir.Const):
2533
+ freevars.append(freevar_def.value)
2534
+ else:
2535
+ msg = (
2536
+ "Cannot capture the non-constant value associated with "
2537
+ "variable '%s' in a function that may escape." % x
2538
+ )
2539
+ raise TypingError(msg, loc=code_obj.loc)
2540
+
2541
+ func_env = "\n".join(
2542
+ ["\tc_%d = %s" % (i, x) for i, x in enumerate(freevars)]
2543
+ )
2544
+ func_clo = ",".join(["c_%d" % i for i in range(nfree)])
2545
+ co_varnames = list(fcode.co_varnames)
2546
+
2547
+ # This is horrible. The code object knows about the number of args present
2548
+ # it also knows the name of the args but these are bundled in with other
2549
+ # vars in `co_varnames`. The make_function IR node knows what the defaults
2550
+ # are, they are defined in the IR as consts. The following finds the total
2551
+ # number of args (args + kwargs with defaults), finds the default values
2552
+ # and infers the number of "kwargs with defaults" from this and then infers
2553
+ # the number of actual arguments from that.
2554
+ n_kwargs = 0
2555
+ n_allargs = fcode.co_argcount
2556
+ kwarg_defaults = caller_ir.get_definition(code_obj.defaults)
2557
+ if kwarg_defaults is not None:
2558
+ if isinstance(kwarg_defaults, tuple):
2559
+ d = [caller_ir.get_definition(x).value for x in kwarg_defaults]
2560
+ kwarg_defaults_tup = tuple(d)
2561
+ else:
2562
+ d = [
2563
+ caller_ir.get_definition(x).value for x in kwarg_defaults.items
2564
+ ]
2565
+ kwarg_defaults_tup = tuple(d)
2566
+ n_kwargs = len(kwarg_defaults_tup)
2567
+ nargs = n_allargs - n_kwargs
2568
+
2569
+ func_arg = ",".join(["%s" % (co_varnames[i]) for i in range(nargs)])
2570
+ if n_kwargs:
2571
+ kw_const = [
2572
+ "%s = %s" % (co_varnames[i + nargs], kwarg_defaults_tup[i])
2573
+ for i in range(n_kwargs)
2574
+ ]
2575
+ func_arg += ", "
2576
+ func_arg += ", ".join(kw_const)
2577
+
2578
+ # globals are the same as those in the caller
2579
+ glbls = caller_ir.func_id.func.__globals__
2580
+
2581
+ # create the function and return it
2582
+ return _create_function_from_code_obj(
2583
+ fcode, func_env, func_arg, func_clo, glbls
2584
+ )
2585
+
2586
+
2587
+ def fixup_var_define_in_scope(blocks):
2588
+ """Fixes the mapping of ir.Block to ensure all referenced ir.Var are
2589
+ defined in every scope used by the function. Such that looking up a variable
2590
+ from any scope in this function will not fail.
2591
+
2592
+ Note: This is a workaround. Ideally, all the blocks should refer to the
2593
+ same ir.Scope, but that property is not maintained by all the passes.
2594
+ """
2595
+ # Scan for all used variables
2596
+ used_var = {}
2597
+ for blk in blocks.values():
2598
+ scope = blk.scope
2599
+ for inst in blk.body:
2600
+ for var in inst.list_vars():
2601
+ used_var[var] = inst
2602
+ # Note: not all blocks share a single scope even though they should.
2603
+ # Ensure the scope of each block defines all used variables.
2604
+ for blk in blocks.values():
2605
+ scope = blk.scope
2606
+ for var, inst in used_var.items():
2607
+ # add this variable if it's not in scope
2608
+ if var.name not in scope.localvars:
2609
+ # Note: using a internal method to reuse the same
2610
+ scope.localvars.define(var.name, var)
2611
+
2612
+
2613
+ def transfer_scope(block, scope):
2614
+ """Transfer the ir.Block to use the given ir.Scope."""
2615
+ old_scope = block.scope
2616
+ if old_scope is scope:
2617
+ # bypass if the block is already using the given scope
2618
+ return block
2619
+ # Ensure variables are defined in the new scope
2620
+ for var in old_scope.localvars._con.values():
2621
+ if var.name not in scope.localvars:
2622
+ scope.localvars.define(var.name, var)
2623
+ # replace scope
2624
+ block.scope = scope
2625
+ return block
2626
+
2627
+
2628
+ def is_setup_with(stmt):
2629
+ return isinstance(stmt, ir.EnterWith)
2630
+
2631
+
2632
+ def is_terminator(stmt):
2633
+ return isinstance(stmt, ir.Terminator)
2634
+
2635
+
2636
+ def is_raise(stmt):
2637
+ return isinstance(stmt, ir.Raise)
2638
+
2639
+
2640
+ def is_return(stmt):
2641
+ return isinstance(stmt, ir.Return)
2642
+
2643
+
2644
+ def is_pop_block(stmt):
2645
+ return isinstance(stmt, ir.PopBlock)