numba-cuda 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ subsequent deallocation could further corrupt the CUDA context and causes the
10
10
  system to freeze in some cases.
11
11
 
12
12
  """
13
+
13
14
  import sys
14
15
  import os
15
16
  import ctypes
@@ -25,8 +26,17 @@ import tempfile
25
26
  import re
26
27
  from itertools import product
27
28
  from abc import ABCMeta, abstractmethod
28
- from ctypes import (c_int, byref, c_size_t, c_char, c_char_p, addressof,
29
- c_void_p, c_float, c_uint)
29
+ from ctypes import (
30
+ c_int,
31
+ byref,
32
+ c_size_t,
33
+ c_char,
34
+ c_char_p,
35
+ addressof,
36
+ c_void_p,
37
+ c_float,
38
+ c_uint,
39
+ )
30
40
  import contextlib
31
41
  import importlib
32
42
  import numpy as np
@@ -51,13 +61,14 @@ USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
51
61
 
52
62
  if USE_NV_BINDING:
53
63
  from cuda import cuda as binding
64
+
54
65
  # There is no definition of the default stream in the Nvidia bindings (nor
55
66
  # is there at the C/C++ level), so we define it here so we don't need to
56
67
  # use a magic number 0 in places where we want the default stream.
57
68
  CU_STREAM_DEFAULT = 0
58
69
 
59
70
  MIN_REQUIRED_CC = (3, 5)
60
- SUPPORTS_IPC = sys.platform.startswith('linux')
71
+ SUPPORTS_IPC = sys.platform.startswith("linux")
61
72
 
62
73
 
63
74
  _py_decref = ctypes.pythonapi.Py_DecRef
@@ -71,10 +82,9 @@ _MVC_ERROR_MESSAGE = (
71
82
  "to be available"
72
83
  )
73
84
 
74
- ENABLE_PYNVJITLINK = (
75
- _readenv("NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, False)
76
- or getattr(config, "CUDA_ENABLE_PYNVJITLINK", False)
77
- )
85
+ ENABLE_PYNVJITLINK = _readenv(
86
+ "NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, False
87
+ ) or getattr(config, "CUDA_ENABLE_PYNVJITLINK", False)
78
88
  if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"):
79
89
  config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK
80
90
 
@@ -94,7 +104,7 @@ def make_logger():
94
104
  if config.CUDA_LOG_LEVEL:
95
105
  # create a simple handler that prints to stderr
96
106
  handler = logging.StreamHandler(sys.stderr)
97
- fmt = '== CUDA [%(relativeCreated)d] %(levelname)5s -- %(message)s'
107
+ fmt = "== CUDA [%(relativeCreated)d] %(levelname)5s -- %(message)s"
98
108
  handler.setFormatter(logging.Formatter(fmt=fmt))
99
109
  logger.addHandler(handler)
100
110
  else:
@@ -122,50 +132,52 @@ class CudaAPIError(CudaDriverError):
122
132
 
123
133
 
124
134
  def locate_driver_and_loader():
125
-
126
135
  envpath = config.CUDA_DRIVER
127
136
 
128
- if envpath == '0':
137
+ if envpath == "0":
129
138
  # Force fail
130
139
  _raise_driver_not_found()
131
140
 
132
141
  # Determine DLL type
133
- if sys.platform == 'win32':
142
+ if sys.platform == "win32":
134
143
  dlloader = ctypes.WinDLL
135
- dldir = ['\\windows\\system32']
136
- dlnames = ['nvcuda.dll']
137
- elif sys.platform == 'darwin':
144
+ dldir = ["\\windows\\system32"]
145
+ dlnames = ["nvcuda.dll"]
146
+ elif sys.platform == "darwin":
138
147
  dlloader = ctypes.CDLL
139
- dldir = ['/usr/local/cuda/lib']
140
- dlnames = ['libcuda.dylib']
148
+ dldir = ["/usr/local/cuda/lib"]
149
+ dlnames = ["libcuda.dylib"]
141
150
  else:
142
151
  # Assume to be *nix like
143
152
  dlloader = ctypes.CDLL
144
- dldir = ['/usr/lib', '/usr/lib64']
145
- dlnames = ['libcuda.so', 'libcuda.so.1']
153
+ dldir = ["/usr/lib", "/usr/lib64"]
154
+ dlnames = ["libcuda.so", "libcuda.so.1"]
146
155
 
147
156
  if envpath:
148
157
  try:
149
158
  envpath = os.path.abspath(envpath)
150
159
  except ValueError:
151
- raise ValueError("NUMBA_CUDA_DRIVER %s is not a valid path" %
152
- envpath)
160
+ raise ValueError(
161
+ "NUMBA_CUDA_DRIVER %s is not a valid path" % envpath
162
+ )
153
163
  if not os.path.isfile(envpath):
154
- raise ValueError("NUMBA_CUDA_DRIVER %s is not a valid file "
155
- "path. Note it must be a filepath of the .so/"
156
- ".dll/.dylib or the driver" % envpath)
164
+ raise ValueError(
165
+ "NUMBA_CUDA_DRIVER %s is not a valid file "
166
+ "path. Note it must be a filepath of the .so/"
167
+ ".dll/.dylib or the driver" % envpath
168
+ )
157
169
  candidates = [envpath]
158
170
  else:
159
171
  # First search for the name in the default library path.
160
172
  # If that is not found, try the specific path.
161
- candidates = dlnames + [os.path.join(x, y)
162
- for x, y in product(dldir, dlnames)]
173
+ candidates = dlnames + [
174
+ os.path.join(x, y) for x, y in product(dldir, dlnames)
175
+ ]
163
176
 
164
177
  return dlloader, candidates
165
178
 
166
179
 
167
180
  def load_driver(dlloader, candidates):
168
-
169
181
  # Load the driver; Collect driver error information
170
182
  path_not_exist = []
171
183
  driver_load_error = []
@@ -184,7 +196,7 @@ def load_driver(dlloader, candidates):
184
196
  if all(path_not_exist):
185
197
  _raise_driver_not_found()
186
198
  else:
187
- errmsg = '\n'.join(str(e) for e in driver_load_error)
199
+ errmsg = "\n".join(str(e) for e in driver_load_error)
188
200
  _raise_driver_error(errmsg)
189
201
 
190
202
 
@@ -216,7 +228,7 @@ def _raise_driver_error(e):
216
228
 
217
229
 
218
230
  def _build_reverse_error_map():
219
- prefix = 'CUDA_ERROR'
231
+ prefix = "CUDA_ERROR"
220
232
  map = utils.UniqueDict()
221
233
  for name in dir(enums):
222
234
  if name.startswith(prefix):
@@ -236,6 +248,7 @@ class Driver(object):
236
248
  """
237
249
  Driver API functions are lazily bound.
238
250
  """
251
+
239
252
  _singleton = None
240
253
 
241
254
  def __new__(cls):
@@ -254,9 +267,11 @@ class Driver(object):
254
267
  self.pid = None
255
268
  try:
256
269
  if config.DISABLE_CUDA:
257
- msg = ("CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 "
258
- "in the environment, or because CUDA is unsupported on "
259
- "32-bit systems.")
270
+ msg = (
271
+ "CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 "
272
+ "in the environment, or because CUDA is unsupported on "
273
+ "32-bit systems."
274
+ )
260
275
  raise CudaSupportError(msg)
261
276
  self.lib = find_driver()
262
277
  except CudaSupportError as e:
@@ -273,7 +288,7 @@ class Driver(object):
273
288
 
274
289
  self.is_initialized = True
275
290
  try:
276
- _logger.info('init')
291
+ _logger.info("init")
277
292
  self.cuInit(0)
278
293
  except CudaAPIError as e:
279
294
  description = f"{e.msg} ({e.code})"
@@ -292,8 +307,9 @@ class Driver(object):
292
307
  self.ensure_initialized()
293
308
 
294
309
  if self.initialization_error is not None:
295
- raise CudaSupportError("Error at driver init: \n%s:" %
296
- self.initialization_error)
310
+ raise CudaSupportError(
311
+ "Error at driver init: \n%s:" % self.initialization_error
312
+ )
297
313
 
298
314
  if USE_NV_BINDING:
299
315
  return self._cuda_python_wrap_fn(fname)
@@ -317,12 +333,12 @@ class Driver(object):
317
333
 
318
334
  def verbose_cuda_api_call(*args):
319
335
  argstr = ", ".join([str(arg) for arg in args])
320
- _logger.debug('call driver api: %s(%s)', libfn.__name__, argstr)
336
+ _logger.debug("call driver api: %s(%s)", libfn.__name__, argstr)
321
337
  retcode = libfn(*args)
322
338
  self._check_ctypes_error(fname, retcode)
323
339
 
324
340
  def safe_cuda_api_call(*args):
325
- _logger.debug('call driver api: %s', libfn.__name__)
341
+ _logger.debug("call driver api: %s", libfn.__name__)
326
342
  retcode = libfn(*args)
327
343
  self._check_ctypes_error(fname, retcode)
328
344
 
@@ -340,11 +356,11 @@ class Driver(object):
340
356
 
341
357
  def verbose_cuda_api_call(*args):
342
358
  argstr = ", ".join([str(arg) for arg in args])
343
- _logger.debug('call driver api: %s(%s)', libfn.__name__, argstr)
359
+ _logger.debug("call driver api: %s(%s)", libfn.__name__, argstr)
344
360
  return self._check_cuda_python_error(fname, libfn(*args))
345
361
 
346
362
  def safe_cuda_api_call(*args):
347
- _logger.debug('call driver api: %s', libfn.__name__)
363
+ _logger.debug("call driver api: %s", libfn.__name__)
348
364
  return self._check_cuda_python_error(fname, libfn(*args))
349
365
 
350
366
  if config.CUDA_LOG_API_ARGS:
@@ -361,30 +377,30 @@ class Driver(object):
361
377
  # binding. For the NVidia binding, it handles linking to the correct
362
378
  # variant.
363
379
  if config.CUDA_PER_THREAD_DEFAULT_STREAM and not USE_NV_BINDING:
364
- variants = ('_v2_ptds', '_v2_ptsz', '_ptds', '_ptsz', '_v2', '')
380
+ variants = ("_v2_ptds", "_v2_ptsz", "_ptds", "_ptsz", "_v2", "")
365
381
  else:
366
- variants = ('_v2', '')
382
+ variants = ("_v2", "")
367
383
 
368
384
  if fname in ("cuCtxGetDevice", "cuCtxSynchronize"):
369
385
  return getattr(self.lib, fname)
370
386
 
371
387
  for variant in variants:
372
388
  try:
373
- return getattr(self.lib, f'{fname}{variant}')
389
+ return getattr(self.lib, f"{fname}{variant}")
374
390
  except AttributeError:
375
391
  pass
376
392
 
377
393
  # Not found.
378
394
  # Delay missing function error to use
379
395
  def absent_function(*args, **kws):
380
- raise CudaDriverError(f'Driver missing function: {fname}')
396
+ raise CudaDriverError(f"Driver missing function: {fname}")
381
397
 
382
398
  setattr(self, fname, absent_function)
383
399
  return absent_function
384
400
 
385
401
  def _detect_fork(self):
386
402
  if self.pid is not None and _getpid() != self.pid:
387
- msg = 'pid %s forked from pid %s after CUDA driver init'
403
+ msg = "pid %s forked from pid %s after CUDA driver init"
388
404
  _logger.critical(msg, _getpid(), self.pid)
389
405
  raise CudaDriverError("CUDA initialized before forking")
390
406
 
@@ -428,13 +444,11 @@ class Driver(object):
428
444
  return count.value
429
445
 
430
446
  def list_devices(self):
431
- """Returns a list of active devices
432
- """
447
+ """Returns a list of active devices"""
433
448
  return list(self.devices.values())
434
449
 
435
450
  def reset(self):
436
- """Reset all devices
437
- """
451
+ """Reset all devices"""
438
452
  for dev in self.devices.values():
439
453
  dev.reset()
440
454
 
@@ -452,8 +466,7 @@ class Driver(object):
452
466
  return popped
453
467
 
454
468
  def get_active_context(self):
455
- """Returns an instance of ``_ActiveContext``.
456
- """
469
+ """Returns an instance of ``_ActiveContext``."""
457
470
  return _ActiveContext()
458
471
 
459
472
  def get_version(self):
@@ -480,12 +493,13 @@ class _ActiveContext(object):
480
493
  Once entering the context, it is assumed that the active CUDA context is
481
494
  not changed until the context is exited.
482
495
  """
