numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (171) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +4 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  129. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  130. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  135. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  140. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  142. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  143. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  145. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  146. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  147. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  148. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  151. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  152. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  153. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  154. numba_cuda/numba/cuda/tests/support.py +55 -15
  155. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  156. numba_cuda/numba/cuda/types.py +56 -0
  157. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  158. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  159. numba_cuda/numba/cuda/typing/context.py +751 -0
  160. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  161. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  162. numba_cuda/numba/cuda/typing/templates.py +7 -6
  163. numba_cuda/numba/cuda/ufuncs.py +3 -3
  164. numba_cuda/numba/cuda/utils.py +6 -112
  165. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
  167. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  168. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,319 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from collections import namedtuple
5
+ from numba import types
6
+ from numba.core import ir
7
+ from numba.cuda.core import consts
8
+ from numba.core.analysis import compute_cfg_from_blocks
9
+
10
+
11
+ # Used to describe a nullified condition in dead branch pruning
12
+ nullified = namedtuple("nullified", "condition, taken_br, rewrite_stmt")
13
+
14
+
15
+ def dead_branch_prune(func_ir, called_args):
16
+ """
17
+ Removes dead branches based on constant inference from function args.
18
+ This directly mutates the IR.
19
+
20
+ func_ir is the IR
21
+ called_args are the actual arguments with which the function is called
22
+ """
23
+ from numba.cuda.core.ir_utils import (
24
+ get_definition,
25
+ guard,
26
+ find_const,
27
+ GuardException,
28
+ )
29
+
30
+ DEBUG = 0
31
+
32
+ def find_branches(func_ir):
33
+ # find *all* branches
34
+ branches = []
35
+ for blk in func_ir.blocks.values():
36
+ branch_or_jump = blk.body[-1]
37
+ if isinstance(branch_or_jump, ir.Branch):
38
+ branch = branch_or_jump
39
+ pred = guard(get_definition, func_ir, branch.cond.name)
40
+ if pred is not None and getattr(pred, "op", None) == "call":
41
+ function = guard(get_definition, func_ir, pred.func)
42
+ if (
43
+ function is not None
44
+ and isinstance(function, ir.Global)
45
+ and function.value is bool
46
+ ):
47
+ condition = guard(get_definition, func_ir, pred.args[0])
48
+ if condition is not None:
49
+ branches.append((branch, condition, blk))
50
+ return branches
51
+
52
+ def do_prune(take_truebr, blk):
53
+ keep = branch.truebr if take_truebr else branch.falsebr
54
+ # replace the branch with a direct jump
55
+ jmp = ir.Jump(keep, loc=branch.loc)
56
+ blk.body[-1] = jmp
57
+ return 1 if keep == branch.truebr else 0
58
+
59
+ def prune_by_type(branch, condition, blk, *conds):
60
+ # this prunes a given branch and fixes up the IR
61
+ # at least one needs to be a NoneType
62
+ lhs_cond, rhs_cond = conds
63
+ lhs_none = isinstance(lhs_cond, types.NoneType)
64
+ rhs_none = isinstance(rhs_cond, types.NoneType)
65
+ if lhs_none or rhs_none:
66
+ try:
67
+ take_truebr = condition.fn(lhs_cond, rhs_cond)
68
+ except Exception:
69
+ return False, None
70
+ if DEBUG > 0:
71
+ kill = branch.falsebr if take_truebr else branch.truebr
72
+ print(
73
+ "Pruning %s" % kill,
74
+ branch,
75
+ lhs_cond,
76
+ rhs_cond,
77
+ condition.fn,
78
+ )
79
+ taken = do_prune(take_truebr, blk)
80
+ return True, taken
81
+ return False, None
82
+
83
+ def prune_by_value(branch, condition, blk, *conds):
84
+ lhs_cond, rhs_cond = conds
85
+ try:
86
+ take_truebr = condition.fn(lhs_cond, rhs_cond)
87
+ except Exception:
88
+ return False, None
89
+ if DEBUG > 0:
90
+ kill = branch.falsebr if take_truebr else branch.truebr
91
+ print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn)
92
+ do_prune(take_truebr, blk)
93
+ # It is not safe to rewrite the predicate to a nominal value based on
94
+ # which branch is taken, the rewritten const predicate needs to
95
+ # hold the actual computed const value as something else may refer to
96
+ # it!
97
+ return True, take_truebr
98
+
99
+ def prune_by_predicate(branch, pred, blk):
100
+ try:
101
+ # Just to prevent accidents, whilst already guarded, ensure this
102
+ # is an ir.Const
103
+ if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
104
+ raise TypeError("Expected constant Numba IR node")
105
+ take_truebr = bool(pred.value)
106
+ except TypeError:
107
+ return False, None
108
+ if DEBUG > 0:
109
+ kill = branch.falsebr if take_truebr else branch.truebr
110
+ print("Pruning %s" % kill, branch, pred)
111
+ taken = do_prune(take_truebr, blk)
112
+ return True, taken
113
+
114
+ class Unknown(object):
115
+ pass
116
+
117
+ def resolve_input_arg_const(input_arg_idx):
118
+ """
119
+ Resolves an input arg to a constant (if possible)
120
+ """
121
+ input_arg_ty = called_args[input_arg_idx]
122
+
123
+ # comparing to None?
124
+ if isinstance(input_arg_ty, types.NoneType):
125
+ return input_arg_ty
126
+
127
+ # is it a kwarg default
128
+ if isinstance(input_arg_ty, types.Omitted):
129
+ val = input_arg_ty.value
130
+ if isinstance(val, types.NoneType):
131
+ return val
132
+ elif val is None:
133
+ return types.NoneType("none")
134
+
135
+ # literal type, return the type itself so comparisons like `x == None`
136
+ # still work as e.g. x = types.int64 will never be None/NoneType so
137
+ # the branch can still be pruned
138
+ return getattr(input_arg_ty, "literal_type", Unknown())
139
+
140
+ if DEBUG > 1:
141
+ print("before".center(80, "-"))
142
+ print(func_ir.dump())
143
+
144
+ phi2lbl = dict()
145
+ phi2asgn = dict()
146
+ for lbl, blk in func_ir.blocks.items():
147
+ for stmt in blk.body:
148
+ if isinstance(stmt, ir.Assign):
149
+ if isinstance(stmt.value, ir.Expr) and stmt.value.op == "phi":
150
+ phi2lbl[stmt.value] = lbl
151
+ phi2asgn[stmt.value] = stmt
152
+
153
+ # This looks for branches where:
154
+ # at least one arg of the condition is in input args and const
155
+ # at least one an arg of the condition is a const
156
+ # if the condition is met it will replace the branch with a jump
157
+ branch_info = find_branches(func_ir)
158
+ # stores conditions that have no impact post prune
159
+ nullified_conditions = []
160
+
161
+ for branch, condition, blk in branch_info:
162
+ const_conds = []
163
+ if isinstance(condition, ir.Expr) and condition.op == "binop":
164
+ prune = prune_by_value
165
+ for arg in [condition.lhs, condition.rhs]:
166
+ resolved_const = Unknown()
167
+ arg_def = guard(get_definition, func_ir, arg)
168
+ if isinstance(arg_def, ir.Arg):
169
+ # it's an e.g. literal argument to the function
170
+ resolved_const = resolve_input_arg_const(arg_def.index)
171
+ prune = prune_by_type
172
+ else:
173
+ # it's some const argument to the function, cannot use guard
174
+ # here as the const itself may be None
175
+ try:
176
+ resolved_const = find_const(func_ir, arg)
177
+ if resolved_const is None:
178
+ resolved_const = types.NoneType("none")
179
+ except GuardException:
180
+ pass
181
+
182
+ if not isinstance(resolved_const, Unknown):
183
+ const_conds.append(resolved_const)
184
+
185
+ # lhs/rhs are consts
186
+ if len(const_conds) == 2:
187
+ # prune the branch, switch the branch for an unconditional jump
188
+ prune_stat, taken = prune(branch, condition, blk, *const_conds)
189
+ if prune_stat:
190
+ # add the condition to the list of nullified conditions
191
+ nullified_conditions.append(
192
+ nullified(condition, taken, True)
193
+ )
194
+ else:
195
+ # see if this is a branch on a constant value predicate
196
+ resolved_const = Unknown()
197
+ try:
198
+ pred_call = get_definition(func_ir, branch.cond)
199
+ resolved_const = find_const(func_ir, pred_call.args[0])
200
+ if resolved_const is None:
201
+ resolved_const = types.NoneType("none")
202
+ except GuardException:
203
+ pass
204
+
205
+ if not isinstance(resolved_const, Unknown):
206
+ prune_stat, taken = prune_by_predicate(branch, condition, blk)
207
+ if prune_stat:
208
+ # add the condition to the list of nullified conditions
209
+ nullified_conditions.append(
210
+ nullified(condition, taken, False)
211
+ )
212
+
213
+ # 'ERE BE DRAGONS...
214
+ # It is the evaluation of the condition expression that often trips up type
215
+ # inference, so ideally it would be removed as it is effectively rendered
216
+ # dead by the unconditional jump if a branch was pruned. However, there may
217
+ # be references to the condition that exist in multiple places (e.g. dels)
218
+ # and we cannot run DCE here as typing has not taken place to give enough
219
+ # information to run DCE safely. Upshot of all this is the condition gets
220
+ # rewritten below into a benign const that typing will be happy with and DCE
221
+ # can remove it and its reference post typing when it is safe to do so
222
+ # (if desired). It is required that the const is assigned a value that
223
+ # indicates the branch taken as its mutated value would be read in the case
224
+ # of object mode fall back in place of the condition itself. For
225
+ # completeness the func_ir._definitions and ._consts are also updated to
226
+ # make the IR state self consistent.
227
+
228
+ deadcond = [x.condition for x in nullified_conditions]
229
+ for _, cond, blk in branch_info:
230
+ if cond in deadcond:
231
+ for x in blk.body:
232
+ if isinstance(x, ir.Assign) and x.value is cond:
233
+ # rewrite the condition as a true/false bit
234
+ nullified_info = nullified_conditions[deadcond.index(cond)]
235
+ # only do a rewrite of conditions, predicates need to retain
236
+ # their value as they may be used later.
237
+ if nullified_info.rewrite_stmt:
238
+ branch_bit = nullified_info.taken_br
239
+ x.value = ir.Const(branch_bit, loc=x.loc)
240
+ # update the specific definition to the new const
241
+ defns = func_ir._definitions[x.target.name]
242
+ repl_idx = defns.index(cond)
243
+ defns[repl_idx] = x.value
244
+
245
+ # Check post dominators of dead nodes from in the original CFG for use of
246
+ # vars that are being removed in the dead blocks which might be referred to
247
+ # by phi nodes.
248
+ #
249
+ # Multiple things to fix up:
250
+ #
251
+ # 1. Cases like:
252
+ #
253
+ # A A
254
+ # |\ |
255
+ # | B --> B
256
+ # |/ |
257
+ # C C
258
+ #
259
+ # i.e. the branch is dead but the block is still alive. In this case CFG
260
+ # simplification will fuse A-B-C and any phi in C can be updated as an
261
+ # direct assignment from the last assigned version in the dominators of the
262
+ # fused block.
263
+ #
264
+ # 2. Cases like:
265
+ #
266
+ # A A
267
+ # / \ |
268
+ # B C --> B
269
+ # \ / |
270
+ # D D
271
+ #
272
+ # i.e. the block C is dead. In this case the phis in D need updating to
273
+ # reflect the collapse of the phi condition. This should result in a direct
274
+ # assignment of the surviving version in B to the LHS of the phi in D.
275
+
276
+ new_cfg = compute_cfg_from_blocks(func_ir.blocks)
277
+ dead_blocks = new_cfg.dead_nodes()
278
+
279
+ # for all phis that are still in live blocks.
280
+ for phi, lbl in phi2lbl.items():
281
+ if lbl in dead_blocks:
282
+ continue
283
+ new_incoming = [x[0] for x in new_cfg.predecessors(lbl)]
284
+ if set(new_incoming) != set(phi.incoming_blocks):
285
+ # Something has changed in the CFG...
286
+ if len(new_incoming) == 1:
287
+ # There's now just one incoming. Replace the PHI node by a
288
+ # direct assignment
289
+ idx = phi.incoming_blocks.index(new_incoming[0])
290
+ phi2asgn[phi].value = phi.incoming_values[idx]
291
+ else:
292
+ # There's more than one incoming still, then look through the
293
+ # incoming and remove dead
294
+ ic_val_tmp = []
295
+ ic_blk_tmp = []
296
+ for ic_val, ic_blk in zip(
297
+ phi.incoming_values, phi.incoming_blocks
298
+ ):
299
+ if ic_blk in dead_blocks:
300
+ continue
301
+ else:
302
+ ic_val_tmp.append(ic_val)
303
+ ic_blk_tmp.append(ic_blk)
304
+ phi.incoming_values.clear()
305
+ phi.incoming_values.extend(ic_val_tmp)
306
+ phi.incoming_blocks.clear()
307
+ phi.incoming_blocks.extend(ic_blk_tmp)
308
+
309
+ # Remove dead blocks, this is safe as it relies on the CFG only.
310
+ for dead in dead_blocks:
311
+ del func_ir.blocks[dead]
312
+
313
+ # if conditions were nullified then consts were rewritten, update
314
+ if nullified_conditions:
315
+ func_ir._consts = consts.ConstantInference(func_ir)
316
+
317
+ if DEBUG > 1:
318
+ print("after".center(80, "-"))
319
+ print(func_ir.dump())
File without changes
@@ -0,0 +1,304 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from collections import defaultdict, OrderedDict
5
+ from collections.abc import Mapping
6
+ from contextlib import closing
7
+ import copy
8
+ import inspect
9
+ import os
10
+ import re
11
+ import textwrap
12
+ from io import StringIO
13
+
14
+ import numba.core.dispatcher
15
+ from numba.core import ir
16
+
17
+
18
+ class SourceLines(Mapping):
19
+ def __init__(self, func):
20
+ try:
21
+ lines, startno = inspect.getsourcelines(func)
22
+ except OSError:
23
+ self.lines = ()
24
+ self.startno = 0
25
+ else:
26
+ self.lines = textwrap.dedent("".join(lines)).splitlines()
27
+ self.startno = startno
28
+
29
+ def __getitem__(self, lineno):
30
+ try:
31
+ return self.lines[lineno - self.startno].rstrip()
32
+ except IndexError:
33
+ return ""
34
+
35
+ def __iter__(self):
36
+ return iter((self.startno + i) for i in range(len(self.lines)))
37
+
38
+ def __len__(self):
39
+ return len(self.lines)
40
+
41
+ @property
42
+ def avail(self):
43
+ return bool(self.lines)
44
+
45
+
46
+ class TypeAnnotation(object):
47
+ # func_data dict stores annotation data for all functions that are
48
+ # compiled. We store the data in the TypeAnnotation class since a new
49
+ # TypeAnnotation instance is created for each function that is compiled.
50
+ # For every function that is compiled, we add the type annotation data to
51
+ # this dict and write the html annotation file to disk (rewrite the html
52
+ # file for every function since we don't know if this is the last function
53
+ # to be compiled).
54
+ func_data = OrderedDict()
55
+
56
+ def __init__(
57
+ self,
58
+ func_ir,
59
+ typemap,
60
+ calltypes,
61
+ lifted,
62
+ lifted_from,
63
+ args,
64
+ return_type,
65
+ html_output=None,
66
+ ):
67
+ self.func_id = func_ir.func_id
68
+ self.blocks = func_ir.blocks
69
+ self.typemap = typemap
70
+ self.calltypes = calltypes
71
+ self.filename = func_ir.loc.filename
72
+ self.linenum = str(func_ir.loc.line)
73
+ self.signature = str(args) + " -> " + str(return_type)
74
+
75
+ # lifted loop information
76
+ self.lifted = lifted
77
+ self.num_lifted_loops = len(lifted)
78
+
79
+ # If this is a lifted loop function that is being compiled, lifted_from
80
+ # points to annotation data from function that this loop lifted function
81
+ # was lifted from. This is used to stick lifted loop annotations back
82
+ # into original function.
83
+ self.lifted_from = lifted_from
84
+
85
+ def prepare_annotations(self):
86
+ # Prepare annotations
87
+ groupedinst = defaultdict(list)
88
+ found_lifted_loop = False
89
+ # for blkid, blk in self.blocks.items():
90
+ for blkid in sorted(self.blocks.keys()):
91
+ blk = self.blocks[blkid]
92
+ groupedinst[blk.loc.line].append("label %s" % blkid)
93
+ for inst in blk.body:
94
+ lineno = inst.loc.line
95
+
96
+ if isinstance(inst, ir.Assign):
97
+ if found_lifted_loop:
98
+ atype = "XXX Lifted Loop XXX"
99
+ found_lifted_loop = False
100
+ elif (
101
+ isinstance(inst.value, ir.Expr)
102
+ and inst.value.op == "call"
103
+ ):
104
+ atype = self.calltypes[inst.value]
105
+ elif isinstance(inst.value, ir.Const) and isinstance(
106
+ inst.value.value, numba.core.dispatcher.LiftedLoop
107
+ ):
108
+ atype = "XXX Lifted Loop XXX"
109
+ found_lifted_loop = True
110
+ else:
111
+ # TODO: fix parfor lowering so that typemap is valid.
112
+ atype = self.typemap.get(inst.target.name, "<missing>")
113
+
114
+ aline = "%s = %s :: %s" % (inst.target, inst.value, atype)
115
+ elif isinstance(inst, ir.SetItem):
116
+ atype = self.calltypes[inst]
117
+ aline = "%s :: %s" % (inst, atype)
118
+ else:
119
+ aline = "%s" % inst
120
+ groupedinst[lineno].append(" %s" % aline)
121
+ return groupedinst
122
+
123
+ def annotate(self):
124
+ source = SourceLines(self.func_id.func)
125
+ # if not source.avail:
126
+ # return "Source code unavailable"
127
+
128
+ groupedinst = self.prepare_annotations()
129
+
130
+ # Format annotations
131
+ io = StringIO()
132
+ with closing(io):
133
+ if source.avail:
134
+ print("# File: %s" % self.filename, file=io)
135
+ for num in source:
136
+ srcline = source[num]
137
+ ind = _getindent(srcline)
138
+ print("%s# --- LINE %d --- " % (ind, num), file=io)
139
+ for inst in groupedinst[num]:
140
+ print("%s# %s" % (ind, inst), file=io)
141
+ print(file=io)
142
+ print(srcline, file=io)
143
+ print(file=io)
144
+ if self.lifted:
145
+ print("# The function contains lifted loops", file=io)
146
+ for loop in self.lifted:
147
+ print(
148
+ "# Loop at line %d" % loop.get_source_location(),
149
+ file=io,
150
+ )
151
+ print(
152
+ "# Has %d overloads" % len(loop.overloads), file=io
153
+ )
154
+ for cres in loop.overloads.values():
155
+ print(cres.type_annotation, file=io)
156
+ else:
157
+ print("# Source code unavailable", file=io)
158
+ for num in groupedinst:
159
+ for inst in groupedinst[num]:
160
+ print("%s" % (inst,), file=io)
161
+ print(file=io)
162
+
163
+ return io.getvalue()
164
+
165
+ def html_annotate(self, outfile):
166
+ # ensure that annotation information is assembled
167
+ self.annotate_raw()
168
+ # make a deep copy ahead of the pending mutations
169
+ func_data = copy.deepcopy(self.func_data)
170
+
171
+ key = "python_indent"
172
+ for this_func in func_data.values():
173
+ if key in this_func:
174
+ idents = {}
175
+ for line, amount in this_func[key].items():
176
+ idents[line] = "&nbsp;" * amount
177
+ this_func[key] = idents
178
+
179
+ key = "ir_indent"
180
+ for this_func in func_data.values():
181
+ if key in this_func:
182
+ idents = {}
183
+ for line, ir_id in this_func[key].items():
184
+ idents[line] = ["&nbsp;" * amount for amount in ir_id]
185
+ this_func[key] = idents
186
+
187
+ try:
188
+ from jinja2 import Template
189
+ except ImportError:
190
+ raise ImportError("please install the 'jinja2' package")
191
+
192
+ root = os.path.join(os.path.dirname(__file__))
193
+ template_filename = os.path.join(root, "template.html")
194
+ with open(template_filename, "r") as template:
195
+ html = template.read()
196
+
197
+ template = Template(html)
198
+ rendered = template.render(func_data=func_data)
199
+ outfile.write(rendered)
200
+
201
+ def annotate_raw(self):
202
+ """
203
+ This returns "raw" annotation information i.e. it has no output format
204
+ specific markup included.
205
+ """
206
+ python_source = SourceLines(self.func_id.func)
207
+ ir_lines = self.prepare_annotations()
208
+ line_nums = [num for num in python_source]
209
+ lifted_lines = [l.get_source_location() for l in self.lifted]
210
+
211
+ def add_ir_line(func_data, line):
212
+ line_str = line.strip()
213
+ line_type = ""
214
+ if line_str.endswith("pyobject"):
215
+ line_str = line_str.replace("pyobject", "")
216
+ line_type = "pyobject"
217
+ func_data["ir_lines"][num].append((line_str, line_type))
218
+ indent_len = len(_getindent(line))
219
+ func_data["ir_indent"][num].append(indent_len)
220
+
221
+ func_key = (
222
+ self.func_id.filename + ":" + str(self.func_id.firstlineno + 1),
223
+ self.signature,
224
+ )
225
+ if (
226
+ self.lifted_from is not None
227
+ and self.lifted_from[1]["num_lifted_loops"] > 0
228
+ ):
229
+ # This is a lifted loop function that is being compiled. Get the
230
+ # numba ir for lines in loop function to use for annotating
231
+ # original python function that the loop was lifted from.
232
+ func_data = self.lifted_from[1]
233
+ for num in line_nums:
234
+ if num not in ir_lines.keys():
235
+ continue
236
+ func_data["ir_lines"][num] = []
237
+ func_data["ir_indent"][num] = []
238
+ for line in ir_lines[num]:
239
+ add_ir_line(func_data, line)
240
+ if line.strip().endswith("pyobject"):
241
+ func_data["python_tags"][num] = "object_tag"
242
+ # If any pyobject line is found, make sure original python
243
+ # line that was marked as a lifted loop start line is tagged
244
+ # as an object line instead. Lifted loop start lines should
245
+ # only be marked as lifted loop lines if the lifted loop
246
+ # was successfully compiled in nopython mode.
247
+ func_data["python_tags"][self.lifted_from[0]] = (
248
+ "object_tag"
249
+ )
250
+
251
+ # We're done with this lifted loop, so decrement lifted loop counter.
252
+ # When lifted loop counter hits zero, that means we're ready to write
253
+ # out annotations to html file.
254
+ self.lifted_from[1]["num_lifted_loops"] -= 1
255
+
256
+ elif func_key not in TypeAnnotation.func_data.keys():
257
+ TypeAnnotation.func_data[func_key] = {}
258
+ func_data = TypeAnnotation.func_data[func_key]
259
+
260
+ for i, loop in enumerate(self.lifted):
261
+ # Make sure that when we process each lifted loop function later,
262
+ # we'll know where it originally came from.
263
+ loop.lifted_from = (lifted_lines[i], func_data)
264
+ func_data["num_lifted_loops"] = self.num_lifted_loops
265
+
266
+ func_data["filename"] = self.filename
267
+ func_data["funcname"] = self.func_id.func_name
268
+ func_data["python_lines"] = []
269
+ func_data["python_indent"] = {}
270
+ func_data["python_tags"] = {}
271
+ func_data["ir_lines"] = {}
272
+ func_data["ir_indent"] = {}
273
+
274
+ for num in line_nums:
275
+ func_data["python_lines"].append(
276
+ (num, python_source[num].strip())
277
+ )
278
+ indent_len = len(_getindent(python_source[num]))
279
+ func_data["python_indent"][num] = indent_len
280
+ func_data["python_tags"][num] = ""
281
+ func_data["ir_lines"][num] = []
282
+ func_data["ir_indent"][num] = []
283
+
284
+ for line in ir_lines[num]:
285
+ add_ir_line(func_data, line)
286
+ if num in lifted_lines:
287
+ func_data["python_tags"][num] = "lifted_tag"
288
+ elif line.strip().endswith("pyobject"):
289
+ func_data["python_tags"][num] = "object_tag"
290
+ return self.func_data
291
+
292
+ def __str__(self):
293
+ return self.annotate()
294
+
295
+
296
+ re_longest_white_prefix = re.compile(r"^\s*")
297
+
298
+
299
+ def _getindent(text):
300
+ m = re_longest_white_prefix.match(text)
301
+ if not m:
302
+ return ""
303
+ else:
304
+ return " " * len(m.group(0))