numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.0.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,23 @@
1
1
  import operator
2
2
  from numba.core import types
3
- from numba.core.typing.npydecl import (parse_dtype, parse_shape,
4
- register_number_classes,
5
- register_numpy_ufunc,
6
- trigonometric_functions,
7
- comparison_functions,
8
- math_operations,
9
- bit_twiddling_functions)
10
- from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
11
- AbstractTemplate, CallableTemplate,
12
- signature, Registry)
3
+ from numba.core.typing.npydecl import (
4
+ parse_dtype,
5
+ parse_shape,
6
+ register_number_classes,
7
+ register_numpy_ufunc,
8
+ trigonometric_functions,
9
+ comparison_functions,
10
+ math_operations,
11
+ bit_twiddling_functions,
12
+ )
13
+ from numba.core.typing.templates import (
14
+ AttributeTemplate,
15
+ ConcreteTemplate,
16
+ AbstractTemplate,
17
+ CallableTemplate,
18
+ signature,
19
+ Registry,
20
+ )
13
21
  from numba.cuda.types import dim3
14
22
  from numba.core.typeconv import Conversion
15
23
  from numba import cuda
@@ -26,15 +34,15 @@ register_number_classes(register_global)
26
34
  class Cuda_array_decl(CallableTemplate):
27
35
  def generic(self):
28
36
  def typer(shape, dtype):
29
-
30
37
  # Only integer literals and tuples of integer literals are valid
31
38
  # shapes
32
39
  if isinstance(shape, types.Integer):
33
40
  if not isinstance(shape, types.IntegerLiteral):
34
41
  return None
35
42
  elif isinstance(shape, (types.Tuple, types.UniTuple)):
36
- if any([not isinstance(s, types.IntegerLiteral)
37
- for s in shape]):
43
+ if any(
44
+ [not isinstance(s, types.IntegerLiteral) for s in shape]
45
+ ):
38
46
  return None
39
47
  else:
40
48
  return None
@@ -42,7 +50,7 @@ class Cuda_array_decl(CallableTemplate):
42
50
  ndim = parse_shape(shape)
43
51
  nb_dtype = parse_dtype(dtype)
44
52
  if nb_dtype is not None and ndim is not None:
45
- return types.Array(dtype=nb_dtype, ndim=ndim, layout='C')
53
+ return types.Array(dtype=nb_dtype, ndim=ndim, layout="C")
46
54
 
47
55
  return typer
48
56
 
@@ -64,6 +72,7 @@ class Cuda_const_array_like(CallableTemplate):
64
72
  def generic(self):
65
73
  def typer(ndarray):
66
74
  return ndarray
75
+
67
76
  return typer
68
77
 
69
78
 
@@ -95,22 +104,49 @@ class Cuda_syncwarp(ConcreteTemplate):
95
104
  class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
96
105
  key = cuda.shfl_sync_intrinsic
97
106
  cases = [
98
- signature(types.Tuple((types.i4, types.b1)),
99
- types.i4, types.i4, types.i4, types.i4, types.i4),
100
- signature(types.Tuple((types.i8, types.b1)),
101
- types.i4, types.i4, types.i8, types.i4, types.i4),
102
- signature(types.Tuple((types.f4, types.b1)),
103
- types.i4, types.i4, types.f4, types.i4, types.i4),
104
- signature(types.Tuple((types.f8, types.b1)),
105
- types.i4, types.i4, types.f8, types.i4, types.i4),
107
+ signature(
108
+ types.Tuple((types.i4, types.b1)),
109
+ types.i4,
110
+ types.i4,
111
+ types.i4,
112
+ types.i4,
113
+ types.i4,
114
+ ),
115
+ signature(
116
+ types.Tuple((types.i8, types.b1)),
117
+ types.i4,
118
+ types.i4,
119
+ types.i8,
120
+ types.i4,
121
+ types.i4,
122
+ ),
123
+ signature(
124
+ types.Tuple((types.f4, types.b1)),
125
+ types.i4,
126
+ types.i4,
127
+ types.f4,
128
+ types.i4,
129
+ types.i4,
130
+ ),
131
+ signature(
132
+ types.Tuple((types.f8, types.b1)),
133
+ types.i4,
134
+ types.i4,
135
+ types.f8,
136
+ types.i4,
137
+ types.i4,
138
+ ),
106
139
  ]
107
140
 
108
141
 
109
142
  @register
110
143
  class Cuda_vote_sync_intrinsic(ConcreteTemplate):
111
144
  key = cuda.vote_sync_intrinsic
112
- cases = [signature(types.Tuple((types.i4, types.b1)),
113
- types.i4, types.i4, types.b1)]
145
+ cases = [
146
+ signature(
147
+ types.Tuple((types.i4, types.b1)), types.i4, types.i4, types.b1
148
+ )
149
+ ]
114
150
 
