numba-cuda 0.21.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (488) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +577 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +556 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +951 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3222 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +558 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +995 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +903 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +158 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsic_wrapper.py +41 -0
  161. numba_cuda/numba/cuda/intrinsics.py +382 -0
  162. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  163. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  164. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  165. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  166. numba_cuda/numba/cuda/libdevice.py +3386 -0
  167. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  168. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  169. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  170. numba_cuda/numba/cuda/locks.py +19 -0
  171. numba_cuda/numba/cuda/lowering.py +1951 -0
  172. numba_cuda/numba/cuda/mathimpl.py +374 -0
  173. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  175. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  178. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  179. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  180. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  181. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  182. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  183. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  184. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  185. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  186. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  187. numba_cuda/numba/cuda/misc/literal.py +28 -0
  188. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  189. numba_cuda/numba/cuda/misc/special.py +94 -0
  190. numba_cuda/numba/cuda/models.py +56 -0
  191. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  192. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  193. numba_cuda/numba/cuda/np/extensions.py +11 -0
  194. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  195. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  196. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  197. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  198. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  199. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  200. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  201. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  202. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  203. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  204. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  206. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  207. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  208. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  209. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  210. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  211. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  212. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  213. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  214. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  215. numba_cuda/numba/cuda/printimpl.py +126 -0
  216. numba_cuda/numba/cuda/random.py +308 -0
  217. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  218. numba_cuda/numba/cuda/serialize.py +267 -0
  219. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  220. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  221. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  222. numba_cuda/numba/cuda/simulator/api.py +179 -0
  223. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  224. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  236. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  237. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  238. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  239. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  241. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  242. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  243. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  244. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  245. numba_cuda/numba/cuda/simulator_init.py +18 -0
  246. numba_cuda/numba/cuda/stubs.py +635 -0
  247. numba_cuda/numba/cuda/target.py +505 -0
  248. numba_cuda/numba/cuda/testing.py +347 -0
  249. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  251. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  252. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  253. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  254. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  255. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +187 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +198 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  285. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  286. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  289. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  290. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  291. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  292. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  293. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  294. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  295. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +889 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  396. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +331 -0
  397. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  399. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  400. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  401. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  402. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  403. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  404. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  406. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  407. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  424. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  425. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +391 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  430. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  431. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  433. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  434. numba_cuda/numba/cuda/tests/support.py +900 -0
  435. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  436. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  437. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  438. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  439. numba_cuda/numba/cuda/types/__init__.py +233 -0
  440. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  441. numba_cuda/numba/cuda/types/abstract.py +9 -0
  442. numba_cuda/numba/cuda/types/common.py +9 -0
  443. numba_cuda/numba/cuda/types/containers.py +9 -0
  444. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  445. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  446. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  447. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  448. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  449. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  450. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  451. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  452. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  453. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  454. numba_cuda/numba/cuda/types/function_type.py +11 -0
  455. numba_cuda/numba/cuda/types/functions.py +9 -0
  456. numba_cuda/numba/cuda/types/iterators.py +9 -0
  457. numba_cuda/numba/cuda/types/misc.py +9 -0
  458. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  459. numba_cuda/numba/cuda/types/scalars.py +9 -0
  460. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  461. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  462. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  463. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  464. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  465. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  466. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  467. numba_cuda/numba/cuda/typing/collections.py +138 -0
  468. numba_cuda/numba/cuda/typing/context.py +782 -0
  469. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  470. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  471. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  472. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  473. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  474. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  475. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  476. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  477. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  478. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  479. numba_cuda/numba/cuda/ufuncs.py +746 -0
  480. numba_cuda/numba/cuda/utils.py +724 -0
  481. numba_cuda/numba/cuda/vector_types.py +214 -0
  482. numba_cuda/numba/cuda/vectorizers.py +260 -0
  483. numba_cuda-0.21.1.dist-info/METADATA +109 -0
  484. numba_cuda-0.21.1.dist-info/RECORD +488 -0
  485. numba_cuda-0.21.1.dist-info/WHEEL +5 -0
  486. numba_cuda-0.21.1.dist-info/licenses/LICENSE +26 -0
  487. numba_cuda-0.21.1.dist-info/licenses/LICENSE.numba +24 -0
  488. numba_cuda-0.21.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1088 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from numba.cuda.testing import unittest, CUDATestCase