496
+
483
497
  _tls_cache = threading.local()
484
498
 
485
499
  def __enter__(self):
486
500
  is_top = False
487
501
  # check TLS cache
488
- if hasattr(self._tls_cache, 'ctx_devnum'):
502
+ if hasattr(self._tls_cache, "ctx_devnum"):
489
503
  hctx, devnum = self._tls_cache.ctx_devnum
490
504
  # Not cached. Query the driver API.
491
505
  else:
@@ -518,11 +532,10 @@ class _ActiveContext(object):
518
532
 
519
533
  def __exit__(self, exc_type, exc_val, exc_tb):
520
534
  if self._is_top:
521
- delattr(self._tls_cache, 'ctx_devnum')
535
+ delattr(self._tls_cache, "ctx_devnum")
522
536
 
523
537
  def __bool__(self):
524
- """Returns True is there's a valid and active CUDA context.
525
- """
538
+ """Returns True is there's a valid and active CUDA context."""
526
539
  return self.context_handle is not None
527
540
 
528
541
  __nonzero__ = __bool__
@@ -536,7 +549,7 @@ def _build_reverse_device_attrs():
536
549
  map = utils.UniqueDict()
537
550
  for name in dir(enums):
538
551
  if name.startswith(prefix):
539
- map[name[len(prefix):]] = getattr(enums, name)
552
+ map[name[len(prefix) :]] = getattr(enums, name)
540
553
  return map
541
554
 
542
555
 
@@ -548,6 +561,7 @@ class Device(object):
548
561
  The device object owns the CUDA contexts. This is owned by the driver
549
562
  object. User should not construct devices directly.
550
563
  """
564
+
551
565
  @classmethod
552
566
  def from_identity(self, identity):
553
567
  """Create Device object from device identity created by
@@ -582,15 +596,17 @@ class Device(object):
582
596
  self.attributes = {}
583
597
 
584
598
  # Read compute capability
585
- self.compute_capability = (self.COMPUTE_CAPABILITY_MAJOR,
586
- self.COMPUTE_CAPABILITY_MINOR)
599
+ self.compute_capability = (
600
+ self.COMPUTE_CAPABILITY_MAJOR,
601
+ self.COMPUTE_CAPABILITY_MINOR,
602
+ )
587
603
 
588
604
  # Read name
589
605
  bufsz = 128
590
606
 
591
607
  if USE_NV_BINDING:
592
608
  buf = driver.cuDeviceGetName(bufsz, self.id)
593
- name = buf.decode('utf-8').rstrip('\0')
609
+ name = buf.decode("utf-8").rstrip("\0")
594
610
  else:
595
611
  buf = (c_char * bufsz)()
596
612
  driver.cuDeviceGetName(buf, bufsz, self.id)
@@ -607,31 +623,31 @@ class Device(object):
607
623
  driver.cuDeviceGetUuid(byref(uuid), self.id)
608
624
  uuid_vals = tuple(bytes(uuid))
609
625
 
610
- b = '%02x'
626
+ b = "%02x"
611
627
  b2 = b * 2
612
628
  b4 = b * 4
613
629
  b6 = b * 6
614
- fmt = f'GPU-{b4}-{b2}-{b2}-{b2}-{b6}'
630
+ fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}"
615
631
  self.uuid = fmt % uuid_vals
616
632
 
617
633
  self.primary_context = None
618
634
 
619
635
  def get_device_identity(self):
620
636
  return {
621
- 'pci_domain_id': self.PCI_DOMAIN_ID,
622
- 'pci_bus_id': self.PCI_BUS_ID,
623
- 'pci_device_id': self.PCI_DEVICE_ID,
637
+ "pci_domain_id": self.PCI_DOMAIN_ID,
638
+ "pci_bus_id": self.PCI_BUS_ID,
639
+ "pci_device_id": self.PCI_DEVICE_ID,
624
640
  }
625
641
 
626
642
  def __repr__(self):
627
643
  return "<CUDA device %d '%s'>" % (self.id, self.name)
628
644
 
629
645
  def __getattr__(self, attr):
630
- """Read attributes lazily
631
- """
646
+ """Read attributes lazily"""
632
647
  if USE_NV_BINDING:
633
- code = getattr(binding.CUdevice_attribute,
634
- f'CU_DEVICE_ATTRIBUTE_{attr}')
648
+ code = getattr(
649
+ binding.CUdevice_attribute, f"CU_DEVICE_ATTRIBUTE_{attr}"
650
+ )
635
651
  value = driver.cuDeviceGetAttribute(code, self.id)
636
652
  else:
637
653
  try:
@@ -701,17 +717,18 @@ class Device(object):
701
717
 
702
718
  def met_requirement_for_device(device):
703
719
  if device.compute_capability < MIN_REQUIRED_CC:
704
- raise CudaSupportError("%s has compute capability < %s" %
705
- (device, MIN_REQUIRED_CC))
720
+ raise CudaSupportError(
721
+ "%s has compute capability < %s" % (device, MIN_REQUIRED_CC)
722
+ )
706
723
 
707
724
 
708
725
  class BaseCUDAMemoryManager(object, metaclass=ABCMeta):
709
726
  """Abstract base class for External Memory Management (EMM) Plugins."""
710
727
 
711
728
  def __init__(self, *args, **kwargs):
712
- if 'context' not in kwargs:
729
+ if "context" not in kwargs:
713
730
  raise RuntimeError("Memory manager requires a context")
714
- self.context = kwargs.pop('context')
731
+ self.context = kwargs.pop("context")
715
732
 
716
733
  @abstractmethod
717
734
  def memalloc(self, size):
@@ -867,8 +884,7 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
867
884
  else:
868
885
  raise
869
886
 
870
- def memhostalloc(self, size, mapped=False, portable=False,
871
- wc=False):
887
+ def memhostalloc(self, size, mapped=False, portable=False, wc=False):
872
888
  """Implements the allocation of pinned host memory.
873
889
 
874
890
  It is recommended that this method is not overridden by EMM Plugin
@@ -883,6 +899,7 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
883
899
  flags |= enums.CU_MEMHOSTALLOC_WRITECOMBINED
884
900
 
885
901
  if USE_NV_BINDING:
902
+
886
903
  def allocator():
887
904
  return driver.cuMemHostAlloc(size, flags)
888
905
 
@@ -949,16 +966,19 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
949
966
  ctx = weakref.proxy(self.context)
950
967
 
951
968
  if mapped:
952
- mem = MappedMemory(ctx, pointer, size, owner=owner,
953
- finalizer=finalizer)
969
+ mem = MappedMemory(
970
+ ctx, pointer, size, owner=owner, finalizer=finalizer
971
+ )
954
972
  self.allocations[alloc_key] = mem
955
973
  return mem.own()
956
974
  else:
957
- return PinnedMemory(ctx, pointer, size, owner=owner,
958
- finalizer=finalizer)
975
+ return PinnedMemory(
976
+ ctx, pointer, size, owner=owner, finalizer=finalizer
977
+ )
959
978
 
960
979
  def memallocmanaged(self, size, attach_global):
961
980
  if USE_NV_BINDING:
981
+
962
982
  def allocator():
963
983
  ma_flags = binding.CUmemAttach_flags
964
984
 
@@ -1017,8 +1037,7 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
1017
1037
 
1018
1038
 
1019
1039
  class GetIpcHandleMixin:
1020
- """A class that provides a default implementation of ``get_ipc_handle()``.
1021
- """
1040
+ """A class that provides a default implementation of ``get_ipc_handle()``."""
1022
1041
 
1023
1042
  def get_ipc_handle(self, memory):
1024
1043
  """Open an IPC memory handle by using ``cuMemGetAddressRange`` to
