numba-cuda 0.22.0__cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (487) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +580 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpython-313-aarch64-linux-gnu.so +0 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpython-313-aarch64-linux-gnu.so +0 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cpython-313-aarch64-linux-gnu.so +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpython-313-aarch64-linux-gnu.so +0 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cpython-313-aarch64-linux-gnu.so +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +543 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +954 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3238 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +562 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +983 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +997 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +155 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsics.py +531 -0
  161. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  162. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  163. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  164. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  165. numba_cuda/numba/cuda/libdevice.py +3386 -0
  166. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  167. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  168. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  169. numba_cuda/numba/cuda/locks.py +19 -0
  170. numba_cuda/numba/cuda/lowering.py +1980 -0
  171. numba_cuda/numba/cuda/mathimpl.py +374 -0
  172. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  173. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  175. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  178. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  179. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  180. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  181. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  182. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  183. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  184. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  185. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  186. numba_cuda/numba/cuda/misc/literal.py +28 -0
  187. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  188. numba_cuda/numba/cuda/misc/special.py +94 -0
  189. numba_cuda/numba/cuda/models.py +56 -0
  190. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  191. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  192. numba_cuda/numba/cuda/np/extensions.py +11 -0
  193. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  194. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  195. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  196. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  197. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  198. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  199. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  200. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  201. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  202. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  203. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  204. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  206. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  207. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  208. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  209. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  210. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  211. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  212. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  213. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  214. numba_cuda/numba/cuda/printimpl.py +126 -0
  215. numba_cuda/numba/cuda/random.py +308 -0
  216. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  217. numba_cuda/numba/cuda/serialize.py +267 -0
  218. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  219. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  220. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  221. numba_cuda/numba/cuda/simulator/api.py +179 -0
  222. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  223. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  224. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  236. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  237. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  238. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  239. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  241. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  242. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  243. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  244. numba_cuda/numba/cuda/simulator_init.py +18 -0
  245. numba_cuda/numba/cuda/stubs.py +624 -0
  246. numba_cuda/numba/cuda/target.py +505 -0
  247. numba_cuda/numba/cuda/testing.py +347 -0
  248. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  249. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  251. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  252. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  253. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  254. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  255. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +191 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +200 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  285. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  286. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  289. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  290. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  291. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  292. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  293. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  294. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  295. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +978 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +446 -0
  396. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  397. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  399. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  400. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  401. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  402. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  403. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  404. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  406. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  407. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  424. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  425. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +452 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  430. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  431. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  433. numba_cuda/numba/cuda/tests/support.py +900 -0
  434. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  435. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  436. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  437. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  438. numba_cuda/numba/cuda/types/__init__.py +233 -0
  439. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  440. numba_cuda/numba/cuda/types/abstract.py +9 -0
  441. numba_cuda/numba/cuda/types/common.py +9 -0
  442. numba_cuda/numba/cuda/types/containers.py +9 -0
  443. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  444. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  445. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  446. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  447. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  448. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  449. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  450. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  451. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  452. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  453. numba_cuda/numba/cuda/types/function_type.py +11 -0
  454. numba_cuda/numba/cuda/types/functions.py +9 -0
  455. numba_cuda/numba/cuda/types/iterators.py +9 -0
  456. numba_cuda/numba/cuda/types/misc.py +9 -0
  457. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  458. numba_cuda/numba/cuda/types/scalars.py +9 -0
  459. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  460. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  461. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  462. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  463. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  464. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  465. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  466. numba_cuda/numba/cuda/typing/collections.py +138 -0
  467. numba_cuda/numba/cuda/typing/context.py +782 -0
  468. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  469. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  470. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  471. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  472. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  473. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  474. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  475. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  476. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  477. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  478. numba_cuda/numba/cuda/ufuncs.py +746 -0
  479. numba_cuda/numba/cuda/utils.py +724 -0
  480. numba_cuda/numba/cuda/vector_types.py +214 -0
  481. numba_cuda/numba/cuda/vectorizers.py +260 -0
  482. numba_cuda-0.22.0.dist-info/METADATA +109 -0
  483. numba_cuda-0.22.0.dist-info/RECORD +487 -0
  484. numba_cuda-0.22.0.dist-info/WHEEL +6 -0
  485. numba_cuda-0.22.0.dist-info/licenses/LICENSE +26 -0
  486. numba_cuda-0.22.0.dist-info/licenses/LICENSE.numba +24 -0
  487. numba_cuda-0.22.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2638 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2017 Intel Corporation