5
+ from numba.cuda.cudadrv.driver import _have_nvjitlink
6
+ from llvmlite import ir
7
+
8
+ import numpy as np
9
+ import os
10
+ from numba import cuda
11
+ from numba.cuda import HAS_NUMBA
12
+ from numba.cuda.testing import skip_on_standalone_numba_cuda
13
+ from numba.cuda import types
14
+ from numba.cuda import config
15
+
16
+ if config.ENABLE_CUDASIM:
17
+ raise unittest.SkipTest("Simulator does not support extending types")
18
+
19
+ import inspect
20
+ import math
21
+ import pickle
22
+ import unittest
23
+
24
+
25
+ import numba
26
+ from numba import njit
27
+ from numba.cuda import cgutils, jit
28
+ from numba.cuda.tests.support import TestCase, override_config
29
+ from numba.cuda.typing.templates import AttributeTemplate
30
+ from numba.cuda.cudadecl import registry as cuda_registry
31
+ from numba.cuda.cudaimpl import lower_attr as cuda_lower_attr
32
+
33
+ from numba.core import errors
34
+ from numba.cuda.errors import LoweringError
35
+
36
+ from numba.cuda.extending import (
37
+ type_callable,
38
+ lower_builtin,
39
+ overload,
40
+ overload_method,
41
+ intrinsic,
42
+ _Intrinsic,
43
+ register_jitable,
44
+ core_models,
45
+ typeof_impl,
46
+ register_model,
47
+ make_attribute_wrapper,
48
+ )
49
+
50
+ TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR")
51
+ if TEST_BIN_DIR:
52
+ test_device_functions_a = os.path.join(
53
+ TEST_BIN_DIR, "test_device_functions.a"
54
+ )
55
+ test_device_functions_cubin = os.path.join(
56
+ TEST_BIN_DIR, "test_device_functions.cubin"
57
+ )
58
+ test_device_functions_cu = os.path.join(
59
+ TEST_BIN_DIR, "test_device_functions.cu"
60
+ )
61
+ test_device_functions_fatbin = os.path.join(
62
+ TEST_BIN_DIR, "test_device_functions.fatbin"
63
+ )
64
+ test_device_functions_fatbin_multi = os.path.join(
65
+ TEST_BIN_DIR, "test_device_functions_multi.fatbin"
66
+ )
67
+ test_device_functions_o = os.path.join(
68
+ TEST_BIN_DIR, "test_device_functions.o"
69
+ )
70
+ test_device_functions_ptx = os.path.join(
71
+ TEST_BIN_DIR, "test_device_functions.ptx"
72
+ )
73
+ test_device_functions_ltoir = os.path.join(
74
+ TEST_BIN_DIR, "test_device_functions.ltoir"
75
+ )
76
+
77
+
78
+ class Interval:
79
+ """
80
+ A half-open interval on the real number line.
81
+ """
82
+
83
+ def __init__(self, lo, hi):
84
+ self.lo = lo
85
+ self.hi = hi
86
+
87
+ def __repr__(self):
88
+ return "Interval(%f, %f)" % (self.lo, self.hi)
89
+
90
+ @property
91
+ def width(self):
92
+ return self.hi - self.lo
93
+
94
+
95
+ if HAS_NUMBA:
96
+ from numba import njit
97
+ else:
98
+ njit = None
99
+
100
+
101
+ @njit
102
+ def interval_width(interval):
103
+ return interval.width
104
+
105
+
106
+ @njit
107
+ def sum_intervals(i, j):
108
+ return Interval(i.lo + j.lo, i.hi + j.hi)
109
+
110
+
111
+ class IntervalType(types.Type):
112
+ def __init__(self):
113
+ super().__init__(name="Interval")
114
+
115
+
116
+ interval_type = IntervalType()
117
+
118
+
119
+ @typeof_impl.register(Interval)
120
+ def typeof_interval(val, c):
121
+ return interval_type
122
+
123
+
124
+ @type_callable(Interval)
125
+ def type_interval(context):
126
+ def typer(lo, hi):
127
+ if isinstance(lo, types.Float) and isinstance(hi, types.Float):
128
+ return interval_type
129
+
130
+ return typer
131
+
132
+
133
+ @register_model(IntervalType)
134
+ class IntervalModel(core_models.StructModel):
135
+ def __init__(self, dmm, fe_type):
136
+ members = [
137
+ ("lo", types.float64),
138
+ ("hi", types.float64),
139
+ ]
140
+ core_models.StructModel.__init__(self, dmm, fe_type, members)
141
+
142
+
143
+ make_attribute_wrapper(IntervalType, "lo", "lo")
144
+ make_attribute_wrapper(IntervalType, "hi", "hi")
145
+
146
+
147
+ @lower_builtin(Interval, types.Float, types.Float)
148
+ def impl_interval(context, builder, sig, args):
149
+ typ = sig.return_type
150
+ lo, hi = args
151
+ interval = cgutils.create_struct_proxy(typ)(context, builder)
152
+ interval.lo = lo
153
+ interval.hi = hi
154
+ return interval._getvalue()
155
+
156
+
157
+ @cuda_registry.register_attr
158
+ class Interval_attrs(AttributeTemplate):
159
+ key = IntervalType
160
+
161
+ def resolve_width(self, mod):
162
+ return types.float64
163
+
164
+
165
+ @cuda_lower_attr(IntervalType, "width")
166
+ def cuda_Interval_width(context, builder, sig, arg):
167
+ lo = builder.extract_value(arg, 0)
168
+ hi = builder.extract_value(arg, 1)
169
+ return builder.fsub(hi, lo)
170
+
171
+
172
+ # -----------------------------------------------------------------------
173
+ # Define a function's typing and implementation using the classical
174
+ # two-step API
175
+
176
+
177
+ def func1(x=None):
178
+ raise NotImplementedError
179
+
180
+
181
+ def type_func1_(context):
182
+ def typer(x=None):
183
+ if x in (None, types.none):
184
+ # 0-arg or 1-arg with None
185
+ return types.int32
186
+ elif isinstance(x, types.Float):
187
+ # 1-arg with float
188
+ return x
189
+
190
+ return typer
191
+
192
+
193
+ type_func1 = type_callable(func1)(type_func1_)
194
+
195
+
196
+ @lower_builtin(func1)
197
+ @lower_builtin(func1, types.none)
198
+ def func1_nullary(context, builder, sig, args):
199
+ return context.get_constant(sig.return_type, 42)
200
+
201
+
202
+ @lower_builtin(func1, types.Float)
203
+ def func1_unary(context, builder, sig, args):
204
+ def func1_impl(x):
205
+ return math.sqrt(2 * x)
206
+
207
+ return context.compile_internal(builder, func1_impl, sig, args)
208
+
209
+
210
+ # -----------------------------------------------------------------------
211
+ # Overload an already defined built-in function, extending it for new types.
212
+
213
+
214
+ def call_func1_nullary(res):
215
+ res[0] = func1()
216
+
217
+
218
+ def call_func1_unary(x, res):
219
+ res[0] = func1(x)
220
+
221
+
222
+ class TestExtending(CUDATestCase):
223
+ def test_attributes(self):
224
+ @cuda.jit
225
+ def f(r, x):
226
+ iv = Interval(x[0], x[1])
227
+ r[0] = iv.lo
228
+ r[1] = iv.hi
229
+
230
+ x = np.asarray((1.5, 2.5))
231
+ r = np.zeros_like(x)
232
+
233
+ f[1, 1](r, x)
234
+
235
+ np.testing.assert_equal(r, x)
236
+
237
+ def test_property(self):
238
+ @cuda.jit
239
+ def f(r, x):
240
+ iv = Interval(x[0], x[1])
241
+ r[0] = iv.width
242
+
243
+ x = np.asarray((1.5, 2.5))
244
+ r = np.zeros(1)
245
+
246
+ f[1, 1](r, x)
247
+
248
+ np.testing.assert_allclose(r[0], x[1] - x[0])
249
+
250
+ @skip_on_standalone_numba_cuda
251
+ def test_extension_type_as_arg(self):
252
+ @cuda.jit
253
+ def f(r, x):
254
+ iv = Interval(x[0], x[1])
255
+ r[0] = interval_width(iv)
256
+
257
+ x = np.asarray((1.5, 2.5))
258
+ r = np.zeros(1)
259
+
260
+ f[1, 1](r, x)
261
+
262
+ np.testing.assert_allclose(r[0], x[1] - x[0])
263
+
264
+ @skip_on_standalone_numba_cuda
265
+ def test_extension_type_as_retvalue(self):
266
+ @cuda.jit
267
+ def f(r, x):
268
+ iv1 = Interval(x[0], x[1])
269
+ iv2 = Interval(x[2], x[3])
270
+ iv_sum = sum_intervals(iv1, iv2)
271
+ r[0] = iv_sum.lo
272
+ r[1] = iv_sum.hi
273
+
274
+ x = np.asarray((1.5, 2.5, 3.0, 4.0))
275
+ r = np.zeros(2)
276
+
277
+ f[1, 1](r, x)
278
+
279
+ expected = np.asarray((x[0] + x[2], x[1] + x[3]))
280
+ np.testing.assert_allclose(r, expected)
281
+
282
+
283
+ class TestExtendingLinkage(CUDATestCase):
284
+ @unittest.skipUnless(TEST_BIN_DIR, "Necessary binaries are not available")
285
+ def test_extension_adds_linkable_code(self):
286
+ files = (
287
+ (test_device_functions_a, cuda.Archive),
288
+ (test_device_functions_cubin, cuda.Cubin),
289
+ (test_device_functions_cu, cuda.CUSource),
290
+ (test_device_functions_fatbin, cuda.Fatbin),
291
+ (test_device_functions_o, cuda.Object),
292
+ (test_device_functions_ptx, cuda.PTXSource),
293
+ (test_device_functions_ltoir, cuda.LTOIR),
294
+ )
295
+
296
+ lto = _have_nvjitlink()
297
+
298
+ for path, ctor in files:
299
+ if ctor == cuda.LTOIR and not lto:
300
+ # Don't try to test with LTOIR if LTO is not enabled
301
+ continue
302
+
303
+ with open(path, "rb") as f:
304
+ code_object = ctor(f.read())
305
+
306
+ def external_add(x, y):
307
+ return x + y
308
+
309
+ @type_callable(external_add)
310
+ def type_external_add(context):
311
+ def typer(x, y):
312
+ if x == types.uint32 and y == types.uint32:
313
+ return types.uint32
314
+
315
+ return typer
316
+
317
+ @lower_builtin(external_add, types.uint32, types.uint32)
318
+ def lower_external_add(context, builder, sig, args):
319
+ context.active_code_library.add_linking_file(code_object)
320
+ i32 = ir.IntType(32)
321
+ fnty = ir.FunctionType(i32, [i32, i32])
322
+ fn = cgutils.get_or_insert_function(
323
+ builder.module, fnty, "add_cabi"
324
+ )
325
+ return builder.call(fn, args)
326
+
327
+ @cuda.jit(lto=lto)
328
+ def use_external_add(r, x, y):
329
+ r[0] = external_add(x[0], y[0])
330
+
331
+ r = np.zeros(1, dtype=np.uint32)
332
+ x = np.ones(1, dtype=np.uint32)
333
+ y = np.ones(1, dtype=np.uint32) * 2
334
+
335
+ use_external_add[1, 1](r, x, y)
336
+
337
+ np.testing.assert_equal(r[0], 3)
338
+
339
+ @cuda.jit(lto=lto)
340
+ def use_external_add_device(x, y):
341
+ return external_add(x, y)
342
+
343
+ @cuda.jit(lto=lto)
344
+ def use_external_add_kernel(r, x, y):
345
+ r[0] = use_external_add_device(x[0], y[0])
346
+
347
+ r = np.zeros(1, dtype=np.uint32)
348
+ x = np.ones(1, dtype=np.uint32)
349
+ y = np.ones(1, dtype=np.uint32) * 2
350
+
351
+ use_external_add_kernel[1, 1](r, x, y)
352
+
353
+ np.testing.assert_equal(r[0], 3)
354
+
355
+ def test_linked_called_through_overload(self):
356
+ cu_code = cuda.CUSource("""
357
+ extern "C" __device__
358
+ int bar(int *out, int a)
359
+ {
360
+ *out = a * 2;
361
+ return 0;
362
+ }
363
+ """)
364
+
365
+ bar = cuda.declare_device("bar", "int32(int32)", link=cu_code)
366
+
367
+ def bar_call(val):
368
+ pass
369
+
370
+ @overload(bar_call, target="cuda")
371
+ def ol_bar_call(a):
372
+ return lambda a: bar(a)
373
+
374
+ @cuda.jit("void(int32[::1], int32[::1])")
375
+ def foo(r, x):
376
+ i = cuda.grid(1)
377
+ if i < len(r):
378
+ r[i] = bar_call(x[i])
379
+
380
+ x = np.arange(10, dtype=np.int32)
381
+ r = np.empty_like(x)
382
+
383
+ foo[1, 32](r, x)
384
+
385
+ np.testing.assert_equal(r, x * 2)
386
+
387
+
388
+ class TestLowLevelExtending(TestCase):
389
+ """
390
+ Test the low-level two-tier extension API.
391
+ """
392
+
393
+ # Check with `@jit` from within the test process and also in a new test
394
+ # process so as to check the registration mechanism.
395
+
396
+ def test_func1(self):
397
+ pyfunc = call_func1_nullary
398
+ cfunc = jit(pyfunc)
399
+ res = np.zeros(1)
400
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
401
+ cfunc[1, 1](res)
402
+ self.assertPreciseEqual(res[0], 42.0)
403
+ pyfunc = call_func1_unary
404
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
405
+ cfunc = jit(pyfunc)
406
+ self.assertPreciseEqual(res[0], 42.0)
407
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
408
+ cfunc[1, 1](18.0, res)
409
+ self.assertPreciseEqual(res[0], 6.0)
410
+
411
+ @TestCase.run_test_in_subprocess
412
+ def test_func1_isolated(self):
413
+ self.test_func1()
414
+
415
+ def test_type_callable_keeps_function(self):
416
+ self.assertIs(type_func1, type_func1_)
417
+ self.assertIsNotNone(type_func1)
418
+
419
+
420
+ class TestHighLevelExtending(TestCase):
421
+ """
422
+ Test the high-level combined API.
423
+ """
424
+
425
+ def test_typing_vs_impl_signature_mismatch_handling(self):
426
+ """
427
+ Tests that an overload which has a differing typing and implementing
428
+ signature raises an exception.
429
+ """
430
+
431
+ def gen_ol(impl=None):
432
+ def myoverload(a, b, c, kw=None):
433
+ pass
434
+
435
+ @overload(myoverload)
436
+ def _myoverload_impl(a, b, c, kw=None):
437
+ return impl
438
+
439
+ @jit
440
+ def foo(a, b, c, d):
441
+ myoverload(a, b, c, kw=d)
442
+
443
+ return foo
444
+
445
+ sentinel = "Typing and implementation arguments differ in"
446
+
447
+ # kwarg value is different
448
+ def impl1(a, b, c, kw=12):
449
+ if a > 10:
450
+ return 1
451
+ else:
452
+ return -1
453
+
454
+ with self.assertRaises(errors.TypingError) as e:
455
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
456
+ gen_ol(impl1)[1, 1](1, 2, 3, 4)
457
+ msg = str(e.exception)
458
+ self.assertIn(sentinel, msg)
459
+ self.assertIn("keyword argument default values", msg)
460
+ self.assertIn('<Parameter "kw=12">', msg)
461
+ self.assertIn('<Parameter "kw=None">', msg)
462
+
463
+ # kwarg name is different
464
+ def impl2(a, b, c, kwarg=None):
465
+ if a > 10:
466
+ return 1
467
+ else:
468
+ return -1
469
+
470
+ with self.assertRaises(errors.TypingError) as e:
471
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
472
+ gen_ol(impl2)[1, 1](1, 2, 3, 4)
473
+ msg = str(e.exception)
474
+ self.assertIn(sentinel, msg)
475
+ self.assertIn("keyword argument names", msg)
476
+ self.assertIn('<Parameter "kwarg=None">', msg)
477
+ self.assertIn('<Parameter "kw=None">', msg)
478
+
479
+ # arg name is different
480
+ def impl3(z, b, c, kw=None):
481
+ if a > 10: # noqa: F821
482
+ return 1
483
+ else:
484
+ return -1
485
+
486
+ with self.assertRaises(errors.TypingError) as e:
487
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
488
+ gen_ol(impl3)[1, 1](1, 2, 3, 4)
489
+ msg = str(e.exception)
490
+ self.assertIn(sentinel, msg)
491
+ self.assertIn("argument names", msg)
492
+ self.assertFalse("keyword" in msg)
493
+ self.assertIn('<Parameter "a">', msg)
494
+ self.assertIn('<Parameter "z">', msg)
495
+
496
+ from .overload_usecases import impl4, impl5
497
+
498
+ with self.assertRaises(errors.TypingError) as e:
499
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
500
+ gen_ol(impl4)[1, 1](1, 2, 3, 4)
501
+ msg = str(e.exception)
502
+ self.assertIn(sentinel, msg)
503
+ self.assertIn("argument names", msg)
504
+ self.assertFalse("keyword" in msg)
505
+ self.assertIn("First difference: 'z'", msg)
506
+
507
+ with self.assertRaises(errors.TypingError) as e:
508
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
509
+ gen_ol(impl5)[1, 1](1, 2, 3, 4)
510
+ msg = str(e.exception)
511
+ self.assertIn(sentinel, msg)
512
+ self.assertIn("argument names", msg)
513
+ self.assertFalse("keyword" in msg)
514
+ self.assertIn('<Parameter "a">', msg)
515
+ self.assertIn('<Parameter "z">', msg)
516
+
517
+ # too many args
518
+ def impl6(a, b, c, d, e, kw=None):
519
+ if a > 10:
520
+ return 1
521
+ else:
522
+ return -1
523
+
524
+ with self.assertRaises(errors.TypingError) as e:
525
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
526
+ gen_ol(impl6)[1, 1](1, 2, 3, 4)
527
+ msg = str(e.exception)
528
+ self.assertIn(sentinel, msg)
529
+ self.assertIn("argument names", msg)
530
+ self.assertFalse("keyword" in msg)
531
+ self.assertIn('<Parameter "d">', msg)
532
+ self.assertIn('<Parameter "e">', msg)
533
+
534
+ # too few args
535
+ def impl7(a, b, kw=None):
536
+ if a > 10:
537
+ return 1
538
+ else:
539
+ return -1
540
+
541
+ with self.assertRaises(errors.TypingError) as e:
542
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
543
+ gen_ol(impl7)[1, 1](1, 2, 3, 4)
544
+ msg = str(e.exception)
545
+ self.assertIn(sentinel, msg)
546
+ self.assertIn("argument names", msg)
547
+ self.assertFalse("keyword" in msg)
548
+ self.assertIn('<Parameter "c">', msg)
549
+
550
+ # too many kwargs
551
+ def impl8(a, b, c, kw=None, extra_kwarg=None):
552
+ if a > 10:
553
+ return 1
554
+ else:
555
+ return -1
556
+
557
+ with self.assertRaises(errors.TypingError) as e:
558
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
559
+ gen_ol(impl8)[1, 1](1, 2, 3, 4)
560
+ msg = str(e.exception)
561
+ self.assertIn(sentinel, msg)
562
+ self.assertIn("keyword argument names", msg)
563
+ self.assertIn('<Parameter "extra_kwarg=None">', msg)
564
+
565
+ # too few kwargs
566
+ def impl9(a, b, c):
567
+ if a > 10:
568
+ return 1
569
+ else:
570
+ return -1
571
+
572
+ with self.assertRaises(errors.TypingError) as e:
573
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
574
+ gen_ol(impl9)[1, 1](1, 2, 3, 4)
575
+ msg = str(e.exception)
576
+ self.assertIn(sentinel, msg)
577
+ self.assertIn("keyword argument names", msg)
578
+ self.assertIn('<Parameter "kw=None">', msg)
579
+
580
+ def test_typing_vs_impl_signature_mismatch_handling_var_positional(self):
581
+ """
582
+ Tests that an overload which has a differing typing and implementing
583
+ signature raises an exception and uses VAR_POSITIONAL (*args) in typing
584
+ """
585
+
586
+ def myoverload(a, kw=None):
587
+ pass
588
+
589
+ from .overload_usecases import var_positional_impl
590
+
591
+ overload(myoverload)(var_positional_impl)
592
+
593
+ @jit
594
+ def foo(a, b):
595
+ myoverload(a, b, 9, kw=11)
596
+
597
+ with self.assertRaises(errors.TypingError) as e:
598
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
599
+ foo[1, 1](1, 5)
600
+ msg = str(e.exception)
601
+ self.assertIn("VAR_POSITIONAL (e.g. *args) argument kind", msg)
602
+ self.assertIn("offending argument name is '*star_args_token'", msg)
603
+
604
+ def test_typing_vs_impl_signature_mismatch_handling_var_keyword(self):
605
+ """
606
+ Tests that an overload which uses **kwargs (VAR_KEYWORD)
607
+ """
608
+
609
+ def gen_ol(impl, strict=True):
610
+ def myoverload(a, kw=None):
611
+ pass
612
+
613
+ overload(myoverload, strict=strict)(impl)
614
+
615
+ @jit
616
+ def foo(a, b):
617
+ myoverload(a, kw=11)
618
+
619
+ return foo
620
+
621
+ # **kwargs in typing
622
+ def ol1(a, **kws):
623
+ def impl(a, kw=10):
624
+ return a
625
+
626
+ return impl
627
+
628
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
629
+ gen_ol(ol1, False)[1, 1](
630
+ 1, 2
631
+ ) # no error if strictness not enforced
632
+ with self.assertRaises(errors.TypingError) as e:
633
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
634
+ gen_ol(ol1)[1, 1](1, 2)
635
+ msg = str(e.exception)
636
+ self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg)
637
+ self.assertIn("offending argument name is '**kws'", msg)
638
+
639
+ # **kwargs in implementation
640
+ def ol2(a, kw=0):
641
+ def impl(a, **kws):
642
+ return a
643
+
644
+ return impl
645
+
646
+ with self.assertRaises(errors.TypingError) as e:
647
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
648
+ gen_ol(ol2)[1, 1](1, 2)
649
+ msg = str(e.exception)
650
+ self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg)
651
+ self.assertIn("offending argument name is '**kws'", msg)
652
+
653
+ def test_overload_method_kwargs(self):
654
+ # Issue #3489
655
+ @overload_method(types.Array, "foo")
656
+ def fooimpl(arr, a_kwarg=10):
657
+ def impl(arr, a_kwarg=10):
658
+ return a_kwarg
659
+
660
+ return impl
661
+
662
+ @jit
663
+ def bar(A, res):
664
+ res[0] = A.foo()
665
+ res[1] = A.foo(20)
666
+ res[2] = A.foo(a_kwarg=30)
667
+
668
+ Z = np.arange(5)
669
+ res = np.zeros(3)
670
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
671
+ bar[1, 1](Z, res)
672
+ self.assertEqual(res[0], 10)
673
+ self.assertEqual(res[1], 20)
674
+ self.assertEqual(res[2], 30)
675
+
676
+ def test_overload_method_literal_unpack(self):
677
+ # Issue #3683
678
+ @overload_method(types.Array, "litfoo")
679
+ def litfoo(arr, val):
680
+ # Must be an integer
681
+ if isinstance(val, types.Integer):
682
+ # Must not be literal
683
+ if not isinstance(val, types.Literal):
684
+
685
+ def impl(arr, val):
686
+ return val
687
+
688
+ return impl
689
+
690
+ @jit
691
+ def bar(A, res):
692
+ res[0] = A.litfoo(0xCAFE)
693
+
694
+ A = np.zeros(1)
695
+ res = np.zeros(1)
696
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
697
+ bar[1, 1](A, res)
698
+ self.assertEqual(res[0], 0xCAFE)
699
+
700
+
701
+ def _assert_cache_stats(cfunc, expect_hit, expect_misses):
702
+ hit = cfunc._cache_hits[cfunc.signatures[0]]
703
+ if hit != expect_hit:
704
+ raise AssertionError("cache not used")
705
+ miss = cfunc._cache_misses[cfunc.signatures[0]]
706
+ if miss != expect_misses:
707
+ raise AssertionError("cache not used")
708
+
709
+
710
+ class TestIntrinsic(TestCase):
711
+ def test_void_return(self):
712
+ """
713
+ Verify that returning a None from codegen function is handled
714
+ automatically for void functions, otherwise raise exception.
715
+ """
716
+
717
+ @intrinsic
718
+ def void_func(typingctx, a):
719
+ sig = types.void(types.int32)
720
+
721
+ def codegen(context, builder, signature, args):
722
+ pass # do nothing, return None, should be turned into
723
+ # dummy value
724
+
725
+ return sig, codegen
726
+
727
+ @intrinsic
728
+ def non_void_func(typingctx, a):
729
+ sig = types.int32(types.int32)
730
+
731
+ def codegen(context, builder, signature, args):
732
+ pass # oops, should be returning a value here, raise exception
733
+
734
+ return sig, codegen
735
+
736
+ @jit
737
+ def call_void_func():
738
+ void_func(1)
739
+
740
+ @jit
741
+ def call_non_void_func():
742
+ non_void_func(1)
743
+
744
+ # void func should work
745
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
746
+ self.assertEqual(call_void_func[1, 1](), None)
747
+ # not void function should raise exception
748
+ with self.assertRaises(LoweringError) as e:
749
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
750
+ call_non_void_func[1, 1]()
751
+ self.assertIn("non-void function returns None", e.exception.msg)
752
+
753
+ def test_serialization(self):
754
+ """
755
+ Test serialization of intrinsic objects
756
+ """
757
+
758
+ # define a intrinsic
759
+ @intrinsic
760
+ def identity(context, x):
761
+ def codegen(context, builder, signature, args):
762
+ return args[0]
763
+
764
+ sig = x(x)
765
+ return sig, codegen
766
+
767
+ # use in a jit function
768
+ @jit
769
+ def foo(x):
770
+ identity(x)
771
+
772
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
773
+ self.assertEqual(foo[1, 1](1), None)
774
+
775
+ # get serialization memo
776
+ memo = _Intrinsic._memo
777
+ memo_size = len(memo)
778
+
779
+ # pickle foo and check memo size
780
+ serialized_foo = pickle.dumps(foo)
781
+ # increases the memo size
782
+ memo_size += 1
783
+ self.assertEqual(memo_size, len(memo))
784
+ # unpickle
785
+ foo_rebuilt = pickle.loads(serialized_foo)
786
+ self.assertEqual(memo_size, len(memo))
787
+ # check rebuilt foo
788
+
789
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
790
+ self.assertEqual(foo[1, 1](1), foo_rebuilt[1, 1](1))
791
+
792
+ # pickle identity directly
793
+ serialized_identity = pickle.dumps(identity)
794
+ # memo size unchanged
795
+ self.assertEqual(memo_size, len(memo))
796
+ # unpickle
797
+ identity_rebuilt = pickle.loads(serialized_identity)
798
+ # must be the same object
799
+ self.assertIs(identity, identity_rebuilt)
800
+ # memo size unchanged
801
+ self.assertEqual(memo_size, len(memo))
802
+
803
+ def test_deserialization(self):
804
+ """
805
+ Test deserialization of intrinsic
806
+ """
807
+
808
+ def defn(context, x):
809
+ def codegen(context, builder, signature, args):
810
+ return args[0]
811
+
812
+ return x(x), codegen
813
+
814
+ memo = _Intrinsic._memo
815
+ memo_size = len(memo)
816
+ # invoke _Intrinsic indirectly to avoid registration which keeps an
817
+ # internal reference inside the compiler
818
+ original = _Intrinsic("foo", defn)
819
+ self.assertIs(original._defn, defn)
820
+ pickled = pickle.dumps(original)
821
+ # by pickling, a new memo entry is created
822
+ memo_size += 1
823
+ self.assertEqual(memo_size, len(memo))
824
+ del original # remove original before unpickling
825
+
826
+ # by deleting, the memo entry is NOT removed due to recent
827
+ # function queue
828
+ self.assertEqual(memo_size, len(memo))
829
+
830
+ # Manually force clear of _recent queue
831
+ _Intrinsic._recent.clear()
832
+ memo_size -= 1
833
+ self.assertEqual(memo_size, len(memo))
834
+
835
+ rebuilt = pickle.loads(pickled)
836
+ # verify that the rebuilt object is different
837
+ self.assertIsNot(rebuilt._defn, defn)
838
+
839
+ # the second rebuilt object is the same as the first
840
+ second = pickle.loads(pickled)
841
+ self.assertIs(rebuilt._defn, second._defn)
842
+
843
+ def test_docstring(self):
844
+ @intrinsic
845
+ def void_func(typingctx, a: int):
846
+ """void_func docstring"""
847
+ sig = types.void(types.int32)
848
+
849
+ def codegen(context, builder, signature, args):
850
+ pass # do nothing, return None, should be turned into
851
+ # dummy value
852
+
853
+ return sig, codegen
854
+
855
+ self.assertEqual(
856
+ "numba.cuda.tests.cudapy.test_extending", void_func.__module__
857
+ )
858
+ self.assertEqual("void_func", void_func.__name__)
859
+ self.assertEqual(
860
+ "TestIntrinsic.test_docstring.<locals>.void_func",
861
+ void_func.__qualname__,
862
+ )
863
+ self.assertDictEqual({"a": int}, void_func.__annotations__)
864
+ self.assertEqual("void_func docstring", void_func.__doc__)
865
+
866
+
867
+ class TestRegisterJitable(unittest.TestCase):
868
+ def test_no_flags(self):
869
+ @register_jitable
870
+ def foo(x, y):
871
+ x[0] += y
872
+
873
+ def bar(x, y):
874
+ foo(x, y)
875
+ x[0] += x[0]
876
+
877
+ cbar = jit(bar)
878
+
879
+ x = np.array([1, 2])
880
+ bar(x, 2)
881
+ self.assertEqual(x[0], 6)
882
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
883
+ cbar[1, 1](x, 2)
884
+ self.assertEqual(x[0], 16)
885
+
886
+
887
+ class TestOverloadPreferLiteral(TestCase):
888
+ def test_overload(self):
889
+ def prefer_lit(x):
890
+ pass
891
+
892
+ def non_lit(x):
893
+ pass
894
+
895
+ def ov(x):
896
+ if isinstance(x, types.IntegerLiteral):
897
+ # With prefer_literal=False, this branch will not be reached.
898
+ if x.literal_value == 1:
899
+
900
+ def impl(x):
901
+ return 0xCAFE
902
+
903
+ return impl
904
+ else:
905
+ raise errors.TypingError("literal value")
906
+ else:
907
+
908
+ def impl(x):
909
+ return x * 100
910
+
911
+ return impl
912
+
913
+ overload(prefer_lit, prefer_literal=True)(ov)
914
+ overload(non_lit)(ov)
915
+
916
+ @jit
917
+ def check_prefer_lit(x, res):
918
+ res[0] = prefer_lit(1)
919
+ res[1] = prefer_lit(2)
920
+ res[2] = prefer_lit(x)
921
+
922
+ res = np.zeros(3)
923
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
924
+ check_prefer_lit[1, 1](3, res)
925
+ a, b, c = res
926
+ self.assertEqual(a, 0xCAFE)
927
+ self.assertEqual(b, 200)
928
+ self.assertEqual(c, 300)
929
+
930
+ @jit
931
+ def check_non_lit(x, res):
932
+ res[0] = non_lit(1)
933
+ res[1] = non_lit(2)
934
+ res[2] = non_lit(x)
935
+
936
+ with override_config("DISABLE_PERFORMANCE_WARNINGS", 1):
937
+ check_non_lit[1, 1](3, res)
938
+ a, b, c = res
939
+ self.assertEqual(a, 100)
940
+ self.assertEqual(b, 200)
941
+ self.assertEqual(c, 300)
942
+
943
+
944
+ class TestNumbaInternalOverloads(TestCase):
945
+ def test_signatures_match_overloaded_api(self):
946
+ # This is a "best-effort" test to try and ensure that Numba's internal
947
+ # overload declarations have signatures with argument names that match
948
+ # the API they are overloading. The purpose of ensuring there is a
949
+ # match is so that users can use call-by-name for positional arguments.
950
+
951
+ # Set this to:
952
+ # 0 to make violations raise a ValueError (default).
953
+ # 1 to get violations reported to STDOUT
954
+ # 2 to get a verbose output of everything that was checked and its state
955
+ # reported to STDOUT.
956
+ DEBUG = 0
957
+
958
+ # np.random.* does not have a signature exposed to `inspect`... so
959
+ # custom parse the docstrings.
960
+ def sig_from_np_random(x):
961
+ if not x.startswith("_"):
962
+ thing = getattr(np.random, x)
963
+ if inspect.isbuiltin(thing):
964
+ docstr = thing.__doc__.splitlines()
965
+ for l in docstr:
966
+ if l:
967
+ sl = l.strip()
968
+ if sl.startswith(x): # its the signature
969
+ # special case np.random.seed, it has `self` in
970
+ # the signature whereas all the other functions
971
+ # do not!?
972
+ if x == "seed":
973
+ sl = "seed(seed)"
974
+
975
+ fake_impl = f"def {sl}:\n\tpass"
976
+ l = {}
977
+ try:
978
+ exec(fake_impl, {}, l)
979
+ except SyntaxError:
980
+ # likely elipsis, e.g. rand(d0, d1, ..., dn)
981
+ if DEBUG == 2:
982
+ print(
983
+ "... skipped as cannot parse "
984
+ "signature"
985
+ )
986
+ return None
987
+ else:
988
+ fn = l.get(x)
989
+ return inspect.signature(fn)
990
+
991
+ def checker(func, overload_func):
992
+ if DEBUG == 2:
993
+ print(f"Checking: {func}")
994
+
995
+ def create_message(func, overload_func, func_sig, ol_sig):
996
+ msg = []
997
+ s = (
998
+ f"{func} from module '{getattr(func, '__module__')}' "
999
+ "has mismatched sig."
1000
+ )
1001
+ msg.append(s)
1002
+ msg.append(f" - expected: {func_sig}")
1003
+ msg.append(f" - got: {ol_sig}")
1004
+ lineno = inspect.getsourcelines(overload_func)[1]
1005
+ tmpsrcfile = inspect.getfile(overload_func)
1006
+ srcfile = tmpsrcfile.replace(numba.__path__[0], "")
1007
+ msg.append(f"from {srcfile}:{lineno}")
1008
+ msgstr = "\n" + "\n".join(msg)
1009
+ return msgstr
1010
+
1011
+ func_sig = None
1012
+ try:
1013
+ func_sig = inspect.signature(func)
1014
+ except ValueError:
1015
+ # probably a built-in/C code, see if it's a np.random function
1016
+ if fname := getattr(func, "__name__", False):
1017
+ if maybe_func := getattr(np.random, fname, False):
1018
+ if maybe_func == func:
1019
+ # it's a built-in from np.random
1020
+ func_sig = sig_from_np_random(fname)
1021
+
1022
+ if func_sig is not None:
1023
+ ol_sig = inspect.signature(overload_func)
1024
+ x = list(func_sig.parameters.keys())
1025
+ y = list(ol_sig.parameters.keys())
1026
+ for a, b in zip(x[: len(y)], y):
1027
+ if a != b:
1028
+ p = func_sig.parameters[a]
1029
+ if p.kind == p.POSITIONAL_ONLY:
1030
+ # probably a built-in/C code
1031
+ if DEBUG == 2:
1032
+ print(
1033
+ "... skipped as positional only "
1034
+ "arguments found"
1035
+ )
1036
+ break
1037
+ elif "*" in str(p): # probably *args or similar
1038
+ if DEBUG == 2:
1039
+ print("... skipped as contains *args")
1040
+ break
1041
+ else:
1042
+ # Only error/report on functions that have a module
1043
+ # or are from somewhere other than Numba.
1044
+ if (
1045
+ not func.__module__
1046
+ or not func.__module__.startswith("numba")
1047
+ ):
1048
+ msgstr = create_message(
1049
+ func, overload_func, func_sig, ol_sig
1050
+ )
1051
+ if DEBUG != 0:
1052
+ if DEBUG == 2:
1053
+ print("... INVALID")
1054
+ if msgstr:
1055
+ print(msgstr)
1056
+ break
1057
+ else:
1058
+ raise ValueError(msgstr)
1059
+ else:
1060
+ if DEBUG == 2:
1061
+ if not func.__module__:
1062
+ print(
1063
+ "... skipped as no __module__ "
1064
+ "present"
1065
+ )
1066
+ else:
1067
+ print("... skipped as Numba internal")
1068
+ break
1069
+ else:
1070
+ if DEBUG == 2:
1071
+ print("... OK")
1072
+
1073
+ # Compile something to make sure that the typing context registries
1074
+ # are populated with everything from the CPU target.
1075
+ jit(lambda: None).compile(())
1076
+ tyctx = numba.cuda.target.CUDATypingContext()
1077
+ tyctx.refresh()
1078
+
1079
+ # Walk the registries and check each function that is an overload
1080
+ regs = tyctx._registries
1081
+ for k, v in regs.items():
1082
+ for item in k.functions:
1083
+ if getattr(item, "_overload_func", False):
1084
+ checker(item.key, item._overload_func)
1085
+
1086
+
1087
+ if __name__ == "__main__":
1088
+ unittest.main()