numba-cuda 0.22.0__cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (487) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +580 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpython-313-aarch64-linux-gnu.so +0 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpython-313-aarch64-linux-gnu.so +0 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cpython-313-aarch64-linux-gnu.so +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpython-313-aarch64-linux-gnu.so +0 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cpython-313-aarch64-linux-gnu.so +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +543 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +954 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3238 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +562 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +983 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +997 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +155 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsics.py +531 -0
  161. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  162. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  163. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  164. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  165. numba_cuda/numba/cuda/libdevice.py +3386 -0
  166. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  167. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  168. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  169. numba_cuda/numba/cuda/locks.py +19 -0
  170. numba_cuda/numba/cuda/lowering.py +1980 -0
  171. numba_cuda/numba/cuda/mathimpl.py +374 -0
  172. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  173. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  175. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  178. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  179. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  180. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  181. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  182. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  183. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  184. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  185. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  186. numba_cuda/numba/cuda/misc/literal.py +28 -0
  187. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  188. numba_cuda/numba/cuda/misc/special.py +94 -0
  189. numba_cuda/numba/cuda/models.py +56 -0
  190. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  191. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  192. numba_cuda/numba/cuda/np/extensions.py +11 -0
  193. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  194. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  195. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  196. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  197. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  198. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  199. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  200. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  201. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  202. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  203. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  204. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  206. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  207. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  208. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  209. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  210. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  211. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  212. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  213. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  214. numba_cuda/numba/cuda/printimpl.py +126 -0
  215. numba_cuda/numba/cuda/random.py +308 -0
  216. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  217. numba_cuda/numba/cuda/serialize.py +267 -0
  218. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  219. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  220. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  221. numba_cuda/numba/cuda/simulator/api.py +179 -0
  222. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  223. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  224. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  236. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  237. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  238. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  239. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  241. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  242. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  243. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  244. numba_cuda/numba/cuda/simulator_init.py +18 -0
  245. numba_cuda/numba/cuda/stubs.py +624 -0
  246. numba_cuda/numba/cuda/target.py +505 -0
  247. numba_cuda/numba/cuda/testing.py +347 -0
  248. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  249. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  251. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  252. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  253. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  254. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  255. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +191 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +200 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  285. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  286. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  289. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  290. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  291. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  292. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  293. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  294. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  295. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +978 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +446 -0
  396. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  397. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  399. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  400. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  401. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  402. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  403. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  404. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  406. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  407. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  424. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  425. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +452 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  430. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  431. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  433. numba_cuda/numba/cuda/tests/support.py +900 -0
  434. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  435. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  436. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  437. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  438. numba_cuda/numba/cuda/types/__init__.py +233 -0
  439. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  440. numba_cuda/numba/cuda/types/abstract.py +9 -0
  441. numba_cuda/numba/cuda/types/common.py +9 -0
  442. numba_cuda/numba/cuda/types/containers.py +9 -0
  443. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  444. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  445. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  446. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  447. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  448. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  449. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  450. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  451. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  452. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  453. numba_cuda/numba/cuda/types/function_type.py +11 -0
  454. numba_cuda/numba/cuda/types/functions.py +9 -0
  455. numba_cuda/numba/cuda/types/iterators.py +9 -0
  456. numba_cuda/numba/cuda/types/misc.py +9 -0
  457. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  458. numba_cuda/numba/cuda/types/scalars.py +9 -0
  459. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  460. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  461. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  462. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  463. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  464. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  465. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  466. numba_cuda/numba/cuda/typing/collections.py +138 -0
  467. numba_cuda/numba/cuda/typing/context.py +782 -0
  468. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  469. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  470. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  471. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  472. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  473. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  474. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  475. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  476. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  477. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  478. numba_cuda/numba/cuda/ufuncs.py +746 -0
  479. numba_cuda/numba/cuda/utils.py +724 -0
  480. numba_cuda/numba/cuda/vector_types.py +214 -0
  481. numba_cuda/numba/cuda/vectorizers.py +260 -0
  482. numba_cuda-0.22.0.dist-info/METADATA +109 -0
  483. numba_cuda-0.22.0.dist-info/RECORD +487 -0
  484. numba_cuda-0.22.0.dist-info/WHEEL +6 -0
  485. numba_cuda-0.22.0.dist-info/licenses/LICENSE +26 -0
  486. numba_cuda-0.22.0.dist-info/licenses/LICENSE.numba +24 -0
  487. numba_cuda-0.22.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,956 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Implement transformation on Numba IR
