numba-cuda 0.21.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (488) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +577 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +556 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +951 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3222 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +558 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +995 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +903 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +158 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsic_wrapper.py +41 -0
  161. numba_cuda/numba/cuda/intrinsics.py +382 -0
  162. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  163. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  164. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  165. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  166. numba_cuda/numba/cuda/libdevice.py +3386 -0
  167. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  168. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  169. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  170. numba_cuda/numba/cuda/locks.py +19 -0
  171. numba_cuda/numba/cuda/lowering.py +1951 -0
  172. numba_cuda/numba/cuda/mathimpl.py +374 -0
  173. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  175. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  178. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  179. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  180. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  181. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  182. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  183. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  184. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  185. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  186. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  187. numba_cuda/numba/cuda/misc/literal.py +28 -0
  188. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  189. numba_cuda/numba/cuda/misc/special.py +94 -0
  190. numba_cuda/numba/cuda/models.py +56 -0
  191. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  192. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  193. numba_cuda/numba/cuda/np/extensions.py +11 -0
  194. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  195. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  196. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  197. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  198. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  199. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  200. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  201. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  202. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  203. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  204. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  206. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  207. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  208. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  209. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  210. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  211. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  212. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  213. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  214. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  215. numba_cuda/numba/cuda/printimpl.py +126 -0
  216. numba_cuda/numba/cuda/random.py +308 -0
  217. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  218. numba_cuda/numba/cuda/serialize.py +267 -0
  219. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  220. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  221. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  222. numba_cuda/numba/cuda/simulator/api.py +179 -0
  223. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  224. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  236. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  237. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  238. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  239. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  241. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  242. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  243. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  244. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  245. numba_cuda/numba/cuda/simulator_init.py +18 -0
  246. numba_cuda/numba/cuda/stubs.py +635 -0
  247. numba_cuda/numba/cuda/target.py +505 -0
  248. numba_cuda/numba/cuda/testing.py +347 -0
  249. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  251. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  252. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  253. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  254. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  255. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +187 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +198 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  285. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  286. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  289. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  290. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  291. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  292. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  293. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  294. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  295. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +889 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  396. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +331 -0
  397. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  399. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  400. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  401. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  402. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  403. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  404. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  406. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  407. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  424. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  425. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +391 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  430. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  431. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  433. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  434. numba_cuda/numba/cuda/tests/support.py +900 -0
  435. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  436. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  437. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  438. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  439. numba_cuda/numba/cuda/types/__init__.py +233 -0
  440. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  441. numba_cuda/numba/cuda/types/abstract.py +9 -0
  442. numba_cuda/numba/cuda/types/common.py +9 -0
  443. numba_cuda/numba/cuda/types/containers.py +9 -0
  444. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  445. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  446. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  447. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  448. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  449. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  450. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  451. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  452. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  453. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  454. numba_cuda/numba/cuda/types/function_type.py +11 -0
  455. numba_cuda/numba/cuda/types/functions.py +9 -0
  456. numba_cuda/numba/cuda/types/iterators.py +9 -0
  457. numba_cuda/numba/cuda/types/misc.py +9 -0
  458. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  459. numba_cuda/numba/cuda/types/scalars.py +9 -0
  460. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  461. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  462. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  463. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  464. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  465. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  466. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  467. numba_cuda/numba/cuda/typing/collections.py +138 -0
  468. numba_cuda/numba/cuda/typing/context.py +782 -0
  469. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  470. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  471. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  472. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  473. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  474. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  475. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  476. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  477. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  478. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  479. numba_cuda/numba/cuda/ufuncs.py +746 -0
  480. numba_cuda/numba/cuda/utils.py +724 -0
  481. numba_cuda/numba/cuda/vector_types.py +214 -0
  482. numba_cuda/numba/cuda/vectorizers.py +260 -0
  483. numba_cuda-0.21.1.dist-info/METADATA +109 -0
  484. numba_cuda-0.21.1.dist-info/RECORD +488 -0
  485. numba_cuda-0.21.1.dist-info/WHEEL +5 -0
  486. numba_cuda-0.21.1.dist-info/licenses/LICENSE +26 -0
  487. numba_cuda-0.21.1.dist-info/licenses/LICENSE.numba +24 -0
  488. numba_cuda-0.21.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.cuda import utils
