numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +246 -114
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +293 -99
  14. numba_cuda/numba/cuda/cudadecl.py +93 -79
  15. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  16. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  17. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  18. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  19. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  20. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  21. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  22. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  23. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  24. numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
  25. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  26. numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
  27. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  28. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  29. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  30. numba_cuda/numba/cuda/cudaimpl.py +296 -275
  31. numba_cuda/numba/cuda/cudamath.py +1 -1
  32. numba_cuda/numba/cuda/debuginfo.py +99 -7
  33. numba_cuda/numba/cuda/decorators.py +87 -45
  34. numba_cuda/numba/cuda/descriptor.py +1 -1
  35. numba_cuda/numba/cuda/device_init.py +68 -18
  36. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  37. numba_cuda/numba/cuda/dispatcher.py +300 -213
  38. numba_cuda/numba/cuda/errors.py +13 -10
  39. numba_cuda/numba/cuda/extending.py +55 -1
  40. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  41. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  42. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
  43. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
  44. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  45. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  46. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  47. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  48. numba_cuda/numba/cuda/initialize.py +5 -3
  49. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  50. numba_cuda/numba/cuda/intrinsics.py +203 -28
  51. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  52. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  53. numba_cuda/numba/cuda/libdevice.py +317 -317
  54. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  55. numba_cuda/numba/cuda/locks.py +16 -0
  56. numba_cuda/numba/cuda/lowering.py +43 -0
  57. numba_cuda/numba/cuda/mathimpl.py +62 -57
  58. numba_cuda/numba/cuda/models.py +1 -5
  59. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  60. numba_cuda/numba/cuda/printimpl.py +9 -5
  61. numba_cuda/numba/cuda/random.py +46 -36
  62. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  63. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  64. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  65. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  66. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  67. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  68. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  69. numba_cuda/numba/cuda/simulator/api.py +38 -22
  70. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  71. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  72. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  73. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  74. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  75. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  76. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  77. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  78. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  79. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  80. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  81. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  82. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  83. numba_cuda/numba/cuda/simulator_init.py +2 -4
  84. numba_cuda/numba/cuda/stubs.py +134 -108
  85. numba_cuda/numba/cuda/target.py +92 -47
  86. numba_cuda/numba/cuda/testing.py +24 -19
  87. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  88. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  89. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  90. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  91. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  92. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  93. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  94. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  95. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  96. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  97. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  98. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  99. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  100. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  102. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  103. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  104. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  105. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  107. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  108. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  109. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  110. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  111. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  112. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  113. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  114. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  115. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  117. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  118. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  119. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
  120. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  121. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  123. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  124. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  125. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  127. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
  129. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  130. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  131. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  132. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  133. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  134. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  135. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  136. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  137. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  138. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  139. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  140. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  141. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  142. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  143. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
  144. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  145. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  146. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
  147. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  148. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  149. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
  150. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  151. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  152. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  153. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  154. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  155. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  156. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  157. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  158. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  159. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  161. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  162. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  163. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  164. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  165. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
  166. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  167. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  168. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  169. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  170. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  171. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  172. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  173. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  174. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  175. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  176. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  178. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  179. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  180. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  181. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  182. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  183. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  184. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  185. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  186. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  187. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  188. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  189. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  190. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  191. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  192. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  193. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  194. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  195. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  196. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  197. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  198. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  199. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  200. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  201. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  202. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  203. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  204. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  205. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
  206. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  207. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  208. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  209. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  210. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  211. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  212. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  213. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  214. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  216. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  217. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  218. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  219. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  220. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  221. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  222. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  223. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  224. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  225. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  226. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  227. numba_cuda/numba/cuda/types.py +5 -2
  228. numba_cuda/numba/cuda/ufuncs.py +382 -362
  229. numba_cuda/numba/cuda/utils.py +2 -2
  230. numba_cuda/numba/cuda/vector_types.py +5 -3
  231. numba_cuda/numba/cuda/vectorizers.py +38 -33
  232. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
  233. numba_cuda-0.10.0.dist-info/RECORD +263 -0
  234. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
  235. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  236. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
  237. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,23 @@
1
1
  import operator
2
2
  from numba.core import types
