numba-cuda 0.22.0__cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (487) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +580 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpython-312-aarch64-linux-gnu.so +0 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpython-312-aarch64-linux-gnu.so +0 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cpython-312-aarch64-linux-gnu.so +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpython-312-aarch64-linux-gnu.so +0 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cpython-312-aarch64-linux-gnu.so +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +543 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +954 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3238 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +562 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +983 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +997 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +155 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsics.py +531 -0
  161. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  162. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  163. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  164. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  165. numba_cuda/numba/cuda/libdevice.py +3386 -0
  166. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  167. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  168. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  169. numba_cuda/numba/cuda/locks.py +19 -0
  170. numba_cuda/numba/cuda/lowering.py +1980 -0
  171. numba_cuda/numba/cuda/mathimpl.py +374 -0
  172. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  173. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  175. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  178. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  179. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  180. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  181. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  182. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  183. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  184. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  185. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  186. numba_cuda/numba/cuda/misc/literal.py +28 -0
  187. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  188. numba_cuda/numba/cuda/misc/special.py +94 -0
  189. numba_cuda/numba/cuda/models.py +56 -0
  190. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  191. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  192. numba_cuda/numba/cuda/np/extensions.py +11 -0
  193. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  194. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  195. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  196. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  197. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  198. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  199. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  200. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  201. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  202. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  203. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  204. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  206. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  207. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  208. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  209. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  210. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  211. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  212. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  213. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  214. numba_cuda/numba/cuda/printimpl.py +126 -0
  215. numba_cuda/numba/cuda/random.py +308 -0
  216. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  217. numba_cuda/numba/cuda/serialize.py +267 -0
  218. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  219. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  220. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  221. numba_cuda/numba/cuda/simulator/api.py +179 -0
  222. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  223. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  224. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  236. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  237. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  238. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  239. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  241. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  242. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  243. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  244. numba_cuda/numba/cuda/simulator_init.py +18 -0
  245. numba_cuda/numba/cuda/stubs.py +624 -0
  246. numba_cuda/numba/cuda/target.py +505 -0
  247. numba_cuda/numba/cuda/testing.py +347 -0
  248. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  249. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  251. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  252. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  253. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  254. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  255. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +191 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +200 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  285. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  286. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  289. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  290. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  291. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  292. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  293. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  294. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  295. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +978 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +446 -0
  396. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  397. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  399. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  400. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  401. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  402. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  403. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  404. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  406. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  407. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  424. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  425. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +452 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  430. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  431. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  433. numba_cuda/numba/cuda/tests/support.py +900 -0
  434. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  435. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  436. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  437. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  438. numba_cuda/numba/cuda/types/__init__.py +233 -0
  439. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  440. numba_cuda/numba/cuda/types/abstract.py +9 -0
  441. numba_cuda/numba/cuda/types/common.py +9 -0
  442. numba_cuda/numba/cuda/types/containers.py +9 -0
  443. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  444. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  445. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  446. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  447. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  448. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  449. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  450. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  451. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  452. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  453. numba_cuda/numba/cuda/types/function_type.py +11 -0
  454. numba_cuda/numba/cuda/types/functions.py +9 -0
  455. numba_cuda/numba/cuda/types/iterators.py +9 -0
  456. numba_cuda/numba/cuda/types/misc.py +9 -0
  457. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  458. numba_cuda/numba/cuda/types/scalars.py +9 -0
  459. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  460. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  461. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  462. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  463. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  464. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  465. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  466. numba_cuda/numba/cuda/typing/collections.py +138 -0
  467. numba_cuda/numba/cuda/typing/context.py +782 -0
  468. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  469. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  470. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  471. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  472. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  473. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  474. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  475. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  476. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  477. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  478. numba_cuda/numba/cuda/ufuncs.py +746 -0
  479. numba_cuda/numba/cuda/utils.py +724 -0
  480. numba_cuda/numba/cuda/vector_types.py +214 -0
  481. numba_cuda/numba/cuda/vectorizers.py +260 -0
  482. numba_cuda-0.22.0.dist-info/METADATA +109 -0
  483. numba_cuda-0.22.0.dist-info/RECORD +487 -0
  484. numba_cuda-0.22.0.dist-info/WHEEL +6 -0
  485. numba_cuda-0.22.0.dist-info/licenses/LICENSE +26 -0
  486. numba_cuda-0.22.0.dist-info/licenses/LICENSE.numba +24 -0
  487. numba_cuda-0.22.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2463 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ import numpy as np
