numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.0.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -72,12 +72,12 @@ class UFuncMechanism(object):
72
72
  """
73
73
  Prepare ufunc arguments for vectorize.
74
74
  """
75
+
75
76
  DEFAULT_STREAM = None
76
77
  SUPPORT_DEVICE_SLICING = False
77
78
 
78
79
  def __init__(self, typemap, args):
79
- """Never used directly by user. Invoke by UFuncMechanism.call().
80
- """
80
+ """Never used directly by user. Invoke by UFuncMechanism.call()."""
81
81
  self.typemap = typemap
82
82
  self.args = args
83
83
  nargs = len(self.args)
@@ -105,7 +105,7 @@ class UFuncMechanism(object):
105
105
  """
106
106
  for i, ary in enumerate(self.arrays):
107
107
  if ary is not None:
108
- dtype = getattr(ary, 'dtype')
108
+ dtype = getattr(ary, "dtype")
109
109
  if dtype is None:
110
110
  dtype = np.asarray(ary).dtype
111
111
  self.argtypes[i] = dtype
@@ -120,8 +120,9 @@ class UFuncMechanism(object):
120
120
  # Try resolve scalar arguments
121
121
  for formaltys in self.typemap:
122
122
  match_map = []
123
- for i, (formal, actual) in enumerate(zip(formaltys,
124
- self.argtypes)):
123
+ for i, (formal, actual) in enumerate(
124
+ zip(formaltys, self.argtypes)
125
+ ):
125
126
  if actual is None:
126
127
  actual = np.asarray(self.args[i]).dtype
127
128
 
@@ -134,21 +135,26 @@ class UFuncMechanism(object):
134
135
  if not matches:
135
136
  matches = []
136
137
  for formaltys in self.typemap:
137
- all_matches = all(actual is None or formal == actual
138
- for formal, actual in
139
- zip(formaltys, self.argtypes))
138
+ all_matches = all(
139
+ actual is None or formal == actual
140
+ for formal, actual in zip(formaltys, self.argtypes)
141
+ )
140
142
  if all_matches:
141
143
  matches.append(formaltys)
142
144
 
143
145
  if not matches:
144
- raise TypeError("No matching version. GPU ufunc requires array "
145
- "arguments to have the exact types. This behaves "
146
- "like regular ufunc with casting='no'.")
146
+ raise TypeError(
147
+ "No matching version. GPU ufunc requires array "
148
+ "arguments to have the exact types. This behaves "
149
+ "like regular ufunc with casting='no'."
150
+ )
147
151
 
148
152
  if len(matches) > 1:
149
- raise TypeError("Failed to resolve ufunc due to ambiguous "
150
- "signature. Too many untyped scalars. "
151
- "Use numpy dtype object to type tag.")
153
+ raise TypeError(
154
+ "Failed to resolve ufunc due to ambiguous "
155
+ "signature. Too many untyped scalars. "
156
+ "Use numpy dtype object to type tag."
157
+ )
152
158
 
153
159
  # Try scalar arguments
154
160
  self.argtypes = matches[0]
@@ -163,8 +169,7 @@ class UFuncMechanism(object):
163
169
  return self.arrays
164
170
 
165
171
  def _broadcast(self, arys):
166
- """Perform numpy ufunc broadcasting
167
- """
172
+ """Perform numpy ufunc broadcasting"""
168
173
  shapelist = [a.shape for a in arys]
169
174
  shape = _multi_broadcast(*shapelist)
170
175
 
@@ -177,9 +182,11 @@ class UFuncMechanism(object):
177
182
  arys[i] = self.broadcast_device(ary, shape)
178
183
 
179
184
  else:
180
- ax_differs = [ax for ax in range(len(shape))
181
- if ax >= ary.ndim
182
- or ary.shape[ax] != shape[ax]]
185
+ ax_differs = [
186
+ ax
187
+ for ax in range(len(shape))
188
+ if ax >= ary.ndim or ary.shape[ax] != shape[ax]
189
+ ]
183
190
 
184
191
  missingdim = len(shape) - len(ary.shape)
185
192
  strides = [0] * missingdim + list(ary.strides)
@@ -187,9 +194,9 @@ class UFuncMechanism(object):
187
194
  for ax in ax_differs:
188
195
  strides[ax] = 0
189
196
 
190
- strided = np.lib.stride_tricks.as_strided(ary,
191
- shape=shape,
192
- strides=strides)
197
+ strided = np.lib.stride_tricks.as_strided(
198
+ ary, shape=shape, strides=strides
199
+ )
193
200
 
194
201
  arys[i] = self.force_array_layout(strided)
195
202
 
@@ -206,8 +213,7 @@ class UFuncMechanism(object):
206
213
  return self._broadcast(arys)
207
214
 
208
215
  def get_function(self):
209
- """Returns (result_dtype, function)
210
- """
216
+ """Returns (result_dtype, function)"""
211
217
  return self.typemap[self.argtypes]
212
218
 
213
219
  def is_device_array(self, obj):
@@ -240,14 +246,13 @@ class UFuncMechanism(object):
240
246
 
241
247
  @classmethod
242
248
  def call(cls, typemap, args, kws):
243
- """Perform the entire ufunc call mechanism.
244
- """
249
+ """Perform the entire ufunc call mechanism."""
245
250
  # Handle keywords
246
- stream = kws.pop('stream', cls.DEFAULT_STREAM)
247
- out = kws.pop('out', None)
251
+ stream = kws.pop("stream", cls.DEFAULT_STREAM)
252
+ out = kws.pop("out", None)
248
253
 
249
254
  if kws:
250
- warnings.warn("unrecognized keywords: %s" % ', '.join(kws))
255
+ warnings.warn("unrecognized keywords: %s" % ", ".join(kws))
251
256
 
252
257
  # Begin call resolution
253
258
  cr = cls(typemap, args)
@@ -364,9 +369,11 @@ class DeviceVectorize(_BaseUFuncBuilder):
364
369
  if cache:
365
370
  raise TypeError("caching is not supported")
366
371
  for opt in targetoptions:
367
- if opt == 'nopython':
368
- warnings.warn("nopython kwarg for cuda target is redundant",
369
- RuntimeWarning)
372
+ if opt == "nopython":
373
+ warnings.warn(
374
+ "nopython kwarg for cuda target is redundant",
375
+ RuntimeWarning,
376
+ )
370
377
  else:
371
378
  fmt = "Unrecognized options. "
372
379
  fmt += "cuda vectorize target does not support option: '%s'"
@@ -386,14 +393,15 @@ class DeviceVectorize(_BaseUFuncBuilder):
386
393
  devfnsig = signature(return_type, *args)
387
394
 
388
395
  funcname = self.pyfunc.__name__
389
- kernelsource = self._get_kernel_source(self._kernel_template,
390
- devfnsig, funcname)
396
+ kernelsource = self._get_kernel_source(
397
+ self._kernel_template, devfnsig, funcname
398
+ )
391
399
  corefn, return_type = self._compile_core(devfnsig)
392
400
  glbl = self._get_globals(corefn)
393
401
  sig = signature(types.void, *([a[:] for a in args] + [return_type[:]]))
394
402
  exec(kernelsource, glbl)
395
403
 
396
- stager = glbl['__vectorized_%s' % funcname]
404
+ stager = glbl["__vectorized_%s" % funcname]
397
405
  kernel = self._compile_kernel(stager, sig)
398
406
 
399
407
  argdtypes = tuple(to_dtype(t) for t in devfnsig.args)
@@ -404,10 +412,12 @@ class DeviceVectorize(_BaseUFuncBuilder):
404
412
  raise NotImplementedError
405
413
 
406
414
  def _get_kernel_source(self, template, sig, funcname):
407
- args = ['a%d' % i for i in range(len(sig.args))]
408
- fmts = dict(name=funcname,
409
- args=', '.join(args),
410
- argitems=', '.join('%s[__tid__]' % i for i in args))
415
+ args = ["a%d" % i for i in range(len(sig.args))]
416
+ fmts = dict(
417
+ name=funcname,
418
+ args=", ".join(args),
419
+ argitems=", ".join("%s[__tid__]" % i for i in args),
420
+ )
411
421
  return template.format(**fmts)
412
422
 
413
423
  def _compile_core(self, sig):
@@ -421,19 +431,26 @@ class DeviceVectorize(_BaseUFuncBuilder):
421
431
 
422
432
 
423
433
  class DeviceGUFuncVectorize(_BaseUFuncBuilder):
424
- def __init__(self, func, sig, identity=None, cache=False, targetoptions={},
425
- writable_args=()):
434
+ def __init__(
435
+ self,
436
+ func,
437
+ sig,
438
+ identity=None,
439
+ cache=False,
440
+ targetoptions={},
441
+ writable_args=(),
442
+ ):
426
443
  if cache:
427
444
  raise TypeError("caching is not supported")
428
445
  if writable_args:
429
446
  raise TypeError("writable_args are not supported")
430
447
 
431
448
  # Allow nopython flag to be set.
432
- if not targetoptions.pop('nopython', True):
449
+ if not targetoptions.pop("nopython", True):
433
450
  raise TypeError("nopython flag must be True")
434
451
  # Are there any more target options?
435
452
  if targetoptions:
436
- opts = ', '.join([repr(k) for k in targetoptions.keys()])
453
+ opts = ", ".join([repr(k) for k in targetoptions.keys()])
437
454
  fmt = "The following target options are not supported: {0}"
438
455
  raise TypeError(fmt.format(opts))
439
456
 
@@ -458,18 +475,21 @@ class DeviceGUFuncVectorize(_BaseUFuncBuilder):
458
475
  # specify the return type (where the "Python None" is the return type)
459
476
  valid_return_type = return_type in (types.none, None)
460
477
  if not valid_return_type:
461
- raise TypeError('guvectorized functions cannot return values: '
462
- f'signature {sig} specifies {return_type} return '
463
- 'type')
478
+ raise TypeError(
479
+ "guvectorized functions cannot return values: "
480
+ f"signature {sig} specifies {return_type} return "
481
+ "type"
482
+ )
464
483
 
465
484
  funcname = self.py_func.__name__
466
- src = expand_gufunc_template(self._kernel_template, indims,
467
- outdims, funcname, args)
485
+ src = expand_gufunc_template(
486
+ self._kernel_template, indims, outdims, funcname, args
487
+ )
468
488
 
469
489
  glbls = self._get_globals(sig)
470
490
 
471
491
  exec(src, glbls)
472
- fnobj = glbls['__gufunc_{name}'.format(name=funcname)]
492
+ fnobj = glbls["__gufunc_{name}".format(name=funcname)]
473
493
 
474
494
  outertys = list(_determine_gufunc_outer_types(args, indims + outdims))
475
495
  kernel = self._compile_kernel(fnobj, sig=tuple(outertys))
@@ -495,49 +515,58 @@ def _determine_gufunc_outer_types(argtys, dims):
495
515
  else:
496
516
  if nd > 0:
497
517
  raise ValueError("gufunc signature mismatch: ndim>0 for scalar")
498
- yield types.Array(dtype=at, ndim=1, layout='A')
518
+ yield types.Array(dtype=at, ndim=1, layout="A")
499
519
 
500
520
 
501
521
  def expand_gufunc_template(template, indims, outdims, funcname, argtypes):
502
- """Expand gufunc source template
503
- """
522
+ """Expand gufunc source template"""
504
523
  argdims = indims + outdims
505
524
  argnames = ["arg{0}".format(i) for i in range(len(argdims))]
506
- checkedarg = "min({0})".format(', '.join(["{0}.shape[0]".format(a)
507
- for a in argnames]))
508
- inputs = [_gen_src_for_indexing(aref, adims, atype)
509
- for aref, adims, atype in zip(argnames, indims, argtypes)]
510
- outputs = [_gen_src_for_indexing(aref, adims, atype)
511
- for aref, adims, atype in zip(argnames[len(indims):], outdims,
512
- argtypes[len(indims):])]
525
+ checkedarg = "min({0})".format(
526
+ ", ".join(["{0}.shape[0]".format(a) for a in argnames])
527
+ )
528
+ inputs = [
529
+ _gen_src_for_indexing(aref, adims, atype)
530
+ for aref, adims, atype in zip(argnames, indims, argtypes)
531
+ ]
532
+ outputs = [
533
+ _gen_src_for_indexing(aref, adims, atype)
534
+ for aref, adims, atype in zip(
535
+ argnames[len(indims) :], outdims, argtypes[len(indims) :]
536
+ )
537
+ ]
513
538
  argitems = inputs + outputs
514
- src = template.format(name=funcname, args=', '.join(argnames),
515
- checkedarg=checkedarg,
516
- argitems=', '.join(argitems))
539
+ src = template.format(
540
+ name=funcname,
541
+ args=", ".join(argnames),
542
+ checkedarg=checkedarg,
543
+ argitems=", ".join(argitems),
544
+ )
517
545
  return src
518
546
 
519
547
 
520
548
  def _gen_src_for_indexing(aref, adims, atype):
521
- return "{aref}[{sliced}]".format(aref=aref,
522
- sliced=_gen_src_index(adims, atype))
549
+ return "{aref}[{sliced}]".format(
550
+ aref=aref, sliced=_gen_src_index(adims, atype)
551
+ )
523
552
 
524
553
 
525
554
  def _gen_src_index(adims, atype):
526
555
  if adims > 0:
527
- return ','.join(['__tid__'] + [':'] * adims)
556
+ return ",".join(["__tid__"] + [":"] * adims)
528
557
  elif isinstance(atype, types.Array) and atype.ndim - 1 == adims:
529
558
  # Special case for 0-nd in shape-signature but
530
559
  # 1d array in type signature.
531
560
  # Slice it so that the result has the same dimension.
532
- return '__tid__:(__tid__ + 1)'
561
+ return "__tid__:(__tid__ + 1)"
533
562
  else:
534
- return '__tid__'
563
+ return "__tid__"
535
564
 
536
565
 
537
566
  class GUFuncEngine(object):
538
- '''Determine how to broadcast and execute a gufunc
567
+ """Determine how to broadcast and execute a gufunc
539
568
  base on input shape and signature