3
- from numba.core.typing.npydecl import (parse_dtype, parse_shape,
4
- register_number_classes,
5
- register_numpy_ufunc,
6
- trigonometric_functions,
7
- comparison_functions,
8
- math_operations,
9
- bit_twiddling_functions)
10
- from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
11
- AbstractTemplate, CallableTemplate,
12
- signature, Registry)
3
+ from numba.core.typing.npydecl import (
4
+ parse_dtype,
5
+ parse_shape,
6
+ register_number_classes,
7
+ register_numpy_ufunc,
8
+ trigonometric_functions,
9
+ comparison_functions,
10
+ math_operations,
11
+ bit_twiddling_functions,
12
+ )
13
+ from numba.core.typing.templates import (
14
+ AttributeTemplate,
15
+ ConcreteTemplate,
16
+ AbstractTemplate,
17
+ CallableTemplate,
18
+ signature,
19
+ Registry,
20
+ )
13
21
  from numba.cuda.types import dim3
14
22
  from numba.core.typeconv import Conversion
15
23
  from numba import cuda
@@ -26,15 +34,15 @@ register_number_classes(register_global)
26
34
  class Cuda_array_decl(CallableTemplate):
27
35
  def generic(self):
28
36
  def typer(shape, dtype):
29
-
30
37
  # Only integer literals and tuples of integer literals are valid
31
38
  # shapes
32
39
  if isinstance(shape, types.Integer):
33
40
  if not isinstance(shape, types.IntegerLiteral):
34
41
  return None
35
42
  elif isinstance(shape, (types.Tuple, types.UniTuple)):
36
- if any([not isinstance(s, types.IntegerLiteral)
37
- for s in shape]):
43
+ if any(
44
+ [not isinstance(s, types.IntegerLiteral) for s in shape]
45
+ ):
38
46
  return None
39
47
  else:
40
48
  return None
@@ -42,7 +50,7 @@ class Cuda_array_decl(CallableTemplate):
42
50
  ndim = parse_shape(shape)
43
51
  nb_dtype = parse_dtype(dtype)
44
52
  if nb_dtype is not None and ndim is not None:
45
- return types.Array(dtype=nb_dtype, ndim=ndim, layout='C')
53
+ return types.Array(dtype=nb_dtype, ndim=ndim, layout="C")
46
54
 
47
55
  return typer
48
56
 
@@ -64,6 +72,7 @@ class Cuda_const_array_like(CallableTemplate):
64
72
  def generic(self):
65
73
  def typer(ndarray):
66
74
  return ndarray
75
+
67
76
  return typer
68
77
 
69
78
 
@@ -91,26 +100,14 @@ class Cuda_syncwarp(ConcreteTemplate):
91
100
  cases = [signature(types.none), signature(types.none, types.i4)]
92
101
 
93
102
 
94
- @register
95
- class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
96
- key = cuda.shfl_sync_intrinsic
97
- cases = [
98
- signature(types.Tuple((types.i4, types.b1)),
99
- types.i4, types.i4, types.i4, types.i4, types.i4),
100
- signature(types.Tuple((types.i8, types.b1)),
101
- types.i4, types.i4, types.i8, types.i4, types.i4),
102
- signature(types.Tuple((types.f4, types.b1)),
103
- types.i4, types.i4, types.f4, types.i4, types.i4),
104
- signature(types.Tuple((types.f8, types.b1)),
105
- types.i4, types.i4, types.f8, types.i4, types.i4),
106
- ]
107
-
108
-
109
103
  @register
110
104
  class Cuda_vote_sync_intrinsic(ConcreteTemplate):
111
105
  key = cuda.vote_sync_intrinsic
112
- cases = [signature(types.Tuple((types.i4, types.b1)),
113
- types.i4, types.i4, types.b1)]
106
+ cases = [
107
+ signature(
108
+ types.Tuple((types.i4, types.b1)), types.i4, types.i4, types.b1
109
+ )
110
+ ]
114
111
 
115
112
 
116
113
  @register
@@ -153,6 +150,7 @@ class Cuda_popc(ConcreteTemplate):
153
150
  Supported types from `llvm.popc`
