numba-cuda 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,11 @@
1
1
  from math import sqrt
2
2
  from numba import cuda, float32, int16, int32, int64, uint32, void
3
- from numba.cuda import (compile, compile_for_current_device, compile_ptx,
4
- compile_ptx_for_current_device)
3
+ from numba.cuda import (
4
+ compile,
5
+ compile_for_current_device,
6
+ compile_ptx,
7
+ compile_ptx_for_current_device,
8
+ )
5
9
  from numba.cuda.cudadrv import runtime
6
10
  from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase
7
11
 
@@ -12,7 +16,7 @@ def f_module(x, y):
12
16
  return x + y
13
17
 
14
18
 
15
- @skip_on_cudasim('Compilation unsupported in the simulator')
19
+ @skip_on_cudasim("Compilation unsupported in the simulator")
16
20
  class TestCompile(unittest.TestCase):
17
21
  def test_global_kernel(self):
18
22
  def f(r, x, y):
@@ -24,11 +28,11 @@ class TestCompile(unittest.TestCase):
24
28
  ptx, resty = compile_ptx(f, args)
25
29
 
26
30
  # Kernels should not have a func_retval parameter
27
- self.assertNotIn('func_retval', ptx)
31
+ self.assertNotIn("func_retval", ptx)
28
32
  # .visible .func is used to denote a device function
29
- self.assertNotIn('.visible .func', ptx)
33
+ self.assertNotIn(".visible .func", ptx)
30
34
  # .visible .entry would denote the presence of a global function
31
- self.assertIn('.visible .entry', ptx)
35
+ self.assertIn(".visible .entry", ptx)
32
36
  # Return type for kernels should always be void
33
37
  self.assertEqual(resty, void)
34
38
 
@@ -41,11 +45,11 @@ class TestCompile(unittest.TestCase):
41
45
 
42
46
  # Device functions take a func_retval parameter for storing the
43
47
  # returned value in by reference
44
- self.assertIn('func_retval', ptx)
48
+ self.assertIn("func_retval", ptx)
45
49
  # .visible .func is used to denote a device function
46
- self.assertIn('.visible .func', ptx)
50
+ self.assertIn(".visible .func", ptx)
47
51
  # .visible .entry would denote the presence of a global function
48
- self.assertNotIn('.visible .entry', ptx)
52
+ self.assertNotIn(".visible .entry", ptx)
49
53
  # Inferred return type as expected?
50
54
  self.assertEqual(resty, float32)
51
55
 
@@ -71,21 +75,21 @@ class TestCompile(unittest.TestCase):
71
75
 
72
76
  # Without fastmath, fma contraction is enabled by default, but ftz and
73
77
  # approximate div / sqrt is not.
74
- self.assertIn('fma.rn.f32', ptx)
75
- self.assertIn('div.rn.f32', ptx)
76
- self.assertIn('sqrt.rn.f32', ptx)
78
+ self.assertIn("fma.rn.f32", ptx)
79
+ self.assertIn("div.rn.f32", ptx)
80
+ self.assertIn("sqrt.rn.f32", ptx)
77
81
 
78
82
  ptx, resty = compile_ptx(f, args, device=True, fastmath=True)
79
83
 
80
84
  # With fastmath, ftz and approximate div / sqrt are enabled
81
- self.assertIn('fma.rn.ftz.f32', ptx)
82
- self.assertIn('div.approx.ftz.f32', ptx)
83
- self.assertIn('sqrt.approx.ftz.f32', ptx)
85
+ self.assertIn("fma.rn.ftz.f32", ptx)
86
+ self.assertIn("div.approx.ftz.f32", ptx)
87
+ self.assertIn("sqrt.approx.ftz.f32", ptx)
84
88
 
85
89
  def check_debug_info(self, ptx):
86
90
  # A debug_info section should exist in the PTX. Whitespace varies
87
91
  # between CUDA toolkit versions.
88
- self.assertRegex(ptx, '\\.section\\s+\\.debug_info')
92
+ self.assertRegex(ptx, "\\.section\\s+\\.debug_info")
89
93
  # A .file directive should be produced and include the name of the
