numba-cuda 0.22.0__cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (487) hide show
  1. _numba_cuda_redirector.pth +4 -0
  2. _numba_cuda_redirector.py +89 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +6 -0
  5. numba_cuda/_version.py +11 -0
  6. numba_cuda/numba/cuda/__init__.py +70 -0
  7. numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
  8. numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
  9. numba_cuda/numba/cuda/api.py +580 -0
  10. numba_cuda/numba/cuda/api_util.py +76 -0
  11. numba_cuda/numba/cuda/args.py +72 -0
  12. numba_cuda/numba/cuda/bf16.py +397 -0
  13. numba_cuda/numba/cuda/cache_hints.py +287 -0
  14. numba_cuda/numba/cuda/cext/__init__.py +2 -0
  15. numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
  16. numba_cuda/numba/cuda/cext/_devicearray.cpython-313-aarch64-linux-gnu.so +0 -0
  17. numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
  18. numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
  19. numba_cuda/numba/cuda/cext/_dispatcher.cpython-313-aarch64-linux-gnu.so +0 -0
  20. numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
  21. numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
  22. numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
  23. numba_cuda/numba/cuda/cext/_helperlib.cpython-313-aarch64-linux-gnu.so +0 -0
  24. numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
  25. numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
  26. numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
  27. numba_cuda/numba/cuda/cext/_typeconv.cpython-313-aarch64-linux-gnu.so +0 -0
  28. numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
  29. numba_cuda/numba/cuda/cext/_typeof.h +19 -0
  30. numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
  31. numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
  32. numba_cuda/numba/cuda/cext/mviewbuf.cpython-313-aarch64-linux-gnu.so +0 -0
  33. numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
  34. numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
  35. numba_cuda/numba/cuda/cg.py +67 -0
  36. numba_cuda/numba/cuda/cgutils.py +1294 -0
  37. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  38. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  39. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  40. numba_cuda/numba/cuda/codegen.py +541 -0
  41. numba_cuda/numba/cuda/compiler.py +1396 -0
  42. numba_cuda/numba/cuda/core/analysis.py +758 -0
  43. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  44. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
  45. numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
  46. numba_cuda/numba/cuda/core/base.py +1332 -0
  47. numba_cuda/numba/cuda/core/boxing.py +1411 -0
  48. numba_cuda/numba/cuda/core/bytecode.py +728 -0
  49. numba_cuda/numba/cuda/core/byteflow.py +2346 -0
  50. numba_cuda/numba/cuda/core/caching.py +744 -0
  51. numba_cuda/numba/cuda/core/callconv.py +392 -0
  52. numba_cuda/numba/cuda/core/codegen.py +171 -0
  53. numba_cuda/numba/cuda/core/compiler.py +199 -0
  54. numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
  55. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  56. numba_cuda/numba/cuda/core/config.py +650 -0
  57. numba_cuda/numba/cuda/core/consts.py +124 -0
  58. numba_cuda/numba/cuda/core/controlflow.py +989 -0
  59. numba_cuda/numba/cuda/core/entrypoints.py +57 -0
  60. numba_cuda/numba/cuda/core/environment.py +66 -0
  61. numba_cuda/numba/cuda/core/errors.py +917 -0
  62. numba_cuda/numba/cuda/core/event.py +511 -0
  63. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  64. numba_cuda/numba/cuda/core/generators.py +387 -0
  65. numba_cuda/numba/cuda/core/imputils.py +509 -0
  66. numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
  67. numba_cuda/numba/cuda/core/interpreter.py +3617 -0
  68. numba_cuda/numba/cuda/core/ir.py +1812 -0
  69. numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
  70. numba_cuda/numba/cuda/core/optional.py +129 -0
  71. numba_cuda/numba/cuda/core/options.py +262 -0
  72. numba_cuda/numba/cuda/core/postproc.py +249 -0
  73. numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
  74. numba_cuda/numba/cuda/core/registry.py +46 -0
  75. numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
  76. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  77. numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
  78. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  79. numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
  80. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
  81. numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
  82. numba_cuda/numba/cuda/core/sigutils.py +68 -0
  83. numba_cuda/numba/cuda/core/ssa.py +498 -0
  84. numba_cuda/numba/cuda/core/targetconfig.py +330 -0
  85. numba_cuda/numba/cuda/core/tracing.py +231 -0
  86. numba_cuda/numba/cuda/core/transforms.py +956 -0
  87. numba_cuda/numba/cuda/core/typed_passes.py +867 -0
  88. numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
  89. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  90. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  91. numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
  92. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  93. numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
  94. numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
  95. numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
  96. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  97. numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
  98. numba_cuda/numba/cuda/cpython/iterators.py +167 -0
  99. numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
  100. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  101. numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
  102. numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
  103. numba_cuda/numba/cuda/cpython/slicing.py +322 -0
  104. numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
  105. numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
  106. numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
  107. numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
  108. numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
  109. numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
  110. numba_cuda/numba/cuda/cuda_paths.py +691 -0
  111. numba_cuda/numba/cuda/cudadecl.py +543 -0
  112. numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
  113. numba_cuda/numba/cuda/cudadrv/devicearray.py +954 -0
  114. numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
  115. numba_cuda/numba/cuda/cudadrv/driver.py +3238 -0
  116. numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
  117. numba_cuda/numba/cuda/cudadrv/dummyarray.py +562 -0
  118. numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
  119. numba_cuda/numba/cuda/cudadrv/error.py +48 -0
  120. numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
  121. numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
  122. numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
  123. numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
  124. numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
  125. numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
  126. numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
  127. numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
  128. numba_cuda/numba/cuda/cudaimpl.py +983 -0
  129. numba_cuda/numba/cuda/cudamath.py +149 -0
  130. numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
  131. numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
  132. numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
  133. numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
  134. numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
  135. numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
  136. numba_cuda/numba/cuda/datamodel/manager.py +11 -0
  137. numba_cuda/numba/cuda/datamodel/models.py +9 -0
  138. numba_cuda/numba/cuda/datamodel/packer.py +9 -0
  139. numba_cuda/numba/cuda/datamodel/registry.py +11 -0
  140. numba_cuda/numba/cuda/datamodel/testing.py +11 -0
  141. numba_cuda/numba/cuda/debuginfo.py +997 -0
  142. numba_cuda/numba/cuda/decorators.py +294 -0
  143. numba_cuda/numba/cuda/descriptor.py +35 -0
  144. numba_cuda/numba/cuda/device_init.py +155 -0
  145. numba_cuda/numba/cuda/deviceufunc.py +1021 -0
  146. numba_cuda/numba/cuda/dispatcher.py +2463 -0
  147. numba_cuda/numba/cuda/errors.py +72 -0
  148. numba_cuda/numba/cuda/extending.py +697 -0
  149. numba_cuda/numba/cuda/flags.py +178 -0
  150. numba_cuda/numba/cuda/fp16.py +357 -0
  151. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  152. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  153. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  154. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  155. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  156. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  157. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  158. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  159. numba_cuda/numba/cuda/initialize.py +24 -0
  160. numba_cuda/numba/cuda/intrinsics.py +531 -0
  161. numba_cuda/numba/cuda/itanium_mangler.py +214 -0
  162. numba_cuda/numba/cuda/kernels/__init__.py +2 -0
  163. numba_cuda/numba/cuda/kernels/reduction.py +265 -0
  164. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  165. numba_cuda/numba/cuda/libdevice.py +3386 -0
  166. numba_cuda/numba/cuda/libdevicedecl.py +20 -0
  167. numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
  168. numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
  169. numba_cuda/numba/cuda/locks.py +19 -0
  170. numba_cuda/numba/cuda/lowering.py +1980 -0
  171. numba_cuda/numba/cuda/mathimpl.py +374 -0
  172. numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
  173. numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
  174. numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
  175. numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
  176. numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
  177. numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
  178. numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
  179. numba_cuda/numba/cuda/misc/appdirs.py +594 -0
  180. numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
  181. numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
  182. numba_cuda/numba/cuda/misc/dump_style.py +41 -0
  183. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  184. numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
  185. numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
  186. numba_cuda/numba/cuda/misc/literal.py +28 -0
  187. numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
  188. numba_cuda/numba/cuda/misc/special.py +94 -0
  189. numba_cuda/numba/cuda/models.py +56 -0
  190. numba_cuda/numba/cuda/np/arraymath.py +5130 -0
  191. numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
  192. numba_cuda/numba/cuda/np/extensions.py +11 -0
  193. numba_cuda/numba/cuda/np/linalg.py +3087 -0
  194. numba_cuda/numba/cuda/np/math/__init__.py +0 -0
  195. numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
  196. numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
  197. numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
  198. numba_cuda/numba/cuda/np/npdatetime.py +969 -0
  199. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  200. numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
  201. numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
  202. numba_cuda/numba/cuda/np/numpy_support.py +798 -0
  203. numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
  204. numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
  205. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
  206. numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
  207. numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
  208. numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
  209. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
  210. numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
  211. numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
  212. numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
  213. numba_cuda/numba/cuda/nvvmutils.py +254 -0
  214. numba_cuda/numba/cuda/printimpl.py +126 -0
  215. numba_cuda/numba/cuda/random.py +308 -0
  216. numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
  217. numba_cuda/numba/cuda/serialize.py +267 -0
  218. numba_cuda/numba/cuda/simulator/__init__.py +63 -0
  219. numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
  220. numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
  221. numba_cuda/numba/cuda/simulator/api.py +179 -0
  222. numba_cuda/numba/cuda/simulator/bf16.py +4 -0
  223. numba_cuda/numba/cuda/simulator/compiler.py +38 -0
  224. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
  225. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
  226. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
  227. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
  228. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
  229. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
  230. numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
  231. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
  232. numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
  233. numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
  234. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
  235. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
  236. numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
  237. numba_cuda/numba/cuda/simulator/kernel.py +320 -0
  238. numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
  239. numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
  240. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
  241. numba_cuda/numba/cuda/simulator/reduction.py +19 -0
  242. numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
  243. numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
  244. numba_cuda/numba/cuda/simulator_init.py +18 -0
  245. numba_cuda/numba/cuda/stubs.py +624 -0
  246. numba_cuda/numba/cuda/target.py +505 -0
  247. numba_cuda/numba/cuda/testing.py +347 -0
  248. numba_cuda/numba/cuda/tests/__init__.py +62 -0
  249. numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
  250. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
  251. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  252. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
  253. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
  254. numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
  255. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
  256. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
  257. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
  258. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
  259. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
  260. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
  261. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
  262. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +191 -0
  263. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
  264. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
  265. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
  266. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +200 -0
  267. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
  268. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
  269. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
  270. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
  271. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
  272. numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
  273. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
  274. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
  275. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
  276. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
  277. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
  278. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
  279. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
  280. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
  281. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
  282. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
  283. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
  284. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
  285. numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
  286. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
  287. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
  288. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
  289. numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
  290. numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
  291. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
  292. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
  293. numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
  294. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
  295. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
  296. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
  297. numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
  298. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
  299. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
  300. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
  301. numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
  302. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
  303. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
  304. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
  305. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
  306. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
  307. numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
  308. numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
  309. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
  310. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
  311. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
  312. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
  313. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
  314. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
  315. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
  316. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
  317. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
  318. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
  319. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
  320. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
  321. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
  322. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
  323. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +978 -0
  324. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
  325. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
  326. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
  327. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
  328. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
  329. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
  330. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
  331. numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
  332. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
  333. numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
  334. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
  335. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
  336. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
  337. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
  338. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
  339. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
  340. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
  341. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
  342. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
  343. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
  344. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
  345. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
  346. numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
  347. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
  348. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
  349. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
  350. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
  351. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
  352. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
  353. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
  354. numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
  355. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
  356. numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
  357. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
  358. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
  359. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
  360. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
  361. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
  362. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
  363. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
  364. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
  365. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
  366. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
  367. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
  368. numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
  369. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
  370. numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
  371. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
  372. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
  373. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
  374. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  375. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
  376. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
  377. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
  378. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
  379. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
  380. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
  381. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
  382. numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
  383. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
  384. numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
  385. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  386. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
  387. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
  388. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
  389. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
  390. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
  391. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
  392. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
  393. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
  394. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
  395. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +446 -0
  396. numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
  397. numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
  398. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
  399. numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
  400. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
  401. numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
  402. numba_cuda/numba/cuda/tests/data/error.cu +12 -0
  403. numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
  404. numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
  405. numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
  406. numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
  407. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
  408. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
  409. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
  410. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
  411. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
  412. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
  413. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
  414. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
  415. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
  416. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
  417. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
  418. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
  419. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
  420. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
  421. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
  422. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
  423. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
  424. numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
  425. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +452 -0
  426. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
  427. numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
  428. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
  429. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
  430. numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
  431. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
  432. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
  433. numba_cuda/numba/cuda/tests/support.py +900 -0
  434. numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
  435. numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
  436. numba_cuda/numba/cuda/typeconv/rules.py +63 -0
  437. numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
  438. numba_cuda/numba/cuda/types/__init__.py +233 -0
  439. numba_cuda/numba/cuda/types/__init__.pyi +167 -0
  440. numba_cuda/numba/cuda/types/abstract.py +9 -0
  441. numba_cuda/numba/cuda/types/common.py +9 -0
  442. numba_cuda/numba/cuda/types/containers.py +9 -0
  443. numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
  444. numba_cuda/numba/cuda/types/cuda_common.py +110 -0
  445. numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
  446. numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
  447. numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
  448. numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
  449. numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
  450. numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
  451. numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
  452. numba_cuda/numba/cuda/types/ext_types.py +101 -0
  453. numba_cuda/numba/cuda/types/function_type.py +11 -0
  454. numba_cuda/numba/cuda/types/functions.py +9 -0
  455. numba_cuda/numba/cuda/types/iterators.py +9 -0
  456. numba_cuda/numba/cuda/types/misc.py +9 -0
  457. numba_cuda/numba/cuda/types/npytypes.py +9 -0
  458. numba_cuda/numba/cuda/types/scalars.py +9 -0
  459. numba_cuda/numba/cuda/typing/__init__.py +19 -0
  460. numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
  461. numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
  462. numba_cuda/numba/cuda/typing/bufproto.py +70 -0
  463. numba_cuda/numba/cuda/typing/builtins.py +1209 -0
  464. numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
  465. numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
  466. numba_cuda/numba/cuda/typing/collections.py +138 -0
  467. numba_cuda/numba/cuda/typing/context.py +782 -0
  468. numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
  469. numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
  470. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  471. numba_cuda/numba/cuda/typing/listdecl.py +147 -0
  472. numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
  473. numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
  474. numba_cuda/numba/cuda/typing/npydecl.py +749 -0
  475. numba_cuda/numba/cuda/typing/setdecl.py +115 -0
  476. numba_cuda/numba/cuda/typing/templates.py +1446 -0
  477. numba_cuda/numba/cuda/typing/typeof.py +301 -0
  478. numba_cuda/numba/cuda/ufuncs.py +746 -0
  479. numba_cuda/numba/cuda/utils.py +724 -0
  480. numba_cuda/numba/cuda/vector_types.py +214 -0
  481. numba_cuda/numba/cuda/vectorizers.py +260 -0
  482. numba_cuda-0.22.0.dist-info/METADATA +109 -0
  483. numba_cuda-0.22.0.dist-info/RECORD +487 -0
  484. numba_cuda-0.22.0.dist-info/WHEEL +6 -0
  485. numba_cuda-0.22.0.dist-info/licenses/LICENSE +26 -0
  486. numba_cuda-0.22.0.dist-info/licenses/LICENSE.numba +24 -0
  487. numba_cuda-0.22.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1021 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Implements custom ufunc dispatch mechanism for non-CPU devices.