154
151
  [here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
155
152
  """
153
+
156
154
  key = cuda.popc
157
155
  cases = [
158
156
  signature(types.int8, types.int8),
@@ -172,6 +170,7 @@ class Cuda_fma(ConcreteTemplate):
172
170
  Supported types from `llvm.fma`
173
171
  [here](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#standard-c-library-intrinics)
174
172
  """
173
+
175
174
  key = cuda.fma
176
175
  cases = [
177
176
  signature(types.float32, types.float32, types.float32, types.float32),
@@ -189,7 +188,6 @@ class Cuda_hfma(ConcreteTemplate):
189
188
 
190
189
  @register
191
190
  class Cuda_cbrt(ConcreteTemplate):
192
-
193
191
  key = cuda.cbrt
194
192
  cases = [
195
193
  signature(types.float32, types.float32),
@@ -212,6 +210,7 @@ class Cuda_clz(ConcreteTemplate):
212
210
  Supported types from `llvm.ctlz`
213
211
  [here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
214
212
  """
213
+
215
214
  key = cuda.clz
216
215
  cases = [
217
216
  signature(types.int8, types.int8),
@@ -231,6 +230,7 @@ class Cuda_ffs(ConcreteTemplate):
231
230
  Supported types from `llvm.cttz`
232
231
  [here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
233
232
  """
233
+
234
234
  key = cuda.ffs
235
235
  cases = [
236
236
  signature(types.uint32, types.int8),
@@ -254,10 +254,16 @@ class Cuda_selp(AbstractTemplate):
254
254
 
255
255
  # per docs
256
256
  # http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
257
- supported_types = (types.float64, types.float32,
258
- types.int16, types.uint16,
259
- types.int32, types.uint32,
260
- types.int64, types.uint64)
257
+ supported_types = (
258
+ types.float64,
259
+ types.float32,
260
+ types.int16,
261
+ types.uint16,
262
+ types.int32,
263
+ types.uint32,
264
+ types.int64,
265
+ types.uint64,
266
+ )
261
267
 
262
268
  if a != b or a not in supported_types:
263
269
  return
@@ -298,7 +304,6 @@ def _genfp16_binary(l_key):
298
304
 
299
305
  @register_global(float)
300
306
  class Float(AbstractTemplate):
301
-
302
307
  def generic(self, args, kws):
303
308
  assert not kws
304
309
 
@@ -313,11 +318,11 @@ def _genfp16_binary_comparison(l_key):
313
318
  class Cuda_fp16_cmp(ConcreteTemplate):
314
319
  key = l_key
315
320
 
316
- cases = [
317
- signature(types.b1, types.float16, types.float16)
318
- ]
321
+ cases = [signature(types.b1, types.float16, types.float16)]
322
+
319
323
  return Cuda_fp16_cmp
320
324
 
325
+
321
326
  # If multiple ConcreteTemplates provide typing for a single function, then
322
327
  # function resolution will pick the first compatible typing it finds even if it
323
328
  # involves inserting a cast that would be considered undesirable (in this
@@ -340,9 +345,10 @@ def _fp16_binary_operator(l_key, retty):
340
345
  def generic(self, args, kws):
341
346
  assert not kws
342
347
 
343
- if len(args) == 2 and \
344
- (args[0] == types.float16 or args[1] == types.float16):
345
- if (args[0] == types.float16):
348
+ if len(args) == 2 and (
349
+ args[0] == types.float16 or args[1] == types.float16
350
+ ):
351
+ if args[0] == types.float16:
346
352
  convertible = self.context.can_convert(args[1], args[0])
347
353
  else:
348
354
  convertible = self.context.can_convert(args[0], args[1])
@@ -355,9 +361,11 @@ def _fp16_binary_operator(l_key, retty):
355
361
  # 3. fp16 to int8 (safe conversion) -
356
362
  # - Conversion.safe
357
363
 
358
- if (convertible == Conversion.exact) or \
359
- (convertible == Conversion.promote) or \
360
- (convertible == Conversion.safe):
364
+ if (
365
+ (convertible == Conversion.exact)
366
+ or (convertible == Conversion.promote)
367
+ or (convertible == Conversion.safe)
368
+ ):
361
369
  return signature(retty, types.float16, types.float16)
362
370
 
363
371
  return Cuda_fp16_operator
@@ -404,38 +412,42 @@ _genfp16_binary_operator(operator.itruediv)
404
412
 
405
413
  def _resolve_wrapped_unary(fname):
406
414
  link = tuple()
407
- decl = declare_device_function_template(f'__numba_wrapper_{fname}',
408
- types.float16,
409
- (types.float16,),
410
- link)
415
+ decl = declare_device_function_template(
416
+ f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
417
+ )
411
418
  return types.Function(decl)
412
419
 
413
420
 
414
421
  def _resolve_wrapped_binary(fname):
415
422
  link = tuple()
416
- decl = declare_device_function_template(f'__numba_wrapper_{fname}',
417
- types.float16,
418
- (types.float16, types.float16,),
419
- link)
423
+ decl = declare_device_function_template(
424
+ f"__numba_wrapper_{fname}",
425
+ types.float16,
426
+ (
427
+ types.float16,
428
+ types.float16,
429
+ ),
430
+ link,
431
+ )
420
432
  return types.Function(decl)
421
433
 
422
434
 
423
- hsin_device = _resolve_wrapped_unary('hsin')
424
- hcos_device = _resolve_wrapped_unary('hcos')
425
- hlog_device = _resolve_wrapped_unary('hlog')
426
- hlog10_device = _resolve_wrapped_unary('hlog10')
427
- hlog2_device = _resolve_wrapped_unary('hlog2')
428
- hexp_device = _resolve_wrapped_unary('hexp')
429
- hexp10_device = _resolve_wrapped_unary('hexp10')
430
- hexp2_device = _resolve_wrapped_unary('hexp2')
431
- hsqrt_device = _resolve_wrapped_unary('hsqrt')
432
- hrsqrt_device = _resolve_wrapped_unary('hrsqrt')
433
- hfloor_device = _resolve_wrapped_unary('hfloor')
434
- hceil_device = _resolve_wrapped_unary('hceil')
435
- hrcp_device = _resolve_wrapped_unary('hrcp')
436
- hrint_device = _resolve_wrapped_unary('hrint')
437
- htrunc_device = _resolve_wrapped_unary('htrunc')
438
- hdiv_device = _resolve_wrapped_binary('hdiv')
435
+ hsin_device = _resolve_wrapped_unary("hsin")
436
+ hcos_device = _resolve_wrapped_unary("hcos")
437
+ hlog_device = _resolve_wrapped_unary("hlog")
438
+ hlog10_device = _resolve_wrapped_unary("hlog10")
439
+ hlog2_device = _resolve_wrapped_unary("hlog2")
440
+ hexp_device = _resolve_wrapped_unary("hexp")
441
+ hexp10_device = _resolve_wrapped_unary("hexp10")
442
+ hexp2_device = _resolve_wrapped_unary("hexp2")
443
+ hsqrt_device = _resolve_wrapped_unary("hsqrt")
444
+ hrsqrt_device = _resolve_wrapped_unary("hrsqrt")
445
+ hfloor_device = _resolve_wrapped_unary("hfloor")
446
+ hceil_device = _resolve_wrapped_unary("hceil")
447
+ hrcp_device = _resolve_wrapped_unary("hrcp")
448
+ hrint_device = _resolve_wrapped_unary("hrint")
449
+ htrunc_device = _resolve_wrapped_unary("htrunc")
450
+ hdiv_device = _resolve_wrapped_binary("hdiv")
439
451
 
440
452
 
441
453
  # generate atomic operations
@@ -455,15 +467,20 @@ def _gen(l_key, supported_types):
455
467
  return signature(ary.dtype, ary, types.intp, ary.dtype)
456
468
  elif ary.ndim > 1:
457
469
  return signature(ary.dtype, ary, idx, ary.dtype)
470
+
458
471
  return Cuda_atomic
459
472
 
460
473
 
461
- all_numba_types = (types.float64, types.float32,
462
- types.int32, types.uint32,
463
- types.int64, types.uint64)
474
+ all_numba_types = (
475
+ types.float64,
476
+ types.float32,
477
+ types.int32,
478
+ types.uint32,
479
+ types.int64,
480
+ types.uint64,
481
+ )
464
482
 
465
- integer_numba_types = (types.int32, types.uint32,
466
- types.int64, types.uint64)
483
+ integer_numba_types = (types.int32, types.uint32, types.int64, types.uint64)
467
484
 
468
485
  unsigned_int_numba_types = (types.uint32, types.uint64)
469
486
 
@@ -759,9 +776,6 @@ class CudaModuleTemplate(AttributeTemplate):
759
776
  def resolve_syncwarp(self, mod):
760
777
  return types.Function(Cuda_syncwarp)
761
778
 
762
- def resolve_shfl_sync_intrinsic(self, mod):
763
- return types.Function(Cuda_shfl_sync_intrinsic)
764
-
765
779
  def resolve_vote_sync_intrinsic(self, mod):
766
780
  return types.Function(Cuda_vote_sync_intrinsic)
767
781
 
@@ -811,5 +825,5 @@ for func in bit_twiddling_functions:
811
825
  register_numpy_ufunc(func, register_global)
812
826
 
813
827
  for func in math_operations:
814
- if func in ('log', 'log2', 'log10'):
828
+ if func in ("log", "log2", "log10"):
815
829
  register_numpy_ufunc(func, register_global)
@@ -5,5 +5,7 @@
5
5
  - Device array implementation
6
6
 
7
7
  """
8
+
8
9
  from numba.core import config
9
- assert not config.ENABLE_CUDASIM, 'Cannot use real driver API with simulator'
10
+
11
+ assert not config.ENABLE_CUDASIM, "Cannot use real driver API with simulator"