90
94
  # source. The path and whitespace may vary, so we accept anything
91
95
  # ending in the filename of this module.
@@ -136,23 +140,25 @@ class TestCompile(unittest.TestCase):
136
140
  def f(x, y):
137
141
  return x[0] + y[0]
138
142
 
139
- with self.assertRaisesRegex(TypeError, 'must have void return type'):
143
+ with self.assertRaisesRegex(TypeError, "must have void return type"):
140
144
  compile_ptx(f, (uint32[::1], uint32[::1]))
141
145
 
142
146
  def test_c_abi_disallowed_for_kernel(self):
143
147
  def f(x, y):
144
148
  return x + y
145
149
 
146
- with self.assertRaisesRegex(NotImplementedError,
147
- "The C ABI is not supported for kernels"):
150
+ with self.assertRaisesRegex(
151
+ NotImplementedError, "The C ABI is not supported for kernels"
152
+ ):
148
153
  compile_ptx(f, (int32, int32), abi="c")
149
154
 
150
155
  def test_unsupported_abi(self):
151
156
  def f(x, y):
152
157
  return x + y
153
158
 
154
- with self.assertRaisesRegex(NotImplementedError,
155
- "Unsupported ABI: fastcall"):
159
+ with self.assertRaisesRegex(
160
+ NotImplementedError, "Unsupported ABI: fastcall"
161
+ ):
156
162
  compile_ptx(f, (int32, int32), abi="fastcall")
157
163
 
158
164
  def test_c_abi_device_function(self):
@@ -166,8 +172,11 @@ class TestCompile(unittest.TestCase):
166
172
  # The function name should match the Python function name (not the
167
173
  # qualname, which includes additional info), and its return value
168
174
  # should be 32 bits
169
- self.assertRegex(ptx, r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
170
- r"func_retval0\)\s+f\(")
175
+ self.assertRegex(
176
+ ptx,
177
+ r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
178
+ r"func_retval0\)\s+f\(",
179
+ )
171
180
 
172
181
  # If we compile for 64-bit integers, the return type should be 64 bits
173
182
  # wide
@@ -175,44 +184,60 @@ class TestCompile(unittest.TestCase):
175
184
  self.assertRegex(ptx, r"\.visible\s+\.func\s+\(\.param\s+\.b64")
176
185
 
177
186
  def test_c_abi_device_function_module_scope(self):
178
- ptx, resty = compile_ptx(f_module, int32(int32, int32), device=True,
179
- abi="c")
187
+ ptx, resty = compile_ptx(
188
+ f_module, int32(int32, int32), device=True, abi="c"
189
+ )
180
190
 
181
191
  # The function name should match the Python function name, and its
182
192
  # return value should be 32 bits
183
- self.assertRegex(ptx, r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
184
- r"func_retval0\)\s+f_module\(")
193
+ self.assertRegex(
194
+ ptx,
195
+ r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
196
+ r"func_retval0\)\s+f_module\(",
197
+ )
185
198
 
186
199
  def test_c_abi_with_abi_name(self):
187
- abi_info = {'abi_name': '_Z4funcii'}
188
- ptx, resty = compile_ptx(f_module, int32(int32, int32), device=True,
189
- abi="c", abi_info=abi_info)
200
+ abi_info = {"abi_name": "_Z4funcii"}
201
+ ptx, resty = compile_ptx(
202
+ f_module,
203
+ int32(int32, int32),
204
+ device=True,
205
+ abi="c",
206
+ abi_info=abi_info,
207
+ )
190
208
 
191
209
  # The function name should match the one given in the ABI info, and its
192
210
  # return value should be 32 bits
193
- self.assertRegex(ptx, r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
194
- r"func_retval0\)\s+_Z4funcii\(")
211
+ self.assertRegex(
212
+ ptx,
213
+ r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
214
+ r"func_retval0\)\s+_Z4funcii\(",
215
+ )
195
216
 
196
217
  def test_compile_defaults_to_c_abi(self):
