numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (171) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +4 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  129. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  130. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  135. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  140. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  142. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  143. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  145. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  146. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  147. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  148. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  151. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  152. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  153. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  154. numba_cuda/numba/cuda/tests/support.py +55 -15
  155. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  156. numba_cuda/numba/cuda/types.py +56 -0
  157. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  158. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  159. numba_cuda/numba/cuda/typing/context.py +751 -0
  160. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  161. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  162. numba_cuda/numba/cuda/typing/templates.py +7 -6
  163. numba_cuda/numba/cuda/ufuncs.py +3 -3
  164. numba_cuda/numba/cuda/utils.py +6 -112
  165. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
  167. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  168. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,952 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Implement transformation on Numba IR
6
+ """
7
+
8
+ from collections import namedtuple, defaultdict
9
+ import logging
10
+ import operator
11
+
12
+ from numba.core.analysis import compute_cfg_from_blocks, find_top_level_loops
13
+ from numba.core import ir, errors
14
+ from numba.cuda.core import ir_utils
15
+ from numba.core.analysis import compute_use_defs
16
+
17
+
18
+ _logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _extract_loop_lifting_candidates(cfg, blocks):
22
+ """
23
+ Returns a list of loops that are candidate for loop lifting
24
+ """
25
+
26
+ # check well-formed-ness of the loop
27
+ def same_exit_point(loop):
28
+ "all exits must point to the same location"
29
+ outedges = set()
30
+ for k in loop.exits:
31
+ succs = set(x for x, _ in cfg.successors(k))
32
+ if not succs:
33
+ # If the exit point has no successor, it contains an return
34
+ # statement, which is not handled by the looplifting code.
35
+ # Thus, this loop is not a candidate.
36
+ _logger.debug("return-statement in loop.")
37
+ return False
38
+ outedges |= succs
39
+ ok = len(outedges) == 1
40
+ _logger.debug("same_exit_point=%s (%s)", ok, outedges)
41
+ return ok
42
+
43
+ def one_entry(loop):
44
+ "there is one entry"
45
+ ok = len(loop.entries) == 1
46
+ _logger.debug("one_entry=%s", ok)
47
+ return ok
48
+
49
+ def cannot_yield(loop):
50
+ "cannot have yield inside the loop"
51
+ insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
52
+ for blk in map(blocks.__getitem__, insiders):
53
+ for inst in blk.body:
54
+ if isinstance(inst, ir.Assign):
55
+ if isinstance(inst.value, ir.Yield):
56
+ _logger.debug("has yield")
57
+ return False
58
+ _logger.debug("no yield")
59
+ return True
60
+
61
+ _logger.info("finding looplift candidates")
62
+ # the check for cfg.entry_point in the loop.entries is to prevent a bad
63
+ # rewrite where a prelude for a lifted loop would get written into block -1
64
+ # if a loop entry were in block 0
65
+ candidates = []
66
+ for loop in find_top_level_loops(cfg):
67
+ _logger.debug("top-level loop: %s", loop)
68
+ if (
69
+ same_exit_point(loop)
70
+ and one_entry(loop)
71
+ and cannot_yield(loop)
72
+ and cfg.entry_point() not in loop.entries
73
+ ):
74
+ candidates.append(loop)
75
+ _logger.debug("add candidate: %s", loop)
76
+ return candidates
77
+
78
+
79
+ def find_region_inout_vars(blocks, livemap, callfrom, returnto, body_block_ids):
80
+ """Find input and output variables to a block region."""
81
+ inputs = livemap[callfrom]
82
+ outputs = livemap[returnto]
83
+
84
+ # ensure live variables are actually used in the blocks, else remove,
85
+ # saves having to create something valid to run through postproc
86
+ # to achieve similar
87
+ loopblocks = {}
88
+ for k in body_block_ids:
89
+ loopblocks[k] = blocks[k]
90
+
91
+ used_vars = set()
92
+ def_vars = set()
93
+ defs = compute_use_defs(loopblocks)
94
+ for vs in defs.usemap.values():
95
+ used_vars |= vs
96
+ for vs in defs.defmap.values():
97
+ def_vars |= vs
98
+ used_or_defined = used_vars | def_vars
99
+
100
+ # note: sorted for stable ordering
101
+ inputs = sorted(set(inputs) & used_or_defined)
102
+ outputs = sorted(set(outputs) & used_or_defined & def_vars)
103
+ return inputs, outputs
104
+
105
+
106
+ _loop_lift_info = namedtuple(
107
+ "loop_lift_info", "loop,inputs,outputs,callfrom,returnto"
108
+ )
109
+
110
+
111
+ def _loop_lift_get_candidate_infos(cfg, blocks, livemap):
112
+ """
113
+ Returns information on looplifting candidates.
114
+ """
115
+ loops = _extract_loop_lifting_candidates(cfg, blocks)
116
+ loopinfos = []
117
+ for loop in loops:
118
+ [callfrom] = loop.entries # requirement checked earlier
119
+ an_exit = next(iter(loop.exits)) # anyone of the exit block
120
+ if len(loop.exits) > 1:
121
+ # has multiple exits
122
+ [(returnto, _)] = cfg.successors(
123
+ an_exit
124
+ ) # requirement checked earlier
125
+ else:
126
+ # does not have multiple exits
127
+ returnto = an_exit
128
+
129
+ local_block_ids = set(loop.body) | set(loop.entries) | set(loop.exits)
130
+ inputs, outputs = find_region_inout_vars(
131
+ blocks=blocks,
132
+ livemap=livemap,
133
+ callfrom=callfrom,
134
+ returnto=returnto,
135
+ body_block_ids=local_block_ids,
136
+ )
137
+
138
+ lli = _loop_lift_info(
139
+ loop=loop,
140
+ inputs=inputs,
141
+ outputs=outputs,
142
+ callfrom=callfrom,
143
+ returnto=returnto,
144
+ )
145
+ loopinfos.append(lli)
146
+
147
+ return loopinfos
148
+
149
+
150
+ def _loop_lift_modify_call_block(liftedloop, block, inputs, outputs, returnto):
151
+ """
152
+ Transform calling block from top-level function to call the lifted loop.
153
+ """
154
+ scope = block.scope
155
+ loc = block.loc
156
+ blk = ir.Block(scope=scope, loc=loc)
157
+
158
+ ir_utils.fill_block_with_call(
159
+ newblock=blk,
160
+ callee=liftedloop,
161
+ label_next=returnto,
162
+ inputs=inputs,
163
+ outputs=outputs,
164
+ )
165
+ return blk
166
+
167
+
168
+ def _loop_lift_prepare_loop_func(loopinfo, blocks):
169
+ """
170
+ Inplace transform loop blocks for use as lifted loop.
171
+ """
172
+ entry_block = blocks[loopinfo.callfrom]
173
+ scope = entry_block.scope
174
+ loc = entry_block.loc
175
+
176
+ # Lowering assumes the first block to be the one with the smallest offset
177
+ firstblk = min(blocks) - 1
178
+ blocks[firstblk] = ir_utils.fill_callee_prologue(
179
+ block=ir.Block(scope=scope, loc=loc),
180
+ inputs=loopinfo.inputs,
181
+ label_next=loopinfo.callfrom,
182
+ )
183
+ blocks[loopinfo.returnto] = ir_utils.fill_callee_epilogue(
184
+ block=ir.Block(scope=scope, loc=loc),
185
+ outputs=loopinfo.outputs,
186
+ )
187
+
188
+
189
+ def _loop_lift_modify_blocks(
190
+ func_ir, loopinfo, blocks, typingctx, targetctx, flags, locals
191
+ ):
192
+ """
193
+ Modify the block inplace to call to the lifted-loop.
194
+ Returns a dictionary of blocks of the lifted-loop.
195
+ """
196
+ from numba.core.dispatcher import LiftedLoop
197
+
198
+ # Copy loop blocks
199
+ loop = loopinfo.loop
200
+
201
+ loopblockkeys = set(loop.body) | set(loop.entries)
202
+ if len(loop.exits) > 1:
203
+ # has multiple exits
204
+ loopblockkeys |= loop.exits
205
+ loopblocks = dict((k, blocks[k].copy()) for k in loopblockkeys)
206
+ # Modify the loop blocks
207
+ _loop_lift_prepare_loop_func(loopinfo, loopblocks)
208
+ # Since Python 3.13, [END_FOR, POP_TOP] sequence becomes the start of the
209
+ # block causing the block to have line number of the start of previous loop.
210
+ # Fix this using the loc of the first getiter.
211
+ getiter_exprs = []
212
+ for blk in loopblocks.values():
213
+ getiter_exprs.extend(blk.find_exprs(op="getiter"))
214
+ first_getiter = min(getiter_exprs, key=lambda x: x.loc.line)
215
+ loop_loc = first_getiter.loc
216
+ # Create a new IR for the lifted loop
217
+ lifted_ir = func_ir.derive(
218
+ blocks=loopblocks,
219
+ arg_names=tuple(loopinfo.inputs),
220
+ arg_count=len(loopinfo.inputs),
221
+ force_non_generator=True,
222
+ loc=loop_loc,
223
+ )
224
+ liftedloop = LiftedLoop(lifted_ir, typingctx, targetctx, flags, locals)
225
+
226
+ # modify for calling into liftedloop
227
+ callblock = _loop_lift_modify_call_block(
228
+ liftedloop,
229
+ blocks[loopinfo.callfrom],
230
+ loopinfo.inputs,
231
+ loopinfo.outputs,
232
+ loopinfo.returnto,
233
+ )
234
+ # remove blocks
235
+ for k in loopblockkeys:
236
+ del blocks[k]
237
+ # update main interpreter callsite into the liftedloop
238
+ blocks[loopinfo.callfrom] = callblock
239
+ return liftedloop
240
+
241
+
242
+ def _has_multiple_loop_exits(cfg, lpinfo):
243
+ """Returns True if there is more than one exit in the loop.
244
+
245
+ NOTE: "common exits" refers to the situation where a loop exit has another
246
+ loop exit as its successor. In that case, we do not need to alter it.
247
+ """
248
+ if len(lpinfo.exits) <= 1:
249
+ return False
250
+ exits = set(lpinfo.exits)
251
+ pdom = cfg.post_dominators()
252
+
253
+ # Eliminate blocks that have other blocks as post-dominators.
254
+ processed = set()
255
+ remain = set(exits) # create a copy to work on
256
+ while remain:
257
+ node = remain.pop()
258
+ processed.add(node)
259
+ exits -= pdom[node] - {node}
260
+ remain = exits - processed
261
+
262
+ return len(exits) > 1
263
+
264
+
265
+ def _pre_looplift_transform(func_ir):
266
+ """Canonicalize loops for looplifting."""
267
+ from numba.core.postproc import PostProcessor
268
+
269
+ cfg = compute_cfg_from_blocks(func_ir.blocks)
270
+ # For every loop that has multiple exits, combine the exits into one.
271
+ for loop_info in cfg.loops().values():
272
+ if _has_multiple_loop_exits(cfg, loop_info):
273
+ func_ir, _common_key = _fix_multi_exit_blocks(
274
+ func_ir, loop_info.exits
275
+ )
276
+ # Reset and reprocess the func_ir
277
+ func_ir._reset_analysis_variables()
278
+ PostProcessor(func_ir).run()
279
+ return func_ir
280
+
281
+
282
+ def loop_lifting(func_ir, typingctx, targetctx, flags, locals):
283
+ """
284
+ Loop lifting transformation.
285
+
286
+ Given a interpreter `func_ir` returns a 2 tuple of
287
+ `(toplevel_interp, [loop0_interp, loop1_interp, ....])`
288
+ """
289
+ func_ir = _pre_looplift_transform(func_ir)
290
+ blocks = func_ir.blocks.copy()
291
+ cfg = compute_cfg_from_blocks(blocks)
292
+ loopinfos = _loop_lift_get_candidate_infos(
293
+ cfg, blocks, func_ir.variable_lifetime.livemap
294
+ )
295
+ loops = []
296
+ if loopinfos:
297
+ _logger.debug(
298
+ "loop lifting this IR with %d candidates:\n%s",
299
+ len(loopinfos),
300
+ func_ir.dump_to_string(),
301
+ )
302
+ for loopinfo in loopinfos:
303
+ lifted = _loop_lift_modify_blocks(
304
+ func_ir, loopinfo, blocks, typingctx, targetctx, flags, locals
305
+ )
306
+ loops.append(lifted)
307
+
308
+ # Make main IR
309
+ main = func_ir.derive(blocks=blocks)
310
+
311
+ return main, loops
312
+
313
+
314
+ def canonicalize_cfg_single_backedge(blocks):
315
+ """
316
+ Rewrite loops that have multiple backedges.
317
+ """
318
+ cfg = compute_cfg_from_blocks(blocks)
319
+ newblocks = blocks.copy()
320
+
321
+ def new_block_id():
322
+ return max(newblocks.keys()) + 1
323
+
324
+ def has_multiple_backedges(loop):
325
+ count = 0
326
+ for k in loop.body:
327
+ blk = blocks[k]
328
+ edges = blk.terminator.get_targets()
329
+ # is a backedge?
330
+ if loop.header in edges:
331
+ count += 1
332
+ if count > 1:
333
+ # early exit
334
+ return True
335
+ return False
336
+
337
+ def yield_loops_with_multiple_backedges():
338
+ for lp in cfg.loops().values():
339
+ if has_multiple_backedges(lp):
340
+ yield lp
341
+
342
+ def replace_target(term, src, dst):
343
+ def replace(target):
344
+ return dst if target == src else target
345
+
346
+ if isinstance(term, ir.Branch):
347
+ return ir.Branch(
348
+ cond=term.cond,
349
+ truebr=replace(term.truebr),
350
+ falsebr=replace(term.falsebr),
351
+ loc=term.loc,
352
+ )
353
+ elif isinstance(term, ir.Jump):
354
+ return ir.Jump(target=replace(term.target), loc=term.loc)
355
+ else:
356
+ assert not term.get_targets()
357
+ return term
358
+
359
+ def rewrite_single_backedge(loop):
360
+ """
361
+ Add new tail block that gathers all the backedges
362
+ """
363
+ header = loop.header
364
+ tailkey = new_block_id()
365
+ for blkkey in loop.body:
366
+ blk = newblocks[blkkey]
367
+ if header in blk.terminator.get_targets():
368
+ newblk = blk.copy()
369
+ # rewrite backedge into jumps to new tail block
370
+ newblk.body[-1] = replace_target(
371
+ blk.terminator, header, tailkey
372
+ )
373
+ newblocks[blkkey] = newblk
374
+ # create new tail block
375
+ entryblk = newblocks[header]
376
+ tailblk = ir.Block(scope=entryblk.scope, loc=entryblk.loc)
377
+ # add backedge
378
+ tailblk.append(ir.Jump(target=header, loc=tailblk.loc))
379
+ newblocks[tailkey] = tailblk
380
+
381
+ for loop in yield_loops_with_multiple_backedges():
382
+ rewrite_single_backedge(loop)
383
+
384
+ return newblocks
385
+
386
+
387
+ def canonicalize_cfg(blocks):
388
+ """
389
+ Rewrite the given blocks to canonicalize the CFG.
390
+ Returns a new dictionary of blocks.
391
+ """
392
+ return canonicalize_cfg_single_backedge(blocks)
393
+
394
+
395
+ def with_lifting(func_ir, typingctx, targetctx, flags, locals):
396
+ """With-lifting transformation
397
+
398
+ Rewrite the IR to extract all withs.
399
+ Only the top-level withs are extracted.
400
+ Returns the (the_new_ir, the_lifted_with_ir)
401
+ """
402
+ from numba.core import postproc
403
+
404
+ def dispatcher_factory(func_ir, objectmode=False, **kwargs):
405
+ from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
406
+
407
+ myflags = flags.copy()
408
+ if objectmode:
409
+ # Lifted with-block cannot looplift
410
+ myflags.enable_looplift = False
411
+ # Lifted with-block uses object mode
412
+ myflags.enable_pyobject = True
413
+ myflags.force_pyobject = True
414
+ myflags.no_cpython_wrapper = False
415
+ cls = ObjModeLiftedWith
416
+ else:
417
+ cls = LiftedWith
418
+ return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)
419
+
420
+ # find where with-contexts regions are
421
+ withs, func_ir = find_setupwiths(func_ir)
422
+
423
+ if not withs:
424
+ return func_ir, []
425
+
426
+ postproc.PostProcessor(func_ir).run() # ensure we have variable lifetime
427
+ assert func_ir.variable_lifetime
428
+ vlt = func_ir.variable_lifetime
429
+ blocks = func_ir.blocks.copy()
430
+ cfg = vlt.cfg
431
+ # For each with-regions, mutate them according to
432
+ # the kind of contextmanager
433
+ sub_irs = []
434
+ for blk_start, blk_end in withs:
435
+ body_blocks = []
436
+ for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
437
+ body_blocks.append(node)
438
+ _legalize_with_head(blocks[blk_start])
439
+ # Find the contextmanager
440
+ cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
441
+ # Mutate the body and get new IR
442
+ sub = cmkind.mutate_with_body(
443
+ func_ir,
444
+ blocks,
445
+ blk_start,
446
+ blk_end,
447
+ body_blocks,
448
+ dispatcher_factory,
449
+ extra,
450
+ )
451
+ sub_irs.append(sub)
452
+ if not sub_irs:
453
+ # Unchanged
454
+ new_ir = func_ir
455
+ else:
456
+ new_ir = func_ir.derive(blocks)
457
+ return new_ir, sub_irs
458
+
459
+
460
+ def _get_with_contextmanager(func_ir, blocks, blk_start):
461
+ """Get the global object used for the context manager"""
462
+ _illegal_cm_msg = "Illegal use of context-manager."
463
+
464
+ def get_var_dfn(var):
465
+ """Get the definition given a variable"""
466
+ return func_ir.get_definition(var)
467
+
468
+ def get_ctxmgr_obj(var_ref):
469
+ """Return the context-manager object and extra info.
470
+
471
+ The extra contains the arguments if the context-manager is used
472
+ as a call.
473
+ """
474
+ # If the contextmanager used as a Call
475
+ dfn = func_ir.get_definition(var_ref)
476
+ if isinstance(dfn, ir.Expr) and dfn.op == "call":
477
+ args = [get_var_dfn(x) for x in dfn.args]
478
+ kws = {k: get_var_dfn(v) for k, v in dfn.kws}
479
+ extra = {"args": args, "kwargs": kws}
480
+ var_ref = dfn.func
481
+ else:
482
+ extra = None
483
+
484
+ ctxobj = ir_utils.guard(ir_utils.find_outer_value, func_ir, var_ref)
485
+
486
+ # check the contextmanager object
487
+ if ctxobj is ir.UNDEFINED:
488
+ raise errors.CompilerError(
489
+ "Undefined variable used as context manager",
490
+ loc=blocks[blk_start].loc,
491
+ )
492
+
493
+ if ctxobj is None:
494
+ raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc)
495
+
496
+ return ctxobj, extra
497
+
498
+ # Scan the start of the with-region for the contextmanager
499
+ for stmt in blocks[blk_start].body:
500
+ if isinstance(stmt, ir.EnterWith):
501
+ var_ref = stmt.contextmanager
502
+ ctxobj, extra = get_ctxmgr_obj(var_ref)
503
+ if not hasattr(ctxobj, "mutate_with_body"):
504
+ raise errors.CompilerError(
505
+ "Unsupported context manager in use",
506
+ loc=blocks[blk_start].loc,
507
+ )
508
+ return ctxobj, extra
509
+ # No contextmanager found?
510
+ raise errors.CompilerError(
511
+ "malformed with-context usage",
512
+ loc=blocks[blk_start].loc,
513
+ )
514
+
515
+
516
+ def _legalize_with_head(blk):
517
+ """Given *blk*, the head block of the with-context, check that it doesn't
518
+ do anything else.
519
+ """
520
+ counters = defaultdict(int)
521
+ for stmt in blk.body:
522
+ counters[type(stmt)] += 1
523
+ if counters.pop(ir.EnterWith) != 1:
524
+ raise errors.CompilerError(
525
+ "with's head-block must have exactly 1 ENTER_WITH",
526
+ loc=blk.loc,
527
+ )
528
+ if counters.pop(ir.Jump, 0) != 1:
529
+ raise errors.CompilerError(
530
+ "with's head-block must have exactly 1 JUMP",
531
+ loc=blk.loc,
532
+ )
533
+ # Can have any number of del
534
+ counters.pop(ir.Del, None)
535
+ # There MUST NOT be any other statements
536
+ if counters:
537
+ raise errors.CompilerError(
538
+ "illegal statements in with's head-block",
539
+ loc=blk.loc,
540
+ )
541
+
542
+
543
+ def _cfg_nodes_in_region(cfg, region_begin, region_end):
544
+ """Find the set of CFG nodes that are in the given region"""
545
+ region_nodes = set()
546
+ stack = [region_begin]
547
+ while stack:
548
+ tos = stack.pop()
549
+ succlist = list(cfg.successors(tos))
550
+ # a single block function will have a empty successor list
551
+ if succlist:
552
+ succs, _ = zip(*succlist)
553
+ nodes = set(
554
+ [
555
+ node
556
+ for node in succs
557
+ if node not in region_nodes and node != region_end
558
+ ]
559
+ )
560
+ stack.extend(nodes)
561
+ region_nodes |= nodes
562
+
563
+ return region_nodes
564
+
565
+
566
+ def find_setupwiths(func_ir):
567
+ """Find all top-level with.
568
+
569
+ Returns a list of ranges for the with-regions.
570
+ """
571
+
572
+ def find_ranges(blocks):
573
+ cfg = compute_cfg_from_blocks(blocks)
574
+ sus_setups, sus_pops = set(), set()
575
+ # traverse the cfg and collect all suspected SETUP_WITH and POP_BLOCK
576
+ # statements so that we can iterate over them
577
+ for label, block in blocks.items():
578
+ for stmt in block.body:
579
+ if ir_utils.is_setup_with(stmt):
580
+ sus_setups.add(label)
581
+ if ir_utils.is_pop_block(stmt):
582
+ sus_pops.add(label)
583
+
584
+ # now that we do have the statements, iterate through them in reverse
585
+ # topo order and from each start looking for pop_blocks
586
+ setup_with_to_pop_blocks_map = defaultdict(set)
587
+ for setup_block in cfg.topo_sort(sus_setups, reverse=True):
588
+ # begin pop_block, search
589
+ to_visit, seen = [], []
590
+ to_visit.append(setup_block)
591
+ while to_visit:
592
+ # get whatever is next and record that we have seen it
593
+ block = to_visit.pop()
594
+ seen.append(block)
595
+ # go through the body of the block, looking for statements
596
+ for stmt in blocks[block].body:
597
+ # raise detected before pop_block
598
+ if ir_utils.is_raise(stmt):
599
+ raise errors.CompilerError(
600
+ "unsupported control flow due to raise "
601
+ "statements inside with block"
602
+ )
603
+ # if a pop_block, process it
604
+ if ir_utils.is_pop_block(stmt) and block in sus_pops:
605
+ # record the jump target of this block belonging to this setup
606
+ setup_with_to_pop_blocks_map[setup_block].add(block)
607
+ # remove the block from blocks to be matched
608
+ sus_pops.remove(block)
609
+ # stop looking, we have reached the frontier
610
+ break
611
+ # if we are still here, by the block terminator,
612
+ # add all its targets to the to_visit stack, unless we
613
+ # have seen them already
614
+ if ir_utils.is_terminator(stmt):
615
+ for t in stmt.get_targets():
616
+ if t not in seen:
617
+ to_visit.append(t)
618
+
619
+ return setup_with_to_pop_blocks_map
620
+
621
+ blocks = func_ir.blocks
622
+ # initial find, will return a dictionary, mapping indices of blocks
623
+ # containing SETUP_WITH statements to a set of indices of blocks containing
624
+ # POP_BLOCK statements
625
+ with_ranges_dict = find_ranges(blocks)
626
+ # rewrite the CFG in case there are multiple POP_BLOCK statements for one
627
+ # with
628
+ func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
629
+ # here we need to turn the withs back into a list of tuples so that the
630
+ # rest of the code can cope
631
+ with_ranges_tuple = [(s, list(p)[0]) for (s, p) in with_ranges_dict.items()]
632
+
633
+ # check for POP_BLOCKS with multiple outgoing edges and reject
634
+ for _, p in with_ranges_tuple:
635
+ targets = blocks[p].terminator.get_targets()
636
+ if len(targets) != 1:
637
+ raise errors.CompilerError(
638
+ "unsupported control flow: with-context contains branches "
639
+ "(i.e. break/return/raise) that can leave the block "
640
+ )
641
+ # now we check for returns inside with and reject them
642
+ for _, p in with_ranges_tuple:
643
+ target_block = blocks[p]
644
+ if ir_utils.is_return(
645
+ func_ir.blocks[target_block.terminator.get_targets()[0]].terminator
646
+ ):
647
+ _rewrite_return(func_ir, p)
648
+
649
+ # now we need to rewrite the tuple such that we have SETUP_WITH matching the
650
+ # successor of the block that contains the POP_BLOCK.
651
+ with_ranges_tuple = [
652
+ (s, func_ir.blocks[p].terminator.get_targets()[0])
653
+ for (s, p) in with_ranges_tuple
654
+ ]
655
+
656
+ # finally we check for nested with statements and reject them
657
+ with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
658
+
659
+ return with_ranges_tuple, func_ir
660
+
661
+
662
+ def _rewrite_return(func_ir, target_block_label):
663
+ """Rewrite a return block inside a with statement.
664
+
665
+ Arguments
666
+ ---------
667
+
668
+ func_ir: Function IR
669
+ the CFG to transform
670
+ target_block_label: int
671
+ the block index/label of the block containing the POP_BLOCK statement
672
+
673
+
674
+ This implements a CFG transformation to insert a block between two other
675
+ blocks.
676
+
677
+ The input situation is:
678
+
679
+ ┌───────────────┐
680
+ │ top │
681
+ │ POP_BLOCK │
682
+ │ bottom │
683
+ └───────┬───────┘
684
+
685
+ ┌───────▼───────┐
686
+ │ │
687
+ │ RETURN │
688
+ │ │
689
+ └───────────────┘
690
+
691
+ If such a pattern is detected in IR, it means there is a `return` statement
692
+ within a `with` context. The basic idea is to rewrite the CFG as follows:
693
+
694
+ ┌───────────────┐
695
+ │ top │
696
+ │ POP_BLOCK │
697
+ │ │
698
+ └───────┬───────┘
699
+
700
+ ┌───────▼───────┐
701
+ │ │
702
+ │ bottom │
703
+ │ │
704
+ └───────┬───────┘
705
+
706
+ ┌───────▼───────┐
707
+ │ │
708
+ │ RETURN │
709
+ │ │
710
+ └───────────────┘
711
+
712
+ We split the block that contains the `POP_BLOCK` statement into two blocks.
713
+ Everything from the beginning of the block up to and including the
714
+ `POP_BLOCK` statement is considered the 'top' and everything below is
715
+ considered 'bottom'. Finally the jump statements are re-wired to make sure
716
+ the CFG remains valid.
717
+
718
+ """
719
+ # the block itself from the index
720
+ target_block = func_ir.blocks[target_block_label]
721
+ # get the index of the block containing the return
722
+ target_block_successor_label = target_block.terminator.get_targets()[0]
723
+ # the return block
724
+ target_block_successor = func_ir.blocks[target_block_successor_label]
725
+
726
+ # create the new return block with an appropriate label
727
+ max_label = ir_utils.find_max_label(func_ir.blocks)
728
+ new_label = max_label + 1
729
+ # create the new return block
730
+ new_block_loc = target_block_successor.loc
731
+ new_block_scope = ir.Scope(None, loc=new_block_loc)
732
+ new_block = ir.Block(new_block_scope, loc=new_block_loc)
733
+
734
+ # Split the block containing the POP_BLOCK into top and bottom
735
+ # Block must be of the form:
736
+ # -----------------
737
+ # <some stmts>
738
+ # POP_BLOCK
739
+ # <some more stmts>
740
+ # JUMP
741
+ # -----------------
742
+ top_body, bottom_body = [], []
743
+ pop_blocks = [*target_block.find_insts(ir.PopBlock)]
744
+ assert len(pop_blocks) == 1
745
+ assert len([*target_block.find_insts(ir.Jump)]) == 1
746
+ assert isinstance(target_block.body[-1], ir.Jump)
747
+ pb_marker = pop_blocks[0]
748
+ pb_is = target_block.body.index(pb_marker)
749
+ top_body.extend(target_block.body[:pb_is])
750
+ top_body.append(ir.Jump(target_block_successor_label, target_block.loc))
751
+ bottom_body.extend(target_block.body[pb_is:-1])
752
+ bottom_body.append(ir.Jump(new_label, target_block.loc))
753
+
754
+ # get the contents of the return block
755
+ return_body = func_ir.blocks[target_block_successor_label].body
756
+ # finally, re-assign all blocks
757
+ new_block.body.extend(return_body)
758
+ target_block_successor.body.clear()
759
+ target_block_successor.body.extend(bottom_body)
760
+ target_block.body.clear()
761
+ target_block.body.extend(top_body)
762
+
763
+ # finally, append the new return block and rebuild the IR properties
764
+ func_ir.blocks[new_label] = new_block
765
+ func_ir._definitions = ir_utils.build_definitions(func_ir.blocks)
766
+ return func_ir
767
+
768
+
769
+ def _eliminate_nested_withs(with_ranges):
770
+ known_ranges = []
771
+
772
+ def within_known_range(start, end, known_ranges):
773
+ for a, b in known_ranges:
774
+ # FIXME: this should be a comparison in topological order, right
775
+ # now we are comparing the integers of the blocks, stuff probably
776
+ # works by accident.
777
+ if start > a and end < b:
778
+ return True
779
+ return False
780
+
781
+ for s, e in sorted(with_ranges):
782
+ if not within_known_range(s, e, known_ranges):
783
+ known_ranges.append((s, e))
784
+
785
+ return known_ranges
786
+
787
+
788
+ def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
789
+ """Modify the FunctionIR to merge the exit blocks of with constructs."""
790
+ for k in withs:
791
+ vs: set = withs[k]
792
+ if len(vs) > 1:
793
+ func_ir, common = _fix_multi_exit_blocks(
794
+ func_ir,
795
+ vs,
796
+ split_condition=ir_utils.is_pop_block,
797
+ )
798
+ withs[k] = {common}
799
+ return func_ir
800
+
801
+
802
+ def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None):
803
+ """Modify the FunctionIR to create a single common exit node given the
804
+ original exit nodes.
805
+
806
+ Parameters
807
+ ----------
808
+ func_ir :
809
+ The FunctionIR. Mutated inplace.
810
+ exit_nodes :
811
+ The original exit nodes. A sequence of block keys.
812
+ split_condition : callable or None
813
+ If not None, it is a callable with the signature
814
+ `split_condition(statement)` that determines if the `statement` is the
815
+ splitting point (e.g. `POP_BLOCK`) in an exit node.
816
+ If it's None, the exit node is not split.
817
+ """
818
+
819
+ # Convert the following:
820
+ #
821
+ # | |
822
+ # +-------+ +-------+
823
+ # | exit0 | | exit1 |
824
+ # +-------+ +-------+
825
+ # | |
826
+ # +-------+ +-------+
827
+ # | after0| | after1|
828
+ # +-------+ +-------+
829
+ # | |
830
+ #
831
+ # To roughly:
832
+ #
833
+ # | |
834
+ # +-------+ +-------+
835
+ # | exit0 | | exit1 |
836
+ # +-------+ +-------+
837
+ # | |
838
+ # +-----+-----+
839
+ # |
840
+ # +---------+
841
+ # | common |
842
+ # +---------+
843
+ # |
844
+ # +-------+
845
+ # | post |
846
+ # +-------+
847
+ # |
848
+ # +-----+-----+
849
+ # | |
850
+ # +-------+ +-------+
851
+ # | after0| | after1|
852
+ # +-------+ +-------+
853
+
854
+ blocks = func_ir.blocks
855
+ # Getting the scope
856
+ any_blk = min(func_ir.blocks.values())
857
+ scope = any_blk.scope
858
+ # Getting the maximum block label
859
+ max_label = max(func_ir.blocks) + 1
860
+ # Define the new common block for the new exit.
861
+ common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
862
+ common_label = max_label
863
+ max_label += 1
864
+ blocks[common_label] = common_block
865
+ # Define the new block after the exit.
866
+ post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
867
+ post_label = max_label
868
+ max_label += 1
869
+ blocks[post_label] = post_block
870
+
871
+ # Adjust each exit node
872
+ remainings = []
873
+ for i, k in enumerate(exit_nodes):
874
+ blk = blocks[k]
875
+
876
+ # split the block if needed
877
+ if split_condition is not None:
878
+ for pt, stmt in enumerate(blk.body):
879
+ if split_condition(stmt):
880
+ break
881
+ else:
882
+ # no splitting
883
+ pt = -1
884
+
885
+ before = blk.body[:pt]
886
+ after = blk.body[pt:]
887
+ remainings.append(after)
888
+
889
+ # Add control-point variable to mark which exit block this is.
890
+ blk.body = before
891
+ loc = blk.loc
892
+ blk.body.append(
893
+ ir.Assign(
894
+ value=ir.Const(i, loc=loc),
895
+ target=scope.get_or_define("$cp", loc=loc),
896
+ loc=loc,
897
+ )
898
+ )
899
+ # Replace terminator with a jump to the common block
900
+ assert not blk.is_terminated
901
+ blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc))
902
+
903
+ if split_condition is not None:
904
+ # Move the splitting statement to the common block
905
+ common_block.body.append(remainings[0][0])
906
+ assert not common_block.is_terminated
907
+ # Append jump from common block to post block
908
+ common_block.body.append(ir.Jump(post_label, loc=loc))
909
+
910
+ # Make if-else tree to jump to target
911
+ remain_blocks = []
912
+ for remain in remainings:
913
+ remain_blocks.append(max_label)
914
+ max_label += 1
915
+
916
+ switch_block = post_block
917
+ loc = ir.unknown_loc
918
+ for i, remain in enumerate(remainings):
919
+ match_expr = scope.redefine("$cp_check", loc=loc)
920
+ match_rhs = scope.redefine("$cp_rhs", loc=loc)
921
+
922
+ # Do comparison to match control-point variable to the exit block
923
+ switch_block.body.append(
924
+ ir.Assign(value=ir.Const(i, loc=loc), target=match_rhs, loc=loc),
925
+ )
926
+
927
+ # Add assignment for the comparison
928
+ switch_block.body.append(
929
+ ir.Assign(
930
+ value=ir.Expr.binop(
931
+ fn=operator.eq,
932
+ lhs=scope.get("$cp"),
933
+ rhs=match_rhs,
934
+ loc=loc,
935
+ ),
936
+ target=match_expr,
937
+ loc=loc,
938
+ ),
939
+ )
940
+
941
+ # Insert jump to the next case
942
+ [jump_target] = remain[-1].get_targets()
943
+ switch_block.body.append(
944
+ ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc),
945
+ )
946
+ switch_block = ir.Block(scope=scope, loc=loc)
947
+ blocks[remain_blocks[i]] = switch_block
948
+
949
+ # Add the final jump
950
+ switch_block.body.append(ir.Jump(jump_target, loc=loc))
951
+
952
+ return func_ir, common_label