115
151
 
116
152
  @register
@@ -153,6 +189,7 @@ class Cuda_popc(ConcreteTemplate):
153
189
  Supported types from `llvm.popc`
154
190
  [here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
155
191
  """
192
+
156
193
  key = cuda.popc
157
194
  cases = [
158
195
  signature(types.int8, types.int8),
@@ -172,6 +209,7 @@ class Cuda_fma(ConcreteTemplate):
172
209
  Supported types from `llvm.fma`
173
210
  [here](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#standard-c-library-intrinics)
174
211
  """
212
+
175
213
  key = cuda.fma
176
214
  cases = [
177
215
  signature(types.float32, types.float32, types.float32, types.float32),
@@ -189,7 +227,6 @@ class Cuda_hfma(ConcreteTemplate):
189
227
 
190
228
  @register
191
229
  class Cuda_cbrt(ConcreteTemplate):
192
-
193
230
  key = cuda.cbrt
194
231
  cases = [
195
232
  signature(types.float32, types.float32),
@@ -212,6 +249,7 @@ class Cuda_clz(ConcreteTemplate):
212
249
  Supported types from `llvm.ctlz`
213
250
  [here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
214
251
  """
252
+
215
253
  key = cuda.clz
216
254
  cases = [
217
255
  signature(types.int8, types.int8),
@@ -231,6 +269,7 @@ class Cuda_ffs(ConcreteTemplate):
231
269
  Supported types from `llvm.cttz`
232
270
  [here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
233
271
  """
272
+
234
273
  key = cuda.ffs
235
274
  cases = [
236
275
  signature(types.uint32, types.int8),
@@ -254,10 +293,16 @@ class Cuda_selp(AbstractTemplate):
254
293
 
255
294
  # per docs
256
295
  # http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
257
- supported_types = (types.float64, types.float32,
258
- types.int16, types.uint16,
259
- types.int32, types.uint32,
260
- types.int64, types.uint64)
296
+ supported_types = (
297
+ types.float64,
298
+ types.float32,
299
+ types.int16,
300
+ types.uint16,
301
+ types.int32,
302
+ types.uint32,
303
+ types.int64,
304
+ types.uint64,
305
+ )
261
306
 
262
307
  if a != b or a not in supported_types:
263
308
  return
@@ -298,7 +343,6 @@ def _genfp16_binary(l_key):
298
343
 
299
344
  @register_global(float)
300
345
  class Float(AbstractTemplate):
301
-
302
346
  def generic(self, args, kws):
303
347
  assert not kws
304
348
 
@@ -313,11 +357,11 @@ def _genfp16_binary_comparison(l_key):
313
357
  class Cuda_fp16_cmp(ConcreteTemplate):
314
358
  key = l_key
315
359
 
316
- cases = [
317
- signature(types.b1, types.float16, types.float16)
318
- ]
360
+ cases = [signature(types.b1, types.float16, types.float16)]
361
+
319
362
  return Cuda_fp16_cmp
320
363
 
364
+
321
365
  # If multiple ConcreteTemplates provide typing for a single function, then
322
366
  # function resolution will pick the first compatible typing it finds even if it
323
367
  # involves inserting a cast that would be considered undesirable (in this
@@ -340,9 +384,10 @@ def _fp16_binary_operator(l_key, retty):
340
384
  def generic(self, args, kws):
341
385
  assert not kws
342
386
 
343
- if len(args) == 2 and \
344
- (args[0] == types.float16 or args[1] == types.float16):
345
- if (args[0] == types.float16):
387
+ if len(args) == 2 and (
388
+ args[0] == types.float16 or args[1] == types.float16
389
+ ):
390
+ if args[0] == types.float16:
346
391
  convertible = self.context.can_convert(args[1], args[0])
347
392
  else:
348
393
  convertible = self.context.can_convert(args[0], args[1])
@@ -355,9 +400,11 @@ def _fp16_binary_operator(l_key, retty):
355
400
  # 3. fp16 to int8 (safe conversion) -
356
401
  # - Conversion.safe
357
402
 
358
- if (convertible == Conversion.exact) or \
359
- (convertible == Conversion.promote) or \
360
- (convertible == Conversion.safe):
403
+ if (
404
+ (convertible == Conversion.exact)
405
+ or (convertible == Conversion.promote)
406
+ or (convertible == Conversion.safe)
407
+ ):
361
408
  return signature(retty, types.float16, types.float16)
362
409
 
363
410
  return Cuda_fp16_operator
@@ -404,38 +451,42 @@ _genfp16_binary_operator(operator.itruediv)
404
451
 
405
452
  def _resolve_wrapped_unary(fname):
406
453
  link = tuple()
407
- decl = declare_device_function_template(f'__numba_wrapper_{fname}',
408
- types.float16,
409
- (types.float16,),
410
- link)
454
+ decl = declare_device_function_template(
455
+ f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
456
+ )
411
457
  return types.Function(decl)
412
458
 
413
459
 
414
460
  def _resolve_wrapped_binary(fname):
415
461
  link = tuple()
416
- decl = declare_device_function_template(f'__numba_wrapper_{fname}',
417
- types.float16,
418
- (types.float16, types.float16,),
419
- link)
462
+ decl = declare_device_function_template(
463
+ f"__numba_wrapper_{fname}",
464
+ types.float16,
465
+ (
466
+ types.float16,
467
+ types.float16,
468
+ ),
469
+ link,
470
+ )
420
471
  return types.Function(decl)
421
472
 
422
473
 
423
- hsin_device = _resolve_wrapped_unary('hsin')
424
- hcos_device = _resolve_wrapped_unary('hcos')
425
- hlog_device = _resolve_wrapped_unary('hlog')
426
- hlog10_device = _resolve_wrapped_unary('hlog10')
427
- hlog2_device = _resolve_wrapped_unary('hlog2')
428
- hexp_device = _resolve_wrapped_unary('hexp')
429
- hexp10_device = _resolve_wrapped_unary('hexp10')
430
- hexp2_device = _resolve_wrapped_unary('hexp2')
431
- hsqrt_device = _resolve_wrapped_unary('hsqrt')
432
- hrsqrt_device = _resolve_wrapped_unary('hrsqrt')
433
- hfloor_device = _resolve_wrapped_unary('hfloor')
434
- hceil_device = _resolve_wrapped_unary('hceil')
435
- hrcp_device = _resolve_wrapped_unary('hrcp')
436
- hrint_device = _resolve_wrapped_unary('hrint')
437
- htrunc_device = _resolve_wrapped_unary('htrunc')
438
- hdiv_device = _resolve_wrapped_binary('hdiv')
474
+ hsin_device = _resolve_wrapped_unary("hsin")
475
+ hcos_device = _resolve_wrapped_unary("hcos")
476
+ hlog_device = _resolve_wrapped_unary("hlog")
477
+ hlog10_device = _resolve_wrapped_unary("hlog10")
478
+ hlog2_device = _resolve_wrapped_unary("hlog2")
479
+ hexp_device = _resolve_wrapped_unary("hexp")
480
+ hexp10_device = _resolve_wrapped_unary("hexp10")
481
+ hexp2_device = _resolve_wrapped_unary("hexp2")
482
+ hsqrt_device = _resolve_wrapped_unary("hsqrt")
483
+ hrsqrt_device = _resolve_wrapped_unary("hrsqrt")
484
+ hfloor_device = _resolve_wrapped_unary("hfloor")
485
+ hceil_device = _resolve_wrapped_unary("hceil")
486
+ hrcp_device = _resolve_wrapped_unary("hrcp")
487
+ hrint_device = _resolve_wrapped_unary("hrint")
488
+ htrunc_device = _resolve_wrapped_unary("htrunc")
489
+ hdiv_device = _resolve_wrapped_binary("hdiv")
439
490
 
440
491
 
441
492
  # generate atomic operations
@@ -455,15 +506,20 @@ def _gen(l_key, supported_types):
455
506
  return signature(ary.dtype, ary, types.intp, ary.dtype)
456
507
  elif ary.ndim > 1:
457
508
  return signature(ary.dtype, ary, idx, ary.dtype)
509
+
458
510
  return Cuda_atomic
459
511
 
460
512
 
461
- all_numba_types = (types.float64, types.float32,
462
- types.int32, types.uint32,
463
- types.int64, types.uint64)
513
+ all_numba_types = (
514
+ types.float64,
515
+ types.float32,
516
+ types.int32,
517
+ types.uint32,
518
+ types.int64,
519
+ types.uint64,
520
+ )
464
521
 
465
- integer_numba_types = (types.int32, types.uint32,
466
- types.int64, types.uint64)
522
+ integer_numba_types = (types.int32, types.uint32, types.int64, types.uint64)
467
523
 
468
524
  unsigned_int_numba_types = (types.uint32, types.uint64)
469
525
 
@@ -811,5 +867,5 @@ for func in bit_twiddling_functions:
811
867
  register_numpy_ufunc(func, register_global)
812
868
 
813
869
  for func in math_operations:
814
- if func in ('log', 'log2', 'log10'):
870
+ if func in ("log", "log2", "log10"):
815
871
  register_numpy_ufunc(func, register_global)
@@ -5,5 +5,7 @@
5
5
  - Device array implementation
6
6
 
7
7
  """
8
+
8
9
  from numba.core import config
9
- assert not config.ENABLE_CUDASIM, 'Cannot use real driver API with simulator'
10
+
11
+ assert not config.ENABLE_CUDASIM, "Cannot use real driver API with simulator"