numba-cuda 0.0.0__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. _numba_cuda_redirector.pth +1 -0
  2. _numba_cuda_redirector.py +74 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +5 -0
  5. numba_cuda/_version.py +19 -0
  6. numba_cuda/numba/cuda/__init__.py +22 -0
  7. numba_cuda/numba/cuda/api.py +526 -0
  8. numba_cuda/numba/cuda/api_util.py +30 -0
  9. numba_cuda/numba/cuda/args.py +77 -0
  10. numba_cuda/numba/cuda/cg.py +62 -0
  11. numba_cuda/numba/cuda/codegen.py +378 -0
  12. numba_cuda/numba/cuda/compiler.py +422 -0
  13. numba_cuda/numba/cuda/cpp_function_wrappers.cu +47 -0
  14. numba_cuda/numba/cuda/cuda_fp16.h +3631 -0
  15. numba_cuda/numba/cuda/cuda_fp16.hpp +2465 -0
  16. numba_cuda/numba/cuda/cuda_paths.py +258 -0
  17. numba_cuda/numba/cuda/cudadecl.py +806 -0
  18. numba_cuda/numba/cuda/cudadrv/__init__.py +9 -0
  19. numba_cuda/numba/cuda/cudadrv/devicearray.py +904 -0
  20. numba_cuda/numba/cuda/cudadrv/devices.py +248 -0
  21. numba_cuda/numba/cuda/cudadrv/driver.py +3201 -0
  22. numba_cuda/numba/cuda/cudadrv/drvapi.py +398 -0
  23. numba_cuda/numba/cuda/cudadrv/dummyarray.py +452 -0
  24. numba_cuda/numba/cuda/cudadrv/enums.py +607 -0
  25. numba_cuda/numba/cuda/cudadrv/error.py +36 -0
  26. numba_cuda/numba/cuda/cudadrv/libs.py +176 -0
  27. numba_cuda/numba/cuda/cudadrv/ndarray.py +20 -0
  28. numba_cuda/numba/cuda/cudadrv/nvrtc.py +260 -0
  29. numba_cuda/numba/cuda/cudadrv/nvvm.py +707 -0
  30. numba_cuda/numba/cuda/cudadrv/rtapi.py +10 -0
  31. numba_cuda/numba/cuda/cudadrv/runtime.py +142 -0
  32. numba_cuda/numba/cuda/cudaimpl.py +1055 -0
  33. numba_cuda/numba/cuda/cudamath.py +140 -0
  34. numba_cuda/numba/cuda/decorators.py +189 -0
  35. numba_cuda/numba/cuda/descriptor.py +33 -0
  36. numba_cuda/numba/cuda/device_init.py +89 -0
  37. numba_cuda/numba/cuda/deviceufunc.py +908 -0
  38. numba_cuda/numba/cuda/dispatcher.py +1057 -0
  39. numba_cuda/numba/cuda/errors.py +59 -0
  40. numba_cuda/numba/cuda/extending.py +7 -0
  41. numba_cuda/numba/cuda/initialize.py +13 -0
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +77 -0
  43. numba_cuda/numba/cuda/intrinsics.py +198 -0
  44. numba_cuda/numba/cuda/kernels/__init__.py +0 -0
  45. numba_cuda/numba/cuda/kernels/reduction.py +262 -0
  46. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  47. numba_cuda/numba/cuda/libdevice.py +3382 -0
  48. numba_cuda/numba/cuda/libdevicedecl.py +17 -0
  49. numba_cuda/numba/cuda/libdevicefuncs.py +1057 -0
  50. numba_cuda/numba/cuda/libdeviceimpl.py +83 -0
  51. numba_cuda/numba/cuda/mathimpl.py +448 -0
  52. numba_cuda/numba/cuda/models.py +48 -0
  53. numba_cuda/numba/cuda/nvvmutils.py +235 -0
  54. numba_cuda/numba/cuda/printimpl.py +86 -0
  55. numba_cuda/numba/cuda/random.py +292 -0
  56. numba_cuda/numba/cuda/simulator/__init__.py +38 -0
  57. numba_cuda/numba/cuda/simulator/api.py +110 -0
  58. numba_cuda/numba/cuda/simulator/compiler.py +9 -0
  59. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +2 -0
  60. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +432 -0
  61. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +117 -0
  62. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +62 -0
  63. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +4 -0
  64. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +4 -0
  65. numba_cuda/numba/cuda/simulator/cudadrv/error.py +6 -0
  66. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +2 -0
  67. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +29 -0
  68. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +19 -0
  69. numba_cuda/numba/cuda/simulator/kernel.py +308 -0
  70. numba_cuda/numba/cuda/simulator/kernelapi.py +495 -0
  71. numba_cuda/numba/cuda/simulator/reduction.py +15 -0
  72. numba_cuda/numba/cuda/simulator/vector_types.py +58 -0
  73. numba_cuda/numba/cuda/simulator_init.py +17 -0
  74. numba_cuda/numba/cuda/stubs.py +902 -0
  75. numba_cuda/numba/cuda/target.py +440 -0
  76. numba_cuda/numba/cuda/testing.py +202 -0
  77. numba_cuda/numba/cuda/tests/__init__.py +58 -0
  78. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +8 -0
  79. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +145 -0
  80. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +145 -0
  81. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +375 -0
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +21 -0
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +179 -0
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +235 -0
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +22 -0
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +193 -0
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +547 -0
  88. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +249 -0
  89. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +81 -0
  90. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +192 -0
  91. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +38 -0
  92. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +65 -0
  93. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +139 -0
  94. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +37 -0
  95. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +12 -0
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +317 -0
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +127 -0
  98. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +54 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +199 -0
  100. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +37 -0
  101. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +20 -0
  102. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +149 -0
  103. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +36 -0
  104. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +85 -0
  105. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +41 -0
  106. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +122 -0
  107. numba_cuda/numba/cuda/tests/cudapy/__init__.py +8 -0
  108. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +234 -0
  109. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +41 -0
  110. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +58 -0
  111. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +30 -0
  112. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +100 -0
  113. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +42 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_array.py +260 -0
  115. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +201 -0
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +35 -0
  117. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1620 -0
  118. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +120 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +24 -0
  120. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +545 -0
  121. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +257 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +33 -0
  123. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +276 -0
  124. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +296 -0
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +20 -0
  126. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +129 -0
  127. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +176 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +147 -0
  129. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +435 -0
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +90 -0
  131. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +94 -0
  132. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +101 -0
  133. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +221 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +222 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +700 -0
  136. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +121 -0
  137. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +79 -0
  138. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +174 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +155 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +244 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +52 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +29 -0
  143. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +66 -0
  144. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +60 -0
  145. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +456 -0
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +159 -0
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +95 -0
  148. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +37 -0
  149. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +165 -0
  150. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1106 -0
  151. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +318 -0
  152. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +99 -0
  153. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +64 -0
  154. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +119 -0
  155. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +187 -0
  156. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +199 -0
  157. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +164 -0
  158. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +37 -0
  159. numba_cuda/numba/cuda/tests/cudapy/test_math.py +786 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +74 -0
  161. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +113 -0
  162. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +22 -0
  163. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +140 -0
  164. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +46 -0
  165. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +101 -0
  166. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +49 -0
  167. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +401 -0
  168. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +86 -0
  169. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +335 -0
  170. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +124 -0
  171. numba_cuda/numba/cuda/tests/cudapy/test_print.py +128 -0
  172. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +33 -0
  173. numba_cuda/numba/cuda/tests/cudapy/test_random.py +104 -0
  174. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +610 -0
  175. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +125 -0
  176. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +76 -0
  177. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  178. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +85 -0
  179. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +37 -0
  180. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +444 -0
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +205 -0
  182. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +271 -0
  183. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +80 -0
  184. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +277 -0
  185. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +47 -0
  186. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +307 -0
  187. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +283 -0
  188. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +20 -0
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +69 -0
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +36 -0
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +37 -0
  192. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +139 -0
  193. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +276 -0
  194. numba_cuda/numba/cuda/tests/cudasim/__init__.py +6 -0
  195. numba_cuda/numba/cuda/tests/cudasim/support.py +6 -0
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +102 -0
  197. numba_cuda/numba/cuda/tests/data/__init__.py +0 -0
  198. numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
  199. numba_cuda/numba/cuda/tests/data/error.cu +7 -0
  200. numba_cuda/numba/cuda/tests/data/jitlink.cu +23 -0
  201. numba_cuda/numba/cuda/tests/data/jitlink.ptx +51 -0
  202. numba_cuda/numba/cuda/tests/data/warn.cu +7 -0
  203. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +6 -0
  204. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +0 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +49 -0
  206. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +77 -0
  207. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +76 -0
  208. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +82 -0
  209. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +155 -0
  210. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +173 -0
  211. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +109 -0
  212. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +59 -0
  213. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +76 -0
  214. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +130 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +50 -0
  216. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +73 -0
  217. numba_cuda/numba/cuda/tests/nocuda/__init__.py +8 -0
  218. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +359 -0
  219. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +36 -0
  220. numba_cuda/numba/cuda/tests/nocuda/test_import.py +49 -0
  221. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +238 -0
  222. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +54 -0
  223. numba_cuda/numba/cuda/types.py +37 -0
  224. numba_cuda/numba/cuda/ufuncs.py +662 -0
  225. numba_cuda/numba/cuda/vector_types.py +209 -0
  226. numba_cuda/numba/cuda/vectorizers.py +252 -0
  227. numba_cuda-0.0.12.dist-info/LICENSE +25 -0
  228. numba_cuda-0.0.12.dist-info/METADATA +68 -0
  229. numba_cuda-0.0.12.dist-info/RECORD +231 -0
  230. {numba_cuda-0.0.0.dist-info → numba_cuda-0.0.12.dist-info}/WHEEL +1 -1
  231. numba_cuda-0.0.0.dist-info/METADATA +0 -6
  232. numba_cuda-0.0.0.dist-info/RECORD +0 -5
  233. {numba_cuda-0.0.0.dist-info → numba_cuda-0.0.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,908 @@
1
+ """
2
+ Implements custom ufunc dispatch mechanism for non-CPU devices.
3
+ """
4
+
5
+ from abc import ABCMeta, abstractmethod
6
+ from collections import OrderedDict
7
+ import operator
8
+ import warnings
9
+ from functools import reduce
10
+
11
+ import numpy as np
12
+
13
+ from numba.np.ufunc.ufuncbuilder import _BaseUFuncBuilder, parse_identity
14
+ from numba.core import types, sigutils
15
+ from numba.core.typing import signature
16
+ from numba.np.ufunc.sigparse import parse_signature
17
+
18
+
19
+ def _broadcast_axis(a, b):
20
+ """
21
+ Raises
22
+ ------
23
+ ValueError if broadcast fails
24
+ """
25
+ if a == b:
26
+ return a
27
+ elif a == 1:
28
+ return b
29
+ elif b == 1:
30
+ return a
31
+ else:
32
+ raise ValueError("failed to broadcast {0} and {1}".format(a, b))
33
+
34
+
35
+ def _pairwise_broadcast(shape1, shape2):
36
+ """
37
+ Raises
38
+ ------
39
+ ValueError if broadcast fails
40
+ """
41
+ shape1, shape2 = map(tuple, [shape1, shape2])
42
+
43
+ while len(shape1) < len(shape2):
44
+ shape1 = (1,) + shape1
45
+
46
+ while len(shape1) > len(shape2):
47
+ shape2 = (1,) + shape2
48
+
49
+ return tuple(_broadcast_axis(a, b) for a, b in zip(shape1, shape2))
50
+
51
+
52
+ def _multi_broadcast(*shapelist):
53
+ """
54
+ Raises
55
+ ------
56
+ ValueError if broadcast fails
57
+ """
58
+ assert shapelist
59
+
60
+ result = shapelist[0]
61
+ others = shapelist[1:]
62
+ try:
63
+ for i, each in enumerate(others, start=1):
64
+ result = _pairwise_broadcast(result, each)
65
+ except ValueError:
66
+ raise ValueError("failed to broadcast argument #{0}".format(i))
67
+ else:
68
+ return result
69
+
70
+
71
+ class UFuncMechanism(object):
72
+ """
73
+ Prepare ufunc arguments for vectorize.
74
+ """
75
+ DEFAULT_STREAM = None
76
+ SUPPORT_DEVICE_SLICING = False
77
+
78
+ def __init__(self, typemap, args):
79
+ """Never used directly by user. Invoke by UFuncMechanism.call().
80
+ """
81
+ self.typemap = typemap
82
+ self.args = args
83
+ nargs = len(self.args)
84
+ self.argtypes = [None] * nargs
85
+ self.scalarpos = []
86
+ self.signature = None
87
+ self.arrays = [None] * nargs
88
+
89
+ def _fill_arrays(self):
90
+ """
91
+ Get all arguments in array form
92
+ """
93
+ for i, arg in enumerate(self.args):
94
+ if self.is_device_array(arg):
95
+ self.arrays[i] = self.as_device_array(arg)
96
+ elif isinstance(arg, (int, float, complex, np.number)):
97
+ # Is scalar
98
+ self.scalarpos.append(i)
99
+ else:
100
+ self.arrays[i] = np.asarray(arg)
101
+
102
+ def _fill_argtypes(self):
103
+ """
104
+ Get dtypes
105
+ """
106
+ for i, ary in enumerate(self.arrays):
107
+ if ary is not None:
108
+ dtype = getattr(ary, 'dtype')
109
+ if dtype is None:
110
+ dtype = np.asarray(ary).dtype
111
+ self.argtypes[i] = dtype
112
+
113
+ def _resolve_signature(self):
114
+ """Resolve signature.
115
+ May have ambiguous case.
116
+ """
117
+ matches = []
118
+ # Resolve scalar args exact match first
119
+ if self.scalarpos:
120
+ # Try resolve scalar arguments
121
+ for formaltys in self.typemap:
122
+ match_map = []
123
+ for i, (formal, actual) in enumerate(zip(formaltys,
124
+ self.argtypes)):
125
+ if actual is None:
126
+ actual = np.asarray(self.args[i]).dtype
127
+
128
+ match_map.append(actual == formal)
129
+
130
+ if all(match_map):
131
+ matches.append(formaltys)
132
+
133
+ # No matching with exact match; try coercing the scalar arguments
134
+ if not matches:
135
+ matches = []
136
+ for formaltys in self.typemap:
137
+ all_matches = all(actual is None or formal == actual
138
+ for formal, actual in
139
+ zip(formaltys, self.argtypes))
140
+ if all_matches:
141
+ matches.append(formaltys)
142
+
143
+ if not matches:
144
+ raise TypeError("No matching version. GPU ufunc requires array "
145
+ "arguments to have the exact types. This behaves "
146
+ "like regular ufunc with casting='no'.")
147
+
148
+ if len(matches) > 1:
149
+ raise TypeError("Failed to resolve ufunc due to ambiguous "
150
+ "signature. Too many untyped scalars. "
151
+ "Use numpy dtype object to type tag.")
152
+
153
+ # Try scalar arguments
154
+ self.argtypes = matches[0]
155
+
156
+ def _get_actual_args(self):
157
+ """Return the actual arguments
158
+ Casts scalar arguments to np.array.
159
+ """
160
+ for i in self.scalarpos:
161
+ self.arrays[i] = np.array([self.args[i]], dtype=self.argtypes[i])
162
+
163
+ return self.arrays
164
+
165
+ def _broadcast(self, arys):
166
+ """Perform numpy ufunc broadcasting
167
+ """
168
+ shapelist = [a.shape for a in arys]
169
+ shape = _multi_broadcast(*shapelist)
170
+
171
+ for i, ary in enumerate(arys):
172
+ if ary.shape == shape:
173
+ pass
174
+
175
+ else:
176
+ if self.is_device_array(ary):
177
+ arys[i] = self.broadcast_device(ary, shape)
178
+
179
+ else:
180
+ ax_differs = [ax for ax in range(len(shape))
181
+ if ax >= ary.ndim
182
+ or ary.shape[ax] != shape[ax]]
183
+
184
+ missingdim = len(shape) - len(ary.shape)
185
+ strides = [0] * missingdim + list(ary.strides)
186
+
187
+ for ax in ax_differs:
188
+ strides[ax] = 0
189
+
190
+ strided = np.lib.stride_tricks.as_strided(ary,
191
+ shape=shape,
192
+ strides=strides)
193
+
194
+ arys[i] = self.force_array_layout(strided)
195
+
196
+ return arys
197
+
198
+ def get_arguments(self):
199
+ """Prepare and return the arguments for the ufunc.
200
+ Does not call to_device().
201
+ """
202
+ self._fill_arrays()
203
+ self._fill_argtypes()
204
+ self._resolve_signature()
205
+ arys = self._get_actual_args()
206
+ return self._broadcast(arys)
207
+
208
+ def get_function(self):
209
+ """Returns (result_dtype, function)
210
+ """
211
+ return self.typemap[self.argtypes]
212
+
213
+ def is_device_array(self, obj):
214
+ """Is the `obj` a device array?
215
+ Override in subclass
216
+ """
217
+ return False
218
+
219
+ def as_device_array(self, obj):
220
+ """Convert the `obj` to a device array
221
+ Override in subclass
222
+
223
+ Default implementation is an identity function
224
+ """
225
+ return obj
226
+
227
+ def broadcast_device(self, ary, shape):
228
+ """Handles ondevice broadcasting
229
+
230
+ Override in subclass to add support.
231
+ """
232
+ raise NotImplementedError("broadcasting on device is not supported")
233
+
234
+ def force_array_layout(self, ary):
235
+ """Ensures array layout met device requirement.
236
+
237
+ Override in sublcass
238
+ """
239
+ return ary
240
+
241
+ @classmethod
242
+ def call(cls, typemap, args, kws):
243
+ """Perform the entire ufunc call mechanism.
244
+ """
245
+ # Handle keywords
246
+ stream = kws.pop('stream', cls.DEFAULT_STREAM)
247
+ out = kws.pop('out', None)
248
+
249
+ if kws:
250
+ warnings.warn("unrecognized keywords: %s" % ', '.join(kws))
251
+
252
+ # Begin call resolution
253
+ cr = cls(typemap, args)
254
+ args = cr.get_arguments()
255
+ resty, func = cr.get_function()
256
+
257
+ outshape = args[0].shape
258
+
259
+ # Adjust output value
260
+ if out is not None and cr.is_device_array(out):
261
+ out = cr.as_device_array(out)
262
+
263
+ def attempt_ravel(a):
264
+ if cr.SUPPORT_DEVICE_SLICING:
265
+ raise NotImplementedError
266
+
267
+ try:
268
+ # Call the `.ravel()` method
269
+ return a.ravel()
270
+ except NotImplementedError:
271
+ # If it is not a device array
272
+ if not cr.is_device_array(a):
273
+ raise
274
+ # For device array, retry ravel on the host by first
275
+ # copying it back.
276
+ else:
277
+ hostary = cr.to_host(a, stream).ravel()
278
+ return cr.to_device(hostary, stream)
279
+
280
+ if args[0].ndim > 1:
281
+ args = [attempt_ravel(a) for a in args]
282
+
283
+ # Prepare argument on the device
284
+ devarys = []
285
+ any_device = False
286
+ for a in args:
287
+ if cr.is_device_array(a):
288
+ devarys.append(a)
289
+ any_device = True
290
+ else:
291
+ dev_a = cr.to_device(a, stream=stream)
292
+ devarys.append(dev_a)
293
+
294
+ # Launch
295
+ shape = args[0].shape
296
+ if out is None:
297
+ # No output is provided
298
+ devout = cr.allocate_device_array(shape, resty, stream=stream)
299
+
300
+ devarys.extend([devout])
301
+ cr.launch(func, shape[0], stream, devarys)
302
+
303
+ if any_device:
304
+ # If any of the arguments are on device,
305
+ # Keep output on the device
306
+ return devout.reshape(outshape)
307
+ else:
308
+ # Otherwise, transfer output back to host
309
+ return devout.copy_to_host().reshape(outshape)
310
+
311
+ elif cr.is_device_array(out):
312
+ # If output is provided and it is a device array,
313
+ # Return device array
314
+ if out.ndim > 1:
315
+ out = attempt_ravel(out)
316
+ devout = out
317
+ devarys.extend([devout])
318
+ cr.launch(func, shape[0], stream, devarys)
319
+ return devout.reshape(outshape)
320
+
321
+ else:
322
+ # If output is provided and it is a host array,
323
+ # Return host array
324
+ assert out.shape == shape
325
+ assert out.dtype == resty
326
+ devout = cr.allocate_device_array(shape, resty, stream=stream)
327
+ devarys.extend([devout])
328
+ cr.launch(func, shape[0], stream, devarys)
329
+ return devout.copy_to_host(out, stream=stream).reshape(outshape)
330
+
331
+ def to_device(self, hostary, stream):
332
+ """Implement to device transfer
333
+ Override in subclass
334
+ """
335
+ raise NotImplementedError
336
+
337
+ def to_host(self, devary, stream):
338
+ """Implement to host transfer
339
+ Override in subclass
340
+ """
341
+ raise NotImplementedError
342
+
343
+ def allocate_device_array(self, shape, dtype, stream):
344
+ """Implements device allocation
345
+ Override in subclass
346
+ """
347
+ raise NotImplementedError
348
+
349
+ def launch(self, func, count, stream, args):
350
+ """Implements device function invocation
351
+ Override in subclass
352
+ """
353
+ raise NotImplementedError
354
+
355
+
356
+ def to_dtype(ty):
357
+ if isinstance(ty, types.EnumMember):
358
+ ty = ty.dtype
359
+ return np.dtype(str(ty))
360
+
361
+
362
+ class DeviceVectorize(_BaseUFuncBuilder):
363
+ def __init__(self, func, identity=None, cache=False, targetoptions={}):
364
+ if cache:
365
+ raise TypeError("caching is not supported")
366
+ for opt in targetoptions:
367
+ if opt == 'nopython':
368
+ warnings.warn("nopython kwarg for cuda target is redundant",
369
+ RuntimeWarning)
370
+ else:
371
+ fmt = "Unrecognized options. "
372
+ fmt += "cuda vectorize target does not support option: '%s'"
373
+ raise KeyError(fmt % opt)
374
+ self.py_func = func
375
+ self.identity = parse_identity(identity)
376
+ # { arg_dtype: (return_dtype), cudakernel }
377
+ self.kernelmap = OrderedDict()
378
+
379
+ @property
380
+ def pyfunc(self):
381
+ return self.py_func
382
+
383
+ def add(self, sig=None):
384
+ # compile core as device function
385
+ args, return_type = sigutils.normalize_signature(sig)
386
+ devfnsig = signature(return_type, *args)
387
+
388
+ funcname = self.pyfunc.__name__
389
+ kernelsource = self._get_kernel_source(self._kernel_template,
390
+ devfnsig, funcname)
391
+ corefn, return_type = self._compile_core(devfnsig)
392
+ glbl = self._get_globals(corefn)
393
+ sig = signature(types.void, *([a[:] for a in args] + [return_type[:]]))
394
+ exec(kernelsource, glbl)
395
+
396
+ stager = glbl['__vectorized_%s' % funcname]
397
+ kernel = self._compile_kernel(stager, sig)
398
+
399
+ argdtypes = tuple(to_dtype(t) for t in devfnsig.args)
400
+ resdtype = to_dtype(return_type)
401
+ self.kernelmap[tuple(argdtypes)] = resdtype, kernel
402
+
403
+ def build_ufunc(self):
404
+ raise NotImplementedError
405
+
406
+ def _get_kernel_source(self, template, sig, funcname):
407
+ args = ['a%d' % i for i in range(len(sig.args))]
408
+ fmts = dict(name=funcname,
409
+ args=', '.join(args),
410
+ argitems=', '.join('%s[__tid__]' % i for i in args))
411
+ return template.format(**fmts)
412
+
413
+ def _compile_core(self, sig):
414
+ raise NotImplementedError
415
+
416
+ def _get_globals(self, corefn):
417
+ raise NotImplementedError
418
+
419
+ def _compile_kernel(self, fnobj, sig):
420
+ raise NotImplementedError
421
+
422
+
423
+ class DeviceGUFuncVectorize(_BaseUFuncBuilder):
424
+ def __init__(self, func, sig, identity=None, cache=False, targetoptions={},
425
+ writable_args=()):
426
+ if cache:
427
+ raise TypeError("caching is not supported")
428
+ if writable_args:
429
+ raise TypeError("writable_args are not supported")
430
+
431
+ # Allow nopython flag to be set.
432
+ if not targetoptions.pop('nopython', True):
433
+ raise TypeError("nopython flag must be True")
434
+ # Are there any more target options?
435
+ if targetoptions:
436
+ opts = ', '.join([repr(k) for k in targetoptions.keys()])
437
+ fmt = "The following target options are not supported: {0}"
438
+ raise TypeError(fmt.format(opts))
439
+
440
+ self.py_func = func
441
+ self.identity = parse_identity(identity)
442
+ self.signature = sig
443
+ self.inputsig, self.outputsig = parse_signature(self.signature)
444
+
445
+ # Maps from a tuple of input_dtypes to (output_dtypes, kernel)
446
+ self.kernelmap = OrderedDict()
447
+
448
+ @property
449
+ def pyfunc(self):
450
+ return self.py_func
451
+
452
+ def add(self, sig=None):
453
+ indims = [len(x) for x in self.inputsig]
454
+ outdims = [len(x) for x in self.outputsig]
455
+ args, return_type = sigutils.normalize_signature(sig)
456
+
457
+ # It is only valid to specify types.none as a return type, or to not
458
+ # specify the return type (where the "Python None" is the return type)
459
+ valid_return_type = return_type in (types.none, None)
460
+ if not valid_return_type:
461
+ raise TypeError('guvectorized functions cannot return values: '
462
+ f'signature {sig} specifies {return_type} return '
463
+ 'type')
464
+
465
+ funcname = self.py_func.__name__
466
+ src = expand_gufunc_template(self._kernel_template, indims,
467
+ outdims, funcname, args)
468
+
469
+ glbls = self._get_globals(sig)
470
+
471
+ exec(src, glbls)
472
+ fnobj = glbls['__gufunc_{name}'.format(name=funcname)]
473
+
474
+ outertys = list(_determine_gufunc_outer_types(args, indims + outdims))
475
+ kernel = self._compile_kernel(fnobj, sig=tuple(outertys))
476
+
477
+ nout = len(outdims)
478
+ dtypes = [np.dtype(str(t.dtype)) for t in outertys]
479
+ indtypes = tuple(dtypes[:-nout])
480
+ outdtypes = tuple(dtypes[-nout:])
481
+
482
+ self.kernelmap[indtypes] = outdtypes, kernel
483
+
484
+ def _compile_kernel(self, fnobj, sig):
485
+ raise NotImplementedError
486
+
487
+ def _get_globals(self, sig):
488
+ raise NotImplementedError
489
+
490
+
491
+ def _determine_gufunc_outer_types(argtys, dims):
492
+ for at, nd in zip(argtys, dims):
493
+ if isinstance(at, types.Array):
494
+ yield at.copy(ndim=nd + 1)
495
+ else:
496
+ if nd > 0:
497
+ raise ValueError("gufunc signature mismatch: ndim>0 for scalar")
498
+ yield types.Array(dtype=at, ndim=1, layout='A')
499
+
500
+
501
+ def expand_gufunc_template(template, indims, outdims, funcname, argtypes):
502
+ """Expand gufunc source template
503
+ """
504
+ argdims = indims + outdims
505
+ argnames = ["arg{0}".format(i) for i in range(len(argdims))]
506
+ checkedarg = "min({0})".format(', '.join(["{0}.shape[0]".format(a)
507
+ for a in argnames]))
508
+ inputs = [_gen_src_for_indexing(aref, adims, atype)
509
+ for aref, adims, atype in zip(argnames, indims, argtypes)]
510
+ outputs = [_gen_src_for_indexing(aref, adims, atype)
511
+ for aref, adims, atype in zip(argnames[len(indims):], outdims,
512
+ argtypes[len(indims):])]
513
+ argitems = inputs + outputs
514
+ src = template.format(name=funcname, args=', '.join(argnames),
515
+ checkedarg=checkedarg,
516
+ argitems=', '.join(argitems))
517
+ return src
518
+
519
+
520
+ def _gen_src_for_indexing(aref, adims, atype):
521
+ return "{aref}[{sliced}]".format(aref=aref,
522
+ sliced=_gen_src_index(adims, atype))
523
+
524
+
525
+ def _gen_src_index(adims, atype):
526
+ if adims > 0:
527
+ return ','.join(['__tid__'] + [':'] * adims)
528
+ elif isinstance(atype, types.Array) and atype.ndim - 1 == adims:
529
+ # Special case for 0-nd in shape-signature but
530
+ # 1d array in type signature.
531
+ # Slice it so that the result has the same dimension.
532
+ return '__tid__:(__tid__ + 1)'
533
+ else:
534
+ return '__tid__'
535
+
536
+
537
+ class GUFuncEngine(object):
538
+ '''Determine how to broadcast and execute a gufunc
539
+ base on input shape and signature
540
+ '''
541
+
542
+ @classmethod
543
+ def from_signature(cls, signature):
544
+ return cls(*parse_signature(signature))
545
+
546
+ def __init__(self, inputsig, outputsig):
547
+ # signatures
548
+ self.sin = inputsig
549
+ self.sout = outputsig
550
+ # argument count
551
+ self.nin = len(self.sin)
552
+ self.nout = len(self.sout)
553
+
554
+ def schedule(self, ishapes):
555
+ if len(ishapes) != self.nin:
556
+ raise TypeError('invalid number of input argument')
557
+
558
+ # associate symbol values for input signature
559
+ symbolmap = {}
560
+ outer_shapes = []
561
+ inner_shapes = []
562
+
563
+ for argn, (shape, symbols) in enumerate(zip(ishapes, self.sin)):
564
+ argn += 1 # start from 1 for human
565
+ inner_ndim = len(symbols)
566
+ if len(shape) < inner_ndim:
567
+ fmt = "arg #%d: insufficient inner dimension"
568
+ raise ValueError(fmt % (argn,))
569
+ if inner_ndim:
570
+ inner_shape = shape[-inner_ndim:]
571
+ outer_shape = shape[:-inner_ndim]
572
+ else:
573
+ inner_shape = ()
574
+ outer_shape = shape
575
+
576
+ for axis, (dim, sym) in enumerate(zip(inner_shape, symbols)):
577
+ axis += len(outer_shape)
578
+ if sym in symbolmap:
579
+ if symbolmap[sym] != dim:
580
+ fmt = "arg #%d: shape[%d] mismatch argument"
581
+ raise ValueError(fmt % (argn, axis))
582
+ symbolmap[sym] = dim
583
+
584
+ outer_shapes.append(outer_shape)
585
+ inner_shapes.append(inner_shape)
586
+
587
+ # solve output shape
588
+ oshapes = []
589
+ for outsig in self.sout:
590
+ oshape = []
591
+ for sym in outsig:
592
+ oshape.append(symbolmap[sym])
593
+ oshapes.append(tuple(oshape))
594
+
595
+ # find the biggest outershape as looping dimension
596
+ sizes = [reduce(operator.mul, s, 1) for s in outer_shapes]
597
+ largest_i = np.argmax(sizes)
598
+ loopdims = outer_shapes[largest_i]
599
+
600
+ pinned = [False] * self.nin # same argument for each iteration
601
+ for i, d in enumerate(outer_shapes):
602
+ if d != loopdims:
603
+ if d == (1,) or d == ():
604
+ pinned[i] = True
605
+ else:
606
+ fmt = "arg #%d: outer dimension mismatch"
607
+ raise ValueError(fmt % (i + 1,))
608
+
609
+ return GUFuncSchedule(self, inner_shapes, oshapes, loopdims, pinned)
610
+
611
+
612
+ class GUFuncSchedule(object):
613
+ def __init__(self, parent, ishapes, oshapes, loopdims, pinned):
614
+ self.parent = parent
615
+ # core shapes
616
+ self.ishapes = ishapes
617
+ self.oshapes = oshapes
618
+ # looping dimension
619
+ self.loopdims = loopdims
620
+ self.loopn = reduce(operator.mul, loopdims, 1)
621
+ # flags
622
+ self.pinned = pinned
623
+
624
+ self.output_shapes = [loopdims + s for s in oshapes]
625
+
626
+ def __str__(self):
627
+ import pprint
628
+
629
+ attrs = 'ishapes', 'oshapes', 'loopdims', 'loopn', 'pinned'
630
+ values = [(k, getattr(self, k)) for k in attrs]
631
+ return pprint.pformat(dict(values))
632
+
633
+
634
+ class GeneralizedUFunc(object):
635
+ def __init__(self, kernelmap, engine):
636
+ self.kernelmap = kernelmap
637
+ self.engine = engine
638
+ self.max_blocksize = 2 ** 30
639
+
640
+ def __call__(self, *args, **kws):
641
+ callsteps = self._call_steps(self.engine.nin, self.engine.nout,
642
+ args, kws)
643
+ indtypes, schedule, outdtypes, kernel = self._schedule(
644
+ callsteps.inputs, callsteps.outputs)
645
+ callsteps.adjust_input_types(indtypes)
646
+
647
+ outputs = callsteps.prepare_outputs(schedule, outdtypes)
648
+ inputs = callsteps.prepare_inputs()
649
+ parameters = self._broadcast(schedule, inputs, outputs)
650
+
651
+ callsteps.launch_kernel(kernel, schedule.loopn, parameters)
652
+
653
+ return callsteps.post_process_outputs(outputs)
654
+
655
+ def _schedule(self, inputs, outs):
656
+ input_shapes = [a.shape for a in inputs]
657
+ schedule = self.engine.schedule(input_shapes)
658
+
659
+ # find kernel
660
+ indtypes = tuple(i.dtype for i in inputs)
661
+ try:
662
+ outdtypes, kernel = self.kernelmap[indtypes]
663
+ except KeyError:
664
+ # No exact match, then use the first compatible.
665
+ # This does not match the numpy dispatching exactly.
666
+ # Later, we may just jit a new version for the missing signature.
667
+ indtypes = self._search_matching_signature(indtypes)
668
+ # Select kernel
669
+ outdtypes, kernel = self.kernelmap[indtypes]
670
+
671
+ # check output
672
+ for sched_shape, out in zip(schedule.output_shapes, outs):
673
+ if out is not None and sched_shape != out.shape:
674
+ raise ValueError('output shape mismatch')
675
+
676
+ return indtypes, schedule, outdtypes, kernel
677
+
678
+ def _search_matching_signature(self, idtypes):
679
+ """
680
+ Given the input types in `idtypes`, return a compatible sequence of
681
+ types that is defined in `kernelmap`.
682
+
683
+ Note: Ordering is guaranteed by `kernelmap` being a OrderedDict
684
+ """
685
+ for sig in self.kernelmap.keys():
686
+ if all(np.can_cast(actual, desired)
687
+ for actual, desired in zip(sig, idtypes)):
688
+ return sig
689
+ else:
690
+ raise TypeError("no matching signature")
691
+
692
+ def _broadcast(self, schedule, params, retvals):
693
+ assert schedule.loopn > 0, "zero looping dimension"
694
+
695
+ odim = 1 if not schedule.loopdims else schedule.loopn
696
+ newparams = []
697
+ for p, cs in zip(params, schedule.ishapes):
698
+ if not cs and p.size == 1:
699
+ # Broadcast scalar input
700
+ devary = self._broadcast_scalar_input(p, odim)
701
+ newparams.append(devary)
702
+ else:
703
+ # Broadcast vector input
704
+ newparams.append(self._broadcast_array(p, odim, cs))
705
+
706
+ newretvals = []
707
+ for retval, oshape in zip(retvals, schedule.oshapes):
708
+ newretvals.append(retval.reshape(odim, *oshape))
709
+ return tuple(newparams) + tuple(newretvals)
710
+
711
+ def _broadcast_array(self, ary, newdim, innerdim):
712
+ newshape = (newdim,) + innerdim
713
+ # No change in shape
714
+ if ary.shape == newshape:
715
+ return ary
716
+
717
+ # Creating new dimension
718
+ elif len(ary.shape) < len(newshape):
719
+ assert newshape[-len(ary.shape):] == ary.shape, \
720
+ "cannot add dim and reshape at the same time"
721
+ return self._broadcast_add_axis(ary, newshape)
722
+
723
+ # Collapsing dimension
724
+ else:
725
+ return ary.reshape(*newshape)
726
+
727
+ def _broadcast_add_axis(self, ary, newshape):
728
+ raise NotImplementedError("cannot add new axis")
729
+
730
+ def _broadcast_scalar_input(self, ary, shape):
731
+ raise NotImplementedError
732
+
733
+
734
+ class GUFuncCallSteps(metaclass=ABCMeta):
735
+ """
736
+ Implements memory management and kernel launch operations for GUFunc calls.
737
+
738
+ One instance of this class is instantiated for each call, and the instance
739
+ is specific to the arguments given to the GUFunc call.
740
+
741
+ The base class implements the overall logic; subclasses provide
742
+ target-specific implementations of individual functions.
743
+ """
744
+
745
+ # The base class uses these slots; subclasses may provide additional slots.
746
+ __slots__ = [
747
+ 'outputs',
748
+ 'inputs',
749
+ '_copy_result_to_host',
750
+ ]
751
+
752
+ @abstractmethod
753
+ def launch_kernel(self, kernel, nelem, args):
754
+ """Implement the kernel launch"""
755
+
756
+ @abstractmethod
757
+ def is_device_array(self, obj):
758
+ """
759
+ Return True if `obj` is a device array for this target, False
760
+ otherwise.
761
+ """
762
+
763
+ @abstractmethod
764
+ def as_device_array(self, obj):
765
+ """
766
+ Return `obj` as a device array on this target.
767
+
768
+ May return `obj` directly if it is already on the target.
769
+ """
770
+
771
+ @abstractmethod
772
+ def to_device(self, hostary):
773
+ """
774
+ Copy `hostary` to the device and return the device array.
775
+ """
776
+
777
+ @abstractmethod
778
+ def allocate_device_array(self, shape, dtype):
779
+ """
780
+ Allocate a new uninitialized device array with the given shape and
781
+ dtype.
782
+ """
783
+
784
+ def __init__(self, nin, nout, args, kwargs):
785
+ outputs = kwargs.get('out')
786
+
787
+ # Ensure the user has passed a correct number of arguments
788
+ if outputs is None and len(args) not in (nin, (nin + nout)):
789
+ def pos_argn(n):
790
+ return f'{n} positional argument{"s" * (n != 1)}'
791
+
792
+ msg = (f'This gufunc accepts {pos_argn(nin)} (when providing '
793
+ f'input only) or {pos_argn(nin + nout)} (when providing '
794
+ f'input and output). Got {pos_argn(len(args))}.')
795
+ raise TypeError(msg)
796
+
797
+ if outputs is not None and len(args) > nin:
798
+ raise ValueError("cannot specify argument 'out' as both positional "
799
+ "and keyword")
800
+ else:
801
+ # If the user did not pass outputs either in the out kwarg or as
802
+ # positional arguments, then we need to generate an initial list of
803
+ # "placeholder" outputs using None as a sentry value
804
+ outputs = [outputs] * nout
805
+
806
+ # Ensure all output device arrays are Numba device arrays - for
807
+ # example, any output passed in that supports the CUDA Array Interface
808
+ # is converted to a Numba CUDA device array; others are left untouched.
809
+ all_user_outputs_are_host = True
810
+ self.outputs = []
811
+ for output in outputs:
812
+ if self.is_device_array(output):
813
+ self.outputs.append(self.as_device_array(output))
814
+ all_user_outputs_are_host = False
815
+ else:
816
+ self.outputs.append(output)
817
+
818
+ all_host_arrays = not any([self.is_device_array(a) for a in args])
819
+
820
+ # - If any of the arguments are device arrays, we leave the output on
821
+ # the device.
822
+ self._copy_result_to_host = (all_host_arrays and
823
+ all_user_outputs_are_host)
824
+
825
+ # Normalize arguments - ensure they are either device- or host-side
826
+ # arrays (as opposed to lists, tuples, etc).
827
+ def normalize_arg(a):
828
+ if self.is_device_array(a):
829
+ convert = self.as_device_array
830
+ else:
831
+ convert = np.asarray
832
+
833
+ return convert(a)
834
+
835
+ normalized_args = [normalize_arg(a) for a in args]
836
+ self.inputs = normalized_args[:nin]
837
+
838
+ # Check if there are extra arguments for outputs.
839
+ unused_inputs = normalized_args[nin:]
840
+ if unused_inputs:
841
+ self.outputs = unused_inputs
842
+
843
+ def adjust_input_types(self, indtypes):
844
+ """
845
+ Attempt to cast the inputs to the required types if necessary
846
+ and if they are not device arrays.
847
+
848
+ Side effect: Only affects the elements of `inputs` that require
849
+ a type cast.
850
+ """
851
+ for i, (ity, val) in enumerate(zip(indtypes, self.inputs)):
852
+ if ity != val.dtype:
853
+ if not hasattr(val, 'astype'):
854
+ msg = ("compatible signature is possible by casting but "
855
+ "{0} does not support .astype()").format(type(val))
856
+ raise TypeError(msg)
857
+ # Cast types
858
+ self.inputs[i] = val.astype(ity)
859
+
860
+ def prepare_outputs(self, schedule, outdtypes):
861
+ """
862
+ Returns a list of output parameters that all reside on the target
863
+ device.
864
+
865
+ Outputs that were passed-in to the GUFunc are used if they reside on the
866
+ device; other outputs are allocated as necessary.
867
+ """
868
+ outputs = []
869
+ for shape, dtype, output in zip(schedule.output_shapes, outdtypes,
870
+ self.outputs):
871
+ if output is None or self._copy_result_to_host:
872
+ output = self.allocate_device_array(shape, dtype)
873
+ outputs.append(output)
874
+
875
+ return outputs
876
+
877
+ def prepare_inputs(self):
878
+ """
879
+ Returns a list of input parameters that all reside on the target device.
880
+ """
881
+ def ensure_device(parameter):
882
+ if self.is_device_array(parameter):
883
+ convert = self.as_device_array
884
+ else:
885
+ convert = self.to_device
886
+
887
+ return convert(parameter)
888
+
889
+ return [ensure_device(p) for p in self.inputs]
890
+
891
+ def post_process_outputs(self, outputs):
892
+ """
893
+ Moves the given output(s) to the host if necessary.
894
+
895
+ Returns a single value (e.g. an array) if there was one output, or a
896
+ tuple of arrays if there were multiple. Although this feels a little
897
+ jarring, it is consistent with the behavior of GUFuncs in general.
898
+ """
899
+ if self._copy_result_to_host:
900
+ outputs = [self.to_host(output, self_output)
901
+ for output, self_output in zip(outputs, self.outputs)]
902
+ elif self.outputs[0] is not None:
903
+ outputs = self.outputs
904
+
905
+ if len(outputs) == 1:
906
+ return outputs[0]
907
+ else:
908
+ return tuple(outputs)