numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (172) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +7 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
  129. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  130. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  134. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  139. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  141. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  143. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  146. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  147. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  148. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  151. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  152. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  153. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  154. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  155. numba_cuda/numba/cuda/tests/support.py +55 -15
  156. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  157. numba_cuda/numba/cuda/types.py +56 -0
  158. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  159. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  160. numba_cuda/numba/cuda/typing/context.py +751 -0
  161. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  162. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  163. numba_cuda/numba/cuda/typing/templates.py +7 -6
  164. numba_cuda/numba/cuda/ufuncs.py +3 -3
  165. numba_cuda/numba/cuda/utils.py +6 -112
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
  167. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
  168. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
  172. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ A subpackage hosting Numba IR rewrite passes.
6
+ """
7
+
8
+ from .registry import register_rewrite, rewrite_registry, Rewrite
9
+
10
+ # Register various built-in rewrite passes
11
+ from numba.cuda.core.rewrites import (
12
+ static_getitem,
13
+ static_raise,
14
+ static_binop,
15
+ ir_print,
16
+ )
17
+
18
+ __all__ = (
19
+ "static_getitem",
20
+ "static_raise",
21
+ "static_binop",
22
+ "ir_print",
23
+ "register_rewrite",
24
+ "rewrite_registry",
25
+ "Rewrite",
26
+ )
@@ -0,0 +1,90 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.core import errors, ir
5
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
6
+
7
+
8
+ @register_rewrite("before-inference")
9
+ class RewritePrintCalls(Rewrite):
10
+ """
11
+ Rewrite calls to the print() global function to dedicated IR print() nodes.
12
+ """
13
+
14
+ def match(self, func_ir, block, typemap, calltypes):
15
+ self.prints = prints = {}
16
+ self.block = block
17
+ # Find all assignments with a right-hand print() call
18
+ for inst in block.find_insts(ir.Assign):
19
+ if isinstance(inst.value, ir.Expr) and inst.value.op == "call":
20
+ expr = inst.value
21
+ try:
22
+ callee = func_ir.infer_constant(expr.func)
23
+ except errors.ConstantInferenceError:
24
+ continue
25
+ if callee is print:
26
+ if expr.kws:
27
+ # Only positional args are supported
28
+ msg = (
29
+ "Numba's print() function implementation does not "
30
+ "support keyword arguments."
31
+ )
32
+ raise errors.UnsupportedError(msg, inst.loc)
33
+ prints[inst] = expr
34
+ return len(prints) > 0
35
+
36
+ def apply(self):
37
+ """
38
+ Rewrite `var = call <print function>(...)` as a sequence of
39
+ `print(...)` and `var = const(None)`.
40
+ """
41
+ new_block = self.block.copy()
42
+ new_block.clear()
43
+ for inst in self.block.body:
44
+ if inst in self.prints:
45
+ expr = self.prints[inst]
46
+ print_node = ir.Print(
47
+ args=expr.args, vararg=expr.vararg, loc=expr.loc
48
+ )
49
+ new_block.append(print_node)
50
+ assign_node = ir.Assign(
51
+ value=ir.Const(None, loc=expr.loc),
52
+ target=inst.target,
53
+ loc=inst.loc,
54
+ )
55
+ new_block.append(assign_node)
56
+ else:
57
+ new_block.append(inst)
58
+ return new_block
59
+
60
+
61
+ @register_rewrite("before-inference")
62
+ class DetectConstPrintArguments(Rewrite):
63
+ """
64
+ Detect and store constant arguments to print() nodes.
65
+ """
66
+
67
+ def match(self, func_ir, block, typemap, calltypes):
68
+ self.consts = consts = {}
69
+ self.block = block
70
+ for inst in block.find_insts(ir.Print):
71
+ if inst.consts:
72
+ # Already rewritten
73
+ continue
74
+ for idx, var in enumerate(inst.args):
75
+ try:
76
+ const = func_ir.infer_constant(var)
77
+ except errors.ConstantInferenceError:
78
+ continue
79
+ consts.setdefault(inst, {})[idx] = const
80
+
81
+ return len(consts) > 0
82
+
83
+ def apply(self):
84
+ """
85
+ Store detected constant arguments on their nodes.
86
+ """
87
+ for inst in self.block.body:
88
+ if inst in self.consts:
89
+ inst.consts = self.consts[inst]
90
+ return self.block
@@ -0,0 +1,104 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from collections import defaultdict
5
+
6
+ from numba.core import config
7
+
8
+
9
+ class Rewrite(object):
10
+ """Defines the abstract base class for Numba rewrites."""
11
+
12
+ def __init__(self, state=None):
13
+ """Constructor for the Rewrite class."""
14
+ pass
15
+
16
+ def match(self, func_ir, block, typemap, calltypes) -> bool:
17
+ """Overload this method to check an IR block for matching terms in the
18
+ rewrite.
19
+ """
20
+ return False
21
+
22
+ def apply(self):
23
+ """Overload this method to return a rewritten IR basic block when a
24
+ match has been found.
25
+ """
26
+ raise NotImplementedError("Abstract Rewrite.apply() called!")
27
+
28
+
29
+ class RewriteRegistry(object):
30
+ """Defines a registry for Numba rewrites."""
31
+
32
+ _kinds = frozenset(["before-inference", "after-inference"])
33
+
34
+ def __init__(self):
35
+ """Constructor for the rewrite registry. Initializes the rewrites
36
+ member to an empty list.
37
+ """
38
+ self.rewrites = defaultdict(list)
39
+
40
+ def register(self, kind):
41
+ """
42
+ Decorator adding a subclass of Rewrite to the registry for
43
+ the given *kind*.
44
+ """
45
+ if kind not in self._kinds:
46
+ raise KeyError("invalid kind %r" % (kind,))
47
+
48
+ def do_register(rewrite_cls):
49
+ if not issubclass(rewrite_cls, Rewrite):
50
+ raise TypeError(
51
+ "{0} is not a subclass of Rewrite".format(rewrite_cls)
52
+ )
53
+ self.rewrites[kind].append(rewrite_cls)
54
+ return rewrite_cls
55
+
56
+ return do_register
57
+
58
+ def apply(self, kind, state):
59
+ """Given a pipeline and a dictionary of basic blocks, exhaustively
60
+ attempt to apply all registered rewrites to all basic blocks.
61
+ """
62
+ assert kind in self._kinds
63
+ blocks = state.func_ir.blocks
64
+ old_blocks = blocks.copy()
65
+ for rewrite_cls in self.rewrites[kind]:
66
+ # Exhaustively apply a rewrite until it stops matching.
67
+ rewrite = rewrite_cls(state)
68
+ work_list = list(blocks.items())
69
+ while work_list:
70
+ key, block = work_list.pop()
71
+ matches = rewrite.match(
72
+ state.func_ir, block, state.typemap, state.calltypes
73
+ )
74
+ if matches:
75
+ if config.DEBUG or config.DUMP_IR:
76
+ print("_" * 70)
77
+ print("REWRITING (%s):" % rewrite_cls.__name__)
78
+ block.dump()
79
+ print("_" * 60)
80
+ new_block = rewrite.apply()
81
+ blocks[key] = new_block
82
+ work_list.append((key, new_block))
83
+ if config.DEBUG or config.DUMP_IR:
84
+ new_block.dump()
85
+ print("_" * 70)
86
+ # If any blocks were changed, perform a sanity check.
87
+ for key, block in blocks.items():
88
+ if block != old_blocks[key]:
89
+ block.verify()
90
+
91
+ # Some passes, e.g. _inline_const_arraycall are known to occasionally
92
+ # do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
93
+ # things with ir.Del, but the placement is not optimal. The lines below
94
+ # fix-up the IR so that ref counts are valid and optimally placed,
95
+ # see #4093 for context. This has to be run here opposed to in
96
+ # apply() as the CFG needs computing so full IR is needed.
97
+ from numba.core import postproc
98
+
99
+ post_proc = postproc.PostProcessor(state.func_ir)
100
+ post_proc.run()
101
+
102
+
103
+ rewrite_registry = RewriteRegistry()
104
+ register_rewrite = rewrite_registry.register
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.core import errors, ir
5
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
6
+
7
+
8
+ @register_rewrite("before-inference")
9
+ class DetectStaticBinops(Rewrite):
10
+ """
11
+ Detect constant arguments to select binops.
12
+ """
13
+
14
+ # Those operators can benefit from a constant-inferred argument
15
+ rhs_operators = {"**"}
16
+
17
+ def match(self, func_ir, block, typemap, calltypes):
18
+ self.static_lhs = {}
19
+ self.static_rhs = {}
20
+ self.block = block
21
+ # Find binop expressions with a constant lhs or rhs
22
+ for expr in block.find_exprs(op="binop"):
23
+ try:
24
+ if (
25
+ expr.fn in self.rhs_operators
26
+ and expr.static_rhs is ir.UNDEFINED
27
+ ):
28
+ self.static_rhs[expr] = func_ir.infer_constant(expr.rhs)
29
+ except errors.ConstantInferenceError:
30
+ continue
31
+
32
+ return len(self.static_lhs) > 0 or len(self.static_rhs) > 0
33
+
34
+ def apply(self):
35
+ """
36
+ Store constant arguments that were detected in match().
37
+ """
38
+ for expr, rhs in self.static_rhs.items():
39
+ expr.static_rhs = rhs
40
+ return self.block
@@ -0,0 +1,187 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.core import errors, types, ir
5
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
6
+
7
+
8
+ @register_rewrite("before-inference")
9
+ class RewriteConstGetitems(Rewrite):
10
+ """
11
+ Rewrite IR expressions of the kind `getitem(value=arr, index=$constXX)`
12
+ where `$constXX` is a known constant as
13
+ `static_getitem(value=arr, index=<constant value>)`.
14
+ """
15
+
16
+ def match(self, func_ir, block, typemap, calltypes):
17
+ self.getitems = getitems = {}
18
+ self.block = block
19
+ # Detect all getitem expressions and find which ones can be
20
+ # rewritten
21
+ for expr in block.find_exprs(op="getitem"):
22
+ if expr.op == "getitem":
23
+ try:
24
+ const = func_ir.infer_constant(expr.index)
25
+ except errors.ConstantInferenceError:
26
+ continue
27
+ getitems[expr] = const
28
+
29
+ return len(getitems) > 0
30
+
31
+ def apply(self):
32
+ """
33
+ Rewrite all matching getitems as static_getitems.
34
+ """
35
+ new_block = self.block.copy()
36
+ new_block.clear()
37
+ for inst in self.block.body:
38
+ if isinstance(inst, ir.Assign):
39
+ expr = inst.value
40
+ if expr in self.getitems:
41
+ const = self.getitems[expr]
42
+ new_expr = ir.Expr.static_getitem(
43
+ value=expr.value,
44
+ index=const,
45
+ index_var=expr.index,
46
+ loc=expr.loc,
47
+ )
48
+ inst = ir.Assign(
49
+ value=new_expr, target=inst.target, loc=inst.loc
50
+ )
51
+ new_block.append(inst)
52
+ return new_block
53
+
54
+
55
+ @register_rewrite("after-inference")
56
+ class RewriteStringLiteralGetitems(Rewrite):
57
+ """
58
+ Rewrite IR expressions of the kind `getitem(value=arr, index=$XX)`
59
+ where `$XX` is a StringLiteral value as
60
+ `static_getitem(value=arr, index=<literal value>)`.
61
+ """
62
+
63
+ def match(self, func_ir, block, typemap, calltypes):
64
+ """
65
+ Detect all getitem expressions and find which ones have
66
+ string literal indexes
67
+ """
68
+ self.getitems = getitems = {}
69
+ self.block = block
70
+ self.calltypes = calltypes
71
+ for expr in block.find_exprs(op="getitem"):
72
+ if expr.op == "getitem":
73
+ index_ty = typemap[expr.index.name]
74
+ if isinstance(index_ty, types.StringLiteral):
75
+ getitems[expr] = (expr.index, index_ty.literal_value)
76
+
77
+ return len(getitems) > 0
78
+
79
+ def apply(self):
80
+ """
81
+ Rewrite all matching getitems as static_getitems where the index
82
+ is the literal value of the string.
83
+ """
84
+ new_block = ir.Block(self.block.scope, self.block.loc)
85
+ for inst in self.block.body:
86
+ if isinstance(inst, ir.Assign):
87
+ expr = inst.value
88
+ if expr in self.getitems:
89
+ const, lit_val = self.getitems[expr]
90
+ new_expr = ir.Expr.static_getitem(
91
+ value=expr.value,
92
+ index=lit_val,
93
+ index_var=expr.index,
94
+ loc=expr.loc,
95
+ )
96
+ self.calltypes[new_expr] = self.calltypes[expr]
97
+ inst = ir.Assign(
98
+ value=new_expr, target=inst.target, loc=inst.loc
99
+ )
100
+ new_block.append(inst)
101
+ return new_block
102
+
103
+
104
+ @register_rewrite("after-inference")
105
+ class RewriteStringLiteralSetitems(Rewrite):
106
+ """
107
+ Rewrite IR expressions of the kind `setitem(value=arr, index=$XX, value=)`
108
+ where `$XX` is a StringLiteral value as
109
+ `static_setitem(value=arr, index=<literal value>, value=)`.
110
+ """
111
+
112
+ def match(self, func_ir, block, typemap, calltypes):
113
+ """
114
+ Detect all setitem expressions and find which ones have
115
+ string literal indexes
116
+ """
117
+ self.setitems = setitems = {}
118
+ self.block = block
119
+ self.calltypes = calltypes
120
+ for inst in block.find_insts(ir.SetItem):
121
+ index_ty = typemap[inst.index.name]
122
+ if isinstance(index_ty, types.StringLiteral):
123
+ setitems[inst] = (inst.index, index_ty.literal_value)
124
+
125
+ return len(setitems) > 0
126
+
127
+ def apply(self):
128
+ """
129
+ Rewrite all matching setitems as static_setitems where the index
130
+ is the literal value of the string.
131
+ """
132
+ new_block = ir.Block(self.block.scope, self.block.loc)
133
+ for inst in self.block.body:
134
+ if isinstance(inst, ir.SetItem):
135
+ if inst in self.setitems:
136
+ const, lit_val = self.setitems[inst]
137
+ new_inst = ir.StaticSetItem(
138
+ target=inst.target,
139
+ index=lit_val,
140
+ index_var=inst.index,
141
+ value=inst.value,
142
+ loc=inst.loc,
143
+ )
144
+ self.calltypes[new_inst] = self.calltypes[inst]
145
+ inst = new_inst
146
+ new_block.append(inst)
147
+ return new_block
148
+
149
+
150
+ @register_rewrite("before-inference")
151
+ class RewriteConstSetitems(Rewrite):
152
+ """
153
+ Rewrite IR statements of the kind `setitem(target=arr, index=$constXX, ...)`
154
+ where `$constXX` is a known constant as
155
+ `static_setitem(target=arr, index=<constant value>, ...)`.
156
+ """
157
+
158
+ def match(self, func_ir, block, typemap, calltypes):
159
+ self.setitems = setitems = {}
160
+ self.block = block
161
+ # Detect all setitem statements and find which ones can be
162
+ # rewritten
163
+ for inst in block.find_insts(ir.SetItem):
164
+ try:
165
+ const = func_ir.infer_constant(inst.index)
166
+ except errors.ConstantInferenceError:
167
+ continue
168
+ setitems[inst] = const
169
+
170
+ return len(setitems) > 0
171
+
172
+ def apply(self):
173
+ """
174
+ Rewrite all matching setitems as static_setitems.
175
+ """
176
+ new_block = self.block.copy()
177
+ new_block.clear()
178
+ for inst in self.block.body:
179
+ if inst in self.setitems:
180
+ const = self.setitems[inst]
181
+ new_inst = ir.StaticSetItem(
182
+ inst.target, const, inst.index, inst.value, inst.loc
183
+ )
184
+ new_block.append(new_inst)
185
+ else:
186
+ new_block.append(inst)
187
+ return new_block
@@ -0,0 +1,98 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.core import errors, consts, ir
5
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
6
+
7
+
8
+ @register_rewrite("before-inference")
9
+ class RewriteConstRaises(Rewrite):
10
+ """
11
+ Rewrite IR statements of the kind `raise(value)`
12
+ where `value` is the result of instantiating an exception with
13
+ constant arguments
14
+ into `static_raise(exception_type, constant args)`.
15
+
16
+ This allows lowering in nopython mode, where one can't instantiate
17
+ exception instances from runtime data.
18
+ """
19
+
20
+ def _is_exception_type(self, const):
21
+ return isinstance(const, type) and issubclass(const, Exception)
22
+
23
+ def _break_constant(self, const, loc):
24
+ """
25
+ Break down constant exception.
26
+ """
27
+ if isinstance(const, tuple): # it's a tuple(exception class, args)
28
+ if not self._is_exception_type(const[0]):
29
+ msg = "Encountered unsupported exception constant %r"
30
+ raise errors.UnsupportedError(msg % (const[0],), loc)
31
+ return const[0], tuple(const[1])
32
+ elif self._is_exception_type(const):
33
+ return const, None
34
+ else:
35
+ if isinstance(const, str):
36
+ msg = (
37
+ "Directly raising a string constant as an exception is "
38
+ "not supported."
39
+ )
40
+ else:
41
+ msg = "Encountered unsupported constant type used for exception"
42
+ raise errors.UnsupportedError(msg, loc)
43
+
44
+ def _try_infer_constant(self, func_ir, inst):
45
+ try:
46
+ return func_ir.infer_constant(inst.exception)
47
+ except consts.ConstantInferenceError:
48
+ # not a static exception
49
+ return None
50
+
51
+ def match(self, func_ir, block, typemap, calltypes):
52
+ self.raises = raises = {}
53
+ self.tryraises = tryraises = {}
54
+ self.block = block
55
+ # Detect all raise statements and find which ones can be
56
+ # rewritten
57
+ for inst in block.find_insts((ir.Raise, ir.TryRaise)):
58
+ if inst.exception is None:
59
+ # re-reraise
60
+ exc_type, exc_args = None, None
61
+ else:
62
+ # raise <something> => find the definition site for <something>
63
+ const = self._try_infer_constant(func_ir, inst)
64
+
65
+ # failure to infer constant indicates this isn't a static
66
+ # exception
67
+ if const is None:
68
+ continue
69
+
70
+ loc = inst.exception.loc
71
+ exc_type, exc_args = self._break_constant(const, loc)
72
+
73
+ if isinstance(inst, ir.Raise):
74
+ raises[inst] = exc_type, exc_args
75
+ elif isinstance(inst, ir.TryRaise):
76
+ tryraises[inst] = exc_type, exc_args
77
+ else:
78
+ raise ValueError("unexpected: {}".format(type(inst)))
79
+ return (len(raises) + len(tryraises)) > 0
80
+
81
+ def apply(self):
82
+ """
83
+ Rewrite all matching setitems as static_setitems.
84
+ """
85
+ new_block = self.block.copy()
86
+ new_block.clear()
87
+ for inst in self.block.body:
88
+ if inst in self.raises:
89
+ exc_type, exc_args = self.raises[inst]
90
+ new_inst = ir.StaticRaise(exc_type, exc_args, inst.loc)
91
+ new_block.append(new_inst)
92
+ elif inst in self.tryraises:
93
+ exc_type, exc_args = self.tryraises[inst]
94
+ new_inst = ir.StaticTryRaise(exc_type, exc_args, inst.loc)
95
+ new_block.append(new_inst)
96
+ else:
97
+ new_block.append(inst)
98
+ return new_block