6
+ """
7
+
8
+ from collections import namedtuple, defaultdict
9
+ import logging
10
+ import operator
11
+
12
+ from numba.cuda.core.analysis import (
13
+ compute_cfg_from_blocks,
14
+ find_top_level_loops,
15
+ )
16
+ from numba.cuda.core import ir
17
+ from numba.cuda.core import errors
18
+ from numba.cuda.core import ir_utils
19
+ from numba.cuda.core.analysis import compute_use_defs
20
+
21
+
22
+ _logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _extract_loop_lifting_candidates(cfg, blocks):
26
+ """
27
+ Returns a list of loops that are candidate for loop lifting
28
+ """
29
+
30
+ # check well-formed-ness of the loop
31
+ def same_exit_point(loop):
32
+ "all exits must point to the same location"
33
+ outedges = set()
34
+ for k in loop.exits:
35
+ succs = set(x for x, _ in cfg.successors(k))
36
+ if not succs:
37
+ # If the exit point has no successor, it contains an return
38
+ # statement, which is not handled by the looplifting code.
39
+ # Thus, this loop is not a candidate.
40
+ _logger.debug("return-statement in loop.")
41
+ return False
42
+ outedges |= succs
43
+ ok = len(outedges) == 1
44
+ _logger.debug("same_exit_point=%s (%s)", ok, outedges)
45
+ return ok
46
+
47
+ def one_entry(loop):
48
+ "there is one entry"
49
+ ok = len(loop.entries) == 1
50
+ _logger.debug("one_entry=%s", ok)
51
+ return ok
52
+
53
+ def cannot_yield(loop):
54
+ "cannot have yield inside the loop"
55
+ insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
56
+ for blk in map(blocks.__getitem__, insiders):
57
+ for inst in blk.body:
58
+ if isinstance(inst, ir.Assign):
59
+ if isinstance(inst.value, ir.Yield):
60
+ _logger.debug("has yield")
61
+ return False
62
+ _logger.debug("no yield")
63
+ return True
64
+
65
+ _logger.info("finding looplift candidates")
66
+ # the check for cfg.entry_point in the loop.entries is to prevent a bad
67
+ # rewrite where a prelude for a lifted loop would get written into block -1
68
+ # if a loop entry were in block 0
69
+ candidates = []
70
+ for loop in find_top_level_loops(cfg):
71
+ _logger.debug("top-level loop: %s", loop)
72
+ if (
73
+ same_exit_point(loop)
74
+ and one_entry(loop)
75
+ and cannot_yield(loop)
76
+ and cfg.entry_point() not in loop.entries
77
+ ):
78
+ candidates.append(loop)
79
+ _logger.debug("add candidate: %s", loop)
80
+ return candidates
81
+
82
+
83
+ def find_region_inout_vars(blocks, livemap, callfrom, returnto, body_block_ids):
84
+ """Find input and output variables to a block region."""
85
+ inputs = livemap[callfrom]
86
+ outputs = livemap[returnto]
87
+
88
+ # ensure live variables are actually used in the blocks, else remove,
89
+ # saves having to create something valid to run through postproc
90
+ # to achieve similar
91
+ loopblocks = {}
92
+ for k in body_block_ids:
93
+ loopblocks[k] = blocks[k]
94
+
95
+ used_vars = set()
96
+ def_vars = set()
97
+ defs = compute_use_defs(loopblocks)
98
+ for vs in defs.usemap.values():
99
+ used_vars |= vs
100
+ for vs in defs.defmap.values():
101
+ def_vars |= vs
102
+ used_or_defined = used_vars | def_vars
103
+
104
+ # note: sorted for stable ordering
105
+ inputs = sorted(set(inputs) & used_or_defined)
106
+ outputs = sorted(set(outputs) & used_or_defined & def_vars)
107
+ return inputs, outputs
108
+
109
+
110
+ _loop_lift_info = namedtuple(
111
+ "loop_lift_info", "loop,inputs,outputs,callfrom,returnto"
112
+ )
113
+
114
+
115
+ def _loop_lift_get_candidate_infos(cfg, blocks, livemap):
116
+ """
117
+ Returns information on looplifting candidates.
118
+ """
119
+ loops = _extract_loop_lifting_candidates(cfg, blocks)
120
+ loopinfos = []
121
+ for loop in loops:
122
+ [callfrom] = loop.entries # requirement checked earlier
123
+ an_exit = next(iter(loop.exits)) # anyone of the exit block
124
+ if len(loop.exits) > 1:
125
+ # has multiple exits
126
+ [(returnto, _)] = cfg.successors(
127
+ an_exit
128
+ ) # requirement checked earlier
129
+ else:
130
+ # does not have multiple exits
131
+ returnto = an_exit
132
+
133
+ local_block_ids = set(loop.body) | set(loop.entries) | set(loop.exits)
134
+ inputs, outputs = find_region_inout_vars(
135
+ blocks=blocks,
136
+ livemap=livemap,
137
+ callfrom=callfrom,
138
+ returnto=returnto,
139
+ body_block_ids=local_block_ids,
140
+ )
141
+
142
+ lli = _loop_lift_info(
143
+ loop=loop,
144
+ inputs=inputs,
145
+ outputs=outputs,
146
+ callfrom=callfrom,
147
+ returnto=returnto,
148
+ )
149
+ loopinfos.append(lli)
150
+
151
+ return loopinfos
152
+
153
+
154
+ def _loop_lift_modify_call_block(liftedloop, block, inputs, outputs, returnto):
155
+ """
156
+ Transform calling block from top-level function to call the lifted loop.
157
+ """
158
+ scope = block.scope
159
+ loc = block.loc
160
+ blk = ir.Block(scope=scope, loc=loc)
161
+
162
+ ir_utils.fill_block_with_call(
163
+ newblock=blk,
164
+ callee=liftedloop,
165
+ label_next=returnto,
166
+ inputs=inputs,
167
+ outputs=outputs,
168
+ )
169
+ return blk
170
+
171
+
172
+ def _loop_lift_prepare_loop_func(loopinfo, blocks):
173
+ """
174
+ Inplace transform loop blocks for use as lifted loop.
175
+ """
176
+ entry_block = blocks[loopinfo.callfrom]
177
+ scope = entry_block.scope
178
+ loc = entry_block.loc
179
+
180
+ # Lowering assumes the first block to be the one with the smallest offset
181
+ firstblk = min(blocks) - 1
182
+ blocks[firstblk] = ir_utils.fill_callee_prologue(
183
+ block=ir.Block(scope=scope, loc=loc),
184
+ inputs=loopinfo.inputs,
185
+ label_next=loopinfo.callfrom,
186
+ )
187
+ blocks[loopinfo.returnto] = ir_utils.fill_callee_epilogue(
188
+ block=ir.Block(scope=scope, loc=loc),
189
+ outputs=loopinfo.outputs,
190
+ )
191
+
192
+
193
+ def _loop_lift_modify_blocks(
194
+ func_ir, loopinfo, blocks, typingctx, targetctx, flags, locals
195
+ ):
196
+ """
197
+ Modify the block inplace to call to the lifted-loop.
198
+ Returns a dictionary of blocks of the lifted-loop.
199
+ """
200
+ from numba.cuda.dispatcher import LiftedLoop
201
+
202
+ # Copy loop blocks
203
+ loop = loopinfo.loop
204
+
205
+ loopblockkeys = set(loop.body) | set(loop.entries)
206
+ if len(loop.exits) > 1:
207
+ # has multiple exits
208
+ loopblockkeys |= loop.exits
209
+ loopblocks = dict((k, blocks[k].copy()) for k in loopblockkeys)
210
+ # Modify the loop blocks
211
+ _loop_lift_prepare_loop_func(loopinfo, loopblocks)
212
+ # Since Python 3.13, [END_FOR, POP_TOP] sequence becomes the start of the
213
+ # block causing the block to have line number of the start of previous loop.
214
+ # Fix this using the loc of the first getiter.
215
+ getiter_exprs = []
216
+ for blk in loopblocks.values():
217
+ getiter_exprs.extend(blk.find_exprs(op="getiter"))
218
+ first_getiter = min(getiter_exprs, key=lambda x: x.loc.line)
219
+ loop_loc = first_getiter.loc
220
+ # Create a new IR for the lifted loop
221
+ lifted_ir = func_ir.derive(
222
+ blocks=loopblocks,
223
+ arg_names=tuple(loopinfo.inputs),
224
+ arg_count=len(loopinfo.inputs),
225
+ force_non_generator=True,
226
+ loc=loop_loc,
227
+ )
228
+ liftedloop = LiftedLoop(lifted_ir, typingctx, targetctx, flags, locals)
229
+
230
+ # modify for calling into liftedloop
231
+ callblock = _loop_lift_modify_call_block(
232
+ liftedloop,
233
+ blocks[loopinfo.callfrom],
234
+ loopinfo.inputs,
235
+ loopinfo.outputs,
236
+ loopinfo.returnto,
237
+ )
238
+ # remove blocks
239
+ for k in loopblockkeys:
240
+ del blocks[k]
241
+ # update main interpreter callsite into the liftedloop
242
+ blocks[loopinfo.callfrom] = callblock
243
+ return liftedloop
244
+
245
+
246
+ def _has_multiple_loop_exits(cfg, lpinfo):
247
+ """Returns True if there is more than one exit in the loop.
248
+
249
+ NOTE: "common exits" refers to the situation where a loop exit has another
250
+ loop exit as its successor. In that case, we do not need to alter it.
251
+ """
252
+ if len(lpinfo.exits) <= 1:
253
+ return False
254
+ exits = set(lpinfo.exits)
255
+ pdom = cfg.post_dominators()
256
+
257
+ # Eliminate blocks that have other blocks as post-dominators.
258
+ processed = set()
259
+ remain = set(exits) # create a copy to work on
260
+ while remain:
261
+ node = remain.pop()
262
+ processed.add(node)
263
+ exits -= pdom[node] - {node}
264
+ remain = exits - processed
265
+
266
+ return len(exits) > 1
267
+
268
+
269
+ def _pre_looplift_transform(func_ir):
270
+ """Canonicalize loops for looplifting."""
271
+ from numba.cuda.core.postproc import PostProcessor
272
+
273
+ cfg = compute_cfg_from_blocks(func_ir.blocks)
274
+ # For every loop that has multiple exits, combine the exits into one.
275
+ for loop_info in cfg.loops().values():
276
+ if _has_multiple_loop_exits(cfg, loop_info):
277
+ func_ir, _common_key = _fix_multi_exit_blocks(
278
+ func_ir, loop_info.exits
279
+ )
280
+ # Reset and reprocess the func_ir
281
+ func_ir._reset_analysis_variables()
282
+ PostProcessor(func_ir).run()
283
+ return func_ir
284
+
285
+
286
+ def loop_lifting(func_ir, typingctx, targetctx, flags, locals):
287
+ """
288
+ Loop lifting transformation.
289
+
290
+ Given a interpreter `func_ir` returns a 2 tuple of
291
+ `(toplevel_interp, [loop0_interp, loop1_interp, ....])`
292
+ """
293
+ func_ir = _pre_looplift_transform(func_ir)
294
+ blocks = func_ir.blocks.copy()
295
+ cfg = compute_cfg_from_blocks(blocks)
296
+ loopinfos = _loop_lift_get_candidate_infos(
297
+ cfg, blocks, func_ir.variable_lifetime.livemap
298
+ )
299
+ loops = []
300
+ if loopinfos:
301
+ _logger.debug(
302
+ "loop lifting this IR with %d candidates:\n%s",
303
+ len(loopinfos),
304
+ func_ir.dump_to_string(),
305
+ )
306
+ for loopinfo in loopinfos:
307
+ lifted = _loop_lift_modify_blocks(
308
+ func_ir, loopinfo, blocks, typingctx, targetctx, flags, locals
309
+ )
310
+ loops.append(lifted)
311
+
312
+ # Make main IR
313
+ main = func_ir.derive(blocks=blocks)
314
+
315
+ return main, loops
316
+
317
+
318
+ def canonicalize_cfg_single_backedge(blocks):
319
+ """
320
+ Rewrite loops that have multiple backedges.
321
+ """
322
+ cfg = compute_cfg_from_blocks(blocks)
323
+ newblocks = blocks.copy()
324
+
325
+ def new_block_id():
326
+ return max(newblocks.keys()) + 1
327
+
328
+ def has_multiple_backedges(loop):
329
+ count = 0
330
+ for k in loop.body:
331
+ blk = blocks[k]
332
+ edges = blk.terminator.get_targets()
333
+ # is a backedge?
334
+ if loop.header in edges:
335
+ count += 1
336
+ if count > 1:
337
+ # early exit
338
+ return True
339
+ return False
340
+
341
+ def yield_loops_with_multiple_backedges():
342
+ for lp in cfg.loops().values():
343
+ if has_multiple_backedges(lp):
344
+ yield lp
345
+
346
+ def replace_target(term, src, dst):
347
+ def replace(target):
348
+ return dst if target == src else target
349
+
350
+ if isinstance(term, ir.Branch):
351
+ return ir.Branch(
352
+ cond=term.cond,
353
+ truebr=replace(term.truebr),
354
+ falsebr=replace(term.falsebr),
355
+ loc=term.loc,
356
+ )
357
+ elif isinstance(term, ir.Jump):
358
+ return ir.Jump(target=replace(term.target), loc=term.loc)
359
+ else:
360
+ assert not term.get_targets()
361
+ return term
362
+
363
+ def rewrite_single_backedge(loop):
364
+ """
365
+ Add new tail block that gathers all the backedges
366
+ """
367
+ header = loop.header
368
+ tailkey = new_block_id()
369
+ for blkkey in loop.body:
370
+ blk = newblocks[blkkey]
371
+ if header in blk.terminator.get_targets():
372
+ newblk = blk.copy()
373
+ # rewrite backedge into jumps to new tail block
374
+ newblk.body[-1] = replace_target(
375
+ blk.terminator, header, tailkey
376
+ )
377
+ newblocks[blkkey] = newblk
378
+ # create new tail block
379
+ entryblk = newblocks[header]
380
+ tailblk = ir.Block(scope=entryblk.scope, loc=entryblk.loc)
381
+ # add backedge
382
+ tailblk.append(ir.Jump(target=header, loc=tailblk.loc))
383
+ newblocks[tailkey] = tailblk
384
+
385
+ for loop in yield_loops_with_multiple_backedges():
386
+ rewrite_single_backedge(loop)
387
+
388
+ return newblocks
389
+
390
+
391
+ def canonicalize_cfg(blocks):
392
+ """
393
+ Rewrite the given blocks to canonicalize the CFG.
394
+ Returns a new dictionary of blocks.
395
+ """
396
+ return canonicalize_cfg_single_backedge(blocks)
397
+
398
+
399
+ def with_lifting(func_ir, typingctx, targetctx, flags, locals):
400
+ """With-lifting transformation
401
+
402
+ Rewrite the IR to extract all withs.
403
+ Only the top-level withs are extracted.
404
+ Returns the (the_new_ir, the_lifted_with_ir)
405
+ """
406
+ from numba.cuda.core import postproc
407
+
408
+ def dispatcher_factory(func_ir, objectmode=False, **kwargs):
409
+ from numba.cuda.dispatcher import LiftedWith, ObjModeLiftedWith
410
+
411
+ myflags = flags.copy()
412
+ if objectmode:
413
+ # Lifted with-block cannot looplift
414
+ myflags.enable_looplift = False
415
+ # Lifted with-block uses object mode
416
+ myflags.enable_pyobject = True
417
+ myflags.force_pyobject = True
418
+ myflags.no_cpython_wrapper = False
419
+ cls = ObjModeLiftedWith
420
+ else:
421
+ cls = LiftedWith
422
+ return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)
423
+
424
+ # find where with-contexts regions are
425
+ withs, func_ir = find_setupwiths(func_ir)
426
+
427
+ if not withs:
428
+ return func_ir, []
429
+
430
+ postproc.PostProcessor(func_ir).run() # ensure we have variable lifetime
431
+ assert func_ir.variable_lifetime
432
+ vlt = func_ir.variable_lifetime
433
+ blocks = func_ir.blocks.copy()
434
+ cfg = vlt.cfg
435
+ # For each with-regions, mutate them according to
436
+ # the kind of contextmanager
437
+ sub_irs = []
438
+ for blk_start, blk_end in withs:
439
+ body_blocks = []
440
+ for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
441
+ body_blocks.append(node)
442
+ _legalize_with_head(blocks[blk_start])
443
+ # Find the contextmanager
444
+ cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
445
+ # Mutate the body and get new IR
446
+ sub = cmkind.mutate_with_body(
447
+ func_ir,
448
+ blocks,
449
+ blk_start,
450
+ blk_end,
451
+ body_blocks,
452
+ dispatcher_factory,
453
+ extra,
454
+ )
455
+ sub_irs.append(sub)
456
+ if not sub_irs:
457
+ # Unchanged
458
+ new_ir = func_ir
459
+ else:
460
+ new_ir = func_ir.derive(blocks)
461
+ return new_ir, sub_irs
462
+
463
+
464
+ def _get_with_contextmanager(func_ir, blocks, blk_start):
465
+ """Get the global object used for the context manager"""
466
+ _illegal_cm_msg = "Illegal use of context-manager."
467
+
468
+ def get_var_dfn(var):
469
+ """Get the definition given a variable"""
470
+ return func_ir.get_definition(var)
471
+
472
+ def get_ctxmgr_obj(var_ref):
473
+ """Return the context-manager object and extra info.
474
+
475
+ The extra contains the arguments if the context-manager is used
476
+ as a call.
477
+ """
478
+ # If the contextmanager used as a Call
479
+ dfn = func_ir.get_definition(var_ref)
480
+ if isinstance(dfn, ir.Expr) and dfn.op == "call":
481
+ args = [get_var_dfn(x) for x in dfn.args]
482
+ kws = {k: get_var_dfn(v) for k, v in dfn.kws}
483
+ extra = {"args": args, "kwargs": kws}
484
+ var_ref = dfn.func
485
+ else:
486
+ extra = None
487
+
488
+ ctxobj = ir_utils.guard(ir_utils.find_outer_value, func_ir, var_ref)
489
+
490
+ # check the contextmanager object
491
+ if ctxobj is ir.UNDEFINED:
492
+ raise errors.CompilerError(
493
+ "Undefined variable used as context manager",
494
+ loc=blocks[blk_start].loc,
495
+ )
496
+
497
+ if ctxobj is None:
498
+ raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc)
499
+
500
+ return ctxobj, extra
501
+
502
+ # Scan the start of the with-region for the contextmanager
503
+ for stmt in blocks[blk_start].body:
504
+ if isinstance(stmt, ir.EnterWith):
505
+ var_ref = stmt.contextmanager
506
+ ctxobj, extra = get_ctxmgr_obj(var_ref)
507
+ if not hasattr(ctxobj, "mutate_with_body"):
508
+ raise errors.CompilerError(
509
+ "Unsupported context manager in use",
510
+ loc=blocks[blk_start].loc,
511
+ )
512
+ return ctxobj, extra
513
+ # No contextmanager found?
514
+ raise errors.CompilerError(
515
+ "malformed with-context usage",
516
+ loc=blocks[blk_start].loc,
517
+ )
518
+
519
+
520
+ def _legalize_with_head(blk):
521
+ """Given *blk*, the head block of the with-context, check that it doesn't
522
+ do anything else.
523
+ """
524
+ counters = defaultdict(int)
525
+ for stmt in blk.body:
526
+ counters[type(stmt)] += 1
527
+ if counters.pop(ir.EnterWith) != 1:
528
+ raise errors.CompilerError(
529
+ "with's head-block must have exactly 1 ENTER_WITH",
530
+ loc=blk.loc,
531
+ )
532
+ if counters.pop(ir.Jump, 0) != 1:
533
+ raise errors.CompilerError(
534
+ "with's head-block must have exactly 1 JUMP",
535
+ loc=blk.loc,
536
+ )
537
+ # Can have any number of del
538
+ counters.pop(ir.Del, None)
539
+ # There MUST NOT be any other statements
540
+ if counters:
541
+ raise errors.CompilerError(
542
+ "illegal statements in with's head-block",
543
+ loc=blk.loc,
544
+ )
545
+
546
+
547
+ def _cfg_nodes_in_region(cfg, region_begin, region_end):
548
+ """Find the set of CFG nodes that are in the given region"""
549
+ region_nodes = set()
550
+ stack = [region_begin]
551
+ while stack:
552
+ tos = stack.pop()
553
+ succlist = list(cfg.successors(tos))
554
+ # a single block function will have a empty successor list
555
+ if succlist:
556
+ succs, _ = zip(*succlist)
557
+ nodes = set(
558
+ [
559
+ node
560
+ for node in succs
561
+ if node not in region_nodes and node != region_end
562
+ ]
563
+ )
564
+ stack.extend(nodes)
565
+ region_nodes |= nodes
566
+
567
+ return region_nodes
568
+
569
+
570
+ def find_setupwiths(func_ir):
571
+ """Find all top-level with.
572
+
573
+ Returns a list of ranges for the with-regions.
574
+ """
575
+
576
+ def find_ranges(blocks):
577
+ cfg = compute_cfg_from_blocks(blocks)
578
+ sus_setups, sus_pops = set(), set()
579
+ # traverse the cfg and collect all suspected SETUP_WITH and POP_BLOCK
580
+ # statements so that we can iterate over them
581
+ for label, block in blocks.items():
582
+ for stmt in block.body:
583
+ if ir_utils.is_setup_with(stmt):
584
+ sus_setups.add(label)
585
+ if ir_utils.is_pop_block(stmt):
586
+ sus_pops.add(label)
587
+
588
+ # now that we do have the statements, iterate through them in reverse
589
+ # topo order and from each start looking for pop_blocks
590
+ setup_with_to_pop_blocks_map = defaultdict(set)
591
+ for setup_block in cfg.topo_sort(sus_setups, reverse=True):
592
+ # begin pop_block, search
593
+ to_visit, seen = [], []
594
+ to_visit.append(setup_block)
595
+ while to_visit:
596
+ # get whatever is next and record that we have seen it
597
+ block = to_visit.pop()
598
+ seen.append(block)
599
+ # go through the body of the block, looking for statements
600
+ for stmt in blocks[block].body:
601
+ # raise detected before pop_block
602
+ if ir_utils.is_raise(stmt):
603
+ raise errors.CompilerError(
604
+ "unsupported control flow due to raise "
605
+ "statements inside with block"
606
+ )
607
+ # if a pop_block, process it
608
+ if ir_utils.is_pop_block(stmt) and block in sus_pops:
609
+ # record the jump target of this block belonging to this setup
610
+ setup_with_to_pop_blocks_map[setup_block].add(block)
611
+ # remove the block from blocks to be matched
612
+ sus_pops.remove(block)
613
+ # stop looking, we have reached the frontier
614
+ break
615
+ # if we are still here, by the block terminator,
616
+ # add all its targets to the to_visit stack, unless we
617
+ # have seen them already
618
+ if ir_utils.is_terminator(stmt):
619
+ for t in stmt.get_targets():
620
+ if t not in seen:
621
+ to_visit.append(t)
622
+
623
+ return setup_with_to_pop_blocks_map
624
+
625
+ blocks = func_ir.blocks
626
+ # initial find, will return a dictionary, mapping indices of blocks
627
+ # containing SETUP_WITH statements to a set of indices of blocks containing
628
+ # POP_BLOCK statements
629
+ with_ranges_dict = find_ranges(blocks)
630
+ # rewrite the CFG in case there are multiple POP_BLOCK statements for one
631
+ # with
632
+ func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
633
+ # here we need to turn the withs back into a list of tuples so that the
634
+ # rest of the code can cope
635
+ with_ranges_tuple = [(s, list(p)[0]) for (s, p) in with_ranges_dict.items()]
636
+
637
+ # check for POP_BLOCKS with multiple outgoing edges and reject
638
+ for _, p in with_ranges_tuple:
639
+ targets = blocks[p].terminator.get_targets()
640
+ if len(targets) != 1:
641
+ raise errors.CompilerError(
642
+ "unsupported control flow: with-context contains branches "
643
+ "(i.e. break/return/raise) that can leave the block "
644
+ )
645
+ # now we check for returns inside with and reject them
646
+ for _, p in with_ranges_tuple:
647
+ target_block = blocks[p]
648
+ if ir_utils.is_return(
649
+ func_ir.blocks[target_block.terminator.get_targets()[0]].terminator
650
+ ):
651
+ _rewrite_return(func_ir, p)
652
+
653
+ # now we need to rewrite the tuple such that we have SETUP_WITH matching the
654
+ # successor of the block that contains the POP_BLOCK.
655
+ with_ranges_tuple = [
656
+ (s, func_ir.blocks[p].terminator.get_targets()[0])
657
+ for (s, p) in with_ranges_tuple
658
+ ]
659
+
660
+ # finally we check for nested with statements and reject them
661
+ with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
662
+
663
+ return with_ranges_tuple, func_ir
664
+
665
+
666
+ def _rewrite_return(func_ir, target_block_label):
667
+ """Rewrite a return block inside a with statement.
668
+
669
+ Arguments
670
+ ---------
671
+
672
+ func_ir: Function IR
673
+ the CFG to transform
674
+ target_block_label: int
675
+ the block index/label of the block containing the POP_BLOCK statement
676
+
677
+
678
+ This implements a CFG transformation to insert a block between two other
679
+ blocks.
680
+
681
+ The input situation is:
682
+
683
+ ┌───────────────┐
684
+ │ top │
685
+ │ POP_BLOCK │
686
+ │ bottom │
687
+ └───────┬───────┘
688
+
689
+ ┌───────▼───────┐
690
+ │ │
691
+ │ RETURN │
692
+ │ │
693
+ └───────────────┘
694
+
695
+ If such a pattern is detected in IR, it means there is a `return` statement
696
+ within a `with` context. The basic idea is to rewrite the CFG as follows:
697
+
698
+ ┌───────────────┐
699
+ │ top │
700
+ │ POP_BLOCK │
701
+ │ │
702
+ └───────┬───────┘
703
+
704
+ ┌───────▼───────┐
705
+ │ │
706
+ │ bottom │
707
+ │ │
708
+ └───────┬───────┘
709
+
710
+ ┌───────▼───────┐
711
+ │ │
712
+ │ RETURN │
713
+ │ │
714
+ └───────────────┘
715
+
716
+ We split the block that contains the `POP_BLOCK` statement into two blocks.
717
+ Everything from the beginning of the block up to and including the
718
+ `POP_BLOCK` statement is considered the 'top' and everything below is
719
+ considered 'bottom'. Finally the jump statements are re-wired to make sure
720
+ the CFG remains valid.
721
+
722
+ """
723
+ # the block itself from the index
724
+ target_block = func_ir.blocks[target_block_label]
725
+ # get the index of the block containing the return
726
+ target_block_successor_label = target_block.terminator.get_targets()[0]
727
+ # the return block
728
+ target_block_successor = func_ir.blocks[target_block_successor_label]
729
+
730
+ # create the new return block with an appropriate label
731
+ max_label = ir_utils.find_max_label(func_ir.blocks)
732
+ new_label = max_label + 1
733
+ # create the new return block
734
+ new_block_loc = target_block_successor.loc
735
+ new_block_scope = ir.Scope(None, loc=new_block_loc)
736
+ new_block = ir.Block(new_block_scope, loc=new_block_loc)
737
+
738
+ # Split the block containing the POP_BLOCK into top and bottom
739
+ # Block must be of the form:
740
+ # -----------------
741
+ # <some stmts>
742
+ # POP_BLOCK
743
+ # <some more stmts>
744
+ # JUMP
745
+ # -----------------
746
+ top_body, bottom_body = [], []
747
+ pop_blocks = [*target_block.find_insts(ir.PopBlock)]
748
+ assert len(pop_blocks) == 1
749
+ assert len([*target_block.find_insts(ir.Jump)]) == 1
750
+ assert isinstance(target_block.body[-1], ir.Jump)
751
+ pb_marker = pop_blocks[0]
752
+ pb_is = target_block.body.index(pb_marker)
753
+ top_body.extend(target_block.body[:pb_is])
754
+ top_body.append(ir.Jump(target_block_successor_label, target_block.loc))
755
+ bottom_body.extend(target_block.body[pb_is:-1])
756
+ bottom_body.append(ir.Jump(new_label, target_block.loc))
757
+
758
+ # get the contents of the return block
759
+ return_body = func_ir.blocks[target_block_successor_label].body
760
+ # finally, re-assign all blocks
761
+ new_block.body.extend(return_body)
762
+ target_block_successor.body.clear()
763
+ target_block_successor.body.extend(bottom_body)
764
+ target_block.body.clear()
765
+ target_block.body.extend(top_body)
766
+
767
+ # finally, append the new return block and rebuild the IR properties
768
+ func_ir.blocks[new_label] = new_block
769
+ func_ir._definitions = ir_utils.build_definitions(func_ir.blocks)
770
+ return func_ir
771
+
772
+
773
+ def _eliminate_nested_withs(with_ranges):
774
+ known_ranges = []
775
+
776
+ def within_known_range(start, end, known_ranges):
777
+ for a, b in known_ranges:
778
+ # FIXME: this should be a comparison in topological order, right
779
+ # now we are comparing the integers of the blocks, stuff probably
780
+ # works by accident.
781
+ if start > a and end < b:
782
+ return True
783
+ return False
784
+
785
+ for s, e in sorted(with_ranges):
786
+ if not within_known_range(s, e, known_ranges):
787
+ known_ranges.append((s, e))
788
+
789
+ return known_ranges
790
+
791
+
792
+ def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
793
+ """Modify the FunctionIR to merge the exit blocks of with constructs."""
794
+ for k in withs:
795
+ vs: set = withs[k]
796
+ if len(vs) > 1:
797
+ func_ir, common = _fix_multi_exit_blocks(
798
+ func_ir,
799
+ vs,
800
+ split_condition=ir_utils.is_pop_block,
801
+ )
802
+ withs[k] = {common}
803
+ return func_ir
804
+
805
+
806
+ def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None):
807
+ """Modify the FunctionIR to create a single common exit node given the
808
+ original exit nodes.
809
+
810
+ Parameters
811
+ ----------
812
+ func_ir :
813
+ The FunctionIR. Mutated inplace.
814
+ exit_nodes :
815
+ The original exit nodes. A sequence of block keys.
816
+ split_condition : callable or None
817
+ If not None, it is a callable with the signature
818
+ `split_condition(statement)` that determines if the `statement` is the
819
+ splitting point (e.g. `POP_BLOCK`) in an exit node.
820
+ If it's None, the exit node is not split.
821
+ """
822
+
823
+ # Convert the following:
824
+ #
825
+ # | |
826
+ # +-------+ +-------+
827
+ # | exit0 | | exit1 |
828
+ # +-------+ +-------+
829
+ # | |
830
+ # +-------+ +-------+
831
+ # | after0| | after1|
832
+ # +-------+ +-------+
833
+ # | |
834
+ #
835
+ # To roughly:
836
+ #
837
+ # | |
838
+ # +-------+ +-------+
839
+ # | exit0 | | exit1 |
840
+ # +-------+ +-------+
841
+ # | |
842
+ # +-----+-----+
843
+ # |
844
+ # +---------+
845
+ # | common |
846
+ # +---------+
847
+ # |
848
+ # +-------+
849
+ # | post |
850
+ # +-------+
851
+ # |
852
+ # +-----+-----+
853
+ # | |
854
+ # +-------+ +-------+
855
+ # | after0| | after1|
856
+ # +-------+ +-------+
857
+
858
+ blocks = func_ir.blocks
859
+ # Getting the scope
860
+ any_blk = min(func_ir.blocks.values())
861
+ scope = any_blk.scope
862
+ # Getting the maximum block label
863
+ max_label = max(func_ir.blocks) + 1
864
+ # Define the new common block for the new exit.
865
+ common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
866
+ common_label = max_label
867
+ max_label += 1
868
+ blocks[common_label] = common_block
869
+ # Define the new block after the exit.
870
+ post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
871
+ post_label = max_label
872
+ max_label += 1
873
+ blocks[post_label] = post_block
874
+
875
+ # Adjust each exit node
876
+ remainings = []
877
+ for i, k in enumerate(exit_nodes):
878
+ blk = blocks[k]
879
+
880
+ # split the block if needed
881
+ if split_condition is not None:
882
+ for pt, stmt in enumerate(blk.body):
883
+ if split_condition(stmt):
884
+ break
885
+ else:
886
+ # no splitting
887
+ pt = -1
888
+
889
+ before = blk.body[:pt]
890
+ after = blk.body[pt:]
891
+ remainings.append(after)
892
+
893
+ # Add control-point variable to mark which exit block this is.
894
+ blk.body = before
895
+ loc = blk.loc
896
+ blk.body.append(
897
+ ir.Assign(
898
+ value=ir.Const(i, loc=loc),
899
+ target=scope.get_or_define("$cp", loc=loc),
900
+ loc=loc,
901
+ )
902
+ )
903
+ # Replace terminator with a jump to the common block
904
+ assert not blk.is_terminated
905
+ blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc))
906
+
907
+ if split_condition is not None:
908
+ # Move the splitting statement to the common block
909
+ common_block.body.append(remainings[0][0])
910
+ assert not common_block.is_terminated
911
+ # Append jump from common block to post block
912
+ common_block.body.append(ir.Jump(post_label, loc=loc))
913
+
914
+ # Make if-else tree to jump to target
915
+ remain_blocks = []
916
+ for remain in remainings:
917
+ remain_blocks.append(max_label)
918
+ max_label += 1
919
+
920
+ switch_block = post_block
921
+ loc = ir.unknown_loc
922
+ for i, remain in enumerate(remainings):
923
+ match_expr = scope.redefine("$cp_check", loc=loc)
924
+ match_rhs = scope.redefine("$cp_rhs", loc=loc)
925
+
926
+ # Do comparison to match control-point variable to the exit block
927
+ switch_block.body.append(
928
+ ir.Assign(value=ir.Const(i, loc=loc), target=match_rhs, loc=loc),
929
+ )
930
+
931
+ # Add assignment for the comparison
932
+ switch_block.body.append(
933
+ ir.Assign(
934
+ value=ir.Expr.binop(
935
+ fn=operator.eq,
936
+ lhs=scope.get("$cp"),
937
+ rhs=match_rhs,
938
+ loc=loc,
939
+ ),
940
+ target=match_expr,
941
+ loc=loc,
942
+ ),
943
+ )
944
+
945
+ # Insert jump to the next case
946
+ [jump_target] = remain[-1].get_targets()
947
+ switch_block.body.append(
948
+ ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc),
949
+ )
950
+ switch_block = ir.Block(scope=scope, loc=loc)
951
+ blocks[remain_blocks[i]] = switch_block
952
+
953
+ # Add the final jump
954
+ switch_block.body.append(ir.Jump(jump_target, loc=loc))
955
+
956
+ return func_ir, common_label