numba-cuda 0.21.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (488) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +577 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +556 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +951 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3222 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +558 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +995 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +903 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +158 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsic_wrapper.py +41 -0
  161. numba_cuda/numba/cuda/intrinsics.py +382 -0
  162. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  163. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  164. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  165. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  166. numba_cuda/numba/cuda/libdevice.py +3386 -0
  167. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  168. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  169. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  170. numba_cuda/numba/cuda/locks.py +19 -0
  171. numba_cuda/numba/cuda/lowering.py +1951 -0
  172. numba_cuda/numba/cuda/mathimpl.py +374 -0
  173. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  175. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  178. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  179. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  180. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  181. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  182. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  183. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  184. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  185. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  186. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  187. numba_cuda/numba/cuda/misc/literal.py +28 -0
  188. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  189. numba_cuda/numba/cuda/misc/special.py +94 -0
  190. numba_cuda/numba/cuda/models.py +56 -0
  191. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  192. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  193. numba_cuda/numba/cuda/np/extensions.py +11 -0
  194. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  195. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  196. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  197. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  198. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  199. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  200. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  201. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  202. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  203. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  204. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  206. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  207. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  208. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  209. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  210. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  211. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  212. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  213. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  214. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  215. numba_cuda/numba/cuda/printimpl.py +126 -0
  216. numba_cuda/numba/cuda/random.py +308 -0
  217. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  218. numba_cuda/numba/cuda/serialize.py +267 -0
  219. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  220. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  221. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  222. numba_cuda/numba/cuda/simulator/api.py +179 -0
  223. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  224. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  236. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  237. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  238. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  239. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  241. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  242. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  243. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  244. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  245. numba_cuda/numba/cuda/simulator_init.py +18 -0
  246. numba_cuda/numba/cuda/stubs.py +635 -0
  247. numba_cuda/numba/cuda/target.py +505 -0
  248. numba_cuda/numba/cuda/testing.py +347 -0
  249. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  251. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  252. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  253. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  254. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  255. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +187 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +198 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  285. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  286. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  289. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  290. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  291. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  292. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  293. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  294. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  295. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +889 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  396. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +331 -0
  397. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  399. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  400. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  401. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  402. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  403. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  404. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  406. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  407. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  424. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  425. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +391 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  430. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  431. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  433. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  434. numba_cuda/numba/cuda/tests/support.py +900 -0
  435. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  436. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  437. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  438. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  439. numba_cuda/numba/cuda/types/__init__.py +233 -0
  440. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  441. numba_cuda/numba/cuda/types/abstract.py +9 -0
  442. numba_cuda/numba/cuda/types/common.py +9 -0
  443. numba_cuda/numba/cuda/types/containers.py +9 -0
  444. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  445. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  446. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  447. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  448. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  449. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  450. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  451. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  452. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  453. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  454. numba_cuda/numba/cuda/types/function_type.py +11 -0
  455. numba_cuda/numba/cuda/types/functions.py +9 -0
  456. numba_cuda/numba/cuda/types/iterators.py +9 -0
  457. numba_cuda/numba/cuda/types/misc.py +9 -0
  458. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  459. numba_cuda/numba/cuda/types/scalars.py +9 -0
  460. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  461. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  462. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  463. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  464. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  465. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  466. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  467. numba_cuda/numba/cuda/typing/collections.py +138 -0
  468. numba_cuda/numba/cuda/typing/context.py +782 -0
  469. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  470. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  471. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  472. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  473. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  474. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  475. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  476. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  477. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  478. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  479. numba_cuda/numba/cuda/ufuncs.py +746 -0
  480. numba_cuda/numba/cuda/utils.py +724 -0
  481. numba_cuda/numba/cuda/vector_types.py +214 -0
  482. numba_cuda/numba/cuda/vectorizers.py +260 -0
  483. numba_cuda-0.21.1.dist-info/METADATA +109 -0
  484. numba_cuda-0.21.1.dist-info/RECORD +488 -0
  485. numba_cuda-0.21.1.dist-info/WHEEL +5 -0
  486. numba_cuda-0.21.1.dist-info/licenses/LICENSE +26 -0
  487. numba_cuda-0.21.1.dist-info/licenses/LICENSE.numba +24 -0
  488. numba_cuda-0.21.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1979 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from collections import defaultdict, namedtuple