@@ -1037,8 +1056,9 @@ class GetIpcHandleMixin:
1037
1056
  offset = memory.handle.value - base
1038
1057
  source_info = self.context.device.get_device_identity()
1039
1058
 
1040
- return IpcHandle(memory, ipchandle, memory.size, source_info,
1041
- offset=offset)
1059
+ return IpcHandle(
1060
+ memory, ipchandle, memory.size, source_info, offset=offset
1061
+ )
1042
1062
 
1043
1063
 
1044
1064
  class NumbaCUDAMemoryManager(GetIpcHandleMixin, HostOnlyCUDAMemoryManager):
@@ -1053,6 +1073,7 @@ class NumbaCUDAMemoryManager(GetIpcHandleMixin, HostOnlyCUDAMemoryManager):
1053
1073
 
1054
1074
  def memalloc(self, size):
1055
1075
  if USE_NV_BINDING:
1076
+
1056
1077
  def allocator():
1057
1078
  return driver.cuMemAlloc(size)
1058
1079
 
@@ -1101,7 +1122,7 @@ def _ensure_memory_manager():
1101
1122
  if _memory_manager:
1102
1123
  return
1103
1124
 
1104
- if config.CUDA_MEMORY_MANAGER == 'default':
1125
+ if config.CUDA_MEMORY_MANAGER == "default":
1105
1126
  _memory_manager = NumbaCUDAMemoryManager
1106
1127
  return
1107
1128
 
@@ -1109,8 +1130,9 @@ def _ensure_memory_manager():
1109
1130
  mgr_module = importlib.import_module(config.CUDA_MEMORY_MANAGER)
1110
1131
  set_memory_manager(mgr_module._numba_memory_manager)
1111
1132
  except Exception:
1112
- raise RuntimeError("Failed to use memory manager from %s" %
1113
- config.CUDA_MEMORY_MANAGER)
1133
+ raise RuntimeError(
1134
+ "Failed to use memory manager from %s" % config.CUDA_MEMORY_MANAGER
1135
+ )
1114
1136
 
1115
1137
 
1116
1138
  def set_memory_manager(mm_plugin):
@@ -1127,8 +1149,10 @@ def set_memory_manager(mm_plugin):
1127
1149
  dummy = mm_plugin(context=None)
1128
1150
  iv = dummy.interface_version
1129
1151
  if iv != _SUPPORTED_EMM_INTERFACE_VERSION:
1130
- err = "EMM Plugin interface has version %d - version %d required" \
1131
- % (iv, _SUPPORTED_EMM_INTERFACE_VERSION)
1152
+ err = "EMM Plugin interface has version %d - version %d required" % (
1153
+ iv,
1154
+ _SUPPORTED_EMM_INTERFACE_VERSION,
1155
+ )
1132
1156
  raise RuntimeError(err)
1133
1157
 
1134
1158
  _memory_manager = mm_plugin
@@ -1143,7 +1167,7 @@ class _SizeNotSet(int):
1143
1167
  return super().__new__(cls, 0)
1144
1168
 
1145
1169
  def __str__(self):
1146
- return '?'
1170
+ return "?"
1147
1171
 
1148
1172
 
1149
1173
  _SizeNotSet = _SizeNotSet()
@@ -1156,6 +1180,7 @@ class _PendingDeallocs(object):
1156
1180
  modified later once the driver is initialized and the total memory capacity
1157
1181
  known.
1158
1182
  """
1183
+
1159
1184
  def __init__(self, capacity=_SizeNotSet):
1160
1185
  self._cons = deque()
1161
1186
  self._disable_count = 0
@@ -1175,11 +1200,13 @@ class _PendingDeallocs(object):
1175
1200
  byte size of the resource added. It is an optional argument. Some
1176
1201
  resources (e.g. CUModule) has an unknown memory footprint on the device.
1177
1202
  """
1178
- _logger.info('add pending dealloc: %s %s bytes', dtor.__name__, size)
1203
+ _logger.info("add pending dealloc: %s %s bytes", dtor.__name__, size)
1179
1204
  self._cons.append((dtor, handle, size))
1180
1205
  self._size += int(size)
1181
- if (len(self._cons) > config.CUDA_DEALLOCS_COUNT or
1182
- self._size > self._max_pending_bytes):
1206
+ if (
1207
+ len(self._cons) > config.CUDA_DEALLOCS_COUNT
1208
+ or self._size > self._max_pending_bytes
1209
+ ):
1183
1210
  self.clear()
1184
1211
 
1185
1212
  def clear(self):
@@ -1190,7 +1217,7 @@ class _PendingDeallocs(object):
1190
1217
  if not self.is_disabled:
1191
1218
  while self._cons:
1192
1219
  [dtor, handle, size] = self._cons.popleft()
1193
- _logger.info('dealloc: %s %s bytes', dtor.__name__, size)
1220
+ _logger.info("dealloc: %s %s bytes", dtor.__name__, size)
1194
1221
  dtor(handle)
1195
1222
  self._size = 0
1196
1223
 
@@ -1254,19 +1281,19 @@ class Context(object):
1254
1281
  Clean up all owned resources in this context.
1255
1282
  """
1256
1283
  # Free owned resources
1257
- _logger.info('reset context of device %s', self.device.id)
1284
+ _logger.info("reset context of device %s", self.device.id)
1258
1285
  self.memory_manager.reset()
1259
1286
  self.modules.clear()
1260
1287
  # Clear trash
1261
1288
  self.deallocations.clear()
1262
1289
 
1263
1290
  def get_memory_info(self):
1264
- """Returns (free, total) memory in bytes in the context.
1265
- """
1291
+ """Returns (free, total) memory in bytes in the context."""
1266
1292
  return self.memory_manager.get_memory_info()
1267
1293
 
1268
- def get_active_blocks_per_multiprocessor(self, func, blocksize, memsize,
1269
- flags=None):
1294
+ def get_active_blocks_per_multiprocessor(
1295
+ self, func, blocksize, memsize, flags=None
1296
+ ):
1270
1297
  """Return occupancy of a function.
1271
1298
  :param func: kernel for which occupancy is calculated
1272
1299
  :param blocksize: block size the kernel is intended to be launched with
@@ -1278,8 +1305,9 @@ class Context(object):
1278
1305
  else:
1279
1306
  return self._ctypes_active_blocks_per_multiprocessor(*args)
1280
1307
 
1281
- def _cuda_python_active_blocks_per_multiprocessor(self, func, blocksize,
1282
- memsize, flags):
1308
+ def _cuda_python_active_blocks_per_multiprocessor(
1309
+ self, func, blocksize, memsize, flags
1310
+ ):
1283
1311
  ps = [func.handle, blocksize, memsize]
1284
1312
 
1285
1313
  if not flags:
@@ -1288,8 +1316,9 @@ class Context(object):
1288
1316
  ps.append(flags)
1289
1317
  return driver.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(*ps)
1290
1318
 
1291
- def _ctypes_active_blocks_per_multiprocessor(self, func, blocksize,
1292
- memsize, flags):
1319
+ def _ctypes_active_blocks_per_multiprocessor(
1320
+ self, func, blocksize, memsize, flags
1321
+ ):
1293
1322
  retval = c_int()
1294
1323
  args = (byref(retval), func.handle, blocksize, memsize)
1295
1324
 
@@ -1300,8 +1329,9 @@ class Context(object):
1300
1329
 
1301
1330
  return retval.value
1302
1331
 
1303
- def get_max_potential_block_size(self, func, b2d_func, memsize,
1304
- blocksizelimit, flags=None):
1332
+ def get_max_potential_block_size(
1333
+ self, func, b2d_func, memsize, blocksizelimit, flags=None
1334
+ ):
1305
1335
  """Suggest a launch configuration with reasonable occupancy.
1306
1336
  :param func: kernel for which occupancy is calculated
1307
1337
  :param b2d_func: function that calculates how much per-block dynamic
@@ -1318,13 +1348,20 @@ class Context(object):
1318
1348
  else:
1319
1349
  return self._ctypes_max_potential_block_size(*args)
1320
1350
 
1321
- def _ctypes_max_potential_block_size(self, func, b2d_func, memsize,
1322
- blocksizelimit, flags):
1351
+ def _ctypes_max_potential_block_size(
1352
+ self, func, b2d_func, memsize, blocksizelimit, flags
1353
+ ):
1323
1354
  gridsize = c_int()
1324
1355
  blocksize = c_int()
1325
1356
  b2d_cb = cu_occupancy_b2d_size(b2d_func)
1326
- args = [byref(gridsize), byref(blocksize), func.handle, b2d_cb,
1327
- memsize, blocksizelimit]
1357
+ args = [
1358
+ byref(gridsize),
1359
+ byref(blocksize),
1360
+ func.handle,
1361
+ b2d_cb,
1362
+ memsize,
1363
+ blocksizelimit,
1364
+ ]
1328
1365
 
1329
1366
  if not flags:
1330
1367
  driver.cuOccupancyMaxPotentialBlockSize(*args)
@@ -1334,10 +1371,11 @@ class Context(object):
1334
1371
 
1335
1372
  return (gridsize.value, blocksize.value)
1336
1373
 
1337
- def _cuda_python_max_potential_block_size(self, func, b2d_func, memsize,
1338
- blocksizelimit, flags):
1374
+ def _cuda_python_max_potential_block_size(
1375
+ self, func, b2d_func, memsize, blocksizelimit, flags
1376
+ ):
1339
1377
  b2d_cb = ctypes.CFUNCTYPE(c_size_t, c_int)(b2d_func)
1340
- ptr = int.from_bytes(b2d_cb, byteorder='little')
1378
+ ptr = int.from_bytes(b2d_cb, byteorder="little")
1341
1379
  driver_b2d_cb = binding.CUoccupancyB2DSize(ptr)
1342
1380
  args = [func.handle, driver_b2d_cb, memsize, blocksizelimit]
1343
1381
 
@@ -1390,7 +1428,7 @@ class Context(object):
1390
1428
  Returns an *IpcHandle* from a GPU allocation.
1391
1429
  """