197
218
  ptx, resty = compile(f_module, int32(int32, int32), device=True)
198
219
 
199
220
  # The function name should match the Python function name, and its
200
221
  # return value should be 32 bits
201
- self.assertRegex(ptx, r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
202
- r"func_retval0\)\s+f_module\(")
222
+ self.assertRegex(
223
+ ptx,
224
+ r"\.visible\s+\.func\s+\(\.param\s+\.b32\s+"
225
+ r"func_retval0\)\s+f_module\(",
226
+ )
203
227
 
204
228
  def test_compile_to_ltoir(self):
205
229
  if runtime.get_version() < (11, 5):
206
230
  self.skipTest("-gen-lto unavailable in this toolkit version")
207
231
 
208
- ltoir, resty = compile(f_module, int32(int32, int32), device=True,
209
- output="ltoir")
232
+ ltoir, resty = compile(
233
+ f_module, int32(int32, int32), device=True, output="ltoir"
234
+ )
210
235
 
211
236
  # There are no tools to interpret the LTOIR output, but we can check
212
237
  # that we appear to have obtained an LTOIR file. This magic number is
213
238
  # not documented, but is expected to remain consistent.
214
239
  LTOIR_MAGIC = 0x7F4E43ED
215
- header = int.from_bytes(ltoir[:4], byteorder='little')
240
+ header = int.from_bytes(ltoir[:4], byteorder="little")
216
241
  self.assertEqual(header, LTOIR_MAGIC)
217
242
  self.assertEqual(resty, int32)
218
243
 
@@ -220,11 +245,15 @@ class TestCompile(unittest.TestCase):
220
245
  illegal_output = "illegal"
221
246
  msg = f"Unsupported output type: {illegal_output}"
222
247
  with self.assertRaisesRegex(NotImplementedError, msg):
223
- compile(f_module, int32(int32, int32), device=True,
224
- output=illegal_output)
248
+ compile(
249
+ f_module,
250
+ int32(int32, int32),
251
+ device=True,
252
+ output=illegal_output,
253
+ )
225
254
 
226
255
 
227
- @skip_on_cudasim('Compilation unsupported in the simulator')
256
+ @skip_on_cudasim("Compilation unsupported in the simulator")
228
257
  class TestCompileForCurrentDevice(CUDATestCase):
229
258
  def _check_ptx_for_current_device(self, compile_function):
230
259
  def add(x, y):
@@ -237,7 +266,7 @@ class TestCompileForCurrentDevice(CUDATestCase):
237
266
  # closest compute capability supported by the current toolkit.
238
267
  device_cc = cuda.get_current_device().compute_capability
239
268
  cc = cuda.cudadrv.nvvm.find_closest_arch(device_cc)
240
- target = f'.target sm_{cc[0]}{cc[1]}'
269
+ target = f".target sm_{cc[0]}{cc[1]}"
241
270
  self.assertIn(target, ptx)
242
271
 
243
272
  def test_compile_ptx_for_current_device(self):
@@ -247,10 +276,10 @@ class TestCompileForCurrentDevice(CUDATestCase):
247
276
  self._check_ptx_for_current_device(compile_for_current_device)
248
277
 
249
278
 
250
- @skip_on_cudasim('Compilation unsupported in the simulator')
279
+ @skip_on_cudasim("Compilation unsupported in the simulator")
251
280
  class TestCompileOnlyTests(unittest.TestCase):
252
- '''For tests where we can only check correctness by examining the compiler
253
- output rather than observing the effects of execution.'''
281
+ """For tests where we can only check correctness by examining the compiler
282
+ output rather than observing the effects of execution."""
254
283
 
255
284
  def test_nanosleep(self):
256
285
  def use_nanosleep(x):
@@ -262,15 +291,20 @@ class TestCompileOnlyTests(unittest.TestCase):
262
291
  ptx, resty = compile_ptx(use_nanosleep, (uint32,), cc=(7, 0))
263
292
 
264
293
  nanosleep_count = 0
265
- for line in ptx.split('\n'):
266
- if 'nanosleep.u32' in line:
294
+ for line in ptx.split("\n"):
295
+ if "nanosleep.u32" in line:
267
296
  nanosleep_count += 1