5
+
6
+
7
+ class DelayedRegistry(utils.UniqueDict):
8
+ """
9
+ A unique dictionary but with deferred initialisation of the values.
10
+
11
+ Attributes
12
+ ----------
13
+ ondemand:
14
+
15
+ A dictionary of key -> value, where value is executed
16
+ the first time it is is used. It is used for part of a deferred
17
+ initialization strategy.
18
+ """
19
+
20
+ def __init__(self, *args, **kws):
21
+ self.ondemand = utils.UniqueDict()
22
+ self.key_type = kws.pop("key_type", None)
23
+ self.value_type = kws.pop("value_type", None)
24
+ self._type_check = self.key_type or self.value_type
25
+ super(DelayedRegistry, self).__init__(*args, **kws)
26
+
27
+ def __getitem__(self, item):
28
+ if item in self.ondemand:
29
+ self[item] = self.ondemand[item]()
30
+ del self.ondemand[item]
31
+ return super(DelayedRegistry, self).__getitem__(item)
32
+
33
+ def __setitem__(self, key, value):
34
+ if self._type_check:
35
+
36
+ def check(x, ty_x):
37
+ if isinstance(ty_x, type):
38
+ assert ty_x in x.__mro__, (x, ty_x)
39
+ else:
40
+ assert isinstance(x, ty_x), (x, ty_x)
41
+
42
+ if self.key_type is not None:
43
+ check(key, self.key_type)
44
+ if self.value_type is not None:
45
+ check(value, self.value_type)
46
+ return super(DelayedRegistry, self).__setitem__(key, value)
@@ -0,0 +1,123 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Implement a rewrite pass on a LLVM module to remove unnecessary
6
+ refcount operations.
7
+ """
8
+
9
+ from llvmlite.ir.transforms import CallVisitor
10
+
11
+ from numba.cuda import types
12
+
13
+
14
+ class _MarkNrtCallVisitor(CallVisitor):
15
+ """
16
+ A pass to mark all NRT_incref and NRT_decref.
17
+ """
18
+
19
+ def __init__(self):
20
+ self.marked = set()
21
+
22
+ def visit_Call(self, instr):
23
+ if getattr(instr.callee, "name", "") in _accepted_nrtfns:
24
+ self.marked.add(instr)
25
+
26
+
27
+ def _rewrite_function(function):
28
+ # Mark NRT usage
29
+ markpass = _MarkNrtCallVisitor()
30
+ markpass.visit_Function(function)
31
+ # Remove NRT usage
32
+ for bb in function.basic_blocks:
33
+ for inst in list(bb.instructions):
34
+ if inst in markpass.marked:
35
+ bb.instructions.remove(inst)
36
+
37
+
38
+ _accepted_nrtfns = "NRT_incref", "NRT_decref"
39
+
40
+
41
+ def _legalize(module, dmm, fndesc):
42
+ """
43
+ Legalize the code in the module.
44
+ Returns True if the module is legal for the rewrite pass that removes
45
+ unnecessary refcounts.
46
+ """
47
+
48
+ def valid_output(ty):
49
+ """
50
+ Valid output are any type that does not need refcount
51
+ """
52
+ model = dmm[ty]
53
+ return not model.contains_nrt_meminfo()
54
+
55
+ def valid_input(ty):
56
+ """
57
+ Valid input are any type that does not need refcount except Array.
58
+ """
59
+ return valid_output(ty) or isinstance(ty, types.Array)
60
+
61
+ # Ensure no reference to function marked as
62
+ # "numba_args_may_always_need_nrt"
63
+ try:
64
+ nmd = module.get_named_metadata("numba_args_may_always_need_nrt")
65
+ except KeyError:
66
+ # Nothing marked
67
+ pass
68
+ else:
69
+ # Has functions marked as "numba_args_may_always_need_nrt"
70
+ if len(nmd.operands) > 0:
71
+ # The pass is illegal for this compilation unit.
72
+ return False
73
+
74
+ # More legalization base on function type
75
+ argtypes = fndesc.argtypes
76
+ restype = fndesc.restype
77
+ calltypes = fndesc.calltypes
78
+
79
+ # Legalize function arguments
80
+ for argty in argtypes:
81
+ if not valid_input(argty):
82
+ return False
83
+
84
+ # Legalize function return
85
+ if not valid_output(restype):
86
+ return False
87
+
88
+ # Legalize all called functions
89
+ for callty in calltypes.values():
90
+ if callty is not None and not valid_output(callty.return_type):
91
+ return False
92
+
93
+ # Ensure no allocation
94
+ for fn in module.functions:
95
+ if fn.name.startswith("NRT_"):
96
+ if fn.name not in _accepted_nrtfns:
97
+ return False
98
+
99
+ return True
100
+
101
+
102
+ def remove_unnecessary_nrt_usage(function, context, fndesc):
103
+ """
104
+ Remove unnecessary NRT incref/decref in the given LLVM function.
105
+ It uses highlevel type info to determine if the function does not need NRT.
106
+ Such a function does not:
107
+
108
+ - return array object(s);
109
+ - take arguments that need refcounting except array;
110
+ - call function(s) that return refcounted object.
111
+
112
+ In effect, the function will not capture or create references that extend
113
+ the lifetime of any refcounted objects beyond the lifetime of the function.
114
+
115
+ The rewrite is performed in place.
116
+ If rewrite has happened, this function returns True, otherwise, it returns False.
117
+ """
118
+ dmm = context.data_model_manager
119
+ if _legalize(function.module, dmm, fndesc):
120
+ _rewrite_function(function)
121
+ return True
122
+ else:
123
+ return False
@@ -0,0 +1,26 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ A subpackage hosting Numba IR rewrite passes.
6
+ """
7
+
8
+ from .registry import register_rewrite, rewrite_registry, Rewrite
9
+
10
+ # Register various built-in rewrite passes
11
+ from numba.cuda.core.rewrites import (
12
+ static_getitem,
13
+ static_raise,
14
+ static_binop,
15
+ ir_print,
16
+ )
17
+
18
+ __all__ = (
19
+ "static_getitem",
20
+ "static_raise",
21
+ "static_binop",
22
+ "ir_print",
23
+ "register_rewrite",
24
+ "rewrite_registry",
25
+ "Rewrite",
26
+ )
@@ -0,0 +1,91 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.cuda.core import errors
5
+ from numba.cuda.core import ir
6
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
7
+
8
+
9
+ @register_rewrite("before-inference")
10
+ class RewritePrintCalls(Rewrite):
11
+ """
12
+ Rewrite calls to the print() global function to dedicated IR print() nodes.
13
+ """
14
+
15
+ def match(self, func_ir, block, typemap, calltypes):
16
+ self.prints = prints = {}
17
+ self.block = block
18
+ # Find all assignments with a right-hand print() call
19
+ for inst in block.find_insts(ir.Assign):
20
+ if isinstance(inst.value, ir.Expr) and inst.value.op == "call":
21
+ expr = inst.value
22
+ try:
23
+ callee = func_ir.infer_constant(expr.func)
24
+ except errors.ConstantInferenceError:
25
+ continue
26
+ if callee is print:
27
+ if expr.kws:
28
+ # Only positional args are supported
29
+ msg = (
30
+ "Numba's print() function implementation does not "
31
+ "support keyword arguments."
32
+ )
33
+ raise errors.UnsupportedError(msg, inst.loc)
34
+ prints[inst] = expr
35
+ return len(prints) > 0
36
+
37
+ def apply(self):
38
+ """
39
+ Rewrite `var = call <print function>(...)` as a sequence of
40
+ `print(...)` and `var = const(None)`.
41
+ """
42
+ new_block = self.block.copy()
43
+ new_block.clear()
44
+ for inst in self.block.body:
45
+ if inst in self.prints:
46
+ expr = self.prints[inst]
47
+ print_node = ir.Print(
48
+ args=expr.args, vararg=expr.vararg, loc=expr.loc
49
+ )
50
+ new_block.append(print_node)
51
+ assign_node = ir.Assign(
52
+ value=ir.Const(None, loc=expr.loc),
53
+ target=inst.target,
54
+ loc=inst.loc,
55
+ )
56
+ new_block.append(assign_node)
57
+ else:
58
+ new_block.append(inst)
59
+ return new_block
60
+
61
+
62
+ @register_rewrite("before-inference")
63
+ class DetectConstPrintArguments(Rewrite):
64
+ """
65
+ Detect and store constant arguments to print() nodes.
66
+ """
67
+
68
+ def match(self, func_ir, block, typemap, calltypes):
69
+ self.consts = consts = {}
70
+ self.block = block
71
+ for inst in block.find_insts(ir.Print):
72
+ if inst.consts:
73
+ # Already rewritten
74
+ continue
75
+ for idx, var in enumerate(inst.args):
76
+ try:
77
+ const = func_ir.infer_constant(var)
78
+ except errors.ConstantInferenceError:
79
+ continue
80
+ consts.setdefault(inst, {})[idx] = const
81
+
82
+ return len(consts) > 0
83
+
84
+ def apply(self):
85
+ """
86
+ Store detected constant arguments on their nodes.
87
+ """
88
+ for inst in self.block.body:
89
+ if inst in self.consts:
90
+ inst.consts = self.consts[inst]
91
+ return self.block
@@ -0,0 +1,104 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from collections import defaultdict
5
+
6
+ from numba.cuda import config
7
+
8
+
9
+ class Rewrite(object):
10
+ """Defines the abstract base class for Numba rewrites."""
11
+
12
+ def __init__(self, state=None):
13
+ """Constructor for the Rewrite class."""
14
+ pass
15
+
16
+ def match(self, func_ir, block, typemap, calltypes) -> bool:
17
+ """Overload this method to check an IR block for matching terms in the
18
+ rewrite.
19
+ """
20
+ return False
21
+
22
+ def apply(self):
23
+ """Overload this method to return a rewritten IR basic block when a
24
+ match has been found.
25
+ """
26
+ raise NotImplementedError("Abstract Rewrite.apply() called!")
27
+
28
+
29
+ class RewriteRegistry(object):
30
+ """Defines a registry for Numba rewrites."""
31
+
32
+ _kinds = frozenset(["before-inference", "after-inference"])
33
+
34
+ def __init__(self):
35
+ """Constructor for the rewrite registry. Initializes the rewrites
36
+ member to an empty list.
37
+ """
38
+ self.rewrites = defaultdict(list)
39
+
40
+ def register(self, kind):
41
+ """
42
+ Decorator adding a subclass of Rewrite to the registry for
43
+ the given *kind*.
44
+ """
45
+ if kind not in self._kinds:
46
+ raise KeyError("invalid kind %r" % (kind,))
47
+
48
+ def do_register(rewrite_cls):
49
+ if not issubclass(rewrite_cls, Rewrite):
50
+ raise TypeError(
51
+ "{0} is not a subclass of Rewrite".format(rewrite_cls)
52
+ )
53
+ self.rewrites[kind].append(rewrite_cls)
54
+ return rewrite_cls
55
+
56
+ return do_register
57
+
58
+ def apply(self, kind, state):
59
+ """Given a pipeline and a dictionary of basic blocks, exhaustively
60
+ attempt to apply all registered rewrites to all basic blocks.
61
+ """
62
+ assert kind in self._kinds
63
+ blocks = state.func_ir.blocks
64
+ old_blocks = blocks.copy()
65
+ for rewrite_cls in self.rewrites[kind]:
66
+ # Exhaustively apply a rewrite until it stops matching.
67
+ rewrite = rewrite_cls(state)
68
+ work_list = list(blocks.items())
69
+ while work_list:
70
+ key, block = work_list.pop()
71
+ matches = rewrite.match(
72
+ state.func_ir, block, state.typemap, state.calltypes
73
+ )
74
+ if matches:
75
+ if config.DEBUG or config.DUMP_IR:
76
+ print("_" * 70)
77
+ print("REWRITING (%s):" % rewrite_cls.__name__)
78
+ block.dump()
79
+ print("_" * 60)
80
+ new_block = rewrite.apply()
81
+ blocks[key] = new_block
82
+ work_list.append((key, new_block))
83
+ if config.DEBUG or config.DUMP_IR:
84
+ new_block.dump()
85
+ print("_" * 70)
86
+ # If any blocks were changed, perform a sanity check.
87
+ for key, block in blocks.items():
88
+ if block != old_blocks[key]:
89
+ block.verify()
90
+
91
+ # Some passes, e.g. _inline_const_arraycall are known to occasionally
92
+ # do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
93
+ # things with ir.Del, but the placement is not optimal. The lines below
94
+ # fix-up the IR so that ref counts are valid and optimally placed,
95
+ # see #4093 for context. This has to be run here opposed to in
96
+ # apply() as the CFG needs computing so full IR is needed.
97
+ from numba.cuda.core import postproc
98
+
99
+ post_proc = postproc.PostProcessor(state.func_ir)
100
+ post_proc.run()
101
+
102
+
103
+ rewrite_registry = RewriteRegistry()
104
+ register_rewrite = rewrite_registry.register
@@ -0,0 +1,41 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.cuda.core import errors
5
+ from numba.cuda.core import ir
6
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
7
+
8
+
9
+ @register_rewrite("before-inference")
10
+ class DetectStaticBinops(Rewrite):
11
+ """
12
+ Detect constant arguments to select binops.
13
+ """
14
+
15
+ # Those operators can benefit from a constant-inferred argument
16
+ rhs_operators = {"**"}
17
+
18
+ def match(self, func_ir, block, typemap, calltypes):
19
+ self.static_lhs = {}
20
+ self.static_rhs = {}
21
+ self.block = block
22
+ # Find binop expressions with a constant lhs or rhs
23
+ for expr in block.find_exprs(op="binop"):
24
+ try:
25
+ if (
26
+ expr.fn in self.rhs_operators
27
+ and expr.static_rhs is ir.UNDEFINED
28
+ ):
29
+ self.static_rhs[expr] = func_ir.infer_constant(expr.rhs)
30
+ except errors.ConstantInferenceError:
31
+ continue
32
+
33
+ return len(self.static_lhs) > 0 or len(self.static_rhs) > 0
34
+
35
+ def apply(self):
36
+ """
37
+ Store constant arguments that were detected in match().
38
+ """
39
+ for expr, rhs in self.static_rhs.items():
40
+ expr.static_rhs = rhs
41
+ return self.block
@@ -0,0 +1,189 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.cuda.core import errors
5
+ from numba.cuda.core import ir
6
+ from numba.cuda import types
7
+ from numba.cuda.core.rewrites import register_rewrite, Rewrite
8
+
9
+
10
+ @register_rewrite("before-inference")
11
+ class RewriteConstGetitems(Rewrite):
12
+ """
13
+ Rewrite IR expressions of the kind `getitem(value=arr, index=$constXX)`
14
+ where `$constXX` is a known constant as
15
+ `static_getitem(value=arr, index=<constant value>)`.
16
+ """
17
+
18
+ def match(self, func_ir, block, typemap, calltypes):
19
+ self.getitems = getitems = {}
20
+ self.block = block
21
+ # Detect all getitem expressions and find which ones can be
22
+ # rewritten
23
+ for expr in block.find_exprs(op="getitem"):
24
+ if expr.op == "getitem":
25
+ try:
26
+ const = func_ir.infer_constant(expr.index)
27
+ except errors.ConstantInferenceError:
28
+ continue
29
+ getitems[expr] = const
30
+
31
+ return len(getitems) > 0
32
+
33
+ def apply(self):
34
+ """
35
+ Rewrite all matching getitems as static_getitems.
36
+ """
37
+ new_block = self.block.copy()
38
+ new_block.clear()
39
+ for inst in self.block.body:
40
+ if isinstance(inst, ir.Assign):
41
+ expr = inst.value
42
+ if expr in self.getitems:
43
+ const = self.getitems[expr]
44
+ new_expr = ir.Expr.static_getitem(
45
+ value=expr.value,
46
+ index=const,
47
+ index_var=expr.index,
48
+ loc=expr.loc,
49
+ )
50
+ inst = ir.Assign(
51
+ value=new_expr, target=inst.target, loc=inst.loc
52
+ )
53
+ new_block.append(inst)
54
+ return new_block
55
+
56
+
57
+ @register_rewrite("after-inference")
58
+ class RewriteStringLiteralGetitems(Rewrite):
59
+ """
60
+ Rewrite IR expressions of the kind `getitem(value=arr, index=$XX)`
61
+ where `$XX` is a StringLiteral value as
62
+ `static_getitem(value=arr, index=<literal value>)`.
63
+ """
64
+
65
+ def match(self, func_ir, block, typemap, calltypes):
66
+ """
67
+ Detect all getitem expressions and find which ones have
68
+ string literal indexes
69
+ """
70
+ self.getitems = getitems = {}
71
+ self.block = block
72
+ self.calltypes = calltypes
73
+ for expr in block.find_exprs(op="getitem"):
74
+ if expr.op == "getitem":
75
+ index_ty = typemap[expr.index.name]
76
+ if isinstance(index_ty, types.StringLiteral):
77
+ getitems[expr] = (expr.index, index_ty.literal_value)
78
+
79
+ return len(getitems) > 0
80
+
81
+ def apply(self):
82
+ """
83
+ Rewrite all matching getitems as static_getitems where the index
84
+ is the literal value of the string.
85
+ """
86
+ new_block = ir.Block(self.block.scope, self.block.loc)
87
+ for inst in self.block.body:
88
+ if isinstance(inst, ir.Assign):
89
+ expr = inst.value
90
+ if expr in self.getitems:
91
+ const, lit_val = self.getitems[expr]
92
+ new_expr = ir.Expr.static_getitem(
93
+ value=expr.value,
94
+ index=lit_val,
95
+ index_var=expr.index,
96
+ loc=expr.loc,
97
+ )
98
+ self.calltypes[new_expr] = self.calltypes[expr]
99
+ inst = ir.Assign(
100
+ value=new_expr, target=inst.target, loc=inst.loc
101
+ )
102
+ new_block.append(inst)
103
+ return new_block
104
+
105
+
106
+ @register_rewrite("after-inference")
107
+ class RewriteStringLiteralSetitems(Rewrite):
108
+ """
109
+ Rewrite IR expressions of the kind `setitem(value=arr, index=$XX, value=)`
110
+ where `$XX` is a StringLiteral value as
111
+ `static_setitem(value=arr, index=<literal value>, value=)`.
112
+ """
113
+
114
+ def match(self, func_ir, block, typemap, calltypes):
115
+ """
116
+ Detect all setitem expressions and find which ones have
117
+ string literal indexes
118
+ """
119
+ self.setitems = setitems = {}
120
+ self.block = block
121
+ self.calltypes = calltypes
122
+ for inst in block.find_insts(ir.SetItem):
123
+ index_ty = typemap[inst.index.name]
124
+ if isinstance(index_ty, types.StringLiteral):
125
+ setitems[inst] = (inst.index, index_ty.literal_value)
126
+
127
+ return len(setitems) > 0
128
+
129
+ def apply(self):
130
+ """
131
+ Rewrite all matching setitems as static_setitems where the index
132
+ is the literal value of the string.
133
+ """
134
+ new_block = ir.Block(self.block.scope, self.block.loc)
135
+ for inst in self.block.body:
136
+ if isinstance(inst, ir.SetItem):
137
+ if inst in self.setitems:
138
+ const, lit_val = self.setitems[inst]
139
+ new_inst = ir.StaticSetItem(
140
+ target=inst.target,
141
+ index=lit_val,
142
+ index_var=inst.index,
143
+ value=inst.value,
144
+ loc=inst.loc,
145
+ )
146
+ self.calltypes[new_inst] = self.calltypes[inst]
147
+ inst = new_inst
148
+ new_block.append(inst)
149
+ return new_block
150
+
151
+
152
+ @register_rewrite("before-inference")
153
+ class RewriteConstSetitems(Rewrite):
154
+ """
155
+ Rewrite IR statements of the kind `setitem(target=arr, index=$constXX, ...)`
156
+ where `$constXX` is a known constant as
157
+ `static_setitem(target=arr, index=<constant value>, ...)`.
158
+ """
159
+
160
+ def match(self, func_ir, block, typemap, calltypes):
161
+ self.setitems = setitems = {}
162
+ self.block = block
163
+ # Detect all setitem statements and find which ones can be
164
+ # rewritten
165
+ for inst in block.find_insts(ir.SetItem):
166
+ try:
167
+ const = func_ir.infer_constant(inst.index)
168
+ except errors.ConstantInferenceError:
169
+ continue
170
+ setitems[inst] = const
171
+
172
+ return len(setitems) > 0
173
+
174
+ def apply(self):
175
+ """
176
+ Rewrite all matching setitems as static_setitems.
177
+ """
178
+ new_block = self.block.copy()
179
+ new_block.clear()
180
+ for inst in self.block.body:
181
+ if inst in self.setitems:
182
+ const = self.setitems[inst]
183
+ new_inst = ir.StaticSetItem(
184
+ inst.target, const, inst.index, inst.value, inst.loc
185
+ )
186
+ new_block.append(new_inst)
187
+ else:
188
+ new_block.append(inst)
189
+ return new_block