5
+ import os
6
+ import sys
7
+ import ctypes
8
+ import collections
9
+ import functools
10
+ import types as pytypes
11
+ import weakref
12
+ from contextlib import ExitStack
13
+ from abc import abstractmethod
14
+ import uuid
15
+ import re
16
+ from warnings import warn
17
+
18
+ from numba.cuda.core import errors
19
+ from numba.cuda import serialize, utils
20
+ from numba import cuda
21
+
22
+ from numba.cuda.core.compiler_lock import global_compiler_lock
23
+ from numba.cuda.typeconv.rules import default_type_manager
24
+ from numba.cuda.typing.templates import fold_arguments
25
+ from numba.cuda.typing.typeof import Purpose, typeof
26
+
27
+ from numba.cuda import typing, types
28
+ from numba.cuda.types import ext_types
29
+ from numba.cuda.api import get_current_device
30
+ from numba.cuda.args import wrap_arg
31
+ from numba.cuda.core.bytecode import get_code_object
32
+ from numba.cuda.compiler import (
33
+ compile_cuda,
34
+ CUDACompiler,
35
+ kernel_fixup,
36
+ compile_extra,
37
+ compile_ir,
38
+ )
39
+ from numba.cuda.core import sigutils, config, entrypoints
40
+ from numba.cuda.flags import Flags
41
+ from numba.cuda.cudadrv import driver, nvvm
42
+ from numba.cuda.locks import module_init_lock
43
+ from numba.cuda.core.caching import Cache, CacheImpl, NullCache
44
+ from numba.cuda.descriptor import cuda_target
45
+ from numba.cuda.errors import (
46
+ missing_launch_config_msg,
47
+ normalize_kernel_dimensions,
48
+ )
49
+ from numba.cuda.cudadrv.linkable_code import LinkableCode
50
+ from numba.cuda.cudadrv.devices import get_context
51
+ from numba.cuda.memory_management.nrt import rtsys, NRT_LIBRARY
52
+ import numba.cuda.core.event as ev
53
+ from numba.cuda.cext import _dispatcher
54
+
55
+
56
+ cuda_fp16_math_funcs = [
57
+ "hsin",
58
+ "hcos",
59
+ "hlog",
60
+ "hlog10",
61
+ "hlog2",
62
+ "hexp",
63
+ "hexp10",
64
+ "hexp2",
65
+ "hsqrt",
66
+ "hrsqrt",
67
+ "hfloor",
68
+ "hceil",
69
+ "hrcp",
70
+ "hrint",
71
+ "htrunc",
72
+ "hdiv",
73
+ ]
74
+
75
+ reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"]
76
+
77
+
78
+ class _Kernel(serialize.ReduceMixin):
79
+ """
80
+ CUDA Kernel specialized for a given set of argument types. When called, this
81
+ object launches the kernel on the device.
82
+ """
83
+
84
+ NRT_functions = [
85
+ "NRT_Allocate",
86
+ "NRT_MemInfo_init",
87
+ "NRT_MemInfo_new",
88
+ "NRT_Free",
89
+ "NRT_dealloc",
90
+ "NRT_MemInfo_destroy",
91
+ "NRT_MemInfo_call_dtor",
92
+ "NRT_MemInfo_data_fast",
93
+ "NRT_MemInfo_alloc_aligned",
94
+ "NRT_Allocate_External",
95
+ "NRT_decref",
96
+ "NRT_incref",
97
+ ]
98
+
99
+ @global_compiler_lock
100
+ def __init__(
101
+ self,
102
+ py_func,
103
+ argtypes,
104
+ link=None,
105
+ debug=False,
106
+ lineinfo=False,
107
+ inline=False,
108
+ forceinline=False,
109
+ fastmath=False,
110
+ extensions=None,
111
+ max_registers=None,
112
+ lto=False,
113
+ opt=True,
114
+ device=False,
115
+ launch_bounds=None,
116
+ ):
117
+ if device:
118
+ raise RuntimeError("Cannot compile a device function as a kernel")
119
+
120
+ super().__init__()
121
+
122
+ # _DispatcherBase.nopython_signatures() expects this attribute to be
123
+ # present, because it assumes an overload is a CompileResult. In the
124
+ # CUDA target, _Kernel instances are stored instead, so we provide this
125
+ # attribute here to avoid duplicating nopython_signatures() in the CUDA
126
+ # target with slight modifications.
127
+ self.objectmode = False
128
+
129
+ # The finalizer constructed by _DispatcherBase._make_finalizer also
130
+ # expects overloads to be a CompileResult. It uses the entry_point to
131
+ # remove a CompileResult from a target context. However, since we never
132
+ # insert kernels into a target context (there is no need because they
133
+ # cannot be called by other functions, only through the dispatcher) it
134
+ # suffices to pretend we have an entry point of None.
135
+ self.entry_point = None
136
+
137
+ self.py_func = py_func
138
+ self.argtypes = argtypes
139
+ self.debug = debug
140
+ self.lineinfo = lineinfo
141
+ self.extensions = extensions or []
142
+ self.launch_bounds = launch_bounds
143
+
144
+ nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
145
+
146
+ if debug:
147
+ nvvm_options["g"] = None
148
+
149
+ cc = get_current_device().compute_capability
150
+
151
+ cres = compile_cuda(
152
+ self.py_func,
153
+ types.void,
154
+ self.argtypes,
155
+ debug=self.debug,
156
+ lineinfo=lineinfo,
157
+ forceinline=forceinline,
158
+ fastmath=fastmath,
159
+ nvvm_options=nvvm_options,
160
+ cc=cc,
161
+ max_registers=max_registers,
162
+ lto=lto,
163
+ )
164
+ tgt_ctx = cres.target_context
165
+ lib = cres.library
166
+ kernel = lib.get_function(cres.fndesc.llvm_func_name)
167
+ lib._entry_name = cres.fndesc.llvm_func_name
168
+ kernel_fixup(kernel, self.debug)
169
+ nvvm.set_launch_bounds(kernel, launch_bounds)
170
+
171
+ if not link:
172
+ link = []
173
+
174
+ asm = lib.get_asm_str()
175
+
176
+ # The code library contains functions that require cooperative launch.
177
+ self.cooperative = lib.use_cooperative
178
+ # We need to link against cudadevrt if grid sync is being used.
179
+ if self.cooperative:
180
+ lib.needs_cudadevrt = True
181
+
182
+ def link_to_library_functions(
183
+ library_functions, library_path, prefix=None
184
+ ):
185
+ """
186
+ Dynamically links to library functions by searching for their names
187
+ in the specified library and linking to the corresponding source
188
+ file.
189
+ """
190
+ if prefix is not None:
191
+ library_functions = [
192
+ f"{prefix}{fn}" for fn in library_functions
193
+ ]
194
+
195
+ found_functions = [fn for fn in library_functions if f"{fn}" in asm]
196
+
197
+ if found_functions:
198
+ basedir = os.path.dirname(os.path.abspath(__file__))
199
+ source_file_path = os.path.join(basedir, library_path)
200
+ link.append(source_file_path)
201
+
202
+ return found_functions
203
+
204
+ # Link to the helper library functions if needed
205
+ link_to_library_functions(reshape_funcs, "reshape_funcs.cu")
206
+
207
+ self.maybe_link_nrt(link, tgt_ctx, asm)
208
+
209
+ for filepath in link:
210
+ lib.add_linking_file(filepath)
211
+
212
+ # populate members
213
+ self.entry_name = kernel.name
214
+ self.signature = cres.signature
215
+ self._type_annotation = cres.type_annotation
216
+ self._codelibrary = lib
217
+ self.call_helper = cres.call_helper
218
+
219
+ # The following are referred to by the cache implementation. Note:
220
+ # - There are no referenced environments in CUDA.
221
+ # - Kernels don't have lifted code.
222
+ self.target_context = tgt_ctx
223
+ self.fndesc = cres.fndesc
224
+ self.environment = cres.environment
225
+ self._referenced_environments = []
226
+ self.lifted = []
227
+
228
+ def maybe_link_nrt(self, link, tgt_ctx, asm):
229
+ """
230
+ Add the NRT source code to the link if the neccesary conditions are met.
231
+ NRT must be enabled for the CUDATargetContext, and either NRT functions
232
+ must be detected in the kernel asm or an NRT enabled LinkableCode object
233
+ must be passed.
234
+ """
235
+
236
+ if not tgt_ctx.enable_nrt:
237
+ return
238
+
239
+ all_nrt = "|".join(self.NRT_functions)
240
+ pattern = (
241
+ r"\.extern\s+\.func\s+(?:\s*\(.+\)\s*)?("
242
+ + all_nrt
243
+ + r")\s*\([^)]*\)\s*;"
244
+ )
245
+ link_nrt = False
246
+ nrt_in_asm = re.findall(pattern, asm)
247
+ if len(nrt_in_asm) > 0:
248
+ link_nrt = True
249
+ if not link_nrt:
250
+ for file in link:
251
+ if isinstance(file, LinkableCode):
252
+ if file.nrt:
253
+ link_nrt = True
254
+ break
255
+
256
+ if link_nrt:
257
+ link.append(NRT_LIBRARY)
258
+
259
+ @property
260
+ def library(self):
261
+ return self._codelibrary
262
+
263
+ @property
264
+ def type_annotation(self):
265
+ return self._type_annotation
266
+
267
+ def _find_referenced_environments(self):
268
+ return self._referenced_environments
269
+
270
+ @property
271
+ def codegen(self):
272
+ return self.target_context.codegen()
273
+
274
+ @property
275
+ def argument_types(self):
276
+ return tuple(self.signature.args)
277
+
278
+ @classmethod
279
+ def _rebuild(
280
+ cls,
281
+ cooperative,
282
+ name,
283
+ signature,
284
+ codelibrary,
285
+ debug,
286
+ lineinfo,
287
+ call_helper,
288
+ extensions,
289
+ ):
290
+ """
291
+ Rebuild an instance.
292
+ """
293
+ instance = cls.__new__(cls)
294
+ # invoke parent constructor
295
+ super(cls, instance).__init__()
296
+ # populate members
297
+ instance.entry_point = None
298
+ instance.cooperative = cooperative
299
+ instance.entry_name = name
300
+ instance.signature = signature
301
+ instance._type_annotation = None
302
+ instance._codelibrary = codelibrary
303
+ instance.debug = debug
304
+ instance.lineinfo = lineinfo
305
+ instance.call_helper = call_helper
306
+ instance.extensions = extensions
307
+ return instance
308
+
309
+ def _reduce_states(self):
310
+ """
311
+ Reduce the instance for serialization.
312
+ Compiled definitions are serialized in PTX form.
313
+ Type annotation are discarded.
314
+ Thread, block and shared memory configuration are serialized.
315
+ Stream information is discarded.
316
+ """
317
+ return dict(
318
+ cooperative=self.cooperative,
319
+ name=self.entry_name,
320
+ signature=self.signature,
321
+ codelibrary=self._codelibrary,
322
+ debug=self.debug,
323
+ lineinfo=self.lineinfo,
324
+ call_helper=self.call_helper,
325
+ extensions=self.extensions,
326
+ )
327
+
328
+ @module_init_lock
329
+ def initialize_once(self, mod):
330
+ if not mod.initialized:
331
+ mod.setup()
332
+
333
+ def bind(self):
334
+ """
335
+ Force binding to current CUDA context
336
+ """
337
+ cufunc = self._codelibrary.get_cufunc()
338
+
339
+ self.initialize_once(cufunc.module)
340
+
341
+ if (
342
+ hasattr(self, "target_context")
343
+ and self.target_context.enable_nrt
344
+ and config.CUDA_NRT_STATS
345
+ ):
346
+ rtsys.ensure_initialized()
347
+ rtsys.set_memsys_to_module(cufunc.module)
348
+ # We don't know which stream the kernel will be launched on, so
349
+ # we force synchronize here.
350
+ cuda.synchronize()
351
+
352
+ @property
353
+ def regs_per_thread(self):
354
+ """
355
+ The number of registers used by each thread for this kernel.
356
+ """
357
+ return self._codelibrary.get_cufunc().attrs.regs
358
+
359
+ @property
360
+ def const_mem_size(self):
361
+ """
362
+ The amount of constant memory used by this kernel.
363
+ """
364
+ return self._codelibrary.get_cufunc().attrs.const
365
+
366
+ @property
367
+ def shared_mem_per_block(self):
368
+ """
369
+ The amount of shared memory used per block for this kernel.
370
+ """
371
+ return self._codelibrary.get_cufunc().attrs.shared
372
+
373
+ @property
374
+ def max_threads_per_block(self):
375
+ """
376
+ The maximum allowable threads per block.
377
+ """
378
+ return self._codelibrary.get_cufunc().attrs.maxthreads
379
+
380
+ @property
381
+ def local_mem_per_thread(self):
382
+ """
383
+ The amount of local memory used per thread for this kernel.
384
+ """
385
+ return self._codelibrary.get_cufunc().attrs.local
386
+
387
+ def inspect_llvm(self):
388
+ """
389
+ Returns the LLVM IR for this kernel.
390
+ """
391
+ return self._codelibrary.get_llvm_str()
392
+
393
+ def inspect_asm(self, cc):
394
+ """
395
+ Returns the PTX code for this kernel.
396
+ """
397
+ return self._codelibrary.get_asm_str(cc=cc)
398
+
399
+ def inspect_lto_ptx(self, cc):
400
+ """
401
+ Returns the PTX code for the external functions linked to this kernel.
402
+ """
403
+ return self._codelibrary.get_lto_ptx(cc=cc)
404
+
405
+ def inspect_sass_cfg(self):
406
+ """
407
+ Returns the CFG of the SASS for this kernel.
408
+
409
+ Requires nvdisasm to be available on the PATH.
410
+ """
411
+ return self._codelibrary.get_sass_cfg()
412
+
413
+ def inspect_sass(self):
414
+ """
415
+ Returns the SASS code for this kernel.
416
+
417
+ Requires nvdisasm to be available on the PATH.
418
+ """
419
+ return self._codelibrary.get_sass()
420
+
421
+ def inspect_types(self, file=None):
422
+ """
423
+ Produce a dump of the Python source of this function annotated with the
424
+ corresponding Numba IR and type information. The dump is written to
425
+ *file*, or *sys.stdout* if *file* is *None*.
426
+ """
427
+ if self._type_annotation is None:
428
+ raise ValueError("Type annotation is not available")
429
+
430
+ if file is None:
431
+ file = sys.stdout
432
+
433
+ print("%s %s" % (self.entry_name, self.argument_types), file=file)
434
+ print("-" * 80, file=file)
435
+ print(self._type_annotation, file=file)
436
+ print("=" * 80, file=file)
437
+
438
+ def max_cooperative_grid_blocks(self, blockdim, dynsmemsize=0):
439
+ """
440
+ Calculates the maximum number of blocks that can be launched for this
441
+ kernel in a cooperative grid in the current context, for the given block
442
+ and dynamic shared memory sizes.
443
+
444
+ :param blockdim: Block dimensions, either as a scalar for a 1D block, or
445
+ a tuple for 2D or 3D blocks.
446
+ :param dynsmemsize: Dynamic shared memory size in bytes.
447
+ :return: The maximum number of blocks in the grid.
448
+ """
449
+ ctx = get_context()
450
+ cufunc = self._codelibrary.get_cufunc()
451
+
452
+ if isinstance(blockdim, tuple):
453
+ blockdim = functools.reduce(lambda x, y: x * y, blockdim)
454
+ active_per_sm = ctx.get_active_blocks_per_multiprocessor(
455
+ cufunc, blockdim, dynsmemsize
456
+ )
457
+ sm_count = ctx.device.MULTIPROCESSOR_COUNT
458
+ return active_per_sm * sm_count
459
+
460
+ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
461
+ # Prepare kernel
462
+ cufunc = self._codelibrary.get_cufunc()
463
+
464
+ if self.debug:
465
+ excname = cufunc.name + "__errcode__"
466
+ excmem, excsz = cufunc.module.get_global_symbol(excname)
467
+ assert excsz == ctypes.sizeof(ctypes.c_int)
468
+ excval = ctypes.c_int()
469
+ excmem.memset(0, stream=stream)
470
+
471
+ # Prepare arguments
472
+ retr = [] # hold functors for writeback
473
+
474
+ kernelargs = []
475
+ for t, v in zip(self.argument_types, args):
476
+ self._prepare_args(t, v, stream, retr, kernelargs)
477
+
478
+ stream_handle = driver._stream_handle(stream)
479
+
480
+ # Invoke kernel
481
+ driver.launch_kernel(
482
+ cufunc.handle,
483
+ *griddim,
484
+ *blockdim,
485
+ sharedmem,
486
+ stream_handle,
487
+ kernelargs,
488
+ cooperative=self.cooperative,
489
+ )
490
+
491
+ if self.debug:
492
+ driver.device_to_host(ctypes.addressof(excval), excmem, excsz)
493
+ if excval.value != 0:
494
+ # An error occurred
495
+ def load_symbol(name):
496
+ mem, sz = cufunc.module.get_global_symbol(
497
+ "%s__%s__" % (cufunc.name, name)
498
+ )
499
+ val = ctypes.c_int()
500
+ driver.device_to_host(ctypes.addressof(val), mem, sz)
501
+ return val.value
502
+
503
+ tid = [load_symbol("tid" + i) for i in "zyx"]
504
+ ctaid = [load_symbol("ctaid" + i) for i in "zyx"]
505
+ code = excval.value
506
+ exccls, exc_args, loc = self.call_helper.get_exception(code)
507
+ # Prefix the exception message with the source location
508
+ if loc is None:
509
+ locinfo = ""
510
+ else:
511
+ sym, filepath, lineno = loc
512
+ filepath = os.path.abspath(filepath)
513
+ locinfo = "In function %r, file %s, line %s, " % (
514
+ sym,
515
+ filepath,
516
+ lineno,
517
+ )
518
+ # Prefix the exception message with the thread position
519
+ prefix = "%stid=%s ctaid=%s" % (locinfo, tid, ctaid)
520
+ if exc_args:
521
+ exc_args = ("%s: %s" % (prefix, exc_args[0]),) + exc_args[
522
+ 1:
523
+ ]
524
+ else:
525
+ exc_args = (prefix,)
526
+ raise exccls(*exc_args)
527
+
528
+ # retrieve auto converted arrays
529
+ for wb in retr:
530
+ wb()
531
+
532
+ def _prepare_args(self, ty, val, stream, retr, kernelargs):
533
+ """
534
+ Convert arguments to ctypes and append to kernelargs
535
+ """
536
+
537
+ # map the arguments using any extension you've registered
538
+ for extension in reversed(self.extensions):
539
+ ty, val = extension.prepare_args(ty, val, stream=stream, retr=retr)
540
+
541
+ if isinstance(ty, types.Array):
542
+ devary = wrap_arg(val).to_device(retr, stream)
543
+ c_intp = ctypes.c_ssize_t
544
+
545
+ meminfo = ctypes.c_void_p(0)
546
+ parent = ctypes.c_void_p(0)
547
+ nitems = c_intp(devary.size)
548
+ itemsize = c_intp(devary.dtype.itemsize)
549
+
550
+ ptr = driver.device_pointer(devary)
551
+
552
+ ptr = int(ptr)
553
+
554
+ data = ctypes.c_void_p(ptr)
555
+
556
+ kernelargs.append(meminfo)
557
+ kernelargs.append(parent)
558
+ kernelargs.append(nitems)
559
+ kernelargs.append(itemsize)
560
+ kernelargs.append(data)
561
+ kernelargs.extend(map(c_intp, devary.shape))
562
+ kernelargs.extend(map(c_intp, devary.strides))
563
+
564
+ elif isinstance(ty, types.CPointer):
565
+ # Pointer arguments should be a pointer-sized integer
566
+ kernelargs.append(ctypes.c_uint64(val))
567
+
568
+ elif isinstance(ty, types.Integer):
569
+ cval = getattr(ctypes, "c_%s" % ty)(val)
570
+ kernelargs.append(cval)
571
+
572
+ elif ty == types.float16:
573
+ cval = ctypes.c_uint16(np.float16(val).view(np.uint16))
574
+ kernelargs.append(cval)
575
+
576
+ elif ty == types.float64:
577
+ cval = ctypes.c_double(val)
578
+ kernelargs.append(cval)
579
+
580
+ elif ty == types.float32:
581
+ cval = ctypes.c_float(val)
582
+ kernelargs.append(cval)
583
+
584
+ elif ty == types.boolean:
585
+ cval = ctypes.c_uint8(int(val))
586
+ kernelargs.append(cval)
587
+
588
+ elif ty == types.complex64:
589
+ kernelargs.append(ctypes.c_float(val.real))
590
+ kernelargs.append(ctypes.c_float(val.imag))
591
+
592
+ elif ty == types.complex128:
593
+ kernelargs.append(ctypes.c_double(val.real))
594
+ kernelargs.append(ctypes.c_double(val.imag))
595
+
596
+ elif isinstance(ty, (types.NPDatetime, types.NPTimedelta)):
597
+ kernelargs.append(ctypes.c_int64(val.view(np.int64)))
598
+
599
+ elif isinstance(ty, types.Record):
600
+ devrec = wrap_arg(val).to_device(retr, stream)
601
+ ptr = devrec.device_ctypes_pointer
602
+ kernelargs.append(ptr)
603
+
604
+ elif isinstance(ty, types.BaseTuple):
605
+ assert len(ty) == len(val)
606
+ for t, v in zip(ty, val):
607
+ self._prepare_args(t, v, stream, retr, kernelargs)
608
+
609
+ elif isinstance(ty, types.EnumMember):
610
+ try:
611
+ self._prepare_args(
612
+ ty.dtype, val.value, stream, retr, kernelargs
613
+ )
614
+ except NotImplementedError:
615
+ raise NotImplementedError(ty, val)
616
+
617
+ else:
618
+ raise NotImplementedError(ty, val)
619
+
620
+
621
+ class ForAll(object):
622
+ def __init__(self, dispatcher, ntasks, tpb, stream, sharedmem):
623
+ if ntasks < 0:
624
+ raise ValueError(
625
+ "Can't create ForAll with negative task count: %s" % ntasks
626
+ )
627
+ self.dispatcher = dispatcher
628
+ self.ntasks = ntasks
629
+ self.thread_per_block = tpb
630
+ self.stream = stream
631
+ self.sharedmem = sharedmem
632
+
633
+ def __call__(self, *args):
634
+ if self.ntasks == 0:
635
+ return
636
+
637
+ if self.dispatcher.specialized:
638
+ specialized = self.dispatcher
639
+ else:
640
+ specialized = self.dispatcher.specialize(*args)
641
+ blockdim = self._compute_thread_per_block(specialized)
642
+ griddim = (self.ntasks + blockdim - 1) // blockdim
643
+
644
+ return specialized[griddim, blockdim, self.stream, self.sharedmem](
645
+ *args
646
+ )
647
+
648
+ def _compute_thread_per_block(self, dispatcher):
649
+ tpb = self.thread_per_block
650
+ # Prefer user-specified config
651
+ if tpb != 0:
652
+ return tpb
653
+ # Else, ask the driver to give a good config
654
+ else:
655
+ ctx = get_context()
656
+ # Dispatcher is specialized, so there's only one definition - get
657
+ # it so we can get the cufunc from the code library
658
+ kernel = next(iter(dispatcher.overloads.values()))
659
+ kwargs = dict(
660
+ func=kernel._codelibrary.get_cufunc(),
661
+ b2d_func=0, # dynamic-shared memory is constant to blksz
662
+ memsize=self.sharedmem,
663
+ blocksizelimit=1024,
664
+ )
665
+ _, tpb = ctx.get_max_potential_block_size(**kwargs)
666
+ return tpb
667
+
668
+
669
+ class _LaunchConfiguration:
670
+ def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem):
671
+ self.dispatcher = dispatcher
672
+ self.griddim = griddim
673
+ self.blockdim = blockdim
674
+ self.stream = stream
675
+ self.sharedmem = sharedmem
676
+
677
+ if (
678
+ config.CUDA_LOW_OCCUPANCY_WARNINGS
679
+ and not config.DISABLE_PERFORMANCE_WARNINGS
680
+ ):
681
+ # Warn when the grid has fewer than 128 blocks. This number is
682
+ # chosen somewhat heuristically - ideally the minimum is 2 times
683
+ # the number of SMs, but the number of SMs varies between devices -
684
+ # some very small GPUs might only have 4 SMs, but an H100-SXM5 has
685
+ # 132. In general kernels should be launched with large grids
686
+ # (hundreds or thousands of blocks), so warning when fewer than 128
687
+ # blocks are used will likely catch most beginner errors, where the
688
+ # grid tends to be very small (single-digit or low tens of blocks).
689
+ min_grid_size = 128
690
+ grid_size = griddim[0] * griddim[1] * griddim[2]
691
+ if grid_size < min_grid_size:
692
+ msg = (
693
+ f"Grid size {grid_size} will likely result in GPU "
694
+ "under-utilization due to low occupancy."
695
+ )
696
+ warn(errors.NumbaPerformanceWarning(msg))
697
+
698
+ def __call__(self, *args):
699
+ return self.dispatcher.call(
700
+ args, self.griddim, self.blockdim, self.stream, self.sharedmem
701
+ )
702
+
703
+
704
+ class CUDACacheImpl(CacheImpl):
705
+ def reduce(self, kernel):
706
+ return kernel._reduce_states()
707
+
708
+ def rebuild(self, target_context, payload):
709
+ return _Kernel._rebuild(**payload)
710
+
711
+ def check_cachable(self, cres):
712
+ # CUDA Kernels are always cachable - the reasons for an entity not to
713
+ # be cachable are:
714
+ #
715
+ # - The presence of lifted loops, or
716
+ # - The presence of dynamic globals.
717
+ #
718
+ # neither of which apply to CUDA kernels.
719
+ return True
720
+
721
+
722
+ class CUDACache(Cache):
723
+ """
724
+ Implements a cache that saves and loads CUDA kernels and compile results.
725
+ """
726
+
727
+ _impl_class = CUDACacheImpl
728
+
729
+ def load_overload(self, sig, target_context):
730
+ # Loading an overload refreshes the context to ensure it is initialized.
731
+ with utils.numba_target_override():
732
+ return super().load_overload(sig, target_context)
733
+
734
+
735
+ class OmittedArg(object):
736
+ """
737
+ A placeholder for omitted arguments with a default value.
738
+ """
739
+
740
+ def __init__(self, value):
741
+ self.value = value
742
+
743
+ def __repr__(self):
744
+ return "omitted arg(%r)" % (self.value,)
745
+
746
+ @property
747
+ def _numba_type_(self):
748
+ return types.Omitted(self.value)
749
+
750
+
751
+ class CompilingCounter(object):
752
+ """
753
+ A simple counter that increment in __enter__ and decrement in __exit__.
754
+ """
755
+
756
+ def __init__(self):
757
+ self.counter = 0
758
+
759
+ def __enter__(self):
760
+ assert self.counter >= 0
761
+ self.counter += 1
762
+
763
+ def __exit__(self, *args, **kwargs):
764
+ self.counter -= 1
765
+ assert self.counter >= 0
766
+
767
+ def __bool__(self):
768
+ return self.counter > 0
769
+
770
+ __nonzero__ = __bool__
771
+
772
+
773
+ class _DispatcherBase(_dispatcher.Dispatcher):
774
+ """
775
+ Common base class for dispatcher Implementations.
776
+ """
777
+
778
+ __numba__ = "py_func"
779
+
780
+ def __init__(
781
+ self, arg_count, py_func, pysig, can_fallback, exact_match_required
782
+ ):
783
+ self._tm = default_type_manager
784
+
785
+ # A mapping of signatures to compile results
786
+ self.overloads = collections.OrderedDict()
787
+
788
+ self.py_func = py_func
789
+ # other parts of Numba assume the old Python 2 name for code object
790
+ self.func_code = get_code_object(py_func)
791
+ # but newer python uses a different name
792
+ self.__code__ = self.func_code
793
+ # a place to keep an active reference to the types of the active call
794
+ self._types_active_call = set()
795
+ # Default argument values match the py_func
796
+ self.__defaults__ = py_func.__defaults__
797
+
798
+ argnames = tuple(pysig.parameters)
799
+ default_values = self.py_func.__defaults__ or ()
800
+ defargs = tuple(OmittedArg(val) for val in default_values)
801
+ try:
802
+ lastarg = list(pysig.parameters.values())[-1]
803
+ except IndexError:
804
+ has_stararg = False
805
+ else:
806
+ has_stararg = lastarg.kind == lastarg.VAR_POSITIONAL
807
+ _dispatcher.Dispatcher.__init__(
808
+ self,
809
+ self._tm.get_pointer(),
810
+ arg_count,
811
+ self._fold_args,
812
+ argnames,
813
+ defargs,
814
+ can_fallback,
815
+ has_stararg,
816
+ exact_match_required,
817
+ )
818
+
819
+ self.doc = py_func.__doc__
820
+ self._compiling_counter = CompilingCounter()
821
+ weakref.finalize(self, self._make_finalizer())
822
+
823
+ def _compilation_chain_init_hook(self):
824
+ """
825
+ This will be called ahead of any part of compilation taking place (this
826
+ even includes being ahead of working out the types of the arguments).
827
+ This permits activities such as initialising extension entry points so
828
+ that the compiler knows about additional externally defined types etc
829
+ before it does anything.
830
+ """
831
+ entrypoints.init_all()
832
+
833
+ def _reset_overloads(self):
834
+ self._clear()
835
+ self.overloads.clear()
836
+
837
+ def _make_finalizer(self):
838
+ """
839
+ Return a finalizer function that will release references to
840
+ related compiled functions.
841
+ """
842
+ overloads = self.overloads
843
+ targetctx = self.targetctx
844
+
845
+ # Early-bind utils.shutting_down() into the function's local namespace
846
+ # (see issue #689)
847
+ def finalizer(shutting_down=utils.shutting_down):
848
+ # The finalizer may crash at shutdown, skip it (resources
849
+ # will be cleared by the process exiting, anyway).
850
+ if shutting_down():
851
+ return
852
+ # This function must *not* hold any reference to self:
853
+ # we take care to bind the necessary objects in the closure.
854
+ for cres in overloads.values():
855
+ try:
856
+ targetctx.remove_user_function(cres.entry_point)
857
+ except KeyError:
858
+ pass
859
+
860
+ return finalizer
861
+
862
+ @property
863
+ def signatures(self):
864
+ """
865
+ Returns a list of compiled function signatures.
866
+ """
867
+ return list(self.overloads)
868
+
869
+ @property
870
+ def nopython_signatures(self):
871
+ return [
872
+ cres.signature
873
+ for cres in self.overloads.values()
874
+ if not cres.objectmode
875
+ ]
876
+
877
+ def disable_compile(self, val=True):
878
+ """Disable the compilation of new signatures at call time."""
879
+ # If disabling compilation then there must be at least one signature
880
+ assert (not val) or len(self.signatures) > 0
881
+ self._can_compile = not val
882
+
883
+ def add_overload(self, cres):
884
+ args = tuple(cres.signature.args)
885
+ sig = [a._code for a in args]
886
+ self._insert(sig, cres.entry_point, cres.objectmode)
887
+ self.overloads[args] = cres
888
+
889
+ def fold_argument_types(self, args, kws):
890
+ return self._compiler.fold_argument_types(args, kws)
891
+
892
+ def get_call_template(self, args, kws):
893
+ """
894
+ Get a typing.ConcreteTemplate for this dispatcher and the given
895
+ *args* and *kws* types. This allows to resolve the return type.
896
+
897
+ A (template, pysig, args, kws) tuple is returned.
898
+ """
899
+ # XXX how about a dispatcher template class automating the
900
+ # following?
901
+
902
+ # Fold keyword arguments and resolve default values
903
+ pysig, args = self._compiler.fold_argument_types(args, kws)
904
+ kws = {}
905
+ # Ensure an overload is available
906
+ if self._can_compile:
907
+ self.compile(tuple(args))
908
+
909
+ # Create function type for typing
910
+ func_name = self.py_func.__name__
911
+ name = "CallTemplate({0})".format(func_name)
912
+ # The `key` isn't really used except for diagnosis here,
913
+ # so avoid keeping a reference to `cfunc`.
914
+ call_template = typing.make_concrete_template(
915
+ name, key=func_name, signatures=self.nopython_signatures
916
+ )
917
+ return call_template, pysig, args, kws
918
+
919
+ def get_overload(self, sig):
920
+ """
921
+ Return the compiled function for the given signature.
922
+ """
923
+ args, return_type = sigutils.normalize_signature(sig)
924
+ return self.overloads[tuple(args)].entry_point
925
+
926
+ @property
927
+ def is_compiling(self):
928
+ """
929
+ Whether a specialization is currently being compiled.
930
+ """
931
+ return self._compiling_counter
932
+
933
+ def _compile_for_args(self, *args, **kws):
934
+ """
935
+ For internal use. Compile a specialized version of the function
936
+ for the given *args* and *kws*, and return the resulting callable.
937
+ """
938
+ assert not kws
939
+ # call any initialisation required for the compilation chain (e.g.
940
+ # extension point registration).
941
+ self._compilation_chain_init_hook()
942
+
943
+ def error_rewrite(e, issue_type):
944
+ """
945
+ Rewrite and raise Exception `e` with help supplied based on the
946
+ specified issue_type.
947
+ """
948
+ if config.SHOW_HELP:
949
+ help_msg = errors.error_extras[issue_type]
950
+ e.patch_message("\n".join((str(e).rstrip(), help_msg)))
951
+ if config.FULL_TRACEBACKS:
952
+ raise e
953
+ else:
954
+ raise e.with_traceback(None)
955
+
956
+ argtypes = []
957
+ for a in args:
958
+ if isinstance(a, OmittedArg):
959
+ argtypes.append(types.Omitted(a.value))
960
+ else:
961
+ argtypes.append(self.typeof_pyval(a))
962
+
963
+ return_val = None
964
+ try:
965
+ return_val = self.compile(tuple(argtypes))
966
+ except errors.ForceLiteralArg as e:
967
+ # Received request for compiler re-entry with the list of arguments
968
+ # indicated by e.requested_args.
969
+ # First, check if any of these args are already Literal-ized
970
+ already_lit_pos = [
971
+ i
972
+ for i in e.requested_args
973
+ if isinstance(args[i], types.Literal)
974
+ ]
975
+ if already_lit_pos:
976
+ # Abort compilation if any argument is already a Literal.
977
+ # Letting this continue will cause infinite compilation loop.
978
+ m = (
979
+ "Repeated literal typing request.\n"
980
+ "{}.\n"
981
+ "This is likely caused by an error in typing. "
982
+ "Please see nested and suppressed exceptions."
983
+ )
984
+ info = ", ".join(
985
+ "Arg #{} is {}".format(i, args[i])
986
+ for i in sorted(already_lit_pos)
987
+ )
988
+ raise errors.CompilerError(m.format(info))
989
+ # Convert requested arguments into a Literal.
990
+ args = [
991
+ (types.literal if i in e.requested_args else lambda x: x)(
992
+ args[i]
993
+ )
994
+ for i, v in enumerate(args)
995
+ ]
996
+ # Re-enter compilation with the Literal-ized arguments
997
+ return_val = self._compile_for_args(*args)
998
+
999
+ except errors.TypingError as e:
1000
+ # Intercept typing error that may be due to an argument
1001
+ # that failed inferencing as a Numba type
1002
+ failed_args = []
1003
+ for i, arg in enumerate(args):
1004
+ val = arg.value if isinstance(arg, OmittedArg) else arg
1005
+ try:
1006
+ tp = typeof(val, Purpose.argument)
1007
+ except (errors.NumbaValueError, ValueError) as typeof_exc:
1008
+ failed_args.append((i, str(typeof_exc)))
1009
+ else:
1010
+ if tp is None:
1011
+ failed_args.append(
1012
+ (i, f"cannot determine Numba type of value {val}")
1013
+ )
1014
+ if failed_args:
1015
+ # Patch error message to ease debugging
1016
+ args_str = "\n".join(
1017
+ f"- argument {i}: {err}" for i, err in failed_args
1018
+ )
1019
+ msg = (
1020
+ f"{str(e).rstrip()} \n\nThis error may have been caused "
1021
+ f"by the following argument(s):\n{args_str}\n"
1022
+ )
1023
+ e.patch_message(msg)
1024
+
1025
+ error_rewrite(e, "typing")
1026
+ except errors.UnsupportedError as e:
1027
+ # Something unsupported is present in the user code, add help info
1028
+ error_rewrite(e, "unsupported_error")
1029
+ except (
1030
+ errors.NotDefinedError,
1031
+ errors.RedefinedError,
1032
+ errors.VerificationError,
1033
+ ) as e:
1034
+ # These errors are probably from an issue with either the code
1035
+ # supplied being syntactically or otherwise invalid
1036
+ error_rewrite(e, "interpreter")
1037
+ except errors.ConstantInferenceError as e:
1038
+ # this is from trying to infer something as constant when it isn't
1039
+ # or isn't supported as a constant
1040
+ error_rewrite(e, "constant_inference")
1041
+ except Exception as e:
1042
+ if config.SHOW_HELP:
1043
+ if hasattr(e, "patch_message"):
1044
+ help_msg = errors.error_extras["reportable"]
1045
+ e.patch_message("\n".join((str(e).rstrip(), help_msg)))
1046
+ # ignore the FULL_TRACEBACKS config, this needs reporting!
1047
+ raise e
1048
+ finally:
1049
+ self._types_active_call.clear()
1050
+ return return_val
1051
+
1052
+ def inspect_llvm(self, signature=None):
1053
+ """Get the LLVM intermediate representation generated by compilation.
1054
+
1055
+ Parameters
1056
+ ----------
1057
+ signature : tuple of numba types, optional
1058
+ Specify a signature for which to obtain the LLVM IR. If None, the
1059
+ IR is returned for all available signatures.
1060
+
1061
+ Returns
1062
+ -------
1063
+ llvm : dict[signature, str] or str
1064
+ Either the LLVM IR string for the specified signature, or, if no
1065
+ signature was given, a dictionary mapping signatures to LLVM IR
1066
+ strings.
1067
+ """
1068
+ if signature is not None:
1069
+ lib = self.overloads[signature].library
1070
+ return lib.get_llvm_str()
1071
+
1072
+ return dict((sig, self.inspect_llvm(sig)) for sig in self.signatures)
1073
+
1074
+ def inspect_asm(self, signature=None):
1075
+ """Get the generated assembly code.
1076
+
1077
+ Parameters
1078
+ ----------
1079
+ signature : tuple of numba types, optional
1080
+ Specify a signature for which to obtain the assembly code. If
1081
+ None, the assembly code is returned for all available signatures.
1082
+
1083
+ Returns
1084
+ -------
1085
+ asm : dict[signature, str] or str
1086
+ Either the assembly code for the specified signature, or, if no
1087
+ signature was given, a dictionary mapping signatures to assembly
1088
+ code.
1089
+ """
1090
+ if signature is not None:
1091
+ lib = self.overloads[signature].library
1092
+ return lib.get_asm_str()
1093
+
1094
+ return dict((sig, self.inspect_asm(sig)) for sig in self.signatures)
1095
+
1096
+ def inspect_types(
1097
+ self, file=None, signature=None, pretty=False, style="default", **kwargs
1098
+ ):
1099
+ """Print/return Numba intermediate representation (IR)-annotated code.
1100
+
1101
+ Parameters
1102
+ ----------
1103
+ file : file-like object, optional
1104
+ File to which to print. Defaults to sys.stdout if None. Must be
1105
+ None if ``pretty=True``.
1106
+ signature : tuple of numba types, optional
1107
+ Print/return the intermediate representation for only the given
1108
+ signature. If None, the IR is printed for all available signatures.
1109
+ pretty : bool, optional
1110
+ If True, an Annotate object will be returned that can render the
1111
+ IR with color highlighting in Jupyter and IPython. ``file`` must
1112
+ be None if ``pretty`` is True. Additionally, the ``pygments``
1113
+ library must be installed for ``pretty=True``.
1114
+ style : str, optional
1115
+ Choose a style for rendering. Ignored if ``pretty`` is ``False``.
1116
+ This is directly consumed by ``pygments`` formatters. To see a
1117
+ list of available styles, import ``pygments`` and run
1118
+ ``list(pygments.styles.get_all_styles())``.
1119
+
1120
+ Returns
1121
+ -------
1122
+ annotated : Annotate object, optional
1123
+ Only returned if ``pretty=True``, otherwise this function is only
1124
+ used for its printing side effect. If ``pretty=True``, an Annotate
1125
+ object is returned that can render itself in Jupyter and IPython.
1126
+ """
1127
+ overloads = self.overloads
1128
+ if signature is not None:
1129
+ overloads = {signature: self.overloads[signature]}
1130
+
1131
+ if not pretty:
1132
+ if file is None:
1133
+ file = sys.stdout
1134
+
1135
+ for ver, res in overloads.items():
1136
+ print("%s %s" % (self.py_func.__name__, ver), file=file)
1137
+ print("-" * 80, file=file)
1138
+ print(res.type_annotation, file=file)
1139
+ print("=" * 80, file=file)
1140
+ else:
1141
+ if file is not None:
1142
+ raise ValueError("`file` must be None if `pretty=True`")
1143
+ from numba.cuda.core.annotations.pretty_annotate import Annotate
1144
+
1145
+ return Annotate(self, signature=signature, style=style)
1146
+
1147
+ def inspect_cfg(self, signature=None, show_wrapper=None, **kwargs):
1148
+ """
1149
+ For inspecting the CFG of the function.
1150
+
1151
+ By default the CFG of the user function is shown. The *show_wrapper*
1152
+ option can be set to "python" or "cfunc" to show the python wrapper
1153
+ function or the *cfunc* wrapper function, respectively.
1154
+
1155
+ Parameters accepted in kwargs
1156
+ -----------------------------
1157
+ filename : string, optional
1158
+ the name of the output file, if given this will write the output to
1159
+ filename
1160
+ view : bool, optional
1161
+ whether to immediately view the optional output file
1162
+ highlight : bool, set, dict, optional
1163
+ what, if anything, to highlight, options are:
1164
+ { incref : bool, # highlight NRT_incref calls
1165
+ decref : bool, # highlight NRT_decref calls
1166
+ returns : bool, # highlight exits which are normal returns
1167
+ raises : bool, # highlight exits which are from raise
1168
+ meminfo : bool, # highlight calls to NRT*meminfo
1169
+ branches : bool, # highlight true/false branches
1170
+ }
1171
+ Default is True which sets all of the above to True. Supplying a set
1172
+ of strings is also accepted, these are interpreted as key:True with
1173
+ respect to the above dictionary. e.g. {'incref', 'decref'} would
1174
+ switch on highlighting on increfs and decrefs.
1175
+ interleave: bool, set, dict, optional
1176
+ what, if anything, to interleave in the LLVM IR, options are:
1177
+ { python: bool # interleave python source code with the LLVM IR
1178
+ lineinfo: bool # interleave line information markers with the LLVM
1179
+ # IR
1180
+ }
1181
+ Default is True which sets all of the above to True. Supplying a set
1182
+ of strings is also accepted, these are interpreted as key:True with
1183
+ respect to the above dictionary. e.g. {'python',} would
1184
+ switch on interleaving of python source code in the LLVM IR.
1185
+ strip_ir : bool, optional
1186
+ Default is False. If set to True all LLVM IR that is superfluous to
1187
+ that requested in kwarg `highlight` will be removed.
1188
+ show_key : bool, optional
1189
+ Default is True. Create a "key" for the highlighting in the rendered
1190
+ CFG.
1191
+ fontsize : int, optional
1192
+ Default is 8. Set the fontsize in the output to this value.
1193
+ """
1194
+ if signature is not None:
1195
+ cres = self.overloads[signature]
1196
+ lib = cres.library
1197
+ if show_wrapper == "python":
1198
+ fname = cres.fndesc.llvm_cpython_wrapper_name
1199
+ elif show_wrapper == "cfunc":
1200
+ fname = cres.fndesc.llvm_cfunc_wrapper_name
1201
+ else:
1202
+ fname = cres.fndesc.mangled_name
1203
+ return lib.get_function_cfg(fname, py_func=self.py_func, **kwargs)
1204
+
1205
+ return dict(
1206
+ (sig, self.inspect_cfg(sig, show_wrapper=show_wrapper))
1207
+ for sig in self.signatures
1208
+ )
1209
+
1210
+ def inspect_disasm_cfg(self, signature=None):
1211
+ """
1212
+ For inspecting the CFG of the disassembly of the function.
1213
+
1214
+ Requires python package: r2pipe
1215
+ Requires radare2 binary on $PATH.
1216
+ Notebook rendering requires python package: graphviz
1217
+
1218
+ signature : tuple of Numba types, optional
1219
+ Print/return the disassembly CFG for only the given signatures.
1220
+ If None, the IR is printed for all available signatures.
1221
+ """
1222
+ if signature is not None:
1223
+ cres = self.overloads[signature]
1224
+ lib = cres.library
1225
+ return lib.get_disasm_cfg(cres.fndesc.mangled_name)
1226
+
1227
+ return dict(
1228
+ (sig, self.inspect_disasm_cfg(sig)) for sig in self.signatures
1229
+ )
1230
+
1231
+ def get_annotation_info(self, signature=None):
1232
+ """
1233
+ Gets the annotation information for the function specified by
1234
+ signature. If no signature is supplied a dictionary of signature to
1235
+ annotation information is returned.
1236
+ """
1237
+ signatures = self.signatures if signature is None else [signature]
1238
+ out = collections.OrderedDict()
1239
+ for sig in signatures:
1240
+ cres = self.overloads[sig]
1241
+ ta = cres.type_annotation
1242
+ key = (
1243
+ ta.func_id.filename + ":" + str(ta.func_id.firstlineno + 1),
1244
+ ta.signature,
1245
+ )
1246
+ out[key] = ta.annotate_raw()[key]
1247
+ return out
1248
+
1249
+ def _explain_ambiguous(self, *args, **kws):
1250
+ """
1251
+ Callback for the C _Dispatcher object.
1252
+ """
1253
+ assert not kws, "kwargs not handled"
1254
+ args = tuple([self.typeof_pyval(a) for a in args])
1255
+ # The order here must be deterministic for testing purposes, which
1256
+ # is ensured by the OrderedDict.
1257
+ sigs = self.nopython_signatures
1258
+ # This will raise
1259
+ self.typingctx.resolve_overload(
1260
+ self.py_func, sigs, args, kws, allow_ambiguous=False
1261
+ )
1262
+
1263
+ def _explain_matching_error(self, *args, **kws):
1264
+ """
1265
+ Callback for the C _Dispatcher object.
1266
+ """
1267
+ assert not kws, "kwargs not handled"
1268
+ args = [self.typeof_pyval(a) for a in args]
1269
+ msg = "No matching definition for argument type(s) %s" % ", ".join(
1270
+ map(str, args)
1271
+ )
1272
+ raise TypeError(msg)
1273
+
1274
+ def _search_new_conversions(self, *args, **kws):
1275
+ """
1276
+ Callback for the C _Dispatcher object.
1277
+ Search for approximately matching signatures for the given arguments,
1278
+ and ensure the corresponding conversions are registered in the C++
1279
+ type manager.
1280
+ """
1281
+ assert not kws, "kwargs not handled"
1282
+ args = [self.typeof_pyval(a) for a in args]
1283
+ found = False
1284
+ for sig in self.nopython_signatures:
1285
+ conv = self.typingctx.install_possible_conversions(args, sig.args)
1286
+ if conv:
1287
+ found = True
1288
+ return found
1289
+
1290
+ def __repr__(self):
1291
+ return "%s(%s)" % (type(self).__name__, self.py_func)
1292
+
1293
+ def typeof_pyval(self, val):
1294
+ """
1295
+ Resolve the Numba type of Python value *val*.
1296
+ This is called from numba._dispatcher as a fallback if the native code
1297
+ cannot decide the type.
1298
+ """
1299
+ try:
1300
+ tp = typeof(val, Purpose.argument)
1301
+ except (errors.NumbaValueError, ValueError):
1302
+ tp = types.pyobject
1303
+ else:
1304
+ if tp is None:
1305
+ tp = types.pyobject
1306
+ self._types_active_call.add(tp)
1307
+ return tp
1308
+
1309
+ def _callback_add_timer(self, duration, cres, lock_name):
1310
+ md = cres.metadata
1311
+ # md can be None when code is loaded from cache
1312
+ if md is not None:
1313
+ timers = md.setdefault("timers", {})
1314
+ if lock_name not in timers:
1315
+ # Only write if the metadata does not exist
1316
+ timers[lock_name] = duration
1317
+ else:
1318
+ msg = f"'{lock_name} metadata is already defined."
1319
+ raise AssertionError(msg)
1320
+
1321
+ def _callback_add_compiler_timer(self, duration, cres):
1322
+ return self._callback_add_timer(
1323
+ duration, cres, lock_name="compiler_lock"
1324
+ )
1325
+
1326
+ def _callback_add_llvm_timer(self, duration, cres):
1327
+ return self._callback_add_timer(duration, cres, lock_name="llvm_lock")
1328
+
1329
+
1330
+ class _MemoMixin:
1331
+ __uuid = None
1332
+ # A {uuid -> instance} mapping, for deserialization
1333
+ _memo = weakref.WeakValueDictionary()
1334
+ # hold refs to last N functions deserialized, retaining them in _memo
1335
+ # regardless of whether there is another reference
1336
+ _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)
1337
+
1338
+ @property
1339
+ def _uuid(self):
1340
+ """
1341
+ An instance-specific UUID, to avoid multiple deserializations of
1342
+ a given instance.
1343
+
1344
+ Note: this is lazily-generated, for performance reasons.
1345
+ """
1346
+ u = self.__uuid
1347
+ if u is None:
1348
+ u = str(uuid.uuid4())
1349
+ self._set_uuid(u)
1350
+ return u
1351
+
1352
+ def _set_uuid(self, u):
1353
+ assert self.__uuid is None
1354
+ self.__uuid = u
1355
+ self._memo[u] = self
1356
+ self._recent.append(self)
1357
+
1358
+
1359
+ _CompileStats = collections.namedtuple(
1360
+ "_CompileStats", ("cache_path", "cache_hits", "cache_misses")
1361
+ )
1362
+
1363
+
1364
+ class _FunctionCompiler(object):
1365
+ def __init__(self, py_func, targetdescr, targetoptions, pipeline_class):
1366
+ self.py_func = py_func
1367
+ self.targetdescr = targetdescr
1368
+ self.targetoptions = targetoptions
1369
+ self.locals = {}
1370
+ self.pysig = utils.pysignature(self.py_func)
1371
+ self.pipeline_class = pipeline_class
1372
+ # Remember key=(args, return_type) combinations that will fail
1373
+ # compilation to avoid compilation attempt on them. The values are
1374
+ # the exceptions.
1375
+ self._failed_cache = {}
1376
+
1377
+ def fold_argument_types(self, args, kws):
1378
+ """
1379
+ Given positional and named argument types, fold keyword arguments
1380
+ and resolve defaults by inserting types.Omitted() instances.
1381
+
1382
+ A (pysig, argument types) tuple is returned.
1383
+ """
1384
+
1385
+ def normal_handler(index, param, value):
1386
+ return value
1387
+
1388
+ def default_handler(index, param, default):
1389
+ return types.Omitted(default)
1390
+
1391
+ def stararg_handler(index, param, values):
1392
+ return types.StarArgTuple(values)
1393
+
1394
+ # For now, we take argument values from the @jit function
1395
+ args = fold_arguments(
1396
+ self.pysig,
1397
+ args,
1398
+ kws,
1399
+ normal_handler,
1400
+ default_handler,
1401
+ stararg_handler,
1402
+ )
1403
+ return self.pysig, args
1404
+
1405
+ def compile(self, args, return_type):
1406
+ status, retval = self._compile_cached(args, return_type)
1407
+ if status:
1408
+ return retval
1409
+ else:
1410
+ raise retval
1411
+
1412
+ def _compile_cached(self, args, return_type):
1413
+ key = tuple(args), return_type
1414
+ try:
1415
+ return False, self._failed_cache[key]
1416
+ except KeyError:
1417
+ pass
1418
+
1419
+ try:
1420
+ retval = self._compile_core(args, return_type)
1421
+ except errors.TypingError as e:
1422
+ self._failed_cache[key] = e
1423
+ return False, e
1424
+ else:
1425
+ return True, retval
1426
+
1427
+ def _compile_core(self, args, return_type):
1428
+ flags = Flags()
1429
+ self.targetdescr.options.parse_as_flags(flags, self.targetoptions)
1430
+ flags = self._customize_flags(flags)
1431
+
1432
+ impl = self._get_implementation(args, {})
1433
+ cres = compile_extra(
1434
+ self.targetdescr.typing_context,
1435
+ self.targetdescr.target_context,
1436
+ impl,
1437
+ args=args,
1438
+ return_type=return_type,
1439
+ flags=flags,
1440
+ locals=self.locals,
1441
+ pipeline_class=self.pipeline_class,
1442
+ )
1443
+ # Check typing error if object mode is used
1444
+ if cres.typing_error is not None and not flags.enable_pyobject:
1445
+ raise cres.typing_error
1446
+ return cres
1447
+
1448
+ def get_globals_for_reduction(self):
1449
+ return serialize._get_function_globals_for_reduction(self.py_func)
1450
+
1451
+ def _get_implementation(self, args, kws):
1452
+ return self.py_func
1453
+
1454
+ def _customize_flags(self, flags):
1455
+ return flags
1456
+
1457
+
1458
+ class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
1459
+ """
1460
+ CUDA Dispatcher object. When configured and called, the dispatcher will
1461
+ specialize itself for the given arguments (if no suitable specialized
1462
+ version already exists) & compute capability, and launch on the device
1463
+ associated with the current context.
1464
+
1465
+ Dispatcher objects are not to be constructed by the user, but instead are
1466
+ created using the :func:`numba.cuda.jit` decorator.
1467
+ """
1468
+
1469
+ # Whether to fold named arguments and default values. Default values are
1470
+ # presently unsupported on CUDA, so we can leave this as False in all
1471
+ # cases.
1472
+ _fold_args = False
1473
+
1474
+ targetdescr = cuda_target
1475
+
1476
+ def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
1477
+ """
1478
+ Parameters
1479
+ ----------
1480
+ py_func: function object to be compiled
1481
+ targetoptions: dict, optional
1482
+ Target-specific config options.
1483
+ pipeline_class: type numba.compiler.CompilerBase
1484
+ The compiler pipeline type.
1485
+ """
1486
+ self.typingctx = self.targetdescr.typing_context
1487
+ self.targetctx = self.targetdescr.target_context
1488
+
1489
+ pysig = utils.pysignature(py_func)
1490
+ arg_count = len(pysig.parameters)
1491
+ can_fallback = not targetoptions.get("nopython", False)
1492
+
1493
+ _DispatcherBase.__init__(
1494
+ self,
1495
+ arg_count,
1496
+ py_func,
1497
+ pysig,
1498
+ can_fallback,
1499
+ exact_match_required=False,
1500
+ )
1501
+
1502
+ functools.update_wrapper(self, py_func)
1503
+
1504
+ self.targetoptions = targetoptions
1505
+ self._cache = NullCache()
1506
+ compiler_class = _FunctionCompiler
1507
+ self._compiler = compiler_class(
1508
+ py_func, self.targetdescr, targetoptions, pipeline_class
1509
+ )
1510
+ self._cache_hits = collections.Counter()
1511
+ self._cache_misses = collections.Counter()
1512
+
1513
+ # The following properties are for specialization of CUDADispatchers. A
1514
+ # specialized CUDADispatcher is one that is compiled for exactly one
1515
+ # set of argument types, and bypasses some argument type checking for
1516
+ # faster kernel launches.
1517
+
1518
+ # Is this a specialized dispatcher?
1519
+ self._specialized = False
1520
+
1521
+ # If we produced specialized dispatchers, we cache them for each set of
1522
+ # argument types
1523
+ self.specializations = {}
1524
+
1525
+ def dump(self, tab=""):
1526
+ print(
1527
+ f"{tab}DUMP {type(self).__name__}[{self.py_func.__name__}"
1528
+ f", type code={self._type._code}]"
1529
+ )
1530
+ for cres in self.overloads.values():
1531
+ cres.dump(tab=tab + " ")
1532
+ print(f"{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]")
1533
+
1534
+ @property
1535
+ def _numba_type_(self):
1536
+ return ext_types.CUDADispatcher(self)
1537
+
1538
+ def enable_caching(self):
1539
+ self._cache = CUDACache(self.py_func)
1540
+
1541
+ def __get__(self, obj, objtype=None):
1542
+ """Allow a JIT function to be bound as a method to an object"""
1543
+ if obj is None: # Unbound method
1544
+ return self
1545
+ else: # Bound method
1546
+ return pytypes.MethodType(self, obj)
1547
+
1548
+ @functools.lru_cache(maxsize=128)
1549
+ def configure(self, griddim, blockdim, stream=0, sharedmem=0):
1550
+ griddim, blockdim = normalize_kernel_dimensions(griddim, blockdim)
1551
+ return _LaunchConfiguration(self, griddim, blockdim, stream, sharedmem)
1552
+
1553
+ def __getitem__(self, args):
1554
+ if len(args) not in [2, 3, 4]:
1555
+ raise ValueError("must specify at least the griddim and blockdim")
1556
+ return self.configure(*args)
1557
+
1558
+ def forall(self, ntasks, tpb=0, stream=0, sharedmem=0):
1559
+ """Returns a 1D-configured dispatcher for a given number of tasks.
1560
+
1561
+ This assumes that:
1562
+
1563
+ - the kernel maps the Global Thread ID ``cuda.grid(1)`` to tasks on a
1564
+ 1-1 basis.
1565
+ - the kernel checks that the Global Thread ID is upper-bounded by
1566
+ ``ntasks``, and does nothing if it is not.
1567
+
1568
+ :param ntasks: The number of tasks.
1569
+ :param tpb: The size of a block. An appropriate value is chosen if this
1570
+ parameter is not supplied.
1571
+ :param stream: The stream on which the configured dispatcher will be
1572
+ launched.
1573
+ :param sharedmem: The number of bytes of dynamic shared memory required
1574
+ by the kernel.
1575
+ :return: A configured dispatcher, ready to launch on a set of
1576
+ arguments."""
1577
+
1578
+ return ForAll(self, ntasks, tpb=tpb, stream=stream, sharedmem=sharedmem)
1579
+
1580
+ @property
1581
+ def extensions(self):
1582
+ """
1583
+ A list of objects that must have a `prepare_args` function. When a
1584
+ specialized kernel is called, each argument will be passed through
1585
+ to the `prepare_args` (from the last object in this list to the
1586
+ first). The arguments to `prepare_args` are:
1587
+
1588
+ - `ty` the numba type of the argument
1589
+ - `val` the argument value itself
1590
+ - `stream` the CUDA stream used for the current call to the kernel
1591
+ - `retr` a list of zero-arg functions that you may want to append
1592
+ post-call cleanup work to.
1593
+
1594
+ The `prepare_args` function must return a tuple `(ty, val)`, which
1595
+ will be passed in turn to the next right-most `extension`. After all
1596
+ the extensions have been called, the resulting `(ty, val)` will be
1597
+ passed into Numba's default argument marshalling logic.
1598
+ """
1599
+ return self.targetoptions.get("extensions")
1600
+
1601
+ def __call__(self, *args, **kwargs):
1602
+ # An attempt to launch an unconfigured kernel
1603
+ raise ValueError(missing_launch_config_msg)
1604
+
1605
+ def call(self, args, griddim, blockdim, stream, sharedmem):
1606
+ """
1607
+ Compile if necessary and invoke this kernel with *args*.
1608
+ """
1609
+ if self.specialized:
1610
+ kernel = next(iter(self.overloads.values()))
1611
+ else:
1612
+ kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
1613
+
1614
+ kernel.launch(args, griddim, blockdim, stream, sharedmem)
1615
+
1616
+ def _compile_for_args(self, *args, **kws):
1617
+ # Based on _DispatcherBase._compile_for_args.
1618
+ assert not kws
1619
+ argtypes = [self.typeof_pyval(a) for a in args]
1620
+ return self.compile(tuple(argtypes))
1621
+
1622
+ def typeof_pyval(self, val):
1623
+ # Based on _DispatcherBase.typeof_pyval, but differs from it to support
1624
+ # the CUDA Array Interface.
1625
+ try:
1626
+ return typeof(val, Purpose.argument)
1627
+ except ValueError:
1628
+ if (
1629
+ interface := getattr(val, "__cuda_array_interface__")
1630
+ ) is not None:
1631
+ # When typing, we don't need to synchronize on the array's
1632
+ # stream - this is done when the kernel is launched.
1633
+
1634
+ return typeof(
1635
+ cuda.from_cuda_array_interface(interface, sync=False),
1636
+ Purpose.argument,
1637
+ )
1638
+ else:
1639
+ raise
1640
+
1641
+ def specialize(self, *args):
1642
+ """
1643
+ Create a new instance of this dispatcher specialized for the given
1644
+ *args*.
1645
+ """
1646
+ cc = get_current_device().compute_capability
1647
+ argtypes = tuple(self.typeof_pyval(a) for a in args)
1648
+ if self.specialized:
1649
+ raise RuntimeError("Dispatcher already specialized")
1650
+
1651
+ specialization = self.specializations.get((cc, argtypes))
1652
+ if specialization:
1653
+ return specialization
1654
+
1655
+ targetoptions = self.targetoptions
1656
+ specialization = CUDADispatcher(
1657
+ self.py_func, targetoptions=targetoptions
1658
+ )
1659
+ specialization.compile(argtypes)
1660
+ specialization.disable_compile()
1661
+ specialization._specialized = True
1662
+ self.specializations[cc, argtypes] = specialization
1663
+ return specialization
1664
+
1665
+ @property
1666
+ def specialized(self):
1667
+ """
1668
+ True if the Dispatcher has been specialized.
1669
+ """
1670
+ return self._specialized
1671
+
1672
+ def get_regs_per_thread(self, signature=None):
1673
+ """
1674
+ Returns the number of registers used by each thread in this kernel for
1675
+ the device in the current context.
1676
+
1677
+ :param signature: The signature of the compiled kernel to get register
1678
+ usage for. This may be omitted for a specialized
1679
+ kernel.
1680
+ :return: The number of registers used by the compiled variant of the
1681
+ kernel for the given signature and current device.
1682
+ """
1683
+ if signature is not None:
1684
+ return self.overloads[signature.args].regs_per_thread
1685
+ if self.specialized:
1686
+ return next(iter(self.overloads.values())).regs_per_thread
1687
+ else:
1688
+ return {
1689
+ sig: overload.regs_per_thread
1690
+ for sig, overload in self.overloads.items()
1691
+ }
1692
+
1693
+ def get_const_mem_size(self, signature=None):
1694
+ """
1695
+ Returns the size in bytes of constant memory used by this kernel for
1696
+ the device in the current context.
1697
+
1698
+ :param signature: The signature of the compiled kernel to get constant
1699
+ memory usage for. This may be omitted for a
1700
+ specialized kernel.
1701
+ :return: The size in bytes of constant memory allocated by the
1702
+ compiled variant of the kernel for the given signature and
1703
+ current device.
1704
+ """
1705
+ if signature is not None:
1706
+ return self.overloads[signature.args].const_mem_size
1707
+ if self.specialized:
1708
+ return next(iter(self.overloads.values())).const_mem_size
1709
+ else:
1710
+ return {
1711
+ sig: overload.const_mem_size
1712
+ for sig, overload in self.overloads.items()
1713
+ }
1714
+
1715
+ def get_shared_mem_per_block(self, signature=None):
1716
+ """
1717
+ Returns the size in bytes of statically allocated shared memory
1718
+ for this kernel.
1719
+
1720
+ :param signature: The signature of the compiled kernel to get shared
1721
+ memory usage for. This may be omitted for a
1722
+ specialized kernel.
1723
+ :return: The amount of shared memory allocated by the compiled variant
1724
+ of the kernel for the given signature and current device.
1725
+ """
1726
+ if signature is not None:
1727
+ return self.overloads[signature.args].shared_mem_per_block
1728
+ if self.specialized:
1729
+ return next(iter(self.overloads.values())).shared_mem_per_block
1730
+ else:
1731
+ return {
1732
+ sig: overload.shared_mem_per_block
1733
+ for sig, overload in self.overloads.items()
1734
+ }
1735
+
1736
+ def get_max_threads_per_block(self, signature=None):
1737
+ """
1738
+ Returns the maximum allowable number of threads per block
1739
+ for this kernel. Exceeding this threshold will result in
1740
+ the kernel failing to launch.
1741
+
1742
+ :param signature: The signature of the compiled kernel to get the max
1743
+ threads per block for. This may be omitted for a
1744
+ specialized kernel.
1745
+ :return: The maximum allowable threads per block for the compiled
1746
+ variant of the kernel for the given signature and current
1747
+ device.
1748
+ """
1749
+ if signature is not None:
1750
+ return self.overloads[signature.args].max_threads_per_block
1751
+ if self.specialized:
1752
+ return next(iter(self.overloads.values())).max_threads_per_block
1753
+ else:
1754
+ return {
1755
+ sig: overload.max_threads_per_block
1756
+ for sig, overload in self.overloads.items()
1757
+ }
1758
+
1759
+ def get_local_mem_per_thread(self, signature=None):
1760
+ """
1761
+ Returns the size in bytes of local memory per thread
1762
+ for this kernel.
1763
+
1764
+ :param signature: The signature of the compiled kernel to get local
1765
+ memory usage for. This may be omitted for a
1766
+ specialized kernel.
1767
+ :return: The amount of local memory allocated by the compiled variant
1768
+ of the kernel for the given signature and current device.
1769
+ """
1770
+ if signature is not None:
1771
+ return self.overloads[signature.args].local_mem_per_thread
1772
+ if self.specialized:
1773
+ return next(iter(self.overloads.values())).local_mem_per_thread
1774
+ else:
1775
+ return {
1776
+ sig: overload.local_mem_per_thread
1777
+ for sig, overload in self.overloads.items()
1778
+ }
1779
+
1780
+ def get_call_template(self, args, kws):
1781
+ # Originally copied from _DispatcherBase.get_call_template. This
1782
+ # version deviates slightly from the _DispatcherBase version in order
1783
+ # to force casts when calling device functions. See e.g.
1784
+ # TestDeviceFunc.test_device_casting, added in PR #7496.
1785
+ """
1786
+ Get a typing.ConcreteTemplate for this dispatcher and the given
1787
+ *args* and *kws* types. This allows resolution of the return type.
1788
+
1789
+ A (template, pysig, args, kws) tuple is returned.
1790
+ """
1791
+ # Fold keyword arguments and resolve default values
1792
+ pysig, args = self.fold_argument_types(args, kws)
1793
+ kws = {}
1794
+
1795
+ # Ensure an exactly-matching overload is available if we can
1796
+ # compile. We proceed with the typing even if we can't compile
1797
+ # because we may be able to force a cast on the caller side.
1798
+ if self._can_compile:
1799
+ self.compile_device(tuple(args))
1800
+
1801
+ # Create function type for typing
1802
+ func_name = self.py_func.__name__
1803
+ name = "CallTemplate({0})".format(func_name)
1804
+
1805
+ call_template = typing.make_concrete_template(
1806
+ name, key=func_name, signatures=self.nopython_signatures
1807
+ )
1808
+ pysig = utils.pysignature(self.py_func)
1809
+
1810
+ return call_template, pysig, args, kws
1811
+
1812
+ def compile_device(self, args, return_type=None):
1813
+ """Compile the device function for the given argument types.
1814
+
1815
+ Each signature is compiled once by caching the compiled function inside
1816
+ this object.
1817
+
1818
+ Returns the `CompileResult`.
1819
+ """
1820
+ if args not in self.overloads:
1821
+ with self._compiling_counter:
1822
+ debug = self.targetoptions.get("debug")
1823
+ lineinfo = self.targetoptions.get("lineinfo")
1824
+ forceinline = self.targetoptions.get("forceinline")
1825
+ fastmath = self.targetoptions.get("fastmath")
1826
+
1827
+ nvvm_options = {
1828
+ "opt": 3 if self.targetoptions.get("opt") else 0,
1829
+ "fastmath": fastmath,
1830
+ }
1831
+
1832
+ if debug:
1833
+ nvvm_options["g"] = None
1834
+
1835
+ cc = get_current_device().compute_capability
1836
+ cres = compile_cuda(
1837
+ self.py_func,
1838
+ return_type,
1839
+ args,
1840
+ debug=debug,
1841
+ lineinfo=lineinfo,
1842
+ forceinline=forceinline,
1843
+ fastmath=fastmath,
1844
+ nvvm_options=nvvm_options,
1845
+ cc=cc,
1846
+ )
1847
+ self.overloads[args] = cres
1848
+
1849
+ cres.target_context.insert_user_function(
1850
+ cres.entry_point, cres.fndesc, [cres.library]
1851
+ )
1852
+ else:
1853
+ cres = self.overloads[args]
1854
+
1855
+ return cres
1856
+
1857
+ def add_overload(self, kernel, argtypes):
1858
+ c_sig = [a._code for a in argtypes]
1859
+ self._insert(c_sig, kernel, cuda=True)
1860
+ self.overloads[argtypes] = kernel
1861
+
1862
+ @global_compiler_lock
1863
+ def compile(self, sig):
1864
+ """
1865
+ Compile and bind to the current context a version of this kernel
1866
+ specialized for the given signature.
1867
+ """
1868
+ argtypes, return_type = sigutils.normalize_signature(sig)
1869
+ assert return_type is None or return_type == types.none
1870
+
1871
+ # Do we already have an in-memory compiled kernel?
1872
+ if self.specialized:
1873
+ return next(iter(self.overloads.values()))
1874
+ else:
1875
+ kernel = self.overloads.get(argtypes)
1876
+ if kernel is not None:
1877
+ return kernel
1878
+
1879
+ # Can we load from the disk cache?
1880
+ kernel = self._cache.load_overload(sig, self.targetctx)
1881
+
1882
+ if kernel is not None:
1883
+ self._cache_hits[sig] += 1
1884
+ else:
1885
+ # We need to compile a new kernel
1886
+ self._cache_misses[sig] += 1
1887
+ if not self._can_compile:
1888
+ raise RuntimeError("Compilation disabled")
1889
+
1890
+ kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
1891
+ # We call bind to force codegen, so that there is a cubin to cache
1892
+ kernel.bind()
1893
+ self._cache.save_overload(sig, kernel)
1894
+
1895
+ self.add_overload(kernel, argtypes)
1896
+
1897
+ return kernel
1898
+
1899
+ def get_compile_result(self, sig):
1900
+ """Compile (if needed) and return the compilation result with the
1901
+ given signature.
1902
+
1903
+ Returns ``CompileResult``.
1904
+ Raises ``NumbaError`` if the signature is incompatible.
1905
+ """
1906
+ atypes = tuple(sig.args)
1907
+ if atypes not in self.overloads:
1908
+ if self._can_compile:
1909
+ # Compiling may raise any NumbaError
1910
+ self.compile(atypes)
1911
+ else:
1912
+ msg = f"{sig} not available and compilation disabled"
1913
+ raise errors.TypingError(msg)
1914
+ return self.overloads[atypes]
1915
+
1916
+ def recompile(self):
1917
+ """
1918
+ Recompile all signatures afresh.
1919
+ """
1920
+ sigs = list(self.overloads)
1921
+ old_can_compile = self._can_compile
1922
+ # Ensure the old overloads are disposed of,
1923
+ # including compiled functions.
1924
+ self._make_finalizer()()
1925
+ self._reset_overloads()
1926
+ self._cache.flush()
1927
+ self._can_compile = True
1928
+ try:
1929
+ for sig in sigs:
1930
+ self.compile(sig)
1931
+ finally:
1932
+ self._can_compile = old_can_compile
1933
+
1934
+ @property
1935
+ def stats(self):
1936
+ return _CompileStats(
1937
+ cache_path=self._cache.cache_path,
1938
+ cache_hits=self._cache_hits,
1939
+ cache_misses=self._cache_misses,
1940
+ )
1941
+
1942
+ def get_metadata(self, signature=None):
1943
+ """
1944
+ Obtain the compilation metadata for a given signature.
1945
+ """
1946
+ if signature is not None:
1947
+ return self.overloads[signature].metadata
1948
+ else:
1949
+ return dict(
1950
+ (sig, self.overloads[sig].metadata) for sig in self.signatures
1951
+ )
1952
+
1953
+ def get_function_type(self):
1954
+ """Return unique function type of dispatcher when possible, otherwise
1955
+ return None.
1956
+
1957
+ A Dispatcher instance has unique function type when it
1958
+ contains exactly one compilation result and its compilation
1959
+ has been disabled (via its disable_compile method).
1960
+ """
1961
+ if not self._can_compile and len(self.overloads) == 1:
1962
+ cres = tuple(self.overloads.values())[0]
1963
+ return types.FunctionType(cres.signature)
1964
+
1965
+ def inspect_llvm(self, signature=None):
1966
+ """
1967
+ Return the LLVM IR for this kernel.
1968
+
1969
+ :param signature: A tuple of argument types.
1970
+ :return: The LLVM IR for the given signature, or a dict of LLVM IR
1971
+ for all previously-encountered signatures.
1972
+
1973
+ """
1974
+ device = self.targetoptions.get("device")
1975
+ if signature is not None:
1976
+ if device:
1977
+ return self.overloads[signature].library.get_llvm_str()
1978
+ else:
1979
+ return self.overloads[signature].inspect_llvm()
1980
+ else:
1981
+ if device:
1982
+ return {
1983
+ sig: overload.library.get_llvm_str()
1984
+ for sig, overload in self.overloads.items()
1985
+ }
1986
+ else:
1987
+ return {
1988
+ sig: overload.inspect_llvm()
1989
+ for sig, overload in self.overloads.items()
1990
+ }
1991
+
1992
+ def inspect_asm(self, signature=None):
1993
+ """
1994
+ Return this kernel's PTX assembly code for for the device in the
1995
+ current context.
1996
+
1997
+ :param signature: A tuple of argument types.
1998
+ :return: The PTX code for the given signature, or a dict of PTX codes
1999
+ for all previously-encountered signatures.
2000
+ """
2001
+ cc = get_current_device().compute_capability
2002
+ device = self.targetoptions.get("device")
2003
+ if signature is not None:
2004
+ if device:
2005
+ return self.overloads[signature].library.get_asm_str(cc)
2006
+ else:
2007
+ return self.overloads[signature].inspect_asm(cc)
2008
+ else:
2009
+ if device:
2010
+ return {
2011
+ sig: overload.library.get_asm_str(cc)
2012
+ for sig, overload in self.overloads.items()
2013
+ }
2014
+ else:
2015
+ return {
2016
+ sig: overload.inspect_asm(cc)
2017
+ for sig, overload in self.overloads.items()
2018
+ }
2019
+
2020
+ def inspect_lto_ptx(self, signature=None):
2021
+ """
2022
+ Return link-time optimized PTX code for the given signature.
2023
+
2024
+ :param signature: A tuple of argument types.
2025
+ :return: The PTX code for the given signature, or a dict of PTX codes
2026
+ for all previously-encountered signatures.
2027
+ """
2028
+ cc = get_current_device().compute_capability
2029
+ device = self.targetoptions.get("device")
2030
+
2031
+ if signature is not None:
2032
+ if device:
2033
+ return self.overloads[signature].library.get_lto_ptx(cc)
2034
+ else:
2035
+ return self.overloads[signature].inspect_lto_ptx(cc)
2036
+ else:
2037
+ if device:
2038
+ return {
2039
+ sig: overload.library.get_lto_ptx(cc)
2040
+ for sig, overload in self.overloads.items()
2041
+ }
2042
+ else:
2043
+ return {
2044
+ sig: overload.inspect_lto_ptx(cc)
2045
+ for sig, overload in self.overloads.items()
2046
+ }
2047
+
2048
+ def inspect_sass_cfg(self, signature=None):
2049
+ """
2050
+ Return this kernel's CFG for the device in the current context.
2051
+
2052
+ :param signature: A tuple of argument types.
2053
+ :return: The CFG for the given signature, or a dict of CFGs
2054
+ for all previously-encountered signatures.
2055
+
2056
+ The CFG for the device in the current context is returned.
2057
+
2058
+ Requires nvdisasm to be available on the PATH.
2059
+ """
2060
+ if self.targetoptions.get("device"):
2061
+ raise RuntimeError("Cannot get the CFG of a device function")
2062
+
2063
+ if signature is not None:
2064
+ return self.overloads[signature].inspect_sass_cfg()
2065
+ else:
2066
+ return {
2067
+ sig: defn.inspect_sass_cfg()
2068
+ for sig, defn in self.overloads.items()
2069
+ }
2070
+
2071
+ def inspect_sass(self, signature=None):
2072
+ """
2073
+ Return this kernel's SASS assembly code for for the device in the
2074
+ current context.
2075
+
2076
+ :param signature: A tuple of argument types.
2077
+ :return: The SASS code for the given signature, or a dict of SASS codes
2078
+ for all previously-encountered signatures.
2079
+
2080
+ SASS for the device in the current context is returned.
2081
+
2082
+ Requires nvdisasm to be available on the PATH.
2083
+ """
2084
+ if self.targetoptions.get("device"):
2085
+ raise RuntimeError("Cannot inspect SASS of a device function")
2086
+
2087
+ if signature is not None:
2088
+ return self.overloads[signature].inspect_sass()
2089
+ else:
2090
+ return {
2091
+ sig: defn.inspect_sass() for sig, defn in self.overloads.items()
2092
+ }
2093
+
2094
+ def inspect_types(self, file=None):
2095
+ """
2096
+ Produce a dump of the Python source of this function annotated with the
2097
+ corresponding Numba IR and type information. The dump is written to
2098
+ *file*, or *sys.stdout* if *file* is *None*.
2099
+ """
2100
+ if file is None:
2101
+ file = sys.stdout
2102
+
2103
+ for _, defn in self.overloads.items():
2104
+ defn.inspect_types(file=file)
2105
+
2106
+ @classmethod
2107
+ def _rebuild(cls, py_func, targetoptions):
2108
+ """
2109
+ Rebuild an instance.
2110
+ """
2111
+ instance = cls(py_func, targetoptions)
2112
+ return instance
2113
+
2114
+ def _reduce_states(self):
2115
+ """
2116
+ Reduce the instance for serialization.
2117
+ Compiled definitions are discarded.
2118
+ """
2119
+ return dict(py_func=self.py_func, targetoptions=self.targetoptions)
2120
+
2121
+
2122
+ class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
2123
+ """
2124
+ Implementation of the hidden dispatcher objects used for lifted code
2125
+ (a lifted loop is really compiled as a separate function).
2126
+ """
2127
+
2128
+ _fold_args = False
2129
+ can_cache = False
2130
+
2131
+ def __init__(self, func_ir, typingctx, targetctx, flags, locals):
2132
+ self.func_ir = func_ir
2133
+ self.lifted_from = None
2134
+
2135
+ self.typingctx = typingctx
2136
+ self.targetctx = targetctx
2137
+ self.flags = flags
2138
+ self.locals = locals
2139
+
2140
+ _DispatcherBase.__init__(
2141
+ self,
2142
+ self.func_ir.arg_count,
2143
+ self.func_ir.func_id.func,
2144
+ self.func_ir.func_id.pysig,
2145
+ can_fallback=True,
2146
+ exact_match_required=False,
2147
+ )
2148
+
2149
+ def _reduce_states(self):
2150
+ """
2151
+ Reduce the instance for pickling. This will serialize
2152
+ the original function as well the compilation options and
2153
+ compiled signatures, but not the compiled code itself.
2154
+
2155
+ NOTE: part of ReduceMixin protocol
2156
+ """
2157
+ return dict(
2158
+ uuid=self._uuid,
2159
+ func_ir=self.func_ir,
2160
+ flags=self.flags,
2161
+ locals=self.locals,
2162
+ extras=self._reduce_extras(),
2163
+ )
2164
+
2165
+ def _reduce_extras(self):
2166
+ """
2167
+ NOTE: sub-class can override to add extra states
2168
+ """
2169
+ return {}
2170
+
2171
+ @classmethod
2172
+ def _rebuild(cls, uuid, func_ir, flags, locals, extras):
2173
+ """
2174
+ Rebuild an Dispatcher instance after it was __reduce__'d.
2175
+
2176
+ NOTE: part of ReduceMixin protocol
2177
+ """
2178
+ try:
2179
+ return cls._memo[uuid]
2180
+ except KeyError:
2181
+ pass
2182
+
2183
+ from numba.cuda.descriptor import cuda_target
2184
+
2185
+ typingctx = cuda_target.typing_context
2186
+ targetctx = cuda_target.target_context
2187
+
2188
+ self = cls(func_ir, typingctx, targetctx, flags, locals, **extras)
2189
+ self._set_uuid(uuid)
2190
+ return self
2191
+
2192
+ def get_source_location(self):
2193
+ """Return the starting line number of the loop."""
2194
+ return self.func_ir.loc.line
2195
+
2196
+ def _pre_compile(self, args, return_type, flags):
2197
+ """Pre-compile actions"""
2198
+ pass
2199
+
2200
+ @abstractmethod
2201
+ def compile(self, sig):
2202
+ """Lifted code should implement a compilation method that will return
2203
+ a CompileResult.entry_point for the given signature."""
2204
+ pass
2205
+
2206
+ def _get_dispatcher_for_current_target(self):
2207
+ # Lifted code does not honor the target switch currently.
2208
+ # No work has been done to check if this can be allowed.
2209
+ return self
2210
+
2211
+
2212
+ class LiftedLoop(LiftedCode):
2213
+ def _pre_compile(self, args, return_type, flags):
2214
+ assert not flags.enable_looplift, "Enable looplift flags is on"
2215
+
2216
+ def compile(self, sig):
2217
+ with ExitStack() as scope:
2218
+ cres = None
2219
+
2220
+ def cb_compiler(dur):
2221
+ if cres is not None:
2222
+ self._callback_add_compiler_timer(dur, cres)
2223
+
2224
+ def cb_llvm(dur):
2225
+ if cres is not None:
2226
+ self._callback_add_llvm_timer(dur, cres)
2227
+
2228
+ scope.enter_context(
2229
+ ev.install_timer("numba:compiler_lock", cb_compiler)
2230
+ )
2231
+ scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
2232
+ scope.enter_context(global_compiler_lock)
2233
+
2234
+ # Use counter to track recursion compilation depth
2235
+ with self._compiling_counter:
2236
+ # XXX this is mostly duplicated from Dispatcher.
2237
+ flags = self.flags
2238
+ args, return_type = sigutils.normalize_signature(sig)
2239
+
2240
+ # Don't recompile if signature already exists
2241
+ # (e.g. if another thread compiled it before we got the lock)
2242
+ existing = self.overloads.get(tuple(args))
2243
+ if existing is not None:
2244
+ return existing.entry_point
2245
+
2246
+ self._pre_compile(args, return_type, flags)
2247
+
2248
+ # copy the flags, use nopython first
2249
+ npm_loop_flags = flags.copy()
2250
+ npm_loop_flags.force_pyobject = False
2251
+
2252
+ pyobject_loop_flags = flags.copy()
2253
+ pyobject_loop_flags.force_pyobject = True
2254
+
2255
+ # Clone IR to avoid (some of the) mutation in the rewrite pass
2256
+ cloned_func_ir_npm = self.func_ir.copy()
2257
+ cloned_func_ir_fbk = self.func_ir.copy()
2258
+
2259
+ ev_details = dict(
2260
+ dispatcher=self,
2261
+ args=args,
2262
+ return_type=return_type,
2263
+ )
2264
+ with ev.trigger_event("numba:compile", data=ev_details):
2265
+ # this emulates "object mode fall-back", try nopython, if it
2266
+ # fails, then try again in object mode.
2267
+ try:
2268
+ cres = compile_ir(
2269
+ typingctx=self.typingctx,
2270
+ targetctx=self.targetctx,
2271
+ func_ir=cloned_func_ir_npm,
2272
+ args=args,
2273
+ return_type=return_type,
2274
+ flags=npm_loop_flags,
2275
+ locals=self.locals,
2276
+ lifted=(),
2277
+ lifted_from=self.lifted_from,
2278
+ is_lifted_loop=True,
2279
+ )
2280
+ except errors.TypingError:
2281
+ cres = compile_ir(
2282
+ typingctx=self.typingctx,
2283
+ targetctx=self.targetctx,
2284
+ func_ir=cloned_func_ir_fbk,
2285
+ args=args,
2286
+ return_type=return_type,
2287
+ flags=pyobject_loop_flags,
2288
+ locals=self.locals,
2289
+ lifted=(),
2290
+ lifted_from=self.lifted_from,
2291
+ is_lifted_loop=True,
2292
+ )
2293
+ # Check typing error if object mode is used
2294
+ if cres.typing_error is not None:
2295
+ raise cres.typing_error
2296
+ self.add_overload(cres)
2297
+ return cres.entry_point
2298
+
2299
+
2300
+ class LiftedWith(LiftedCode):
2301
+ can_cache = True
2302
+
2303
+ def _reduce_extras(self):
2304
+ return dict(output_types=self.output_types)
2305
+
2306
+ @property
2307
+ def _numba_type_(self):
2308
+ return types.Dispatcher(self)
2309
+
2310
+ def get_call_template(self, args, kws):
2311
+ """
2312
+ Get a typing.ConcreteTemplate for this dispatcher and the given
2313
+ *args* and *kws* types. This enables the resolving of the return type.
2314
+
2315
+ A (template, pysig, args, kws) tuple is returned.
2316
+ """
2317
+ # Ensure an overload is available
2318
+ if self._can_compile:
2319
+ self.compile(tuple(args))
2320
+
2321
+ pysig = None
2322
+ # Create function type for typing
2323
+ func_name = self.py_func.__name__
2324
+ name = "CallTemplate({0})".format(func_name)
2325
+ # The `key` isn't really used except for diagnosis here,
2326
+ # so avoid keeping a reference to `cfunc`.
2327
+ call_template = typing.make_concrete_template(
2328
+ name, key=func_name, signatures=self.nopython_signatures
2329
+ )
2330
+ return call_template, pysig, args, kws
2331
+
2332
+ def compile(self, sig):
2333
+ # this is similar to LiftedLoop's compile but does not have the
2334
+ # "fallback" to object mode part.
2335
+ with ExitStack() as scope:
2336
+ cres = None
2337
+
2338
+ def cb_compiler(dur):
2339
+ if cres is not None:
2340
+ self._callback_add_compiler_timer(dur, cres)
2341
+
2342
+ def cb_llvm(dur):
2343
+ if cres is not None:
2344
+ self._callback_add_llvm_timer(dur, cres)
2345
+
2346
+ scope.enter_context(
2347
+ ev.install_timer("numba:compiler_lock", cb_compiler)
2348
+ )
2349
+ scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
2350
+ scope.enter_context(global_compiler_lock)
2351
+
2352
+ # Use counter to track recursion compilation depth
2353
+ with self._compiling_counter:
2354
+ # XXX this is mostly duplicated from Dispatcher.
2355
+ flags = self.flags
2356
+ args, return_type = sigutils.normalize_signature(sig)
2357
+
2358
+ # Don't recompile if signature already exists
2359
+ # (e.g. if another thread compiled it before we got the lock)
2360
+ existing = self.overloads.get(tuple(args))
2361
+ if existing is not None:
2362
+ return existing.entry_point
2363
+
2364
+ self._pre_compile(args, return_type, flags)
2365
+
2366
+ # Clone IR to avoid (some of the) mutation in the rewrite pass
2367
+ cloned_func_ir = self.func_ir.copy()
2368
+
2369
+ ev_details = dict(
2370
+ dispatcher=self,
2371
+ args=args,
2372
+ return_type=return_type,
2373
+ )
2374
+ with ev.trigger_event("numba:compile", data=ev_details):
2375
+ cres = compile_ir(
2376
+ typingctx=self.typingctx,
2377
+ targetctx=self.targetctx,
2378
+ func_ir=cloned_func_ir,
2379
+ args=args,
2380
+ return_type=return_type,
2381
+ flags=flags,
2382
+ locals=self.locals,
2383
+ lifted=(),
2384
+ lifted_from=self.lifted_from,
2385
+ is_lifted_loop=True,
2386
+ )
2387
+
2388
+ # Check typing error if object mode is used
2389
+ if (
2390
+ cres.typing_error is not None
2391
+ and not flags.enable_pyobject
2392
+ ):
2393
+ raise cres.typing_error
2394
+ self.add_overload(cres)
2395
+ return cres.entry_point
2396
+
2397
+
2398
+ class ObjModeLiftedWith(LiftedWith):
2399
+ def __init__(self, *args, **kwargs):
2400
+ self.output_types = kwargs.pop("output_types", None)
2401
+ super(LiftedWith, self).__init__(*args, **kwargs)
2402
+ if not self.flags.force_pyobject:
2403
+ raise ValueError("expecting `flags.force_pyobject`")
2404
+ if self.output_types is None:
2405
+ raise TypeError("`output_types` must be provided")
2406
+ # switch off rewrites, they have no effect
2407
+ self.flags.no_rewrites = True
2408
+
2409
+ @property
2410
+ def _numba_type_(self):
2411
+ return types.ObjModeDispatcher(self)
2412
+
2413
+ def get_call_template(self, args, kws):
2414
+ """
2415
+ Get a typing.ConcreteTemplate for this dispatcher and the given
2416
+ *args* and *kws* types. This enables the resolving of the return type.
2417
+
2418
+ A (template, pysig, args, kws) tuple is returned.
2419
+ """
2420
+ assert not kws
2421
+ self._legalize_arg_types(args)
2422
+ # Coerce to object mode
2423
+ args = [types.ffi_forced_object] * len(args)
2424
+
2425
+ if self._can_compile:
2426
+ self.compile(tuple(args))
2427
+
2428
+ signatures = [typing.signature(self.output_types, *args)]
2429
+ pysig = None
2430
+ func_name = self.py_func.__name__
2431
+ name = "CallTemplate({0})".format(func_name)
2432
+ call_template = typing.make_concrete_template(
2433
+ name, key=func_name, signatures=signatures
2434
+ )
2435
+
2436
+ return call_template, pysig, args, kws
2437
+
2438
+ def _legalize_arg_types(self, args):
2439
+ for i, a in enumerate(args, start=1):
2440
+ if isinstance(a, types.List):
2441
+ msg = (
2442
+ "Does not support list type inputs into "
2443
+ "with-context for arg {}"
2444
+ )
2445
+ raise errors.TypingError(msg.format(i))
2446
+ elif isinstance(a, types.Dispatcher):
2447
+ msg = (
2448
+ "Does not support function type inputs into "
2449
+ "with-context for arg {}"
2450
+ )
2451
+ raise errors.TypingError(msg.format(i))
2452
+
2453
+ @global_compiler_lock
2454
+ def compile(self, sig):
2455
+ args, _ = sigutils.normalize_signature(sig)
2456
+ sig = (types.ffi_forced_object,) * len(args)
2457
+ return super().compile(sig)
2458
+
2459
+
2460
+ # Initialize typeof machinery
2461
+ _dispatcher.typeof_init(
2462
+ OmittedArg, dict((str(t), t._code) for t in types.number_domain)
2463
+ )