268
297
 
269
298
  expected = 2
270
- self.assertEqual(expected, nanosleep_count,
271
- (f'Got {nanosleep_count} nanosleep instructions, '
272
- f'expected {expected}'))
299
+ self.assertEqual(
300
+ expected,
301
+ nanosleep_count,
302
+ (
303
+ f"Got {nanosleep_count} nanosleep instructions, "
304
+ f"expected {expected}"
305
+ ),
306
+ )
273
307
 
274
308
 
275
- if __name__ == '__main__':
309
+ if __name__ == "__main__":
276
310
  unittest.main()
@@ -6,20 +6,34 @@ import numpy as np
6
6
  from numba.cuda.testing import unittest, CUDATestCase
7
7
  from numba.core import types
8
8
  from numba import cuda
9
- from numba.tests.complex_usecases import (real_usecase, imag_usecase,
10
- conjugate_usecase, phase_usecase,
11
- polar_as_complex_usecase,
12
- rect_usecase, isnan_usecase,
13
- isinf_usecase, isfinite_usecase,
14
- exp_usecase, log_usecase,
15
- log_base_usecase, log10_usecase,
16
- sqrt_usecase, asin_usecase,
17
- acos_usecase, atan_usecase,
18
- cos_usecase, sin_usecase,
19
- tan_usecase, acosh_usecase,
20
- asinh_usecase, atanh_usecase,
21
- cosh_usecase, sinh_usecase,
22
- tanh_usecase)
9
+ from numba.tests.complex_usecases import (
10
+ real_usecase,
11
+ imag_usecase,
12
+ conjugate_usecase,
13
+ phase_usecase,
14
+ polar_as_complex_usecase,
15
+ rect_usecase,
16
+ isnan_usecase,
17
+ isinf_usecase,
18
+ isfinite_usecase,
19
+ exp_usecase,
20
+ log_usecase,
21
+ log_base_usecase,
22
+ log10_usecase,
23
+ sqrt_usecase,
24
+ asin_usecase,
25
+ acos_usecase,
26
+ atan_usecase,
27
+ cos_usecase,
28
+ sin_usecase,
29
+ tan_usecase,
30
+ acosh_usecase,
31
+ asinh_usecase,
32
+ atanh_usecase,
33
+ cosh_usecase,
34
+ sinh_usecase,
35
+ tanh_usecase,
36
+ )
23
37
  from numba.np import numpy_support
24
38
 
25
39
 
@@ -29,15 +43,18 @@ def compile_scalar_func(pyfunc, argtypes, restype):
29
43
  assert not isinstance(restype, types.Array)
30
44
  device_func = cuda.jit(restype(*argtypes), device=True)(pyfunc)
31
45
 
32
- kernel_types = [types.Array(tp, 1, "C")
33
- for tp in [restype] + list(argtypes)]
46
+ kernel_types = [
47
+ types.Array(tp, 1, "C") for tp in [restype] + list(argtypes)
48
+ ]
34
49
 
35
50
  if len(argtypes) == 1:
51
+
36
52
  def kernel_func(out, a):
37
53
  i = cuda.grid(1)
38
54
  if i < out.shape[0]:
39
55
  out[i] = device_func(a[i])
40
56
  elif len(argtypes) == 2:
57
+
41
58
  def kernel_func(out, a, b):
42
59
  i = cuda.grid(1)
43
60
  if i < out.shape[0]:
@@ -49,8 +66,9 @@ def compile_scalar_func(pyfunc, argtypes, restype):
49
66
 
50
67
  def kernel_wrapper(values):
51
68
  n = len(values)
52
- inputs = [np.empty(n, dtype=numpy_support.as_dtype(tp))
53
- for tp in argtypes]
69
+ inputs = [
70
+ np.empty(n, dtype=numpy_support.as_dtype(tp)) for tp in argtypes
71
+ ]
54
72
  output = np.empty(n, dtype=numpy_support.as_dtype(restype))