1392
1430
  if not SUPPORTS_IPC:
1393
- raise OSError('OS does not support CUDA IPC')
1431
+ raise OSError("OS does not support CUDA IPC")
1394
1432
  return self.memory_manager.get_ipc_handle(memory)
1395
1433
 
1396
1434
  def open_ipc_handle(self, handle, size):
@@ -1403,13 +1441,13 @@ class Context(object):
1403
1441
  driver.cuIpcOpenMemHandle(byref(dptr), handle, flags)
1404
1442
 
1405
1443
  # wrap it
1406
- return MemoryPointer(context=weakref.proxy(self), pointer=dptr,
1407
- size=size)
1444
+ return MemoryPointer(
1445
+ context=weakref.proxy(self), pointer=dptr, size=size
1446
+ )
1408
1447
 
1409
1448
  def enable_peer_access(self, peer_context, flags=0):
1410
- """Enable peer access between the current context and the peer context
1411
- """
1412
- assert flags == 0, '*flags* is reserved and MUST be zero'
1449
+ """Enable peer access between the current context and the peer context"""
1450
+ assert flags == 0, "*flags* is reserved and MUST be zero"
1413
1451
  driver.cuCtxEnablePeerAccess(peer_context, flags)
1414
1452
 
1415
1453
  def can_access_peer(self, peer_device):
@@ -1418,26 +1456,34 @@ class Context(object):
1418
1456
  """
1419
1457
  if USE_NV_BINDING:
1420
1458
  peer_device = binding.CUdevice(peer_device)
1421
- can_access_peer = driver.cuDeviceCanAccessPeer(self.device.id,
1422
- peer_device)
1459
+ can_access_peer = driver.cuDeviceCanAccessPeer(
1460
+ self.device.id, peer_device
1461
+ )
1423
1462
  else:
1424
1463
  can_access_peer = c_int()
1425
- driver.cuDeviceCanAccessPeer(byref(can_access_peer),
1426
- self.device.id, peer_device,)
1464
+ driver.cuDeviceCanAccessPeer(
1465
+ byref(can_access_peer),
1466
+ self.device.id,
1467
+ peer_device,
1468
+ )
1427
1469
 
1428
1470
  return bool(can_access_peer)
1429
1471
 
1430
1472
  def create_module_ptx(self, ptx):
1431
1473
  if isinstance(ptx, str):
1432
- ptx = ptx.encode('utf8')
1474
+ ptx = ptx.encode("utf8")
1433
1475
  if USE_NV_BINDING:
1434
1476
  image = ptx
1435
1477
  else:
1436
1478
  image = c_char_p(ptx)
1437
1479
  return self.create_module_image(image)
1438
1480
 
1439
- def create_module_image(self, image):
1440
- module = load_module_image(self, image)
1481
+ def create_module_image(
1482
+ self, image, setup_callbacks=None, teardown_callbacks=None
1483
+ ):
1484
+ module = load_module_image(
1485
+ self, image, setup_callbacks, teardown_callbacks
1486
+ )
1441
1487
  if USE_NV_BINDING:
1442
1488
  key = module.handle
1443
1489
  else:
@@ -1484,8 +1530,11 @@ class Context(object):
1484
1530
  else:
1485
1531
  handle = drvapi.cu_stream()
1486
1532
  driver.cuStreamCreate(byref(handle), 0)
1487
- return Stream(weakref.proxy(self), handle,
1488
- _stream_finalizer(self.deallocations, handle))
1533
+ return Stream(
1534
+ weakref.proxy(self),
1535
+ handle,
1536
+ _stream_finalizer(self.deallocations, handle),
1537
+ )
1489
1538
 
1490
1539
  def create_external_stream(self, ptr):
1491
1540
  if not isinstance(ptr, int):
@@ -1494,8 +1543,7 @@ class Context(object):
1494
1543
  handle = binding.CUstream(ptr)
1495
1544
  else:
1496
1545
  handle = drvapi.cu_stream(ptr)
1497
- return Stream(weakref.proxy(self), handle, None,
1498
- external=True)
1546
+ return Stream(weakref.proxy(self), handle, None, external=True)
1499
1547
 
1500
1548
  def create_event(self, timing=True):
1501
1549
  flags = 0
@@ -1506,8 +1554,11 @@ class Context(object):
1506
1554
  else:
1507
1555
  handle = drvapi.cu_event()
1508
1556
  driver.cuEventCreate(byref(handle), flags)
1509
- return Event(weakref.proxy(self), handle,
1510
- finalizer=_event_finalizer(self.deallocations, handle))
1557
+ return Event(
1558
+ weakref.proxy(self),
1559
+ handle,
1560
+ finalizer=_event_finalizer(self.deallocations, handle),
1561
+ )
1511
1562
 
1512
1563
  def synchronize(self):
1513
1564
  driver.cuCtxSynchronize()
@@ -1531,17 +1582,25 @@ class Context(object):
1531
1582
  return not self.__eq__(other)
1532
1583
 
1533
1584
 
1534
- def load_module_image(context, image):
1585
+ def load_module_image(
1586
+ context, image, setup_callbacks=None, teardown_callbacks=None
1587
+ ):
1535
1588
  """
1536
1589
  image must be a pointer
1537
1590
  """
1538
1591
  if USE_NV_BINDING:
1539
- return load_module_image_cuda_python(context, image)
1592
+ return load_module_image_cuda_python(
1593
+ context, image, setup_callbacks, teardown_callbacks
1594
+ )
1540
1595
  else:
1541
- return load_module_image_ctypes(context, image)
1596
+ return load_module_image_ctypes(
1597
+ context, image, setup_callbacks, teardown_callbacks
1598
+ )
1542
1599
 
1543
1600
 
1544
- def load_module_image_ctypes(context, image):
1601
+ def load_module_image_ctypes(
1602
+ context, image, setup_callbacks, teardown_callbacks
1603
+ ):
1545
1604
  logsz = config.CUDA_LOG_SIZE
1546
1605
 
1547
1606
  jitinfo = (c_char * logsz)()
@@ -1560,19 +1619,28 @@ def load_module_image_ctypes(context, image):
1560
1619
 
1561
1620
  handle = drvapi.cu_module()
1562
1621
  try:
1563
- driver.cuModuleLoadDataEx(byref(handle), image, len(options),
1564
- option_keys, option_vals)
1622
+ driver.cuModuleLoadDataEx(
1623
+ byref(handle), image, len(options), option_keys, option_vals
1624
+ )
1565
1625
  except CudaAPIError as e:
1566
1626
  msg = "cuModuleLoadDataEx error:\n%s" % jiterrors.value.decode("utf8")
1567
1627
  raise CudaAPIError(e.code, msg)
1568
1628
 
1569
1629
  info_log = jitinfo.value
1570
1630
 
1571
- return CtypesModule(weakref.proxy(context), handle, info_log,
1572
- _module_finalizer(context, handle))
1631
+ return CtypesModule(
1632
+ weakref.proxy(context),
1633
+ handle,
1634
+ info_log,
1635
+ _module_finalizer(context, handle),
1636
+ setup_callbacks,
1637
+ teardown_callbacks,
1638
+ )
1573
1639
 
1574
1640
 
1575
- def load_module_image_cuda_python(context, image):
1641
+ def load_module_image_cuda_python(
1642
+ context, image, setup_callbacks, teardown_callbacks
1643
+ ):
1576
1644
  """
1577
1645
  image must be a pointer
1578
1646
  """
@@ -1594,17 +1662,24 @@ def load_module_image_cuda_python(context, image):
1594
1662
  option_vals = [v for v in options.values()]
1595
1663
 
1596
1664
  try:
1597
- handle = driver.cuModuleLoadDataEx(image, len(options), option_keys,
1598
- option_vals)
1665
+ handle = driver.cuModuleLoadDataEx(
1666
+ image, len(options), option_keys, option_vals
1667
+ )
1599
1668
  except CudaAPIError as e:
1600
- err_string = jiterrors.decode('utf-8')
1669
+ err_string = jiterrors.decode("utf-8")
1601
1670
  msg = "cuModuleLoadDataEx error:\n%s" % err_string
1602
1671
  raise CudaAPIError(e.code, msg)
1603
1672
 
1604
- info_log = jitinfo.decode('utf-8')
1673
+ info_log = jitinfo.decode("utf-8")
1605
1674
 
1606
- return CudaPythonModule(weakref.proxy(context), handle, info_log,
1607
- _module_finalizer(context, handle))
1675
+ return CudaPythonModule(
1676
+ weakref.proxy(context),
1677
+ handle,
1678
+ info_log,
1679
+ _module_finalizer(context, handle),
1680
+ setup_callbacks,
1681
+ teardown_callbacks,
1682
+ )
1608
1683
 
1609
1684
 
1610
1685
  def _alloc_finalizer(memory_manager, ptr, alloc_key, size):
@@ -1707,6 +1782,7 @@ class _CudaIpcImpl(object):
1707
1782
  """Implementation of GPU IPC using CUDA driver API.