2
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: BSD-2-Clause
4
+
5
+ import numpy
6
+ import math
7
+
8
+ import types as pytypes
9
+ import collections
10
+ import warnings
11
+
12
+ import numba.cuda
13
+ from numba.cuda import HAS_NUMBA
14
+ from numba.cuda import types
15
+ from numba.cuda.core import ir
16
+ from numba.cuda import typing
17
+ from numba.cuda.core import analysis, postproc, rewrites, config
18
+ from numba.cuda.typing.templates import signature
19
+ from numba.cuda.core.analysis import (
20
+ compute_live_map,
21
+ compute_use_defs,
22
+ compute_cfg_from_blocks,
23
+ )
24
+ from numba.cuda.core.errors import (
25
+ TypingError,
26
+ UnsupportedError,
27
+ NumbaPendingDeprecationWarning,
28
+ CompilerError,
29
+ )
30
+
31
+ import copy
32
+
33
+ _unique_var_count = 0
34
+
35
+
36
+ def mk_unique_var(prefix):
37
+ global _unique_var_count
38
+ var = prefix + "." + str(_unique_var_count)
39
+ _unique_var_count = _unique_var_count + 1
40
+ return var
41
+
42
+
43
+ class _MaxLabel:
44
+ def __init__(self, value=0):
45
+ self._value = value
46
+
47
+ def next(self):
48
+ self._value += 1
49
+ return self._value
50
+
51
+ def update(self, newval):
52
+ self._value = max(newval, self._value)
53
+
54
+
55
+ _the_max_label = _MaxLabel()
56
+ del _MaxLabel
57
+
58
+
59
+ def get_unused_var_name(prefix, var_table):
60
+ """Get a new var name with a given prefix and
61
+ make sure it is unused in the given variable table.
62
+ """
63
+ cur = 0
64
+ while True:
65
+ var = prefix + str(cur)
66
+ if var not in var_table:
67
+ return var
68
+ cur += 1
69
+
70
+
71
+ def next_label():
72
+ return _the_max_label.next()
73
+
74
+
75
+ def mk_alloc(
76
+ typingctx, typemap, calltypes, lhs, size_var, dtype, scope, loc, lhs_typ
77
+ ):
78
+ """generate an array allocation with np.empty() and return list of nodes.
79
+ size_var can be an int variable or tuple of int variables.
80
+ lhs_typ is the type of the array being allocated.
81
+ """
82
+ out = []
83
+ ndims = 1
84
+ size_typ = types.intp
85
+ if isinstance(size_var, tuple):
86
+ if len(size_var) == 1:
87
+ size_var = size_var[0]
88
+ size_var = convert_size_to_var(size_var, typemap, scope, loc, out)
89
+ else:
90
+ # tuple_var = build_tuple([size_var...])
91
+ ndims = len(size_var)
92
+ tuple_var = ir.Var(scope, mk_unique_var("$tuple_var"), loc)
93
+ if typemap:
94
+ typemap[tuple_var.name] = types.containers.UniTuple(
95
+ types.intp, ndims
96
+ )
97
+ # constant sizes need to be assigned to vars
98
+ new_sizes = [
99
+ convert_size_to_var(s, typemap, scope, loc, out)
100
+ for s in size_var
101
+ ]
102
+ tuple_call = ir.Expr.build_tuple(new_sizes, loc)
103
+ tuple_assign = ir.Assign(tuple_call, tuple_var, loc)
104
+ out.append(tuple_assign)
105
+ size_var = tuple_var
106
+ size_typ = types.containers.UniTuple(types.intp, ndims)
107
+ if hasattr(lhs_typ, "__allocate__"):
108
+ return lhs_typ.__allocate__(
109
+ typingctx,
110
+ typemap,
111
+ calltypes,
112
+ lhs,
113
+ size_var,
114
+ dtype,
115
+ scope,
116
+ loc,
117
+ lhs_typ,
118
+ size_typ,
119
+ out,
120
+ )
121
+ # g_np_var = Global(numpy)
122
+ g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
123
+ if typemap:
124
+ typemap[g_np_var.name] = types.misc.Module(numpy)
125
+ g_np = ir.Global("np", numpy, loc)
126
+ g_np_assign = ir.Assign(g_np, g_np_var, loc)
127
+ # attr call: empty_attr = getattr(g_np_var, empty)
128
+ empty_attr_call = ir.Expr.getattr(g_np_var, "empty", loc)
129
+ attr_var = ir.Var(scope, mk_unique_var("$empty_attr_attr"), loc)
130
+ if typemap:
131
+ typemap[attr_var.name] = get_np_ufunc_typ(numpy.empty)
132
+ attr_assign = ir.Assign(empty_attr_call, attr_var, loc)
133
+ # Assume str(dtype) returns a valid type
134
+ dtype_str = str(dtype)
135
+ # alloc call: lhs = empty_attr(size_var, typ_var)
136
+ typ_var = ir.Var(scope, mk_unique_var("$np_typ_var"), loc)
137
+ if typemap:
138
+ typemap[typ_var.name] = types.functions.NumberClass(dtype)
139
+ # If dtype is a datetime/timedelta with a unit,
140
+ # then it won't return a valid type and instead can be created
141
+ # with a string. i.e. "datetime64[ns]")
142
+ if (
143
+ isinstance(dtype, (types.NPDatetime, types.NPTimedelta))
144
+ and dtype.unit != ""
145
+ ):
146
+ typename_const = ir.Const(dtype_str, loc)
147
+ typ_var_assign = ir.Assign(typename_const, typ_var, loc)
148
+ else:
149
+ if dtype_str == "bool":
150
+ # empty doesn't like 'bool' sometimes (e.g. kmeans example)
151
+ dtype_str = "bool_"
152
+ np_typ_getattr = ir.Expr.getattr(g_np_var, dtype_str, loc)
153
+ typ_var_assign = ir.Assign(np_typ_getattr, typ_var, loc)
154
+ alloc_call = ir.Expr.call(attr_var, [size_var, typ_var], (), loc)
155
+
156
+ if calltypes:
157
+ cac = typemap[attr_var.name].get_call_type(
158
+ typingctx, [size_typ, types.functions.NumberClass(dtype)], {}
159
+ )
160
+ # By default, all calls to "empty" are typed as returning a standard
161
+ # NumPy ndarray. If we are allocating a ndarray subclass here then
162
+ # just change the return type to be that of the subclass.
163
+ cac._return_type = (
164
+ lhs_typ.copy(layout="C") if lhs_typ.layout == "F" else lhs_typ
165
+ )
166
+ calltypes[alloc_call] = cac
167
+ if lhs_typ.layout == "F":
168
+ empty_c_typ = lhs_typ.copy(layout="C")
169
+ empty_c_var = ir.Var(scope, mk_unique_var("$empty_c_var"), loc)
170
+ if typemap:
171
+ typemap[empty_c_var.name] = lhs_typ.copy(layout="C")
172
+ empty_c_assign = ir.Assign(alloc_call, empty_c_var, loc)
173
+
174
+ # attr call: asfortranarray = getattr(g_np_var, asfortranarray)
175
+ asfortranarray_attr_call = ir.Expr.getattr(
176
+ g_np_var, "asfortranarray", loc
177
+ )
178
+ afa_attr_var = ir.Var(
179
+ scope, mk_unique_var("$asfortran_array_attr"), loc
180
+ )
181
+ if typemap:
182
+ typemap[afa_attr_var.name] = get_np_ufunc_typ(numpy.asfortranarray)
183
+ afa_attr_assign = ir.Assign(asfortranarray_attr_call, afa_attr_var, loc)
184
+ # call asfortranarray
185
+ asfortranarray_call = ir.Expr.call(afa_attr_var, [empty_c_var], (), loc)
186
+ if calltypes:
187
+ calltypes[asfortranarray_call] = typemap[
188
+ afa_attr_var.name
189
+ ].get_call_type(typingctx, [empty_c_typ], {})
190
+
191
+ asfortranarray_assign = ir.Assign(asfortranarray_call, lhs, loc)
192
+
193
+ out.extend(
194
+ [
195
+ g_np_assign,
196
+ attr_assign,
197
+ typ_var_assign,
198
+ empty_c_assign,
199
+ afa_attr_assign,
200
+ asfortranarray_assign,
201
+ ]
202
+ )
203
+ else:
204
+ alloc_assign = ir.Assign(alloc_call, lhs, loc)
205
+ out.extend([g_np_assign, attr_assign, typ_var_assign, alloc_assign])
206
+
207
+ return out
208
+
209
+
210
+ def convert_size_to_var(size_var, typemap, scope, loc, nodes):
211
+ if isinstance(size_var, int):
212
+ new_size = ir.Var(scope, mk_unique_var("$alloc_size"), loc)
213
+ if typemap:
214
+ typemap[new_size.name] = types.intp
215
+ size_assign = ir.Assign(ir.Const(size_var, loc), new_size, loc)
216
+ nodes.append(size_assign)
217
+ return new_size
218
+ assert isinstance(size_var, ir.Var)
219
+ return size_var
220
+
221
+
222
+ def get_np_ufunc_typ(func):
223
+ """get type of the incoming function from builtin registry"""
224
+ for k, v in typing.npydecl.registry.globals:
225
+ if k == func:
226
+ return v
227
+ for k, v in typing.templates.builtin_registry.globals:
228
+ if k == func:
229
+ return v
230
+ raise RuntimeError("type for func ", func, " not found")
231
+
232
+
233
+ def mk_range_block(typemap, start, stop, step, calltypes, scope, loc):
234
+ """make a block that initializes loop range and iteration variables.
235
+ target label in jump needs to be set.
236
+ """
237
+ # g_range_var = Global(range)
238
+ g_range_var = ir.Var(scope, mk_unique_var("$range_g_var"), loc)
239
+ typemap[g_range_var.name] = get_global_func_typ(range)
240
+ g_range = ir.Global("range", range, loc)
241
+ g_range_assign = ir.Assign(g_range, g_range_var, loc)
242
+ arg_nodes, args = _mk_range_args(typemap, start, stop, step, scope, loc)
243
+ # range_call_var = call g_range_var(start, stop, step)
244
+ range_call = ir.Expr.call(g_range_var, args, (), loc)
245
+ calltypes[range_call] = typemap[g_range_var.name].get_call_type(
246
+ typing.Context(), [types.intp] * len(args), {}
247
+ )
248
+ # signature(types.range_state64_type, types.intp)
249
+ range_call_var = ir.Var(scope, mk_unique_var("$range_c_var"), loc)
250
+ typemap[range_call_var.name] = types.iterators.RangeType(types.intp)
251
+ range_call_assign = ir.Assign(range_call, range_call_var, loc)
252
+ # iter_var = getiter(range_call_var)
253
+ iter_call = ir.Expr.getiter(range_call_var, loc)
254
+ calltype_sig = signature(types.range_iter64_type, types.range_state64_type)
255
+ calltypes[iter_call] = calltype_sig
256
+ iter_var = ir.Var(scope, mk_unique_var("$iter_var"), loc)
257
+ typemap[iter_var.name] = types.iterators.RangeIteratorType(types.intp)
258
+ iter_call_assign = ir.Assign(iter_call, iter_var, loc)
259
+ # $phi = iter_var
260
+ phi_var = ir.Var(scope, mk_unique_var("$phi"), loc)
261
+ typemap[phi_var.name] = types.iterators.RangeIteratorType(types.intp)
262
+ phi_assign = ir.Assign(iter_var, phi_var, loc)
263
+ # jump to header
264
+ jump_header = ir.Jump(-1, loc)
265
+ range_block = ir.Block(scope, loc)
266
+ range_block.body = arg_nodes + [
267
+ g_range_assign,
268
+ range_call_assign,
269
+ iter_call_assign,
270
+ phi_assign,
271
+ jump_header,
272
+ ]
273
+ return range_block
274
+
275
+
276
+ def _mk_range_args(typemap, start, stop, step, scope, loc):
277
+ nodes = []
278
+ if isinstance(stop, ir.Var):
279
+ g_stop_var = stop
280
+ else:
281
+ assert isinstance(stop, int)
282
+ g_stop_var = ir.Var(scope, mk_unique_var("$range_stop"), loc)
283
+ if typemap:
284
+ typemap[g_stop_var.name] = types.intp
285
+ stop_assign = ir.Assign(ir.Const(stop, loc), g_stop_var, loc)
286
+ nodes.append(stop_assign)
287
+ if start == 0 and step == 1:
288
+ return nodes, [g_stop_var]
289
+
290
+ if isinstance(start, ir.Var):
291
+ g_start_var = start
292
+ else:
293
+ assert isinstance(start, int)
294
+ g_start_var = ir.Var(scope, mk_unique_var("$range_start"), loc)
295
+ if typemap:
296
+ typemap[g_start_var.name] = types.intp
297
+ start_assign = ir.Assign(ir.Const(start, loc), g_start_var, loc)
298
+ nodes.append(start_assign)
299
+ if step == 1:
300
+ return nodes, [g_start_var, g_stop_var]
301
+
302
+ if isinstance(step, ir.Var):
303
+ g_step_var = step
304
+ else:
305
+ assert isinstance(step, int)
306
+ g_step_var = ir.Var(scope, mk_unique_var("$range_step"), loc)
307
+ if typemap:
308
+ typemap[g_step_var.name] = types.intp
309
+ step_assign = ir.Assign(ir.Const(step, loc), g_step_var, loc)
310
+ nodes.append(step_assign)
311
+
312
+ return nodes, [g_start_var, g_stop_var, g_step_var]
313
+
314
+
315
+ def get_global_func_typ(func):
316
+ """get type variable for func() from builtin registry"""
317
+ for k, v in typing.templates.builtin_registry.globals:
318
+ if k == func:
319
+ return v
320
+ raise RuntimeError("func type not found {}".format(func))
321
+
322
+
323
+ def mk_loop_header(typemap, phi_var, calltypes, scope, loc):
324
+ """make a block that is a loop header updating iteration variables.
325
+ target labels in branch need to be set.
326
+ """
327
+ # iternext_var = iternext(phi_var)
328
+ iternext_var = ir.Var(scope, mk_unique_var("$iternext_var"), loc)
329
+ typemap[iternext_var.name] = types.containers.Pair(
330
+ types.intp, types.boolean
331
+ )
332
+ iternext_call = ir.Expr.iternext(phi_var, loc)
333
+ range_iter_type = types.range_iter64_type
334
+ calltypes[iternext_call] = signature(
335
+ types.containers.Pair(types.intp, types.boolean), range_iter_type
336
+ )
337
+ iternext_assign = ir.Assign(iternext_call, iternext_var, loc)
338
+ # pair_first_var = pair_first(iternext_var)
339
+ pair_first_var = ir.Var(scope, mk_unique_var("$pair_first_var"), loc)
340
+ typemap[pair_first_var.name] = types.intp
341
+ pair_first_call = ir.Expr.pair_first(iternext_var, loc)
342
+ pair_first_assign = ir.Assign(pair_first_call, pair_first_var, loc)
343
+ # pair_second_var = pair_second(iternext_var)
344
+ pair_second_var = ir.Var(scope, mk_unique_var("$pair_second_var"), loc)
345
+ typemap[pair_second_var.name] = types.boolean
346
+ pair_second_call = ir.Expr.pair_second(iternext_var, loc)
347
+ pair_second_assign = ir.Assign(pair_second_call, pair_second_var, loc)
348
+ # phi_b_var = pair_first_var
349
+ phi_b_var = ir.Var(scope, mk_unique_var("$phi"), loc)
350
+ typemap[phi_b_var.name] = types.intp
351
+ phi_b_assign = ir.Assign(pair_first_var, phi_b_var, loc)
352
+ # branch pair_second_var body_block out_block
353
+ branch = ir.Branch(pair_second_var, -1, -1, loc)
354
+ header_block = ir.Block(scope, loc)
355
+ header_block.body = [
356
+ iternext_assign,
357
+ pair_first_assign,
358
+ pair_second_assign,
359
+ phi_b_assign,
360
+ branch,
361
+ ]
362
+ return header_block
363
+
364
+
365
+ def legalize_names(varnames):
366
+ """returns a dictionary for conversion of variable names to legal
367
+ parameter names.
368
+ """
369
+ var_map = {}
370
+ for var in varnames:
371
+ new_name = var.replace("_", "__").replace("$", "_").replace(".", "_")
372
+ assert new_name not in var_map
373
+ var_map[var] = new_name
374
+ return var_map
375
+
376
+
377
+ def get_name_var_table(blocks):
378
+ """create a mapping from variable names to their ir.Var objects"""
379
+
380
+ def get_name_var_visit(var, namevar):
381
+ namevar[var.name] = var
382
+ return var
383
+
384
+ namevar = {}
385
+ visit_vars(blocks, get_name_var_visit, namevar)
386
+ return namevar
387
+
388
+
389
+ def replace_var_names(blocks, namedict):
390
+ """replace variables (ir.Var to ir.Var) from dictionary (name -> name)"""
391
+ # remove identity values to avoid infinite loop
392
+ new_namedict = {}
393
+ for l, r in namedict.items():
394
+ if l != r:
395
+ new_namedict[l] = r
396
+
397
+ def replace_name(var, namedict):
398
+ assert isinstance(var, ir.Var)
399
+ while var.name in namedict:
400
+ var = ir.Var(var.scope, namedict[var.name], var.loc)
401
+ return var
402
+
403
+ visit_vars(blocks, replace_name, new_namedict)
404
+
405
+
406
+ def replace_var_callback(var, vardict):
407
+ assert isinstance(var, ir.Var)
408
+ while var.name in vardict.keys():
409
+ assert vardict[var.name].name != var.name
410
+ new_var = vardict[var.name]
411
+ var = ir.Var(new_var.scope, new_var.name, new_var.loc)
412
+ return var
413
+
414
+
415
+ def replace_vars(blocks, vardict):
416
+ """replace variables (ir.Var to ir.Var) from dictionary (name -> ir.Var)"""
417
+ # remove identity values to avoid infinite loop
418
+ new_vardict = {}
419
+ for l, r in vardict.items():
420
+ if l != r.name:
421
+ new_vardict[l] = r
422
+ visit_vars(blocks, replace_var_callback, new_vardict)
423
+
424
+
425
+ def replace_vars_stmt(stmt, vardict):
426
+ visit_vars_stmt(stmt, replace_var_callback, vardict)
427
+
428
+
429
+ def replace_vars_inner(node, vardict):
430
+ return visit_vars_inner(node, replace_var_callback, vardict)
431
+
432
+
433
+ # other packages that define new nodes add calls to visit variables in them
434
+ # format: {type:function}
435
+ visit_vars_extensions = {}
436
+
437
+
438
+ def visit_vars(blocks, callback, cbdata):
439
+ """go over statements of block bodies and replace variable names with
440
+ dictionary.
441
+ """
442
+ for block in blocks.values():
443
+ for stmt in block.body:
444
+ visit_vars_stmt(stmt, callback, cbdata)
445
+ return
446
+
447
+
448
+ def visit_vars_stmt(stmt, callback, cbdata):
449
+ # let external calls handle stmt if type matches
450
+ for t, f in visit_vars_extensions.items():
451
+ if isinstance(stmt, t):
452
+ f(stmt, callback, cbdata)
453
+ return
454
+ if isinstance(stmt, ir.Assign):
455
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
456
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
457
+ elif isinstance(stmt, ir.Arg):
458
+ stmt.name = visit_vars_inner(stmt.name, callback, cbdata)
459
+ elif isinstance(stmt, ir.Return):
460
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
461
+ elif isinstance(stmt, ir.Raise):
462
+ stmt.exception = visit_vars_inner(stmt.exception, callback, cbdata)
463
+ elif isinstance(stmt, ir.Branch):
464
+ stmt.cond = visit_vars_inner(stmt.cond, callback, cbdata)
465
+ elif isinstance(stmt, ir.Jump):
466
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
467
+ elif isinstance(stmt, ir.Del):
468
+ # Because Del takes only a var name, we make up by
469
+ # constructing a temporary variable.
470
+ var = ir.Var(None, stmt.value, stmt.loc)
471
+ var = visit_vars_inner(var, callback, cbdata)
472
+ stmt.value = var.name
473
+ elif isinstance(stmt, ir.DelAttr):
474
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
475
+ stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata)
476
+ elif isinstance(stmt, ir.SetAttr):
477
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
478
+ stmt.attr = visit_vars_inner(stmt.attr, callback, cbdata)
479
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
480
+ elif isinstance(stmt, ir.DelItem):
481
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
482
+ stmt.index = visit_vars_inner(stmt.index, callback, cbdata)
483
+ elif isinstance(stmt, ir.StaticSetItem):
484
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
485
+ stmt.index_var = visit_vars_inner(stmt.index_var, callback, cbdata)
486
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
487
+ elif isinstance(stmt, ir.SetItem):
488
+ stmt.target = visit_vars_inner(stmt.target, callback, cbdata)
489
+ stmt.index = visit_vars_inner(stmt.index, callback, cbdata)
490
+ stmt.value = visit_vars_inner(stmt.value, callback, cbdata)
491
+ elif isinstance(stmt, ir.Print):
492
+ stmt.args = [visit_vars_inner(x, callback, cbdata) for x in stmt.args]
493
+ else:
494
+ # TODO: raise NotImplementedError("no replacement for IR node: ", stmt)
495
+ pass
496
+ return
497
+
498
+
499
+ def visit_vars_inner(node, callback, cbdata):
500
+ if isinstance(node, ir.Var):
501
+ return callback(node, cbdata)
502
+ elif isinstance(node, list):
503
+ return [visit_vars_inner(n, callback, cbdata) for n in node]
504
+ elif isinstance(node, tuple):
505
+ return tuple([visit_vars_inner(n, callback, cbdata) for n in node])
506
+ elif isinstance(node, ir.Expr):
507
+ # if node.op in ['binop', 'inplace_binop']:
508
+ # lhs = node.lhs.name
509
+ # rhs = node.rhs.name
510
+ # node.lhs.name = callback, cbdata.get(lhs, lhs)
511
+ # node.rhs.name = callback, cbdata.get(rhs, rhs)
512
+ for arg in node._kws.keys():
513
+ node._kws[arg] = visit_vars_inner(node._kws[arg], callback, cbdata)
514
+ elif isinstance(node, ir.Yield):
515
+ node.value = visit_vars_inner(node.value, callback, cbdata)
516
+ return node
517
+
518
+
519
+ add_offset_to_labels_extensions = {}
520
+
521
+
522
+ def add_offset_to_labels(blocks, offset):
523
+ """add an offset to all block labels and jump/branch targets"""
524
+ new_blocks = {}
525
+ for l, b in blocks.items():
526
+ # some parfor last blocks might be empty
527
+ term = None
528
+ if b.body:
529
+ term = b.body[-1]
530
+ for inst in b.body:
531
+ for T, f in add_offset_to_labels_extensions.items():
532
+ if isinstance(inst, T):
533
+ f(inst, offset)
534
+ if isinstance(term, ir.Jump):
535
+ b.body[-1] = ir.Jump(term.target + offset, term.loc)
536
+ if isinstance(term, ir.Branch):
537
+ b.body[-1] = ir.Branch(
538
+ term.cond, term.truebr + offset, term.falsebr + offset, term.loc
539
+ )
540
+ new_blocks[l + offset] = b
541
+ return new_blocks
542
+
543
+
544
+ find_max_label_extensions = {}
545
+
546
+
547
+ def find_max_label(blocks):
548
+ max_label = 0
549
+ for l, b in blocks.items():
550
+ if b.body:
551
+ for inst in b.body:
552
+ for T, f in find_max_label_extensions.items():
553
+ if isinstance(inst, T):
554
+ f_max = f(inst)
555
+ if f_max > max_label:
556
+ max_label = f_max
557
+ if l > max_label:
558
+ max_label = l
559
+ return max_label
560
+
561
+
562
+ def flatten_labels(blocks):
563
+ """makes the labels in range(0, len(blocks)), useful to compare CFGs"""
564
+ # first bulk move the labels out of the rewrite range
565
+ blocks = add_offset_to_labels(blocks, find_max_label(blocks) + 1)
566
+ # order them in topo order because it's easier to read
567
+ new_blocks = {}
568
+ topo_order = find_topo_order(blocks)
569
+ l_map = dict()
570
+ idx = 0
571
+ for x in topo_order:
572
+ l_map[x] = idx
573
+ idx += 1
574
+
575
+ for t_node in topo_order:
576
+ b = blocks[t_node]
577
+ # some parfor last blocks might be empty
578
+ term = None
579
+ if b.body:
580
+ term = b.body[-1]
581
+ if isinstance(term, ir.Jump):
582
+ b.body[-1] = ir.Jump(l_map[term.target], term.loc)
583
+ if isinstance(term, ir.Branch):
584
+ b.body[-1] = ir.Branch(
585
+ term.cond, l_map[term.truebr], l_map[term.falsebr], term.loc
586
+ )
587
+ new_blocks[l_map[t_node]] = b
588
+ return new_blocks
589
+
590
+
591
+ def remove_dels(blocks):
592
+ """remove ir.Del nodes"""
593
+ for block in blocks.values():
594
+ new_body = []
595
+ for stmt in block.body:
596
+ if not isinstance(stmt, ir.Del):
597
+ new_body.append(stmt)
598
+ block.body = new_body
599
+ return
600
+
601
+
602
+ def remove_args(blocks):
603
+ """remove ir.Arg nodes"""
604
+ for block in blocks.values():
605
+ new_body = []
606
+ for stmt in block.body:
607
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg):
608
+ continue
609
+ new_body.append(stmt)
610
+ block.body = new_body
611
+ return
612
+
613
+
614
+ def dead_code_elimination(
615
+ func_ir, typemap=None, alias_map=None, arg_aliases=None
616
+ ):
617
+ """Performs dead code elimination and leaves the IR in a valid state on
618
+ exit
619
+ """
620
+ do_post_proc = False
621
+ while remove_dead(
622
+ func_ir.blocks,
623
+ func_ir.arg_names,
624
+ func_ir,
625
+ typemap,
626
+ alias_map,
627
+ arg_aliases,
628
+ ):
629
+ do_post_proc = True
630
+
631
+ if do_post_proc:
632
+ post_proc = postproc.PostProcessor(func_ir)
633
+ post_proc.run()
634
+
635
+
636
+ def remove_dead(
637
+ blocks, args, func_ir, typemap=None, alias_map=None, arg_aliases=None
638
+ ):
639
+ """dead code elimination using liveness and CFG info.
640
+ Returns True if something has been removed, or False if nothing is removed.
641
+ """
642
+ cfg = compute_cfg_from_blocks(blocks)
643
+ usedefs = compute_use_defs(blocks)
644
+ live_map = compute_live_map(cfg, blocks, usedefs.usemap, usedefs.defmap)
645
+ call_table, _ = get_call_table(blocks)
646
+ if alias_map is None or arg_aliases is None:
647
+ alias_map, arg_aliases = find_potential_aliases(
648
+ blocks, args, typemap, func_ir
649
+ )
650
+ if config.DEBUG_ARRAY_OPT >= 1:
651
+ print("args:", args)
652
+ print("alias map:", alias_map)
653
+ print("arg_aliases:", arg_aliases)
654
+ print("live_map:", live_map)
655
+ print("usemap:", usedefs.usemap)
656
+ print("defmap:", usedefs.defmap)
657
+ # keep set for easier search
658
+ alias_set = set(alias_map.keys())
659
+
660
+ removed = False
661
+ for label, block in blocks.items():
662
+ # find live variables at each statement to delete dead assignment
663
+ lives = {v.name for v in block.terminator.list_vars()}
664
+ if config.DEBUG_ARRAY_OPT >= 2:
665
+ print("remove_dead processing block", label, lives)
666
+ # find live variables at the end of block
667
+ for out_blk, _data in cfg.successors(label):
668
+ if config.DEBUG_ARRAY_OPT >= 2:
669
+ print("succ live_map", out_blk, live_map[out_blk])
670
+ lives |= live_map[out_blk]
671
+ removed |= remove_dead_block(
672
+ block,
673
+ lives,
674
+ call_table,
675
+ arg_aliases,
676
+ alias_map,
677
+ alias_set,
678
+ func_ir,
679
+ typemap,
680
+ )
681
+
682
+ return removed
683
+
684
+
685
+ # other packages that define new nodes add calls to remove dead code in them
686
+ # format: {type:function}
687
+ remove_dead_extensions = {}
688
+
689
+
690
+ def remove_dead_block(
691
+ block,
692
+ lives,
693
+ call_table,
694
+ arg_aliases,
695
+ alias_map,
696
+ alias_set,
697
+ func_ir,
698
+ typemap,
699
+ ):
700
+ """remove dead code using liveness info.
701
+ Mutable arguments (e.g. arrays) that are not definitely assigned are live
702
+ after return of function.
703
+ """
704
+ # TODO: find mutable args that are not definitely assigned instead of
705
+ # assuming all args are live after return
706
+ removed = False
707
+
708
+ # add statements in reverse order
709
+ new_body = [block.terminator]
710
+ # for each statement in reverse order, excluding terminator
711
+ for stmt in reversed(block.body[:-1]):
712
+ if config.DEBUG_ARRAY_OPT >= 2:
713
+ print("remove_dead_block", stmt)
714
+ # aliases of lives are also live
715
+ alias_lives = set()
716
+ init_alias_lives = lives & alias_set
717
+ for v in init_alias_lives:
718
+ alias_lives |= alias_map[v]
719
+ lives_n_aliases = lives | alias_lives | arg_aliases
720
+
721
+ # let external calls handle stmt if type matches
722
+ if type(stmt) in remove_dead_extensions:
723
+ f = remove_dead_extensions[type(stmt)]
724
+ stmt = f(
725
+ stmt,
726
+ lives,
727
+ lives_n_aliases,
728
+ arg_aliases,
729
+ alias_map,
730
+ func_ir,
731
+ typemap,
732
+ )
733
+ if stmt is None:
734
+ if config.DEBUG_ARRAY_OPT >= 2:
735
+ print("Statement was removed.")
736
+ removed = True
737
+ continue
738
+
739
+ # ignore assignments that their lhs is not live or lhs==rhs
740
+ if isinstance(stmt, ir.Assign):
741
+ lhs = stmt.target
742
+ rhs = stmt.value
743
+ if lhs.name not in lives and has_no_side_effect(
744
+ rhs, lives_n_aliases, call_table
745
+ ):
746
+ if config.DEBUG_ARRAY_OPT >= 2:
747
+ print("Statement was removed.")
748
+ removed = True
749
+ continue
750
+ if isinstance(rhs, ir.Var) and lhs.name == rhs.name:
751
+ if config.DEBUG_ARRAY_OPT >= 2:
752
+ print("Statement was removed.")
753
+ removed = True
754
+ continue
755
+ # TODO: remove other nodes like SetItem etc.
756
+
757
+ if isinstance(stmt, ir.Del):
758
+ if stmt.value not in lives:
759
+ if config.DEBUG_ARRAY_OPT >= 2:
760
+ print("Statement was removed.")
761
+ removed = True
762
+ continue
763
+
764
+ if isinstance(stmt, ir.SetItem):
765
+ name = stmt.target.name
766
+ if name not in lives_n_aliases:
767
+ if config.DEBUG_ARRAY_OPT >= 2:
768
+ print("Statement was removed.")
769
+ continue
770
+
771
+ if type(stmt) in analysis.ir_extension_usedefs:
772
+ def_func = analysis.ir_extension_usedefs[type(stmt)]
773
+ uses, defs = def_func(stmt)
774
+ lives -= defs
775
+ lives |= uses
776
+ else:
777
+ lives |= {v.name for v in stmt.list_vars()}
778
+ if isinstance(stmt, ir.Assign):
779
+ # make sure lhs is not used in rhs, e.g. a = g(a)
780
+ if isinstance(stmt.value, ir.Expr):
781
+ rhs_vars = {v.name for v in stmt.value.list_vars()}
782
+ if lhs.name not in rhs_vars:
783
+ lives.remove(lhs.name)
784
+ else:
785
+ lives.remove(lhs.name)
786
+
787
+ new_body.append(stmt)
788
+ new_body.reverse()
789
+ block.body = new_body
790
+ return removed
791
+
792
+
793
+ # list of functions
794
+ remove_call_handlers = []
795
+
796
+
797
+ def remove_dead_random_call(rhs, lives, call_list):
798
+ if len(call_list) == 3 and call_list[1:] == ["random", numpy]:
799
+ return call_list[0] not in {"seed", "shuffle"}
800
+ return False
801
+
802
+
803
+ remove_call_handlers.append(remove_dead_random_call)
804
+
805
+
806
+ def has_no_side_effect(rhs, lives, call_table):
807
+ """Returns True if this expression has no side effects that
808
+ would prevent re-ordering.
809
+ """
810
+ from numba.cuda.extending import _Intrinsic
811
+
812
+ if isinstance(rhs, ir.Expr) and rhs.op == "call":
813
+ func_name = rhs.func.name
814
+ if func_name not in call_table or call_table[func_name] == []:
815
+ return False
816
+ call_list = call_table[func_name]
817
+ if (
818
+ call_list == ["empty", numpy]
819
+ or call_list == [slice]
820
+ or call_list == ["log", numpy]
821
+ or call_list == ["dtype", numpy]
822
+ or call_list == ["ceil", math]
823
+ or call_list == [max]
824
+ or call_list == [int]
825
+ ):
826
+ return True
827
+ elif isinstance(call_list[0], _Intrinsic) and (
828
+ call_list[0]._name == "empty_inferred"
829
+ or call_list[0]._name == "unsafe_empty_inferred"
830
+ ):
831
+ return True
832
+
833
+ if HAS_NUMBA:
834
+ from numba.core.registry import CPUDispatcher
835
+ from numba.cuda.np.linalg import dot_3_mv_check_args
836
+
837
+ if isinstance(call_list[0], CPUDispatcher):
838
+ py_func = call_list[0].py_func
839
+ if py_func == dot_3_mv_check_args:
840
+ return True
841
+
842
+ for f in remove_call_handlers:
843
+ if f(rhs, lives, call_list):
844
+ return True
845
+ return False
846
+ if isinstance(rhs, ir.Expr) and rhs.op == "inplace_binop":
847
+ return rhs.lhs.name not in lives
848
+ if isinstance(rhs, ir.Yield):
849
+ return False
850
+ if isinstance(rhs, ir.Expr) and rhs.op == "pair_first":
851
+ # don't remove pair_first since prange looks for it
852
+ return False
853
+ return True
854
+
855
+
856
+ is_pure_extensions = []
857
+
858
+
859
+ def is_pure(rhs, lives, call_table):
860
+ """Returns True if every time this expression is evaluated it
861
+ returns the same result. This is not the case for things
862
+ like calls to numpy.random.
863
+ """
864
+ if isinstance(rhs, ir.Expr):
865
+ if rhs.op == "call":
866
+ func_name = rhs.func.name
867
+ if func_name not in call_table or call_table[func_name] == []:
868
+ return False
869
+ call_list = call_table[func_name]
870
+ if (
871
+ call_list == [slice]
872
+ or call_list == ["log", numpy]
873
+ or call_list == ["empty", numpy]
874
+ or call_list == ["ceil", math]
875
+ or call_list == [max]
876
+ or call_list == [int]
877
+ ):
878
+ return True
879
+ for f in is_pure_extensions:
880
+ if f(rhs, lives, call_list):
881
+ return True
882
+ return False
883
+ elif rhs.op == "getiter" or rhs.op == "iternext":
884
+ return False
885
+ if isinstance(rhs, ir.Yield):
886
+ return False
887
+ return True
888
+
889
+
890
+ def is_const_call(module_name, func_name):
891
+ # Returns True if there is no state in the given module changed by the given function.
892
+ if module_name == "numpy":
893
+ if func_name in ["empty"]:
894
+ return True
895
+ return False
896
+
897
+
898
+ alias_analysis_extensions = {}
899
+ alias_func_extensions = {}
900
+
901
+
902
+ def get_canonical_alias(v, alias_map):
903
+ if v not in alias_map:
904
+ return v
905
+
906
+ v_aliases = sorted(list(alias_map[v]))
907
+ return v_aliases[0]
908
+
909
+
910
+ def find_potential_aliases(
911
+ blocks, args, typemap, func_ir, alias_map=None, arg_aliases=None
912
+ ):
913
+ "find all array aliases and argument aliases to avoid remove as dead"
914
+ if alias_map is None:
915
+ alias_map = {}
916
+ if arg_aliases is None:
917
+ arg_aliases = set(a for a in args if not is_immutable_type(a, typemap))
918
+
919
+ # update definitions since they are not guaranteed to be up-to-date
920
+ # FIXME keep definitions up-to-date to avoid the need for rebuilding
921
+ func_ir._definitions = build_definitions(func_ir.blocks)
922
+ np_alias_funcs = ["ravel", "transpose", "reshape"]
923
+
924
+ for bl in blocks.values():
925
+ for instr in bl.body:
926
+ if type(instr) in alias_analysis_extensions:
927
+ f = alias_analysis_extensions[type(instr)]
928
+ f(instr, args, typemap, func_ir, alias_map, arg_aliases)
929
+ if isinstance(instr, ir.Assign):
930
+ expr = instr.value
931
+ lhs = instr.target.name
932
+ # only mutable types can alias
933
+ if is_immutable_type(lhs, typemap):
934
+ continue
935
+ if isinstance(expr, ir.Var) and lhs != expr.name:
936
+ _add_alias(lhs, expr.name, alias_map, arg_aliases)
937
+ # subarrays like A = B[0] for 2D B
938
+ if isinstance(expr, ir.Expr) and (
939
+ expr.op == "cast"
940
+ or expr.op in ["getitem", "static_getitem"]
941
+ ):
942
+ _add_alias(lhs, expr.value.name, alias_map, arg_aliases)
943
+ if isinstance(expr, ir.Expr) and expr.op == "inplace_binop":
944
+ _add_alias(lhs, expr.lhs.name, alias_map, arg_aliases)
945
+ # array attributes like A.T
946
+ if (
947
+ isinstance(expr, ir.Expr)
948
+ and expr.op == "getattr"
949
+ and expr.attr in ["T", "ctypes", "flat"]
950
+ ):
951
+ _add_alias(lhs, expr.value.name, alias_map, arg_aliases)
952
+ # a = b.c. a should alias b
953
+ if (
954
+ isinstance(expr, ir.Expr)
955
+ and expr.op == "getattr"
956
+ and expr.attr not in ["shape"]
957
+ and expr.value.name in arg_aliases
958
+ ):
959
+ _add_alias(lhs, expr.value.name, alias_map, arg_aliases)
960
+ # calls that can create aliases such as B = A.ravel()
961
+ if isinstance(expr, ir.Expr) and expr.op == "call":
962
+ fdef = guard(find_callname, func_ir, expr, typemap)
963
+ # TODO: sometimes gufunc backend creates duplicate code
964
+ # causing find_callname to fail. Example: test_argmax
965
+ # ignored here since those cases don't create aliases
966
+ # but should be fixed in general
967
+ if fdef is None:
968
+ continue
969
+ fname, fmod = fdef
970
+ if fdef in alias_func_extensions:
971
+ alias_func = alias_func_extensions[fdef]
972
+ alias_func(lhs, expr.args, alias_map, arg_aliases)
973
+ if fmod == "numpy" and fname in np_alias_funcs:
974
+ _add_alias(
975
+ lhs, expr.args[0].name, alias_map, arg_aliases
976
+ )
977
+ if isinstance(fmod, ir.Var) and fname in np_alias_funcs:
978
+ _add_alias(lhs, fmod.name, alias_map, arg_aliases)
979
+
980
+ # copy to avoid changing size during iteration
981
+ old_alias_map = copy.deepcopy(alias_map)
982
+ # combine all aliases transitively
983
+ for v in old_alias_map:
984
+ for w in old_alias_map[v]:
985
+ alias_map[v] |= alias_map[w]
986
+ for w in old_alias_map[v]:
987
+ alias_map[w] = alias_map[v]
988
+
989
+ return alias_map, arg_aliases
990
+
991
+
992
+ def _add_alias(lhs, rhs, alias_map, arg_aliases):
993
+ if rhs in arg_aliases:
994
+ arg_aliases.add(lhs)
995
+ else:
996
+ if rhs not in alias_map:
997
+ alias_map[rhs] = set()
998
+ if lhs not in alias_map:
999
+ alias_map[lhs] = set()
1000
+ alias_map[rhs].add(lhs)
1001
+ alias_map[lhs].add(rhs)
1002
+ return
1003
+
1004
+
1005
+ def is_immutable_type(var, typemap):
1006
+ # Conservatively, assume mutable if type not available
1007
+ if typemap is None or var not in typemap:
1008
+ return False
1009
+ typ = typemap[var]
1010
+ # TODO: add more immutable types
1011
+ if isinstance(
1012
+ typ,
1013
+ (
1014
+ types.Number,
1015
+ types.scalars._NPDatetimeBase,
1016
+ types.iterators.RangeType,
1017
+ ),
1018
+ ):
1019
+ return True
1020
+ if typ == types.string:
1021
+ return True
1022
+ # conservatively, assume mutable
1023
+ return False
1024
+
1025
+
1026
+ def copy_propagate(blocks, typemap):
1027
+ """compute copy propagation information for each block using fixed-point
1028
+ iteration on data flow equations:
1029
+ in_b = intersect(predec(B))
1030
+ out_b = gen_b | (in_b - kill_b)
1031
+ """
1032
+ cfg = compute_cfg_from_blocks(blocks)
1033
+ entry = cfg.entry_point()
1034
+
1035
+ # format: dict of block labels to copies as tuples
1036
+ # label -> (l,r)
1037
+ c_data = init_copy_propagate_data(blocks, entry, typemap)
1038
+ (gen_copies, all_copies, kill_copies, in_copies, out_copies) = c_data
1039
+
1040
+ old_point = None
1041
+ new_point = copy.deepcopy(out_copies)
1042
+ # comparison works since dictionary of built-in types
1043
+ while old_point != new_point:
1044
+ for label in blocks.keys():
1045
+ if label == entry:
1046
+ continue
1047
+ predecs = [i for i, _d in cfg.predecessors(label)]
1048
+ # in_b = intersect(predec(B))
1049
+ in_copies[label] = out_copies[predecs[0]].copy()
1050
+ for p in predecs:
1051
+ in_copies[label] &= out_copies[p]
1052
+
1053
+ # out_b = gen_b | (in_b - kill_b)
1054
+ out_copies[label] = gen_copies[label] | (
1055
+ in_copies[label] - kill_copies[label]
1056
+ )
1057
+ old_point = new_point
1058
+ new_point = copy.deepcopy(out_copies)
1059
+ if config.DEBUG_ARRAY_OPT >= 1:
1060
+ print("copy propagate out_copies:", out_copies)
1061
+ return in_copies, out_copies
1062
+
1063
+
1064
+ def init_copy_propagate_data(blocks, entry, typemap):
1065
+ """get initial condition of copy propagation data flow for each block."""
1066
+ # gen is all definite copies, extra_kill is additional ones that may hit
1067
+ # for example, parfors can have control flow so they may hit extra copies
1068
+ gen_copies, extra_kill = get_block_copies(blocks, typemap)
1069
+ # set of all program copies
1070
+ all_copies = set()
1071
+ for l, s in gen_copies.items():
1072
+ all_copies |= gen_copies[l]
1073
+ kill_copies = {}
1074
+ for label, gen_set in gen_copies.items():
1075
+ kill_copies[label] = set()
1076
+ for lhs, rhs in all_copies:
1077
+ if lhs in extra_kill[label] or rhs in extra_kill[label]:
1078
+ kill_copies[label].add((lhs, rhs))
1079
+ # a copy is killed if it is not in this block and lhs or rhs are
1080
+ # assigned in this block
1081
+ assigned = {lhs for lhs, rhs in gen_set}
1082
+ if (lhs, rhs) not in gen_set and (
1083
+ lhs in assigned or rhs in assigned
1084
+ ):
1085
+ kill_copies[label].add((lhs, rhs))
1086
+ # set initial values
1087
+ # all copies are in for all blocks except entry
1088
+ in_copies = {l: all_copies.copy() for l in blocks.keys()}
1089
+ in_copies[entry] = set()
1090
+ out_copies = {}
1091
+ for label in blocks.keys():
1092
+ # out_b = gen_b | (in_b - kill_b)
1093
+ out_copies[label] = gen_copies[label] | (
1094
+ in_copies[label] - kill_copies[label]
1095
+ )
1096
+ out_copies[entry] = gen_copies[entry]
1097
+ return (gen_copies, all_copies, kill_copies, in_copies, out_copies)
1098
+
1099
+
1100
+ # other packages that define new nodes add calls to get copies in them
1101
+ # format: {type:function}
1102
+ copy_propagate_extensions = {}
1103
+
1104
+
1105
+ def get_block_copies(blocks, typemap):
1106
+ """get copies generated and killed by each block"""
1107
+ block_copies = {}
1108
+ extra_kill = {}
1109
+ for label, block in blocks.items():
1110
+ assign_dict = {}
1111
+ extra_kill[label] = set()
1112
+ # assignments as dict to replace with latest value
1113
+ for stmt in block.body:
1114
+ for T, f in copy_propagate_extensions.items():
1115
+ if isinstance(stmt, T):
1116
+ gen_set, kill_set = f(stmt, typemap)
1117
+ for lhs, rhs in gen_set:
1118
+ assign_dict[lhs] = rhs
1119
+ # if a=b is in dict and b is killed, a is also killed
1120
+ new_assign_dict = {}
1121
+ for l, r in assign_dict.items():
1122
+ if l not in kill_set and r not in kill_set:
1123
+ new_assign_dict[l] = r
1124
+ if r in kill_set:
1125
+ extra_kill[label].add(l)
1126
+ assign_dict = new_assign_dict
1127
+ extra_kill[label] |= kill_set
1128
+ if isinstance(stmt, ir.Assign):
1129
+ lhs = stmt.target.name
1130
+ if isinstance(stmt.value, ir.Var):
1131
+ rhs = stmt.value.name
1132
+ # copy is valid only if same type (see
1133
+ # TestCFunc.test_locals)
1134
+ # Some transformations can produce assignments of the
1135
+ # form A = A. We don't put these mapping in the
1136
+ # copy propagation set because then you get cycles and
1137
+ # infinite loops in the replacement phase.
1138
+ if typemap[lhs] == typemap[rhs] and lhs != rhs:
1139
+ assign_dict[lhs] = rhs
1140
+ continue
1141
+ if (
1142
+ isinstance(stmt.value, ir.Expr)
1143
+ and stmt.value.op == "inplace_binop"
1144
+ ):
1145
+ in1_var = stmt.value.lhs.name
1146
+ in1_typ = typemap[in1_var]
1147
+ # inplace_binop assigns first operand if mutable
1148
+ if not (
1149
+ isinstance(in1_typ, types.Number)
1150
+ or in1_typ == types.string
1151
+ ):
1152
+ extra_kill[label].add(in1_var)
1153
+ # if a=b is in dict and b is killed, a is also killed
1154
+ new_assign_dict = {}
1155
+ for l, r in assign_dict.items():
1156
+ if l != in1_var and r != in1_var:
1157
+ new_assign_dict[l] = r
1158
+ if r == in1_var:
1159
+ extra_kill[label].add(l)
1160
+ assign_dict = new_assign_dict
1161
+ extra_kill[label].add(lhs)
1162
+ block_cps = set(assign_dict.items())
1163
+ block_copies[label] = block_cps
1164
+ return block_copies, extra_kill
1165
+
1166
+
1167
+ # other packages that define new nodes add calls to apply copy propagate in them
1168
+ # format: {type:function}
1169
+ apply_copy_propagate_extensions = {}
1170
+
1171
+
1172
+ def apply_copy_propagate(
1173
+ blocks, in_copies, name_var_table, typemap, calltypes, save_copies=None
1174
+ ):
1175
+ """apply copy propagation to IR: replace variables when copies available"""
1176
+ # save_copies keeps an approximation of the copies that were applied, so
1177
+ # that the variable names of removed user variables can be recovered to some
1178
+ # extent.
1179
+ if save_copies is None:
1180
+ save_copies = []
1181
+
1182
+ for label, block in blocks.items():
1183
+ var_dict = {l: name_var_table[r] for l, r in in_copies[label]}
1184
+ # assignments as dict to replace with latest value
1185
+ for stmt in block.body:
1186
+ if type(stmt) in apply_copy_propagate_extensions:
1187
+ f = apply_copy_propagate_extensions[type(stmt)]
1188
+ f(
1189
+ stmt,
1190
+ var_dict,
1191
+ name_var_table,
1192
+ typemap,
1193
+ calltypes,
1194
+ save_copies,
1195
+ )
1196
+ # only rhs of assignments should be replaced
1197
+ # e.g. if x=y is available, x in x=z shouldn't be replaced
1198
+ elif isinstance(stmt, ir.Assign):
1199
+ stmt.value = replace_vars_inner(stmt.value, var_dict)
1200
+ else:
1201
+ replace_vars_stmt(stmt, var_dict)
1202
+ fix_setitem_type(stmt, typemap, calltypes)
1203
+ for T, f in copy_propagate_extensions.items():
1204
+ if isinstance(stmt, T):
1205
+ gen_set, kill_set = f(stmt, typemap)
1206
+ for lhs, rhs in gen_set:
1207
+ if rhs in name_var_table:
1208
+ var_dict[lhs] = name_var_table[rhs]
1209
+ for l, r in var_dict.copy().items():
1210
+ if l in kill_set or r.name in kill_set:
1211
+ var_dict.pop(l)
1212
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var):
1213
+ lhs = stmt.target.name
1214
+ rhs = stmt.value.name
1215
+ # rhs could be replaced with lhs from previous copies
1216
+ if lhs != rhs:
1217
+ # copy is valid only if same type (see
1218
+ # TestCFunc.test_locals)
1219
+ if typemap[lhs] == typemap[rhs] and rhs in name_var_table:
1220
+ var_dict[lhs] = name_var_table[rhs]
1221
+ else:
1222
+ var_dict.pop(lhs, None)
1223
+ # a=b kills previous t=a
1224
+ lhs_kill = []
1225
+ for k, v in var_dict.items():
1226
+ if v.name == lhs:
1227
+ lhs_kill.append(k)
1228
+ for k in lhs_kill:
1229
+ var_dict.pop(k, None)
1230
+ if isinstance(stmt, ir.Assign) and not isinstance(
1231
+ stmt.value, ir.Var
1232
+ ):
1233
+ lhs = stmt.target.name
1234
+ var_dict.pop(lhs, None)
1235
+ # previous t=a is killed if a is killed
1236
+ lhs_kill = []
1237
+ for k, v in var_dict.items():
1238
+ if v.name == lhs:
1239
+ lhs_kill.append(k)
1240
+ for k in lhs_kill:
1241
+ var_dict.pop(k, None)
1242
+ save_copies.extend(var_dict.items())
1243
+
1244
+ return save_copies
1245
+
1246
+
1247
+ def fix_setitem_type(stmt, typemap, calltypes):
1248
+ """Copy propagation can replace setitem target variable, which can be array
1249
+ with 'A' layout. The replaced variable can be 'C' or 'F', so we update
1250
+ setitem call type reflect this (from matrix power test)
1251
+ """
1252
+ if not isinstance(stmt, (ir.SetItem, ir.StaticSetItem)):
1253
+ return
1254
+ t_typ = typemap[stmt.target.name]
1255
+ s_typ = calltypes[stmt].args[0]
1256
+ # test_optional t_typ can be Optional with array
1257
+ if not isinstance(s_typ, types.npytypes.Array) or not isinstance(
1258
+ t_typ, types.npytypes.Array
1259
+ ):
1260
+ return
1261
+ if s_typ.layout == "A" and t_typ.layout != "A":
1262
+ new_s_typ = s_typ.copy(layout=t_typ.layout)
1263
+ calltypes[stmt].args = (
1264
+ new_s_typ,
1265
+ calltypes[stmt].args[1],
1266
+ calltypes[stmt].args[2],
1267
+ )
1268
+ return
1269
+
1270
+
1271
+ def dprint_func_ir(func_ir, title, blocks=None):
1272
+ """Debug print function IR, with an optional blocks argument
1273
+ that may differ from the IR's original blocks.
1274
+ """
1275
+ if config.DEBUG_ARRAY_OPT >= 1:
1276
+ ir_blocks = func_ir.blocks
1277
+ func_ir.blocks = ir_blocks if blocks is None else blocks
1278
+ name = func_ir.func_id.func_qualname
1279
+ print(("IR %s: %s" % (title, name)).center(80, "-"))
1280
+ func_ir.dump()
1281
+ print("-" * 40)
1282
+ func_ir.blocks = ir_blocks
1283
+
1284
+
1285
+ def find_topo_order(blocks, cfg=None):
1286
+ """find topological order of blocks such that true branches are visited
1287
+ first (e.g. for_break test in test_dataflow). This is written as an iterative
1288
+ implementation of post order traversal to avoid recursion limit issues.
1289
+ """
1290
+ if cfg is None:
1291
+ cfg = compute_cfg_from_blocks(blocks)
1292
+
1293
+ post_order = []
1294
+ # Has the node already added its children?
1295
+ seen = set()
1296
+ # Has the node already been pushed to post order?
1297
+ visited = set()
1298
+ stack = [cfg.entry_point()]
1299
+
1300
+ while len(stack) > 0:
1301
+ node = stack[-1]
1302
+ if node not in visited and node not in seen:
1303
+ # We haven't added a node or its children.
1304
+ seen.add(node)
1305
+ succs = cfg._succs[node]
1306
+ last_inst = blocks[node].body[-1]
1307
+ if isinstance(last_inst, ir.Branch):
1308
+ succs = [last_inst.truebr, last_inst.falsebr]
1309
+ for dest in succs:
1310
+ if (node, dest) not in cfg._back_edges:
1311
+ if dest not in seen:
1312
+ stack.append(dest)
1313
+ else:
1314
+ # This node has already added its children. We either need
1315
+ # to visit the node or it has been added multiple times in
1316
+ # which case we should just skip the node.
1317
+ node = stack.pop()
1318
+ if node not in visited:
1319
+ post_order.append(node)
1320
+ visited.add(node)
1321
+ if node in seen:
1322
+ # Remove the node from seen if it exists to limit the memory
1323
+ # usage to 1 entry per node. Otherwise the memory requirement
1324
+ # can double the recursive version.
1325
+ seen.remove(node)
1326
+
1327
+ post_order.reverse()
1328
+ return post_order
1329
+
1330
+
1331
+ # other packages that define new nodes add calls to get call table
1332
+ # format: {type:function}
1333
+ call_table_extensions = {}
1334
+
1335
+
1336
+ def get_call_table(
1337
+ blocks, call_table=None, reverse_call_table=None, topological_ordering=True
1338
+ ):
1339
+ """returns a dictionary of call variables and their references."""
1340
+ # call_table example: c = np.zeros becomes c:["zeroes", np]
1341
+ # reverse_call_table example: c = np.zeros becomes np_var:c
1342
+ if call_table is None:
1343
+ call_table = {}
1344
+ if reverse_call_table is None:
1345
+ reverse_call_table = {}
1346
+
1347
+ if topological_ordering:
1348
+ order = find_topo_order(blocks)
1349
+ else:
1350
+ order = list(blocks.keys())
1351
+
1352
+ for label in reversed(order):
1353
+ for inst in reversed(blocks[label].body):
1354
+ if isinstance(inst, ir.Assign):
1355
+ lhs = inst.target.name
1356
+ rhs = inst.value
1357
+ if isinstance(rhs, ir.Expr) and rhs.op == "call":
1358
+ call_table[rhs.func.name] = []
1359
+ if isinstance(rhs, ir.Expr) and rhs.op == "getattr":
1360
+ if lhs in call_table:
1361
+ call_table[lhs].append(rhs.attr)
1362
+ reverse_call_table[rhs.value.name] = lhs
1363
+ if lhs in reverse_call_table:
1364
+ call_var = reverse_call_table[lhs]
1365
+ call_table[call_var].append(rhs.attr)
1366
+ reverse_call_table[rhs.value.name] = call_var
1367
+ if isinstance(rhs, ir.Global):
1368
+ if lhs in call_table:
1369
+ call_table[lhs].append(rhs.value)
1370
+ if lhs in reverse_call_table:
1371
+ call_var = reverse_call_table[lhs]
1372
+ call_table[call_var].append(rhs.value)
1373
+ if isinstance(rhs, ir.FreeVar):
1374
+ if lhs in call_table:
1375
+ call_table[lhs].append(rhs.value)
1376
+ if lhs in reverse_call_table:
1377
+ call_var = reverse_call_table[lhs]
1378
+ call_table[call_var].append(rhs.value)
1379
+ if isinstance(rhs, ir.Var):
1380
+ if lhs in call_table:
1381
+ call_table[lhs].append(rhs.name)
1382
+ reverse_call_table[rhs.name] = lhs
1383
+ if lhs in reverse_call_table:
1384
+ call_var = reverse_call_table[lhs]
1385
+ call_table[call_var].append(rhs.name)
1386
+ for T, f in call_table_extensions.items():
1387
+ if isinstance(inst, T):
1388
+ f(inst, call_table, reverse_call_table)
1389
+ return call_table, reverse_call_table
1390
+
1391
+
1392
+ # other packages that define new nodes add calls to get tuple table
1393
+ # format: {type:function}
1394
+ tuple_table_extensions = {}
1395
+
1396
+
1397
+ def get_tuple_table(blocks, tuple_table=None):
1398
+ """returns a dictionary of tuple variables and their values."""
1399
+ if tuple_table is None:
1400
+ tuple_table = {}
1401
+
1402
+ for block in blocks.values():
1403
+ for inst in block.body:
1404
+ if isinstance(inst, ir.Assign):
1405
+ lhs = inst.target.name
1406
+ rhs = inst.value
1407
+ if isinstance(rhs, ir.Expr) and rhs.op == "build_tuple":
1408
+ tuple_table[lhs] = rhs.items
1409
+ if isinstance(rhs, ir.Const) and isinstance(rhs.value, tuple):
1410
+ tuple_table[lhs] = rhs.value
1411
+ for T, f in tuple_table_extensions.items():
1412
+ if isinstance(inst, T):
1413
+ f(inst, tuple_table)
1414
+ return tuple_table
1415
+
1416
+
1417
+ def get_stmt_writes(stmt):
1418
+ writes = set()
1419
+ if isinstance(stmt, (ir.Assign, ir.SetItem, ir.StaticSetItem)):
1420
+ writes.add(stmt.target.name)
1421
+ return writes
1422
+
1423
+
1424
+ def rename_labels(blocks):
1425
+ """rename labels of function body blocks according to topological sort.
1426
+ The set of labels of these blocks will remain unchanged.
1427
+ """
1428
+ topo_order = find_topo_order(blocks)
1429
+
1430
+ # make a block with return last if available (just for readability)
1431
+ return_label = -1
1432
+ for l, b in blocks.items():
1433
+ if isinstance(b.body[-1], ir.Return):
1434
+ return_label = l
1435
+ # some cases like generators can have no return blocks
1436
+ if return_label != -1:
1437
+ topo_order.remove(return_label)
1438
+ topo_order.append(return_label)
1439
+
1440
+ label_map = {}
1441
+ all_labels = sorted(topo_order, reverse=True)
1442
+ for label in topo_order:
1443
+ label_map[label] = all_labels.pop()
1444
+ # update target labels in jumps/branches
1445
+ for b in blocks.values():
1446
+ term = b.terminator
1447
+ # create new IR nodes instead of mutating the existing one as copies of
1448
+ # the IR may also refer to the same nodes!
1449
+ if isinstance(term, ir.Jump):
1450
+ b.body[-1] = ir.Jump(label_map[term.target], term.loc)
1451
+ if isinstance(term, ir.Branch):
1452
+ b.body[-1] = ir.Branch(
1453
+ term.cond,
1454
+ label_map[term.truebr],
1455
+ label_map[term.falsebr],
1456
+ term.loc,
1457
+ )
1458
+
1459
+ # update blocks dictionary keys
1460
+ new_blocks = {}
1461
+ for k, b in blocks.items():
1462
+ new_label = label_map[k]
1463
+ new_blocks[new_label] = b
1464
+
1465
+ return new_blocks
1466
+
1467
+
1468
+ def simplify_CFG(blocks):
1469
+ """transform chains of blocks that have no loop into a single block"""
1470
+ # first, inline single-branch-block to its predecessors
1471
+ cfg = compute_cfg_from_blocks(blocks)
1472
+
1473
+ def find_single_branch(label):
1474
+ block = blocks[label]
1475
+ return len(block.body) == 1 and isinstance(block.body[0], ir.Branch)
1476
+
1477
+ single_branch_blocks = list(filter(find_single_branch, blocks.keys()))
1478
+ marked_for_del = set()
1479
+ for label in single_branch_blocks:
1480
+ inst = blocks[label].body[0]
1481
+ predecessors = cfg.predecessors(label)
1482
+ delete_block = True
1483
+ for p, q in predecessors:
1484
+ block = blocks[p]
1485
+ if isinstance(block.body[-1], ir.Jump):
1486
+ block.body[-1] = copy.copy(inst)
1487
+ else:
1488
+ delete_block = False
1489
+ if delete_block:
1490
+ marked_for_del.add(label)
1491
+ # Delete marked labels
1492
+ for label in marked_for_del:
1493
+ del blocks[label]
1494
+ merge_adjacent_blocks(blocks)
1495
+ return rename_labels(blocks)
1496
+
1497
+
1498
+ arr_math = [
1499
+ "min",
1500
+ "max",
1501
+ "sum",
1502
+ "prod",
1503
+ "mean",
1504
+ "var",
1505
+ "std",
1506
+ "cumsum",
1507
+ "cumprod",
1508
+ "argmax",
1509
+ "argmin",
1510
+ "argsort",
1511
+ "nonzero",
1512
+ "ravel",
1513
+ ]
1514
+
1515
+
1516
+ def canonicalize_array_math(func_ir, typemap, calltypes, typingctx):
1517
+ # save array arg to call
1518
+ # call_varname -> array
1519
+ blocks = func_ir.blocks
1520
+ saved_arr_arg = {}
1521
+ topo_order = find_topo_order(blocks)
1522
+ for label in topo_order:
1523
+ block = blocks[label]
1524
+ new_body = []
1525
+ for stmt in block.body:
1526
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr):
1527
+ lhs = stmt.target.name
1528
+ rhs = stmt.value
1529
+ # replace A.func with np.func, and save A in saved_arr_arg
1530
+ if (
1531
+ rhs.op == "getattr"
1532
+ and rhs.attr in arr_math
1533
+ and isinstance(
1534
+ typemap[rhs.value.name], types.npytypes.Array
1535
+ )
1536
+ ):
1537
+ rhs = stmt.value
1538
+ arr = rhs.value
1539
+ saved_arr_arg[lhs] = arr
1540
+ scope = arr.scope
1541
+ loc = arr.loc
1542
+ # g_np_var = Global(numpy)
1543
+ g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
1544
+ typemap[g_np_var.name] = types.misc.Module(numpy)
1545
+ g_np = ir.Global("np", numpy, loc)
1546
+ g_np_assign = ir.Assign(g_np, g_np_var, loc)
1547
+ rhs.value = g_np_var
1548
+ new_body.append(g_np_assign)
1549
+ func_ir._definitions[g_np_var.name] = [g_np]
1550
+ # update func var type
1551
+ func = getattr(numpy, rhs.attr)
1552
+ func_typ = get_np_ufunc_typ(func)
1553
+ typemap.pop(lhs)
1554
+ typemap[lhs] = func_typ
1555
+ if rhs.op == "call" and rhs.func.name in saved_arr_arg:
1556
+ # add array as first arg
1557
+ arr = saved_arr_arg[rhs.func.name]
1558
+ # update call type signature to include array arg
1559
+ old_sig = calltypes.pop(rhs)
1560
+ # argsort requires kws for typing so sig.args can't be used
1561
+ # reusing sig.args since some types become Const in sig
1562
+ argtyps = old_sig.args[: len(rhs.args)]
1563
+ kwtyps = {name: typemap[v.name] for name, v in rhs.kws}
1564
+ calltypes[rhs] = typemap[rhs.func.name].get_call_type(
1565
+ typingctx, [typemap[arr.name]] + list(argtyps), kwtyps
1566
+ )
1567
+ rhs.args = [arr] + rhs.args
1568
+
1569
+ new_body.append(stmt)
1570
+ block.body = new_body
1571
+ return
1572
+
1573
+
1574
+ # format: {type:function}
1575
+ array_accesses_extensions = {}
1576
+
1577
+
1578
+ def get_array_accesses(blocks, accesses=None):
1579
+ """returns a set of arrays accessed and their indices."""
1580
+ if accesses is None:
1581
+ accesses = set()
1582
+
1583
+ for block in blocks.values():
1584
+ for inst in block.body:
1585
+ if isinstance(inst, ir.SetItem):
1586
+ accesses.add((inst.target.name, inst.index.name))
1587
+ if isinstance(inst, ir.StaticSetItem):
1588
+ accesses.add((inst.target.name, inst.index_var.name))
1589
+ if isinstance(inst, ir.Assign):
1590
+ rhs = inst.value
1591
+ if isinstance(rhs, ir.Expr) and rhs.op == "getitem":
1592
+ accesses.add((rhs.value.name, rhs.index.name))
1593
+ if isinstance(rhs, ir.Expr) and rhs.op == "static_getitem":
1594
+ index = rhs.index
1595
+ # slice is unhashable, so just keep the variable
1596
+ if index is None or is_slice_index(index):
1597
+ index = rhs.index_var.name
1598
+ accesses.add((rhs.value.name, index))
1599
+ for T, f in array_accesses_extensions.items():
1600
+ if isinstance(inst, T):
1601
+ f(inst, accesses)
1602
+ return accesses
1603
+
1604
+
1605
+ def is_slice_index(index):
1606
+ """see if index is a slice index or has slice in it"""
1607
+ if isinstance(index, slice):
1608
+ return True
1609
+ if isinstance(index, tuple):
1610
+ for i in index:
1611
+ if isinstance(i, slice):
1612
+ return True
1613
+ return False
1614
+
1615
+
1616
+ def merge_adjacent_blocks(blocks):
1617
+ cfg = compute_cfg_from_blocks(blocks)
1618
+ # merge adjacent blocks
1619
+ removed = set()
1620
+ for label in list(blocks.keys()):
1621
+ if label in removed:
1622
+ continue
1623
+ block = blocks[label]
1624
+ succs = list(cfg.successors(label))
1625
+ while True:
1626
+ if len(succs) != 1:
1627
+ break
1628
+ next_label = succs[0][0]
1629
+ if next_label in removed:
1630
+ break
1631
+ preds = list(cfg.predecessors(next_label))
1632
+ succs = list(cfg.successors(next_label))
1633
+ if len(preds) != 1 or preds[0][0] != label:
1634
+ break
1635
+ next_block = blocks[next_label]
1636
+ # XXX: commented out since scope objects are not consistent
1637
+ # throughout the compiler. for example, pieces of code are compiled
1638
+ # and inlined on the fly without proper scope merge.
1639
+ # if block.scope != next_block.scope:
1640
+ # break
1641
+ # merge
1642
+ block.body.pop() # remove Jump
1643
+ block.body += next_block.body
1644
+ del blocks[next_label]
1645
+ removed.add(next_label)
1646
+ label = next_label
1647
+
1648
+
1649
+ def restore_copy_var_names(blocks, save_copies, typemap):
1650
+ """
1651
+ restores variable names of user variables after applying copy propagation
1652
+ """
1653
+ if not save_copies:
1654
+ return {}
1655
+
1656
+ rename_dict = {}
1657
+ var_rename_map = {}
1658
+ for a, b in save_copies:
1659
+ # a is string name, b is variable
1660
+ # if a is user variable and b is generated temporary and b is not
1661
+ # already renamed
1662
+ if (
1663
+ not a.startswith("$")
1664
+ and b.name.startswith("$")
1665
+ and b.name not in rename_dict
1666
+ ):
1667
+ new_name = mk_unique_var("${}".format(a))
1668
+ rename_dict[b.name] = new_name
1669
+ var_rename_map[new_name] = a
1670
+ typ = typemap.pop(b.name)
1671
+ typemap[new_name] = typ
1672
+
1673
+ replace_var_names(blocks, rename_dict)
1674
+ return var_rename_map
1675
+
1676
+
1677
+ def simplify(func_ir, typemap, calltypes, metadata):
1678
+ # get copies in to blocks and out from blocks
1679
+ in_cps, _ = copy_propagate(func_ir.blocks, typemap)
1680
+ # table mapping variable names to ir.Var objects to help replacement
1681
+ name_var_table = get_name_var_table(func_ir.blocks)
1682
+ save_copies = apply_copy_propagate(
1683
+ func_ir.blocks, in_cps, name_var_table, typemap, calltypes
1684
+ )
1685
+ var_rename_map = restore_copy_var_names(
1686
+ func_ir.blocks, save_copies, typemap
1687
+ )
1688
+ if "var_rename_map" not in metadata:
1689
+ metadata["var_rename_map"] = {}
1690
+ metadata["var_rename_map"].update(var_rename_map)
1691
+ # remove dead code to enable fusion
1692
+ if config.DEBUG_ARRAY_OPT >= 1:
1693
+ dprint_func_ir(func_ir, "after copy prop")
1694
+ remove_dead(func_ir.blocks, func_ir.arg_names, func_ir, typemap)
1695
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1696
+ if config.DEBUG_ARRAY_OPT >= 1:
1697
+ dprint_func_ir(func_ir, "after simplify")
1698
+
1699
+
1700
+ class GuardException(Exception):
1701
+ pass
1702
+
1703
+
1704
+ def require(cond):
1705
+ """
1706
+ Raise GuardException if the given condition is False.
1707
+ """
1708
+ if not cond:
1709
+ raise GuardException
1710
+
1711
+
1712
+ def guard(func, *args, **kwargs):
1713
+ """
1714
+ Run a function with given set of arguments, and guard against
1715
+ any GuardException raised by the function by returning None,
1716
+ or the expected return results if no such exception was raised.
1717
+ """
1718
+ try:
1719
+ return func(*args, **kwargs)
1720
+ except GuardException:
1721
+ return None
1722
+
1723
+
1724
+ def get_definition(func_ir, name, **kwargs):
1725
+ """
1726
+ Same as func_ir.get_definition(name), but raise GuardException if
1727
+ exception KeyError is caught.
1728
+ """
1729
+ try:
1730
+ return func_ir.get_definition(name, **kwargs)
1731
+ except KeyError:
1732
+ raise GuardException
1733
+
1734
+
1735
+ def build_definitions(blocks, definitions=None):
1736
+ """Build the definitions table of the given blocks by scanning
1737
+ through all blocks and instructions, useful when the definitions
1738
+ table is out-of-sync.
1739
+ Will return a new definition table if one is not passed.
1740
+ """
1741
+ if definitions is None:
1742
+ definitions = collections.defaultdict(list)
1743
+
1744
+ for block in blocks.values():
1745
+ for inst in block.body:
1746
+ if isinstance(inst, ir.Assign):
1747
+ name = inst.target.name
1748
+ definition = definitions.get(name, [])
1749
+ if definition == []:
1750
+ definitions[name] = definition
1751
+ definition.append(inst.value)
1752
+ if type(inst) in build_defs_extensions:
1753
+ f = build_defs_extensions[type(inst)]
1754
+ f(inst, definitions)
1755
+
1756
+ return definitions
1757
+
1758
+
1759
+ build_defs_extensions = {}
1760
+
1761
+
1762
+ def find_callname(
1763
+ func_ir, expr, typemap=None, definition_finder=get_definition
1764
+ ):
1765
+ """Try to find a call expression's function and module names and return
1766
+ them as strings for unbounded calls. If the call is a bounded call, return
1767
+ the self object instead of module name. Raise GuardException if failed.
1768
+
1769
+ Providing typemap can make the call matching more accurate in corner cases
1770
+ such as bounded call on an object which is inside another object.
1771
+ """
1772
+ from numba.cuda.extending import _Intrinsic
1773
+
1774
+ require(isinstance(expr, ir.Expr) and expr.op == "call")
1775
+ callee = expr.func
1776
+ callee_def = definition_finder(func_ir, callee)
1777
+ attrs = []
1778
+ obj = None
1779
+ while True:
1780
+ if isinstance(callee_def, (ir.Global, ir.FreeVar)):
1781
+ # require(callee_def.value == numpy)
1782
+ # these checks support modules like numpy, numpy.random as well as
1783
+ # calls like len() and intrinsics like assertEquiv
1784
+ keys = ["name", "_name", "__name__"]
1785
+ value = None
1786
+ for key in keys:
1787
+ if hasattr(callee_def.value, key):
1788
+ value = getattr(callee_def.value, key)
1789
+ break
1790
+ if not value or not isinstance(value, str):
1791
+ raise GuardException
1792
+ attrs.append(value)
1793
+ def_val = callee_def.value
1794
+ # get the underlying definition of Intrinsic object to be able to
1795
+ # find the module effectively.
1796
+ # Otherwise, it will return numba.cuda.extending
1797
+ if isinstance(def_val, _Intrinsic):
1798
+ def_val = def_val._defn
1799
+ if hasattr(def_val, "__module__"):
1800
+ mod_name = def_val.__module__
1801
+ # The reason for first checking if the function is in NumPy's
1802
+ # top level name space by module is that some functions are
1803
+ # deprecated in NumPy but the functions' names are aliased with
1804
+ # other common names. This prevents deprecation warnings on
1805
+ # e.g. getattr(numpy, 'bool') were a bool the target.
1806
+ # For context see #6175, impacts NumPy>=1.20.
1807
+ mod_not_none = mod_name is not None
1808
+ numpy_toplevel = mod_not_none and (
1809
+ mod_name == "numpy" or mod_name.startswith("numpy.")
1810
+ )
1811
+ # it might be a numpy function imported directly
1812
+ if (
1813
+ numpy_toplevel
1814
+ and hasattr(numpy, value)
1815
+ and def_val == getattr(numpy, value)
1816
+ ):
1817
+ attrs += ["numpy"]
1818
+ # it might be a np.random function imported directly
1819
+ elif hasattr(numpy.random, value) and def_val == getattr(
1820
+ numpy.random, value
1821
+ ):
1822
+ attrs += ["random", "numpy"]
1823
+ elif mod_not_none:
1824
+ attrs.append(mod_name)
1825
+ else:
1826
+ class_name = def_val.__class__.__name__
1827
+ if class_name == "builtin_function_or_method":
1828
+ class_name = "builtin"
1829
+ if class_name != "module":
1830
+ attrs.append(class_name)
1831
+ break
1832
+ elif isinstance(callee_def, ir.Expr) and callee_def.op == "getattr":
1833
+ obj = callee_def.value
1834
+ attrs.append(callee_def.attr)
1835
+ if typemap and obj.name in typemap:
1836
+ typ = typemap[obj.name]
1837
+ if not isinstance(typ, types.Module):
1838
+ return attrs[0], obj
1839
+ callee_def = definition_finder(func_ir, obj)
1840
+ else:
1841
+ # obj.func calls where obj is not np array
1842
+ if obj is not None:
1843
+ return ".".join(reversed(attrs)), obj
1844
+ raise GuardException
1845
+ return attrs[0], ".".join(reversed(attrs[1:]))
1846
+
1847
+
1848
+ def find_build_sequence(func_ir, var):
1849
+ """Check if a variable is constructed via build_tuple or
1850
+ build_list or build_set, and return the sequence and the
1851
+ operator, or raise GuardException otherwise.
1852
+ Note: only build_tuple is immutable, so use with care.
1853
+ """
1854
+ require(isinstance(var, ir.Var))
1855
+ var_def = get_definition(func_ir, var)
1856
+ require(isinstance(var_def, ir.Expr))
1857
+ build_ops = ["build_tuple", "build_list", "build_set"]
1858
+ require(var_def.op in build_ops)
1859
+ return var_def.items, var_def.op
1860
+
1861
+
1862
+ def find_const(func_ir, var):
1863
+ """Check if a variable is defined as constant, and return
1864
+ the constant value, or raise GuardException otherwise.
1865
+ """
1866
+ require(isinstance(var, ir.Var))
1867
+ var_def = get_definition(func_ir, var)
1868
+ require(isinstance(var_def, (ir.Const, ir.Global, ir.FreeVar)))
1869
+ return var_def.value
1870
+
1871
+
1872
+ def compile_to_numba_ir(
1873
+ mk_func,
1874
+ glbls,
1875
+ typingctx=None,
1876
+ targetctx=None,
1877
+ arg_typs=None,
1878
+ typemap=None,
1879
+ calltypes=None,
1880
+ ):
1881
+ """
1882
+ Compile a function or a make_function node to Numba IR.
1883
+
1884
+ Rename variables and
1885
+ labels to avoid conflict if inlined somewhere else. Perform type inference
1886
+ if typingctx and other typing inputs are available and update typemap and
1887
+ calltypes.
1888
+ """
1889
+ from numba.cuda.core import typed_passes
1890
+
1891
+ # mk_func can be actual function or make_function node, or a njit function
1892
+ if hasattr(mk_func, "code"):
1893
+ code = mk_func.code
1894
+ elif hasattr(mk_func, "__code__"):
1895
+ code = mk_func.__code__
1896
+ else:
1897
+ raise NotImplementedError(
1898
+ "function type not recognized {}".format(mk_func)
1899
+ )
1900
+ f_ir = get_ir_of_code(glbls, code)
1901
+ remove_dels(f_ir.blocks)
1902
+
1903
+ # relabel by adding an offset
1904
+ f_ir.blocks = add_offset_to_labels(f_ir.blocks, _the_max_label.next())
1905
+ max_label = max(f_ir.blocks.keys())
1906
+ _the_max_label.update(max_label)
1907
+
1908
+ # rename all variables to avoid conflict
1909
+ var_table = get_name_var_table(f_ir.blocks)
1910
+ new_var_dict = {}
1911
+ for name, var in var_table.items():
1912
+ new_var_dict[name] = mk_unique_var(name)
1913
+ replace_var_names(f_ir.blocks, new_var_dict)
1914
+
1915
+ # perform type inference if typingctx is available and update type
1916
+ # data structures typemap and calltypes
1917
+ if typingctx:
1918
+ f_typemap, f_return_type, f_calltypes, _ = (
1919
+ typed_passes.type_inference_stage(
1920
+ typingctx, targetctx, f_ir, arg_typs, None
1921
+ )
1922
+ )
1923
+ # remove argument entries like arg.a from typemap
1924
+ arg_names = [vname for vname in f_typemap if vname.startswith("arg.")]
1925
+ for a in arg_names:
1926
+ f_typemap.pop(a)
1927
+ typemap.update(f_typemap)
1928
+ calltypes.update(f_calltypes)
1929
+ return f_ir
1930
+
1931
+
1932
+ def _create_function_from_code_obj(fcode, func_env, func_arg, func_clo, glbls):
1933
+ """
1934
+ Creates a function from a code object. Args:
1935
+ * fcode - the code object
1936
+ * func_env - string for the freevar placeholders
1937
+ * func_arg - string for the function args (e.g. "a, b, c, d=None")
1938
+ * func_clo - string for the closure args
1939
+ * glbls - the function globals
1940
+ """
1941
+ sanitized_co_name = fcode.co_name.replace("<", "_").replace(">", "_")
1942
+ func_text = (
1943
+ f"def closure():\n{func_env}\n"
1944
+ f"\tdef {sanitized_co_name}({func_arg}):\n"
1945
+ f"\t\treturn ({func_clo})\n"
1946
+ f"\treturn {sanitized_co_name}"
1947
+ )
1948
+ loc = {}
1949
+ exec(func_text, glbls, loc)
1950
+
1951
+ f = loc["closure"]()
1952
+ # replace the code body
1953
+ f.__code__ = fcode
1954
+ f.__name__ = fcode.co_name
1955
+ return f
1956
+
1957
+
1958
+ def get_ir_of_code(glbls, fcode):
1959
+ """
1960
+ Compile a code object to get its IR, ir.Del nodes are emitted
1961
+ """
1962
+ nfree = len(fcode.co_freevars)
1963
+ func_env = "\n".join(["\tc_%d = None" % i for i in range(nfree)])
1964
+ func_clo = ",".join(["c_%d" % i for i in range(nfree)])
1965
+ func_arg = ",".join(["x_%d" % i for i in range(fcode.co_argcount)])
1966
+
1967
+ f = _create_function_from_code_obj(
1968
+ fcode, func_env, func_arg, func_clo, glbls
1969
+ )
1970
+
1971
+ from numba.cuda import compiler
1972
+ from numba.cuda.core.compiler import StateDict
1973
+
1974
+ ir = compiler.run_frontend(f)
1975
+
1976
+ # we need to run the before inference rewrite pass to normalize the IR
1977
+ # XXX: check rewrite pass flag?
1978
+ # for example, Raise nodes need to become StaticRaise before type inference
1979
+ class DummyPipeline(object):
1980
+ def __init__(self, f_ir):
1981
+ self.state = StateDict()
1982
+ self.state.typingctx = None
1983
+ self.state.targetctx = None
1984
+ self.state.args = None
1985
+ self.state.func_ir = f_ir
1986
+ self.state.typemap = None
1987
+ self.state.return_type = None
1988
+ self.state.calltypes = None
1989
+
1990
+ state = DummyPipeline(ir).state
1991
+ rewrites.rewrite_registry.apply("before-inference", state)
1992
+ # call inline pass to handle cases like comprehensions
1993
+ swapped = {} # TODO: get this from diagnostics store
1994
+ from numba.cuda.core.inline_closurecall import InlineClosureCallPass
1995
+
1996
+ inline_pass = InlineClosureCallPass(
1997
+ ir, numba.cuda.core.options.ParallelOptions(False), swapped
1998
+ )
1999
+ inline_pass.run()
2000
+
2001
+ # TODO: DO NOT ADD MORE THINGS HERE!
2002
+ # If adding more things here is being contemplated, it really is time to
2003
+ # retire this function and work on getting the InlineWorker class from
2004
+ # numba.core.inline_closurecall into sufficient shape as a replacement.
2005
+ # The issue with `get_ir_of_code` is that it doesn't run a full compilation
2006
+ # pipeline and as a result various additional things keep needing to be
2007
+ # added to create valid IR.
2008
+
2009
+ # rebuild IR in SSA form
2010
+ from numba.cuda.core.untyped_passes import ReconstructSSA
2011
+ from numba.cuda.core.typed_passes import PreLowerStripPhis
2012
+
2013
+ reconstruct_ssa = ReconstructSSA()
2014
+ phistrip = PreLowerStripPhis()
2015
+ reconstruct_ssa.run_pass(state)
2016
+ phistrip.run_pass(state)
2017
+
2018
+ post_proc = postproc.PostProcessor(ir)
2019
+ post_proc.run(True)
2020
+ return ir
2021
+
2022
+
2023
+ def replace_arg_nodes(block, args):
2024
+ """
2025
+ Replace ir.Arg(...) with variables
2026
+ """
2027
+ for stmt in block.body:
2028
+ if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Arg):
2029
+ idx = stmt.value.index
2030
+ assert idx < len(args)
2031
+ stmt.value = args[idx]
2032
+ return
2033
+
2034
+
2035
+ def replace_returns(blocks, target, return_label):
2036
+ """
2037
+ Return return statement by assigning directly to target, and a jump.
2038
+ """
2039
+ for block in blocks.values():
2040
+ # some blocks may be empty during transformations
2041
+ if not block.body:
2042
+ continue
2043
+ stmt = block.terminator
2044
+ if isinstance(stmt, ir.Return):
2045
+ block.body.pop() # remove return
2046
+ cast_stmt = block.body.pop()
2047
+ assert (
2048
+ isinstance(cast_stmt, ir.Assign)
2049
+ and isinstance(cast_stmt.value, ir.Expr)
2050
+ and cast_stmt.value.op == "cast"
2051
+ ), "invalid return cast"
2052
+ block.body.append(
2053
+ ir.Assign(cast_stmt.value.value, target, stmt.loc)
2054
+ )
2055
+ block.body.append(ir.Jump(return_label, stmt.loc))
2056
+
2057
+
2058
+ def gen_np_call(func_as_str, func, lhs, args, typingctx, typemap, calltypes):
2059
+ scope = args[0].scope
2060
+ loc = args[0].loc
2061
+
2062
+ # g_np_var = Global(numpy)
2063
+ g_np_var = ir.Var(scope, mk_unique_var("$np_g_var"), loc)
2064
+ typemap[g_np_var.name] = types.misc.Module(numpy)
2065
+ g_np = ir.Global("np", numpy, loc)
2066
+ g_np_assign = ir.Assign(g_np, g_np_var, loc)
2067
+ # attr call: <something>_attr = getattr(g_np_var, func_as_str)
2068
+ np_attr_call = ir.Expr.getattr(g_np_var, func_as_str, loc)
2069
+ attr_var = ir.Var(scope, mk_unique_var("$np_attr_attr"), loc)
2070
+ func_var_typ = get_np_ufunc_typ(func)
2071
+ typemap[attr_var.name] = func_var_typ
2072
+ attr_assign = ir.Assign(np_attr_call, attr_var, loc)
2073
+ # np call: lhs = np_attr(*args)
2074
+ np_call = ir.Expr.call(attr_var, args, (), loc)
2075
+ arg_types = [typemap[x.name] for x in args]
2076
+ func_typ = func_var_typ.get_call_type(typingctx, arg_types, {})
2077
+ calltypes[np_call] = func_typ
2078
+ np_assign = ir.Assign(np_call, lhs, loc)
2079
+ return [g_np_assign, attr_assign, np_assign]
2080
+
2081
+
2082
+ def dump_block(label, block):
2083
+ print(label, ":")
2084
+ for stmt in block.body:
2085
+ print(" ", stmt)
2086
+
2087
+
2088
+ def dump_blocks(blocks):
2089
+ for label, block in blocks.items():
2090
+ dump_block(label, block)
2091
+
2092
+
2093
+ def is_operator_or_getitem(expr):
2094
+ """true if expr is unary or binary operator or getitem"""
2095
+ return (
2096
+ isinstance(expr, ir.Expr)
2097
+ and getattr(expr, "op", False)
2098
+ and expr.op
2099
+ in ["unary", "binop", "inplace_binop", "getitem", "static_getitem"]
2100
+ )
2101
+
2102
+
2103
+ def is_get_setitem(stmt):
2104
+ """stmt is getitem assignment or setitem (and static cases)"""
2105
+ return is_getitem(stmt) or is_setitem(stmt)
2106
+
2107
+
2108
+ def is_getitem(stmt):
2109
+ """true if stmt is a getitem or static_getitem assignment"""
2110
+ return (
2111
+ isinstance(stmt, ir.Assign)
2112
+ and isinstance(stmt.value, ir.Expr)
2113
+ and stmt.value.op in ["getitem", "static_getitem"]
2114
+ )
2115
+
2116
+
2117
+ def is_setitem(stmt):
2118
+ """true if stmt is a SetItem or StaticSetItem node"""
2119
+ return isinstance(stmt, (ir.SetItem, ir.StaticSetItem))
2120
+
2121
+
2122
+ def index_var_of_get_setitem(stmt):
2123
+ """get index variable for getitem/setitem nodes (and static cases)"""
2124
+ if is_getitem(stmt):
2125
+ if stmt.value.op == "getitem":
2126
+ return stmt.value.index
2127
+ else:
2128
+ return stmt.value.index_var
2129
+
2130
+ if is_setitem(stmt):
2131
+ if isinstance(stmt, ir.SetItem):
2132
+ return stmt.index
2133
+ else:
2134
+ return stmt.index_var
2135
+
2136
+ return None
2137
+
2138
+
2139
+ def set_index_var_of_get_setitem(stmt, new_index):
2140
+ if is_getitem(stmt):
2141
+ if stmt.value.op == "getitem":
2142
+ stmt.value.index = new_index
2143
+ else:
2144
+ stmt.value.index_var = new_index
2145
+ elif is_setitem(stmt):
2146
+ if isinstance(stmt, ir.SetItem):
2147
+ stmt.index = new_index
2148
+ else:
2149
+ stmt.index_var = new_index
2150
+ else:
2151
+ raise ValueError(
2152
+ "getitem or setitem node expected but received {}".format(stmt)
2153
+ )
2154
+
2155
+
2156
+ def is_namedtuple_class(c):
2157
+ """check if c is a namedtuple class"""
2158
+ if not isinstance(c, type):
2159
+ return False
2160
+ # should have only tuple as superclass
2161
+ bases = c.__bases__
2162
+ if len(bases) != 1 or bases[0] is not tuple:
2163
+ return False
2164
+ # should have _make method
2165
+ if not hasattr(c, "_make"):
2166
+ return False
2167
+ # should have _fields that is all string
2168
+ fields = getattr(c, "_fields", None)
2169
+ if not isinstance(fields, tuple):
2170
+ return False
2171
+ return all(isinstance(f, str) for f in fields)
2172
+
2173
+
2174
+ def fill_block_with_call(newblock, callee, label_next, inputs, outputs):
2175
+ """Fill *newblock* to call *callee* with arguments listed in *inputs*.
2176
+ The returned values are unwrapped into variables in *outputs*.
2177
+ The block would then jump to *label_next*.
2178
+ """
2179
+ scope = newblock.scope
2180
+ loc = newblock.loc
2181
+
2182
+ fn = ir.Const(value=callee, loc=loc)
2183
+ fnvar = scope.make_temp(loc=loc)
2184
+ newblock.append(ir.Assign(target=fnvar, value=fn, loc=loc))
2185
+ # call
2186
+ args = [scope.get_exact(name) for name in inputs]
2187
+ callexpr = ir.Expr.call(func=fnvar, args=args, kws=(), loc=loc)
2188
+ callres = scope.make_temp(loc=loc)
2189
+ newblock.append(ir.Assign(target=callres, value=callexpr, loc=loc))
2190
+ # unpack return value
2191
+ for i, out in enumerate(outputs):
2192
+ target = scope.get_exact(out)
2193
+ getitem = ir.Expr.static_getitem(
2194
+ value=callres, index=i, index_var=None, loc=loc
2195
+ )
2196
+ newblock.append(ir.Assign(target=target, value=getitem, loc=loc))
2197
+ # jump to next block
2198
+ newblock.append(ir.Jump(target=label_next, loc=loc))
2199
+ return newblock
2200
+
2201
+
2202
+ def fill_callee_prologue(block, inputs, label_next):
2203
+ """
2204
+ Fill a new block *block* that unwraps arguments using names in *inputs* and
2205
+ then jumps to *label_next*.
2206
+
2207
+ Expected to use with *fill_block_with_call()*
2208
+ """
2209
+ scope = block.scope
2210
+ loc = block.loc
2211
+ # load args
2212
+ args = [ir.Arg(name=k, index=i, loc=loc) for i, k in enumerate(inputs)]
2213
+ for aname, aval in zip(inputs, args):
2214
+ tmp = ir.Var(scope=scope, name=aname, loc=loc)
2215
+ block.append(ir.Assign(target=tmp, value=aval, loc=loc))
2216
+ # jump to loop entry
2217
+ block.append(ir.Jump(target=label_next, loc=loc))
2218
+ return block
2219
+
2220
+
2221
+ def fill_callee_epilogue(block, outputs):
2222
+ """
2223
+ Fill a new block *block* to prepare the return values.
2224
+ This block is the last block of the function.
2225
+
2226
+ Expected to use with *fill_block_with_call()*
2227
+ """
2228
+ scope = block.scope
2229
+ loc = block.loc
2230
+ # prepare tuples to return
2231
+ vals = [scope.get_exact(name=name) for name in outputs]
2232
+ tupexpr = ir.Expr.build_tuple(items=vals, loc=loc)
2233
+ tup = scope.make_temp(loc=loc)
2234
+ block.append(ir.Assign(target=tup, value=tupexpr, loc=loc))
2235
+ # return
2236
+ block.append(ir.Return(value=tup, loc=loc))
2237
+ return block
2238
+
2239
+
2240
+ def find_outer_value(func_ir, var):
2241
+ """Check if a variable is a global value, and return the value,
2242
+ or raise GuardException otherwise.
2243
+ """
2244
+ dfn = get_definition(func_ir, var)
2245
+ if isinstance(dfn, (ir.Global, ir.FreeVar)):
2246
+ return dfn.value
2247
+
2248
+ if isinstance(dfn, ir.Expr) and dfn.op == "getattr":
2249
+ prev_val = find_outer_value(func_ir, dfn.value)
2250
+ try:
2251
+ val = getattr(prev_val, dfn.attr)
2252
+ return val
2253
+ except AttributeError:
2254
+ raise GuardException
2255
+
2256
+ raise GuardException
2257
+
2258
+
2259
+ def raise_on_unsupported_feature(func_ir, typemap):
2260
+ """
2261
+ Helper function to walk IR and raise if it finds op codes
2262
+ that are unsupported. Could be extended to cover IR sequences
2263
+ as well as op codes. Intended use is to call it as a pipeline
2264
+ stage just prior to lowering to prevent LoweringErrors for known
2265
+ unsupported features.
2266
+ """
2267
+ gdb_calls = [] # accumulate calls to gdb/gdb_init
2268
+
2269
+ # issue 2195: check for excessively large tuples
2270
+ for arg_name in func_ir.arg_names:
2271
+ if (
2272
+ arg_name in typemap
2273
+ and isinstance(typemap[arg_name], types.containers.UniTuple)
2274
+ and typemap[arg_name].count > 1000
2275
+ ):
2276
+ # Raise an exception when len(tuple) > 1000. The choice of this number (1000)
2277
+ # was entirely arbitrary
2278
+ msg = (
2279
+ "Tuple '{}' length must be smaller than 1000.\n"
2280
+ "Large tuples lead to the generation of a prohibitively large "
2281
+ "LLVM IR which causes excessive memory pressure "
2282
+ "and large compile times.\n"
2283
+ "As an alternative, the use of a 'list' is recommended in "
2284
+ "place of a 'tuple' as lists do not suffer from this problem.".format(
2285
+ arg_name
2286
+ )
2287
+ )
2288
+ raise UnsupportedError(msg, func_ir.loc)
2289
+
2290
+ for blk in func_ir.blocks.values():
2291
+ for stmt in blk.find_insts(ir.Assign):
2292
+ # This raises on finding `make_function`
2293
+ if isinstance(stmt.value, ir.Expr):
2294
+ if stmt.value.op == "make_function":
2295
+ val = stmt.value
2296
+
2297
+ # See if the construct name can be refined
2298
+ code = getattr(val, "code", None)
2299
+ if code is not None:
2300
+ # check if this is a closure, the co_name will
2301
+ # be the captured function name which is not
2302
+ # useful so be explicit
2303
+ if getattr(val, "closure", None) is not None:
2304
+ use = "<creating a function from a closure>"
2305
+ expr = ""
2306
+ else:
2307
+ use = code.co_name
2308
+ expr = "(%s) " % use
2309
+ else:
2310
+ use = "<could not ascertain use case>"
2311
+ expr = ""
2312
+
2313
+ msg = (
2314
+ "Numba encountered the use of a language "
2315
+ "feature it does not support in this context: "
2316
+ "%s (op code: make_function not supported). If "
2317
+ "the feature is explicitly supported it is "
2318
+ "likely that the result of the expression %s"
2319
+ "is being used in an unsupported manner."
2320
+ ) % (use, expr)
2321
+ raise UnsupportedError(msg, stmt.value.loc)
2322
+
2323
+ # this checks for gdb initialization calls, only one is permitted
2324
+ if isinstance(stmt.value, (ir.Global, ir.FreeVar)):
2325
+ val = stmt.value
2326
+ val = getattr(val, "value", None)
2327
+ if val is None:
2328
+ continue
2329
+
2330
+ # check global function
2331
+ found = False
2332
+ if isinstance(val, pytypes.FunctionType):
2333
+ found = val in {numba.cuda.gdb, numba.cuda.gdb_init}
2334
+ if not found: # freevar bind to intrinsic
2335
+ found = getattr(val, "_name", "") == "gdb_internal"
2336
+ if found:
2337
+ gdb_calls.append(stmt.loc) # report last seen location
2338
+
2339
+ # this checks that np.<type> was called if view is called
2340
+ if isinstance(stmt.value, ir.Expr):
2341
+ if stmt.value.op == "getattr" and stmt.value.attr == "view":
2342
+ var = stmt.value.value.name
2343
+ if isinstance(typemap[var], types.Array):
2344
+ continue
2345
+ df = func_ir.get_definition(var)
2346
+ cn = guard(find_callname, func_ir, df)
2347
+ if cn and cn[1] == "numpy":
2348
+ ty = getattr(numpy, cn[0])
2349
+ if numpy.issubdtype(
2350
+ ty, numpy.integer
2351
+ ) or numpy.issubdtype(ty, numpy.floating):
2352
+ continue
2353
+
2354
+ vardescr = (
2355
+ "" if var.startswith("$") else "'{}' ".format(var)
2356
+ )
2357
+ raise TypingError(
2358
+ "'view' can only be called on NumPy dtypes, "
2359
+ "try wrapping the variable {}with 'np.<dtype>()'".format(
2360
+ vardescr
2361
+ ),
2362
+ loc=stmt.loc,
2363
+ )
2364
+
2365
+ # checks for globals that are also reflected
2366
+ if isinstance(stmt.value, ir.Global):
2367
+ ty = typemap[stmt.target.name]
2368
+ msg = (
2369
+ "The use of a %s type, assigned to variable '%s' in "
2370
+ "globals, is not supported as globals are considered "
2371
+ "compile-time constants and there is no known way to "
2372
+ "compile a %s type as a constant."
2373
+ )
2374
+ if getattr(ty, "reflected", False) or isinstance(
2375
+ ty, (types.DictType, types.ListType)
2376
+ ):
2377
+ raise TypingError(
2378
+ msg % (ty, stmt.value.name, ty), loc=stmt.loc
2379
+ )
2380
+
2381
+ # checks for generator expressions (yield in use when func_ir has
2382
+ # not been identified as a generator).
2383
+ if isinstance(stmt.value, ir.Yield) and not func_ir.is_generator:
2384
+ msg = "The use of generator expressions is unsupported."
2385
+ raise UnsupportedError(msg, loc=stmt.loc)
2386
+
2387
+ # There is more than one call to function gdb/gdb_init
2388
+ if len(gdb_calls) > 1:
2389
+ msg = (
2390
+ "Calling either numba.gdb() or numba.gdb_init() more than once "
2391
+ "in a function is unsupported (strange things happen!), use "
2392
+ "numba.gdb_breakpoint() to create additional breakpoints "
2393
+ "instead.\n\nRelevant documentation is available here:\n"
2394
+ "https://numba.readthedocs.io/en/stable/user/troubleshoot.html"
2395
+ "#using-numba-s-direct-gdb-bindings-in-nopython-mode\n\n"
2396
+ "Conflicting calls found at:\n %s"
2397
+ )
2398
+ buf = "\n".join([x.strformat() for x in gdb_calls])
2399
+ raise UnsupportedError(msg % buf)
2400
+
2401
+
2402
+ def warn_deprecated(func_ir, typemap):
2403
+ # first pass, just walk the type map
2404
+ for name, ty in typemap.items():
2405
+ # the Type Metaclass has a reflected member
2406
+ if ty.reflected:
2407
+ # if its an arg, report function call
2408
+ if name.startswith("arg."):
2409
+ loc = func_ir.loc
2410
+ arg = name.split(".")[1]
2411
+ fname = func_ir.func_id.func_qualname
2412
+ tyname = "list" if isinstance(ty, types.List) else "set"
2413
+ url = (
2414
+ "https://numba.readthedocs.io/en/stable/reference/"
2415
+ "deprecation.html#deprecation-of-reflection-for-list-and"
2416
+ "-set-types"
2417
+ )
2418
+ msg = (
2419
+ "\nEncountered the use of a type that is scheduled for "
2420
+ "deprecation: type 'reflected %s' found for argument "
2421
+ "'%s' of function '%s'.\n\nFor more information visit "
2422
+ "%s" % (tyname, arg, fname, url)
2423
+ )
2424
+ warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))
2425
+
2426
+
2427
+ def resolve_func_from_module(func_ir, node):
2428
+ """
2429
+ This returns the python function that is being getattr'd from a module in
2430
+ some IR, it resolves import chains/submodules recursively. Should it not be
2431
+ possible to find the python function being called None will be returned.
2432
+
2433
+ func_ir - the FunctionIR object
2434
+ node - the IR node from which to start resolving (should be a `getattr`).
2435
+ """
2436
+ getattr_chain = []
2437
+
2438
+ def resolve_mod(mod):
2439
+ if getattr(mod, "op", False) == "getattr":
2440
+ getattr_chain.insert(0, mod.attr)
2441
+ try:
2442
+ mod = func_ir.get_definition(mod.value)
2443
+ except KeyError: # multiple definitions
2444
+ return None
2445
+ return resolve_mod(mod)
2446
+ elif isinstance(mod, (ir.Global, ir.FreeVar)):
2447
+ if isinstance(mod.value, pytypes.ModuleType):
2448
+ return mod
2449
+ return None
2450
+
2451
+ mod = resolve_mod(node)
2452
+ if mod is not None:
2453
+ defn = mod.value
2454
+ for x in getattr_chain:
2455
+ defn = getattr(defn, x, False)
2456
+ if not defn:
2457
+ break
2458
+ else:
2459
+ return defn
2460
+ else:
2461
+ return None
2462
+
2463
+
2464
+ def enforce_no_dels(func_ir):
2465
+ """
2466
+ Enforce there being no ir.Del nodes in the IR.
2467
+ """
2468
+ for blk in func_ir.blocks.values():
2469
+ dels = [x for x in blk.find_insts(ir.Del)]
2470
+ if dels:
2471
+ msg = "Illegal IR, del found at: %s" % dels[0]
2472
+ raise CompilerError(msg, loc=dels[0].loc)
2473
+
2474
+
2475
+ def enforce_no_phis(func_ir):
2476
+ """
2477
+ Enforce there being no ir.Expr.phi nodes in the IR.
2478
+ """
2479
+ for blk in func_ir.blocks.values():
2480
+ phis = [x for x in blk.find_exprs(op="phi")]
2481
+ if phis:
2482
+ msg = "Illegal IR, phi found at: %s" % phis[0]
2483
+ raise CompilerError(msg, loc=phis[0].loc)
2484
+
2485
+
2486
+ def legalize_single_scope(blocks):
2487
+ """Check the given mapping of ir.Block for containing a single scope."""
2488
+ return len({blk.scope for blk in blocks.values()}) == 1
2489
+
2490
+
2491
+ def check_and_legalize_ir(func_ir, flags: "numba.cuda.flags.Flags"):
2492
+ """
2493
+ This checks that the IR presented is legal
2494
+ """
2495
+ enforce_no_phis(func_ir)
2496
+ enforce_no_dels(func_ir)
2497
+ # postprocess and emit ir.Dels
2498
+ post_proc = postproc.PostProcessor(func_ir)
2499
+ post_proc.run(True, extend_lifetimes=flags.dbg_extend_lifetimes)
2500
+
2501
+
2502
+ def convert_code_obj_to_function(code_obj, caller_ir):
2503
+ """
2504
+ Converts a code object from a `make_function.code` attr in the IR into a
2505
+ python function, caller_ir is the FunctionIR of the caller and is used for
2506
+ the resolution of freevars.
2507
+ """
2508
+ fcode = code_obj.code
2509
+ nfree = len(fcode.co_freevars)
2510
+
2511
+ # try and resolve freevars if they are consts in the caller's IR
2512
+ # these can be baked into the new function
2513
+ freevars = []
2514
+ for x in fcode.co_freevars:
2515
+ # not using guard here to differentiate between multiple definition and
2516
+ # non-const variable
2517
+ try:
2518
+ freevar_def = caller_ir.get_definition(x)
2519
+ except KeyError:
2520
+ msg = (
2521
+ "Cannot capture a constant value for variable '%s' as there "
2522
+ "are multiple definitions present." % x
2523
+ )
2524
+ raise TypingError(msg, loc=code_obj.loc)
2525
+ if isinstance(freevar_def, ir.Const):
2526
+ freevars.append(freevar_def.value)
2527
+ else:
2528
+ msg = (
2529
+ "Cannot capture the non-constant value associated with "
2530
+ "variable '%s' in a function that may escape." % x
2531
+ )
2532
+ raise TypingError(msg, loc=code_obj.loc)
2533
+
2534
+ func_env = "\n".join(
2535
+ ["\tc_%d = %s" % (i, x) for i, x in enumerate(freevars)]
2536
+ )
2537
+ func_clo = ",".join(["c_%d" % i for i in range(nfree)])
2538
+ co_varnames = list(fcode.co_varnames)
2539
+
2540
+ # This is horrible. The code object knows about the number of args present
2541
+ # it also knows the name of the args but these are bundled in with other
2542
+ # vars in `co_varnames`. The make_function IR node knows what the defaults
2543
+ # are, they are defined in the IR as consts. The following finds the total
2544
+ # number of args (args + kwargs with defaults), finds the default values
2545
+ # and infers the number of "kwargs with defaults" from this and then infers
2546
+ # the number of actual arguments from that.
2547
+ n_kwargs = 0
2548
+ n_allargs = fcode.co_argcount
2549
+ kwarg_defaults = caller_ir.get_definition(code_obj.defaults)
2550
+ if kwarg_defaults is not None:
2551
+ if isinstance(kwarg_defaults, tuple):
2552
+ d = [caller_ir.get_definition(x).value for x in kwarg_defaults]
2553
+ kwarg_defaults_tup = tuple(d)
2554
+ else:
2555
+ d = [
2556
+ caller_ir.get_definition(x).value for x in kwarg_defaults.items
2557
+ ]
2558
+ kwarg_defaults_tup = tuple(d)
2559
+ n_kwargs = len(kwarg_defaults_tup)
2560
+ nargs = n_allargs - n_kwargs
2561
+
2562
+ func_arg = ",".join(["%s" % (co_varnames[i]) for i in range(nargs)])
2563
+ if n_kwargs:
2564
+ kw_const = [
2565
+ "%s = %s" % (co_varnames[i + nargs], kwarg_defaults_tup[i])
2566
+ for i in range(n_kwargs)
2567
+ ]
2568
+ func_arg += ", "
2569
+ func_arg += ", ".join(kw_const)
2570
+
2571
+ # globals are the same as those in the caller
2572
+ glbls = caller_ir.func_id.func.__globals__
2573
+
2574
+ # create the function and return it
2575
+ return _create_function_from_code_obj(
2576
+ fcode, func_env, func_arg, func_clo, glbls
2577
+ )
2578
+
2579
+
2580
+ def fixup_var_define_in_scope(blocks):
2581
+ """Fixes the mapping of ir.Block to ensure all referenced ir.Var are
2582
+ defined in every scope used by the function. Such that looking up a variable
2583
+ from any scope in this function will not fail.
2584
+
2585
+ Note: This is a workaround. Ideally, all the blocks should refer to the
2586
+ same ir.Scope, but that property is not maintained by all the passes.
2587
+ """
2588
+ # Scan for all used variables
2589
+ used_var = {}
2590
+ for blk in blocks.values():
2591
+ scope = blk.scope
2592
+ for inst in blk.body:
2593
+ for var in inst.list_vars():
2594
+ used_var[var] = inst
2595
+ # Note: not all blocks share a single scope even though they should.
2596
+ # Ensure the scope of each block defines all used variables.
2597
+ for blk in blocks.values():
2598
+ scope = blk.scope
2599
+ for var, inst in used_var.items():
2600
+ # add this variable if it's not in scope
2601
+ if var.name not in scope.localvars:
2602
+ # Note: using a internal method to reuse the same
2603
+ scope.localvars.define(var.name, var)
2604
+
2605
+
2606
+ def transfer_scope(block, scope):
2607
+ """Transfer the ir.Block to use the given ir.Scope."""
2608
+ old_scope = block.scope
2609
+ if old_scope is scope:
2610
+ # bypass if the block is already using the given scope
2611
+ return block
2612
+ # Ensure variables are defined in the new scope
2613
+ for var in old_scope.localvars._con.values():
2614
+ if var.name not in scope.localvars:
2615
+ scope.localvars.define(var.name, var)
2616
+ # replace scope
2617
+ block.scope = scope
2618
+ return block
2619
+
2620
+
2621
+ def is_setup_with(stmt):
2622
+ return isinstance(stmt, ir.EnterWith)
2623
+
2624
+
2625
+ def is_terminator(stmt):
2626
+ return isinstance(stmt, ir.Terminator)
2627
+
2628
+
2629
+ def is_raise(stmt):
2630
+ return isinstance(stmt, ir.Raise)
2631
+
2632
+
2633
+ def is_return(stmt):
2634
+ return isinstance(stmt, ir.Return)
2635
+
2636
+
2637
+ def is_pop_block(stmt):
2638
+ return isinstance(stmt, ir.PopBlock)