numba-cuda 0.0.1__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. _numba_cuda_redirector.pth +1 -0
  2. _numba_cuda_redirector.py +74 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +5 -0
  5. numba_cuda/_version.py +19 -0
  6. numba_cuda/numba/cuda/__init__.py +22 -0
  7. numba_cuda/numba/cuda/api.py +526 -0
  8. numba_cuda/numba/cuda/api_util.py +30 -0
  9. numba_cuda/numba/cuda/args.py +77 -0
  10. numba_cuda/numba/cuda/cg.py +62 -0
  11. numba_cuda/numba/cuda/codegen.py +378 -0
  12. numba_cuda/numba/cuda/compiler.py +422 -0
  13. numba_cuda/numba/cuda/cpp_function_wrappers.cu +47 -0
  14. numba_cuda/numba/cuda/cuda_fp16.h +3631 -0
  15. numba_cuda/numba/cuda/cuda_fp16.hpp +2465 -0
  16. numba_cuda/numba/cuda/cuda_paths.py +258 -0
  17. numba_cuda/numba/cuda/cudadecl.py +806 -0
  18. numba_cuda/numba/cuda/cudadrv/__init__.py +9 -0
  19. numba_cuda/numba/cuda/cudadrv/devicearray.py +904 -0
  20. numba_cuda/numba/cuda/cudadrv/devices.py +248 -0
  21. numba_cuda/numba/cuda/cudadrv/driver.py +3201 -0
  22. numba_cuda/numba/cuda/cudadrv/drvapi.py +398 -0
  23. numba_cuda/numba/cuda/cudadrv/dummyarray.py +452 -0
  24. numba_cuda/numba/cuda/cudadrv/enums.py +607 -0
  25. numba_cuda/numba/cuda/cudadrv/error.py +36 -0
  26. numba_cuda/numba/cuda/cudadrv/libs.py +176 -0
  27. numba_cuda/numba/cuda/cudadrv/ndarray.py +20 -0
  28. numba_cuda/numba/cuda/cudadrv/nvrtc.py +260 -0
  29. numba_cuda/numba/cuda/cudadrv/nvvm.py +707 -0
  30. numba_cuda/numba/cuda/cudadrv/rtapi.py +10 -0
  31. numba_cuda/numba/cuda/cudadrv/runtime.py +142 -0
  32. numba_cuda/numba/cuda/cudaimpl.py +1055 -0
  33. numba_cuda/numba/cuda/cudamath.py +140 -0
  34. numba_cuda/numba/cuda/decorators.py +189 -0
  35. numba_cuda/numba/cuda/descriptor.py +33 -0
  36. numba_cuda/numba/cuda/device_init.py +89 -0
  37. numba_cuda/numba/cuda/deviceufunc.py +908 -0
  38. numba_cuda/numba/cuda/dispatcher.py +1057 -0
  39. numba_cuda/numba/cuda/errors.py +59 -0
  40. numba_cuda/numba/cuda/extending.py +7 -0
  41. numba_cuda/numba/cuda/initialize.py +13 -0
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +77 -0
  43. numba_cuda/numba/cuda/intrinsics.py +198 -0
  44. numba_cuda/numba/cuda/kernels/__init__.py +0 -0
  45. numba_cuda/numba/cuda/kernels/reduction.py +262 -0
  46. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  47. numba_cuda/numba/cuda/libdevice.py +3382 -0
  48. numba_cuda/numba/cuda/libdevicedecl.py +17 -0
  49. numba_cuda/numba/cuda/libdevicefuncs.py +1057 -0
  50. numba_cuda/numba/cuda/libdeviceimpl.py +83 -0
  51. numba_cuda/numba/cuda/mathimpl.py +448 -0
  52. numba_cuda/numba/cuda/models.py +48 -0
  53. numba_cuda/numba/cuda/nvvmutils.py +235 -0
  54. numba_cuda/numba/cuda/printimpl.py +86 -0
  55. numba_cuda/numba/cuda/random.py +292 -0
  56. numba_cuda/numba/cuda/simulator/__init__.py +38 -0
  57. numba_cuda/numba/cuda/simulator/api.py +110 -0
  58. numba_cuda/numba/cuda/simulator/compiler.py +9 -0
  59. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +2 -0
  60. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +432 -0
  61. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +117 -0
  62. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +62 -0
  63. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +4 -0
  64. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +4 -0
  65. numba_cuda/numba/cuda/simulator/cudadrv/error.py +6 -0
  66. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +2 -0
  67. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +29 -0
  68. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +19 -0
  69. numba_cuda/numba/cuda/simulator/kernel.py +308 -0
  70. numba_cuda/numba/cuda/simulator/kernelapi.py +495 -0
  71. numba_cuda/numba/cuda/simulator/reduction.py +15 -0
  72. numba_cuda/numba/cuda/simulator/vector_types.py +58 -0
  73. numba_cuda/numba/cuda/simulator_init.py +17 -0
  74. numba_cuda/numba/cuda/stubs.py +902 -0
  75. numba_cuda/numba/cuda/target.py +440 -0
  76. numba_cuda/numba/cuda/testing.py +202 -0
  77. numba_cuda/numba/cuda/tests/__init__.py +58 -0
  78. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +8 -0
  79. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +145 -0
  80. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +145 -0
  81. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +375 -0
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +21 -0
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +179 -0
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +235 -0
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +22 -0
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +193 -0
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +547 -0
  88. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +249 -0
  89. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +81 -0
  90. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +192 -0
  91. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +38 -0
  92. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +65 -0
  93. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +139 -0
  94. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +37 -0
  95. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +12 -0
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +317 -0
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +127 -0
  98. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +54 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +199 -0
  100. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +37 -0
  101. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +20 -0
  102. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +149 -0
  103. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +36 -0
  104. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +85 -0
  105. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +41 -0
  106. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +122 -0
  107. numba_cuda/numba/cuda/tests/cudapy/__init__.py +8 -0
  108. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +234 -0
  109. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +41 -0
  110. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +58 -0
  111. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +30 -0
  112. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +100 -0
  113. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +42 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_array.py +260 -0
  115. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +201 -0
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +35 -0
  117. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1620 -0
  118. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +120 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +24 -0
  120. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +545 -0
  121. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +257 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +33 -0
  123. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +276 -0
  124. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +296 -0
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +20 -0
  126. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +129 -0
  127. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +176 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +147 -0
  129. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +435 -0
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +90 -0
  131. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +94 -0
  132. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +101 -0
  133. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +221 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +222 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +700 -0
  136. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +121 -0
  137. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +79 -0
  138. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +174 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +155 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +244 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +52 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +29 -0
  143. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +66 -0
  144. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +60 -0
  145. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +456 -0
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +159 -0
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +95 -0
  148. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +37 -0
  149. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +165 -0
  150. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1106 -0
  151. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +318 -0
  152. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +99 -0
  153. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +64 -0
  154. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +119 -0
  155. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +187 -0
  156. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +199 -0
  157. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +164 -0
  158. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +37 -0
  159. numba_cuda/numba/cuda/tests/cudapy/test_math.py +786 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +74 -0
  161. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +113 -0
  162. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +22 -0
  163. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +140 -0
  164. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +46 -0
  165. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +101 -0
  166. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +49 -0
  167. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +401 -0
  168. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +86 -0
  169. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +335 -0
  170. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +124 -0
  171. numba_cuda/numba/cuda/tests/cudapy/test_print.py +128 -0
  172. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +33 -0
  173. numba_cuda/numba/cuda/tests/cudapy/test_random.py +104 -0
  174. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +610 -0
  175. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +125 -0
  176. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +76 -0
  177. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  178. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +85 -0
  179. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +37 -0
  180. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +444 -0
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +205 -0
  182. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +271 -0
  183. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +80 -0
  184. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +277 -0
  185. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +47 -0
  186. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +307 -0
  187. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +283 -0
  188. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +20 -0
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +69 -0
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +36 -0
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +37 -0
  192. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +139 -0
  193. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +276 -0
  194. numba_cuda/numba/cuda/tests/cudasim/__init__.py +6 -0
  195. numba_cuda/numba/cuda/tests/cudasim/support.py +6 -0
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +102 -0
  197. numba_cuda/numba/cuda/tests/data/__init__.py +0 -0
  198. numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
  199. numba_cuda/numba/cuda/tests/data/error.cu +7 -0
  200. numba_cuda/numba/cuda/tests/data/jitlink.cu +23 -0
  201. numba_cuda/numba/cuda/tests/data/jitlink.ptx +51 -0
  202. numba_cuda/numba/cuda/tests/data/warn.cu +7 -0
  203. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +6 -0
  204. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +0 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +49 -0
  206. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +77 -0
  207. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +76 -0
  208. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +82 -0
  209. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +155 -0
  210. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +173 -0
  211. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +109 -0
  212. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +59 -0
  213. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +76 -0
  214. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +130 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +50 -0
  216. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +73 -0
  217. numba_cuda/numba/cuda/tests/nocuda/__init__.py +8 -0
  218. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +359 -0
  219. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +36 -0
  220. numba_cuda/numba/cuda/tests/nocuda/test_import.py +49 -0
  221. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +238 -0
  222. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +54 -0
  223. numba_cuda/numba/cuda/types.py +37 -0
  224. numba_cuda/numba/cuda/ufuncs.py +662 -0
  225. numba_cuda/numba/cuda/vector_types.py +209 -0
  226. numba_cuda/numba/cuda/vectorizers.py +252 -0
  227. numba_cuda-0.0.13.dist-info/LICENSE +25 -0
  228. numba_cuda-0.0.13.dist-info/METADATA +69 -0
  229. numba_cuda-0.0.13.dist-info/RECORD +231 -0
  230. {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.13.dist-info}/WHEEL +1 -1
  231. numba_cuda-0.0.1.dist-info/METADATA +0 -10
  232. numba_cuda-0.0.1.dist-info/RECORD +0 -5
  233. {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,700 @@
1
+ import numpy as np
2
+ import threading
3
+
4
+ from numba import boolean, config, cuda, float32, float64, int32, int64, void
5
+ from numba.core.errors import TypingError
6
+ from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase
7
+ import math
8
+
9
+
10
+ def add(x, y):
11
+ return x + y
12
+
13
+
14
+ def add_kernel(r, x, y):
15
+ r[0] = x + y
16
+
17
+
18
+ @skip_on_cudasim('Specialization not implemented in the simulator')
19
+ class TestDispatcherSpecialization(CUDATestCase):
20
+ def _test_no_double_specialize(self, dispatcher, ty):
21
+
22
+ with self.assertRaises(RuntimeError) as e:
23
+ dispatcher.specialize(ty)
24
+
25
+ self.assertIn('Dispatcher already specialized', str(e.exception))
26
+
27
+ def test_no_double_specialize_sig_same_types(self):
28
+ # Attempting to specialize a kernel jitted with a signature is illegal,
29
+ # even for the same types the kernel is already specialized for.
30
+ @cuda.jit('void(float32[::1])')
31
+ def f(x):
32
+ pass
33
+
34
+ self._test_no_double_specialize(f, float32[::1])
35
+
36
+ def test_no_double_specialize_no_sig_same_types(self):
37
+ # Attempting to specialize an already-specialized kernel is illegal,
38
+ # even for the same types the kernel is already specialized for.
39
+ @cuda.jit
40
+ def f(x):
41
+ pass
42
+
43
+ f_specialized = f.specialize(float32[::1])
44
+ self._test_no_double_specialize(f_specialized, float32[::1])
45
+
46
+ def test_no_double_specialize_sig_diff_types(self):
47
+ # Attempting to specialize a kernel jitted with a signature is illegal.
48
+ @cuda.jit('void(int32[::1])')
49
+ def f(x):
50
+ pass
51
+
52
+ self._test_no_double_specialize(f, float32[::1])
53
+
54
+ def test_no_double_specialize_no_sig_diff_types(self):
55
+ # Attempting to specialize an already-specialized kernel is illegal.
56
+ @cuda.jit
57
+ def f(x):
58
+ pass
59
+
60
+ f_specialized = f.specialize(int32[::1])
61
+ self._test_no_double_specialize(f_specialized, float32[::1])
62
+
63
+ def test_specialize_cache_same(self):
64
+ # Ensure that the same dispatcher is returned for the same argument
65
+ # types, and that different dispatchers are returned for different
66
+ # argument types.
67
+ @cuda.jit
68
+ def f(x):
69
+ pass
70
+
71
+ self.assertEqual(len(f.specializations), 0)
72
+
73
+ f_float32 = f.specialize(float32[::1])
74
+ self.assertEqual(len(f.specializations), 1)
75
+
76
+ f_float32_2 = f.specialize(float32[::1])
77
+ self.assertEqual(len(f.specializations), 1)
78
+ self.assertIs(f_float32, f_float32_2)
79
+
80
+ f_int32 = f.specialize(int32[::1])
81
+ self.assertEqual(len(f.specializations), 2)
82
+ self.assertIsNot(f_int32, f_float32)
83
+
84
+ def test_specialize_cache_same_with_ordering(self):
85
+ # Ensure that the same dispatcher is returned for the same argument
86
+ # types, and that different dispatchers are returned for different
87
+ # argument types, taking into account array ordering and multiple
88
+ # arguments.
89
+ @cuda.jit
90
+ def f(x, y):
91
+ pass
92
+
93
+ self.assertEqual(len(f.specializations), 0)
94
+
95
+ # 'A' order specialization
96
+ f_f32a_f32a = f.specialize(float32[:], float32[:])
97
+ self.assertEqual(len(f.specializations), 1)
98
+
99
+ # 'C' order specialization
100
+ f_f32c_f32c = f.specialize(float32[::1], float32[::1])
101
+ self.assertEqual(len(f.specializations), 2)
102
+ self.assertIsNot(f_f32a_f32a, f_f32c_f32c)
103
+
104
+ # Reuse 'C' order specialization
105
+ f_f32c_f32c_2 = f.specialize(float32[::1], float32[::1])
106
+ self.assertEqual(len(f.specializations), 2)
107
+ self.assertIs(f_f32c_f32c, f_f32c_f32c_2)
108
+
109
+
110
+ class TestDispatcher(CUDATestCase):
111
+ """Most tests based on those in numba.tests.test_dispatcher."""
112
+
113
+ def test_coerce_input_types(self):
114
+ # Do not allow unsafe conversions if we can still compile other
115
+ # specializations.
116
+ c_add = cuda.jit(add_kernel)
117
+
118
+ # Using a complex128 allows us to represent any result produced by the
119
+ # test
120
+ r = np.zeros(1, dtype=np.complex128)
121
+
122
+ c_add[1, 1](r, 123, 456)
123
+ self.assertEqual(r[0], add(123, 456))
124
+
125
+ c_add[1, 1](r, 12.3, 45.6)
126
+ self.assertEqual(r[0], add(12.3, 45.6))
127
+
128
+ c_add[1, 1](r, 12.3, 45.6j)
129
+ self.assertEqual(r[0], add(12.3, 45.6j))
130
+
131
+ c_add[1, 1](r, 12300000000, 456)
132
+ self.assertEqual(r[0], add(12300000000, 456))
133
+
134
+ # Now force compilation of only a single specialization
135
+ c_add = cuda.jit('(i4[::1], i4, i4)')(add_kernel)
136
+ r = np.zeros(1, dtype=np.int32)
137
+
138
+ c_add[1, 1](r, 123, 456)
139
+ self.assertPreciseEqual(r[0], add(123, 456))
140
+
141
+ @skip_on_cudasim('Simulator ignores signature')
142
+ @unittest.expectedFailure
143
+ def test_coerce_input_types_unsafe(self):
144
+ # Implicit (unsafe) conversion of float to int, originally from
145
+ # test_coerce_input_types. This test presently fails with the CUDA
146
+ # Dispatcher because argument preparation is done by
147
+ # _Kernel._prepare_args, which is currently inflexible with respect to
148
+ # the types it can accept when preparing.
149
+ #
150
+ # This test is marked as xfail until future changes enable this
151
+ # behavior.
152
+ c_add = cuda.jit('(i4[::1], i4, i4)')(add_kernel)
153
+ r = np.zeros(1, dtype=np.int32)
154
+
155
+ c_add[1, 1](r, 12.3, 45.6)
156
+ self.assertPreciseEqual(r[0], add(12, 45))
157
+
158
+ @skip_on_cudasim('Simulator ignores signature')
159
+ def test_coerce_input_types_unsafe_complex(self):
160
+ # Implicit conversion of complex to int disallowed
161
+ c_add = cuda.jit('(i4[::1], i4, i4)')(add_kernel)
162
+ r = np.zeros(1, dtype=np.int32)
163
+
164
+ with self.assertRaises(TypeError):
165
+ c_add[1, 1](r, 12.3, 45.6j)
166
+
167
+ @skip_on_cudasim('Simulator does not track overloads')
168
+ def test_ambiguous_new_version(self):
169
+ """Test compiling new version in an ambiguous case
170
+ """
171
+ c_add = cuda.jit(add_kernel)
172
+
173
+ r = np.zeros(1, dtype=np.float64)
174
+ INT = 1
175
+ FLT = 1.5
176
+
177
+ c_add[1, 1](r, INT, FLT)
178
+ self.assertAlmostEqual(r[0], INT + FLT)
179
+ self.assertEqual(len(c_add.overloads), 1)
180
+
181
+ c_add[1, 1](r, FLT, INT)
182
+ self.assertAlmostEqual(r[0], FLT + INT)
183
+ self.assertEqual(len(c_add.overloads), 2)
184
+
185
+ c_add[1, 1](r, FLT, FLT)
186
+ self.assertAlmostEqual(r[0], FLT + FLT)
187
+ self.assertEqual(len(c_add.overloads), 3)
188
+
189
+ # The following call is ambiguous because (int, int) can resolve
190
+ # to (float, int) or (int, float) with equal weight.
191
+ c_add[1, 1](r, 1, 1)
192
+ self.assertAlmostEqual(r[0], INT + INT)
193
+ self.assertEqual(len(c_add.overloads), 4, "didn't compile a new "
194
+ "version")
195
+
196
+ @skip_on_cudasim("Simulator doesn't support concurrent kernels")
197
+ def test_lock(self):
198
+ """
199
+ Test that (lazy) compiling from several threads at once doesn't
200
+ produce errors (see issue #908).
201
+ """
202
+ errors = []
203
+
204
+ @cuda.jit
205
+ def foo(r, x):
206
+ r[0] = x + 1
207
+
208
+ def wrapper():
209
+ try:
210
+ r = np.zeros(1, dtype=np.int64)
211
+ foo[1, 1](r, 1)
212
+ self.assertEqual(r[0], 2)
213
+ except Exception as e:
214
+ errors.append(e)
215
+
216
+ threads = [threading.Thread(target=wrapper) for i in range(16)]
217
+ for t in threads:
218
+ t.start()
219
+ for t in threads:
220
+ t.join()
221
+ self.assertFalse(errors)
222
+
223
+ def _test_explicit_signatures(self, sigs):
224
+ f = cuda.jit(sigs)(add_kernel)
225
+
226
+ # Exact signature matches
227
+ r = np.zeros(1, dtype=np.int64)
228
+ f[1, 1](r, 1, 2)
229
+ self.assertPreciseEqual(r[0], 3)
230
+
231
+ r = np.zeros(1, dtype=np.float64)
232
+ f[1, 1](r, 1.5, 2.5)
233
+ self.assertPreciseEqual(r[0], 4.0)
234
+
235
+ if config.ENABLE_CUDASIM:
236
+ # Pass - we can't check for no conversion on the simulator.
237
+ return
238
+
239
+ # No conversion
240
+ with self.assertRaises(TypeError) as cm:
241
+ r = np.zeros(1, dtype=np.complex128)
242
+ f[1, 1](r, 1j, 1j)
243
+ self.assertIn("No matching definition", str(cm.exception))
244
+ self.assertEqual(len(f.overloads), 2, f.overloads)
245
+
246
+ def test_explicit_signatures_strings(self):
247
+ # Check with a list of strings for signatures
248
+ sigs = ["(int64[::1], int64, int64)",
249
+ "(float64[::1], float64, float64)"]
250
+ self._test_explicit_signatures(sigs)
251
+
252
+ def test_explicit_signatures_tuples(self):
253
+ # Check with a list of tuples of argument types for signatures
254
+ sigs = [(int64[::1], int64, int64), (float64[::1], float64, float64)]
255
+ self._test_explicit_signatures(sigs)
256
+
257
+ def test_explicit_signatures_signatures(self):
258
+ # Check with a list of Signature objects for signatures
259
+ sigs = [void(int64[::1], int64, int64),
260
+ void(float64[::1], float64, float64)]
261
+ self._test_explicit_signatures(sigs)
262
+
263
+ def test_explicit_signatures_mixed(self):
264
+ # Check when we mix types of signature objects in a list of signatures
265
+
266
+ # Tuple and string
267
+ sigs = [(int64[::1], int64, int64),
268
+ "(float64[::1], float64, float64)"]
269
+ self._test_explicit_signatures(sigs)
270
+
271
+ # Tuple and Signature object
272
+ sigs = [(int64[::1], int64, int64),
273
+ void(float64[::1], float64, float64)]
274
+ self._test_explicit_signatures(sigs)
275
+
276
+ # Signature object and string
277
+ sigs = [void(int64[::1], int64, int64),
278
+ "(float64[::1], float64, float64)"]
279
+ self._test_explicit_signatures(sigs)
280
+
281
+ def test_explicit_signatures_same_type_class(self):
282
+ # A more interesting one...
283
+ # (Note that the type of r is deliberately float64 in both cases so
284
+ # that dispatch is differentiated on the types of x and y only, to
285
+ # closely preserve the intent of the original test from
286
+ # numba.tests.test_dispatcher)
287
+ sigs = ["(float64[::1], float32, float32)",
288
+ "(float64[::1], float64, float64)"]
289
+ f = cuda.jit(sigs)(add_kernel)
290
+
291
+ r = np.zeros(1, dtype=np.float64)
292
+ f[1, 1](r, np.float32(1), np.float32(2**-25))
293
+ self.assertPreciseEqual(r[0], 1.0)
294
+
295
+ r = np.zeros(1, dtype=np.float64)
296
+ f[1, 1](r, 1, 2**-25)
297
+ self.assertPreciseEqual(r[0], 1.0000000298023224)
298
+
299
+ @skip_on_cudasim('No overload resolution in the simulator')
300
+ def test_explicit_signatures_ambiguous_resolution(self):
301
+ # Fail to resolve ambiguity between the two best overloads
302
+ # (Also deliberate float64[::1] for the first argument in all cases)
303
+ f = cuda.jit(["(float64[::1], float32, float64)",
304
+ "(float64[::1], float64, float32)",
305
+ "(float64[::1], int64, int64)"])(add_kernel)
306
+ with self.assertRaises(TypeError) as cm:
307
+ r = np.zeros(1, dtype=np.float64)
308
+ f[1, 1](r, 1.0, 2.0)
309
+
310
+ # The two best matches are output in the error message, as well
311
+ # as the actual argument types.
312
+ self.assertRegex(
313
+ str(cm.exception),
314
+ r"Ambiguous overloading for <function add_kernel [^>]*> "
315
+ r"\(Array\(float64, 1, 'C', False, aligned=True\), float64,"
316
+ r" float64\):\n"
317
+ r"\(Array\(float64, 1, 'C', False, aligned=True\), float32,"
318
+ r" float64\) -> none\n"
319
+ r"\(Array\(float64, 1, 'C', False, aligned=True\), float64,"
320
+ r" float32\) -> none"
321
+ )
322
+ # The integer signature is not part of the best matches
323
+ self.assertNotIn("int64", str(cm.exception))
324
+
325
+ @skip_on_cudasim('Simulator does not use _prepare_args')
326
+ @unittest.expectedFailure
327
+ def test_explicit_signatures_unsafe(self):
328
+ # These tests are from test_explicit_signatures, but have to be xfail
329
+ # at present because _prepare_args in the CUDA target cannot handle
330
+ # unsafe conversions of arguments.
331
+ f = cuda.jit("(int64[::1], int64, int64)")(add_kernel)
332
+ r = np.zeros(1, dtype=np.int64)
333
+
334
+ # Approximate match (unsafe conversion)
335
+ f[1, 1](r, 1.5, 2.5)
336
+ self.assertPreciseEqual(r[0], 3)
337
+ self.assertEqual(len(f.overloads), 1, f.overloads)
338
+
339
+ sigs = ["(int64[::1], int64, int64)",
340
+ "(float64[::1], float64, float64)"]
341
+ f = cuda.jit(sigs)(add_kernel)
342
+ r = np.zeros(1, dtype=np.float64)
343
+ # Approximate match (int32 -> float64 is a safe conversion)
344
+ f[1, 1](r, np.int32(1), 2.5)
345
+ self.assertPreciseEqual(r[0], 3.5)
346
+
347
+ def add_device_usecase(self, sigs):
348
+ # Generate a kernel that calls the add device function compiled with a
349
+ # given set of signatures
350
+ add_device = cuda.jit(sigs, device=True)(add)
351
+
352
+ @cuda.jit
353
+ def f(r, x, y):
354
+ r[0] = add_device(x, y)
355
+
356
+ return f
357
+
358
+ def test_explicit_signatures_device(self):
359
+ # Tests similar to test_explicit_signatures, but on a device function
360
+ # instead of a kernel
361
+ sigs = ["(int64, int64)", "(float64, float64)"]
362
+ f = self.add_device_usecase(sigs)
363
+
364
+ # Exact signature matches
365
+ r = np.zeros(1, dtype=np.int64)
366
+ f[1, 1](r, 1, 2)
367
+ self.assertPreciseEqual(r[0], 3)
368
+
369
+ r = np.zeros(1, dtype=np.float64)
370
+ f[1, 1](r, 1.5, 2.5)
371
+ self.assertPreciseEqual(r[0], 4.0)
372
+
373
+ if config.ENABLE_CUDASIM:
374
+ # Pass - we can't check for no conversion on the simulator.
375
+ return
376
+
377
+ # No conversion
378
+ with self.assertRaises(TypingError) as cm:
379
+ r = np.zeros(1, dtype=np.complex128)
380
+ f[1, 1](r, 1j, 1j)
381
+
382
+ msg = str(cm.exception)
383
+ self.assertIn("Invalid use of type", msg)
384
+ self.assertIn("with parameters (complex128, complex128)", msg)
385
+ self.assertEqual(len(f.overloads), 2, f.overloads)
386
+
387
+ def test_explicit_signatures_device_same_type_class(self):
388
+ # A more interesting one...
389
+ # (Note that the type of r is deliberately float64 in both cases so
390
+ # that dispatch is differentiated on the types of x and y only, to
391
+ # closely preserve the intent of the original test from
392
+ # numba.tests.test_dispatcher)
393
+ sigs = ["(float32, float32)", "(float64, float64)"]
394
+ f = self.add_device_usecase(sigs)
395
+
396
+ r = np.zeros(1, dtype=np.float64)
397
+ f[1, 1](r, np.float32(1), np.float32(2**-25))
398
+ self.assertPreciseEqual(r[0], 1.0)
399
+
400
+ r = np.zeros(1, dtype=np.float64)
401
+ f[1, 1](r, 1, 2**-25)
402
+ self.assertPreciseEqual(r[0], 1.0000000298023224)
403
+
404
+ def test_explicit_signatures_device_ambiguous(self):
405
+ # Ambiguity between the two best overloads resolves. This is somewhat
406
+ # surprising given that ambiguity is not permitted for dispatching
407
+ # overloads when launching a kernel, but seems to be the general
408
+ # behaviour of Numba (See Issue #8307:
409
+ # https://github.com/numba/numba/issues/8307).
410
+ sigs = ["(float32, float64)", "(float64, float32)", "(int64, int64)"]
411
+ f = self.add_device_usecase(sigs)
412
+
413
+ r = np.zeros(1, dtype=np.float64)
414
+ f[1, 1](r, 1.5, 2.5)
415
+ self.assertPreciseEqual(r[0], 4.0)
416
+
417
+ @skip_on_cudasim('CUDA Simulator does not force casting')
418
+ def test_explicit_signatures_device_unsafe(self):
419
+ # These tests are from test_explicit_signatures. The device function
420
+ # variant of these tests can succeed on CUDA because the compilation
421
+ # can handle unsafe casting (c.f. test_explicit_signatures_unsafe which
422
+ # has to xfail due to _prepare_args not supporting unsafe casting).
423
+ sigs = ["(int64, int64)"]
424
+ f = self.add_device_usecase(sigs)
425
+
426
+ # Approximate match (unsafe conversion)
427
+ r = np.zeros(1, dtype=np.int64)
428
+ f[1, 1](r, 1.5, 2.5)
429
+ self.assertPreciseEqual(r[0], 3)
430
+ self.assertEqual(len(f.overloads), 1, f.overloads)
431
+
432
+ sigs = ["(int64, int64)", "(float64, float64)"]
433
+ f = self.add_device_usecase(sigs)
434
+
435
+ # Approximate match (int32 -> float64 is a safe conversion)
436
+ r = np.zeros(1, dtype=np.float64)
437
+ f[1, 1](r, np.int32(1), 2.5)
438
+ self.assertPreciseEqual(r[0], 3.5)
439
+
440
+ def test_dispatcher_docstring(self):
441
+ # Ensure that CUDA-jitting a function preserves its docstring. See
442
+ # Issue #5902: https://github.com/numba/numba/issues/5902
443
+
444
+ @cuda.jit
445
+ def add_kernel(a, b):
446
+ """Add two integers, kernel version"""
447
+
448
+ @cuda.jit(device=True)
449
+ def add_device(a, b):
450
+ """Add two integers, device version"""
451
+
452
+ self.assertEqual("Add two integers, kernel version", add_kernel.__doc__)
453
+ self.assertEqual("Add two integers, device version", add_device.__doc__)
454
+
455
+
456
+ @skip_on_cudasim("CUDA simulator doesn't implement kernel properties")
457
+ class TestDispatcherKernelProperties(CUDATestCase):
458
+ def test_get_regs_per_thread_unspecialized(self):
459
+ # A kernel where the register usage per thread is likely to differ
460
+ # between different specializations
461
+ @cuda.jit
462
+ def pi_sin_array(x, n):
463
+ i = cuda.grid(1)
464
+ if i < n:
465
+ x[i] = 3.14 * math.sin(x[i])
466
+
467
+ # Call the kernel with different arguments to create two different
468
+ # definitions within the Dispatcher object
469
+ N = 10
470
+ arr_f32 = np.zeros(N, dtype=np.float32)
471
+ arr_f64 = np.zeros(N, dtype=np.float64)
472
+
473
+ pi_sin_array[1, N](arr_f32, N)
474
+ pi_sin_array[1, N](arr_f64, N)
475
+
476
+ # Check we get a positive integer for the two different variations
477
+ sig_f32 = void(float32[::1], int64)
478
+ sig_f64 = void(float64[::1], int64)
479
+ regs_per_thread_f32 = pi_sin_array.get_regs_per_thread(sig_f32)
480
+ regs_per_thread_f64 = pi_sin_array.get_regs_per_thread(sig_f64)
481
+
482
+ self.assertIsInstance(regs_per_thread_f32, int)
483
+ self.assertIsInstance(regs_per_thread_f64, int)
484
+
485
+ self.assertGreater(regs_per_thread_f32, 0)
486
+ self.assertGreater(regs_per_thread_f64, 0)
487
+
488
+ # Check that getting the registers per thread for all signatures
489
+ # provides the same values as getting the registers per thread for
490
+ # individual signatures.
491
+ regs_per_thread_all = pi_sin_array.get_regs_per_thread()
492
+ self.assertEqual(regs_per_thread_all[sig_f32.args],
493
+ regs_per_thread_f32)
494
+ self.assertEqual(regs_per_thread_all[sig_f64.args],
495
+ regs_per_thread_f64)
496
+
497
+ if regs_per_thread_f32 == regs_per_thread_f64:
498
+ # If the register usage is the same for both variants, there may be
499
+ # a bug, but this may also be an artifact of the compiler / driver
500
+ # / device combination, so produce an informational message only.
501
+ print('f32 and f64 variant thread usages are equal.')
502
+ print('This may warrant some investigation. Devices:')
503
+ cuda.detect()
504
+
505
+ def test_get_regs_per_thread_specialized(self):
506
+ @cuda.jit(void(float32[::1], int64))
507
+ def pi_sin_array(x, n):
508
+ i = cuda.grid(1)
509
+ if i < n:
510
+ x[i] = 3.14 * math.sin(x[i])
511
+
512
+ # Check we get a positive integer for the specialized variation
513
+ regs_per_thread = pi_sin_array.get_regs_per_thread()
514
+ self.assertIsInstance(regs_per_thread, int)
515
+ self.assertGreater(regs_per_thread, 0)
516
+
517
+ def test_get_const_mem_unspecialized(self):
518
+ @cuda.jit
519
+ def const_fmt_string(val, to_print):
520
+ # We guard the print with a conditional to prevent noise from the
521
+ # test suite
522
+ if to_print:
523
+ print(val)
524
+
525
+ # Call the kernel with different arguments to create two different
526
+ # definitions within the Dispatcher object
527
+ const_fmt_string[1, 1](1, False)
528
+ const_fmt_string[1, 1](1.0, False)
529
+
530
+ # Check we get a positive integer for the two different variations
531
+ sig_i64 = void(int64, boolean)
532
+ sig_f64 = void(float64, boolean)
533
+ const_mem_size_i64 = const_fmt_string.get_const_mem_size(sig_i64)
534
+ const_mem_size_f64 = const_fmt_string.get_const_mem_size(sig_f64)
535
+
536
+ self.assertIsInstance(const_mem_size_i64, int)
537
+ self.assertIsInstance(const_mem_size_f64, int)
538
+
539
+ # 6 bytes for the equivalent of b'%lld\n\0'
540
+ self.assertGreaterEqual(const_mem_size_i64, 6)
541
+ # 4 bytes for the equivalent of b'%f\n\0'
542
+ self.assertGreaterEqual(const_mem_size_f64, 4)
543
+
544
+ # Check that getting the const memory size for all signatures
545
+ # provides the same values as getting the const memory size for
546
+ # individual signatures.
547
+
548
+ const_mem_size_all = const_fmt_string.get_const_mem_size()
549
+ self.assertEqual(const_mem_size_all[sig_i64.args], const_mem_size_i64)
550
+ self.assertEqual(const_mem_size_all[sig_f64.args], const_mem_size_f64)
551
+
552
+ def test_get_const_mem_specialized(self):
553
+ arr = np.arange(32, dtype=np.int64)
554
+ sig = void(int64[::1])
555
+
556
+ @cuda.jit(sig)
557
+ def const_array_use(x):
558
+ C = cuda.const.array_like(arr)
559
+ i = cuda.grid(1)
560
+ x[i] = C[i]
561
+
562
+ const_mem_size = const_array_use.get_const_mem_size(sig)
563
+ self.assertIsInstance(const_mem_size, int)
564
+ self.assertGreaterEqual(const_mem_size, arr.nbytes)
565
+
566
+ def test_get_shared_mem_per_block_unspecialized(self):
567
+ N = 10
568
+
569
+ # A kernel where the shared memory per block is likely to differ
570
+ # between different specializations
571
+ @cuda.jit
572
+ def simple_smem(ary):
573
+ sm = cuda.shared.array(N, dtype=ary.dtype)
574
+ for j in range(N):
575
+ sm[j] = j
576
+ for j in range(N):
577
+ ary[j] = sm[j]
578
+
579
+ # Call the kernel with different arguments to create two different
580
+ # definitions within the Dispatcher object
581
+ arr_f32 = np.zeros(N, dtype=np.float32)
582
+ arr_f64 = np.zeros(N, dtype=np.float64)
583
+
584
+ simple_smem[1, 1](arr_f32)
585
+ simple_smem[1, 1](arr_f64)
586
+
587
+ sig_f32 = void(float32[::1])
588
+ sig_f64 = void(float64[::1])
589
+
590
+ sh_mem_f32 = simple_smem.get_shared_mem_per_block(sig_f32)
591
+ sh_mem_f64 = simple_smem.get_shared_mem_per_block(sig_f64)
592
+
593
+ self.assertIsInstance(sh_mem_f32, int)
594
+ self.assertIsInstance(sh_mem_f64, int)
595
+
596
+ self.assertEqual(sh_mem_f32, N * 4)
597
+ self.assertEqual(sh_mem_f64, N * 8)
598
+
599
+ # Check that getting the shared memory per block for all signatures
600
+ # provides the same values as getting the shared mem per block for
601
+ # individual signatures.
602
+ sh_mem_f32_all = simple_smem.get_shared_mem_per_block()
603
+ sh_mem_f64_all = simple_smem.get_shared_mem_per_block()
604
+ self.assertEqual(sh_mem_f32_all[sig_f32.args], sh_mem_f32)
605
+ self.assertEqual(sh_mem_f64_all[sig_f64.args], sh_mem_f64)
606
+
607
+ def test_get_shared_mem_per_block_specialized(self):
608
+ @cuda.jit(void(float32[::1]))
609
+ def simple_smem(ary):
610
+ sm = cuda.shared.array(100, dtype=float32)
611
+ i = cuda.grid(1)
612
+ if i == 0:
613
+ for j in range(100):
614
+ sm[j] = j
615
+ cuda.syncthreads()
616
+ ary[i] = sm[i]
617
+
618
+ shared_mem_per_block = simple_smem.get_shared_mem_per_block()
619
+ self.assertIsInstance(shared_mem_per_block, int)
620
+ self.assertEqual(shared_mem_per_block, 400)
621
+
622
+ def test_get_max_threads_per_block_unspecialized(self):
623
+ N = 10
624
+
625
+ @cuda.jit
626
+ def simple_maxthreads(ary):
627
+ i = cuda.grid(1)
628
+ ary[i] = i
629
+
630
+ arr_f32 = np.zeros(N, dtype=np.float32)
631
+ simple_maxthreads[1, 1](arr_f32)
632
+ sig_f32 = void(float32[::1])
633
+ max_threads_f32 = simple_maxthreads.get_max_threads_per_block(sig_f32)
634
+
635
+ self.assertIsInstance(max_threads_f32, int)
636
+ self.assertGreater(max_threads_f32, 0)
637
+
638
+ max_threads_f32_all = simple_maxthreads.get_max_threads_per_block()
639
+ self.assertEqual(max_threads_f32_all[sig_f32.args], max_threads_f32)
640
+
641
+ def test_get_local_mem_per_thread_unspecialized(self):
642
+ # NOTE: A large amount of local memory must be allocated
643
+ # otherwise the compiler will optimize out the call to
644
+ # cuda.local.array and use local registers instead
645
+ N = 1000
646
+
647
+ @cuda.jit
648
+ def simple_lmem(ary):
649
+ lm = cuda.local.array(N, dtype=ary.dtype)
650
+ for j in range(N):
651
+ lm[j] = j
652
+ for j in range(N):
653
+ ary[j] = lm[j]
654
+
655
+ # Call the kernel with different arguments to create two different
656
+ # definitions within the Dispatcher object
657
+ arr_f32 = np.zeros(N, dtype=np.float32)
658
+ arr_f64 = np.zeros(N, dtype=np.float64)
659
+
660
+ simple_lmem[1, 1](arr_f32)
661
+ simple_lmem[1, 1](arr_f64)
662
+
663
+ sig_f32 = void(float32[::1])
664
+ sig_f64 = void(float64[::1])
665
+ local_mem_f32 = simple_lmem.get_local_mem_per_thread(sig_f32)
666
+ local_mem_f64 = simple_lmem.get_local_mem_per_thread(sig_f64)
667
+ self.assertIsInstance(local_mem_f32, int)
668
+ self.assertIsInstance(local_mem_f64, int)
669
+
670
+ self.assertGreaterEqual(local_mem_f32, N * 4)
671
+ self.assertGreaterEqual(local_mem_f64, N * 8)
672
+
673
+ # Check that getting the local memory per thread for all signatures
674
+ # provides the same values as getting the shared mem per block for
675
+ # individual signatures.
676
+ local_mem_all = simple_lmem.get_local_mem_per_thread()
677
+ self.assertEqual(local_mem_all[sig_f32.args], local_mem_f32)
678
+ self.assertEqual(local_mem_all[sig_f64.args], local_mem_f64)
679
+
680
+ def test_get_local_mem_per_thread_specialized(self):
681
+ # NOTE: A large amount of local memory must be allocated
682
+ # otherwise the compiler will optimize out the call to
683
+ # cuda.local.array and use local registers instead
684
+ N = 1000
685
+
686
+ @cuda.jit(void(float32[::1]))
687
+ def simple_lmem(ary):
688
+ lm = cuda.local.array(N, dtype=ary.dtype)
689
+ for j in range(N):
690
+ lm[j] = j
691
+ for j in range(N):
692
+ ary[j] = lm[j]
693
+
694
+ local_mem_per_thread = simple_lmem.get_local_mem_per_thread()
695
+ self.assertIsInstance(local_mem_per_thread, int)
696
+ self.assertGreaterEqual(local_mem_per_thread, N * 4)
697
+
698
+
699
+ if __name__ == '__main__':
700
+ unittest.main()