1708
1783
  This requires the devices to be peer accessible.
1709
1784
  """
1785
+
1710
1786
  def __init__(self, parent):
1711
1787
  self.base = parent.base
1712
1788
  self.handle = parent.handle
@@ -1720,10 +1796,10 @@ class _CudaIpcImpl(object):
1720
1796
  Import the IPC memory and returns a raw CUDA memory pointer object
1721
1797
  """
1722
1798
  if self.base is not None:
1723
- raise ValueError('opening IpcHandle from original process')
1799
+ raise ValueError("opening IpcHandle from original process")
1724
1800
 
1725
1801
  if self._opened_mem is not None:
1726
- raise ValueError('IpcHandle is already opened')
1802
+ raise ValueError("IpcHandle is already opened")
1727
1803
 
1728
1804
  mem = context.open_ipc_handle(self.handle, self.offset + self.size)
1729
1805
  # this object owns the opened allocation
@@ -1734,7 +1810,7 @@ class _CudaIpcImpl(object):
1734
1810
 
1735
1811
  def close(self):
1736
1812
  if self._opened_mem is None:
1737
- raise ValueError('IpcHandle not opened')
1813
+ raise ValueError("IpcHandle not opened")
1738
1814
  driver.cuIpcCloseMemHandle(self._opened_mem.handle)
1739
1815
  self._opened_mem = None
1740
1816
 
@@ -1743,6 +1819,7 @@ class _StagedIpcImpl(object):
1743
1819
  """Implementation of GPU IPC using custom staging logic to workaround
1744
1820
  CUDA IPC limitation on peer accessibility between devices.
1745
1821
  """
1822
+
1746
1823
  def __init__(self, parent, source_info):
1747
1824
  self.parent = parent
1748
1825
  self.base = parent.base
@@ -1798,6 +1875,7 @@ class IpcHandle(object):
1798
1875
  referred to by this IPC handle.
1799
1876
  :type offset: int
1800
1877
  """
1878
+
1801
1879
  def __init__(self, base, handle, size, source_info=None, offset=0):
1802
1880
  self.base = base
1803
1881
  self.handle = handle
@@ -1821,12 +1899,11 @@ class IpcHandle(object):
1821
1899
  return context.can_access_peer(source_device.id)
1822
1900
 
1823
1901
  def open_staged(self, context):
1824
- """Open the IPC by allowing staging on the host memory first.
1825
- """
1902
+ """Open the IPC by allowing staging on the host memory first."""
1826
1903
  self._sentry_source_info()
1827
1904
 
1828
1905
  if self._impl is not None:
1829
- raise ValueError('IpcHandle is already opened')
1906
+ raise ValueError("IpcHandle is already opened")
1830
1907
 
1831
1908
  self._impl = _StagedIpcImpl(self, self.source_info)
1832
1909
  return self._impl.open(context)
@@ -1836,7 +1913,7 @@ class IpcHandle(object):
1836
1913
  Import the IPC memory and returns a raw CUDA memory pointer object
1837
1914
  """
1838
1915
  if self._impl is not None:
1839
- raise ValueError('IpcHandle is already opened')
1916
+ raise ValueError("IpcHandle is already opened")
1840
1917
 
1841
1918
  self._impl = _CudaIpcImpl(self)
1842
1919
  return self._impl.open(context)
@@ -1867,12 +1944,13 @@ class IpcHandle(object):
1867
1944
  strides = dtype.itemsize
1868
1945
  dptr = self.open(context)
1869
1946
  # read the device pointer as an array
1870
- return devicearray.DeviceNDArray(shape=shape, strides=strides,
1871
- dtype=dtype, gpu_data=dptr)
1947
+ return devicearray.DeviceNDArray(
1948
+ shape=shape, strides=strides, dtype=dtype, gpu_data=dptr
1949
+ )
1872
1950
 
1873
1951
  def close(self):
1874
1952
  if self._impl is None:
1875
- raise ValueError('IpcHandle not opened')
1953
+ raise ValueError("IpcHandle not opened")
1876
1954
  self._impl.close()
1877
1955
  self._impl = None
1878
1956
 
@@ -1898,8 +1976,13 @@ class IpcHandle(object):
1898
1976
  else:
1899
1977
  handle = drvapi.cu_ipc_mem_handle()
1900
1978
  handle.reserved = handle_ary
1901
- return cls(base=None, handle=handle, size=size,
1902
- source_info=source_info, offset=offset)
1979
+ return cls(
1980
+ base=None,
1981
+ handle=handle,
1982
+ size=size,
1983
+ source_info=source_info,
1984
+ offset=offset,
1985
+ )
1903
1986
 
1904
1987
 
1905
1988
  class MemoryPointer(object):
@@ -1933,6 +2016,7 @@ class MemoryPointer(object):
1933
2016
  :param finalizer: A function that is called when the buffer is to be freed.
1934
2017
  :type finalizer: function
1935
2018
  """
2019
+
1936
2020
  __cuda_memory__ = True
1937
2021
 
1938
2022
  def __init__(self, context, pointer, size, owner=None, finalizer=None):
@@ -1968,8 +2052,9 @@ class MemoryPointer(object):
1968
2052
  def memset(self, byte, count=None, stream=0):
1969
2053
  count = self.size if count is None else count
1970
2054
  if stream:
1971
- driver.cuMemsetD8Async(self.device_pointer, byte, count,
1972
- stream.handle)
2055
+ driver.cuMemsetD8Async(
2056
+ self.device_pointer, byte, count, stream.handle
2057
+ )
1973
2058
  else:
1974
2059
  driver.cuMemsetD8(self.device_pointer, byte, count)
1975
2060
 
@@ -1983,12 +2068,12 @@ class MemoryPointer(object):
1983
2068
  if not self.device_pointer_value:
1984
2069
  if size != 0:
1985
2070
  raise RuntimeError("non-empty slice into empty slice")
1986
- view = self # new view is just a reference to self
2071
+ view = self # new view is just a reference to self
1987
2072
  # Handle normal case
1988
2073
  else:
1989
2074
  base = self.device_pointer_value + start
1990
2075
  if size < 0:
1991
- raise RuntimeError('size cannot be negative')
2076
+ raise RuntimeError("size cannot be negative")
1992
2077
  if USE_NV_BINDING:
1993
2078
  pointer = binding.CUdeviceptr()
1994
2079
  ctypes_ptr = drvapi.cu_device_ptr.from_address(pointer.getPtr())
@@ -2024,6 +2109,7 @@ class AutoFreePointer(MemoryPointer):
2024
2109
 
2025
2110
  Constructor arguments are the same as for :class:`MemoryPointer`.
2026
2111
  """
2112
+
2027
2113
  def __init__(self, *args, **kwargs):
2028
2114
  super(AutoFreePointer, self).__init__(*args, **kwargs)
2029
2115
  # Releease the self reference to the buffer, so that the finalizer
@@ -2066,8 +2152,9 @@ class MappedMemory(AutoFreePointer):
2066
2152
  self._bufptr_ = self.host_pointer.value
2067
2153
 
2068
2154
  self.device_pointer = devptr
2069
- super(MappedMemory, self).__init__(context, devptr, size,
2070
- finalizer=finalizer)
2155
+ super(MappedMemory, self).__init__(
2156
+ context, devptr, size, finalizer=finalizer
2157
+ )
2071
2158
  self.handle = self.host_pointer
2072
2159
 
2073
2160
  # For buffer interface
@@ -2182,8 +2269,7 @@ class OwnedPointer(object):
2182
2269
  weakref.finalize(self, deref)
2183
2270
 
2184
2271
  def __getattr__(self, fname):
2185
- """Proxy MemoryPointer methods
2186
- """
2272
+ """Proxy MemoryPointer methods"""
2187
2273
  return getattr(self._view, fname)
2188
2274
 
2189
2275
 
@@ -2214,18 +2300,15 @@ class Stream(object):
2214
2300
  if USE_NV_BINDING:
2215
2301
  default_streams = {
2216
2302
  CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2217
- binding.CU_STREAM_LEGACY:
2218
- "<Legacy default CUDA stream on %s>",
2219
- binding.CU_STREAM_PER_THREAD:
2220
- "<Per-thread default CUDA stream on %s>",
2303
+ binding.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2304
+ binding.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2221
2305
  }
2222
2306
  ptr = int(self.handle) or 0
2223
2307
  else:
2224
2308
  default_streams = {
2225
2309
  drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2226
2310
  drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2227
- drvapi.CU_STREAM_PER_THREAD:
2228
- "<Per-thread default CUDA stream on %s>",
2311
+ drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2229
2312
  }
2230
2313
  ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
2231
2314
 
@@ -2237,18 +2320,18 @@ class Stream(object):
2237
2320
  return "<CUDA stream %d on %s>" % (ptr, self.context)
2238
2321
 
2239
2322
  def synchronize(self):
2240
- '''
2323
+ """
2241
2324
  Wait for all commands in this stream to execute. This will commit any
2242
2325
  pending memory transfers.
2243
- '''
2326
+ """
2244
2327
  driver.cuStreamSynchronize(self.handle)
2245
2328
 
2246
2329
  @contextlib.contextmanager
2247
2330
  def auto_synchronize(self):
2248
- '''
2331
+ """
2249
2332
  A context manager that waits for all commands in this stream to execute
2250
2333
  and commits any pending memory transfers upon exiting the context.
