numba-cuda 0.0.1__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. _numba_cuda_redirector.pth +1 -0
  2. _numba_cuda_redirector.py +74 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +5 -0
  5. numba_cuda/_version.py +19 -0
  6. numba_cuda/numba/cuda/__init__.py +22 -0
  7. numba_cuda/numba/cuda/api.py +526 -0
  8. numba_cuda/numba/cuda/api_util.py +30 -0
  9. numba_cuda/numba/cuda/args.py +77 -0
  10. numba_cuda/numba/cuda/cg.py +62 -0
  11. numba_cuda/numba/cuda/codegen.py +378 -0
  12. numba_cuda/numba/cuda/compiler.py +422 -0
  13. numba_cuda/numba/cuda/cpp_function_wrappers.cu +47 -0
  14. numba_cuda/numba/cuda/cuda_fp16.h +3631 -0
  15. numba_cuda/numba/cuda/cuda_fp16.hpp +2465 -0
  16. numba_cuda/numba/cuda/cuda_paths.py +258 -0
  17. numba_cuda/numba/cuda/cudadecl.py +806 -0
  18. numba_cuda/numba/cuda/cudadrv/__init__.py +9 -0
  19. numba_cuda/numba/cuda/cudadrv/devicearray.py +904 -0
  20. numba_cuda/numba/cuda/cudadrv/devices.py +248 -0
  21. numba_cuda/numba/cuda/cudadrv/driver.py +3201 -0
  22. numba_cuda/numba/cuda/cudadrv/drvapi.py +398 -0
  23. numba_cuda/numba/cuda/cudadrv/dummyarray.py +452 -0
  24. numba_cuda/numba/cuda/cudadrv/enums.py +607 -0
  25. numba_cuda/numba/cuda/cudadrv/error.py +36 -0
  26. numba_cuda/numba/cuda/cudadrv/libs.py +176 -0
  27. numba_cuda/numba/cuda/cudadrv/ndarray.py +20 -0
  28. numba_cuda/numba/cuda/cudadrv/nvrtc.py +260 -0
  29. numba_cuda/numba/cuda/cudadrv/nvvm.py +707 -0
  30. numba_cuda/numba/cuda/cudadrv/rtapi.py +10 -0
  31. numba_cuda/numba/cuda/cudadrv/runtime.py +142 -0
  32. numba_cuda/numba/cuda/cudaimpl.py +1055 -0
  33. numba_cuda/numba/cuda/cudamath.py +140 -0
  34. numba_cuda/numba/cuda/decorators.py +189 -0
  35. numba_cuda/numba/cuda/descriptor.py +33 -0
  36. numba_cuda/numba/cuda/device_init.py +89 -0
  37. numba_cuda/numba/cuda/deviceufunc.py +908 -0
  38. numba_cuda/numba/cuda/dispatcher.py +1057 -0
  39. numba_cuda/numba/cuda/errors.py +59 -0
  40. numba_cuda/numba/cuda/extending.py +7 -0
  41. numba_cuda/numba/cuda/initialize.py +13 -0
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +77 -0
  43. numba_cuda/numba/cuda/intrinsics.py +198 -0
  44. numba_cuda/numba/cuda/kernels/__init__.py +0 -0
  45. numba_cuda/numba/cuda/kernels/reduction.py +262 -0
  46. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  47. numba_cuda/numba/cuda/libdevice.py +3382 -0
  48. numba_cuda/numba/cuda/libdevicedecl.py +17 -0
  49. numba_cuda/numba/cuda/libdevicefuncs.py +1057 -0
  50. numba_cuda/numba/cuda/libdeviceimpl.py +83 -0
  51. numba_cuda/numba/cuda/mathimpl.py +448 -0
  52. numba_cuda/numba/cuda/models.py +48 -0
  53. numba_cuda/numba/cuda/nvvmutils.py +235 -0
  54. numba_cuda/numba/cuda/printimpl.py +86 -0
  55. numba_cuda/numba/cuda/random.py +292 -0
  56. numba_cuda/numba/cuda/simulator/__init__.py +38 -0
  57. numba_cuda/numba/cuda/simulator/api.py +110 -0
  58. numba_cuda/numba/cuda/simulator/compiler.py +9 -0
  59. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +2 -0
  60. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +432 -0
  61. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +117 -0
  62. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +62 -0
  63. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +4 -0
  64. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +4 -0
  65. numba_cuda/numba/cuda/simulator/cudadrv/error.py +6 -0
  66. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +2 -0
  67. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +29 -0
  68. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +19 -0
  69. numba_cuda/numba/cuda/simulator/kernel.py +308 -0
  70. numba_cuda/numba/cuda/simulator/kernelapi.py +495 -0
  71. numba_cuda/numba/cuda/simulator/reduction.py +15 -0
  72. numba_cuda/numba/cuda/simulator/vector_types.py +58 -0
  73. numba_cuda/numba/cuda/simulator_init.py +17 -0
  74. numba_cuda/numba/cuda/stubs.py +902 -0
  75. numba_cuda/numba/cuda/target.py +440 -0
  76. numba_cuda/numba/cuda/testing.py +202 -0
  77. numba_cuda/numba/cuda/tests/__init__.py +58 -0
  78. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +8 -0
  79. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +145 -0
  80. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +145 -0
  81. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +375 -0
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +21 -0
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +179 -0
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +235 -0
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +22 -0
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +193 -0
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +547 -0
  88. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +249 -0
  89. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +81 -0
  90. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +192 -0
  91. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +38 -0
  92. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +65 -0
  93. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +139 -0
  94. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +37 -0
  95. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +12 -0
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +317 -0
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +127 -0
  98. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +54 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +199 -0
  100. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +37 -0
  101. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +20 -0
  102. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +149 -0
  103. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +36 -0
  104. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +85 -0
  105. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +41 -0
  106. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +122 -0
  107. numba_cuda/numba/cuda/tests/cudapy/__init__.py +8 -0
  108. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +234 -0
  109. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +41 -0
  110. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +58 -0
  111. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +30 -0
  112. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +100 -0
  113. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +42 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_array.py +260 -0
  115. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +201 -0
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +35 -0
  117. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1620 -0
  118. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +120 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +24 -0
  120. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +545 -0
  121. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +257 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +33 -0
  123. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +276 -0
  124. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +296 -0
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +20 -0
  126. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +129 -0
  127. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +176 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +147 -0
  129. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +435 -0
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +90 -0
  131. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +94 -0
  132. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +101 -0
  133. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +221 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +222 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +700 -0
  136. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +121 -0
  137. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +79 -0
  138. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +174 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +155 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +244 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +52 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +29 -0
  143. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +66 -0
  144. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +60 -0
  145. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +456 -0
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +159 -0
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +95 -0
  148. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +37 -0
  149. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +165 -0
  150. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1106 -0
  151. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +318 -0
  152. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +99 -0
  153. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +64 -0
  154. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +119 -0
  155. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +187 -0
  156. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +199 -0
  157. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +164 -0
  158. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +37 -0
  159. numba_cuda/numba/cuda/tests/cudapy/test_math.py +786 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +74 -0
  161. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +113 -0
  162. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +22 -0
  163. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +140 -0
  164. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +46 -0
  165. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +101 -0
  166. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +49 -0
  167. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +401 -0
  168. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +86 -0
  169. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +335 -0
  170. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +124 -0
  171. numba_cuda/numba/cuda/tests/cudapy/test_print.py +128 -0
  172. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +33 -0
  173. numba_cuda/numba/cuda/tests/cudapy/test_random.py +104 -0
  174. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +610 -0
  175. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +125 -0
  176. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +76 -0
  177. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  178. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +85 -0
  179. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +37 -0
  180. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +444 -0
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +205 -0
  182. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +271 -0
  183. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +80 -0
  184. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +277 -0
  185. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +47 -0
  186. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +307 -0
  187. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +283 -0
  188. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +20 -0
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +69 -0
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +36 -0
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +37 -0
  192. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +139 -0
  193. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +276 -0
  194. numba_cuda/numba/cuda/tests/cudasim/__init__.py +6 -0
  195. numba_cuda/numba/cuda/tests/cudasim/support.py +6 -0
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +102 -0
  197. numba_cuda/numba/cuda/tests/data/__init__.py +0 -0
  198. numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
  199. numba_cuda/numba/cuda/tests/data/error.cu +7 -0
  200. numba_cuda/numba/cuda/tests/data/jitlink.cu +23 -0
  201. numba_cuda/numba/cuda/tests/data/jitlink.ptx +51 -0
  202. numba_cuda/numba/cuda/tests/data/warn.cu +7 -0
  203. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +6 -0
  204. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +0 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +49 -0
  206. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +77 -0
  207. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +76 -0
  208. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +82 -0
  209. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +155 -0
  210. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +173 -0
  211. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +109 -0
  212. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +59 -0
  213. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +76 -0
  214. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +130 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +50 -0
  216. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +73 -0
  217. numba_cuda/numba/cuda/tests/nocuda/__init__.py +8 -0
  218. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +359 -0
  219. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +36 -0
  220. numba_cuda/numba/cuda/tests/nocuda/test_import.py +49 -0
  221. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +238 -0
  222. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +54 -0
  223. numba_cuda/numba/cuda/types.py +37 -0
  224. numba_cuda/numba/cuda/ufuncs.py +662 -0
  225. numba_cuda/numba/cuda/vector_types.py +209 -0
  226. numba_cuda/numba/cuda/vectorizers.py +252 -0
  227. numba_cuda-0.0.13.dist-info/LICENSE +25 -0
  228. numba_cuda-0.0.13.dist-info/METADATA +69 -0
  229. numba_cuda-0.0.13.dist-info/RECORD +231 -0
  230. {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.13.dist-info}/WHEEL +1 -1
  231. numba_cuda-0.0.1.dist-info/METADATA +0 -10
  232. numba_cuda-0.0.1.dist-info/RECORD +0 -5
  233. {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ """CUDA Toolkit libraries lookup utilities.
2
+
3
+ CUDA Toolkit libraries can be available via either:
4
+
5
+ - the `cuda-nvcc` and `cuda-nvrtc` conda packages for CUDA 12,
6
+ - the `cudatoolkit` conda package for CUDA 11,
7
+ - a user supplied location from CUDA_HOME,
8
+ - a system wide location,
9
+ - package-specific locations (e.g. the Debian NVIDIA packages),
10
+ - or can be discovered by the system loader.
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import ctypes
16
+
17
+ from numba.misc.findlib import find_lib
18
+ from numba.cuda.cuda_paths import get_cuda_paths
19
+ from numba.cuda.cudadrv.driver import locate_driver_and_loader, load_driver
20
+ from numba.cuda.cudadrv.error import CudaSupportError
21
+
22
+
23
+ if sys.platform == 'win32':
24
+ _dllnamepattern = '%s.dll'
25
+ _staticnamepattern = '%s.lib'
26
+ elif sys.platform == 'darwin':
27
+ _dllnamepattern = 'lib%s.dylib'
28
+ _staticnamepattern = 'lib%s.a'
29
+ else:
30
+ _dllnamepattern = 'lib%s.so'
31
+ _staticnamepattern = 'lib%s.a'
32
+
33
+
34
+ def get_libdevice():
35
+ d = get_cuda_paths()
36
+ paths = d['libdevice'].info
37
+ return paths
38
+
39
+
40
+ def open_libdevice():
41
+ with open(get_libdevice(), 'rb') as bcfile:
42
+ return bcfile.read()
43
+
44
+
45
+ def get_cudalib(lib, static=False):
46
+ """
47
+ Find the path of a CUDA library based on a search of known locations. If
48
+ the search fails, return a generic filename for the library (e.g.
49
+ 'libnvvm.so' for 'nvvm') so that we may attempt to load it using the system
50
+ loader's search mechanism.
51
+ """
52
+ if lib == 'nvvm':
53
+ return get_cuda_paths()['nvvm'].info or _dllnamepattern % 'nvvm'
54
+ else:
55
+ dir_type = 'static_cudalib_dir' if static else 'cudalib_dir'
56
+ libdir = get_cuda_paths()[dir_type].info
57
+
58
+ candidates = find_lib(lib, libdir, static=static)
59
+ namepattern = _staticnamepattern if static else _dllnamepattern
60
+ return max(candidates) if candidates else namepattern % lib
61
+
62
+
63
+ def open_cudalib(lib):
64
+ path = get_cudalib(lib)
65
+ return ctypes.CDLL(path)
66
+
67
+
68
+ def check_static_lib(path):
69
+ if not os.path.isfile(path):
70
+ raise FileNotFoundError(f'{path} not found')
71
+
72
+
73
+ def _get_source_variable(lib, static=False):
74
+ if lib == 'nvvm':
75
+ return get_cuda_paths()['nvvm'].by
76
+ elif lib == 'libdevice':
77
+ return get_cuda_paths()['libdevice'].by
78
+ else:
79
+ dir_type = 'static_cudalib_dir' if static else 'cudalib_dir'
80
+ return get_cuda_paths()[dir_type].by
81
+
82
+
83
+ def test():
84
+ """Test library lookup. Path info is printed to stdout.
85
+ """
86
+ failed = False
87
+
88
+ # Check for the driver
89
+ try:
90
+ dlloader, candidates = locate_driver_and_loader()
91
+ print('Finding driver from candidates:')
92
+ for location in candidates:
93
+ print(f'\t{location}')
94
+ print(f'Using loader {dlloader}')
95
+ print('\tTrying to load driver', end='...')
96
+ dll, path = load_driver(dlloader, candidates)
97
+ print('\tok')
98
+ print(f'\t\tLoaded from {path}')
99
+ except CudaSupportError as e:
100
+ print(f'\tERROR: failed to open driver: {e}')
101
+ failed = True
102
+
103
+ # Find the absolute location of the driver on Linux. Various driver-related
104
+ # issues have been reported by WSL2 users, and it is almost always due to a
105
+ # Linux (i.e. not- WSL2) driver being installed in a WSL2 system.
106
+ # Providing the absolute location of the driver indicates its version
107
+ # number in the soname (e.g. "libcuda.so.530.30.02"), which can be used to
108
+ # look up whether the driver was intended for "native" Linux.
109
+ if sys.platform == 'linux' and not failed:
110
+ pid = os.getpid()
111
+ mapsfile = os.path.join(os.path.sep, 'proc', f'{pid}', 'maps')
112
+ try:
113
+ with open(mapsfile) as f:
114
+ maps = f.read()
115
+ # It's difficult to predict all that might go wrong reading the maps
116
+ # file - in case various error conditions ensue (the file is not found,
117
+ # not readable, etc.) we use OSError to hopefully catch any of them.
118
+ except OSError:
119
+ # It's helpful to report that this went wrong to the user, but we
120
+ # don't set failed to True because this doesn't have any connection
121
+ # to actual CUDA functionality.
122
+ print(f'\tERROR: Could not open {mapsfile} to determine absolute '
123
+ 'path to libcuda.so')
124
+ else:
125
+ # In this case we could read the maps, so we can report the
126
+ # relevant ones to the user
127
+ locations = set(s for s in maps.split() if 'libcuda.so' in s)
128
+ print('\tMapped libcuda.so paths:')
129
+ for location in locations:
130
+ print(f'\t\t{location}')
131
+
132
+ # Checks for dynamic libraries
133
+ libs = 'nvvm nvrtc cudart'.split()
134
+ for lib in libs:
135
+ path = get_cudalib(lib)
136
+ print('Finding {} from {}'.format(lib, _get_source_variable(lib)))
137
+ print('\tLocated at', path)
138
+
139
+ try:
140
+ print('\tTrying to open library', end='...')
141
+ open_cudalib(lib)
142
+ print('\tok')
143
+ except OSError as e:
144
+ print('\tERROR: failed to open %s:\n%s' % (lib, e))
145
+ failed = True
146
+
147
+ # Check for cudadevrt (the only static library)
148
+ lib = 'cudadevrt'
149
+ path = get_cudalib(lib, static=True)
150
+ print('Finding {} from {}'.format(lib, _get_source_variable(lib,
151
+ static=True)))
152
+ print('\tLocated at', path)
153
+
154
+ try:
155
+ print('\tChecking library', end='...')
156
+ check_static_lib(path)
157
+ print('\tok')
158
+ except FileNotFoundError as e:
159
+ print('\tERROR: failed to find %s:\n%s' % (lib, e))
160
+ failed = True
161
+
162
+ # Check for libdevice
163
+ where = _get_source_variable('libdevice')
164
+ print(f'Finding libdevice from {where}')
165
+ path = get_libdevice()
166
+ print('\tLocated at', path)
167
+
168
+ try:
169
+ print('\tChecking library', end='...')
170
+ check_static_lib(path)
171
+ print('\tok')
172
+ except FileNotFoundError as e:
173
+ print('\tERROR: failed to find %s:\n%s' % (lib, e))
174
+ failed = True
175
+
176
+ return not failed
@@ -0,0 +1,20 @@
1
+ from numba.cuda.cudadrv import devices, driver
2
+ from numba.core.registry import cpu_target
3
+
4
+
5
+ def _calc_array_sizeof(ndim):
6
+ """
7
+ Use the ABI size in the CPU target
8
+ """
9
+ ctx = cpu_target.target_context
10
+ return ctx.calc_array_sizeof(ndim)
11
+
12
+
13
+ def ndarray_device_allocate_data(ary):
14
+ """
15
+ Allocate gpu data buffer
16
+ """
17
+ datasize = driver.host_memory_size(ary)
18
+ # allocate
19
+ gpu_data = devices.get_context().memalloc(datasize)
20
+ return gpu_data
@@ -0,0 +1,260 @@
1
+ from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
2
+ from enum import IntEnum
3
+ from numba.core import config
4
+ from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
5
+ NvrtcSupportError)
6
+
7
+ import functools
8
+ import os
9
+ import threading
10
+ import warnings
11
+
12
+ # Opaque handle for compilation unit
13
+ nvrtc_program = c_void_p
14
+
15
+ # Result code
16
+ nvrtc_result = c_int
17
+
18
+
19
+ class NvrtcResult(IntEnum):
20
+ NVRTC_SUCCESS = 0
21
+ NVRTC_ERROR_OUT_OF_MEMORY = 1
22
+ NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2
23
+ NVRTC_ERROR_INVALID_INPUT = 3
24
+ NVRTC_ERROR_INVALID_PROGRAM = 4
25
+ NVRTC_ERROR_INVALID_OPTION = 5
26
+ NVRTC_ERROR_COMPILATION = 6
27
+ NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7
28
+ NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8
29
+ NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9
30
+ NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10
31
+ NVRTC_ERROR_INTERNAL_ERROR = 11
32
+
33
+
34
+ _nvrtc_lock = threading.Lock()
35
+
36
+
37
+ class NvrtcProgram:
38
+ """
39
+ A class for managing the lifetime of nvrtcProgram instances. Instances of
40
+ the class own an nvrtcProgram; when an instance is deleted, the underlying
41
+ nvrtcProgram is destroyed using the appropriate NVRTC API.
42
+ """
43
+ def __init__(self, nvrtc, handle):
44
+ self._nvrtc = nvrtc
45
+ self._handle = handle
46
+
47
+ @property
48
+ def handle(self):
49
+ return self._handle
50
+
51
+ def __del__(self):
52
+ if self._handle:
53
+ self._nvrtc.destroy_program(self)
54
+
55
+
56
+ class NVRTC:
57
+ """
58
+ Provides a Pythonic interface to the NVRTC APIs, abstracting away the C API
59
+ calls.
60
+
61
+ The sole instance of this class is a process-wide singleton, similar to the
62
+ NVVM interface. Initialization is protected by a lock and uses the standard
63
+ (for Numba) open_cudalib function to load the NVRTC library.
64
+ """
65
+ _PROTOTYPES = {
66
+ # nvrtcResult nvrtcVersion(int *major, int *minor)
67
+ 'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
68
+ # nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
69
+ # const char *src,
70
+ # const char *name,
71
+ # int numHeaders,
72
+ # const char * const *headers,
73
+ # const char * const *includeNames)
74
+ 'nvrtcCreateProgram': (nvrtc_result, nvrtc_program, c_char_p, c_char_p,
75
+ c_int, POINTER(c_char_p), POINTER(c_char_p)),
76
+ # nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
77
+ 'nvrtcDestroyProgram': (nvrtc_result, POINTER(nvrtc_program)),
78
+ # nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
79
+ # int numOptions,
80
+ # const char * const *options)
81
+ 'nvrtcCompileProgram': (nvrtc_result, nvrtc_program, c_int,
82
+ POINTER(c_char_p)),
83
+ # nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
84
+ 'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
85
+ # nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
86
+ 'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p),
87
+ # nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
88
+ # size_t *cubinSizeRet);
89
+ 'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
90
+ # nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
91
+ 'nvrtcGetCUBIN': (nvrtc_result, nvrtc_program, c_char_p),
92
+ # nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog,
93
+ # size_t *logSizeRet);
94
+ 'nvrtcGetProgramLogSize': (nvrtc_result, nvrtc_program,
95
+ POINTER(c_size_t)),
96
+ # nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
97
+ 'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p),
98
+ }
99
+
100
+ # Singleton reference
101
+ __INSTANCE = None
102
+
103
+ def __new__(cls):
104
+ with _nvrtc_lock:
105
+ if cls.__INSTANCE is None:
106
+ from numba.cuda.cudadrv.libs import open_cudalib
107
+ cls.__INSTANCE = inst = object.__new__(cls)
108
+ try:
109
+ lib = open_cudalib('nvrtc')
110
+ except OSError as e:
111
+ cls.__INSTANCE = None
112
+ raise NvrtcSupportError("NVRTC cannot be loaded") from e
113
+
114
+ # Find & populate functions
115
+ for name, proto in inst._PROTOTYPES.items():
116
+ func = getattr(lib, name)
117
+ func.restype = proto[0]
118
+ func.argtypes = proto[1:]
119
+
120
+ @functools.wraps(func)
121
+ def checked_call(*args, func=func, name=name):
122
+ error = func(*args)
123
+ if error == NvrtcResult.NVRTC_ERROR_COMPILATION:
124
+ raise NvrtcCompilationError()
125
+ elif error != NvrtcResult.NVRTC_SUCCESS:
126
+ try:
127
+ error_name = NvrtcResult(error).name
128
+ except ValueError:
129
+ error_name = ('Unknown nvrtc_result '
130
+ f'(error code: {error})')
131
+ msg = f'Failed to call {name}: {error_name}'
132
+ raise NvrtcError(msg)
133
+
134
+ setattr(inst, name, checked_call)
135
+
136
+ return cls.__INSTANCE
137
+
138
+ def get_version(self):
139
+ """
140
+ Get the NVRTC version as a tuple (major, minor).
141
+ """
142
+ major = c_int()
143
+ minor = c_int()
144
+ self.nvrtcVersion(byref(major), byref(minor))
145
+ return major.value, minor.value
146
+
147
+ def create_program(self, src, name):
148
+ """
149
+ Create an NVRTC program with managed lifetime.
150
+ """
151
+ if isinstance(src, str):
152
+ src = src.encode()
153
+ if isinstance(name, str):
154
+ name = name.encode()
155
+
156
+ handle = nvrtc_program()
157
+
158
+ # The final three arguments are for passing the contents of headers -
159
+ # this is not supported, so there are 0 headers and the header names
160
+ # and contents are null.
161
+ self.nvrtcCreateProgram(byref(handle), src, name, 0, None, None)
162
+ return NvrtcProgram(self, handle)
163
+
164
+ def compile_program(self, program, options):
165
+ """
166
+ Compile an NVRTC program. Compilation may fail due to a user error in
167
+ the source; this function returns ``True`` if there is a compilation
168
+ error and ``False`` on success.
169
+ """
170
+ # We hold a list of encoded options to ensure they can't be collected
171
+ # prior to the call to nvrtcCompileProgram
172
+ encoded_options = [opt.encode() for opt in options]
173
+ option_pointers = [c_char_p(opt) for opt in encoded_options]
174
+ c_options_type = (c_char_p * len(options))
175
+ c_options = c_options_type(*option_pointers)
176
+ try:
177
+ self.nvrtcCompileProgram(program.handle, len(options), c_options)
178
+ return False
179
+ except NvrtcCompilationError:
180
+ return True
181
+
182
+ def destroy_program(self, program):
183
+ """
184
+ Destroy an NVRTC program.
185
+ """
186
+ self.nvrtcDestroyProgram(byref(program.handle))
187
+
188
+ def get_compile_log(self, program):
189
+ """
190
+ Get the compile log as a Python string.
191
+ """
192
+ log_size = c_size_t()
193
+ self.nvrtcGetProgramLogSize(program.handle, byref(log_size))
194
+
195
+ log = (c_char * log_size.value)()
196
+ self.nvrtcGetProgramLog(program.handle, log)
197
+
198
+ return log.value.decode()
199
+
200
+ def get_ptx(self, program):
201
+ """
202
+ Get the compiled PTX as a Python string.
203
+ """
204
+ ptx_size = c_size_t()
205
+ self.nvrtcGetPTXSize(program.handle, byref(ptx_size))
206
+
207
+ ptx = (c_char * ptx_size.value)()
208
+ self.nvrtcGetPTX(program.handle, ptx)
209
+
210
+ return ptx.value.decode()
211
+
212
+
213
+ def compile(src, name, cc):
214
+ """
215
+ Compile a CUDA C/C++ source to PTX for a given compute capability.
216
+
217
+ :param src: The source code to compile
218
+ :type src: str
219
+ :param name: The filename of the source (for information only)
220
+ :type name: str
221
+ :param cc: A tuple ``(major, minor)`` of the compute capability
222
+ :type cc: tuple
223
+ :return: The compiled PTX and compilation log
224
+ :rtype: tuple
225
+ """
226
+ nvrtc = NVRTC()
227
+ program = nvrtc.create_program(src, name)
228
+
229
+ # Compilation options:
230
+ # - Compile for the current device's compute capability.
231
+ # - The CUDA include path is added.
232
+ # - Relocatable Device Code (rdc) is needed to prevent device functions
233
+ # being optimized away.
234
+ major, minor = cc
235
+ arch = f'--gpu-architecture=compute_{major}{minor}'
236
+ include = f'-I{config.CUDA_INCLUDE_PATH}'
237
+
238
+ cudadrv_path = os.path.dirname(os.path.abspath(__file__))
239
+ numba_cuda_path = os.path.dirname(cudadrv_path)
240
+ numba_include = f'-I{numba_cuda_path}'
241
+ options = [arch, include, numba_include, '-rdc', 'true']
242
+
243
+ # Compile the program
244
+ compile_error = nvrtc.compile_program(program, options)
245
+
246
+ # Get log from compilation
247
+ log = nvrtc.get_compile_log(program)
248
+
249
+ # If the compile failed, provide the log in an exception
250
+ if compile_error:
251
+ msg = (f'NVRTC Compilation failure whilst compiling {name}:\n\n{log}')
252
+ raise NvrtcError(msg)
253
+
254
+ # Otherwise, if there's any content in the log, present it as a warning
255
+ if log:
256
+ msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
257
+ warnings.warn(msg)
258
+
259
+ ptx = nvrtc.get_ptx(program)
260
+ return ptx, log