540
- '''
569
+ """
541
570
 
542
571
  @classmethod
543
572
  def from_signature(cls, signature):
@@ -553,7 +582,7 @@ class GUFuncEngine(object):
553
582
 
554
583
  def schedule(self, ishapes):
555
584
  if len(ishapes) != self.nin:
556
- raise TypeError('invalid number of input argument')
585
+ raise TypeError("invalid number of input argument")
557
586
 
558
587
  # associate symbol values for input signature
559
588
  symbolmap = {}
@@ -626,7 +655,7 @@ class GUFuncSchedule(object):
626
655
  def __str__(self):
627
656
  import pprint
628
657
 
629
- attrs = 'ishapes', 'oshapes', 'loopdims', 'loopn', 'pinned'
658
+ attrs = "ishapes", "oshapes", "loopdims", "loopn", "pinned"
630
659
  values = [(k, getattr(self, k)) for k in attrs]
631
660
  return pprint.pformat(dict(values))
632
661
 
@@ -635,13 +664,15 @@ class GeneralizedUFunc(object):
635
664
  def __init__(self, kernelmap, engine):
636
665
  self.kernelmap = kernelmap
637
666
  self.engine = engine
638
- self.max_blocksize = 2 ** 30
667
+ self.max_blocksize = 2**30
639
668
 
640
669
  def __call__(self, *args, **kws):
641
- callsteps = self._call_steps(self.engine.nin, self.engine.nout,
642
- args, kws)
670
+ callsteps = self._call_steps(
671
+ self.engine.nin, self.engine.nout, args, kws
672
+ )
643
673
  indtypes, schedule, outdtypes, kernel = self._schedule(
644
- callsteps.inputs, callsteps.outputs)
674
+ callsteps.inputs, callsteps.outputs
675
+ )
645
676
  callsteps.adjust_input_types(indtypes)
646
677
 
647
678
  outputs = callsteps.prepare_outputs(schedule, outdtypes)
@@ -671,7 +702,7 @@ class GeneralizedUFunc(object):
671
702
  # check output
672
703
  for sched_shape, out in zip(schedule.output_shapes, outs):
673
704
  if out is not None and sched_shape != out.shape:
674
- raise ValueError('output shape mismatch')
705
+ raise ValueError("output shape mismatch")
675
706
 
676
707
  return indtypes, schedule, outdtypes, kernel
677
708
 
@@ -683,8 +714,10 @@ class GeneralizedUFunc(object):
683
714
  Note: Ordering is guaranteed by `kernelmap` being a OrderedDict
684
715
  """
685
716
  for sig in self.kernelmap.keys():
686
- if all(np.can_cast(actual, desired)
687
- for actual, desired in zip(sig, idtypes)):
717
+ if all(
718
+ np.can_cast(actual, desired)
719
+ for actual, desired in zip(sig, idtypes)
720
+ ):
688
721
  return sig
689
722
  else:
690
723
  raise TypeError("no matching signature")
@@ -716,8 +749,9 @@ class GeneralizedUFunc(object):
716
749
 
717
750
  # Creating new dimension
718
751
  elif len(ary.shape) < len(newshape):
719
- assert newshape[-len(ary.shape):] == ary.shape, \
752
+ assert newshape[-len(ary.shape) :] == ary.shape, (
720
753
  "cannot add dim and reshape at the same time"
754
+ )
721
755
  return self._broadcast_add_axis(ary, newshape)
722
756
 
723
757
  # Collapsing dimension
@@ -744,9 +778,9 @@ class GUFuncCallSteps(metaclass=ABCMeta):
744
778
 
745
779
  # The base class uses these slots; subclasses may provide additional slots.
746
780
  __slots__ = [
747
- 'outputs',
748
- 'inputs',
749
- '_copy_result_to_host',
781
+ "outputs",
782
+ "inputs",
783
+ "_copy_result_to_host",
750
784
  ]
751
785
 
752
786
  @abstractmethod
@@ -782,21 +816,25 @@ class GUFuncCallSteps(metaclass=ABCMeta):
782
816
  """
783
817
 
784
818
  def __init__(self, nin, nout, args, kwargs):
785
- outputs = kwargs.get('out')
819
+ outputs = kwargs.get("out")
786
820
 
787
821
  # Ensure the user has passed a correct number of arguments
788
822
  if outputs is None and len(args) not in (nin, (nin + nout)):
823
+
789
824
  def pos_argn(n):
790
- return f'{n} positional argument{"s" * (n != 1)}'
825
+ return f"{n} positional argument{'s' * (n != 1)}"
791
826
 
792
- msg = (f'This gufunc accepts {pos_argn(nin)} (when providing '
793
- f'input only) or {pos_argn(nin + nout)} (when providing '
794
- f'input and output). Got {pos_argn(len(args))}.')
827
+ msg = (
828
+ f"This gufunc accepts {pos_argn(nin)} (when providing "
829
+ f"input only) or {pos_argn(nin + nout)} (when providing "
830
+ f"input and output). Got {pos_argn(len(args))}."
831
+ )
795
832
  raise TypeError(msg)
796
833
 
797
834
  if outputs is not None and len(args) > nin:
798
- raise ValueError("cannot specify argument 'out' as both positional "
799
- "and keyword")
835
+ raise ValueError(
836
+ "cannot specify argument 'out' as both positional and keyword"
837
+ )
800
838
  else:
801
839
  # If the user did not pass outputs either in the out kwarg or as
802
840
  # positional arguments, then we need to generate an initial list of
@@ -819,8 +857,9 @@ class GUFuncCallSteps(metaclass=ABCMeta):
819
857
 
820
858
  # - If any of the arguments are device arrays, we leave the output on
821
859
  # the device.
822
- self._copy_result_to_host = (all_host_arrays and
823
- all_user_outputs_are_host)
860
+ self._copy_result_to_host = (
861
+ all_host_arrays and all_user_outputs_are_host
862
+ )
824
863
 
825
864
  # Normalize arguments - ensure they are either device- or host-side
826
865
  # arrays (as opposed to lists, tuples, etc).
@@ -850,9 +889,11 @@ class GUFuncCallSteps(metaclass=ABCMeta):
850
889
  """
851
890
  for i, (ity, val) in enumerate(zip(indtypes, self.inputs)):
852
891
  if ity != val.dtype:
853
- if not hasattr(val, 'astype'):
854
- msg = ("compatible signature is possible by casting but "
855
- "{0} does not support .astype()").format(type(val))
892
+ if not hasattr(val, "astype"):
893
+ msg = (
894
+ "compatible signature is possible by casting but "
895
+ "{0} does not support .astype()"
896
+ ).format(type(val))
856
897
  raise TypeError(msg)
857
898
  # Cast types
858
899
  self.inputs[i] = val.astype(ity)
@@ -866,8 +907,9 @@ class GUFuncCallSteps(metaclass=ABCMeta):
866
907
  device; other outputs are allocated as necessary.
867
908
  """
868
909
  outputs = []
869
- for shape, dtype, output in zip(schedule.output_shapes, outdtypes,
870
- self.outputs):
910
+ for shape, dtype, output in zip(
911
+ schedule.output_shapes, outdtypes, self.outputs
912
+ ):
871
913
  if output is None or self._copy_result_to_host:
872
914
  output = self.allocate_device_array(shape, dtype)
873
915
  outputs.append(output)
@@ -878,6 +920,7 @@ class GUFuncCallSteps(metaclass=ABCMeta):
878
920
  """
879
921
  Returns a list of input parameters that all reside on the target device.
880
922
  """
923
+
881
924
  def ensure_device(parameter):
882
925
  if self.is_device_array(parameter):
883
926
  convert = self.as_device_array
@@ -897,8 +940,10 @@ class GUFuncCallSteps(metaclass=ABCMeta):
897
940
  jarring, it is consistent with the behavior of GUFuncs in general.
898
941
  """
899
942
  if self._copy_result_to_host:
900
- outputs = [self.to_host(output, self_output)
901
- for output, self_output in zip(outputs, self.outputs)]
943
+ outputs = [
944
+ self.to_host(output, self_output)
945
+ for output, self_output in zip(outputs, self.outputs)
946
+ ]
902
947
  elif self.outputs[0] is not None:
903
948
  outputs = self.outputs
904
949