6
+ """
7
+
8
+ from abc import ABCMeta, abstractmethod
9
+ from collections import OrderedDict
10
+ import operator
11
+ import warnings
12
+ from functools import reduce
13
+ import tokenize
14
+ import string
15
+
16
+ import numpy as np
17
+
18
+ from numba.cuda.np.ufunc.ufuncbuilder import _BaseUFuncBuilder, parse_identity
19
+ from numba.cuda import types
20
+ from numba.cuda.typing import signature
21
+ from numba.cuda.core import sigutils
22
+
23
+
24
+ def parse_signature(sig):
25
+ """Parse generalized ufunc signature.
26
+
27
+ NOTE: ',' (COMMA) is a delimiter; not separator.
28
+ This means trailing comma is legal.
29
+ """
30
+
31
+ def stripws(s):
32
+ return "".join(c for c in s if c not in string.whitespace)
33
+
34
+ def tokenizer(src):
35
+ def readline():
36
+ yield src
37
+
38
+ gen = readline()
39
+ return tokenize.generate_tokens(lambda: next(gen))
40
+
41
+ def parse(src):
42
+ tokgen = tokenizer(src)
43
+ while True:
44
+ tok = next(tokgen)
45
+ if tok[1] == "(":
46
+ symbols = []
47
+ while True:
48
+ tok = next(tokgen)
49
+ if tok[1] == ")":
50
+ break
51
+ elif tok[0] == tokenize.NAME:
52
+ symbols.append(tok[1])
53
+ elif tok[1] == ",":
54
+ continue
55
+ else:
56
+ raise ValueError('bad token in signature "%s"' % tok[1])
57
+ yield tuple(symbols)
58
+ tok = next(tokgen)
59
+ if tok[1] == ",":
60
+ continue
61
+ elif tokenize.ISEOF(tok[0]):
62
+ break
63
+ elif tokenize.ISEOF(tok[0]):
64
+ break
65
+ else:
66
+ raise ValueError('bad token in signature "%s"' % tok[1])
67
+
68
+ ins, _, outs = stripws(sig).partition("->")
69
+ inputs = list(parse(ins))
70
+ outputs = list(parse(outs))
71
+
72
+ # check that all output symbols are defined in the inputs
73
+ isym = set()
74
+ osym = set()
75
+ for grp in inputs:
76
+ isym |= set(grp)
77
+ for grp in outputs:
78
+ osym |= set(grp)
79
+
80
+ diff = osym.difference(isym)
81
+ if diff:
82
+ raise NameError("undefined output symbols: %s" % ",".join(sorted(diff)))
83
+
84
+ return inputs, outputs
85
+
86
+
87
+ def _broadcast_axis(a, b):
88
+ """
89
+ Raises
90
+ ------
91
+ ValueError if broadcast fails
92
+ """
93
+ if a == b:
94
+ return a
95
+ elif a == 1:
96
+ return b
97
+ elif b == 1:
98
+ return a
99
+ else:
100
+ raise ValueError("failed to broadcast {0} and {1}".format(a, b))
101
+
102
+
103
+ def _pairwise_broadcast(shape1, shape2):
104
+ """
105
+ Raises
106
+ ------
107
+ ValueError if broadcast fails
108
+ """
109
+ shape1, shape2 = map(tuple, [shape1, shape2])
110
+
111
+ while len(shape1) < len(shape2):
112
+ shape1 = (1,) + shape1
113
+
114
+ while len(shape1) > len(shape2):
115
+ shape2 = (1,) + shape2
116
+
117
+ return tuple(_broadcast_axis(a, b) for a, b in zip(shape1, shape2))
118
+
119
+
120
+ def _multi_broadcast(*shapelist):
121
+ """
122
+ Raises
123
+ ------
124
+ ValueError if broadcast fails
125
+ """
126
+ assert shapelist
127
+
128
+ result = shapelist[0]
129
+ others = shapelist[1:]
130
+ try:
131
+ for i, each in enumerate(others, start=1):
132
+ result = _pairwise_broadcast(result, each)
133
+ except ValueError:
134
+ raise ValueError("failed to broadcast argument #{0}".format(i))
135
+ else:
136
+ return result
137
+
138
+
139
+ class UFuncMechanism(object):
140
+ """
141
+ Prepare ufunc arguments for vectorize.
142
+ """
143
+
144
+ DEFAULT_STREAM = None
145
+ SUPPORT_DEVICE_SLICING = False
146
+
147
+ def __init__(self, typemap, args):
148
+ """Never used directly by user. Invoke by UFuncMechanism.call()."""
149
+ self.typemap = typemap
150
+ self.args = args
151
+ nargs = len(self.args)
152
+ self.argtypes = [None] * nargs
153
+ self.scalarpos = []
154
+ self.signature = None
155
+ self.arrays = [None] * nargs
156
+
157
+ def _fill_arrays(self):
158
+ """
159
+ Get all arguments in array form
160
+ """
161
+ for i, arg in enumerate(self.args):
162
+ if self.is_device_array(arg):
163
+ self.arrays[i] = self.as_device_array(arg)
164
+ elif isinstance(arg, (int, float, complex, np.number)):
165
+ # Is scalar
166
+ self.scalarpos.append(i)
167
+ else:
168
+ self.arrays[i] = np.asarray(arg)
169
+
170
+ def _fill_argtypes(self):
171
+ """
172
+ Get dtypes
173
+ """
174
+ for i, ary in enumerate(self.arrays):
175
+ if ary is not None:
176
+ dtype = getattr(ary, "dtype")
177
+ if dtype is None:
178
+ dtype = np.asarray(ary).dtype
179
+ self.argtypes[i] = dtype
180
+
181
+ def _resolve_signature(self):
182
+ """Resolve signature.
183
+ May have ambiguous case.
184
+ """
185
+ matches = []
186
+ # Resolve scalar args exact match first
187
+ if self.scalarpos:
188
+ # Try resolve scalar arguments
189
+ for formaltys in self.typemap:
190
+ match_map = []
191
+ for i, (formal, actual) in enumerate(
192
+ zip(formaltys, self.argtypes)
193
+ ):
194
+ if actual is None:
195
+ actual = np.asarray(self.args[i]).dtype
196
+
197
+ match_map.append(actual == formal)
198
+
199
+ if all(match_map):
200
+ matches.append(formaltys)
201
+
202
+ # No matching with exact match; try coercing the scalar arguments
203
+ if not matches:
204
+ matches = []
205
+ for formaltys in self.typemap:
206
+ all_matches = all(
207
+ actual is None or formal == actual
208
+ for formal, actual in zip(formaltys, self.argtypes)
209
+ )
210
+ if all_matches:
211
+ matches.append(formaltys)
212
+
213
+ if not matches:
214
+ raise TypeError(
215
+ "No matching version. GPU ufunc requires array "
216
+ "arguments to have the exact types. This behaves "
217
+ "like regular ufunc with casting='no'."
218
+ )
219
+
220
+ if len(matches) > 1:
221
+ raise TypeError(
222
+ "Failed to resolve ufunc due to ambiguous "
223
+ "signature. Too many untyped scalars. "
224
+ "Use numpy dtype object to type tag."
225
+ )
226
+
227
+ # Try scalar arguments
228
+ self.argtypes = matches[0]
229
+
230
+ def _get_actual_args(self):
231
+ """Return the actual arguments
232
+ Casts scalar arguments to np.array.
233
+ """
234
+ for i in self.scalarpos:
235
+ self.arrays[i] = np.array([self.args[i]], dtype=self.argtypes[i])
236
+
237
+ return self.arrays
238
+
239
+ def _broadcast(self, arys):
240
+ """Perform numpy ufunc broadcasting"""
241
+ shapelist = [a.shape for a in arys]
242
+ shape = _multi_broadcast(*shapelist)
243
+
244
+ for i, ary in enumerate(arys):
245
+ if ary.shape == shape:
246
+ pass
247
+
248
+ else:
249
+ if self.is_device_array(ary):
250
+ arys[i] = self.broadcast_device(ary, shape)
251
+
252
+ else:
253
+ ax_differs = [
254
+ ax
255
+ for ax in range(len(shape))
256
+ if ax >= ary.ndim or ary.shape[ax] != shape[ax]
257
+ ]
258
+
259
+ missingdim = len(shape) - len(ary.shape)
260
+ strides = [0] * missingdim + list(ary.strides)
261
+
262
+ for ax in ax_differs:
263
+ strides[ax] = 0
264
+
265
+ strided = np.lib.stride_tricks.as_strided(
266
+ ary, shape=shape, strides=strides
267
+ )
268
+
269
+ arys[i] = self.force_array_layout(strided)
270
+
271
+ return arys
272
+
273
+ def get_arguments(self):
274
+ """Prepare and return the arguments for the ufunc.
275
+ Does not call to_device().
276
+ """
277
+ self._fill_arrays()
278
+ self._fill_argtypes()
279
+ self._resolve_signature()
280
+ arys = self._get_actual_args()
281
+ return self._broadcast(arys)
282
+
283
+ def get_function(self):
284
+ """Returns (result_dtype, function)"""
285
+ return self.typemap[self.argtypes]
286
+
287
+ def is_device_array(self, obj):
288
+ """Is the `obj` a device array?
289
+ Override in subclass
290
+ """
291
+ return False
292
+
293
+ def as_device_array(self, obj):
294
+ """Convert the `obj` to a device array
295
+ Override in subclass
296
+
297
+ Default implementation is an identity function
298
+ """
299
+ return obj
300
+
301
+ def broadcast_device(self, ary, shape):
302
+ """Handles ondevice broadcasting
303
+
304
+ Override in subclass to add support.
305
+ """
306
+ raise NotImplementedError("broadcasting on device is not supported")
307
+
308
+ def force_array_layout(self, ary):
309
+ """Ensures array layout met device requirement.
310
+
311
+ Override in sublcass
312
+ """
313
+ return ary
314
+
315
+ @classmethod
316
+ def call(cls, typemap, args, kws):
317
+ """Perform the entire ufunc call mechanism."""
318
+ # Handle keywords
319
+ stream = kws.pop("stream", cls.DEFAULT_STREAM)
320
+ out = kws.pop("out", None)
321
+
322
+ if kws:
323
+ warnings.warn("unrecognized keywords: %s" % ", ".join(kws))
324
+
325
+ # Begin call resolution
326
+ cr = cls(typemap, args)
327
+ args = cr.get_arguments()
328
+ resty, func = cr.get_function()
329
+
330
+ outshape = args[0].shape
331
+
332
+ # Adjust output value
333
+ if out is not None and cr.is_device_array(out):
334
+ out = cr.as_device_array(out)
335
+
336
+ def attempt_ravel(a):
337
+ if cr.SUPPORT_DEVICE_SLICING:
338
+ raise NotImplementedError
339
+
340
+ try:
341
+ # Call the `.ravel()` method
342
+ return a.ravel()
343
+ except NotImplementedError:
344
+ # If it is not a device array
345
+ if not cr.is_device_array(a):
346
+ raise
347
+ # For device array, retry ravel on the host by first
348
+ # copying it back.
349
+ else:
350
+ hostary = cr.to_host(a, stream).ravel()
351
+ return cr.to_device(hostary, stream)
352
+
353
+ if args[0].ndim > 1:
354
+ args = [attempt_ravel(a) for a in args]
355
+
356
+ # Prepare argument on the device
357
+ devarys = []
358
+ any_device = False
359
+ for a in args:
360
+ if cr.is_device_array(a):
361
+ devarys.append(a)
362
+ any_device = True
363
+ else:
364
+ dev_a = cr.to_device(a, stream=stream)
365
+ devarys.append(dev_a)
366
+
367
+ # Launch
368
+ shape = args[0].shape
369
+ if out is None:
370
+ # No output is provided
371
+ devout = cr.allocate_device_array(shape, resty, stream=stream)
372
+
373
+ devarys.extend([devout])
374
+ cr.launch(func, shape[0], stream, devarys)
375
+
376
+ if any_device:
377
+ # If any of the arguments are on device,
378
+ # Keep output on the device
379
+ return devout.reshape(outshape)
380
+ else:
381
+ # Otherwise, transfer output back to host
382
+ return devout.copy_to_host().reshape(outshape)
383
+
384
+ elif cr.is_device_array(out):
385
+ # If output is provided and it is a device array,
386
+ # Return device array
387
+ if out.ndim > 1:
388
+ out = attempt_ravel(out)
389
+ devout = out
390
+ devarys.extend([devout])
391
+ cr.launch(func, shape[0], stream, devarys)
392
+ return devout.reshape(outshape)
393
+
394
+ else:
395
+ # If output is provided and it is a host array,
396
+ # Return host array
397
+ assert out.shape == shape
398
+ assert out.dtype == resty
399
+ devout = cr.allocate_device_array(shape, resty, stream=stream)
400
+ devarys.extend([devout])
401
+ cr.launch(func, shape[0], stream, devarys)
402
+ return devout.copy_to_host(out, stream=stream).reshape(outshape)
403
+
404
+ def to_device(self, hostary, stream):
405
+ """Implement to device transfer
406
+ Override in subclass
407
+ """
408
+ raise NotImplementedError
409
+
410
+ def to_host(self, devary, stream):
411
+ """Implement to host transfer
412
+ Override in subclass
413
+ """
414
+ raise NotImplementedError
415
+
416
+ def allocate_device_array(self, shape, dtype, stream):
417
+ """Implements device allocation
418
+ Override in subclass
419
+ """
420
+ raise NotImplementedError
421
+
422
+ def launch(self, func, count, stream, args):
423
+ """Implements device function invocation
424
+ Override in subclass
425
+ """
426
+ raise NotImplementedError
427
+
428
+
429
+ def to_dtype(ty):
430
+ if isinstance(ty, types.EnumMember):
431
+ ty = ty.dtype
432
+ return np.dtype(str(ty))
433
+
434
+
435
+ class DeviceVectorize(_BaseUFuncBuilder):
436
+ def __init__(self, func, identity=None, cache=False, targetoptions={}):
437
+ if cache:
438
+ raise TypeError("caching is not supported")
439
+ for opt in targetoptions:
440
+ if opt == "nopython":
441
+ warnings.warn(
442
+ "nopython kwarg for cuda target is redundant",
443
+ RuntimeWarning,
444
+ )
445
+ else:
446
+ fmt = "Unrecognized options. "
447
+ fmt += "cuda vectorize target does not support option: '%s'"
448
+ raise KeyError(fmt % opt)
449
+ self.py_func = func
450
+ self.identity = parse_identity(identity)
451
+ # { arg_dtype: (return_dtype), cudakernel }
452
+ self.kernelmap = OrderedDict()
453
+
454
+ @property
455
+ def pyfunc(self):
456
+ return self.py_func
457
+
458
+ def add(self, sig=None):
459
+ # compile core as device function
460
+ args, return_type = sigutils.normalize_signature(sig)
461
+ devfnsig = signature(return_type, *args)
462
+
463
+ funcname = self.pyfunc.__name__
464
+ kernelsource = self._get_kernel_source(
465
+ self._kernel_template, devfnsig, funcname
466
+ )
467
+ corefn, return_type = self._compile_core(devfnsig)
468
+ glbl = self._get_globals(corefn)
469
+ sig = signature(types.void, *([a[:] for a in args] + [return_type[:]]))
470
+ exec(kernelsource, glbl)
471
+
472
+ stager = glbl["__vectorized_%s" % funcname]
473
+ kernel = self._compile_kernel(stager, sig)
474
+
475
+ argdtypes = tuple(to_dtype(t) for t in devfnsig.args)
476
+ resdtype = to_dtype(return_type)
477
+ self.kernelmap[tuple(argdtypes)] = resdtype, kernel
478
+
479
+ def build_ufunc(self):
480
+ raise NotImplementedError
481
+
482
+ def _get_kernel_source(self, template, sig, funcname):
483
+ args = ["a%d" % i for i in range(len(sig.args))]
484
+ fmts = dict(
485
+ name=funcname,
486
+ args=", ".join(args),
487
+ argitems=", ".join("%s[__tid__]" % i for i in args),
488
+ )
489
+ return template.format(**fmts)
490
+
491
+ def _compile_core(self, sig):
492
+ raise NotImplementedError
493
+
494
+ def _get_globals(self, corefn):
495
+ raise NotImplementedError
496
+
497
+ def _compile_kernel(self, fnobj, sig):
498
+ raise NotImplementedError
499
+
500
+
501
+ class DeviceGUFuncVectorize(_BaseUFuncBuilder):
502
+ def __init__(
503
+ self,
504
+ func,
505
+ sig,
506
+ identity=None,
507
+ cache=False,
508
+ targetoptions={},
509
+ writable_args=(),
510
+ ):
511
+ if cache:
512
+ raise TypeError("caching is not supported")
513
+ if writable_args:
514
+ raise TypeError("writable_args are not supported")
515
+
516
+ # Allow nopython flag to be set.
517
+ if not targetoptions.pop("nopython", True):
518
+ raise TypeError("nopython flag must be True")
519
+ # Are there any more target options?
520
+ if targetoptions:
521
+ opts = ", ".join([repr(k) for k in targetoptions.keys()])
522
+ fmt = "The following target options are not supported: {0}"
523
+ raise TypeError(fmt.format(opts))
524
+
525
+ self.py_func = func
526
+ self.identity = parse_identity(identity)
527
+ self.signature = sig
528
+ self.inputsig, self.outputsig = parse_signature(self.signature)
529
+
530
+ # Maps from a tuple of input_dtypes to (output_dtypes, kernel)
531
+ self.kernelmap = OrderedDict()
532
+
533
+ @property
534
+ def pyfunc(self):
535
+ return self.py_func
536
+
537
+ def add(self, sig=None):
538
+ indims = [len(x) for x in self.inputsig]
539
+ outdims = [len(x) for x in self.outputsig]
540
+ args, return_type = sigutils.normalize_signature(sig)
541
+
542
+ # It is only valid to specify types.none as a return type, or to not
543
+ # specify the return type (where the "Python None" is the return type)
544
+ valid_return_type = return_type in (types.none, None)
545
+ if not valid_return_type:
546
+ raise TypeError(
547
+ "guvectorized functions cannot return values: "
548
+ f"signature {sig} specifies {return_type} return "
549
+ "type"
550
+ )
551
+
552
+ funcname = self.py_func.__name__
553
+ src = expand_gufunc_template(
554
+ self._kernel_template, indims, outdims, funcname, args
555
+ )
556
+
557
+ glbls = self._get_globals(sig)
558
+
559
+ exec(src, glbls)
560
+ fnobj = glbls["__gufunc_{name}".format(name=funcname)]
561
+
562
+ outertys = list(_determine_gufunc_outer_types(args, indims + outdims))
563
+ kernel = self._compile_kernel(fnobj, sig=tuple(outertys))
564
+
565
+ nout = len(outdims)
566
+ dtypes = [np.dtype(str(t.dtype)) for t in outertys]
567
+ indtypes = tuple(dtypes[:-nout])
568
+ outdtypes = tuple(dtypes[-nout:])
569
+
570
+ self.kernelmap[indtypes] = outdtypes, kernel
571
+
572
+ def _compile_kernel(self, fnobj, sig):
573
+ raise NotImplementedError
574
+
575
+ def _get_globals(self, sig):
576
+ raise NotImplementedError
577
+
578
+
579
+ def _determine_gufunc_outer_types(argtys, dims):
580
+ for at, nd in zip(argtys, dims):
581
+ if isinstance(at, types.Array):
582
+ yield at.copy(ndim=nd + 1)
583
+ else:
584
+ if nd > 0:
585
+ raise ValueError("gufunc signature mismatch: ndim>0 for scalar")
586
+ yield types.Array(dtype=at, ndim=1, layout="A")
587
+
588
+
589
+ def expand_gufunc_template(template, indims, outdims, funcname, argtypes):
590
+ """Expand gufunc source template"""
591
+ argdims = indims + outdims
592
+ argnames = ["arg{0}".format(i) for i in range(len(argdims))]
593
+ checkedarg = "min({0})".format(
594
+ ", ".join(["{0}.shape[0]".format(a) for a in argnames])
595
+ )
596
+ inputs = [
597
+ _gen_src_for_indexing(aref, adims, atype)
598
+ for aref, adims, atype in zip(argnames, indims, argtypes)
599
+ ]
600
+ outputs = [
601
+ _gen_src_for_indexing(aref, adims, atype)
602
+ for aref, adims, atype in zip(
603
+ argnames[len(indims) :], outdims, argtypes[len(indims) :]
604
+ )
605
+ ]
606
+ argitems = inputs + outputs
607
+ src = template.format(
608
+ name=funcname,
609
+ args=", ".join(argnames),
610
+ checkedarg=checkedarg,
611
+ argitems=", ".join(argitems),
612
+ )
613
+ return src
614
+
615
+
616
+ def _gen_src_for_indexing(aref, adims, atype):
617
+ return "{aref}[{sliced}]".format(
618
+ aref=aref, sliced=_gen_src_index(adims, atype)
619
+ )
620
+
621
+
622
+ def _gen_src_index(adims, atype):
623
+ if adims > 0:
624
+ return ",".join(["__tid__"] + [":"] * adims)
625
+ elif isinstance(atype, types.Array) and atype.ndim - 1 == adims:
626
+ # Special case for 0-nd in shape-signature but
627
+ # 1d array in type signature.
628
+ # Slice it so that the result has the same dimension.
629
+ return "__tid__:(__tid__ + 1)"
630
+ else:
631
+ return "__tid__"
632
+
633
+
634
+ class GUFuncEngine(object):
635
+ """Determine how to broadcast and execute a gufunc
636
+ base on input shape and signature
637
+ """
638
+
639
+ @classmethod
640
+ def from_signature(cls, signature):
641
+ return cls(*parse_signature(signature))
642
+
643
+ def __init__(self, inputsig, outputsig):
644
+ # signatures
645
+ self.sin = inputsig
646
+ self.sout = outputsig
647
+ # argument count
648
+ self.nin = len(self.sin)
649
+ self.nout = len(self.sout)
650
+
651
+ def schedule(self, ishapes):
652
+ if len(ishapes) != self.nin:
653
+ raise TypeError("invalid number of input argument")
654
+
655
+ # associate symbol values for input signature
656
+ symbolmap = {}
657
+ outer_shapes = []
658
+ inner_shapes = []
659
+
660
+ for argn, (shape, symbols) in enumerate(zip(ishapes, self.sin)):
661
+ argn += 1 # start from 1 for human
662
+ inner_ndim = len(symbols)
663
+ if len(shape) < inner_ndim:
664
+ fmt = "arg #%d: insufficient inner dimension"
665
+ raise ValueError(fmt % (argn,))
666
+ if inner_ndim:
667
+ inner_shape = shape[-inner_ndim:]
668
+ outer_shape = shape[:-inner_ndim]
669
+ else:
670
+ inner_shape = ()
671
+ outer_shape = shape
672
+
673
+ for axis, (dim, sym) in enumerate(zip(inner_shape, symbols)):
674
+ axis += len(outer_shape)
675
+ if sym in symbolmap:
676
+ if symbolmap[sym] != dim:
677
+ fmt = "arg #%d: shape[%d] mismatch argument"
678
+ raise ValueError(fmt % (argn, axis))
679
+ symbolmap[sym] = dim
680
+
681
+ outer_shapes.append(outer_shape)
682
+ inner_shapes.append(inner_shape)
683
+
684
+ # solve output shape
685
+ oshapes = []
686
+ for outsig in self.sout:
687
+ oshape = []
688
+ for sym in outsig:
689
+ oshape.append(symbolmap[sym])
690
+ oshapes.append(tuple(oshape))
691
+
692
+ # find the biggest outershape as looping dimension
693
+ sizes = [reduce(operator.mul, s, 1) for s in outer_shapes]
694
+ largest_i = np.argmax(sizes)
695
+ loopdims = outer_shapes[largest_i]
696
+
697
+ pinned = [False] * self.nin # same argument for each iteration
698
+ for i, d in enumerate(outer_shapes):
699
+ if d != loopdims:
700
+ if d == (1,) or d == ():
701
+ pinned[i] = True
702
+ else:
703
+ fmt = "arg #%d: outer dimension mismatch"
704
+ raise ValueError(fmt % (i + 1,))
705
+
706
+ return GUFuncSchedule(self, inner_shapes, oshapes, loopdims, pinned)
707
+
708
+
709
+ class GUFuncSchedule(object):
710
+ def __init__(self, parent, ishapes, oshapes, loopdims, pinned):
711
+ self.parent = parent
712
+ # core shapes
713
+ self.ishapes = ishapes
714
+ self.oshapes = oshapes
715
+ # looping dimension
716
+ self.loopdims = loopdims
717
+ self.loopn = reduce(operator.mul, loopdims, 1)
718
+ # flags
719
+ self.pinned = pinned
720
+
721
+ self.output_shapes = [loopdims + s for s in oshapes]
722
+
723
+ def __str__(self):
724
+ import pprint
725
+
726
+ attrs = "ishapes", "oshapes", "loopdims", "loopn", "pinned"
727
+ values = [(k, getattr(self, k)) for k in attrs]
728
+ return pprint.pformat(dict(values))
729
+
730
+
731
+ class GeneralizedUFunc(object):
732
+ def __init__(self, kernelmap, engine):
733
+ self.kernelmap = kernelmap
734
+ self.engine = engine
735
+ self.max_blocksize = 2**30
736
+
737
+ def __call__(self, *args, **kws):
738
+ callsteps = self._call_steps(
739
+ self.engine.nin, self.engine.nout, args, kws
740
+ )
741
+ indtypes, schedule, outdtypes, kernel = self._schedule(
742
+ callsteps.inputs, callsteps.outputs
743
+ )
744
+ callsteps.adjust_input_types(indtypes)
745
+
746
+ outputs = callsteps.prepare_outputs(schedule, outdtypes)
747
+ inputs = callsteps.prepare_inputs()
748
+ parameters = self._broadcast(schedule, inputs, outputs)
749
+
750
+ callsteps.launch_kernel(kernel, schedule.loopn, parameters)
751
+
752
+ return callsteps.post_process_outputs(outputs)
753
+
754
+ def _schedule(self, inputs, outs):
755
+ input_shapes = [a.shape for a in inputs]
756
+ schedule = self.engine.schedule(input_shapes)
757
+
758
+ # find kernel
759
+ indtypes = tuple(i.dtype for i in inputs)
760
+ try:
761
+ outdtypes, kernel = self.kernelmap[indtypes]
762
+ except KeyError:
763
+ # No exact match, then use the first compatible.
764
+ # This does not match the numpy dispatching exactly.
765
+ # Later, we may just jit a new version for the missing signature.
766
+ indtypes = self._search_matching_signature(indtypes)
767
+ # Select kernel
768
+ outdtypes, kernel = self.kernelmap[indtypes]
769
+
770
+ # check output
771
+ for sched_shape, out in zip(schedule.output_shapes, outs):
772
+ if out is not None and sched_shape != out.shape:
773
+ raise ValueError("output shape mismatch")
774
+
775
+ return indtypes, schedule, outdtypes, kernel
776
+
777
+ def _search_matching_signature(self, idtypes):
778
+ """
779
+ Given the input types in `idtypes`, return a compatible sequence of
780
+ types that is defined in `kernelmap`.
781
+
782
+ Note: Ordering is guaranteed by `kernelmap` being a OrderedDict
783
+ """
784
+ for sig in self.kernelmap.keys():
785
+ if all(
786
+ np.can_cast(actual, desired)
787
+ for actual, desired in zip(sig, idtypes)
788
+ ):
789
+ return sig
790
+ else:
791
+ raise TypeError("no matching signature")
792
+
793
+ def _broadcast(self, schedule, params, retvals):
794
+ assert schedule.loopn > 0, "zero looping dimension"
795
+
796
+ odim = 1 if not schedule.loopdims else schedule.loopn
797
+ newparams = []
798
+ for p, cs in zip(params, schedule.ishapes):
799
+ if not cs and p.size == 1:
800
+ # Broadcast scalar input
801
+ devary = self._broadcast_scalar_input(p, odim)
802
+ newparams.append(devary)
803
+ else:
804
+ # Broadcast vector input
805
+ newparams.append(self._broadcast_array(p, odim, cs))
806
+
807
+ newretvals = []
808
+ for retval, oshape in zip(retvals, schedule.oshapes):
809
+ newretvals.append(retval.reshape(odim, *oshape))
810
+ return tuple(newparams) + tuple(newretvals)
811
+
812
+ def _broadcast_array(self, ary, newdim, innerdim):
813
+ newshape = (newdim,) + innerdim
814
+ # No change in shape
815
+ if ary.shape == newshape:
816
+ return ary
817
+
818
+ # Creating new dimension
819
+ elif len(ary.shape) < len(newshape):
820
+ assert newshape[-len(ary.shape) :] == ary.shape, (
821
+ "cannot add dim and reshape at the same time"
822
+ )
823
+ return self._broadcast_add_axis(ary, newshape)
824
+
825
+ # Collapsing dimension
826
+ else:
827
+ return ary.reshape(*newshape)
828
+
829
+ def _broadcast_add_axis(self, ary, newshape):
830
+ raise NotImplementedError("cannot add new axis")
831
+
832
+ def _broadcast_scalar_input(self, ary, shape):
833
+ raise NotImplementedError
834
+
835
+
836
+ class GUFuncCallSteps(metaclass=ABCMeta):
837
+ """
838
+ Implements memory management and kernel launch operations for GUFunc calls.
839
+
840
+ One instance of this class is instantiated for each call, and the instance
841
+ is specific to the arguments given to the GUFunc call.
842
+
843
+ The base class implements the overall logic; subclasses provide
844
+ target-specific implementations of individual functions.
845
+ """
846
+
847
+ # The base class uses these slots; subclasses may provide additional slots.
848
+ __slots__ = [
849
+ "outputs",
850
+ "inputs",
851
+ "_copy_result_to_host",
852
+ ]
853
+
854
+ @abstractmethod
855
+ def launch_kernel(self, kernel, nelem, args):
856
+ """Implement the kernel launch"""
857
+
858
+ @abstractmethod
859
+ def is_device_array(self, obj):
860
+ """
861
+ Return True if `obj` is a device array for this target, False
862
+ otherwise.
863
+ """
864
+
865
+ @abstractmethod
866
+ def as_device_array(self, obj):
867
+ """
868
+ Return `obj` as a device array on this target.
869
+
870
+ May return `obj` directly if it is already on the target.
871
+ """
872
+
873
+ @abstractmethod
874
+ def to_device(self, hostary):
875
+ """
876
+ Copy `hostary` to the device and return the device array.
877
+ """
878
+
879
+ @abstractmethod
880
+ def allocate_device_array(self, shape, dtype):
881
+ """
882
+ Allocate a new uninitialized device array with the given shape and
883
+ dtype.
884
+ """
885
+
886
+ def __init__(self, nin, nout, args, kwargs):
887
+ outputs = kwargs.get("out")
888
+
889
+ # Ensure the user has passed a correct number of arguments
890
+ if outputs is None and len(args) not in (nin, (nin + nout)):
891
+
892
+ def pos_argn(n):
893
+ return f"{n} positional argument{'s' * (n != 1)}"
894
+
895
+ msg = (
896
+ f"This gufunc accepts {pos_argn(nin)} (when providing "
897
+ f"input only) or {pos_argn(nin + nout)} (when providing "
898
+ f"input and output). Got {pos_argn(len(args))}."
899
+ )
900
+ raise TypeError(msg)
901
+
902
+ if outputs is not None and len(args) > nin:
903
+ raise ValueError(
904
+ "cannot specify argument 'out' as both positional and keyword"
905
+ )
906
+ else:
907
+ # If the user did not pass outputs either in the out kwarg or as
908
+ # positional arguments, then we need to generate an initial list of
909
+ # "placeholder" outputs using None as a sentry value
910
+ outputs = [outputs] * nout
911
+
912
+ # Ensure all output device arrays are Numba device arrays - for
913
+ # example, any output passed in that supports the CUDA Array Interface
914
+ # is converted to a Numba CUDA device array; others are left untouched.
915
+ all_user_outputs_are_host = True
916
+ self.outputs = []
917
+ for output in outputs:
918
+ if self.is_device_array(output):
919
+ self.outputs.append(self.as_device_array(output))
920
+ all_user_outputs_are_host = False
921
+ else:
922
+ self.outputs.append(output)
923
+
924
+ all_host_arrays = not any([self.is_device_array(a) for a in args])
925
+
926
+ # - If any of the arguments are device arrays, we leave the output on
927
+ # the device.
928
+ self._copy_result_to_host = (
929
+ all_host_arrays and all_user_outputs_are_host
930
+ )
931
+
932
+ # Normalize arguments - ensure they are either device- or host-side
933
+ # arrays (as opposed to lists, tuples, etc).
934
+ def normalize_arg(a):
935
+ if self.is_device_array(a):
936
+ convert = self.as_device_array
937
+ else:
938
+ convert = np.asarray
939
+
940
+ return convert(a)
941
+
942
+ normalized_args = [normalize_arg(a) for a in args]
943
+ self.inputs = normalized_args[:nin]
944
+
945
+ # Check if there are extra arguments for outputs.
946
+ unused_inputs = normalized_args[nin:]
947
+ if unused_inputs:
948
+ self.outputs = unused_inputs
949
+
950
+ def adjust_input_types(self, indtypes):
951
+ """
952
+ Attempt to cast the inputs to the required types if necessary
953
+ and if they are not device arrays.
954
+
955
+ Side effect: Only affects the elements of `inputs` that require
956
+ a type cast.
957
+ """
958
+ for i, (ity, val) in enumerate(zip(indtypes, self.inputs)):
959
+ if ity != val.dtype:
960
+ if not hasattr(val, "astype"):
961
+ msg = (
962
+ "compatible signature is possible by casting but "
963
+ "{0} does not support .astype()"
964
+ ).format(type(val))
965
+ raise TypeError(msg)
966
+ # Cast types
967
+ self.inputs[i] = val.astype(ity)
968
+
969
+ def prepare_outputs(self, schedule, outdtypes):
970
+ """
971
+ Returns a list of output parameters that all reside on the target
972
+ device.
973
+
974
+ Outputs that were passed-in to the GUFunc are used if they reside on the
975
+ device; other outputs are allocated as necessary.
976
+ """
977
+ outputs = []
978
+ for shape, dtype, output in zip(
979
+ schedule.output_shapes, outdtypes, self.outputs
980
+ ):
981
+ if output is None or self._copy_result_to_host:
982
+ output = self.allocate_device_array(shape, dtype)
983
+ outputs.append(output)
984
+
985
+ return outputs
986
+
987
+ def prepare_inputs(self):
988
+ """
989
+ Returns a list of input parameters that all reside on the target device.
990
+ """
991
+
992
+ def ensure_device(parameter):
993
+ if self.is_device_array(parameter):
994
+ convert = self.as_device_array
995
+ else:
996
+ convert = self.to_device
997
+
998
+ return convert(parameter)
999
+
1000
+ return [ensure_device(p) for p in self.inputs]
1001
+
1002
+ def post_process_outputs(self, outputs):
1003
+ """
1004
+ Moves the given output(s) to the host if necessary.
1005
+
1006
+ Returns a single value (e.g. an array) if there was one output, or a
1007
+ tuple of arrays if there were multiple. Although this feels a little
1008
+ jarring, it is consistent with the behavior of GUFuncs in general.
1009
+ """
1010
+ if self._copy_result_to_host:
1011
+ outputs = [
1012
+ self.to_host(output, self_output)
1013
+ for output, self_output in zip(outputs, self.outputs)
1014
+ ]
1015
+ elif self.outputs[0] is not None:
1016
+ outputs = self.outputs
1017
+
1018
+ if len(outputs) == 1:
1019
+ return outputs[0]
1020
+ else:
1021
+ return tuple(outputs)