5
+ from contextlib import contextmanager
6
+ from copy import deepcopy, copy
7
+ import warnings
8
+
9
+ from numba.cuda.core.compiler_machinery import (
10
+ FunctionPass,
11
+ AnalysisPass,
12
+ SSACompliantMixin,
13
+ register_pass,
14
+ )
15
+ from numba import cuda
16
+ from numba.cuda.core import postproc, bytecode, transforms, inline_closurecall
17
+ from numba.cuda.core import (
18
+ errors,
19
+ )
20
+ from numba.cuda.core import ir
21
+ from numba.cuda import types
22
+ from numba.cuda.core import consts, rewrites, config
23
+ from numba.cuda.core.interpreter import Interpreter
24
+
25
+
26
+ from numba.cuda.misc.special import literal_unroll
27
+ from numba.cuda.core.analysis import dead_branch_prune
28
+ from numba.cuda.core.analysis import (
29
+ rewrite_semantic_constants,
30
+ find_literally_calls,
31
+ compute_cfg_from_blocks,
32
+ compute_use_defs,
33
+ )
34
+ from numba.cuda.core.ir_utils import (
35
+ guard,
36
+ resolve_func_from_module,
37
+ simplify_CFG,
38
+ GuardException,
39
+ convert_code_obj_to_function,
40
+ build_definitions,
41
+ replace_var_names,
42
+ get_name_var_table,
43
+ compile_to_numba_ir,
44
+ get_definition,
45
+ find_max_label,
46
+ rename_labels,
47
+ transfer_scope,
48
+ fixup_var_define_in_scope,
49
+ )
50
+ from numba.cuda.core.ssa import reconstruct_ssa
51
+
52
+
53
+ @contextmanager
54
+ def fallback_context(state, msg):
55
+ """
56
+ Wraps code that would signal a fallback to object mode
57
+ """
58
+ try:
59
+ yield
60
+ except Exception as e:
61
+ if not state.status.can_fallback:
62
+ raise
63
+ else:
64
+ # Clear all references attached to the traceback
65
+ e = e.with_traceback(None)
66
+ # this emits a warning containing the error message body in the
67
+ # case of fallback from npm to objmode
68
+ loop_lift = "" if state.flags.enable_looplift else "OUT"
69
+ warnings.warn_explicit(
70
+ "Compilation is falling back to object mode "
71
+ f"WITH{loop_lift} looplifting enabled because {msg} due to: {e}",
72
+ errors.NumbaWarning,
73
+ state.func_id.filename,
74
+ state.func_id.firstlineno,
75
+ )
76
+ raise
77
+
78
+
79
+ @register_pass(mutates_CFG=True, analysis_only=False)
80
+ class ExtractByteCode(FunctionPass):
81
+ _name = "extract_bytecode"
82
+
83
+ def __init__(self):
84
+ FunctionPass.__init__(self)
85
+
86
+ def run_pass(self, state):
87
+ """
88
+ Extract bytecode from function
89
+ """
90
+ func_id = state["func_id"]
91
+ bc = bytecode.ByteCode(func_id)
92
+ if config.DUMP_BYTECODE:
93
+ print(bc.dump())
94
+
95
+ state["bc"] = bc
96
+ return True
97
+
98
+
99
+ @register_pass(mutates_CFG=True, analysis_only=False)
100
+ class TranslateByteCode(FunctionPass):
101
+ _name = "translate_bytecode"
102
+
103
+ def __init__(self):
104
+ FunctionPass.__init__(self)
105
+
106
+ def run_pass(self, state):
107
+ """
108
+ Analyze bytecode and translating to Numba IR
109
+ """
110
+ func_id = state["func_id"]
111
+ bc = state["bc"]
112
+ interp = Interpreter(func_id)
113
+ func_ir = interp.interpret(bc)
114
+ state["func_ir"] = func_ir
115
+ return True
116
+
117
+
118
+ @register_pass(mutates_CFG=True, analysis_only=False)
119
+ class FixupArgs(FunctionPass):
120
+ _name = "fixup_args"
121
+
122
+ def __init__(self):
123
+ FunctionPass.__init__(self)
124
+
125
+ def run_pass(self, state):
126
+ state["nargs"] = state["func_ir"].arg_count
127
+ if not state["args"] and state["flags"].force_pyobject:
128
+ # Allow an empty argument types specification when object mode
129
+ # is explicitly requested.
130
+ state["args"] = (types.pyobject,) * state["nargs"]
131
+ elif len(state["args"]) != state["nargs"]:
132
+ raise TypeError(
133
+ "Signature mismatch: %d argument types given, "
134
+ "but function takes %d arguments"
135
+ % (len(state["args"]), state["nargs"])
136
+ )
137
+ return True
138
+
139
+
140
+ @register_pass(mutates_CFG=True, analysis_only=False)
141
+ class IRProcessing(FunctionPass):
142
+ _name = "ir_processing"
143
+
144
+ def __init__(self):
145
+ FunctionPass.__init__(self)
146
+
147
+ def run_pass(self, state):
148
+ func_ir = state["func_ir"]
149
+ post_proc = postproc.PostProcessor(func_ir)
150
+ post_proc.run()
151
+
152
+ if config.DEBUG or config.DUMP_IR:
153
+ name = func_ir.func_id.func_qualname
154
+ print(("IR DUMP: %s" % name).center(80, "-"))
155
+ func_ir.dump()
156
+ if func_ir.is_generator:
157
+ print(("GENERATOR INFO: %s" % name).center(80, "-"))
158
+ func_ir.dump_generator_info()
159
+ return True
160
+
161
+
162
+ @register_pass(mutates_CFG=True, analysis_only=False)
163
+ class RewriteSemanticConstants(FunctionPass):
164
+ _name = "rewrite_semantic_constants"
165
+
166
+ def __init__(self):
167
+ FunctionPass.__init__(self)
168
+
169
+ def run_pass(self, state):
170
+ """
171
+ This prunes dead branches, a dead branch is one which is derivable as
172
+ not taken at compile time purely based on const/literal evaluation.
173
+ """
174
+ assert state.func_ir
175
+ msg = (
176
+ "Internal error in pre-inference dead branch pruning "
177
+ "pass encountered during compilation of "
178
+ 'function "%s"' % (state.func_id.func_name,)
179
+ )
180
+ with fallback_context(state, msg):
181
+ rewrite_semantic_constants(state.func_ir, state.args)
182
+
183
+ return True
184
+
185
+
186
+ @register_pass(mutates_CFG=True, analysis_only=False)
187
+ class DeadBranchPrune(SSACompliantMixin, FunctionPass):
188
+ _name = "dead_branch_prune"
189
+
190
+ def __init__(self):
191
+ FunctionPass.__init__(self)
192
+
193
+ def run_pass(self, state):
194
+ """
195
+ This prunes dead branches, a dead branch is one which is derivable as
196
+ not taken at compile time purely based on const/literal evaluation.
197
+ """
198
+
199
+ # purely for demonstration purposes, obtain the analysis from a pass
200
+ # declare as a required dependent
201
+ semantic_const_analysis = self.get_analysis(type(self)) # noqa
202
+
203
+ assert state.func_ir
204
+ msg = (
205
+ "Internal error in pre-inference dead branch pruning "
206
+ "pass encountered during compilation of "
207
+ 'function "%s"' % (state.func_id.func_name,)
208
+ )
209
+ with fallback_context(state, msg):
210
+ dead_branch_prune(state.func_ir, state.args)
211
+
212
+ return True
213
+
214
+ def get_analysis_usage(self, AU):
215
+ AU.add_required(RewriteSemanticConstants)
216
+
217
+
218
+ @register_pass(mutates_CFG=True, analysis_only=False)
219
+ class InlineClosureLikes(FunctionPass):
220
+ _name = "inline_closure_likes"
221
+
222
+ def __init__(self):
223
+ FunctionPass.__init__(self)
224
+
225
+ def run_pass(self, state):
226
+ # Ensure we have an IR and type information.
227
+ assert state.func_ir
228
+
229
+ # if the return type is a pyobject, there's no type info available and
230
+ # no ability to resolve certain typed function calls in the array
231
+ # inlining code, use this variable to indicate
232
+ typed_pass = not isinstance(state.return_type, types.misc.PyObject)
233
+
234
+ inline_pass = inline_closurecall.InlineClosureCallPass(
235
+ state.func_ir,
236
+ state.flags.auto_parallel,
237
+ None,
238
+ typed_pass,
239
+ )
240
+ inline_pass.run()
241
+
242
+ # Remove all Dels, and re-run postproc
243
+ post_proc = postproc.PostProcessor(state.func_ir)
244
+ post_proc.run()
245
+
246
+ fixup_var_define_in_scope(state.func_ir.blocks)
247
+
248
+ return True
249
+
250
+
251
+ @register_pass(mutates_CFG=True, analysis_only=False)
252
+ class GenericRewrites(FunctionPass):
253
+ _name = "generic_rewrites"
254
+
255
+ def __init__(self):
256
+ FunctionPass.__init__(self)
257
+
258
+ def run_pass(self, state):
259
+ """
260
+ Perform any intermediate representation rewrites before type
261
+ inference.
262
+ """
263
+ assert state.func_ir
264
+ msg = (
265
+ "Internal error in pre-inference rewriting "
266
+ "pass encountered during compilation of "
267
+ 'function "%s"' % (state.func_id.func_name,)
268
+ )
269
+ with fallback_context(state, msg):
270
+ rewrites.rewrite_registry.apply("before-inference", state)
271
+ return True
272
+
273
+
274
+ @register_pass(mutates_CFG=True, analysis_only=False)
275
+ class WithLifting(FunctionPass):
276
+ _name = "with_lifting"
277
+
278
+ def __init__(self):
279
+ FunctionPass.__init__(self)
280
+
281
+ def run_pass(self, state):
282
+ """
283
+ Extract with-contexts
284
+ """
285
+ main, withs = transforms.with_lifting(
286
+ func_ir=state.func_ir,
287
+ typingctx=state.typingctx,
288
+ targetctx=state.targetctx,
289
+ flags=state.flags,
290
+ locals=state.locals,
291
+ )
292
+ if withs:
293
+ from numba.cuda.compiler import compile_ir
294
+ from numba.cuda.core.compiler import _EarlyPipelineCompletion
295
+
296
+ cres = compile_ir(
297
+ state.typingctx,
298
+ state.targetctx,
299
+ main,
300
+ state.args,
301
+ state.return_type,
302
+ state.flags,
303
+ state.locals,
304
+ lifted=tuple(withs),
305
+ lifted_from=None,
306
+ pipeline_class=type(state.pipeline),
307
+ )
308
+ raise _EarlyPipelineCompletion(cres)
309
+ return True
310
+
311
+
312
+ @register_pass(mutates_CFG=True, analysis_only=False)
313
+ class InlineInlinables(FunctionPass):
314
+ """
315
+ This pass will inline a function wrapped by the numba.jit decorator directly
316
+ into the site of its call depending on the value set in the 'inline' kwarg
317
+ to the decorator.
318
+
319
+ This is an untyped pass. CFG simplification is performed at the end of the
320
+ pass but no block level clean up is performed on the mutated IR (typing
321
+ information is not available to do so).
322
+ """
323
+
324
+ _name = "inline_inlinables"
325
+ _DEBUG = False
326
+
327
+ def __init__(self):
328
+ FunctionPass.__init__(self)
329
+
330
+ def run_pass(self, state):
331
+ """Run inlining of inlinables"""
332
+ if self._DEBUG:
333
+ print("before inline".center(80, "-"))
334
+ print(state.func_ir.dump())
335
+ print("".center(80, "-"))
336
+
337
+ inline_worker = inline_closurecall.InlineWorker(
338
+ state.typingctx,
339
+ state.targetctx,
340
+ state.locals,
341
+ state.pipeline,
342
+ state.flags,
343
+ validator=inline_closurecall.callee_ir_validator,
344
+ )
345
+
346
+ modified = False
347
+ # use a work list, look for call sites via `ir.Expr.op == call` and
348
+ # then pass these to `self._do_work` to make decisions about inlining.
349
+ work_list = list(state.func_ir.blocks.items())
350
+ while work_list:
351
+ label, block = work_list.pop()
352
+ for i, instr in enumerate(block.body):
353
+ if isinstance(instr, ir.Assign):
354
+ expr = instr.value
355
+ if isinstance(expr, ir.Expr) and expr.op == "call":
356
+ if guard(
357
+ self._do_work,
358
+ state,
359
+ work_list,
360
+ block,
361
+ i,
362
+ expr,
363
+ inline_worker,
364
+ ):
365
+ modified = True
366
+ break # because block structure changed
367
+
368
+ if modified:
369
+ # clean up unconditional branches that appear due to inlined
370
+ # functions introducing blocks
371
+ cfg = compute_cfg_from_blocks(state.func_ir.blocks)
372
+ for dead in cfg.dead_nodes():
373
+ del state.func_ir.blocks[dead]
374
+ post_proc = postproc.PostProcessor(state.func_ir)
375
+ post_proc.run()
376
+ state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)
377
+
378
+ if self._DEBUG:
379
+ print("after inline".center(80, "-"))
380
+ print(state.func_ir.dump())
381
+ print("".center(80, "-"))
382
+ return True
383
+
384
+ def _do_work(self, state, work_list, block, i, expr, inline_worker):
385
+ from numba.cuda.compiler import run_frontend
386
+ from numba.cuda.core.options import InlineOptions
387
+
388
+ # try and get a definition for the call, this isn't always possible as
389
+ # it might be a eval(str)/part generated awaiting update etc. (parfors)
390
+ to_inline = None
391
+ try:
392
+ to_inline = state.func_ir.get_definition(expr.func)
393
+ except Exception:
394
+ if self._DEBUG:
395
+ print("Cannot find definition for %s" % expr.func)
396
+ return False
397
+ # do not handle closure inlining here, another pass deals with that.
398
+ if getattr(to_inline, "op", False) == "make_function":
399
+ return False
400
+
401
+ # see if the definition is a "getattr", in which case walk the IR to
402
+ # try and find the python function via the module from which it's
403
+ # imported, this should all be encoded in the IR.
404
+ if getattr(to_inline, "op", False) == "getattr":
405
+ val = resolve_func_from_module(state.func_ir, to_inline)
406
+ else:
407
+ # This is likely a freevar or global
408
+ #
409
+ # NOTE: getattr 'value' on a call may fail if it's an ir.Expr as
410
+ # getattr is overloaded to look in _kws.
411
+ try:
412
+ val = getattr(to_inline, "value", False)
413
+ except Exception:
414
+ raise GuardException
415
+
416
+ # if something was found...
417
+ if val:
418
+ # check it's dispatcher-like, the targetoptions attr holds the
419
+ # kwargs supplied in the jit decorator and is where 'inline' will
420
+ # be if it is present.
421
+ topt = getattr(val, "targetoptions", False)
422
+ if topt:
423
+ inline_type = topt.get("inline", None)
424
+ # has 'inline' been specified?
425
+ if inline_type is not None:
426
+ inline_opt = InlineOptions(inline_type)
427
+ # Could this be inlinable?
428
+ if not inline_opt.is_never_inline:
429
+ # yes, it could be inlinable
430
+ do_inline = True
431
+ pyfunc = val.py_func
432
+ # Has it got an associated cost model?
433
+ if inline_opt.has_cost_model:
434
+ # yes, it has a cost model, use it to determine
435
+ # whether to do the inline
436
+ py_func_ir = run_frontend(pyfunc)
437
+ do_inline = inline_type(
438
+ expr, state.func_ir, py_func_ir
439
+ )
440
+ # if do_inline is True then inline!
441
+ if do_inline:
442
+ _, _, _, new_blocks = inline_worker.inline_function(
443
+ state.func_ir,
444
+ block,
445
+ i,
446
+ pyfunc,
447
+ )
448
+ if work_list is not None:
449
+ for blk in new_blocks:
450
+ work_list.append(blk)
451
+ return True
452
+ return False
453
+
454
+
455
+ @register_pass(mutates_CFG=False, analysis_only=False)
456
+ class PreserveIR(AnalysisPass):
457
+ """
458
+ Preserves the IR in the metadata
459
+ """
460
+
461
+ _name = "preserve_ir"
462
+
463
+ def __init__(self):
464
+ AnalysisPass.__init__(self)
465
+
466
+ def run_pass(self, state):
467
+ state.metadata["preserved_ir"] = state.func_ir.copy()
468
+ return False
469
+
470
+
471
+ @register_pass(mutates_CFG=False, analysis_only=True)
472
+ class FindLiterallyCalls(FunctionPass):
473
+ """Find calls to `numba.literally()` and signal if its requirement is not
474
+ satisfied.
475
+ """
476
+
477
+ _name = "find_literally"
478
+
479
+ def __init__(self):
480
+ FunctionPass.__init__(self)
481
+
482
+ def run_pass(self, state):
483
+ find_literally_calls(state.func_ir, state.args)
484
+ return False
485
+
486
+
487
+ @register_pass(mutates_CFG=True, analysis_only=False)
488
+ class CanonicalizeLoopExit(FunctionPass):
489
+ """A pass to canonicalize loop exit by splitting it from function exit."""
490
+
491
+ _name = "canonicalize_loop_exit"
492
+
493
+ def __init__(self):
494
+ FunctionPass.__init__(self)
495
+
496
+ def run_pass(self, state):
497
+ fir = state.func_ir
498
+ cfg = compute_cfg_from_blocks(fir.blocks)
499
+ status = False
500
+ for loop in cfg.loops().values():
501
+ for exit_label in loop.exits:
502
+ if exit_label in cfg.exit_points():
503
+ self._split_exit_block(fir, cfg, exit_label)
504
+ status = True
505
+
506
+ fir._reset_analysis_variables()
507
+
508
+ vlt = postproc.VariableLifetime(fir.blocks)
509
+ fir.variable_lifetime = vlt
510
+ return status
511
+
512
+ def _split_exit_block(self, fir, cfg, exit_label):
513
+ curblock = fir.blocks[exit_label]
514
+ newlabel = exit_label + 1
515
+ newlabel = find_max_label(fir.blocks) + 1
516
+ fir.blocks[newlabel] = curblock
517
+ newblock = ir.Block(scope=curblock.scope, loc=curblock.loc)
518
+ newblock.append(ir.Jump(newlabel, loc=curblock.loc))
519
+ fir.blocks[exit_label] = newblock
520
+ # Rename all labels
521
+ fir.blocks = rename_labels(fir.blocks)
522
+
523
+
524
+ @register_pass(mutates_CFG=True, analysis_only=False)
525
+ class CanonicalizeLoopEntry(FunctionPass):
526
+ """A pass to canonicalize loop header by splitting it from function entry.
527
+
528
+ This is needed for loop-lifting; esp in py3.8
529
+ """
530
+
531
+ _name = "canonicalize_loop_entry"
532
+ _supported_globals = {range, enumerate, zip}
533
+
534
+ def __init__(self):
535
+ FunctionPass.__init__(self)
536
+
537
+ def run_pass(self, state):
538
+ fir = state.func_ir
539
+ cfg = compute_cfg_from_blocks(fir.blocks)
540
+ status = False
541
+ for loop in cfg.loops().values():
542
+ if len(loop.entries) == 1:
543
+ [entry_label] = loop.entries
544
+ if entry_label == cfg.entry_point():
545
+ self._split_entry_block(fir, cfg, loop, entry_label)
546
+ status = True
547
+ fir._reset_analysis_variables()
548
+
549
+ vlt = postproc.VariableLifetime(fir.blocks)
550
+ fir.variable_lifetime = vlt
551
+ return status
552
+
553
+ def _split_entry_block(self, fir, cfg, loop, entry_label):
554
+ # Find iterator inputs into the for-loop header
555
+ header_block = fir.blocks[loop.header]
556
+ deps = set()
557
+ for expr in header_block.find_exprs(op="iternext"):
558
+ deps.add(expr.value)
559
+ # Find the getiter for each iterator
560
+ entry_block = fir.blocks[entry_label]
561
+
562
+ # Find the start of loop entry statement that needs to be included.
563
+ startpt = None
564
+ list_of_insts = list(entry_block.find_insts(ir.Assign))
565
+ for assign in reversed(list_of_insts):
566
+ if assign.target in deps:
567
+ rhs = assign.value
568
+ if isinstance(rhs, ir.Var):
569
+ if rhs.is_temp:
570
+ deps.add(rhs)
571
+ elif isinstance(rhs, ir.Expr):
572
+ expr = rhs
573
+ if expr.op == "getiter":
574
+ startpt = assign
575
+ if expr.value.is_temp:
576
+ deps.add(expr.value)
577
+ elif expr.op == "call":
578
+ defn = guard(get_definition, fir, expr.func)
579
+ if isinstance(defn, ir.Global):
580
+ if expr.func.is_temp:
581
+ deps.add(expr.func)
582
+ elif (
583
+ isinstance(rhs, ir.Global)
584
+ and rhs.value in self._supported_globals
585
+ ):
586
+ startpt = assign
587
+
588
+ if startpt is None:
589
+ return
590
+
591
+ splitpt = entry_block.body.index(startpt)
592
+ new_block = entry_block.copy()
593
+ new_block.body = new_block.body[splitpt:]
594
+ new_block.loc = new_block.body[0].loc
595
+ new_label = find_max_label(fir.blocks) + 1
596
+ entry_block.body = entry_block.body[:splitpt]
597
+ entry_block.append(ir.Jump(new_label, loc=new_block.loc))
598
+
599
+ fir.blocks[new_label] = new_block
600
+ # Rename all labels
601
+ fir.blocks = rename_labels(fir.blocks)
602
+
603
+
604
+ @register_pass(mutates_CFG=False, analysis_only=True)
605
+ class PrintIRCFG(FunctionPass):
606
+ _name = "print_ir_cfg"
607
+
608
+ def __init__(self):
609
+ FunctionPass.__init__(self)
610
+ self._ver = 0
611
+
612
+ def run_pass(self, state):
613
+ fir = state.func_ir
614
+ self._ver += 1
615
+ fir.render_dot(filename_prefix="v{}".format(self._ver)).render()
616
+ return False
617
+
618
+
619
+ @register_pass(mutates_CFG=True, analysis_only=False)
620
+ class MakeFunctionToJitFunction(FunctionPass):
621
+ """
622
+ This swaps an ir.Expr.op == "make_function" i.e. a closure, for a compiled
623
+ function containing the closure body and puts it in ir.Global. It's a 1:1
624
+ statement value swap. `make_function` is already untyped
625
+ """
626
+
627
+ _name = "make_function_op_code_to_jit_function"
628
+
629
+ def __init__(self):
630
+ FunctionPass.__init__(self)
631
+
632
+ def run_pass(self, state):
633
+ func_ir = state.func_ir
634
+ mutated = False
635
+ for idx, blk in func_ir.blocks.items():
636
+ for stmt in blk.body:
637
+ if isinstance(stmt, ir.Assign):
638
+ if isinstance(stmt.value, ir.Expr):
639
+ if stmt.value.op == "make_function":
640
+ node = stmt.value
641
+ getdef = func_ir.get_definition
642
+ kw_default = getdef(node.defaults)
643
+ ok = False
644
+ if kw_default is None or isinstance(
645
+ kw_default, ir.Const
646
+ ):
647
+ ok = True
648
+ elif isinstance(kw_default, tuple):
649
+ ok = all(
650
+ [
651
+ isinstance(getdef(x), ir.Const)
652
+ for x in kw_default
653
+ ]
654
+ )
655
+ elif isinstance(kw_default, ir.Expr):
656
+ if kw_default.op != "build_tuple":
657
+ continue
658
+ ok = all(
659
+ [
660
+ isinstance(getdef(x), ir.Const)
661
+ for x in kw_default.items
662
+ ]
663
+ )
664
+ if not ok:
665
+ continue
666
+
667
+ pyfunc = convert_code_obj_to_function(node, func_ir)
668
+ func = cuda.jit()(pyfunc)
669
+ new_node = ir.Global(
670
+ node.code.co_name, func, stmt.loc
671
+ )
672
+ stmt.value = new_node
673
+ mutated |= True
674
+
675
+ # if a change was made the del ordering is probably wrong, patch up
676
+ if mutated:
677
+ post_proc = postproc.PostProcessor(func_ir)
678
+ post_proc.run()
679
+
680
+ return mutated
681
+
682
+
683
+ @register_pass(mutates_CFG=True, analysis_only=False)
684
+ class TransformLiteralUnrollConstListToTuple(FunctionPass):
685
+ """This pass spots a `literal_unroll([<constant values>])` and rewrites it
686
+ as a `literal_unroll(tuple(<constant values>))`.
687
+ """
688
+
689
+ _name = "transform_literal_unroll_const_list_to_tuple"
690
+
691
+ _accepted_types = (types.BaseTuple, types.LiteralList)
692
+
693
+ def __init__(self):
694
+ FunctionPass.__init__(self)
695
+
696
+ def run_pass(self, state):
697
+ mutated = False
698
+ func_ir = state.func_ir
699
+ for label, blk in func_ir.blocks.items():
700
+ calls = [_ for _ in blk.find_exprs("call")]
701
+ for call in calls:
702
+ glbl = guard(get_definition, func_ir, call.func)
703
+ if glbl and isinstance(glbl, (ir.Global, ir.FreeVar)):
704
+ # find a literal_unroll
705
+ if glbl.value is literal_unroll:
706
+ if len(call.args) > 1:
707
+ msg = "literal_unroll takes one argument, found %s"
708
+ raise errors.UnsupportedError(
709
+ msg % len(call.args), call.loc
710
+ )
711
+ # get the arg, make sure its a build_list
712
+ unroll_var = call.args[0]
713
+ to_unroll = guard(get_definition, func_ir, unroll_var)
714
+ if (
715
+ isinstance(to_unroll, ir.Expr)
716
+ and to_unroll.op == "build_list"
717
+ ):
718
+ # make sure they are all const items in the list
719
+ for i, item in enumerate(to_unroll.items):
720
+ val = guard(get_definition, func_ir, item)
721
+ if not val:
722
+ msg = (
723
+ "multiple definitions for variable "
724
+ "%s, cannot resolve constant"
725
+ )
726
+ raise errors.UnsupportedError(
727
+ msg % item, to_unroll.loc
728
+ )
729
+ if not isinstance(val, ir.Const):
730
+ msg = (
731
+ "Found non-constant value at "
732
+ "position %s in a list argument to "
733
+ "literal_unroll" % i
734
+ )
735
+ raise errors.UnsupportedError(
736
+ msg, to_unroll.loc
737
+ )
738
+ # The above appears ok, now swap the build_list for
739
+ # a built tuple.
740
+
741
+ # find the assignment for the unroll target
742
+ to_unroll_lhs = guard(
743
+ get_definition,
744
+ func_ir,
745
+ unroll_var,
746
+ lhs_only=True,
747
+ )
748
+
749
+ if to_unroll_lhs is None:
750
+ msg = (
751
+ "multiple definitions for variable "
752
+ "%s, cannot resolve constant"
753
+ )
754
+ raise errors.UnsupportedError(
755
+ msg % unroll_var, to_unroll.loc
756
+ )
757
+ # scan all blocks looking for the LHS
758
+ for b in func_ir.blocks.values():
759
+ asgn = b.find_variable_assignment(
760
+ to_unroll_lhs.name
761
+ )
762
+ if asgn is not None:
763
+ break
764
+ else:
765
+ msg = (
766
+ "Cannot find assignment for known "
767
+ "variable %s"
768
+ ) % to_unroll_lhs.name
769
+ raise errors.CompilerError(msg, to_unroll.loc)
770
+
771
+ # Create a tuple with the list items as contents
772
+ tup = ir.Expr.build_tuple(
773
+ to_unroll.items, to_unroll.loc
774
+ )
775
+
776
+ # swap the list for the tuple
777
+ asgn.value = tup
778
+ mutated = True
779
+ elif (
780
+ isinstance(to_unroll, ir.Expr)
781
+ and to_unroll.op == "build_tuple"
782
+ ):
783
+ # this is fine, do nothing
784
+ pass
785
+ elif isinstance(
786
+ to_unroll, (ir.Global, ir.FreeVar)
787
+ ) and isinstance(to_unroll.value, tuple):
788
+ # this is fine, do nothing
789
+ pass
790
+ elif isinstance(to_unroll, ir.Arg):
791
+ # this is only fine if the arg is a tuple
792
+ ty = state.typemap[to_unroll.name]
793
+ if not isinstance(ty, self._accepted_types):
794
+ msg = (
795
+ "Invalid use of literal_unroll with a "
796
+ "function argument, only tuples are "
797
+ "supported as function arguments, found "
798
+ "%s"
799
+ ) % ty
800
+ raise errors.UnsupportedError(
801
+ msg, to_unroll.loc
802
+ )
803
+ else:
804
+ extra = None
805
+ if isinstance(to_unroll, ir.Expr):
806
+ # probably a slice
807
+ if to_unroll.op == "getitem":
808
+ ty = state.typemap[to_unroll.value.name]
809
+ # check if this is a tuple slice
810
+ if not isinstance(ty, self._accepted_types):
811
+ extra = "operation %s" % to_unroll.op
812
+ loc = to_unroll.loc
813
+ elif isinstance(to_unroll, ir.Arg):
814
+ extra = "non-const argument %s" % to_unroll.name
815
+ loc = to_unroll.loc
816
+ else:
817
+ if to_unroll is None:
818
+ extra = (
819
+ "multiple definitions of "
820
+ 'variable "%s".' % unroll_var.name
821
+ )
822
+ loc = unroll_var.loc
823
+ else:
824
+ loc = to_unroll.loc
825
+ extra = "unknown problem"
826
+
827
+ if extra:
828
+ msg = (
829
+ "Invalid use of literal_unroll, "
830
+ "argument should be a tuple or a list "
831
+ "of constant values. Failure reason: "
832
+ "found %s" % extra
833
+ )
834
+ raise errors.UnsupportedError(msg, loc)
835
+ return mutated
836
+
837
+
838
+ @register_pass(mutates_CFG=True, analysis_only=False)
839
+ class MixedContainerUnroller(FunctionPass):
840
+ _name = "mixed_container_unroller"
841
+
842
+ _DEBUG = False
843
+
844
+ _accepted_types = (types.BaseTuple, types.LiteralList)
845
+
846
+ def __init__(self):
847
+ FunctionPass.__init__(self)
848
+
849
+ def analyse_tuple(self, tup):
850
+ """
851
+ Returns a map of type->list(indexes) for a typed tuple
852
+ """
853
+ d = defaultdict(list)
854
+ for i, ty in enumerate(tup):
855
+ d[ty].append(i)
856
+ return d
857
+
858
+ def add_offset_to_labels_w_ignore(self, blocks, offset, ignore=None):
859
+ """add an offset to all block labels and jump/branch targets
860
+ don't add an offset to anything in the ignore list
861
+ """
862
+ if ignore is None:
863
+ ignore = set()
864
+
865
+ new_blocks = {}
866
+ for l, b in blocks.items():
867
+ # some parfor last blocks might be empty
868
+ term = None
869
+ if b.body:
870
+ term = b.body[-1]
871
+ if isinstance(term, ir.Jump):
872
+ if term.target not in ignore:
873
+ b.body[-1] = ir.Jump(term.target + offset, term.loc)
874
+ if isinstance(term, ir.Branch):
875
+ if term.truebr not in ignore:
876
+ new_true = term.truebr + offset
877
+ else:
878
+ new_true = term.truebr
879
+
880
+ if term.falsebr not in ignore:
881
+ new_false = term.falsebr + offset
882
+ else:
883
+ new_false = term.falsebr
884
+ b.body[-1] = ir.Branch(term.cond, new_true, new_false, term.loc)
885
+ new_blocks[l + offset] = b
886
+ return new_blocks
887
+
888
+ def inject_loop_body(
889
+ self, switch_ir, loop_ir, caller_max_label, dont_replace, switch_data
890
+ ):
891
+ """
892
+ Injects the "loop body" held in `loop_ir` into `switch_ir` where ever
893
+ there is a statement of the form `SENTINEL.<int> = RHS`. It also:
894
+ * Finds and then deliberately does not relabel non-local jumps so as to
895
+ make the switch table suitable for injection into the IR from which
896
+ the loop body was derived.
897
+ * Looks for `typed_getitem` and wires them up to loop body version
898
+ specific variables or, if possible, directly writes in their constant
899
+ value at their use site.
900
+
901
+ Args:
902
+ - switch_ir, the switch table with SENTINELS as generated by
903
+ self.gen_switch
904
+ - loop_ir, the IR of the loop blocks (derived from the original func_ir)
905
+ - caller_max_label, the maximum label in the func_ir caller
906
+ - dont_replace, variables that should not be renamed (to handle
907
+ references to variables that are incoming at the loop head/escaping at
908
+ the loop exit.
909
+ - switch_data, the switch table data used to generated the switch_ir,
910
+ can be generated by self.analyse_tuple.
911
+
912
+ Returns:
913
+ - A type specific switch table with each case containing a versioned
914
+ loop body suitable for injection as a replacement for the loop_ir.
915
+ """
916
+
917
+ # Switch IR came from code gen, immediately relabel to prevent
918
+ # collisions with IR derived from the user code (caller)
919
+ switch_ir.blocks = self.add_offset_to_labels_w_ignore(
920
+ switch_ir.blocks, caller_max_label + 1
921
+ )
922
+
923
+ # Find the sentinels and validate the form
924
+ sentinel_exits = set()
925
+ sentinel_blocks = []
926
+ for lbl, blk in switch_ir.blocks.items():
927
+ for i, stmt in enumerate(blk.body):
928
+ if isinstance(stmt, ir.Assign):
929
+ if "SENTINEL" in stmt.target.name:
930
+ sentinel_blocks.append(lbl)
931
+ sentinel_exits.add(blk.body[-1].target)
932
+ break
933
+
934
+ assert len(sentinel_exits) == 1 # should only be 1 exit
935
+ switch_ir.blocks.pop(sentinel_exits.pop()) # kill the exit, it's dead
936
+
937
+ # find jumps that are non-local, we won't relabel these
938
+ ignore_set = set()
939
+ local_lbl = [x for x in loop_ir.blocks.keys()]
940
+ for lbl, blk in loop_ir.blocks.items():
941
+ for i, stmt in enumerate(blk.body):
942
+ if isinstance(stmt, ir.Jump):
943
+ if stmt.target not in local_lbl:
944
+ ignore_set.add(stmt.target)
945
+ if isinstance(stmt, ir.Branch):
946
+ if stmt.truebr not in local_lbl:
947
+ ignore_set.add(stmt.truebr)
948
+ if stmt.falsebr not in local_lbl:
949
+ ignore_set.add(stmt.falsebr)
950
+
951
+ # make sure the generated switch table matches the switch data
952
+ assert len(sentinel_blocks) == len(switch_data)
953
+
954
+ # replace the sentinel_blocks with the loop body
955
+ for lbl, branch_ty in zip(sentinel_blocks, switch_data.keys()):
956
+ loop_blocks = deepcopy(loop_ir.blocks)
957
+ # relabel blocks WRT switch table, each block replacement will shift
958
+ # the maximum label
959
+ max_label = max(switch_ir.blocks.keys())
960
+ loop_blocks = self.add_offset_to_labels_w_ignore(
961
+ loop_blocks, max_label + 1, ignore_set
962
+ )
963
+
964
+ # start label
965
+ loop_start_lbl = min(loop_blocks.keys())
966
+
967
+ # fix the typed_getitem locations in the loop blocks
968
+ for blk in loop_blocks.values():
969
+ new_body = []
970
+ for stmt in blk.body:
971
+ if isinstance(stmt, ir.Assign):
972
+ if (
973
+ isinstance(stmt.value, ir.Expr)
974
+ and stmt.value.op == "typed_getitem"
975
+ ):
976
+ if isinstance(branch_ty, types.Literal):
977
+ scope = switch_ir.blocks[lbl].scope
978
+ new_const_name = scope.redefine(
979
+ "branch_const", stmt.loc
980
+ ).name
981
+ new_const_var = ir.Var(
982
+ blk.scope, new_const_name, stmt.loc
983
+ )
984
+ new_const_val = ir.Const(
985
+ branch_ty.literal_value, stmt.loc
986
+ )
987
+ const_assign = ir.Assign(
988
+ new_const_val, new_const_var, stmt.loc
989
+ )
990
+ new_assign = ir.Assign(
991
+ new_const_var, stmt.target, stmt.loc
992
+ )
993
+ new_body.append(const_assign)
994
+ new_body.append(new_assign)
995
+ dont_replace.append(new_const_name)
996
+ else:
997
+ orig = stmt.value
998
+ new_typed_getitem = ir.Expr.typed_getitem(
999
+ value=orig.value,
1000
+ dtype=branch_ty,
1001
+ index=orig.index,
1002
+ loc=orig.loc,
1003
+ )
1004
+ new_assign = ir.Assign(
1005
+ new_typed_getitem, stmt.target, stmt.loc
1006
+ )
1007
+ new_body.append(new_assign)
1008
+ else:
1009
+ new_body.append(stmt)
1010
+ else:
1011
+ new_body.append(stmt)
1012
+ blk.body = new_body
1013
+
1014
+ # rename
1015
+ var_table = get_name_var_table(loop_blocks)
1016
+ drop_keys = []
1017
+ for k, v in var_table.items():
1018
+ if v.name in dont_replace:
1019
+ drop_keys.append(k)
1020
+ for k in drop_keys:
1021
+ var_table.pop(k)
1022
+
1023
+ new_var_dict = {}
1024
+ for name, var in var_table.items():
1025
+ scope = switch_ir.blocks[lbl].scope
1026
+ try:
1027
+ scope.get_exact(name)
1028
+ except errors.NotDefinedError:
1029
+ # In case the scope doesn't have the variable, we need to
1030
+ # define it prior creating new copies of it! This is
1031
+ # because the scope of the function and the scope of the
1032
+ # loop are different and the variable needs to be redefined
1033
+ # within the scope of the loop.
1034
+ scope.define(name, var.loc)
1035
+ new_var_dict[name] = scope.redefine(name, var.loc).name
1036
+ replace_var_names(loop_blocks, new_var_dict)
1037
+
1038
+ # clobber the sentinel body and then stuff in the rest
1039
+ switch_ir.blocks[lbl] = deepcopy(loop_blocks[loop_start_lbl])
1040
+ remaining_keys = [y for y in loop_blocks.keys()]
1041
+ remaining_keys.remove(loop_start_lbl)
1042
+ for k in remaining_keys:
1043
+ switch_ir.blocks[k] = deepcopy(loop_blocks[k])
1044
+
1045
+ if self._DEBUG:
1046
+ print("-" * 80 + "EXIT STUFFER")
1047
+ switch_ir.dump()
1048
+ print("-" * 80)
1049
+
1050
+ return switch_ir
1051
+
1052
+ def gen_switch(self, data, index):
1053
+ """
1054
+ Generates a function with a switch table like
1055
+ def foo():
1056
+ if PLACEHOLDER_INDEX in (<integers>):
1057
+ SENTINEL = None
1058
+ elif PLACEHOLDER_INDEX in (<integers>):
1059
+ SENTINEL = None
1060
+ ...
1061
+ else:
1062
+ raise RuntimeError
1063
+
1064
+ The data is a map of (type : indexes) for example:
1065
+ (int64, int64, float64)
1066
+ might give:
1067
+ {int64: [0, 1], float64: [2]}
1068
+
1069
+ The index is the index variable for the driving range loop over the
1070
+ mixed tuple.
1071
+ """
1072
+ elif_tplt = "\n\telif PLACEHOLDER_INDEX in (%s,):\n\t\tSENTINEL = None"
1073
+
1074
+ # Note regarding the insertion of the garbage/defeat variables below:
1075
+ # These values have been designed and inserted to defeat a specific
1076
+ # behaviour of the cpython optimizer. The optimization was introduced
1077
+ # in Python 3.10.
1078
+
1079
+ # The URL for the BPO is:
1080
+ # https://bugs.python.org/issue44626
1081
+ # The code for the optimization can be found at:
1082
+ # https://github.com/python/cpython/blob/d41abe8/Python/compile.c#L7533-L7557
1083
+
1084
+ # Essentially the CPython optimizer will inline the exit block under
1085
+ # certain circumstances and thus replace the jump with a return if the
1086
+ # exit block is small enough. This is an issue for unroller, as it
1087
+ # looks for a jump, not a return, when it inserts the generated switch
1088
+ # table.
1089
+
1090
+ # Part of the condition for this optimization to be applied is that the
1091
+ # exit block not exceed a certain (4 at the time of writing) number of
1092
+ # bytecode instructions. We defeat the optimizer by inserting a
1093
+ # sufficient number of instructions so that the exit block is big
1094
+ # enough. We don't care about this garbage, because the generated exit
1095
+ # block is discarded anyway when we smash the switch table into the
1096
+ # original function and so all the inserted garbage is dropped again.
1097
+
1098
+ # The final lines of the stacktrace w/o this will look like:
1099
+ #
1100
+ # File "/numba/numba/core/untyped_passes.py", line 830, \
1101
+ # in inject_loop_body
1102
+ # sentinel_exits.add(blk.body[-1].target)
1103
+ # AttributeError: Failed in nopython mode pipeline \
1104
+ # (step: handles literal_unroll)
1105
+ # Failed in literal_unroll_subpipeline mode pipeline \
1106
+ # (step: performs mixed container unroll)
1107
+ # 'Return' object has no attribute 'target'
1108
+ #
1109
+ # Which indicates that a Return has been found instead of a Jump
1110
+
1111
+ b = (
1112
+ "def foo():\n\tif PLACEHOLDER_INDEX in (%s,):\n\t\t"
1113
+ "SENTINEL = None\n%s\n\telse:\n\t\t"
1114
+ 'raise RuntimeError("Unreachable")\n\t'
1115
+ "py310_defeat1 = 1\n\t"
1116
+ "py310_defeat2 = 2\n\t"
1117
+ "py310_defeat3 = 3\n\t"
1118
+ "py310_defeat4 = 4\n\t"
1119
+ )
1120
+ keys = [k for k in data.keys()]
1121
+
1122
+ elifs = []
1123
+ for i in range(1, len(keys)):
1124
+ elifs.append(elif_tplt % ",".join(map(str, data[keys[i]])))
1125
+ src = b % (",".join(map(str, data[keys[0]])), "".join(elifs))
1126
+ wstr = src
1127
+ l = {}
1128
+ exec(wstr, {}, l)
1129
+ bfunc = l["foo"]
1130
+ branches = compile_to_numba_ir(bfunc, {})
1131
+ for lbl, blk in branches.blocks.items():
1132
+ for stmt in blk.body:
1133
+ if isinstance(stmt, ir.Assign):
1134
+ if isinstance(stmt.value, ir.Global):
1135
+ if stmt.value.name == "PLACEHOLDER_INDEX":
1136
+ stmt.value = index
1137
+ return branches
1138
+
1139
+ def apply_transform(self, state):
1140
+ # compute new CFG
1141
+ func_ir = state.func_ir
1142
+ cfg = compute_cfg_from_blocks(func_ir.blocks)
1143
+ # find loops
1144
+ loops = cfg.loops()
1145
+
1146
+ # 0. Find the loops containing literal_unroll and store this
1147
+ # information
1148
+ unroll_info = namedtuple(
1149
+ "unroll_info", ["loop", "call", "arg", "getitem"]
1150
+ )
1151
+
1152
+ def get_call_args(init_arg, want):
1153
+ # Chases the assignment of a called value back through a specific
1154
+ # call to a global function "want" and returns the arguments
1155
+ # supplied to that function's call
1156
+ some_call = get_definition(func_ir, init_arg)
1157
+ if not isinstance(some_call, ir.Expr):
1158
+ raise GuardException
1159
+ if not some_call.op == "call":
1160
+ raise GuardException
1161
+ the_global = get_definition(func_ir, some_call.func)
1162
+ if not isinstance(the_global, ir.Global):
1163
+ raise GuardException
1164
+ if the_global.value is not want:
1165
+ raise GuardException
1166
+ return some_call
1167
+
1168
+ def find_unroll_loops(loops):
1169
+ """This finds loops which are compliant with the form:
1170
+ for i in range(len(literal_unroll(<something>>)))"""
1171
+ unroll_loops = {}
1172
+ for header_lbl, loop in loops.items():
1173
+ # TODO: check the loop head has literal_unroll, if it does but
1174
+ # does not conform to the following then raise
1175
+
1176
+ # scan loop header
1177
+ iternexts = [
1178
+ _
1179
+ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
1180
+ ]
1181
+ # needs to be an single iternext driven loop
1182
+ if len(iternexts) != 1:
1183
+ continue
1184
+ for iternext in iternexts:
1185
+ # Walk the canonicalised loop structure and check it
1186
+ # Check loop form range(literal_unroll(container)))
1187
+ phi = guard(get_definition, func_ir, iternext.value)
1188
+ if phi is None:
1189
+ continue
1190
+
1191
+ # check call global "range"
1192
+ range_call = guard(get_call_args, phi.value, range)
1193
+ if range_call is None:
1194
+ continue
1195
+ range_arg = range_call.args[0]
1196
+
1197
+ # check call global "len"
1198
+ len_call = guard(get_call_args, range_arg, len)
1199
+ if len_call is None:
1200
+ continue
1201
+ len_arg = len_call.args[0]
1202
+
1203
+ # check literal_unroll
1204
+ literal_unroll_call = guard(
1205
+ get_definition, func_ir, len_arg
1206
+ )
1207
+ if literal_unroll_call is None:
1208
+ continue
1209
+ if not isinstance(literal_unroll_call, ir.Expr):
1210
+ continue
1211
+ if literal_unroll_call.op != "call":
1212
+ continue
1213
+ literal_func = getattr(literal_unroll_call, "func", None)
1214
+ if not literal_func:
1215
+ continue
1216
+ call_func = guard(
1217
+ get_definition, func_ir, literal_unroll_call.func
1218
+ )
1219
+ if call_func is None:
1220
+ continue
1221
+ call_func_value = call_func.value
1222
+
1223
+ if call_func_value is literal_unroll:
1224
+ assert len(literal_unroll_call.args) == 1
1225
+ unroll_loops[loop] = literal_unroll_call
1226
+ return unroll_loops
1227
+
1228
+ def ensure_no_nested_unroll(unroll_loops):
1229
+ # Validate loop nests, nested literal_unroll loops are unsupported.
1230
+ # This doesn't check that there's a getitem or anything else
1231
+ # required for the transform to work, simply just that there's no
1232
+ # nesting.
1233
+ for test_loop in unroll_loops:
1234
+ for ref_loop in unroll_loops:
1235
+ if test_loop == ref_loop: # comparing to self! skip
1236
+ continue
1237
+ if test_loop.header in ref_loop.body:
1238
+ msg = "Nesting of literal_unroll is unsupported"
1239
+ loc = func_ir.blocks[test_loop.header].loc
1240
+ raise errors.UnsupportedError(msg, loc)
1241
+
1242
+ def collect_literal_unroll_info(literal_unroll_loops):
1243
+ """Finds the loops induced by `literal_unroll`, returns a list of
1244
+ unroll_info namedtuples for use in the transform pass.
1245
+ """
1246
+
1247
+ literal_unroll_info = []
1248
+ for loop, literal_unroll_call in literal_unroll_loops.items():
1249
+ arg = literal_unroll_call.args[0]
1250
+ typemap = state.typemap
1251
+ resolved_arg = guard(
1252
+ get_definition, func_ir, arg, lhs_only=True
1253
+ )
1254
+ ty = typemap[resolved_arg.name]
1255
+ assert isinstance(ty, self._accepted_types)
1256
+ # loop header is spelled ok, now make sure the body
1257
+ # actually contains a getitem
1258
+
1259
+ # find a "getitem"... only looks in the blocks that belong
1260
+ # _solely_ to this literal_unroll (there should not be nested
1261
+ # literal_unroll loops, this is unsupported).
1262
+ tuple_getitem = None
1263
+ for lbli in loop.body:
1264
+ blk = func_ir.blocks[lbli]
1265
+ for stmt in blk.body:
1266
+ if isinstance(stmt, ir.Assign):
1267
+ if (
1268
+ isinstance(stmt.value, ir.Expr)
1269
+ and stmt.value.op == "getitem"
1270
+ ):
1271
+ # check for something like a[i]
1272
+ if stmt.value.value != arg:
1273
+ # that failed, so check for the
1274
+ # definition
1275
+ dfn = guard(
1276
+ get_definition,
1277
+ func_ir,
1278
+ stmt.value.value,
1279
+ )
1280
+ if dfn is None:
1281
+ continue
1282
+ try:
1283
+ args = getattr(dfn, "args", False)
1284
+ except KeyError:
1285
+ continue
1286
+ if not args:
1287
+ continue
1288
+ if not args[0] == arg:
1289
+ continue
1290
+ target_ty = state.typemap[arg.name]
1291
+ if not isinstance(
1292
+ target_ty, self._accepted_types
1293
+ ):
1294
+ continue
1295
+ tuple_getitem = stmt
1296
+ break
1297
+ if tuple_getitem:
1298
+ break
1299
+ else:
1300
+ continue # no getitem in this loop
1301
+
1302
+ ui = unroll_info(loop, literal_unroll_call, arg, tuple_getitem)
1303
+ literal_unroll_info.append(ui)
1304
+ return literal_unroll_info
1305
+
1306
+ # 1. Collect info about the literal_unroll loops, ensure they are legal
1307
+ literal_unroll_loops = find_unroll_loops(loops)
1308
+ # validate
1309
+ ensure_no_nested_unroll(literal_unroll_loops)
1310
+ # assemble info
1311
+ literal_unroll_info = collect_literal_unroll_info(literal_unroll_loops)
1312
+ if not literal_unroll_info:
1313
+ return False
1314
+
1315
+ # 2. Do the unroll, get a loop and process it!
1316
+ info = literal_unroll_info[0]
1317
+ self.unroll_loop(state, info)
1318
+
1319
+ # 3. Rebuild the state, the IR has taken a hammering
1320
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1321
+ post_proc = postproc.PostProcessor(func_ir)
1322
+ post_proc.run()
1323
+ if self._DEBUG:
1324
+ print("-" * 80 + "END OF PASS, SIMPLIFY DONE")
1325
+ func_ir.dump()
1326
+ func_ir._definitions = build_definitions(func_ir.blocks)
1327
+ return True
1328
+
1329
+ def unroll_loop(self, state, loop_info):
1330
+ # The general idea here is to:
1331
+ # 1. Find *a* getitem that conforms to the literal_unroll semantic,
1332
+ # i.e. one that is targeting a tuple with a loop induced index
1333
+ # 2. Compute a structure from the tuple that describes which
1334
+ # iterations of a loop will have which type
1335
+ # 3. Generate a switch table in IR form for the structure in 2
1336
+ # 4. Switch out getitems for the tuple for a `typed_getitem`
1337
+ # 5. Inject switch table as replacement loop body
1338
+ # 6. Patch up
1339
+ func_ir = state.func_ir
1340
+ getitem_target = loop_info.arg
1341
+ target_ty = state.typemap[getitem_target.name]
1342
+ assert isinstance(target_ty, self._accepted_types)
1343
+
1344
+ # 1. find a "getitem" that conforms
1345
+ tuple_getitem = []
1346
+ for lbl in loop_info.loop.body:
1347
+ blk = func_ir.blocks[lbl]
1348
+ for stmt in blk.body:
1349
+ if isinstance(stmt, ir.Assign):
1350
+ if (
1351
+ isinstance(stmt.value, ir.Expr)
1352
+ and stmt.value.op == "getitem"
1353
+ ):
1354
+ # try a couple of spellings... a[i] and ref(a)[i]
1355
+ if stmt.value.value != getitem_target:
1356
+ dfn = func_ir.get_definition(stmt.value.value)
1357
+ try:
1358
+ args = getattr(dfn, "args", False)
1359
+ except KeyError:
1360
+ continue
1361
+ if not args:
1362
+ continue
1363
+ if not args[0] == getitem_target:
1364
+ continue
1365
+ target_ty = state.typemap[getitem_target.name]
1366
+ if not isinstance(target_ty, self._accepted_types):
1367
+ continue
1368
+ tuple_getitem.append(stmt)
1369
+
1370
+ if not tuple_getitem:
1371
+ msg = (
1372
+ "Loop unrolling analysis has failed, there's no getitem "
1373
+ "in loop body that conforms to literal_unroll "
1374
+ "requirements."
1375
+ )
1376
+ LOC = func_ir.blocks[loop_info.loop.header].loc
1377
+ raise errors.CompilerError(msg, LOC)
1378
+
1379
+ # 2. get switch data
1380
+ switch_data = self.analyse_tuple(target_ty)
1381
+
1382
+ # 3. generate switch IR
1383
+ index = func_ir._definitions[tuple_getitem[0].value.index.name][0]
1384
+ branches = self.gen_switch(switch_data, index)
1385
+
1386
+ # 4. swap getitems for a typed_getitem, these are actually just
1387
+ # placeholders at this point. When the loop is duplicated they can
1388
+ # be swapped for a typed_getitem of the correct type or if the item
1389
+ # is literal it can be shoved straight into the duplicated loop body
1390
+ for item in tuple_getitem:
1391
+ old = item.value
1392
+ new = ir.Expr.typed_getitem(
1393
+ old.value, types.void, old.index, old.loc
1394
+ )
1395
+ item.value = new
1396
+
1397
+ # 5. Inject switch table
1398
+
1399
+ # Find the actual loop without the header (that won't get replaced)
1400
+ # and derive some new IR for this set of blocks
1401
+ this_loop = loop_info.loop
1402
+ this_loop_body = this_loop.body - set([this_loop.header])
1403
+ loop_blocks = {x: func_ir.blocks[x] for x in this_loop_body}
1404
+ new_ir = func_ir.derive(loop_blocks)
1405
+
1406
+ # Work out what is live on entry and exit so as to prevent
1407
+ # replacement (defined vars can escape, used vars live at the header
1408
+ # need to remain as-is so their references are correct, they can
1409
+ # also escape).
1410
+
1411
+ usedefs = compute_use_defs(func_ir.blocks)
1412
+ idx = this_loop.header
1413
+ keep = set()
1414
+ keep |= usedefs.usemap[idx] | usedefs.defmap[idx]
1415
+ keep |= func_ir.variable_lifetime.livemap[idx]
1416
+ dont_replace = [x for x in (keep)]
1417
+
1418
+ # compute the unrolled body
1419
+ unrolled_body = self.inject_loop_body(
1420
+ branches,
1421
+ new_ir,
1422
+ max(func_ir.blocks.keys()) + 1,
1423
+ dont_replace,
1424
+ switch_data,
1425
+ )
1426
+
1427
+ # 6. Patch in the unrolled body and fix up
1428
+ blks = state.func_ir.blocks
1429
+ the_scope = next(iter(blks.values())).scope
1430
+ orig_lbl = tuple(this_loop_body)
1431
+
1432
+ replace, *delete = orig_lbl
1433
+ unroll, header_block = unrolled_body, this_loop.header
1434
+ unroll_lbl = [x for x in sorted(unroll.blocks.keys())]
1435
+ blks[replace] = transfer_scope(unroll.blocks[unroll_lbl[0]], the_scope)
1436
+ [blks.pop(d) for d in delete]
1437
+ for k in unroll_lbl[1:]:
1438
+ blks[k] = transfer_scope(unroll.blocks[k], the_scope)
1439
+ # stitch up the loop predicate true -> new loop body jump
1440
+ blks[header_block].body[-1].truebr = replace
1441
+
1442
+ def run_pass(self, state):
1443
+ mutated = False
1444
+ func_ir = state.func_ir
1445
+ # first limit the work by squashing the CFG if possible
1446
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1447
+
1448
+ if self._DEBUG:
1449
+ print("-" * 80 + "PASS ENTRY")
1450
+ func_ir.dump()
1451
+ print("-" * 80)
1452
+
1453
+ # limitations:
1454
+ # 1. No nested unrolls
1455
+ # 2. Opt in via `numba.literal_unroll`
1456
+ # 3. No multiple mix-tuple use
1457
+
1458
+ # keep running the transform loop until it reports no more changes
1459
+ while True:
1460
+ stat = self.apply_transform(state)
1461
+ mutated |= stat
1462
+ if not stat:
1463
+ break
1464
+
1465
+ # reset type inference now we are done with the partial results
1466
+ state.typemap = {}
1467
+ state.calltypes = None
1468
+
1469
+ return mutated
1470
+
1471
+
1472
+ @register_pass(mutates_CFG=True, analysis_only=False)
1473
+ class IterLoopCanonicalization(FunctionPass):
1474
+ """Transforms loops that are induced by `getiter` into range() driven loops
1475
+ If the typemap is available this will only impact Tuple and UniTuple, if it
1476
+ is not available it will impact all matching loops.
1477
+ """
1478
+
1479
+ _name = "iter_loop_canonicalisation"
1480
+
1481
+ _DEBUG = False
1482
+
1483
+ # if partial typing info is available it will only look at these types
1484
+ _accepted_types = (types.BaseTuple, types.LiteralList)
1485
+ _accepted_calls = (literal_unroll,)
1486
+
1487
+ def __init__(self):
1488
+ FunctionPass.__init__(self)
1489
+
1490
+ def assess_loop(self, loop, func_ir, partial_typemap=None):
1491
+ # it's a iter loop if:
1492
+ # - loop header is driven by an iternext
1493
+ # - the iternext value is a phi derived from getiter()
1494
+
1495
+ # check header
1496
+ iternexts = [
1497
+ _ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
1498
+ ]
1499
+ if len(iternexts) != 1:
1500
+ return False
1501
+ for iternext in iternexts:
1502
+ phi = guard(get_definition, func_ir, iternext.value)
1503
+ if phi is None:
1504
+ return False
1505
+ if getattr(phi, "op", False) == "getiter":
1506
+ if partial_typemap:
1507
+ # check that the call site is accepted, until we're
1508
+ # confident that tuple unrolling is behaving require opt-in
1509
+ # guard of `literal_unroll`, remove this later!
1510
+ phi_val_defn = guard(get_definition, func_ir, phi.value)
1511
+ if not isinstance(phi_val_defn, ir.Expr):
1512
+ return False
1513
+ if not phi_val_defn.op == "call":
1514
+ return False
1515
+ call = guard(get_definition, func_ir, phi_val_defn)
1516
+ if call is None or len(call.args) != 1:
1517
+ return False
1518
+ func_var = guard(get_definition, func_ir, call.func)
1519
+ func = guard(get_definition, func_ir, func_var)
1520
+ if func is None or not isinstance(
1521
+ func, (ir.Global, ir.FreeVar)
1522
+ ):
1523
+ return False
1524
+ if (
1525
+ func.value is None
1526
+ or func.value not in self._accepted_calls
1527
+ ):
1528
+ return False
1529
+
1530
+ # now check the type is supported
1531
+ ty = partial_typemap.get(call.args[0].name, None)
1532
+ if ty and isinstance(ty, self._accepted_types):
1533
+ return len(loop.entries) == 1
1534
+ else:
1535
+ return len(loop.entries) == 1
1536
+
1537
+ def transform(self, loop, func_ir, cfg):
1538
+ def get_range(a):
1539
+ return range(len(a))
1540
+
1541
+ iternext = [
1542
+ _ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
1543
+ ][0]
1544
+ LOC = func_ir.blocks[loop.header].loc
1545
+ scope = func_ir.blocks[loop.header].scope
1546
+ get_range_var = scope.redefine("CANONICALISER_get_range_gbl", LOC)
1547
+ get_range_global = ir.Global("get_range", get_range, LOC)
1548
+ assgn = ir.Assign(get_range_global, get_range_var, LOC)
1549
+
1550
+ loop_entry = tuple(loop.entries)[0]
1551
+ entry_block = func_ir.blocks[loop_entry]
1552
+ entry_block.body.insert(0, assgn)
1553
+
1554
+ iterarg = guard(get_definition, func_ir, iternext.value)
1555
+ if iterarg is not None:
1556
+ iterarg = iterarg.value
1557
+
1558
+ # look for iternext
1559
+ idx = 0
1560
+ for stmt in entry_block.body:
1561
+ if isinstance(stmt, ir.Assign):
1562
+ if (
1563
+ isinstance(stmt.value, ir.Expr)
1564
+ and stmt.value.op == "getiter"
1565
+ ):
1566
+ break
1567
+ idx += 1
1568
+ else:
1569
+ raise ValueError("problem")
1570
+
1571
+ # create a range(len(tup)) and inject it
1572
+ call_get_range_var = scope.redefine("CANONICALISER_call_get_range", LOC)
1573
+ make_call = ir.Expr.call(get_range_var, (stmt.value.value,), (), LOC)
1574
+ assgn_call = ir.Assign(make_call, call_get_range_var, LOC)
1575
+ entry_block.body.insert(idx, assgn_call)
1576
+ entry_block.body[idx + 1].value.value = call_get_range_var
1577
+
1578
+ glbls = copy(func_ir.func_id.func.__globals__)
1579
+
1580
+ inline_closurecall.inline_closure_call(
1581
+ func_ir,
1582
+ glbls,
1583
+ entry_block,
1584
+ idx,
1585
+ get_range,
1586
+ )
1587
+ kill = entry_block.body.index(assgn)
1588
+ entry_block.body.pop(kill)
1589
+
1590
+ # find the induction variable + references in the loop header
1591
+ # fixed point iter to do this, it's a bit clunky
1592
+ induction_vars = set()
1593
+ header_block = func_ir.blocks[loop.header]
1594
+
1595
+ # find induction var
1596
+ ind = [x for x in header_block.find_exprs("pair_first")]
1597
+ for x in ind:
1598
+ induction_vars.add(func_ir.get_assignee(x, loop.header))
1599
+ # find aliases of the induction var
1600
+ tmp = set()
1601
+ for x in induction_vars:
1602
+ try: # there's not always an alias, e.g. loop from inlined closure
1603
+ tmp.add(func_ir.get_assignee(x, loop.header))
1604
+ except ValueError:
1605
+ pass
1606
+ induction_vars |= tmp
1607
+ induction_var_names = set([x.name for x in induction_vars])
1608
+
1609
+ # Find the downstream blocks that might reference the induction var
1610
+ succ = set()
1611
+ for lbl in loop.exits:
1612
+ succ |= set([x[0] for x in cfg.successors(lbl)])
1613
+ check_blocks = (loop.body | loop.exits | succ) ^ {loop.header}
1614
+
1615
+ # replace RHS use of induction var with getitem
1616
+ for lbl in check_blocks:
1617
+ for stmt in func_ir.blocks[lbl].body:
1618
+ if isinstance(stmt, ir.Assign):
1619
+ # check for aliases
1620
+ try:
1621
+ lookup = getattr(stmt.value, "name", None)
1622
+ except KeyError:
1623
+ continue
1624
+ if lookup and lookup in induction_var_names:
1625
+ stmt.value = ir.Expr.getitem(
1626
+ iterarg, stmt.value, stmt.loc
1627
+ )
1628
+
1629
+ post_proc = postproc.PostProcessor(func_ir)
1630
+ post_proc.run()
1631
+
1632
+ def run_pass(self, state):
1633
+ func_ir = state.func_ir
1634
+ cfg = compute_cfg_from_blocks(func_ir.blocks)
1635
+ loops = cfg.loops()
1636
+
1637
+ mutated = False
1638
+ for header, loop in loops.items():
1639
+ stat = self.assess_loop(loop, func_ir, state.typemap)
1640
+ if stat:
1641
+ if self._DEBUG:
1642
+ print("Canonicalising loop", loop)
1643
+ self.transform(loop, func_ir, cfg)
1644
+ mutated = True
1645
+ else:
1646
+ if self._DEBUG:
1647
+ print("NOT Canonicalising loop", loop)
1648
+
1649
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1650
+ return mutated
1651
+
1652
+
1653
+ @register_pass(mutates_CFG=False, analysis_only=False)
1654
+ class PropagateLiterals(FunctionPass):
1655
+ """Implement literal propagation based on partial type inference"""
1656
+
1657
+ _name = "PropagateLiterals"
1658
+
1659
+ def __init__(self):
1660
+ FunctionPass.__init__(self)
1661
+
1662
+ def get_analysis_usage(self, AU):
1663
+ AU.add_required(ReconstructSSA)
1664
+
1665
+ def run_pass(self, state):
1666
+ func_ir = state.func_ir
1667
+ typemap = state.typemap
1668
+ flags = state.flags
1669
+
1670
+ accepted_functions = ("isinstance", "hasattr")
1671
+
1672
+ if not hasattr(func_ir, "_definitions") and not flags.enable_ssa:
1673
+ func_ir._definitions = build_definitions(func_ir.blocks)
1674
+
1675
+ changed = False
1676
+
1677
+ for block in func_ir.blocks.values():
1678
+ for assign in block.find_insts(ir.Assign):
1679
+ value = assign.value
1680
+ if isinstance(value, (ir.Arg, ir.Const, ir.FreeVar, ir.Global)):
1681
+ continue
1682
+
1683
+ # 1) Don't change return stmt in the form
1684
+ # $return_xyz = cast(value=ABC)
1685
+ # 2) Don't propagate literal values that are not primitives
1686
+ if isinstance(value, ir.Expr) and value.op in (
1687
+ "cast",
1688
+ "build_map",
1689
+ "build_list",
1690
+ "build_tuple",
1691
+ "build_set",
1692
+ ):
1693
+ continue
1694
+
1695
+ target = assign.target
1696
+ if not flags.enable_ssa:
1697
+ # SSA is disabled when doing inlining
1698
+ if guard(get_definition, func_ir, target.name) is None: # noqa: E501
1699
+ continue
1700
+
1701
+ # Numba cannot safely determine if an isinstance call
1702
+ # with a PHI node is True/False. For instance, in
1703
+ # the case below, the partial type inference step can coerce
1704
+ # '$z' to float, so any call to 'isinstance(z, int)' would fail.
1705
+ #
1706
+ # def fn(x):
1707
+ # if x > 4:
1708
+ # z = 1
1709
+ # else:
1710
+ # z = 3.14
1711
+ # if isinstance(z, int):
1712
+ # print('int')
1713
+ # else:
1714
+ # print('float')
1715
+ #
1716
+ # At the moment, one avoid propagating the literal
1717
+ # value if the argument is a PHI node
1718
+
1719
+ if isinstance(value, ir.Expr) and value.op == "call":
1720
+ fn = guard(get_definition, func_ir, value.func.name)
1721
+ if fn is None:
1722
+ continue
1723
+
1724
+ if not (
1725
+ isinstance(fn, ir.Global)
1726
+ and fn.name in accepted_functions
1727
+ ):
1728
+ continue
1729
+
1730
+ for arg in value.args:
1731
+ # check if any of the args to isinstance is a PHI node
1732
+ iv = func_ir._definitions[arg.name]
1733
+ assert len(iv) == 1 # SSA!
1734
+ if isinstance(iv[0], ir.Expr) and iv[0].op == "phi":
1735
+ msg = (
1736
+ f"{fn.name}() cannot determine the "
1737
+ f'type of variable "{arg.unversioned_name}" '
1738
+ "due to a branch."
1739
+ )
1740
+ raise errors.NumbaTypeError(msg, loc=assign.loc)
1741
+
1742
+ # Only propagate a PHI node if all arguments are the same
1743
+ # constant
1744
+ if isinstance(value, ir.Expr) and value.op == "phi":
1745
+ # typemap will return None in case `inc.name` not in typemap
1746
+ v = [typemap.get(inc.name) for inc in value.incoming_values]
1747
+ # stop if the elements in `v` do not hold the same value
1748
+ if v[0] is not None and any([v[0] != vi for vi in v]):
1749
+ continue
1750
+
1751
+ lit = typemap.get(target.name, None)
1752
+ if lit and isinstance(lit, types.Literal):
1753
+ # replace assign instruction by ir.Const(lit) iff
1754
+ # lit is a literal value
1755
+ rhs = ir.Const(lit.literal_value, assign.loc)
1756
+ new_assign = ir.Assign(rhs, target, assign.loc)
1757
+
1758
+ # replace instruction
1759
+ block.insert_after(new_assign, assign)
1760
+ block.remove(assign)
1761
+
1762
+ changed = True
1763
+
1764
+ # reset type inference now we are done with the partial results
1765
+ state.typemap = None
1766
+ state.calltypes = None
1767
+
1768
+ if changed:
1769
+ # Rebuild definitions
1770
+ func_ir._definitions = build_definitions(func_ir.blocks)
1771
+
1772
+ return changed
1773
+
1774
+
1775
+ @register_pass(mutates_CFG=True, analysis_only=False)
1776
+ class LiteralPropagationSubPipelinePass(FunctionPass):
1777
+ """Implement literal propagation based on partial type inference"""
1778
+
1779
+ _name = "LiteralPropagation"
1780
+
1781
+ def __init__(self):
1782
+ FunctionPass.__init__(self)
1783
+
1784
+ def run_pass(self, state):
1785
+ # Determine whether to even attempt this pass... if there's no
1786
+ # `isinstance` as a global or as a freevar then just skip.
1787
+
1788
+ found = False
1789
+ func_ir = state.func_ir
1790
+ for blk in func_ir.blocks.values():
1791
+ for asgn in blk.find_insts(ir.Assign):
1792
+ if isinstance(asgn.value, (ir.Global, ir.FreeVar)):
1793
+ value = asgn.value.value
1794
+ if value is isinstance or value is hasattr:
1795
+ found = True
1796
+ break
1797
+ if found:
1798
+ break
1799
+ if not found:
1800
+ return False
1801
+
1802
+ # run as subpipeline
1803
+ from numba.cuda.core.compiler_machinery import PassManager
1804
+ from numba.cuda.core.typed_passes import PartialTypeInference
1805
+
1806
+ pm = PassManager("literal_propagation_subpipeline")
1807
+
1808
+ pm.add_pass(PartialTypeInference, "performs partial type inference")
1809
+ pm.add_pass(PropagateLiterals, "performs propagation of literal values")
1810
+
1811
+ # rewrite consts / dead branch pruning
1812
+ pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
1813
+ pm.add_pass(DeadBranchPrune, "dead branch pruning")
1814
+
1815
+ pm.finalize()
1816
+ pm.run(state)
1817
+ return True
1818
+
1819
+ def get_analysis_usage(self, AU):
1820
+ AU.add_required(ReconstructSSA)
1821
+
1822
+
1823
+ @register_pass(mutates_CFG=True, analysis_only=False)
1824
+ class LiteralUnroll(FunctionPass):
1825
+ """Implement the literal_unroll semantics"""
1826
+
1827
+ _name = "literal_unroll"
1828
+
1829
+ def __init__(self):
1830
+ FunctionPass.__init__(self)
1831
+
1832
+ def run_pass(self, state):
1833
+ # Determine whether to even attempt this pass... if there's no
1834
+ # `literal_unroll` as a global or as a freevar then just skip.
1835
+ found = False
1836
+ func_ir = state.func_ir
1837
+ for blk in func_ir.blocks.values():
1838
+ for asgn in blk.find_insts(ir.Assign):
1839
+ if isinstance(asgn.value, (ir.Global, ir.FreeVar)):
1840
+ if asgn.value.value is literal_unroll:
1841
+ found = True
1842
+ break
1843
+ if found:
1844
+ break
1845
+ if not found:
1846
+ return False
1847
+
1848
+ # run as subpipeline
1849
+ from numba.cuda.core.compiler_machinery import PassManager
1850
+ from numba.cuda.core.typed_passes import PartialTypeInference
1851
+
1852
+ pm = PassManager("literal_unroll_subpipeline")
1853
+ # get types where possible to help with list->tuple change
1854
+ pm.add_pass(PartialTypeInference, "performs partial type inference")
1855
+ # make const lists tuples
1856
+ pm.add_pass(
1857
+ TransformLiteralUnrollConstListToTuple,
1858
+ "switch const list for tuples",
1859
+ )
1860
+ # recompute partial typemap following IR change
1861
+ pm.add_pass(PartialTypeInference, "performs partial type inference")
1862
+ # canonicalise loops
1863
+ pm.add_pass(
1864
+ IterLoopCanonicalization, "switch iter loops for range driven loops"
1865
+ )
1866
+ # rewrite consts
1867
+ pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
1868
+ # do the unroll
1869
+ pm.add_pass(MixedContainerUnroller, "performs mixed container unroll")
1870
+ # rewrite dynamic getitem to static getitem as it's possible some more
1871
+ # getitems will now be statically resolvable
1872
+ pm.add_pass(GenericRewrites, "Generic Rewrites")
1873
+ pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
1874
+ pm.finalize()
1875
+ pm.run(state)
1876
+ return True
1877
+
1878
+
1879
+ @register_pass(mutates_CFG=True, analysis_only=False)
1880
+ class SimplifyCFG(FunctionPass):
1881
+ """Perform CFG simplification"""
1882
+
1883
+ _name = "simplify_cfg"
1884
+
1885
+ def __init__(self):
1886
+ FunctionPass.__init__(self)
1887
+
1888
+ def run_pass(self, state):
1889
+ blks = state.func_ir.blocks
1890
+ new_blks = simplify_CFG(blks)
1891
+ state.func_ir.blocks = new_blks
1892
+ mutated = blks != new_blks
1893
+ return mutated
1894
+
1895
+
1896
+ @register_pass(mutates_CFG=False, analysis_only=False)
1897
+ class ReconstructSSA(FunctionPass):
1898
+ """Perform SSA-reconstruction
1899
+
1900
+ Produces minimal SSA.
1901
+ """
1902
+
1903
+ _name = "reconstruct_ssa"
1904
+
1905
+ def __init__(self):
1906
+ FunctionPass.__init__(self)
1907
+
1908
+ def run_pass(self, state):
1909
+ state.func_ir = reconstruct_ssa(state.func_ir)
1910
+ self._patch_locals(state)
1911
+
1912
+ # Rebuild definitions
1913
+ state.func_ir._definitions = build_definitions(state.func_ir.blocks)
1914
+
1915
+ # Rerun postprocessor to update metadata
1916
+ # example generator_info
1917
+ post_proc = postproc.PostProcessor(state.func_ir)
1918
+ post_proc.run(emit_dels=False)
1919
+
1920
+ if config.DEBUG or config.DUMP_SSA:
1921
+ name = state.func_ir.func_id.func_qualname
1922
+ print(f"SSA IR DUMP: {name}".center(80, "-"))
1923
+ state.func_ir.dump()
1924
+
1925
+ return True # XXX detect if it actually got changed
1926
+
1927
+ def _patch_locals(self, state):
1928
+ # Fix dispatcher locals dictionary type annotation
1929
+ locals_dict = state.get("locals")
1930
+ if locals_dict is None:
1931
+ return
1932
+
1933
+ first_blk, *_ = state.func_ir.blocks.values()
1934
+ scope = first_blk.scope
1935
+ for parent, redefs in scope.var_redefinitions.items():
1936
+ if parent in locals_dict:
1937
+ typ = locals_dict[parent]
1938
+ for derived in redefs:
1939
+ locals_dict[derived] = typ
1940
+
1941
+
1942
+ @register_pass(mutates_CFG=False, analysis_only=False)
1943
+ class RewriteDynamicRaises(FunctionPass):
1944
+ """Replace existing raise statements by dynamic raises in Numba IR."""
1945
+
1946
+ _name = "Rewrite dynamic raises"
1947
+
1948
+ def __init__(self):
1949
+ FunctionPass.__init__(self)
1950
+
1951
+ def run_pass(self, state):
1952
+ func_ir = state.func_ir
1953
+ changed = False
1954
+
1955
+ for block in func_ir.blocks.values():
1956
+ for raise_ in block.find_insts((ir.Raise, ir.TryRaise)):
1957
+ call_inst = guard(get_definition, func_ir, raise_.exception)
1958
+ if call_inst is None:
1959
+ continue
1960
+ exc_type = func_ir.infer_constant(call_inst.func.name)
1961
+ exc_args = []
1962
+ for exc_arg in call_inst.args:
1963
+ try:
1964
+ const = func_ir.infer_constant(exc_arg)
1965
+ exc_args.append(const)
1966
+ except consts.ConstantInferenceError:
1967
+ exc_args.append(exc_arg)
1968
+ loc = raise_.loc
1969
+
1970
+ cls = {
1971
+ ir.TryRaise: ir.DynamicTryRaise,
1972
+ ir.Raise: ir.DynamicRaise,
1973
+ }[type(raise_)]
1974
+
1975
+ dyn_raise = cls(exc_type, tuple(exc_args), loc)
1976
+ block.insert_after(dyn_raise, raise_)
1977
+ block.remove(raise_)
1978
+ changed = True
1979
+ return changed