55
73
  for i, vs in enumerate(values):
56
74
  for v, inp in zip(vs, inputs):
@@ -58,42 +76,70 @@ def compile_scalar_func(pyfunc, argtypes, restype):
58
76
  args = [output] + inputs
59
77
  kernel[int(math.ceil(n / 256)), 256](*args)
60
78
  return list(output)
79
+
61
80
  return kernel_wrapper
62
81
 
63
82
 
64
83
  class BaseComplexTest(CUDATestCase):
65
-
66
84
  def basic_values(self):
67
- reals = [-0.0, +0.0, 1, -1, +1.5, -3.5,
68
- float('-inf'), float('+inf'), float('nan')]
85
+ reals = [
86
+ -0.0,
87
+ +0.0,
88
+ 1,
89
+ -1,
90
+ +1.5,
91
+ -3.5,
92
+ float("-inf"),
93
+ float("+inf"),
94
+ float("nan"),
95
+ ]
69
96
  return [complex(x, y) for x, y in itertools.product(reals, reals)]
70
97
 
71
98
  def more_values(self):
72
- reals = [0.0, +0.0, 1, -1, -math.pi, +math.pi,
73
- float('-inf'), float('+inf'), float('nan')]
99
+ reals = [
100
+ 0.0,
101
+ +0.0,
102
+ 1,
103
+ -1,
104
+ -math.pi,
105
+ +math.pi,
106
+ float("-inf"),
107
+ float("+inf"),
108
+ float("nan"),
109
+ ]
74
110
  return [complex(x, y) for x, y in itertools.product(reals, reals)]
75
111
 
76
112
  def non_nan_values(self):
77
- reals = [-0.0, +0.0, 1, -1, -math.pi, +math.pi,
78
- float('inf'), float('-inf')]
113
+ reals = [
114
+ -0.0,
115
+ +0.0,
116
+ 1,
117
+ -1,
118
+ -math.pi,
119
+ +math.pi,
120
+ float("inf"),
121
+ float("-inf"),
122
+ ]
79
123
  return [complex(x, y) for x, y in itertools.product(reals, reals)]
80
124
 
81
125
  def run_func(self, pyfunc, sigs, values, ulps=1, ignore_sign_on_zero=False):
82
126
  for sig in sigs:
83
127
  if isinstance(sig, types.Type):
84
- sig = sig,
128
+ sig = (sig,)
85
129
  if isinstance(sig, tuple):
86
130
  # Assume return type is the type of first argument
87
131
  sig = sig[0](*sig)
88
- prec = ('single'
89
- if sig.args[0] in (types.float32, types.complex64)
90
- else 'double')
132
+ prec = (
133
+ "single"
134
+ if sig.args[0] in (types.float32, types.complex64)
135
+ else "double"
136
+ )
91
137
  cudafunc = compile_scalar_func(pyfunc, sig.args, sig.return_type)
92
138
  ok_values = []
93
139
  expected_list = []
94
140
  for args in values:
95
141
  if not isinstance(args, (list, tuple)):
96
- args = args,
142
+ args = (args,)
97
143
  try:
98
144
  expected_list.append(pyfunc(*args))
99
145
  ok_values.append(args)
@@ -102,24 +148,31 @@ class BaseComplexTest(CUDATestCase):
102
148
  continue
103
149
  got_list = cudafunc(ok_values)
104
150
  for got, expected, args in zip(got_list, expected_list, ok_values):
105
- msg = 'for input %r with prec %r' % (args, prec)
106
- self.assertPreciseEqual(got, expected, prec=prec,
107
- ulps=ulps,
108
- ignore_sign_on_zero=ignore_sign_on_zero,
109
- msg=msg)
151
+ msg = "for input %r with prec %r" % (args, prec)
152
+ self.assertPreciseEqual(
153
+ got,
154
+ expected,
155
+ prec=prec,
156
+ ulps=ulps,
157
+ ignore_sign_on_zero=ignore_sign_on_zero,
158
+ msg=msg,
159
+ )
110
160
 
111
161
  run_unary = run_func
112
162
  run_binary = run_func