2251
- '''
2334
+ """
2252
2335
  yield self
2253
2336
  self.synchronize()
2254
2337
 
@@ -2275,7 +2358,7 @@ class Stream(object):
2275
2358
  data = (self, callback, arg)
2276
2359
  _py_incref(data)
2277
2360
  if USE_NV_BINDING:
2278
- ptr = int.from_bytes(self._stream_callback, byteorder='little')
2361
+ ptr = int.from_bytes(self._stream_callback, byteorder="little")
2279
2362
  stream_callback = binding.CUstreamCallback(ptr)
2280
2363
  # The callback needs to receive a pointer to the data PyObject
2281
2364
  data = id(data)
@@ -2376,9 +2459,9 @@ class Event(object):
2376
2459
 
2377
2460
 
2378
2461
  def event_elapsed_time(evtstart, evtend):
2379
- '''
2462
+ """
2380
2463
  Compute the elapsed time between two events in milliseconds.
2381
- '''
2464
+ """
2382
2465
  if USE_NV_BINDING:
2383
2466
  return driver.cuEventElapsedTime(evtstart.handle, evtend.handle)
2384
2467
  else:
@@ -2390,13 +2473,27 @@ def event_elapsed_time(evtstart, evtend):
2390
2473
  class Module(metaclass=ABCMeta):
2391
2474
  """Abstract base class for modules"""
2392
2475
 
2393
- def __init__(self, context, handle, info_log, finalizer=None):
2476
+ def __init__(
2477
+ self,
2478
+ context,
2479
+ handle,
2480
+ info_log,
2481
+ finalizer=None,
2482
+ setup_callbacks=None,
2483
+ teardown_callbacks=None,
2484
+ ):
2394
2485
  self.context = context
2395
2486
  self.handle = handle
2396
2487
  self.info_log = info_log
2397
2488
  if finalizer is not None:
2398
2489
  self._finalizer = weakref.finalize(self, finalizer)
2399
2490
 
2491
+ self.initialized = False
2492
+ self.setup_functions = setup_callbacks
2493
+ self.teardown_functions = teardown_callbacks
2494
+
2495
+ self._set_finalizers()
2496
+
2400
2497
  def unload(self):
2401
2498
  """Unload this module from the context"""
2402
2499
  self.context.unload_module(self)
@@ -2409,36 +2506,66 @@ class Module(metaclass=ABCMeta):
2409
2506
  def get_global_symbol(self, name):
2410
2507
  """Return a MemoryPointer referring to the named symbol"""
2411
2508
 
2509
+ def setup(self):
2510
+ """Call the setup functions for the module"""
2511
+ if self.initialized:
2512
+ raise RuntimeError("The module has already been initialized.")
2412
2513
 
2413
- class CtypesModule(Module):
2514
+ if self.setup_functions is None:
2515
+ return
2516
+
2517
+ for f in self.setup_functions:
2518
+ f(self.handle)
2519
+
2520
+ self.initialized = True
2414
2521
 
2522
+ def _set_finalizers(self):
2523
+ """Create finalizers that tear down the module."""
2524
+ if self.teardown_functions is None:
2525
+ return
2526
+
2527
+ def _teardown(teardowns, handle):
2528
+ for f in teardowns:
2529
+ f(handle)
2530
+
2531
+ weakref.finalize(
2532
+ self,
2533
+ _teardown,
2534
+ self.teardown_functions,
2535
+ self.handle,
2536
+ )
2537
+
2538
+
2539
+ class CtypesModule(Module):
2415
2540
  def get_function(self, name):
2416
2541
  handle = drvapi.cu_function()
2417
- driver.cuModuleGetFunction(byref(handle), self.handle,
2418
- name.encode('utf8'))
2542
+ driver.cuModuleGetFunction(
2543
+ byref(handle), self.handle, name.encode("utf8")
2544
+ )
2419
2545
  return CtypesFunction(weakref.proxy(self), handle, name)
2420
2546
 
2421
2547
  def get_global_symbol(self, name):
2422
2548
  ptr = drvapi.cu_device_ptr()
2423
2549
  size = drvapi.c_size_t()
2424
- driver.cuModuleGetGlobal(byref(ptr), byref(size), self.handle,
2425
- name.encode('utf8'))
2550
+ driver.cuModuleGetGlobal(
2551
+ byref(ptr), byref(size), self.handle, name.encode("utf8")
2552
+ )
2426
2553
  return MemoryPointer(self.context, ptr, size), size.value
2427
2554
 
2428
2555
 
2429
2556
  class CudaPythonModule(Module):
2430
-
2431
2557
  def get_function(self, name):
2432
- handle = driver.cuModuleGetFunction(self.handle, name.encode('utf8'))
2558
+ handle = driver.cuModuleGetFunction(self.handle, name.encode("utf8"))
2433
2559
  return CudaPythonFunction(weakref.proxy(self), handle, name)
2434
2560
 
2435
2561
  def get_global_symbol(self, name):
2436
- ptr, size = driver.cuModuleGetGlobal(self.handle, name.encode('utf8'))
2562
+ ptr, size = driver.cuModuleGetGlobal(self.handle, name.encode("utf8"))
2437
2563
  return MemoryPointer(self.context, ptr, size), size
2438
2564
 
2439
2565
 
2440
- FuncAttr = namedtuple("FuncAttr", ["regs", "shared", "local", "const",
2441
- "maxthreads"])
2566
+ FuncAttr = namedtuple(
2567
+ "FuncAttr", ["regs", "shared", "local", "const", "maxthreads"]
2568
+ )
2442
2569
 
2443
2570
 
2444
2571
  class Function(metaclass=ABCMeta):
@@ -2461,8 +2588,9 @@ class Function(metaclass=ABCMeta):
2461
2588
  return self.module.context.device
2462
2589
 
2463
2590
  @abstractmethod
2464
- def cache_config(self, prefer_equal=False, prefer_cache=False,
2465
- prefer_shared=False):
2591
+ def cache_config(
2592
+ self, prefer_equal=False, prefer_cache=False, prefer_shared=False
2593
+ ):
2466
2594
  """Set the cache configuration for this function."""
2467
2595
 
2468
2596
  @abstractmethod
@@ -2476,9 +2604,9 @@ class Function(metaclass=ABCMeta):
2476
2604
 
2477
2605
 
2478
2606
  class CtypesFunction(Function):
2479
-
2480
- def cache_config(self, prefer_equal=False, prefer_cache=False,
2481
- prefer_shared=False):
2607
+ def cache_config(
2608
+ self, prefer_equal=False, prefer_cache=False, prefer_shared=False
2609
+ ):
2482
2610
  prefer_equal = prefer_equal or (prefer_cache and prefer_shared)
2483
2611
  if prefer_equal:
2484
2612
  flag = enums.CU_FUNC_CACHE_PREFER_EQUAL
