numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +4 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from collections import namedtuple
|
|
5
|
+
from numba import types
|
|
6
|
+
from numba.core import ir
|
|
7
|
+
from numba.cuda.core import consts
|
|
8
|
+
from numba.core.analysis import compute_cfg_from_blocks
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Used to describe a nullified condition in dead branch pruning
|
|
12
|
+
nullified = namedtuple("nullified", "condition, taken_br, rewrite_stmt")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dead_branch_prune(func_ir, called_args):
|
|
16
|
+
"""
|
|
17
|
+
Removes dead branches based on constant inference from function args.
|
|
18
|
+
This directly mutates the IR.
|
|
19
|
+
|
|
20
|
+
func_ir is the IR
|
|
21
|
+
called_args are the actual arguments with which the function is called
|
|
22
|
+
"""
|
|
23
|
+
from numba.cuda.core.ir_utils import (
|
|
24
|
+
get_definition,
|
|
25
|
+
guard,
|
|
26
|
+
find_const,
|
|
27
|
+
GuardException,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
DEBUG = 0
|
|
31
|
+
|
|
32
|
+
def find_branches(func_ir):
|
|
33
|
+
# find *all* branches
|
|
34
|
+
branches = []
|
|
35
|
+
for blk in func_ir.blocks.values():
|
|
36
|
+
branch_or_jump = blk.body[-1]
|
|
37
|
+
if isinstance(branch_or_jump, ir.Branch):
|
|
38
|
+
branch = branch_or_jump
|
|
39
|
+
pred = guard(get_definition, func_ir, branch.cond.name)
|
|
40
|
+
if pred is not None and getattr(pred, "op", None) == "call":
|
|
41
|
+
function = guard(get_definition, func_ir, pred.func)
|
|
42
|
+
if (
|
|
43
|
+
function is not None
|
|
44
|
+
and isinstance(function, ir.Global)
|
|
45
|
+
and function.value is bool
|
|
46
|
+
):
|
|
47
|
+
condition = guard(get_definition, func_ir, pred.args[0])
|
|
48
|
+
if condition is not None:
|
|
49
|
+
branches.append((branch, condition, blk))
|
|
50
|
+
return branches
|
|
51
|
+
|
|
52
|
+
def do_prune(take_truebr, blk):
|
|
53
|
+
keep = branch.truebr if take_truebr else branch.falsebr
|
|
54
|
+
# replace the branch with a direct jump
|
|
55
|
+
jmp = ir.Jump(keep, loc=branch.loc)
|
|
56
|
+
blk.body[-1] = jmp
|
|
57
|
+
return 1 if keep == branch.truebr else 0
|
|
58
|
+
|
|
59
|
+
def prune_by_type(branch, condition, blk, *conds):
|
|
60
|
+
# this prunes a given branch and fixes up the IR
|
|
61
|
+
# at least one needs to be a NoneType
|
|
62
|
+
lhs_cond, rhs_cond = conds
|
|
63
|
+
lhs_none = isinstance(lhs_cond, types.NoneType)
|
|
64
|
+
rhs_none = isinstance(rhs_cond, types.NoneType)
|
|
65
|
+
if lhs_none or rhs_none:
|
|
66
|
+
try:
|
|
67
|
+
take_truebr = condition.fn(lhs_cond, rhs_cond)
|
|
68
|
+
except Exception:
|
|
69
|
+
return False, None
|
|
70
|
+
if DEBUG > 0:
|
|
71
|
+
kill = branch.falsebr if take_truebr else branch.truebr
|
|
72
|
+
print(
|
|
73
|
+
"Pruning %s" % kill,
|
|
74
|
+
branch,
|
|
75
|
+
lhs_cond,
|
|
76
|
+
rhs_cond,
|
|
77
|
+
condition.fn,
|
|
78
|
+
)
|
|
79
|
+
taken = do_prune(take_truebr, blk)
|
|
80
|
+
return True, taken
|
|
81
|
+
return False, None
|
|
82
|
+
|
|
83
|
+
def prune_by_value(branch, condition, blk, *conds):
|
|
84
|
+
lhs_cond, rhs_cond = conds
|
|
85
|
+
try:
|
|
86
|
+
take_truebr = condition.fn(lhs_cond, rhs_cond)
|
|
87
|
+
except Exception:
|
|
88
|
+
return False, None
|
|
89
|
+
if DEBUG > 0:
|
|
90
|
+
kill = branch.falsebr if take_truebr else branch.truebr
|
|
91
|
+
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn)
|
|
92
|
+
do_prune(take_truebr, blk)
|
|
93
|
+
# It is not safe to rewrite the predicate to a nominal value based on
|
|
94
|
+
# which branch is taken, the rewritten const predicate needs to
|
|
95
|
+
# hold the actual computed const value as something else may refer to
|
|
96
|
+
# it!
|
|
97
|
+
return True, take_truebr
|
|
98
|
+
|
|
99
|
+
def prune_by_predicate(branch, pred, blk):
|
|
100
|
+
try:
|
|
101
|
+
# Just to prevent accidents, whilst already guarded, ensure this
|
|
102
|
+
# is an ir.Const
|
|
103
|
+
if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
|
|
104
|
+
raise TypeError("Expected constant Numba IR node")
|
|
105
|
+
take_truebr = bool(pred.value)
|
|
106
|
+
except TypeError:
|
|
107
|
+
return False, None
|
|
108
|
+
if DEBUG > 0:
|
|
109
|
+
kill = branch.falsebr if take_truebr else branch.truebr
|
|
110
|
+
print("Pruning %s" % kill, branch, pred)
|
|
111
|
+
taken = do_prune(take_truebr, blk)
|
|
112
|
+
return True, taken
|
|
113
|
+
|
|
114
|
+
class Unknown(object):
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
def resolve_input_arg_const(input_arg_idx):
|
|
118
|
+
"""
|
|
119
|
+
Resolves an input arg to a constant (if possible)
|
|
120
|
+
"""
|
|
121
|
+
input_arg_ty = called_args[input_arg_idx]
|
|
122
|
+
|
|
123
|
+
# comparing to None?
|
|
124
|
+
if isinstance(input_arg_ty, types.NoneType):
|
|
125
|
+
return input_arg_ty
|
|
126
|
+
|
|
127
|
+
# is it a kwarg default
|
|
128
|
+
if isinstance(input_arg_ty, types.Omitted):
|
|
129
|
+
val = input_arg_ty.value
|
|
130
|
+
if isinstance(val, types.NoneType):
|
|
131
|
+
return val
|
|
132
|
+
elif val is None:
|
|
133
|
+
return types.NoneType("none")
|
|
134
|
+
|
|
135
|
+
# literal type, return the type itself so comparisons like `x == None`
|
|
136
|
+
# still work as e.g. x = types.int64 will never be None/NoneType so
|
|
137
|
+
# the branch can still be pruned
|
|
138
|
+
return getattr(input_arg_ty, "literal_type", Unknown())
|
|
139
|
+
|
|
140
|
+
if DEBUG > 1:
|
|
141
|
+
print("before".center(80, "-"))
|
|
142
|
+
print(func_ir.dump())
|
|
143
|
+
|
|
144
|
+
phi2lbl = dict()
|
|
145
|
+
phi2asgn = dict()
|
|
146
|
+
for lbl, blk in func_ir.blocks.items():
|
|
147
|
+
for stmt in blk.body:
|
|
148
|
+
if isinstance(stmt, ir.Assign):
|
|
149
|
+
if isinstance(stmt.value, ir.Expr) and stmt.value.op == "phi":
|
|
150
|
+
phi2lbl[stmt.value] = lbl
|
|
151
|
+
phi2asgn[stmt.value] = stmt
|
|
152
|
+
|
|
153
|
+
# This looks for branches where:
|
|
154
|
+
# at least one arg of the condition is in input args and const
|
|
155
|
+
# at least one an arg of the condition is a const
|
|
156
|
+
# if the condition is met it will replace the branch with a jump
|
|
157
|
+
branch_info = find_branches(func_ir)
|
|
158
|
+
# stores conditions that have no impact post prune
|
|
159
|
+
nullified_conditions = []
|
|
160
|
+
|
|
161
|
+
for branch, condition, blk in branch_info:
|
|
162
|
+
const_conds = []
|
|
163
|
+
if isinstance(condition, ir.Expr) and condition.op == "binop":
|
|
164
|
+
prune = prune_by_value
|
|
165
|
+
for arg in [condition.lhs, condition.rhs]:
|
|
166
|
+
resolved_const = Unknown()
|
|
167
|
+
arg_def = guard(get_definition, func_ir, arg)
|
|
168
|
+
if isinstance(arg_def, ir.Arg):
|
|
169
|
+
# it's an e.g. literal argument to the function
|
|
170
|
+
resolved_const = resolve_input_arg_const(arg_def.index)
|
|
171
|
+
prune = prune_by_type
|
|
172
|
+
else:
|
|
173
|
+
# it's some const argument to the function, cannot use guard
|
|
174
|
+
# here as the const itself may be None
|
|
175
|
+
try:
|
|
176
|
+
resolved_const = find_const(func_ir, arg)
|
|
177
|
+
if resolved_const is None:
|
|
178
|
+
resolved_const = types.NoneType("none")
|
|
179
|
+
except GuardException:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
if not isinstance(resolved_const, Unknown):
|
|
183
|
+
const_conds.append(resolved_const)
|
|
184
|
+
|
|
185
|
+
# lhs/rhs are consts
|
|
186
|
+
if len(const_conds) == 2:
|
|
187
|
+
# prune the branch, switch the branch for an unconditional jump
|
|
188
|
+
prune_stat, taken = prune(branch, condition, blk, *const_conds)
|
|
189
|
+
if prune_stat:
|
|
190
|
+
# add the condition to the list of nullified conditions
|
|
191
|
+
nullified_conditions.append(
|
|
192
|
+
nullified(condition, taken, True)
|
|
193
|
+
)
|
|
194
|
+
else:
|
|
195
|
+
# see if this is a branch on a constant value predicate
|
|
196
|
+
resolved_const = Unknown()
|
|
197
|
+
try:
|
|
198
|
+
pred_call = get_definition(func_ir, branch.cond)
|
|
199
|
+
resolved_const = find_const(func_ir, pred_call.args[0])
|
|
200
|
+
if resolved_const is None:
|
|
201
|
+
resolved_const = types.NoneType("none")
|
|
202
|
+
except GuardException:
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
if not isinstance(resolved_const, Unknown):
|
|
206
|
+
prune_stat, taken = prune_by_predicate(branch, condition, blk)
|
|
207
|
+
if prune_stat:
|
|
208
|
+
# add the condition to the list of nullified conditions
|
|
209
|
+
nullified_conditions.append(
|
|
210
|
+
nullified(condition, taken, False)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# 'ERE BE DRAGONS...
|
|
214
|
+
# It is the evaluation of the condition expression that often trips up type
|
|
215
|
+
# inference, so ideally it would be removed as it is effectively rendered
|
|
216
|
+
# dead by the unconditional jump if a branch was pruned. However, there may
|
|
217
|
+
# be references to the condition that exist in multiple places (e.g. dels)
|
|
218
|
+
# and we cannot run DCE here as typing has not taken place to give enough
|
|
219
|
+
# information to run DCE safely. Upshot of all this is the condition gets
|
|
220
|
+
# rewritten below into a benign const that typing will be happy with and DCE
|
|
221
|
+
# can remove it and its reference post typing when it is safe to do so
|
|
222
|
+
# (if desired). It is required that the const is assigned a value that
|
|
223
|
+
# indicates the branch taken as its mutated value would be read in the case
|
|
224
|
+
# of object mode fall back in place of the condition itself. For
|
|
225
|
+
# completeness the func_ir._definitions and ._consts are also updated to
|
|
226
|
+
# make the IR state self consistent.
|
|
227
|
+
|
|
228
|
+
deadcond = [x.condition for x in nullified_conditions]
|
|
229
|
+
for _, cond, blk in branch_info:
|
|
230
|
+
if cond in deadcond:
|
|
231
|
+
for x in blk.body:
|
|
232
|
+
if isinstance(x, ir.Assign) and x.value is cond:
|
|
233
|
+
# rewrite the condition as a true/false bit
|
|
234
|
+
nullified_info = nullified_conditions[deadcond.index(cond)]
|
|
235
|
+
# only do a rewrite of conditions, predicates need to retain
|
|
236
|
+
# their value as they may be used later.
|
|
237
|
+
if nullified_info.rewrite_stmt:
|
|
238
|
+
branch_bit = nullified_info.taken_br
|
|
239
|
+
x.value = ir.Const(branch_bit, loc=x.loc)
|
|
240
|
+
# update the specific definition to the new const
|
|
241
|
+
defns = func_ir._definitions[x.target.name]
|
|
242
|
+
repl_idx = defns.index(cond)
|
|
243
|
+
defns[repl_idx] = x.value
|
|
244
|
+
|
|
245
|
+
# Check post dominators of dead nodes from in the original CFG for use of
|
|
246
|
+
# vars that are being removed in the dead blocks which might be referred to
|
|
247
|
+
# by phi nodes.
|
|
248
|
+
#
|
|
249
|
+
# Multiple things to fix up:
|
|
250
|
+
#
|
|
251
|
+
# 1. Cases like:
|
|
252
|
+
#
|
|
253
|
+
# A A
|
|
254
|
+
# |\ |
|
|
255
|
+
# | B --> B
|
|
256
|
+
# |/ |
|
|
257
|
+
# C C
|
|
258
|
+
#
|
|
259
|
+
# i.e. the branch is dead but the block is still alive. In this case CFG
|
|
260
|
+
# simplification will fuse A-B-C and any phi in C can be updated as an
|
|
261
|
+
# direct assignment from the last assigned version in the dominators of the
|
|
262
|
+
# fused block.
|
|
263
|
+
#
|
|
264
|
+
# 2. Cases like:
|
|
265
|
+
#
|
|
266
|
+
# A A
|
|
267
|
+
# / \ |
|
|
268
|
+
# B C --> B
|
|
269
|
+
# \ / |
|
|
270
|
+
# D D
|
|
271
|
+
#
|
|
272
|
+
# i.e. the block C is dead. In this case the phis in D need updating to
|
|
273
|
+
# reflect the collapse of the phi condition. This should result in a direct
|
|
274
|
+
# assignment of the surviving version in B to the LHS of the phi in D.
|
|
275
|
+
|
|
276
|
+
new_cfg = compute_cfg_from_blocks(func_ir.blocks)
|
|
277
|
+
dead_blocks = new_cfg.dead_nodes()
|
|
278
|
+
|
|
279
|
+
# for all phis that are still in live blocks.
|
|
280
|
+
for phi, lbl in phi2lbl.items():
|
|
281
|
+
if lbl in dead_blocks:
|
|
282
|
+
continue
|
|
283
|
+
new_incoming = [x[0] for x in new_cfg.predecessors(lbl)]
|
|
284
|
+
if set(new_incoming) != set(phi.incoming_blocks):
|
|
285
|
+
# Something has changed in the CFG...
|
|
286
|
+
if len(new_incoming) == 1:
|
|
287
|
+
# There's now just one incoming. Replace the PHI node by a
|
|
288
|
+
# direct assignment
|
|
289
|
+
idx = phi.incoming_blocks.index(new_incoming[0])
|
|
290
|
+
phi2asgn[phi].value = phi.incoming_values[idx]
|
|
291
|
+
else:
|
|
292
|
+
# There's more than one incoming still, then look through the
|
|
293
|
+
# incoming and remove dead
|
|
294
|
+
ic_val_tmp = []
|
|
295
|
+
ic_blk_tmp = []
|
|
296
|
+
for ic_val, ic_blk in zip(
|
|
297
|
+
phi.incoming_values, phi.incoming_blocks
|
|
298
|
+
):
|
|
299
|
+
if ic_blk in dead_blocks:
|
|
300
|
+
continue
|
|
301
|
+
else:
|
|
302
|
+
ic_val_tmp.append(ic_val)
|
|
303
|
+
ic_blk_tmp.append(ic_blk)
|
|
304
|
+
phi.incoming_values.clear()
|
|
305
|
+
phi.incoming_values.extend(ic_val_tmp)
|
|
306
|
+
phi.incoming_blocks.clear()
|
|
307
|
+
phi.incoming_blocks.extend(ic_blk_tmp)
|
|
308
|
+
|
|
309
|
+
# Remove dead blocks, this is safe as it relies on the CFG only.
|
|
310
|
+
for dead in dead_blocks:
|
|
311
|
+
del func_ir.blocks[dead]
|
|
312
|
+
|
|
313
|
+
# if conditions were nullified then consts were rewritten, update
|
|
314
|
+
if nullified_conditions:
|
|
315
|
+
func_ir._consts = consts.ConstantInference(func_ir)
|
|
316
|
+
|
|
317
|
+
if DEBUG > 1:
|
|
318
|
+
print("after".center(80, "-"))
|
|
319
|
+
print(func_ir.dump())
|
|
File without changes
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict, OrderedDict
|
|
5
|
+
from collections.abc import Mapping
|
|
6
|
+
from contextlib import closing
|
|
7
|
+
import copy
|
|
8
|
+
import inspect
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import textwrap
|
|
12
|
+
from io import StringIO
|
|
13
|
+
|
|
14
|
+
import numba.core.dispatcher
|
|
15
|
+
from numba.core import ir
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SourceLines(Mapping):
|
|
19
|
+
def __init__(self, func):
|
|
20
|
+
try:
|
|
21
|
+
lines, startno = inspect.getsourcelines(func)
|
|
22
|
+
except OSError:
|
|
23
|
+
self.lines = ()
|
|
24
|
+
self.startno = 0
|
|
25
|
+
else:
|
|
26
|
+
self.lines = textwrap.dedent("".join(lines)).splitlines()
|
|
27
|
+
self.startno = startno
|
|
28
|
+
|
|
29
|
+
def __getitem__(self, lineno):
|
|
30
|
+
try:
|
|
31
|
+
return self.lines[lineno - self.startno].rstrip()
|
|
32
|
+
except IndexError:
|
|
33
|
+
return ""
|
|
34
|
+
|
|
35
|
+
def __iter__(self):
|
|
36
|
+
return iter((self.startno + i) for i in range(len(self.lines)))
|
|
37
|
+
|
|
38
|
+
def __len__(self):
|
|
39
|
+
return len(self.lines)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def avail(self):
|
|
43
|
+
return bool(self.lines)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TypeAnnotation(object):
|
|
47
|
+
# func_data dict stores annotation data for all functions that are
|
|
48
|
+
# compiled. We store the data in the TypeAnnotation class since a new
|
|
49
|
+
# TypeAnnotation instance is created for each function that is compiled.
|
|
50
|
+
# For every function that is compiled, we add the type annotation data to
|
|
51
|
+
# this dict and write the html annotation file to disk (rewrite the html
|
|
52
|
+
# file for every function since we don't know if this is the last function
|
|
53
|
+
# to be compiled).
|
|
54
|
+
func_data = OrderedDict()
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
func_ir,
|
|
59
|
+
typemap,
|
|
60
|
+
calltypes,
|
|
61
|
+
lifted,
|
|
62
|
+
lifted_from,
|
|
63
|
+
args,
|
|
64
|
+
return_type,
|
|
65
|
+
html_output=None,
|
|
66
|
+
):
|
|
67
|
+
self.func_id = func_ir.func_id
|
|
68
|
+
self.blocks = func_ir.blocks
|
|
69
|
+
self.typemap = typemap
|
|
70
|
+
self.calltypes = calltypes
|
|
71
|
+
self.filename = func_ir.loc.filename
|
|
72
|
+
self.linenum = str(func_ir.loc.line)
|
|
73
|
+
self.signature = str(args) + " -> " + str(return_type)
|
|
74
|
+
|
|
75
|
+
# lifted loop information
|
|
76
|
+
self.lifted = lifted
|
|
77
|
+
self.num_lifted_loops = len(lifted)
|
|
78
|
+
|
|
79
|
+
# If this is a lifted loop function that is being compiled, lifted_from
|
|
80
|
+
# points to annotation data from function that this loop lifted function
|
|
81
|
+
# was lifted from. This is used to stick lifted loop annotations back
|
|
82
|
+
# into original function.
|
|
83
|
+
self.lifted_from = lifted_from
|
|
84
|
+
|
|
85
|
+
def prepare_annotations(self):
|
|
86
|
+
# Prepare annotations
|
|
87
|
+
groupedinst = defaultdict(list)
|
|
88
|
+
found_lifted_loop = False
|
|
89
|
+
# for blkid, blk in self.blocks.items():
|
|
90
|
+
for blkid in sorted(self.blocks.keys()):
|
|
91
|
+
blk = self.blocks[blkid]
|
|
92
|
+
groupedinst[blk.loc.line].append("label %s" % blkid)
|
|
93
|
+
for inst in blk.body:
|
|
94
|
+
lineno = inst.loc.line
|
|
95
|
+
|
|
96
|
+
if isinstance(inst, ir.Assign):
|
|
97
|
+
if found_lifted_loop:
|
|
98
|
+
atype = "XXX Lifted Loop XXX"
|
|
99
|
+
found_lifted_loop = False
|
|
100
|
+
elif (
|
|
101
|
+
isinstance(inst.value, ir.Expr)
|
|
102
|
+
and inst.value.op == "call"
|
|
103
|
+
):
|
|
104
|
+
atype = self.calltypes[inst.value]
|
|
105
|
+
elif isinstance(inst.value, ir.Const) and isinstance(
|
|
106
|
+
inst.value.value, numba.core.dispatcher.LiftedLoop
|
|
107
|
+
):
|
|
108
|
+
atype = "XXX Lifted Loop XXX"
|
|
109
|
+
found_lifted_loop = True
|
|
110
|
+
else:
|
|
111
|
+
# TODO: fix parfor lowering so that typemap is valid.
|
|
112
|
+
atype = self.typemap.get(inst.target.name, "<missing>")
|
|
113
|
+
|
|
114
|
+
aline = "%s = %s :: %s" % (inst.target, inst.value, atype)
|
|
115
|
+
elif isinstance(inst, ir.SetItem):
|
|
116
|
+
atype = self.calltypes[inst]
|
|
117
|
+
aline = "%s :: %s" % (inst, atype)
|
|
118
|
+
else:
|
|
119
|
+
aline = "%s" % inst
|
|
120
|
+
groupedinst[lineno].append(" %s" % aline)
|
|
121
|
+
return groupedinst
|
|
122
|
+
|
|
123
|
+
def annotate(self):
|
|
124
|
+
source = SourceLines(self.func_id.func)
|
|
125
|
+
# if not source.avail:
|
|
126
|
+
# return "Source code unavailable"
|
|
127
|
+
|
|
128
|
+
groupedinst = self.prepare_annotations()
|
|
129
|
+
|
|
130
|
+
# Format annotations
|
|
131
|
+
io = StringIO()
|
|
132
|
+
with closing(io):
|
|
133
|
+
if source.avail:
|
|
134
|
+
print("# File: %s" % self.filename, file=io)
|
|
135
|
+
for num in source:
|
|
136
|
+
srcline = source[num]
|
|
137
|
+
ind = _getindent(srcline)
|
|
138
|
+
print("%s# --- LINE %d --- " % (ind, num), file=io)
|
|
139
|
+
for inst in groupedinst[num]:
|
|
140
|
+
print("%s# %s" % (ind, inst), file=io)
|
|
141
|
+
print(file=io)
|
|
142
|
+
print(srcline, file=io)
|
|
143
|
+
print(file=io)
|
|
144
|
+
if self.lifted:
|
|
145
|
+
print("# The function contains lifted loops", file=io)
|
|
146
|
+
for loop in self.lifted:
|
|
147
|
+
print(
|
|
148
|
+
"# Loop at line %d" % loop.get_source_location(),
|
|
149
|
+
file=io,
|
|
150
|
+
)
|
|
151
|
+
print(
|
|
152
|
+
"# Has %d overloads" % len(loop.overloads), file=io
|
|
153
|
+
)
|
|
154
|
+
for cres in loop.overloads.values():
|
|
155
|
+
print(cres.type_annotation, file=io)
|
|
156
|
+
else:
|
|
157
|
+
print("# Source code unavailable", file=io)
|
|
158
|
+
for num in groupedinst:
|
|
159
|
+
for inst in groupedinst[num]:
|
|
160
|
+
print("%s" % (inst,), file=io)
|
|
161
|
+
print(file=io)
|
|
162
|
+
|
|
163
|
+
return io.getvalue()
|
|
164
|
+
|
|
165
|
+
def html_annotate(self, outfile):
|
|
166
|
+
# ensure that annotation information is assembled
|
|
167
|
+
self.annotate_raw()
|
|
168
|
+
# make a deep copy ahead of the pending mutations
|
|
169
|
+
func_data = copy.deepcopy(self.func_data)
|
|
170
|
+
|
|
171
|
+
key = "python_indent"
|
|
172
|
+
for this_func in func_data.values():
|
|
173
|
+
if key in this_func:
|
|
174
|
+
idents = {}
|
|
175
|
+
for line, amount in this_func[key].items():
|
|
176
|
+
idents[line] = " " * amount
|
|
177
|
+
this_func[key] = idents
|
|
178
|
+
|
|
179
|
+
key = "ir_indent"
|
|
180
|
+
for this_func in func_data.values():
|
|
181
|
+
if key in this_func:
|
|
182
|
+
idents = {}
|
|
183
|
+
for line, ir_id in this_func[key].items():
|
|
184
|
+
idents[line] = [" " * amount for amount in ir_id]
|
|
185
|
+
this_func[key] = idents
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
from jinja2 import Template
|
|
189
|
+
except ImportError:
|
|
190
|
+
raise ImportError("please install the 'jinja2' package")
|
|
191
|
+
|
|
192
|
+
root = os.path.join(os.path.dirname(__file__))
|
|
193
|
+
template_filename = os.path.join(root, "template.html")
|
|
194
|
+
with open(template_filename, "r") as template:
|
|
195
|
+
html = template.read()
|
|
196
|
+
|
|
197
|
+
template = Template(html)
|
|
198
|
+
rendered = template.render(func_data=func_data)
|
|
199
|
+
outfile.write(rendered)
|
|
200
|
+
|
|
201
|
+
def annotate_raw(self):
|
|
202
|
+
"""
|
|
203
|
+
This returns "raw" annotation information i.e. it has no output format
|
|
204
|
+
specific markup included.
|
|
205
|
+
"""
|
|
206
|
+
python_source = SourceLines(self.func_id.func)
|
|
207
|
+
ir_lines = self.prepare_annotations()
|
|
208
|
+
line_nums = [num for num in python_source]
|
|
209
|
+
lifted_lines = [l.get_source_location() for l in self.lifted]
|
|
210
|
+
|
|
211
|
+
def add_ir_line(func_data, line):
|
|
212
|
+
line_str = line.strip()
|
|
213
|
+
line_type = ""
|
|
214
|
+
if line_str.endswith("pyobject"):
|
|
215
|
+
line_str = line_str.replace("pyobject", "")
|
|
216
|
+
line_type = "pyobject"
|
|
217
|
+
func_data["ir_lines"][num].append((line_str, line_type))
|
|
218
|
+
indent_len = len(_getindent(line))
|
|
219
|
+
func_data["ir_indent"][num].append(indent_len)
|
|
220
|
+
|
|
221
|
+
func_key = (
|
|
222
|
+
self.func_id.filename + ":" + str(self.func_id.firstlineno + 1),
|
|
223
|
+
self.signature,
|
|
224
|
+
)
|
|
225
|
+
if (
|
|
226
|
+
self.lifted_from is not None
|
|
227
|
+
and self.lifted_from[1]["num_lifted_loops"] > 0
|
|
228
|
+
):
|
|
229
|
+
# This is a lifted loop function that is being compiled. Get the
|
|
230
|
+
# numba ir for lines in loop function to use for annotating
|
|
231
|
+
# original python function that the loop was lifted from.
|
|
232
|
+
func_data = self.lifted_from[1]
|
|
233
|
+
for num in line_nums:
|
|
234
|
+
if num not in ir_lines.keys():
|
|
235
|
+
continue
|
|
236
|
+
func_data["ir_lines"][num] = []
|
|
237
|
+
func_data["ir_indent"][num] = []
|
|
238
|
+
for line in ir_lines[num]:
|
|
239
|
+
add_ir_line(func_data, line)
|
|
240
|
+
if line.strip().endswith("pyobject"):
|
|
241
|
+
func_data["python_tags"][num] = "object_tag"
|
|
242
|
+
# If any pyobject line is found, make sure original python
|
|
243
|
+
# line that was marked as a lifted loop start line is tagged
|
|
244
|
+
# as an object line instead. Lifted loop start lines should
|
|
245
|
+
# only be marked as lifted loop lines if the lifted loop
|
|
246
|
+
# was successfully compiled in nopython mode.
|
|
247
|
+
func_data["python_tags"][self.lifted_from[0]] = (
|
|
248
|
+
"object_tag"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# We're done with this lifted loop, so decrement lifted loop counter.
|
|
252
|
+
# When lifted loop counter hits zero, that means we're ready to write
|
|
253
|
+
# out annotations to html file.
|
|
254
|
+
self.lifted_from[1]["num_lifted_loops"] -= 1
|
|
255
|
+
|
|
256
|
+
elif func_key not in TypeAnnotation.func_data.keys():
|
|
257
|
+
TypeAnnotation.func_data[func_key] = {}
|
|
258
|
+
func_data = TypeAnnotation.func_data[func_key]
|
|
259
|
+
|
|
260
|
+
for i, loop in enumerate(self.lifted):
|
|
261
|
+
# Make sure that when we process each lifted loop function later,
|
|
262
|
+
# we'll know where it originally came from.
|
|
263
|
+
loop.lifted_from = (lifted_lines[i], func_data)
|
|
264
|
+
func_data["num_lifted_loops"] = self.num_lifted_loops
|
|
265
|
+
|
|
266
|
+
func_data["filename"] = self.filename
|
|
267
|
+
func_data["funcname"] = self.func_id.func_name
|
|
268
|
+
func_data["python_lines"] = []
|
|
269
|
+
func_data["python_indent"] = {}
|
|
270
|
+
func_data["python_tags"] = {}
|
|
271
|
+
func_data["ir_lines"] = {}
|
|
272
|
+
func_data["ir_indent"] = {}
|
|
273
|
+
|
|
274
|
+
for num in line_nums:
|
|
275
|
+
func_data["python_lines"].append(
|
|
276
|
+
(num, python_source[num].strip())
|
|
277
|
+
)
|
|
278
|
+
indent_len = len(_getindent(python_source[num]))
|
|
279
|
+
func_data["python_indent"][num] = indent_len
|
|
280
|
+
func_data["python_tags"][num] = ""
|
|
281
|
+
func_data["ir_lines"][num] = []
|
|
282
|
+
func_data["ir_indent"][num] = []
|
|
283
|
+
|
|
284
|
+
for line in ir_lines[num]:
|
|
285
|
+
add_ir_line(func_data, line)
|
|
286
|
+
if num in lifted_lines:
|
|
287
|
+
func_data["python_tags"][num] = "lifted_tag"
|
|
288
|
+
elif line.strip().endswith("pyobject"):
|
|
289
|
+
func_data["python_tags"][num] = "object_tag"
|
|
290
|
+
return self.func_data
|
|
291
|
+
|
|
292
|
+
def __str__(self):
|
|
293
|
+
return self.annotate()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
re_longest_white_prefix = re.compile(r"^\s*")
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _getindent(text):
|
|
300
|
+
m = re_longest_white_prefix.match(text)
|
|
301
|
+
if not m:
|
|
302
|
+
return ""
|
|
303
|
+
else:
|
|
304
|
+
return " " * len(m.group(0))
|