113
163
 
114
164
 
115
165
  class TestComplex(BaseComplexTest):
116
-
117
166
  def check_real_image(self, pyfunc):
118
167
  values = self.basic_values()
119
- self.run_unary(pyfunc,
120
- [tp.underlying_float(tp)
121
- for tp in (types.complex64, types.complex128)],
122
- values)
168
+ self.run_unary(
169
+ pyfunc,
170
+ [
171
+ tp.underlying_float(tp)
172
+ for tp in (types.complex64, types.complex128)
173
+ ],
174
+ values,
175
+ )
123
176
 
124
177
  def test_real(self):
125
178
  self.check_real_image(real_usecase)
@@ -130,9 +183,7 @@ class TestComplex(BaseComplexTest):
130
183
  def test_conjugate(self):
131
184
  pyfunc = conjugate_usecase
132
185
  values = self.basic_values()
133
- self.run_unary(pyfunc,
134
- [types.complex64, types.complex128],
135
- values)
186
+ self.run_unary(pyfunc, [types.complex64, types.complex128], values)
136
187
 
137
188
 
138
189
  class TestCMath(BaseComplexTest):
@@ -141,26 +192,44 @@ class TestCMath(BaseComplexTest):
141
192
  """
142
193
 
143
194
  def check_predicate_func(self, pyfunc):
144
- self.run_unary(pyfunc,
145
- [types.boolean(tp)
146
- for tp in (types.complex128, types.complex64)],
147
- self.basic_values())
148
-
149
- def check_unary_func(self, pyfunc, ulps=1, values=None,
150
- returns_float=False, ignore_sign_on_zero=False):
195
+ self.run_unary(
196
+ pyfunc,
197
+ [types.boolean(tp) for tp in (types.complex128, types.complex64)],
198
+ self.basic_values(),
199
+ )
200
+
201
+ def check_unary_func(
202
+ self,
203
+ pyfunc,
204
+ ulps=1,
205
+ values=None,
206
+ returns_float=False,
207
+ ignore_sign_on_zero=False,
208
+ ):
151
209
  if returns_float:
210
+
152
211
  def sig(tp):
153
212
  return tp.underlying_float(tp)
154
213
  else:
214
+
155
215
  def sig(tp):
156
216
  return tp(tp)
157
- self.run_unary(pyfunc, [sig(types.complex128)],
158
- values or self.more_values(), ulps=ulps,
159
- ignore_sign_on_zero=ignore_sign_on_zero)
217
+
218
+ self.run_unary(
219
+ pyfunc,
220
+ [sig(types.complex128)],
221
+ values or self.more_values(),
222
+ ulps=ulps,
223
+ ignore_sign_on_zero=ignore_sign_on_zero,
224
+ )
160
225
  # Avoid discontinuities around pi when in single precision.
161
- self.run_unary(pyfunc, [sig(types.complex64)],
162
- values or self.basic_values(), ulps=ulps,
163
- ignore_sign_on_zero=ignore_sign_on_zero)
226
+ self.run_unary(
227
+ pyfunc,
228
+ [sig(types.complex64)],
229
+ values or self.basic_values(),
230
+ ulps=ulps,
231
+ ignore_sign_on_zero=ignore_sign_on_zero,
232
+ )
164
233
 
165
234
  # Conversions
166
235
 
@@ -172,11 +241,14 @@ class TestCMath(BaseComplexTest):
172
241
 
173
242
  def test_rect(self):
174
243
  def do_test(tp, seed_values):
175
- values = [(z.real, z.imag) for z in seed_values
176
- if not math.isinf(z.imag) or z.real == 0]
244
+ values = [
245
+ (z.real, z.imag)
246
+ for z in seed_values
247
+ if not math.isinf(z.imag) or z.real == 0
248
+ ]
177
249
  float_type = tp.underlying_float
178
- self.run_binary(rect_usecase, [tp(float_type, float_type)],
179
- values)
250
+ self.run_binary(rect_usecase, [tp(float_type, float_type)], values)
251
+
180
252
  do_test(types.complex128, self.more_values())
181
253
  # Avoid discontinuities around pi when in single precision.
182
254
  do_test(types.complex64, self.basic_values())
@@ -202,10 +274,11 @@ class TestCMath(BaseComplexTest):
202
274
 
203
275
  def test_log_base(self):
204
276
  values = list(itertools.product(self.more_values(), self.more_values()))
205
- value_types = [(types.complex128, types.complex128),
206
- (types.complex64, types.complex64)]
207
- self.run_binary(log_base_usecase, value_types, values,
208
- ulps=3)
277
+ value_types = [
278
+ (types.complex128, types.complex128),
279
+ (types.complex64, types.complex64),
280
+ ]
281
+ self.run_binary(log_base_usecase, value_types, values, ulps=3)
209
282
 
210
283
  def test_log10(self):
211
284
  self.check_unary_func(log10_usecase)
@@ -222,8 +295,9 @@ class TestCMath(BaseComplexTest):
222
295
  self.check_unary_func(asin_usecase, ulps=2)
223
296
 
224
297
  def test_atan(self):
225
- self.check_unary_func(atan_usecase, ulps=2,
226
- values=self.non_nan_values())
298
+ self.check_unary_func(
299
+ atan_usecase, ulps=2, values=self.non_nan_values()
300
+ )
227
301
 
228
302
  def test_cos(self):
229
303
  self.check_unary_func(cos_usecase, ulps=2)
@@ -233,8 +307,7 @@ class TestCMath(BaseComplexTest):
233
307
  self.check_unary_func(sin_usecase, ulps=2)
234
308
 
235
309
  def test_tan(self):
236
- self.check_unary_func(tan_usecase, ulps=2,
237
- ignore_sign_on_zero=True)
310
+ self.check_unary_func(tan_usecase, ulps=2, ignore_sign_on_zero=True)
238
311
 
239
312
  # Hyperbolic functions
240
313
 
@@ -245,8 +318,7 @@ class TestCMath(BaseComplexTest):
245
318
  self.check_unary_func(asinh_usecase, ulps=2)
246
319
 
247
320
  def test_atanh(self):
248
- self.check_unary_func(atanh_usecase, ulps=2,
249
- ignore_sign_on_zero=True)
321
+ self.check_unary_func(atanh_usecase, ulps=2, ignore_sign_on_zero=True)
250
322
 
251
323
  def test_cosh(self):
252
324
  self.check_unary_func(cosh_usecase, ulps=2)
@@ -255,8 +327,7 @@ class TestCMath(BaseComplexTest):
255
327
  self.check_unary_func(sinh_usecase, ulps=2)
256
328
 
257
329
  def test_tanh(self):
258
- self.check_unary_func(tanh_usecase, ulps=2,
259
- ignore_sign_on_zero=True)
330
+ self.check_unary_func(tanh_usecase, ulps=2, ignore_sign_on_zero=True)
260
331
 
261
332
 
262
333
  class TestAtomicOnComplexComponents(CUDATestCase):
@@ -292,5 +363,5 @@ class TestAtomicOnComplexComponents(CUDATestCase):
292
363
  np.testing.assert_equal(arr1 + 1j, arr2)
293
364
 
294
365
 
295
- if __name__ == '__main__':
366
+ if __name__ == "__main__":
296
367
  unittest.main()
@@ -5,7 +5,7 @@ from numba.cuda.testing import unittest, CUDATestCase
5
5
 
6
6
  class TestCudaComplex(CUDATestCase):
7
7
  def test_cuda_complex_arg(self):
8
- @cuda.jit('void(complex128[:], complex128)')
8
+ @cuda.jit("void(complex128[:], complex128)")
9
9
  def foo(a, b):
10
10
  i = cuda.grid(1)
11
11
  a[i] += b
@@ -16,5 +16,5 @@ class TestCudaComplex(CUDATestCase):
16
16
  self.assertTrue(np.allclose(a, a0 + 2j))
17
17
 
18
18
 
19
- if __name__ == '__main__':
19
+ if __name__ == "__main__":
20
20
  unittest.main()