@@ -2501,15 +2629,17 @@ class CtypesFunction(Function):
2501
2629
  lmem = self.read_func_attr(enums.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
2502
2630
  smem = self.read_func_attr(enums.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
2503
2631
  maxtpb = self.read_func_attr(
2504
- enums.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
2505
- return FuncAttr(regs=nregs, const=cmem, local=lmem, shared=smem,
2506
- maxthreads=maxtpb)
2632
+ enums.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
2633
+ )
2634
+ return FuncAttr(
2635
+ regs=nregs, const=cmem, local=lmem, shared=smem, maxthreads=maxtpb
2636
+ )
2507
2637
 
2508
2638
 
2509
2639
  class CudaPythonFunction(Function):
2510
-
2511
- def cache_config(self, prefer_equal=False, prefer_cache=False,
2512
- prefer_shared=False):
2640
+ def cache_config(
2641
+ self, prefer_equal=False, prefer_cache=False, prefer_shared=False
2642
+ ):
2513
2643
  prefer_equal = prefer_equal or (prefer_cache and prefer_shared)
2514
2644
  attr = binding.CUfunction_attribute
2515
2645
  if prefer_equal:
@@ -2532,19 +2662,26 @@ class CudaPythonFunction(Function):
2532
2662
  lmem = self.read_func_attr(attr.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
2533
2663
  smem = self.read_func_attr(attr.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
2534
2664
  maxtpb = self.read_func_attr(
2535
- attr.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
2536
- return FuncAttr(regs=nregs, const=cmem, local=lmem, shared=smem,
2537
- maxthreads=maxtpb)
2538
-
2665
+ attr.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
2666
+ )
2667
+ return FuncAttr(
2668
+ regs=nregs, const=cmem, local=lmem, shared=smem, maxthreads=maxtpb
2669
+ )
2539
2670
 
2540
- def launch_kernel(cufunc_handle,
2541
- gx, gy, gz,
2542
- bx, by, bz,
2543
- sharedmem,
2544
- hstream,
2545
- args,
2546
- cooperative=False):
2547
2671
 
2672
+ def launch_kernel(
2673
+ cufunc_handle,
2674
+ gx,
2675
+ gy,
2676
+ gz,
2677
+ bx,
2678
+ by,
2679
+ bz,
2680
+ sharedmem,
2681
+ hstream,
2682
+ args,
2683
+ cooperative=False,
2684
+ ):
2548
2685
  param_ptrs = [addressof(arg) for arg in args]
2549
2686
  params = (c_void_p * len(param_ptrs))(*param_ptrs)
2550
2687
 
@@ -2556,46 +2693,54 @@ def launch_kernel(cufunc_handle,
2556
2693
  extra = None
2557
2694
 
2558
2695
  if cooperative:
2559
- driver.cuLaunchCooperativeKernel(cufunc_handle,
2560
- gx, gy, gz,
2561
- bx, by, bz,
2562
- sharedmem,
2563
- hstream,
2564
- params_for_launch)
2696
+ driver.cuLaunchCooperativeKernel(
2697
+ cufunc_handle,
2698
+ gx,
2699
+ gy,
2700
+ gz,
2701
+ bx,
2702
+ by,
2703
+ bz,
2704
+ sharedmem,
2705
+ hstream,
2706
+ params_for_launch,
2707
+ )
2565
2708
  else:
2566
- driver.cuLaunchKernel(cufunc_handle,
2567
- gx, gy, gz,
2568
- bx, by, bz,
2569
- sharedmem,
2570
- hstream,
2571
- params_for_launch,
2572
- extra)
2709
+ driver.cuLaunchKernel(
2710
+ cufunc_handle,
2711
+ gx,
2712
+ gy,
2713
+ gz,
2714
+ bx,
2715
+ by,
2716
+ bz,
2717
+ sharedmem,
2718
+ hstream,
2719
+ params_for_launch,
2720
+ extra,
2721
+ )
2573
2722
 
2574
2723
 
2575
2724
  class Linker(metaclass=ABCMeta):
2576
2725
  """Abstract base class for linkers"""
2577
2726
 
2578
2727
  @classmethod
2579
- def new(cls,
2580
- max_registers=0,
2581
- lineinfo=False,
2582
- cc=None,
2583
- lto=None,
2584
- additional_flags=None
2585
- ):
2586
-
2728
+ def new(
2729
+ cls,
2730
+ max_registers=0,
2731
+ lineinfo=False,
2732
+ cc=None,
2733
+ lto=None,
2734
+ additional_flags=None,
2735
+ ):
2587
2736
  driver_ver = driver.get_version()
2588
- if (
2589
- config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY
2590
- and driver_ver >= (12, 0)
2737
+ if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and driver_ver >= (
2738
+ 12,
2739
+ 0,
2591
2740
  ):
2592
- raise ValueError(
2593
- "Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC"
2594
- )
2741
+ raise ValueError("Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC")
2595
2742
  if config.CUDA_ENABLE_PYNVJITLINK and driver_ver < (12, 0):
2596
- raise ValueError(
2597
- "Enabling pynvjitlink requires CUDA 12."
2598
- )
2743
+ raise ValueError("Enabling pynvjitlink requires CUDA 12.")
2599
2744
  if config.CUDA_ENABLE_PYNVJITLINK:
2600
2745
  linker = PyNvJitLinker
2601
2746
 
@@ -2644,9 +2789,9 @@ class Linker(metaclass=ABCMeta):
2644
2789
  ptx, log = nvrtc.compile(cu, name, cc)
2645
2790
 
2646
2791
  if config.DUMP_ASSEMBLY:
2647
- print(("ASSEMBLY %s" % name).center(80, '-'))
2792
+ print(("ASSEMBLY %s" % name).center(80, "-"))
2648
2793
  print(ptx)
2649
- print('=' * 80)
2794
+ print("=" * 80)
2650
2795
 
2651
2796
  # Link the program's PTX using the normal linker mechanism
2652
2797
  ptx_name = os.path.splitext(name)[0] + ".ptx"
@@ -2657,7 +2802,7 @@ class Linker(metaclass=ABCMeta):
2657
2802
  """Add code from a file to the link"""
2658
2803
 
2659
2804
  def add_cu_file(self, path):
2660
- with open(path, 'rb') as f:
2805
+ with open(path, "rb") as f:
2661
2806
  cu = f.read()
2662
2807
  self.add_cu(cu, os.path.basename(path))
2663
2808
 
@@ -2675,24 +2820,24 @@ class Linker(metaclass=ABCMeta):
2675
2820
 
2676
2821
  if isinstance(path_or_code, str):
2677
2822
  ext = pathlib.Path(path_or_code).suffix
2678
- if ext == '':
2823
+ if ext == "":
2679
2824
  raise RuntimeError(
2680
2825
  "Don't know how to link file with no extension"
2681
2826
  )
2682
- elif ext == '.cu':
2827
+ elif ext == ".cu":
2683
2828
  self.add_cu_file(path_or_code)
2684
2829
  else:
2685
- kind = FILE_EXTENSION_MAP.get(ext.lstrip('.'), None)
2830
+ kind = FILE_EXTENSION_MAP.get(ext.lstrip("."), None)
2686
2831
  if kind is None:
2687
2832
  raise RuntimeError(
2688
- "Don't know how to link file with extension "
2689
- f"{ext}"
2833
+ f"Don't know how to link file with extension {ext}"
2690
2834
  )
2691
2835
 
2692
2836
  if ignore_nonlto:
2693
2837
  warn_and_return = False
2694
2838
  if kind in (
2695
- FILE_EXTENSION_MAP["fatbin"], FILE_EXTENSION_MAP["o"]
2839
+ FILE_EXTENSION_MAP["fatbin"],
2840
+ FILE_EXTENSION_MAP["o"],
2696
2841
  ):
2697
2842
  entry_types = inspect_obj_content(path_or_code)
2698
2843
  if "nvvm" not in entry_types:
@@ -2757,6 +2902,7 @@ class MVCLinker(Linker):
2757
2902
  Linker supporting Minor Version Compatibility, backed by the cubinlinker
2758
2903
  package.
2759
2904
  """
2905
+
2760
2906
  def __init__(self, max_registers=None, lineinfo=False, cc=None):
2761
2907
  try:
2762
2908
  from cubinlinker import CubinLinker
@@ -2764,18 +2910,20 @@ class MVCLinker(Linker):
2764
2910
  raise ImportError(_MVC_ERROR_MESSAGE) from err
2765
2911
 
2766
2912
  if cc is None:
2767
- raise RuntimeError("MVCLinker requires Compute Capability to be "
2768
- "specified, but cc is None")
2913
+ raise RuntimeError(
2914
+ "MVCLinker requires Compute Capability to be "
2915
+ "specified, but cc is None"
2916
+ )
2769
2917
 
2770
2918
  super().__init__(max_registers, lineinfo, cc)
2771
2919
 
2772
2920
  arch = f"sm_{cc[0] * 10 + cc[1]}"
2773
- ptx_compile_opts = ['--gpu-name', arch, '-c']
2921
+ ptx_compile_opts = ["--gpu-name", arch, "-c"]
2774
2922
  if max_registers:
2775
2923
  arg = f"--maxrregcount={max_registers}"
2776
2924
  ptx_compile_opts.append(arg)
2777
2925
  if lineinfo:
2778
- ptx_compile_opts.append('--generate-line-info')
2926
+ ptx_compile_opts.append("--generate-line-info")
2779
2927
  self.ptx_compile_options = tuple(ptx_compile_opts)
2780
2928
 
2781
2929
  self._linker = CubinLinker(f"--arch={arch}")
@@ -2788,7 +2936,7 @@ class MVCLinker(Linker):
2788
2936
  def error_log(self):
2789
2937
  return self._linker.error_log
2790
2938
 
2791
- def add_ptx(self, ptx, name='<cudapy-ptx>'):
2939
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
2792
2940
  try:
2793
2941
  from ptxcompiler import compile_ptx
2794
2942
  from cubinlinker import CubinLinkerError
@@ -2807,19 +2955,19 @@ class MVCLinker(Linker):
2807
2955
  raise ImportError(_MVC_ERROR_MESSAGE) from err
2808
2956
 
2809
2957
  try:
2810
- with open(path, 'rb') as f:
2958
+ with open(path, "rb") as f:
2811
2959
  data = f.read()
2812
2960
  except FileNotFoundError:
2813
- raise LinkerError(f'{path} not found')
2961
+ raise LinkerError(f"{path} not found")
2814
2962
 
2815
2963
  name = pathlib.Path(path).name
2816
- if kind == FILE_EXTENSION_MAP['cubin']:
2964
+ if kind == FILE_EXTENSION_MAP["cubin"]:
2817
2965
  fn = self._linker.add_cubin
2818
- elif kind == FILE_EXTENSION_MAP['fatbin']:
2966
+ elif kind == FILE_EXTENSION_MAP["fatbin"]:
2819
2967
  fn = self._linker.add_fatbin
2820
- elif kind == FILE_EXTENSION_MAP['a']:
2968
+ elif kind == FILE_EXTENSION_MAP["a"]:
2821
2969
  raise LinkerError(f"Don't know how to link {kind}")
2822
- elif kind == FILE_EXTENSION_MAP['ptx']:
2970
+ elif kind == FILE_EXTENSION_MAP["ptx"]:
2823
2971
  return self.add_ptx(data, name)
2824
2972
  else:
2825
2973
  raise LinkerError(f"Don't know how to link {kind}")
@@ -2845,6 +2993,7 @@ class CtypesLinker(Linker):
2845
2993
  """
2846
2994
  Links for current device if no CC given
2847
2995
  """
2996
+
2848
2997
  def __init__(self, max_registers=0, lineinfo=False, cc=None):
2849
2998
  super().__init__(max_registers, lineinfo, cc)
2850
2999
 
@@ -2878,8 +3027,9 @@ class CtypesLinker(Linker):
2878
3027
  option_vals = (c_void_p * len(raw_values))(*raw_values)
2879
3028
 
2880
3029
  self.handle = handle = drvapi.cu_link_state()
2881
- driver.cuLinkCreate(len(raw_keys), option_keys, option_vals,
2882
- byref(self.handle))
3030
+ driver.cuLinkCreate(
3031
+ len(raw_keys), option_keys, option_vals, byref(self.handle)
3032
+ )
2883
3033
 
2884
3034
  weakref.finalize(self, driver.cuLinkDestroy, handle)
2885
3035
 
@@ -2890,19 +3040,27 @@ class CtypesLinker(Linker):
2890
3040
 
2891
3041
  @property
2892
3042
  def info_log(self):
2893
- return self.linker_info_buf.value.decode('utf8')
3043
+ return self.linker_info_buf.value.decode("utf8")
2894
3044
 
2895
3045
  @property
2896
3046
  def error_log(self):
2897
- return self.linker_errors_buf.value.decode('utf8')
3047
+ return self.linker_errors_buf.value.decode("utf8")
2898
3048
 
2899
- def add_ptx(self, ptx, name='<cudapy-ptx>'):
3049
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
2900
3050
  ptxbuf = c_char_p(ptx)
2901
- namebuf = c_char_p(name.encode('utf8'))
3051
+ namebuf = c_char_p(name.encode("utf8"))
2902
3052
  self._keep_alive += [ptxbuf, namebuf]
2903
3053
  try:
2904
- driver.cuLinkAddData(self.handle, enums.CU_JIT_INPUT_PTX,
2905
- ptxbuf, len(ptx), namebuf, 0, None, None)
3054
+ driver.cuLinkAddData(
3055
+ self.handle,
3056
+ enums.CU_JIT_INPUT_PTX,
3057
+ ptxbuf,
3058
+ len(ptx),
3059
+ namebuf,
3060
+ 0,
3061
+ None,
3062
+ None,
3063
+ )
2906
3064
  except CudaAPIError as e:
2907
3065
  raise LinkerError("%s\n%s" % (e, self.error_log))
2908
3066
 
@@ -2914,7 +3072,7 @@ class CtypesLinker(Linker):
2914
3072
  driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, None, None)
2915
3073
  except CudaAPIError as e:
2916
3074
  if e.code == enums.CUDA_ERROR_FILE_NOT_FOUND:
2917
- msg = f'{path} not found'
3075
+ msg = f"{path} not found"
2918
3076
  else:
2919
3077
  msg = "%s\n%s" % (e, self.error_log)
2920
3078
  raise LinkerError(msg)
@@ -2929,7 +3087,7 @@ class CtypesLinker(Linker):
2929
3087
  raise LinkerError("%s\n%s" % (e, self.error_log))
2930
3088
 
2931
3089
  size = size.value
2932
- assert size > 0, 'linker returned a zero sized cubin'
3090
+ assert size > 0, "linker returned a zero sized cubin"
2933
3091
  del self._keep_alive[:]
2934
3092
 
2935
3093
  # We return a copy of the cubin because it's owned by the linker
@@ -2941,6 +3099,7 @@ class CudaPythonLinker(Linker):
2941
3099
  """
2942
3100
  Links for current device if no CC given
2943
3101
  """
3102
+
2944
3103
  def __init__(self, max_registers=0, lineinfo=False, cc=None):
2945
3104
  super().__init__(max_registers, lineinfo, cc)
2946
3105
 
@@ -2967,8 +3126,9 @@ class CudaPythonLinker(Linker):
2967
3126
  options[jit_option.CU_JIT_TARGET_FROM_CUCONTEXT] = 1
2968
3127
  else:
2969
3128
  cc_val = cc[0] * 10 + cc[1]
2970
- cc_enum = getattr(binding.CUjit_target,
2971
- f'CU_TARGET_COMPUTE_{cc_val}')
3129
+ cc_enum = getattr(
3130
+ binding.CUjit_target, f"CU_TARGET_COMPUTE_{cc_val}"
3131
+ )
2972
3132
  options[jit_option.CU_JIT_TARGET] = cc_enum
2973
3133
 
2974
3134
  raw_keys = list(options.keys())
@@ -2985,19 +3145,20 @@ class CudaPythonLinker(Linker):
2985
3145
 
2986
3146
  @property
2987
3147
  def info_log(self):
2988
- return self.linker_info_buf.decode('utf8')
3148
+ return self.linker_info_buf.decode("utf8")
2989
3149
 
2990
3150
  @property
2991
3151
  def error_log(self):
2992
- return self.linker_errors_buf.decode('utf8')
3152
+ return self.linker_errors_buf.decode("utf8")
2993
3153
 
2994
- def add_ptx(self, ptx, name='<cudapy-ptx>'):
2995
- namebuf = name.encode('utf8')
3154
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
3155
+ namebuf = name.encode("utf8")
2996
3156
  self._keep_alive += [ptx, namebuf]
2997
3157
  try:
2998
3158
  input_ptx = binding.CUjitInputType.CU_JIT_INPUT_PTX
2999
- driver.cuLinkAddData(self.handle, input_ptx, ptx, len(ptx),
3000
- namebuf, 0, [], [])
3159
+ driver.cuLinkAddData(
3160
+ self.handle, input_ptx, ptx, len(ptx), namebuf, 0, [], []
3161
+ )
3001
3162
  except CudaAPIError as e:
3002
3163
  raise LinkerError("%s\n%s" % (e, self.error_log))
3003
3164
 
@@ -3009,7 +3170,7 @@ class CudaPythonLinker(Linker):
3009
3170
  driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, [], [])
3010
3171
  except CudaAPIError as e:
3011
3172
  if e.code == binding.CUresult.CUDA_ERROR_FILE_NOT_FOUND:
3012
- msg = f'{path} not found'
3173
+ msg = f"{path} not found"
3013
3174
  else:
3014
3175
  msg = "%s\n%s" % (e, self.error_log)
3015
3176
  raise LinkerError(msg)
@@ -3020,7 +3181,7 @@ class CudaPythonLinker(Linker):
3020
3181
  except CudaAPIError as e:
3021
3182
  raise LinkerError("%s\n%s" % (e, self.error_log))
3022
3183
 
3023
- assert size > 0, 'linker returned a zero sized cubin'
3184
+ assert size > 0, "linker returned a zero sized cubin"
3024
3185
  del self._keep_alive[:]
3025
3186
  # We return a copy of the cubin because it's owned by the linker
3026
3187
  cubin_ptr = ctypes.cast(cubin_buf, ctypes.POINTER(ctypes.c_char))
@@ -3154,6 +3315,7 @@ class PyNvJitLinker(Linker):
3154
3315
  except NvJitLinkError as e:
3155
3316
  raise LinkerError from e
3156
3317
 
3318
+
3157
3319
  # -----------------------------------------------------------------------------
3158
3320
 
3159
3321
 
@@ -3203,7 +3365,7 @@ def device_memory_size(devmem):
3203
3365
  The result is cached in the device memory object.
3204
3366
  It may query the driver for the memory size of the device memory allocation.
3205
3367
  """
3206
- sz = getattr(devmem, '_cuda_memsize_', None)
3368
+ sz = getattr(devmem, "_cuda_memsize_", None)
3207
3369
  if sz is None:
3208
3370
  s, e = device_extents(devmem)
3209
3371
  if USE_NV_BINDING:
@@ -3216,10 +3378,9 @@ def device_memory_size(devmem):
3216
3378
 
3217
3379
 
3218
3380
  def _is_datetime_dtype(obj):
3219
- """Returns True if the obj.dtype is datetime64 or timedelta64
3220
- """
3221
- dtype = getattr(obj, 'dtype', None)
3222
- return dtype is not None and dtype.char in 'Mm'
3381
+ """Returns True if the obj.dtype is datetime64 or timedelta64"""
3382
+ dtype = getattr(obj, "dtype", None)
3383
+ return dtype is not None and dtype.char in "Mm"
3223
3384
 
3224
3385
 
3225
3386
  def _workaround_for_datetime(obj):
@@ -3298,12 +3459,11 @@ def is_device_memory(obj):
3298
3459
  "device_pointer" which value is an int object carrying the pointer
3299
3460
  value of the device memory address. This is not tested in this method.
3300
3461
  """
3301
- return getattr(obj, '__cuda_memory__', False)
3462
+ return getattr(obj, "__cuda_memory__", False)
3302
3463
 
3303
3464
 
3304
3465
  def require_device_memory(obj):
3305
- """A sentry for methods that accept CUDA memory object.
3306
- """
3466
+ """A sentry for methods that accept CUDA memory object."""
3307
3467
  if not is_device_memory(obj):
3308
3468
  raise Exception("Not a CUDA memory object.")
3309
3469
 
@@ -3394,16 +3554,16 @@ def device_memset(dst, val, size, stream=0):
3394
3554
 
3395
3555
 
3396
3556
  def profile_start():
3397
- '''
3557
+ """
3398
3558
  Enable profile collection in the current context.
3399
- '''
3559
+ """
3400
3560
  driver.cuProfilerStart()
3401
3561
 
3402
3562
 
3403
3563
  def profile_stop():
3404
- '''
3564
+ """
3405
3565
  Disable profile collection in the current context.
3406
- '''
3566
+ """
3407
3567
  driver.cuProfilerStop()
3408
3568
 
3409
3569
 
@@ -3430,18 +3590,21 @@ def inspect_obj_content(objpath: str):
3430
3590
  Given path to a fatbin or object, use `cuobjdump` to examine its content
3431
3591
  Return the set of entries in the object.
3432
3592
  """
3433
- code_types :set[str] = set()
3593
+ code_types: set[str] = set()
3434
3594
 
3435
3595
  try:
3436
- out = subprocess.run(["cuobjdump", objpath], check=True,
3437
- capture_output=True)
3596
+ out = subprocess.run(
3597
+ ["cuobjdump", objpath], check=True, capture_output=True
3598
+ )
3438
3599
  except FileNotFoundError as e:
3439
- msg = ("cuobjdump has not been found. You may need "
3440
- "to install the CUDA toolkit and ensure that "
3441
- "it is available on your PATH.\n")
3600
+ msg = (
3601
+ "cuobjdump has not been found. You may need "
3602
+ "to install the CUDA toolkit and ensure that "
3603
+ "it is available on your PATH.\n"
3604
+ )
3442
3605
  raise RuntimeError(msg) from e
3443
3606
 
3444
- objtable = out.stdout.decode('utf-8')
3607
+ objtable = out.stdout.decode("utf-8")
3445
3608
  entry_pattern = r"Fatbin (.*) code"
3446
3609
  for line in objtable.split("\n"):
3447
3610
  if match := re.match(entry_pattern, line):