numba-cuda 0.18.1__py3-none-any.whl → 0.19.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (301) hide show
  1. _numba_cuda_redirector.pth +3 -0
  2. _numba_cuda_redirector.py +3 -0
  3. numba_cuda/VERSION +1 -1
  4. numba_cuda/__init__.py +2 -1
  5. numba_cuda/_version.py +2 -13
  6. numba_cuda/numba/cuda/__init__.py +4 -1
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +5 -2
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +4 -1
  9. numba_cuda/numba/cuda/api.py +5 -7
  10. numba_cuda/numba/cuda/api_util.py +3 -0
  11. numba_cuda/numba/cuda/args.py +3 -0
  12. numba_cuda/numba/cuda/bf16.py +3 -0
  13. numba_cuda/numba/cuda/cg.py +3 -0
  14. numba_cuda/numba/cuda/cgutils.py +3 -0
  15. numba_cuda/numba/cuda/codegen.py +3 -0
  16. numba_cuda/numba/cuda/compiler.py +10 -4
  17. numba_cuda/numba/cuda/core/caching.py +3 -0
  18. numba_cuda/numba/cuda/core/callconv.py +3 -0
  19. numba_cuda/numba/cuda/core/codegen.py +3 -0
  20. numba_cuda/numba/cuda/core/compiler.py +3 -0
  21. numba_cuda/numba/cuda/core/interpreter.py +3595 -0
  22. numba_cuda/numba/cuda/core/ir_utils.py +2644 -0
  23. numba_cuda/numba/cuda/core/sigutils.py +58 -0
  24. numba_cuda/numba/cuda/core/typed_passes.py +3 -0
  25. numba_cuda/numba/cuda/cuda_paths.py +12 -17
  26. numba_cuda/numba/cuda/cudadecl.py +4 -1
  27. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -0
  28. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  29. numba_cuda/numba/cuda/cudadrv/devices.py +3 -0
  30. numba_cuda/numba/cuda/cudadrv/driver.py +7 -19
  31. numba_cuda/numba/cuda/cudadrv/drvapi.py +3 -0
  32. numba_cuda/numba/cuda/cudadrv/dummyarray.py +3 -0
  33. numba_cuda/numba/cuda/cudadrv/enums.py +3 -0
  34. numba_cuda/numba/cuda/cudadrv/error.py +4 -0
  35. numba_cuda/numba/cuda/cudadrv/libs.py +4 -2
  36. numba_cuda/numba/cuda/cudadrv/linkable_code.py +3 -0
  37. numba_cuda/numba/cuda/cudadrv/mappings.py +3 -0
  38. numba_cuda/numba/cuda/cudadrv/ndarray.py +3 -0
  39. numba_cuda/numba/cuda/cudadrv/nvrtc.py +47 -44
  40. numba_cuda/numba/cuda/cudadrv/nvvm.py +6 -18
  41. numba_cuda/numba/cuda/cudadrv/rtapi.py +3 -0
  42. numba_cuda/numba/cuda/cudadrv/runtime.py +15 -1
  43. numba_cuda/numba/cuda/cudaimpl.py +3 -0
  44. numba_cuda/numba/cuda/cudamath.py +4 -1
  45. numba_cuda/numba/cuda/debuginfo.py +3 -0
  46. numba_cuda/numba/cuda/decorators.py +7 -3
  47. numba_cuda/numba/cuda/descriptor.py +3 -0
  48. numba_cuda/numba/cuda/device_init.py +3 -0
  49. numba_cuda/numba/cuda/deviceufunc.py +5 -1
  50. numba_cuda/numba/cuda/dispatcher.py +6 -2
  51. numba_cuda/numba/cuda/errors.py +10 -0
  52. numba_cuda/numba/cuda/extending.py +4 -1
  53. numba_cuda/numba/cuda/flags.py +2 -0
  54. numba_cuda/numba/cuda/fp16.py +3 -0
  55. numba_cuda/numba/cuda/initialize.py +4 -0
  56. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -0
  57. numba_cuda/numba/cuda/intrinsics.py +3 -0
  58. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  59. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  60. numba_cuda/numba/cuda/kernels/reduction.py +3 -0
  61. numba_cuda/numba/cuda/kernels/transpose.py +3 -0
  62. numba_cuda/numba/cuda/libdevice.py +4 -0
  63. numba_cuda/numba/cuda/libdevicedecl.py +4 -1
  64. numba_cuda/numba/cuda/libdevicefuncs.py +4 -1
  65. numba_cuda/numba/cuda/libdeviceimpl.py +3 -0
  66. numba_cuda/numba/cuda/locks.py +3 -0
  67. numba_cuda/numba/cuda/lowering.py +53 -16
  68. numba_cuda/numba/cuda/mathimpl.py +3 -0
  69. numba_cuda/numba/cuda/memory_management/__init__.py +3 -0
  70. numba_cuda/numba/cuda/memory_management/memsys.cu +5 -0
  71. numba_cuda/numba/cuda/memory_management/memsys.cuh +5 -0
  72. numba_cuda/numba/cuda/memory_management/nrt.cu +5 -0
  73. numba_cuda/numba/cuda/memory_management/nrt.cuh +5 -0
  74. numba_cuda/numba/cuda/memory_management/nrt.py +5 -1
  75. numba_cuda/numba/cuda/models.py +3 -0
  76. numba_cuda/numba/cuda/nvvmutils.py +3 -0
  77. numba_cuda/numba/cuda/printimpl.py +3 -0
  78. numba_cuda/numba/cuda/random.py +3 -0
  79. numba_cuda/numba/cuda/reshape_funcs.cu +5 -0
  80. numba_cuda/numba/cuda/serialize.py +3 -0
  81. numba_cuda/numba/cuda/simulator/__init__.py +3 -0
  82. numba_cuda/numba/cuda/simulator/_internal/__init__.py +3 -0
  83. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  84. numba_cuda/numba/cuda/simulator/api.py +4 -1
  85. numba_cuda/numba/cuda/simulator/bf16.py +3 -0
  86. numba_cuda/numba/cuda/simulator/compiler.py +3 -0
  87. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +3 -0
  88. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +3 -0
  89. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +3 -0
  90. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +3 -7
  91. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +3 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +3 -0
  93. numba_cuda/numba/cuda/simulator/cudadrv/error.py +4 -0
  94. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +4 -0
  95. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +4 -0
  96. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +3 -0
  97. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -0
  98. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -0
  99. numba_cuda/numba/cuda/simulator/dispatcher.py +4 -0
  100. numba_cuda/numba/cuda/simulator/kernel.py +3 -0
  101. numba_cuda/numba/cuda/simulator/kernelapi.py +3 -0
  102. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +3 -0
  103. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +3 -0
  104. numba_cuda/numba/cuda/simulator/reduction.py +3 -0
  105. numba_cuda/numba/cuda/simulator/vector_types.py +3 -0
  106. numba_cuda/numba/cuda/simulator_init.py +3 -0
  107. numba_cuda/numba/cuda/stubs.py +3 -0
  108. numba_cuda/numba/cuda/target.py +4 -2
  109. numba_cuda/numba/cuda/testing.py +7 -6
  110. numba_cuda/numba/cuda/tests/__init__.py +3 -0
  111. numba_cuda/numba/cuda/tests/complex_usecases.py +3 -0
  112. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +3 -0
  113. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  114. numba_cuda/numba/cuda/tests/core/test_serialize.py +3 -0
  115. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +3 -0
  116. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +3 -0
  117. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +3 -0
  118. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +3 -0
  119. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +3 -0
  120. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +3 -0
  121. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +3 -0
  122. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -0
  123. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +3 -0
  124. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +4 -1
  125. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +4 -1
  126. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +4 -1
  127. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +4 -1
  128. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +3 -0
  129. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +3 -0
  130. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +3 -0
  131. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +3 -0
  132. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +3 -0
  133. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +4 -1
  134. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +4 -1
  135. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +3 -0
  136. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +4 -1
  137. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +3 -0
  138. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +7 -6
  139. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +3 -4
  140. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +3 -0
  141. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +3 -0
  142. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +4 -1
  143. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +3 -0
  144. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +3 -0
  145. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +3 -0
  146. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +3 -0
  147. numba_cuda/numba/cuda/tests/cudapy/__init__.py +3 -0
  148. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +3 -0
  149. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +3 -0
  150. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +3 -0
  151. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +3 -0
  152. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +3 -0
  153. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +3 -0
  154. numba_cuda/numba/cuda/tests/cudapy/test_array.py +3 -0
  155. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +3 -0
  156. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +3 -0
  157. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +3 -0
  158. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +3 -0
  159. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +4 -3
  160. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +4 -3
  161. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +3 -0
  162. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -0
  163. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +149 -3
  164. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +3 -0
  165. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +4 -1
  166. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +3 -4
  167. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +3 -0
  168. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +3 -0
  169. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +3 -0
  170. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +3 -0
  171. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +3 -0
  172. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +4 -1
  173. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +4 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +3 -0
  175. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +4 -1
  176. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +23 -284
  177. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  178. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +4 -1
  179. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +3 -0
  180. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +3 -0
  181. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +4 -1
  182. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +3 -0
  183. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +4 -6
  184. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +3 -0
  185. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +3 -0
  186. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +3 -0
  187. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +3 -0
  188. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -0
  189. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +4 -1
  190. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +3 -0
  191. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +3 -0
  192. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +3 -0
  193. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +3 -0
  194. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +3 -0
  195. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +3 -0
  196. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +4 -1
  197. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +298 -0
  198. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +3 -0
  199. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +3 -0
  200. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +3 -0
  201. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +3 -0
  202. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +4 -1
  203. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +3 -0
  204. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +3 -0
  205. numba_cuda/numba/cuda/tests/cudapy/test_math.py +3 -0
  206. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +3 -0
  207. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +3 -0
  208. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +3 -0
  209. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +3 -0
  210. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +3 -0
  211. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +3 -0
  212. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +3 -0
  213. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +4 -1
  214. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +3 -0
  215. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +3 -0
  216. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +3 -0
  217. numba_cuda/numba/cuda/tests/cudapy/test_print.py +3 -0
  218. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +3 -0
  219. numba_cuda/numba/cuda/tests/cudapy/test_random.py +3 -0
  220. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +3 -0
  221. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +3 -0
  222. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +3 -0
  223. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +3 -0
  224. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +3 -0
  225. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +3 -0
  226. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +3 -0
  227. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +3 -0
  228. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +3 -0
  229. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +3 -0
  230. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +3 -0
  231. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +4 -1
  232. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +3 -0
  233. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +3 -0
  234. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +3 -0
  235. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +3 -0
  236. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +3 -0
  237. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +3 -0
  238. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +3 -0
  239. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +8 -1
  240. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +3 -0
  241. numba_cuda/numba/cuda/tests/cudasim/__init__.py +3 -0
  242. numba_cuda/numba/cuda/tests/cudasim/support.py +3 -0
  243. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +3 -0
  244. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  245. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +5 -0
  246. numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
  247. numba_cuda/numba/cuda/tests/data/error.cu +5 -0
  248. numba_cuda/numba/cuda/tests/data/include/add.cuh +5 -0
  249. numba_cuda/numba/cuda/tests/data/jitlink.cu +5 -0
  250. numba_cuda/numba/cuda/tests/data/warn.cu +5 -0
  251. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +3 -0
  252. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  253. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +5 -0
  254. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +5 -0
  255. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +5 -0
  256. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +3 -0
  257. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +4 -1
  258. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -1
  259. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +4 -1
  260. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +4 -1
  261. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +4 -1
  262. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +4 -1
  263. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +3 -0
  264. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +4 -1
  265. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +4 -1
  266. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +4 -1
  267. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +4 -1
  268. numba_cuda/numba/cuda/tests/enum_usecases.py +3 -0
  269. numba_cuda/numba/cuda/tests/nocuda/__init__.py +3 -0
  270. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +3 -0
  271. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +3 -0
  272. numba_cuda/numba/cuda/tests/nocuda/test_import.py +4 -1
  273. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +3 -0
  274. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +3 -0
  275. numba_cuda/numba/cuda/tests/nrt/__init__.py +3 -0
  276. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +5 -2
  277. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +4 -1
  278. numba_cuda/numba/cuda/tests/support.py +755 -0
  279. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +6 -3
  280. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +6 -2
  281. numba_cuda/numba/cuda/tests/test_binary_generation/nrt_extern.cu +5 -0
  282. numba_cuda/numba/cuda/tests/test_binary_generation/test_device_functions.cu +5 -0
  283. numba_cuda/numba/cuda/tests/test_binary_generation/undefined_extern.cu +5 -0
  284. numba_cuda/numba/cuda/types.py +3 -0
  285. numba_cuda/numba/cuda/typing/__init__.py +11 -0
  286. numba_cuda/numba/cuda/typing/templates.py +1448 -0
  287. numba_cuda/numba/cuda/ufuncs.py +3 -0
  288. numba_cuda/numba/cuda/utils.py +3 -0
  289. numba_cuda/numba/cuda/vector_types.py +6 -3
  290. numba_cuda/numba/cuda/vectorizers.py +3 -0
  291. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.1.dist-info}/METADATA +25 -29
  292. numba_cuda-0.19.1.dist-info/RECORD +302 -0
  293. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.1.dist-info}/licenses/LICENSE +1 -0
  294. numba_cuda-0.19.1.dist-info/licenses/LICENSE.numba +24 -0
  295. numba_cuda/numba/cuda/include/11/cuda_bf16.h +0 -3749
  296. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +0 -2683
  297. numba_cuda/numba/cuda/include/11/cuda_fp16.h +0 -3794
  298. numba_cuda/numba/cuda/include/11/cuda_fp16.hpp +0 -2614
  299. numba_cuda-0.18.1.dist-info/RECORD +0 -296
  300. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.1.dist-info}/WHEEL +0 -0
  301. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1448 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Define typing templates
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ import functools
10
+ import sys
11
+ import inspect
12
+ import os.path
13
+ from collections import namedtuple
14
+ from collections.abc import Sequence
15
+ from types import MethodType, FunctionType, MappingProxyType
16
+
17
+ import numba
18
+ from numba.core import types, utils, targetconfig
19
+ from numba.core.errors import (
20
+ TypingError,
21
+ InternalError,
22
+ )
23
+ from numba.core.cpu_options import InlineOptions
24
+ from numba.core.typing.templates import Signature as CoreSignature
25
+ from numba.cuda.core import ir_utils
26
+
27
+ # info store for inliner callback functions e.g. cost model
28
+ _inline_info = namedtuple("inline_info", "func_ir typemap calltypes signature")
29
+
30
+
31
+ # HACK: Remove this inheritance once all references to CoreSignature are removed
32
+ class Signature(CoreSignature):
33
+ """
34
+ The signature of a function call or operation, i.e. its argument types
35
+ and return type.
36
+ """
37
+
38
+ # XXX Perhaps the signature should be a BoundArguments, instead
39
+ # of separate args and pysig...
40
+ __slots__ = "_return_type", "_args", "_recvr", "_pysig"
41
+
42
+ def __init__(self, return_type, args, recvr, pysig=None):
43
+ if isinstance(args, list):
44
+ args = tuple(args)
45
+ self._return_type = return_type
46
+ self._args = args
47
+ self._recvr = recvr
48
+ self._pysig = pysig
49
+
50
+ @property
51
+ def return_type(self):
52
+ return self._return_type
53
+
54
+ @property
55
+ def args(self):
56
+ return self._args
57
+
58
+ @property
59
+ def recvr(self):
60
+ return self._recvr
61
+
62
+ @property
63
+ def pysig(self):
64
+ return self._pysig
65
+
66
+ def replace(self, **kwargs):
67
+ """Copy and replace the given attributes provided as keyword arguments.
68
+ Returns an updated copy.
69
+ """
70
+ curstate = dict(
71
+ return_type=self.return_type,
72
+ args=self.args,
73
+ recvr=self.recvr,
74
+ pysig=self.pysig,
75
+ )
76
+ curstate.update(kwargs)
77
+ return Signature(**curstate)
78
+
79
+ def __getstate__(self):
80
+ """
81
+ Needed because of __slots__.
82
+ """
83
+ return self._return_type, self._args, self._recvr, self._pysig
84
+
85
+ def __setstate__(self, state):
86
+ """
87
+ Needed because of __slots__.
88
+ """
89
+ self._return_type, self._args, self._recvr, self._pysig = state
90
+
91
+ def __hash__(self):
92
+ return hash((self.args, self.return_type))
93
+
94
+ def __eq__(self, other):
95
+ if isinstance(other, Signature):
96
+ return (
97
+ self.args == other.args
98
+ and self.return_type == other.return_type
99
+ and self.recvr == other.recvr
100
+ and self.pysig == other.pysig
101
+ )
102
+
103
+ def __ne__(self, other):
104
+ return not (self == other)
105
+
106
+ def __repr__(self):
107
+ return "%s -> %s" % (self.args, self.return_type)
108
+
109
+ @property
110
+ def is_method(self):
111
+ """
112
+ Whether this signature represents a bound method or a regular
113
+ function.
114
+ """
115
+ return self.recvr is not None
116
+
117
+ def as_method(self):
118
+ """
119
+ Convert this signature to a bound method signature.
120
+ """
121
+ if self.recvr is not None:
122
+ return self
123
+ sig = signature(self.return_type, *self.args[1:], recvr=self.args[0])
124
+
125
+ # Adjust the python signature
126
+ params = list(self.pysig.parameters.values())[1:]
127
+ sig = sig.replace(
128
+ pysig=utils.pySignature(
129
+ parameters=params,
130
+ return_annotation=self.pysig.return_annotation,
131
+ ),
132
+ )
133
+ return sig
134
+
135
+ def as_function(self):
136
+ """
137
+ Convert this signature to a regular function signature.
138
+ """
139
+ if self.recvr is None:
140
+ return self
141
+ sig = signature(self.return_type, *((self.recvr,) + self.args))
142
+ return sig
143
+
144
+ def as_type(self):
145
+ """
146
+ Convert this signature to a first-class function type.
147
+ """
148
+ return types.FunctionType(self)
149
+
150
+ def __unliteral__(self):
151
+ return signature(
152
+ types.unliteral(self.return_type), *map(types.unliteral, self.args)
153
+ )
154
+
155
+ def dump(self, tab=""):
156
+ c = self.as_type()._code
157
+ print(f"{tab}DUMP {type(self).__name__} [type code: {c}]")
158
+ print(f"{tab} Argument types:")
159
+ for a in self.args:
160
+ a.dump(tab=tab + " | ")
161
+ print(f"{tab} Return type:")
162
+ self.return_type.dump(tab=tab + " | ")
163
+ print(f"{tab}END DUMP")
164
+
165
+ def is_precise(self):
166
+ for atype in self.args:
167
+ if not atype.is_precise():
168
+ return False
169
+ return self.return_type.is_precise()
170
+
171
+
172
+ def make_concrete_template(name, key, signatures):
173
+ baseclasses = (ConcreteTemplate,)
174
+ gvars = dict(key=key, cases=list(signatures))
175
+ return type(name, baseclasses, gvars)
176
+
177
+
178
+ def make_callable_template(key, typer, recvr=None):
179
+ """
180
+ Create a callable template with the given key and typer function.
181
+ """
182
+
183
+ def generic(self):
184
+ return typer
185
+
186
+ name = "%s_CallableTemplate" % (key,)
187
+ bases = (CallableTemplate,)
188
+ class_dict = dict(key=key, generic=generic, recvr=recvr)
189
+ return type(name, bases, class_dict)
190
+
191
+
192
+ def signature(return_type, *args, **kws):
193
+ recvr = kws.pop("recvr", None)
194
+ assert not kws
195
+ return Signature(return_type, args, recvr=recvr)
196
+
197
+
198
+ def fold_arguments(
199
+ pysig, args, kws, normal_handler, default_handler, stararg_handler
200
+ ):
201
+ """
202
+ Given the signature *pysig*, explicit *args* and *kws*, resolve
203
+ omitted arguments and keyword arguments. A tuple of positional
204
+ arguments is returned.
205
+ Various handlers allow to process arguments:
206
+ - normal_handler(index, param, value) is called for normal arguments
207
+ - default_handler(index, param, default) is called for omitted arguments
208
+ - stararg_handler(index, param, values) is called for a "*args" argument
209
+ """
210
+ if isinstance(kws, Sequence):
211
+ # Normalize dict kws
212
+ kws = dict(kws)
213
+
214
+ # deal with kwonly args
215
+ params = pysig.parameters
216
+ kwonly = []
217
+ for name, p in params.items():
218
+ if p.kind == p.KEYWORD_ONLY:
219
+ kwonly.append(name)
220
+
221
+ if kwonly:
222
+ bind_args = args[: -len(kwonly)]
223
+ else:
224
+ bind_args = args
225
+ bind_kws = kws.copy()
226
+ if kwonly:
227
+ for idx, n in enumerate(kwonly):
228
+ bind_kws[n] = args[len(kwonly) + idx]
229
+
230
+ # now bind
231
+ try:
232
+ ba = pysig.bind(*bind_args, **bind_kws)
233
+ except TypeError as e:
234
+ # The binding attempt can raise if the args don't match up, this needs
235
+ # to be converted to a TypingError so that e.g. partial type inference
236
+ # doesn't just halt.
237
+ msg = (
238
+ f"Cannot bind 'args={bind_args} kws={bind_kws}' to "
239
+ f"signature '{pysig}' due to \"{type(e).__name__}: {e}\"."
240
+ )
241
+ raise TypingError(msg)
242
+ for i, param in enumerate(pysig.parameters.values()):
243
+ name = param.name
244
+ default = param.default
245
+ if param.kind == param.VAR_POSITIONAL:
246
+ # stararg may be omitted, in which case its "default" value
247
+ # is simply the empty tuple
248
+ if name in ba.arguments:
249
+ argval = ba.arguments[name]
250
+ # NOTE: avoid wrapping the tuple type for stararg in another
251
+ # tuple.
252
+ if len(argval) == 1 and isinstance(
253
+ argval[0], (types.StarArgTuple, types.StarArgUniTuple)
254
+ ):
255
+ argval = tuple(argval[0])
256
+ else:
257
+ argval = ()
258
+ out = stararg_handler(i, param, argval)
259
+
260
+ ba.arguments[name] = out
261
+ elif name in ba.arguments:
262
+ # Non-stararg, present
263
+ ba.arguments[name] = normal_handler(i, param, ba.arguments[name])
264
+ else:
265
+ # Non-stararg, omitted
266
+ assert default is not param.empty
267
+ ba.arguments[name] = default_handler(i, param, default)
268
+ # Collect args in the right order
269
+ args = tuple(
270
+ ba.arguments[param.name] for param in pysig.parameters.values()
271
+ )
272
+ return args
273
+
274
+
275
+ class FunctionTemplate(ABC):
276
+ # Set to true to disable unsafe cast.
277
+ # subclass overide-able
278
+ unsafe_casting = True
279
+ # Set to true to require exact match without casting.
280
+ # subclass overide-able
281
+ exact_match_required = False
282
+ # Set to true to prefer literal arguments.
283
+ # Useful for definitions that specialize on literal but also support
284
+ # non-literals.
285
+ # subclass overide-able
286
+ prefer_literal = False
287
+ # metadata
288
+ metadata = {}
289
+
290
+ def __init__(self, context):
291
+ self.context = context
292
+
293
+ def _select(self, cases, args, kws):
294
+ options = {
295
+ "unsafe_casting": self.unsafe_casting,
296
+ "exact_match_required": self.exact_match_required,
297
+ }
298
+ selected = self.context.resolve_overload(
299
+ self.key, cases, args, kws, **options
300
+ )
301
+ return selected
302
+
303
+ def get_impl_key(self, sig):
304
+ """
305
+ Return the key for looking up the implementation for the given
306
+ signature on the target context.
307
+ """
308
+ # Lookup the key on the class, to avoid binding it with `self`.
309
+ key = type(self).key
310
+ # On Python 2, we must also take care about unbound methods
311
+ if isinstance(key, MethodType):
312
+ assert key.im_self is None
313
+ key = key.im_func
314
+ return key
315
+
316
+ @classmethod
317
+ def get_source_code_info(cls, impl):
318
+ """
319
+ Gets the source information about function impl.
320
+ Returns:
321
+
322
+ code - str: source code as a string
323
+ firstlineno - int: the first line number of the function impl
324
+ path - str: the path to file containing impl
325
+
326
+ if any of the above are not available something generic is returned
327
+ """
328
+ try:
329
+ code, firstlineno = inspect.getsourcelines(impl)
330
+ except OSError: # missing source, probably a string
331
+ code = "None available (built from string?)"
332
+ firstlineno = 0
333
+ path = inspect.getsourcefile(impl)
334
+ if path is None:
335
+ path = "<unknown> (built from string?)"
336
+ return code, firstlineno, path
337
+
338
+ @abstractmethod
339
+ def get_template_info(self):
340
+ """
341
+ Returns a dictionary with information specific to the template that will
342
+ govern how error messages are displayed to users. The dictionary must
343
+ be of the form:
344
+ info = {
345
+ 'kind': "unknown", # str: The kind of template, e.g. "Overload"
346
+ 'name': "unknown", # str: The name of the source function
347
+ 'sig': "unknown", # str: The signature(s) of the source function
348
+ 'filename': "unknown", # str: The filename of the source function
349
+ 'lines': ("start", "end"), # tuple(int, int): The start and
350
+ end line of the source function.
351
+ 'docstring': "unknown" # str: The docstring of the source function
352
+ }
353
+ """
354
+ pass
355
+
356
+ def __str__(self):
357
+ info = self.get_template_info()
358
+ srcinfo = f"{info['filename']}:{info['lines'][0]}"
359
+ return f"<{self.__class__.__name__} {srcinfo}>"
360
+
361
+ __repr__ = __str__
362
+
363
+
364
+ class AbstractTemplate(FunctionTemplate):
365
+ """
366
+ Defines method ``generic(self, args, kws)`` which compute a possible
367
+ signature base on input types. The signature does not have to match the
368
+ input types. It is compared against the input types afterwards.
369
+ """
370
+
371
+ def apply(self, args, kws):
372
+ generic = getattr(self, "generic")
373
+ sig = generic(args, kws)
374
+ # Enforce that *generic()* must return None or Signature
375
+ if sig is not None:
376
+ # HACK: Remove this inheritance once all references to CoreSignature are removed
377
+ if not isinstance(
378
+ sig, (Signature, numba.core.typing.templates.Signature)
379
+ ):
380
+ raise AssertionError(
381
+ "generic() must return a Signature or None. "
382
+ "{} returned {}".format(generic, type(sig)),
383
+ )
384
+
385
+ # Unpack optional type if no matching signature
386
+ if not sig and any(isinstance(x, types.Optional) for x in args):
387
+
388
+ def unpack_opt(x):
389
+ if isinstance(x, types.Optional):
390
+ return x.type
391
+ else:
392
+ return x
393
+
394
+ args = list(map(unpack_opt, args))
395
+ assert not kws # Not supported yet
396
+ sig = generic(args, kws)
397
+
398
+ return sig
399
+
400
+ def get_template_info(self):
401
+ impl = getattr(self, "generic")
402
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
403
+
404
+ code, firstlineno, path = self.get_source_code_info(impl)
405
+ sig = str(utils.pysignature(impl))
406
+ info = {
407
+ "kind": "overload",
408
+ "name": getattr(impl, "__qualname__", impl.__name__),
409
+ "sig": sig,
410
+ "filename": utils.safe_relpath(path, start=basepath),
411
+ "lines": (firstlineno, firstlineno + len(code) - 1),
412
+ "docstring": impl.__doc__,
413
+ }
414
+ return info
415
+
416
+
417
+ class CallableTemplate(FunctionTemplate):
418
+ """
419
+ Base class for a template defining a ``generic(self)`` method
420
+ returning a callable to be called with the actual ``*args`` and
421
+ ``**kwargs`` representing the call signature. The callable has
422
+ to return a return type, a full signature, or None. The signature
423
+ does not have to match the input types. It is compared against the
424
+ input types afterwards.
425
+ """
426
+
427
+ recvr = None
428
+
429
+ def apply(self, args, kws):
430
+ generic = getattr(self, "generic")
431
+ typer = generic()
432
+ match_sig = inspect.signature(typer)
433
+ try:
434
+ match_sig.bind(*args, **kws)
435
+ except TypeError as e:
436
+ # bind failed, raise, if there's a
437
+ # ValueError then there's likely unrecoverable
438
+ # problems
439
+ raise TypingError(str(e)) from e
440
+
441
+ sig = typer(*args, **kws)
442
+
443
+ # Unpack optional type if no matching signature
444
+ if sig is None:
445
+ if any(isinstance(x, types.Optional) for x in args):
446
+
447
+ def unpack_opt(x):
448
+ if isinstance(x, types.Optional):
449
+ return x.type
450
+ else:
451
+ return x
452
+
453
+ args = list(map(unpack_opt, args))
454
+ sig = typer(*args, **kws)
455
+ if sig is None:
456
+ return
457
+
458
+ # Get the pysig
459
+ try:
460
+ pysig = typer.pysig
461
+ except AttributeError:
462
+ pysig = utils.pysignature(typer)
463
+
464
+ # Fold any keyword arguments
465
+ bound = pysig.bind(*args, **kws)
466
+ if bound.kwargs:
467
+ raise TypingError("unsupported call signature")
468
+ if not isinstance(sig, Signature):
469
+ # If not a signature, `sig` is assumed to be the return type
470
+ if not isinstance(sig, types.Type):
471
+ raise TypeError(
472
+ "invalid return type for callable template: got %r" % (sig,)
473
+ )
474
+ sig = signature(sig, *bound.args)
475
+ if self.recvr is not None:
476
+ sig = sig.replace(recvr=self.recvr)
477
+ # Hack any omitted parameters out of the typer's pysig,
478
+ # as lowering expects an exact match between formal signature
479
+ # and actual args.
480
+ if len(bound.args) < len(pysig.parameters):
481
+ parameters = list(pysig.parameters.values())[: len(bound.args)]
482
+ pysig = pysig.replace(parameters=parameters)
483
+ sig = sig.replace(pysig=pysig)
484
+ cases = [sig]
485
+ return self._select(cases, bound.args, bound.kwargs)
486
+
487
+ def get_template_info(self):
488
+ impl = getattr(self, "generic")
489
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
490
+ code, firstlineno, path = self.get_source_code_info(impl)
491
+ sig = str(utils.pysignature(impl))
492
+ info = {
493
+ "kind": "overload",
494
+ "name": getattr(
495
+ self.key,
496
+ "__name__",
497
+ getattr(impl, "__qualname__", impl.__name__),
498
+ ),
499
+ "sig": sig,
500
+ "filename": utils.safe_relpath(path, start=basepath),
501
+ "lines": (firstlineno, firstlineno + len(code) - 1),
502
+ "docstring": impl.__doc__,
503
+ }
504
+ return info
505
+
506
+
507
+ class ConcreteTemplate(FunctionTemplate):
508
+ """
509
+ Defines attributes "cases" as a list of signature to match against the
510
+ given input types.
511
+ """
512
+
513
+ def apply(self, args, kws):
514
+ cases = getattr(self, "cases")
515
+ return self._select(cases, args, kws)
516
+
517
+ def get_template_info(self):
518
+ import operator
519
+
520
+ name = getattr(self.key, "__name__", "unknown")
521
+ op_func = getattr(operator, name, None)
522
+
523
+ kind = "Type restricted function"
524
+ if op_func is not None:
525
+ if self.key is op_func:
526
+ kind = "operator overload"
527
+ info = {
528
+ "kind": kind,
529
+ "name": name,
530
+ "sig": "unknown",
531
+ "filename": "unknown",
532
+ "lines": ("unknown", "unknown"),
533
+ "docstring": "unknown",
534
+ }
535
+ return info
536
+
537
+
538
+ class _EmptyImplementationEntry(InternalError):
539
+ def __init__(self, reason):
540
+ super(_EmptyImplementationEntry, self).__init__(
541
+ "_EmptyImplementationEntry({!r})".format(reason),
542
+ )
543
+
544
+
545
+ class _OverloadFunctionTemplate(AbstractTemplate):
546
+ """
547
+ A base class of templates for overload functions.
548
+ """
549
+
550
+ def _validate_sigs(self, typing_func, impl_func):
551
+ # check that the impl func and the typing func have the same signature!
552
+ typing_sig = utils.pysignature(typing_func)
553
+ impl_sig = utils.pysignature(impl_func)
554
+ # the typing signature is considered golden and must be adhered to by
555
+ # the implementation...
556
+ # Things that are valid:
557
+ # 1. args match exactly
558
+ # 2. kwargs match exactly in name and default value
559
+ # 3. Use of *args in the same location by the same name in both typing
560
+ # and implementation signature
561
+ # 4. Use of *args in the implementation signature to consume any number
562
+ # of arguments in the typing signature.
563
+ # Things that are invalid:
564
+ # 5. Use of *args in the typing signature that is not replicated
565
+ # in the implementing signature
566
+ # 6. Use of **kwargs
567
+
568
+ def get_args_kwargs(sig):
569
+ kws = []
570
+ args = []
571
+ pos_arg = None
572
+ for x in sig.parameters.values():
573
+ if x.default == utils.pyParameter.empty:
574
+ args.append(x)
575
+ if x.kind == utils.pyParameter.VAR_POSITIONAL:
576
+ pos_arg = x
577
+ elif x.kind == utils.pyParameter.VAR_KEYWORD:
578
+ msg = (
579
+ "The use of VAR_KEYWORD (e.g. **kwargs) is "
580
+ "unsupported. (offending argument name is '%s')"
581
+ )
582
+ raise InternalError(msg % x)
583
+ else:
584
+ kws.append(x)
585
+ return args, kws, pos_arg
586
+
587
+ ty_args, ty_kws, ty_pos = get_args_kwargs(typing_sig)
588
+ im_args, im_kws, im_pos = get_args_kwargs(impl_sig)
589
+
590
+ sig_fmt = "Typing signature: %s\nImplementation signature: %s"
591
+ sig_str = sig_fmt % (typing_sig, impl_sig)
592
+
593
+ err_prefix = "Typing and implementation arguments differ in "
594
+
595
+ a = ty_args
596
+ b = im_args
597
+ if ty_pos:
598
+ if not im_pos:
599
+ # case 5. described above
600
+ msg = (
601
+ "VAR_POSITIONAL (e.g. *args) argument kind (offending "
602
+ "argument name is '%s') found in the typing function "
603
+ "signature, but is not in the implementing function "
604
+ "signature.\n%s"
605
+ ) % (ty_pos, sig_str)
606
+ raise InternalError(msg)
607
+ else:
608
+ if im_pos:
609
+ # no *args in typing but there's a *args in the implementation
610
+ # this is case 4. described above
611
+ b = im_args[: im_args.index(im_pos)]
612
+ try:
613
+ a = ty_args[: ty_args.index(b[-1]) + 1]
614
+ except ValueError:
615
+ # there's no b[-1] arg name in the ty_args, something is
616
+ # very wrong, we can't work out a diff (*args consumes
617
+ # unknown quantity of args) so just report first error
618
+ specialized = "argument names.\n%s\nFirst difference: '%s'"
619
+ msg = err_prefix + specialized % (sig_str, b[-1])
620
+ raise InternalError(msg)
621
+
622
+ def gen_diff(typing, implementing):
623
+ diff = set(typing) ^ set(implementing)
624
+ return "Difference: %s" % diff
625
+
626
+ if a != b:
627
+ specialized = "argument names.\n%s\n%s" % (sig_str, gen_diff(a, b))
628
+ raise InternalError(err_prefix + specialized)
629
+
630
+ # ensure kwargs are the same
631
+ ty = [x.name for x in ty_kws]
632
+ im = [x.name for x in im_kws]
633
+ if ty != im:
634
+ specialized = "keyword argument names.\n%s\n%s"
635
+ msg = err_prefix + specialized % (sig_str, gen_diff(ty_kws, im_kws))
636
+ raise InternalError(msg)
637
+ same = [x.default for x in ty_kws] == [x.default for x in im_kws]
638
+ if not same:
639
+ specialized = "keyword argument default values.\n%s\n%s"
640
+ msg = err_prefix + specialized % (sig_str, gen_diff(ty_kws, im_kws))
641
+ raise InternalError(msg)
642
+
643
+ def generic(self, args, kws):
644
+ """
645
+ Type the overloaded function by compiling the appropriate
646
+ implementation for the given args.
647
+ """
648
+ from numba.core.typed_passes import PreLowerStripPhis
649
+
650
+ disp, new_args = self._get_impl(args, kws)
651
+ if disp is None:
652
+ return
653
+ # Compile and type it for the given types
654
+ disp_type = types.Dispatcher(disp)
655
+ # Store the compiled overload for use in the lowering phase if there's
656
+ # no inlining required (else functions are being compiled which will
657
+ # never be used as they are inlined)
658
+ if not self._inline.is_never_inline:
659
+ # need to run the compiler front end up to type inference to compute
660
+ # a signature
661
+ from numba.core import typed_passes, compiler
662
+ from numba.core.inline_closurecall import InlineWorker
663
+
664
+ fcomp = disp._compiler
665
+ flags = compiler.Flags()
666
+
667
+ # Updating these causes problems?!
668
+ # fcomp.targetdescr.options.parse_as_flags(flags,
669
+ # fcomp.targetoptions)
670
+ # flags = fcomp._customize_flags(flags)
671
+
672
+ # spoof a compiler pipline like the one that will be in use
673
+ tyctx = fcomp.targetdescr.typing_context
674
+ tgctx = fcomp.targetdescr.target_context
675
+ compiler_inst = fcomp.pipeline_class(
676
+ tyctx,
677
+ tgctx,
678
+ None,
679
+ None,
680
+ None,
681
+ flags,
682
+ None,
683
+ )
684
+ inline_worker = InlineWorker(
685
+ tyctx,
686
+ tgctx,
687
+ fcomp.locals,
688
+ compiler_inst,
689
+ flags,
690
+ None,
691
+ )
692
+
693
+ # If the inlinee contains something to trigger literal arg dispatch
694
+ # then the pipeline call will unconditionally fail due to a raised
695
+ # ForceLiteralArg exception. Therefore `resolve` is run first, as
696
+ # type resolution must occur at some point, this will hit any
697
+ # `literally` calls and because it's going via the dispatcher will
698
+ # handle them correctly i.e. ForceLiteralArg propagates. This having
699
+ # the desired effect of ensuring the pipeline call is only made in
700
+ # situations that will succeed. For context see #5887.
701
+ resolve = disp_type.dispatcher.get_call_template
702
+ template, pysig, folded_args, kws = resolve(new_args, kws)
703
+ ir = inline_worker.run_untyped_passes(
704
+ disp_type.dispatcher.py_func, enable_ssa=True
705
+ )
706
+
707
+ (typemap, return_type, calltypes, _) = (
708
+ typed_passes.type_inference_stage(
709
+ self.context, tgctx, ir, folded_args, None
710
+ )
711
+ )
712
+ ir = PreLowerStripPhis()._strip_phi_nodes(ir)
713
+ ir._definitions = ir_utils.build_definitions(ir.blocks)
714
+
715
+ sig = Signature(return_type, folded_args, None)
716
+ # this stores a load of info for the cost model function if supplied
717
+ # it by default is None
718
+ self._inline_overloads[sig.args] = {"folded_args": folded_args}
719
+ # this stores the compiled overloads, if there's no compiled
720
+ # overload available i.e. function is always inlined, the key still
721
+ # needs to exist for type resolution
722
+
723
+ # NOTE: If lowering is failing on a `_EmptyImplementationEntry`,
724
+ # the inliner has failed to inline this entry correctly.
725
+ impl_init = _EmptyImplementationEntry("always inlined")
726
+ self._compiled_overloads[sig.args] = impl_init
727
+ if not self._inline.is_always_inline:
728
+ # this branch is here because a user has supplied a function to
729
+ # determine whether to inline or not. As a result both compiled
730
+ # function and inliner info needed, delaying the computation of
731
+ # this leads to an internal state mess at present. TODO: Fix!
732
+ sig = disp_type.get_call_type(self.context, new_args, kws)
733
+ self._compiled_overloads[sig.args] = disp_type.get_overload(sig)
734
+ # store the inliner information, it's used later in the cost
735
+ # model function call
736
+ iinfo = _inline_info(ir, typemap, calltypes, sig)
737
+ self._inline_overloads[sig.args] = {
738
+ "folded_args": folded_args,
739
+ "iinfo": iinfo,
740
+ }
741
+ else:
742
+ sig = disp_type.get_call_type(self.context, new_args, kws)
743
+ if sig is None: # can't resolve for this target
744
+ return None
745
+ self._compiled_overloads[sig.args] = disp_type.get_overload(sig)
746
+ return sig
747
+
748
+ def _get_impl(self, args, kws):
749
+ """Get implementation given the argument types.
750
+
751
+ Returning a Dispatcher object. The Dispatcher object is cached
752
+ internally in `self._impl_cache`.
753
+ """
754
+ flags = targetconfig.ConfigStack.top_or_none()
755
+ cache_key = self.context, tuple(args), tuple(kws.items()), flags
756
+ try:
757
+ impl, args = self._impl_cache[cache_key]
758
+ return impl, args
759
+ except KeyError:
760
+ # pass and try outside the scope so as to not have KeyError with a
761
+ # nested addition error in the case the _build_impl fails
762
+ pass
763
+ impl, args = self._build_impl(cache_key, args, kws)
764
+ return impl, args
765
+
766
+ def _get_jit_decorator(self):
767
+ """Gets a jit decorator suitable for the current target"""
768
+
769
+ from numba.core.target_extension import (
770
+ target_registry,
771
+ get_local_target,
772
+ jit_registry,
773
+ )
774
+
775
+ jitter_str = self.metadata.get("target", "generic")
776
+ jitter = jit_registry.get(jitter_str, None)
777
+
778
+ if jitter is None:
779
+ # No JIT known for target string, see if something is
780
+ # registered for the string and report if not.
781
+ target_class = target_registry.get(jitter_str, None)
782
+ if target_class is None:
783
+ msg = ("Unknown target '{}', has it been ", "registered?")
784
+ raise ValueError(msg.format(jitter_str))
785
+
786
+ target_hw = get_local_target(self.context)
787
+
788
+ # check that the requested target is in the hierarchy for the
789
+ # current frame's target.
790
+ if not issubclass(target_hw, target_class):
791
+ msg = "No overloads exist for the requested target: {}."
792
+
793
+ jitter = jit_registry[target_hw]
794
+
795
+ if jitter is None:
796
+ raise ValueError("Cannot find a suitable jit decorator")
797
+
798
+ return jitter
799
+
800
+ def _build_impl(self, cache_key, args, kws):
801
+ """Build and cache the implementation.
802
+
803
+ Given the positional (`args`) and keyword arguments (`kws`), obtains
804
+ the `overload` implementation and wrap it in a Dispatcher object.
805
+ The expected argument types are returned for use by type-inference.
806
+ The expected argument types are only different from the given argument
807
+ types if there is an imprecise type in the given argument types.
808
+
809
+ Parameters
810
+ ----------
811
+ cache_key : hashable
812
+ The key used for caching the implementation.
813
+ args : Tuple[Type]
814
+ Types of positional argument.
815
+ kws : Dict[Type]
816
+ Types of keyword argument.
817
+
818
+ Returns
819
+ -------
820
+ disp, args :
821
+ On success, returns `(Dispatcher, Tuple[Type])`.
822
+ On failure, returns `(None, None)`.
823
+
824
+ """
825
+ jitter = self._get_jit_decorator()
826
+
827
+ # Get the overload implementation for the given types
828
+ ov_sig = inspect.signature(self._overload_func)
829
+ try:
830
+ ov_sig.bind(*args, **kws)
831
+ except TypeError as e:
832
+ # bind failed, raise, if there's a
833
+ # ValueError then there's likely unrecoverable
834
+ # problems
835
+ raise TypingError(str(e)) from e
836
+ else:
837
+ ovf_result = self._overload_func(*args, **kws)
838
+
839
+ if ovf_result is None:
840
+ # No implementation => fail typing
841
+ self._impl_cache[cache_key] = None, None
842
+ return None, None
843
+ elif isinstance(ovf_result, tuple):
844
+ # The implementation returned a signature that the type-inferencer
845
+ # should be using.
846
+ sig, pyfunc = ovf_result
847
+ args = sig.args
848
+ kws = {}
849
+ cache_key = None # don't cache
850
+ else:
851
+ # Regular case
852
+ pyfunc = ovf_result
853
+
854
+ # Check type of pyfunc
855
+ if not isinstance(pyfunc, FunctionType):
856
+ msg = (
857
+ "Implementation function returned by `@overload` "
858
+ "has an unexpected type. Got {}"
859
+ )
860
+ raise AssertionError(msg.format(pyfunc))
861
+
862
+ # check that the typing and impl sigs match up
863
+ if self._strict:
864
+ self._validate_sigs(self._overload_func, pyfunc)
865
+ # Make dispatcher
866
+ jitdecor = jitter(**self._jit_options)
867
+ disp = jitdecor(pyfunc)
868
+ # Make sure that the implementation can be fully compiled
869
+ disp_type = types.Dispatcher(disp)
870
+ disp_type.get_call_type(self.context, args, kws)
871
+ if cache_key is not None:
872
+ self._impl_cache[cache_key] = disp, args
873
+ return disp, args
874
+
875
+ def get_impl_key(self, sig):
876
+ """
877
+ Return the key for looking up the implementation for the given
878
+ signature on the target context.
879
+ """
880
+ return self._compiled_overloads[sig.args]
881
+
882
+ @classmethod
883
+ def get_source_info(cls):
884
+ """Return a dictionary with information about the source code of the
885
+ implementation.
886
+
887
+ Returns
888
+ -------
889
+ info : dict
890
+ - "kind" : str
891
+ The implementation kind.
892
+ - "name" : str
893
+ The name of the function that provided the definition.
894
+ - "sig" : str
895
+ The formatted signature of the function.
896
+ - "filename" : str
897
+ The name of the source file.
898
+ - "lines": tuple (int, int)
899
+ First and list line number.
900
+ - "docstring": str
901
+ The docstring of the definition.
902
+ """
903
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
904
+ impl = cls._overload_func
905
+ code, firstlineno, path = cls.get_source_code_info(impl)
906
+ sig = str(utils.pysignature(impl))
907
+ info = {
908
+ "kind": "overload",
909
+ "name": getattr(impl, "__qualname__", impl.__name__),
910
+ "sig": sig,
911
+ "filename": utils.safe_relpath(path, start=basepath),
912
+ "lines": (firstlineno, firstlineno + len(code) - 1),
913
+ "docstring": impl.__doc__,
914
+ }
915
+ return info
916
+
917
+ def get_template_info(self):
918
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
919
+ impl = self._overload_func
920
+ code, firstlineno, path = self.get_source_code_info(impl)
921
+ sig = str(utils.pysignature(impl))
922
+ info = {
923
+ "kind": "overload",
924
+ "name": getattr(impl, "__qualname__", impl.__name__),
925
+ "sig": sig,
926
+ "filename": utils.safe_relpath(path, start=basepath),
927
+ "lines": (firstlineno, firstlineno + len(code) - 1),
928
+ "docstring": impl.__doc__,
929
+ }
930
+ return info
931
+
932
+
933
+ def make_overload_template(
934
+ func,
935
+ overload_func,
936
+ jit_options,
937
+ strict,
938
+ inline,
939
+ prefer_literal=False,
940
+ **kwargs,
941
+ ):
942
+ """
943
+ Make a template class for function *func* overloaded by *overload_func*.
944
+ Compiler options are passed as a dictionary to *jit_options*.
945
+ """
946
+ func_name = getattr(func, "__name__", str(func))
947
+ name = "OverloadTemplate_%s" % (func_name,)
948
+ base = _OverloadFunctionTemplate
949
+ dct = dict(
950
+ key=func,
951
+ _overload_func=staticmethod(overload_func),
952
+ _impl_cache={},
953
+ _compiled_overloads={},
954
+ _jit_options=jit_options,
955
+ _strict=strict,
956
+ _inline=staticmethod(InlineOptions(inline)),
957
+ _inline_overloads={},
958
+ prefer_literal=prefer_literal,
959
+ metadata=kwargs,
960
+ )
961
+ return type(base)(name, (base,), dct)
962
+
963
+
964
+ class _TemplateTargetHelperMixin(object):
965
+ """Mixin for helper methods that assist with target/registry resolution"""
966
+
967
+ def _get_target_registry(self, reason):
968
+ """Returns the registry for the current target.
969
+
970
+ Parameters
971
+ ----------
972
+ reason: str
973
+ Reason for the resolution. Expects a noun.
974
+ Returns
975
+ -------
976
+ reg : a registry suitable for the current target.
977
+ """
978
+ from numba.core.target_extension import (
979
+ _get_local_target_checked,
980
+ dispatcher_registry,
981
+ )
982
+
983
+ hwstr = self.metadata.get("target", "generic")
984
+ target_hw = _get_local_target_checked(self.context, hwstr, reason)
985
+ # Get registry for the current hardware
986
+ disp = dispatcher_registry[target_hw]
987
+ tgtctx = disp.targetdescr.target_context
988
+
989
+ # ---------------------------------------------------------------------
990
+ # XXX: In upstream Numba, this function would prefer the builtin
991
+ # registry if it was installed in the target (as it is for the CUDA
992
+ # target). The builtin registry has been removed from this file (it was
993
+ # initialized as `builtin_registry = Registry()`) as it would duplicate
994
+ # the builtin registry in upstream Numba, which would be likely to lead
995
+ # to confusion / mixing things up between two builtin registries. The
996
+ # comment that accompanied this behaviour is left here, even though the
997
+ # code that would pick the builtin registry has been removed, for the
998
+ # benefit of future understanding.
999
+ #
1000
+ # ---------------------------------------------------------------------
1001
+ #
1002
+ # Comment left in from upstream:
1003
+ #
1004
+ # This is all workarounds...
1005
+ # The issue is that whilst targets shouldn't care about which registry
1006
+ # in which to register lowering implementations, the CUDA target
1007
+ # "borrows" implementations from the CPU from specific registries. This
1008
+ # means that if some impl is defined via @intrinsic, e.g. numba.*unsafe
1009
+ # modules, _AND_ CUDA also makes use of the same impl, then it's
1010
+ # required that the registry in use is one that CUDA borrows from. This
1011
+ # leads to the following expression where by the CPU builtin_registry is
1012
+ # used if it is in the target context as a known registry (i.e. the
1013
+ # target installed it) and if it is not then it is assumed that the
1014
+ # registries for the target are unbound to any other target and so it's
1015
+ # fine to use any of them as a place to put lowering impls.
1016
+ #
1017
+ # NOTE: This will need subsequently fixing again when targets use solely
1018
+ # the extension APIs to describe their implementation. The issue will be
1019
+ # that the builtin_registry should contain _just_ the stack allocated
1020
+ # implementations and low level target invariant things and should not
1021
+ # be modified further. It should be acceptable to remove the `then`
1022
+ # branch and just keep the `else`.
1023
+ # =====================================================================
1024
+
1025
+ # =====================================================================
1026
+ # XXX: This ought not to be necessary in the long term, but is left in
1027
+ # for now. When there are fewer registries (or just one) for a target,
1028
+ # it may be safe to remove this. Or, it may always require a refresh in
1029
+ # case there are pending registrations - this remains to be seen
1030
+ # ---------------------------------------------------------------------
1031
+ #
1032
+ # Comment / code left in from upstream:
1033
+ #
1034
+ # In case the target has swapped, e.g. cuda borrowing cpu, refresh to
1035
+ # populate.
1036
+ tgtctx.refresh()
1037
+ # =====================================================================
1038
+
1039
+ # Pick a registry in which to install intrinsics
1040
+ registries = iter(tgtctx._registries)
1041
+ reg = next(registries)
1042
+ return reg
1043
+
1044
+
1045
+ class _IntrinsicTemplate(_TemplateTargetHelperMixin, AbstractTemplate):
1046
+ """
1047
+ A base class of templates for intrinsic definition
1048
+ """
1049
+
1050
+ def generic(self, args, kws):
1051
+ """
1052
+ Type the intrinsic by the arguments.
1053
+ """
1054
+ lower_builtin = self._get_target_registry("intrinsic").lower
1055
+ cache_key = self.context, args, tuple(kws.items())
1056
+ try:
1057
+ return self._impl_cache[cache_key]
1058
+ except KeyError:
1059
+ pass
1060
+ result = self._definition_func(self.context, *args, **kws)
1061
+ if result is None:
1062
+ return
1063
+ [sig, imp] = result
1064
+ pysig = utils.pysignature(self._definition_func)
1065
+ # omit context argument from user function
1066
+ parameters = list(pysig.parameters.values())[1:]
1067
+ sig = sig.replace(pysig=pysig.replace(parameters=parameters))
1068
+ self._impl_cache[cache_key] = sig
1069
+ self._overload_cache[sig.args] = imp
1070
+ # register the lowering
1071
+ lower_builtin(imp, *sig.args)(imp)
1072
+ return sig
1073
+
1074
+ def get_impl_key(self, sig):
1075
+ """
1076
+ Return the key for looking up the implementation for the given
1077
+ signature on the target context.
1078
+ """
1079
+ return self._overload_cache[sig.args]
1080
+
1081
+ def get_template_info(self):
1082
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
1083
+ impl = self._definition_func
1084
+ code, firstlineno, path = self.get_source_code_info(impl)
1085
+ sig = str(utils.pysignature(impl))
1086
+ info = {
1087
+ "kind": "intrinsic",
1088
+ "name": getattr(impl, "__qualname__", impl.__name__),
1089
+ "sig": sig,
1090
+ "filename": utils.safe_relpath(path, start=basepath),
1091
+ "lines": (firstlineno, firstlineno + len(code) - 1),
1092
+ "docstring": impl.__doc__,
1093
+ }
1094
+ return info
1095
+
1096
+
1097
+ def make_intrinsic_template(
1098
+ handle, defn, name, *, prefer_literal=False, kwargs=None
1099
+ ):
1100
+ """
1101
+ Make a template class for a intrinsic handle *handle* defined by the
1102
+ function *defn*. The *name* is used for naming the new template class.
1103
+ """
1104
+ kwargs = MappingProxyType({} if kwargs is None else kwargs)
1105
+ base = _IntrinsicTemplate
1106
+ name = "_IntrinsicTemplate_%s" % (name)
1107
+ dct = dict(
1108
+ key=handle,
1109
+ _definition_func=staticmethod(defn),
1110
+ _impl_cache={},
1111
+ _overload_cache={},
1112
+ prefer_literal=prefer_literal,
1113
+ metadata=kwargs,
1114
+ )
1115
+ return type(base)(name, (base,), dct)
1116
+
1117
+
1118
+ class AttributeTemplate(object):
1119
+ def __init__(self, context):
1120
+ self.context = context
1121
+
1122
+ def resolve(self, value, attr):
1123
+ return self._resolve(value, attr)
1124
+
1125
+ def _resolve(self, value, attr):
1126
+ fn = getattr(self, "resolve_%s" % attr, None)
1127
+ if fn is None:
1128
+ fn = self.generic_resolve
1129
+ if fn is NotImplemented:
1130
+ if isinstance(value, types.Module):
1131
+ return self.context.resolve_module_constants(value, attr)
1132
+ else:
1133
+ return None
1134
+ else:
1135
+ return fn(value, attr)
1136
+ else:
1137
+ return fn(value)
1138
+
1139
+ generic_resolve = NotImplemented
1140
+
1141
+
1142
+ class _OverloadAttributeTemplate(_TemplateTargetHelperMixin, AttributeTemplate):
1143
+ """
1144
+ A base class of templates for @overload_attribute functions.
1145
+ """
1146
+
1147
+ is_method = False
1148
+
1149
+ def __init__(self, context):
1150
+ super(_OverloadAttributeTemplate, self).__init__(context)
1151
+ self.context = context
1152
+ self._init_once()
1153
+
1154
+ def _init_once(self):
1155
+ cls = type(self)
1156
+ attr = cls._attr
1157
+
1158
+ lower_getattr = self._get_target_registry("attribute").lower_getattr
1159
+
1160
+ @lower_getattr(cls.key, attr)
1161
+ def getattr_impl(context, builder, typ, value):
1162
+ typingctx = context.typing_context
1163
+ fnty = cls._get_function_type(typingctx, typ)
1164
+ sig = cls._get_signature(typingctx, fnty, (typ,), {})
1165
+ call = context.get_function(fnty, sig)
1166
+ return call(builder, (value,))
1167
+
1168
+ def _resolve(self, typ, attr):
1169
+ if self._attr != attr:
1170
+ return None
1171
+ fnty = self._get_function_type(self.context, typ)
1172
+ sig = self._get_signature(self.context, fnty, (typ,), {})
1173
+ # There should only be one template
1174
+ for template in fnty.templates:
1175
+ self._inline_overloads.update(template._inline_overloads)
1176
+ return sig.return_type
1177
+
1178
+ @classmethod
1179
+ def _get_signature(cls, typingctx, fnty, args, kws):
1180
+ sig = fnty.get_call_type(typingctx, args, kws)
1181
+ sig = sig.replace(pysig=utils.pysignature(cls._overload_func))
1182
+ return sig
1183
+
1184
+ @classmethod
1185
+ def _get_function_type(cls, typingctx, typ):
1186
+ return typingctx.resolve_value_type(cls._overload_func)
1187
+
1188
+
1189
+ class _OverloadMethodTemplate(_OverloadAttributeTemplate):
1190
+ """
1191
+ A base class of templates for @overload_method functions.
1192
+ """
1193
+
1194
+ is_method = True
1195
+
1196
+ def _init_once(self):
1197
+ """
1198
+ Overriding parent definition
1199
+ """
1200
+ attr = self._attr
1201
+
1202
+ registry = self._get_target_registry("method")
1203
+
1204
+ @registry.lower((self.key, attr), self.key, types.VarArg(types.Any))
1205
+ def method_impl(context, builder, sig, args):
1206
+ typ = sig.args[0]
1207
+ typing_context = context.typing_context
1208
+ fnty = self._get_function_type(typing_context, typ)
1209
+ sig = self._get_signature(typing_context, fnty, sig.args, {})
1210
+ call = context.get_function(fnty, sig)
1211
+ # Link dependent library
1212
+ context.add_linking_libs(getattr(call, "libs", ()))
1213
+ return call(builder, args)
1214
+
1215
+ def _resolve(self, typ, attr):
1216
+ if self._attr != attr:
1217
+ return None
1218
+
1219
+ if isinstance(typ, types.TypeRef):
1220
+ assert typ == self.key
1221
+ elif isinstance(typ, types.Callable):
1222
+ assert typ == self.key
1223
+ else:
1224
+ assert isinstance(typ, self.key)
1225
+
1226
+ class MethodTemplate(AbstractTemplate):
1227
+ key = (self.key, attr)
1228
+ _inline = self._inline
1229
+ _overload_func = staticmethod(self._overload_func)
1230
+ _inline_overloads = self._inline_overloads
1231
+ prefer_literal = self.prefer_literal
1232
+
1233
+ def generic(_, args, kws):
1234
+ args = (typ,) + tuple(args)
1235
+ fnty = self._get_function_type(self.context, typ)
1236
+ sig = self._get_signature(self.context, fnty, args, kws)
1237
+ sig = sig.replace(pysig=utils.pysignature(self._overload_func))
1238
+ for template in fnty.templates:
1239
+ self._inline_overloads.update(template._inline_overloads)
1240
+ if sig is not None:
1241
+ return sig.as_method()
1242
+
1243
+ def get_template_info(self):
1244
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
1245
+ impl = self._overload_func
1246
+ code, firstlineno, path = self.get_source_code_info(impl)
1247
+ sig = str(utils.pysignature(impl))
1248
+ info = {
1249
+ "kind": "overload_method",
1250
+ "name": getattr(impl, "__qualname__", impl.__name__),
1251
+ "sig": sig,
1252
+ "filename": utils.safe_relpath(path, start=basepath),
1253
+ "lines": (firstlineno, firstlineno + len(code) - 1),
1254
+ "docstring": impl.__doc__,
1255
+ }
1256
+
1257
+ return info
1258
+
1259
+ return types.BoundFunction(MethodTemplate, typ)
1260
+
1261
+
1262
+ def make_overload_attribute_template(
1263
+ typ,
1264
+ attr,
1265
+ overload_func,
1266
+ inline="never",
1267
+ prefer_literal=False,
1268
+ base=_OverloadAttributeTemplate,
1269
+ **kwargs,
1270
+ ):
1271
+ """
1272
+ Make a template class for attribute *attr* of *typ* overloaded by
1273
+ *overload_func*.
1274
+ """
1275
+ assert isinstance(typ, types.Type) or issubclass(typ, types.Type)
1276
+ name = "OverloadAttributeTemplate_%s_%s" % (typ, attr)
1277
+ # Note the implementation cache is subclass-specific
1278
+ dct = dict(
1279
+ key=typ,
1280
+ _attr=attr,
1281
+ _impl_cache={},
1282
+ _inline=staticmethod(InlineOptions(inline)),
1283
+ _inline_overloads={},
1284
+ _overload_func=staticmethod(overload_func),
1285
+ prefer_literal=prefer_literal,
1286
+ metadata=kwargs,
1287
+ )
1288
+ obj = type(base)(name, (base,), dct)
1289
+ return obj
1290
+
1291
+
1292
+ def make_overload_method_template(
1293
+ typ, attr, overload_func, inline, prefer_literal=False, **kwargs
1294
+ ):
1295
+ """
1296
+ Make a template class for method *attr* of *typ* overloaded by
1297
+ *overload_func*.
1298
+ """
1299
+ return make_overload_attribute_template(
1300
+ typ,
1301
+ attr,
1302
+ overload_func,
1303
+ inline=inline,
1304
+ base=_OverloadMethodTemplate,
1305
+ prefer_literal=prefer_literal,
1306
+ **kwargs,
1307
+ )
1308
+
1309
+
1310
+ def bound_function(template_key):
1311
+ """
1312
+ Wrap an AttributeTemplate resolve_* method to allow it to
1313
+ resolve an instance method's signature rather than a instance attribute.
1314
+ The wrapped method must return the resolved method's signature
1315
+ according to the given self type, args, and keywords.
1316
+
1317
+ It is used thusly:
1318
+
1319
+ class ComplexAttributes(AttributeTemplate):
1320
+ @bound_function("complex.conjugate")
1321
+ def resolve_conjugate(self, ty, args, kwds):
1322
+ return ty
1323
+
1324
+ *template_key* (e.g. "complex.conjugate" above) will be used by the
1325
+ target to look up the method's implementation, as a regular function.
1326
+ """
1327
+
1328
+ def wrapper(method_resolver):
1329
+ @functools.wraps(method_resolver)
1330
+ def attribute_resolver(self, ty):
1331
+ class MethodTemplate(AbstractTemplate):
1332
+ key = template_key
1333
+
1334
+ def generic(_, args, kws):
1335
+ sig = method_resolver(self, ty, args, kws)
1336
+ if sig is not None and sig.recvr is None:
1337
+ sig = sig.replace(recvr=ty)
1338
+ return sig
1339
+
1340
+ return types.BoundFunction(MethodTemplate, ty)
1341
+
1342
+ return attribute_resolver
1343
+
1344
+ return wrapper
1345
+
1346
+
1347
+ # -----------------------------
1348
+
1349
+
1350
+ class Registry(object):
1351
+ """
1352
+ A registry of typing declarations. The registry stores such declarations
1353
+ for functions, attributes and globals.
1354
+ """
1355
+
1356
+ def __init__(self):
1357
+ self.functions = []
1358
+ self.attributes = []
1359
+ self.globals = []
1360
+
1361
+ def register(self, item):
1362
+ self.functions.append(item)
1363
+ return item
1364
+
1365
+ def register_attr(self, item):
1366
+ self.attributes.append(item)
1367
+ return item
1368
+
1369
+ def register_global(self, val=None, typ=None, **kwargs):
1370
+ """
1371
+ Register the typing of a global value.
1372
+ Functional usage with a Numba type::
1373
+ register_global(value, typ)
1374
+
1375
+ Decorator usage with a template class::
1376
+ @register_global(value, typing_key=None)
1377
+ class Template: ...
1378
+ """
1379
+ if typ is not None:
1380
+ # register_global(val, typ)
1381
+ assert val is not None
1382
+ assert not kwargs
1383
+ self.globals.append((val, typ))
1384
+ else:
1385
+
1386
+ def decorate(cls, typing_key):
1387
+ class Template(cls):
1388
+ key = typing_key
1389
+
1390
+ if callable(val):
1391
+ typ = types.Function(Template)
1392
+ else:
1393
+ raise TypeError("cannot infer type for global value %r")
1394
+ self.globals.append((val, typ))
1395
+ return cls
1396
+
1397
+ # register_global(val, typing_key=None)(<template class>)
1398
+ assert val is not None
1399
+ typing_key = kwargs.pop("typing_key", val)
1400
+ assert not kwargs
1401
+ if typing_key is val:
1402
+ # Check the value is globally reachable, as it is going
1403
+ # to be used as the key.
1404
+ mod = sys.modules[val.__module__]
1405
+ if getattr(mod, val.__name__) is not val:
1406
+ raise ValueError(
1407
+ "%r is not globally reachable as '%s.%s'"
1408
+ % (mod, val.__module__, val.__name__)
1409
+ )
1410
+
1411
+ def decorator(cls):
1412
+ return decorate(cls, typing_key)
1413
+
1414
+ return decorator
1415
+
1416
+
1417
+ class BaseRegistryLoader(object):
1418
+ """
1419
+ An incremental loader for a registry. Each new call to
1420
+ new_registrations() will iterate over the not yet seen registrations.
1421
+
1422
+ The reason for this object is multiple:
1423
+ - there can be several contexts
1424
+ - each context wants to install all registrations
1425
+ - registrations can be added after the first installation, so contexts
1426
+ must be able to get the "new" installations
1427
+
1428
+ Therefore each context maintains its own loaders for each existing
1429
+ registry, without duplicating the registries themselves.
1430
+ """
1431
+
1432
+ def __init__(self, registry):
1433
+ self._registrations = dict(
1434
+ (name, utils.stream_list(getattr(registry, name)))
1435
+ for name in self.registry_items
1436
+ )
1437
+
1438
+ def new_registrations(self, name):
1439
+ for item in next(self._registrations[name]):
1440
+ yield item
1441
+
1442
+
1443
+ class RegistryLoader(BaseRegistryLoader):
1444
+ """
1445
+ An incremental loader for a typing registry.
1446
+ """
1447
+
1448
+ registry_items = ("functions", "attributes", "globals")