numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.0.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ subsequent deallocation could further corrupt the CUDA context and causes the
10
10
  system to freeze in some cases.
11
11
 
12
12
  """
13
+
13
14
  import sys
14
15
  import os
15
16
  import ctypes
@@ -25,8 +26,17 @@ import tempfile
25
26
  import re
26
27
  from itertools import product
27
28
  from abc import ABCMeta, abstractmethod
28
- from ctypes import (c_int, byref, c_size_t, c_char, c_char_p, addressof,
29
- c_void_p, c_float, c_uint)
29
+ from ctypes import (
30
+ c_int,
31
+ byref,
32
+ c_size_t,
33
+ c_char,
34
+ c_char_p,
35
+ addressof,
36
+ c_void_p,
37
+ c_float,
38
+ c_uint,
39
+ )
30
40
  import contextlib
31
41
  import importlib
32
42
  import numpy as np
@@ -51,13 +61,14 @@ USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
51
61
 
52
62
  if USE_NV_BINDING:
53
63
  from cuda import cuda as binding
64
+
54
65
  # There is no definition of the default stream in the Nvidia bindings (nor
55
66
  # is there at the C/C++ level), so we define it here so we don't need to
56
67
  # use a magic number 0 in places where we want the default stream.
57
68
  CU_STREAM_DEFAULT = 0
58
69
 
59
70
  MIN_REQUIRED_CC = (3, 5)
60
- SUPPORTS_IPC = sys.platform.startswith('linux')
71
+ SUPPORTS_IPC = sys.platform.startswith("linux")
61
72
 
62
73
 
63
74
  _py_decref = ctypes.pythonapi.Py_DecRef
@@ -71,10 +82,9 @@ _MVC_ERROR_MESSAGE = (
71
82
  "to be available"
72
83
  )
73
84
 
74
- ENABLE_PYNVJITLINK = (
75
- _readenv("NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, False)
76
- or getattr(config, "CUDA_ENABLE_PYNVJITLINK", False)
77
- )
85
+ ENABLE_PYNVJITLINK = _readenv(
86
+ "NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, False
87
+ ) or getattr(config, "CUDA_ENABLE_PYNVJITLINK", False)
78
88
  if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"):
79
89
  config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK
80
90
 
@@ -94,7 +104,7 @@ def make_logger():
94
104
  if config.CUDA_LOG_LEVEL:
95
105
  # create a simple handler that prints to stderr
96
106
  handler = logging.StreamHandler(sys.stderr)
97
- fmt = '== CUDA [%(relativeCreated)d] %(levelname)5s -- %(message)s'
107
+ fmt = "== CUDA [%(relativeCreated)d] %(levelname)5s -- %(message)s"
98
108
  handler.setFormatter(logging.Formatter(fmt=fmt))
99
109
  logger.addHandler(handler)
100
110
  else:
@@ -122,50 +132,52 @@ class CudaAPIError(CudaDriverError):
122
132
 
123
133
 
124
134
  def locate_driver_and_loader():
125
-
126
135
  envpath = config.CUDA_DRIVER
127
136
 
128
- if envpath == '0':
137
+ if envpath == "0":
129
138
  # Force fail
130
139
  _raise_driver_not_found()
131
140
 
132
141
  # Determine DLL type
133
- if sys.platform == 'win32':
142
+ if sys.platform == "win32":
134
143
  dlloader = ctypes.WinDLL
135
- dldir = ['\\windows\\system32']
136
- dlnames = ['nvcuda.dll']
137
- elif sys.platform == 'darwin':
144
+ dldir = ["\\windows\\system32"]
145
+ dlnames = ["nvcuda.dll"]
146
+ elif sys.platform == "darwin":
138
147
  dlloader = ctypes.CDLL
139
- dldir = ['/usr/local/cuda/lib']
140
- dlnames = ['libcuda.dylib']
148
+ dldir = ["/usr/local/cuda/lib"]
149
+ dlnames = ["libcuda.dylib"]
141
150
  else:
142
151
  # Assume to be *nix like
143
152
  dlloader = ctypes.CDLL
144
- dldir = ['/usr/lib', '/usr/lib64']
145
- dlnames = ['libcuda.so', 'libcuda.so.1']
153
+ dldir = ["/usr/lib", "/usr/lib64"]
154
+ dlnames = ["libcuda.so", "libcuda.so.1"]
146
155
 
147
156
  if envpath:
148
157
  try:
149
158
  envpath = os.path.abspath(envpath)
150
159
  except ValueError:
151
- raise ValueError("NUMBA_CUDA_DRIVER %s is not a valid path" %
152
- envpath)
160
+ raise ValueError(
161
+ "NUMBA_CUDA_DRIVER %s is not a valid path" % envpath
162
+ )
153
163
  if not os.path.isfile(envpath):
154
- raise ValueError("NUMBA_CUDA_DRIVER %s is not a valid file "
155
- "path. Note it must be a filepath of the .so/"
156
- ".dll/.dylib or the driver" % envpath)
164
+ raise ValueError(
165
+ "NUMBA_CUDA_DRIVER %s is not a valid file "
166
+ "path. Note it must be a filepath of the .so/"
167
+ ".dll/.dylib or the driver" % envpath
168
+ )
157
169
  candidates = [envpath]
158
170
  else:
159
171
  # First search for the name in the default library path.
160
172
  # If that is not found, try the specific path.
161
- candidates = dlnames + [os.path.join(x, y)
162
- for x, y in product(dldir, dlnames)]
173
+ candidates = dlnames + [
174
+ os.path.join(x, y) for x, y in product(dldir, dlnames)
175
+ ]
163
176
 
164
177
  return dlloader, candidates
165
178
 
166
179
 
167
180
  def load_driver(dlloader, candidates):
168
-
169
181
  # Load the driver; Collect driver error information
170
182
  path_not_exist = []
171
183
  driver_load_error = []
@@ -184,7 +196,7 @@ def load_driver(dlloader, candidates):
184
196
  if all(path_not_exist):
185
197
  _raise_driver_not_found()
186
198
  else:
187
- errmsg = '\n'.join(str(e) for e in driver_load_error)
199
+ errmsg = "\n".join(str(e) for e in driver_load_error)
188
200
  _raise_driver_error(errmsg)
189
201
 
190
202
 
@@ -216,7 +228,7 @@ def _raise_driver_error(e):
216
228
 
217
229
 
218
230
  def _build_reverse_error_map():
219
- prefix = 'CUDA_ERROR'
231
+ prefix = "CUDA_ERROR"
220
232
  map = utils.UniqueDict()
221
233
  for name in dir(enums):
222
234
  if name.startswith(prefix):
@@ -236,6 +248,7 @@ class Driver(object):
236
248
  """
237
249
  Driver API functions are lazily bound.
238
250
  """
251
+
239
252
  _singleton = None
240
253
 
241
254
  def __new__(cls):
@@ -254,9 +267,11 @@ class Driver(object):
254
267
  self.pid = None
255
268
  try:
256
269
  if config.DISABLE_CUDA:
257
- msg = ("CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 "
258
- "in the environment, or because CUDA is unsupported on "
259
- "32-bit systems.")
270
+ msg = (
271
+ "CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 "
272
+ "in the environment, or because CUDA is unsupported on "
273
+ "32-bit systems."
274
+ )
260
275
  raise CudaSupportError(msg)
261
276
  self.lib = find_driver()
262
277
  except CudaSupportError as e:
@@ -273,7 +288,7 @@ class Driver(object):
273
288
 
274
289
  self.is_initialized = True
275
290
  try:
276
- _logger.info('init')
291
+ _logger.info("init")
277
292
  self.cuInit(0)
278
293
  except CudaAPIError as e:
279
294
  description = f"{e.msg} ({e.code})"
@@ -292,8 +307,9 @@ class Driver(object):
292
307
  self.ensure_initialized()
293
308
 
294
309
  if self.initialization_error is not None:
295
- raise CudaSupportError("Error at driver init: \n%s:" %
296
- self.initialization_error)
310
+ raise CudaSupportError(
311
+ "Error at driver init: \n%s:" % self.initialization_error
312
+ )
297
313
 
298
314
  if USE_NV_BINDING:
299
315
  return self._cuda_python_wrap_fn(fname)
@@ -317,12 +333,12 @@ class Driver(object):
317
333
 
318
334
  def verbose_cuda_api_call(*args):
319
335
  argstr = ", ".join([str(arg) for arg in args])
320
- _logger.debug('call driver api: %s(%s)', libfn.__name__, argstr)
336
+ _logger.debug("call driver api: %s(%s)", libfn.__name__, argstr)
321
337
  retcode = libfn(*args)
322
338
  self._check_ctypes_error(fname, retcode)
323
339
 
324
340
  def safe_cuda_api_call(*args):
325
- _logger.debug('call driver api: %s', libfn.__name__)
341
+ _logger.debug("call driver api: %s", libfn.__name__)
326
342
  retcode = libfn(*args)
327
343
  self._check_ctypes_error(fname, retcode)
328
344
 
@@ -340,11 +356,11 @@ class Driver(object):
340
356
 
341
357
  def verbose_cuda_api_call(*args):
342
358
  argstr = ", ".join([str(arg) for arg in args])
343
- _logger.debug('call driver api: %s(%s)', libfn.__name__, argstr)
359
+ _logger.debug("call driver api: %s(%s)", libfn.__name__, argstr)
344
360
  return self._check_cuda_python_error(fname, libfn(*args))
345
361
 
346
362
  def safe_cuda_api_call(*args):
347
- _logger.debug('call driver api: %s', libfn.__name__)
363
+ _logger.debug("call driver api: %s", libfn.__name__)
348
364
  return self._check_cuda_python_error(fname, libfn(*args))
349
365
 
350
366
  if config.CUDA_LOG_API_ARGS:
@@ -361,27 +377,30 @@ class Driver(object):
361
377
  # binding. For the NVidia binding, it handles linking to the correct
362
378
  # variant.
363
379
  if config.CUDA_PER_THREAD_DEFAULT_STREAM and not USE_NV_BINDING:
364
- variants = ('_v2_ptds', '_v2_ptsz', '_ptds', '_ptsz', '_v2', '')
380
+ variants = ("_v2_ptds", "_v2_ptsz", "_ptds", "_ptsz", "_v2", "")
365
381
  else:
366
- variants = ('_v2', '')
382
+ variants = ("_v2", "")
383
+
384
+ if fname in ("cuCtxGetDevice", "cuCtxSynchronize"):
385
+ return getattr(self.lib, fname)
367
386
 
368
387
  for variant in variants:
369
388
  try:
370
- return getattr(self.lib, f'{fname}{variant}')
389
+ return getattr(self.lib, f"{fname}{variant}")
371
390
  except AttributeError:
372
391
  pass
373
392
 
374
393
  # Not found.
375
394
  # Delay missing function error to use
376
395
  def absent_function(*args, **kws):
377
- raise CudaDriverError(f'Driver missing function: {fname}')
396
+ raise CudaDriverError(f"Driver missing function: {fname}")
378
397
 
379
398
  setattr(self, fname, absent_function)
380
399
  return absent_function
381
400
 
382
401
  def _detect_fork(self):
383
402
  if self.pid is not None and _getpid() != self.pid:
384
- msg = 'pid %s forked from pid %s after CUDA driver init'
403
+ msg = "pid %s forked from pid %s after CUDA driver init"
385
404
  _logger.critical(msg, _getpid(), self.pid)
386
405
  raise CudaDriverError("CUDA initialized before forking")
387
406
 
@@ -425,13 +444,11 @@ class Driver(object):
425
444
  return count.value
426
445
 
427
446
  def list_devices(self):
428
- """Returns a list of active devices
429
- """
447
+ """Returns a list of active devices"""
430
448
  return list(self.devices.values())
431
449
 
432
450
  def reset(self):
433
- """Reset all devices
434
- """
451
+ """Reset all devices"""
435
452
  for dev in self.devices.values():
436
453
  dev.reset()
437
454
 
@@ -449,8 +466,7 @@ class Driver(object):
449
466
  return popped
450
467
 
451
468
  def get_active_context(self):
452
- """Returns an instance of ``_ActiveContext``.
453
- """
469
+ """Returns an instance of ``_ActiveContext``."""
454
470
  return _ActiveContext()
455
471
 
456
472
  def get_version(self):
@@ -477,12 +493,13 @@ class _ActiveContext(object):
477
493
  Once entering the context, it is assumed that the active CUDA context is
478
494
  not changed until the context is exited.
479
495
  """
496
+
480
497
  _tls_cache = threading.local()
481
498
 
482
499
  def __enter__(self):
483
500
  is_top = False
484
501
  # check TLS cache
485
- if hasattr(self._tls_cache, 'ctx_devnum'):
502
+ if hasattr(self._tls_cache, "ctx_devnum"):
486
503
  hctx, devnum = self._tls_cache.ctx_devnum
487
504
  # Not cached. Query the driver API.
488
505
  else:
@@ -515,11 +532,10 @@ class _ActiveContext(object):
515
532
 
516
533
  def __exit__(self, exc_type, exc_val, exc_tb):
517
534
  if self._is_top:
518
- delattr(self._tls_cache, 'ctx_devnum')
535
+ delattr(self._tls_cache, "ctx_devnum")
519
536
 
520
537
  def __bool__(self):
521
- """Returns True is there's a valid and active CUDA context.
522
- """
538
+ """Returns True is there's a valid and active CUDA context."""
523
539
  return self.context_handle is not None
524
540
 
525
541
  __nonzero__ = __bool__
@@ -533,7 +549,7 @@ def _build_reverse_device_attrs():
533
549
  map = utils.UniqueDict()
534
550
  for name in dir(enums):
535
551
  if name.startswith(prefix):
536
- map[name[len(prefix):]] = getattr(enums, name)
552
+ map[name[len(prefix) :]] = getattr(enums, name)
537
553
  return map
538
554
 
539
555
 
@@ -545,6 +561,7 @@ class Device(object):
545
561
  The device object owns the CUDA contexts. This is owned by the driver
546
562
  object. User should not construct devices directly.
547
563
  """
564
+
548
565
  @classmethod
549
566
  def from_identity(self, identity):
550
567
  """Create Device object from device identity created by
@@ -579,15 +596,17 @@ class Device(object):
579
596
  self.attributes = {}
580
597
 
581
598
  # Read compute capability
582
- self.compute_capability = (self.COMPUTE_CAPABILITY_MAJOR,
583
- self.COMPUTE_CAPABILITY_MINOR)
599
+ self.compute_capability = (
600
+ self.COMPUTE_CAPABILITY_MAJOR,
601
+ self.COMPUTE_CAPABILITY_MINOR,
602
+ )
584
603
 
585
604
  # Read name
586
605
  bufsz = 128
587
606
 
588
607
  if USE_NV_BINDING:
589
608
  buf = driver.cuDeviceGetName(bufsz, self.id)
590
- name = buf.decode('utf-8').rstrip('\0')
609
+ name = buf.decode("utf-8").rstrip("\0")
591
610
  else:
592
611
  buf = (c_char * bufsz)()
593
612
  driver.cuDeviceGetName(buf, bufsz, self.id)
@@ -604,31 +623,31 @@ class Device(object):
604
623
  driver.cuDeviceGetUuid(byref(uuid), self.id)
605
624
  uuid_vals = tuple(bytes(uuid))
606
625
 
607
- b = '%02x'
626
+ b = "%02x"
608
627
  b2 = b * 2
609
628
  b4 = b * 4
610
629
  b6 = b * 6
611
- fmt = f'GPU-{b4}-{b2}-{b2}-{b2}-{b6}'
630
+ fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}"
612
631
  self.uuid = fmt % uuid_vals
613
632
 
614
633
  self.primary_context = None
615
634
 
616
635
  def get_device_identity(self):
617
636
  return {
618
- 'pci_domain_id': self.PCI_DOMAIN_ID,
619
- 'pci_bus_id': self.PCI_BUS_ID,
620
- 'pci_device_id': self.PCI_DEVICE_ID,
637
+ "pci_domain_id": self.PCI_DOMAIN_ID,
638
+ "pci_bus_id": self.PCI_BUS_ID,
639
+ "pci_device_id": self.PCI_DEVICE_ID,
621
640
  }
622
641
 
623
642
  def __repr__(self):
624
643
  return "<CUDA device %d '%s'>" % (self.id, self.name)
625
644
 
626
645
  def __getattr__(self, attr):
627
- """Read attributes lazily
628
- """
646
+ """Read attributes lazily"""
629
647
  if USE_NV_BINDING:
630
- code = getattr(binding.CUdevice_attribute,
631
- f'CU_DEVICE_ATTRIBUTE_{attr}')
648
+ code = getattr(
649
+ binding.CUdevice_attribute, f"CU_DEVICE_ATTRIBUTE_{attr}"
650
+ )
632
651
  value = driver.cuDeviceGetAttribute(code, self.id)
633
652
  else:
634
653
  try:
@@ -698,17 +717,18 @@ class Device(object):
698
717
 
699
718
  def met_requirement_for_device(device):
700
719
  if device.compute_capability < MIN_REQUIRED_CC:
701
- raise CudaSupportError("%s has compute capability < %s" %
702
- (device, MIN_REQUIRED_CC))
720
+ raise CudaSupportError(
721
+ "%s has compute capability < %s" % (device, MIN_REQUIRED_CC)
722
+ )
703
723
 
704
724
 
705
725
  class BaseCUDAMemoryManager(object, metaclass=ABCMeta):
706
726
  """Abstract base class for External Memory Management (EMM) Plugins."""
707
727
 
708
728
  def __init__(self, *args, **kwargs):
709
- if 'context' not in kwargs:
729
+ if "context" not in kwargs:
710
730
  raise RuntimeError("Memory manager requires a context")
711
- self.context = kwargs.pop('context')
731
+ self.context = kwargs.pop("context")
712
732
 
713
733
  @abstractmethod
714
734
  def memalloc(self, size):
@@ -864,8 +884,7 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
864
884
  else:
865
885
  raise
866
886
 
867
- def memhostalloc(self, size, mapped=False, portable=False,
868
- wc=False):
887
+ def memhostalloc(self, size, mapped=False, portable=False, wc=False):
869
888
  """Implements the allocation of pinned host memory.
870
889
 
871
890
  It is recommended that this method is not overridden by EMM Plugin
@@ -880,6 +899,7 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
880
899
  flags |= enums.CU_MEMHOSTALLOC_WRITECOMBINED
881
900
 
882
901
  if USE_NV_BINDING:
902
+
883
903
  def allocator():
884
904
  return driver.cuMemHostAlloc(size, flags)
885
905
 
@@ -946,16 +966,19 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
946
966
  ctx = weakref.proxy(self.context)
947
967
 
948
968
  if mapped:
949
- mem = MappedMemory(ctx, pointer, size, owner=owner,
950
- finalizer=finalizer)
969
+ mem = MappedMemory(
970
+ ctx, pointer, size, owner=owner, finalizer=finalizer
971
+ )
951
972
  self.allocations[alloc_key] = mem
952
973
  return mem.own()
953
974
  else:
954
- return PinnedMemory(ctx, pointer, size, owner=owner,
955
- finalizer=finalizer)
975
+ return PinnedMemory(
976
+ ctx, pointer, size, owner=owner, finalizer=finalizer
977
+ )
956
978
 
957
979
  def memallocmanaged(self, size, attach_global):
958
980
  if USE_NV_BINDING:
981
+
959
982
  def allocator():
960
983
  ma_flags = binding.CUmemAttach_flags
961
984
 
@@ -1014,8 +1037,7 @@ class HostOnlyCUDAMemoryManager(BaseCUDAMemoryManager):
1014
1037
 
1015
1038
 
1016
1039
  class GetIpcHandleMixin:
1017
- """A class that provides a default implementation of ``get_ipc_handle()``.
1018
- """
1040
+ """A class that provides a default implementation of ``get_ipc_handle()``."""
1019
1041
 
1020
1042
  def get_ipc_handle(self, memory):
1021
1043
  """Open an IPC memory handle by using ``cuMemGetAddressRange`` to
@@ -1034,8 +1056,9 @@ class GetIpcHandleMixin:
1034
1056
  offset = memory.handle.value - base
1035
1057
  source_info = self.context.device.get_device_identity()
1036
1058
 
1037
- return IpcHandle(memory, ipchandle, memory.size, source_info,
1038
- offset=offset)
1059
+ return IpcHandle(
1060
+ memory, ipchandle, memory.size, source_info, offset=offset
1061
+ )
1039
1062
 
1040
1063
 
1041
1064
  class NumbaCUDAMemoryManager(GetIpcHandleMixin, HostOnlyCUDAMemoryManager):
@@ -1050,6 +1073,7 @@ class NumbaCUDAMemoryManager(GetIpcHandleMixin, HostOnlyCUDAMemoryManager):
1050
1073
 
1051
1074
  def memalloc(self, size):
1052
1075
  if USE_NV_BINDING:
1076
+
1053
1077
  def allocator():
1054
1078
  return driver.cuMemAlloc(size)
1055
1079
 
@@ -1098,7 +1122,7 @@ def _ensure_memory_manager():
1098
1122
  if _memory_manager:
1099
1123
  return
1100
1124
 
1101
- if config.CUDA_MEMORY_MANAGER == 'default':
1125
+ if config.CUDA_MEMORY_MANAGER == "default":
1102
1126
  _memory_manager = NumbaCUDAMemoryManager
1103
1127
  return
1104
1128
 
@@ -1106,8 +1130,9 @@ def _ensure_memory_manager():
1106
1130
  mgr_module = importlib.import_module(config.CUDA_MEMORY_MANAGER)
1107
1131
  set_memory_manager(mgr_module._numba_memory_manager)
1108
1132
  except Exception:
1109
- raise RuntimeError("Failed to use memory manager from %s" %
1110
- config.CUDA_MEMORY_MANAGER)
1133
+ raise RuntimeError(
1134
+ "Failed to use memory manager from %s" % config.CUDA_MEMORY_MANAGER
1135
+ )
1111
1136
 
1112
1137
 
1113
1138
  def set_memory_manager(mm_plugin):
@@ -1124,8 +1149,10 @@ def set_memory_manager(mm_plugin):
1124
1149
  dummy = mm_plugin(context=None)
1125
1150
  iv = dummy.interface_version
1126
1151
  if iv != _SUPPORTED_EMM_INTERFACE_VERSION:
1127
- err = "EMM Plugin interface has version %d - version %d required" \
1128
- % (iv, _SUPPORTED_EMM_INTERFACE_VERSION)
1152
+ err = "EMM Plugin interface has version %d - version %d required" % (
1153
+ iv,
1154
+ _SUPPORTED_EMM_INTERFACE_VERSION,
1155
+ )
1129
1156
  raise RuntimeError(err)
1130
1157
 
1131
1158
  _memory_manager = mm_plugin
@@ -1140,7 +1167,7 @@ class _SizeNotSet(int):
1140
1167
  return super().__new__(cls, 0)
1141
1168
 
1142
1169
  def __str__(self):
1143
- return '?'
1170
+ return "?"
1144
1171
 
1145
1172
 
1146
1173
  _SizeNotSet = _SizeNotSet()
@@ -1153,6 +1180,7 @@ class _PendingDeallocs(object):
1153
1180
  modified later once the driver is initialized and the total memory capacity
1154
1181
  known.
1155
1182
  """
1183
+
1156
1184
  def __init__(self, capacity=_SizeNotSet):
1157
1185
  self._cons = deque()
1158
1186
  self._disable_count = 0
@@ -1172,11 +1200,13 @@ class _PendingDeallocs(object):
1172
1200
  byte size of the resource added. It is an optional argument. Some
1173
1201
  resources (e.g. CUModule) has an unknown memory footprint on the device.
1174
1202
  """
1175
- _logger.info('add pending dealloc: %s %s bytes', dtor.__name__, size)
1203
+ _logger.info("add pending dealloc: %s %s bytes", dtor.__name__, size)
1176
1204
  self._cons.append((dtor, handle, size))
1177
1205
  self._size += int(size)
1178
- if (len(self._cons) > config.CUDA_DEALLOCS_COUNT or
1179
- self._size > self._max_pending_bytes):
1206
+ if (
1207
+ len(self._cons) > config.CUDA_DEALLOCS_COUNT
1208
+ or self._size > self._max_pending_bytes
1209
+ ):
1180
1210
  self.clear()
1181
1211
 
1182
1212
  def clear(self):
@@ -1187,7 +1217,7 @@ class _PendingDeallocs(object):
1187
1217
  if not self.is_disabled:
1188
1218
  while self._cons:
1189
1219
  [dtor, handle, size] = self._cons.popleft()
1190
- _logger.info('dealloc: %s %s bytes', dtor.__name__, size)
1220
+ _logger.info("dealloc: %s %s bytes", dtor.__name__, size)
1191
1221
  dtor(handle)
1192
1222
  self._size = 0
1193
1223
 
@@ -1251,19 +1281,19 @@ class Context(object):
1251
1281
  Clean up all owned resources in this context.
1252
1282
  """
1253
1283
  # Free owned resources
1254
- _logger.info('reset context of device %s', self.device.id)
1284
+ _logger.info("reset context of device %s", self.device.id)
1255
1285
  self.memory_manager.reset()
1256
1286
  self.modules.clear()
1257
1287
  # Clear trash
1258
1288
  self.deallocations.clear()
1259
1289
 
1260
1290
  def get_memory_info(self):
1261
- """Returns (free, total) memory in bytes in the context.
1262
- """
1291
+ """Returns (free, total) memory in bytes in the context."""
1263
1292
  return self.memory_manager.get_memory_info()
1264
1293
 
1265
- def get_active_blocks_per_multiprocessor(self, func, blocksize, memsize,
1266
- flags=None):
1294
+ def get_active_blocks_per_multiprocessor(
1295
+ self, func, blocksize, memsize, flags=None
1296
+ ):
1267
1297
  """Return occupancy of a function.
1268
1298
  :param func: kernel for which occupancy is calculated
1269
1299
  :param blocksize: block size the kernel is intended to be launched with
@@ -1275,8 +1305,9 @@ class Context(object):
1275
1305
  else:
1276
1306
  return self._ctypes_active_blocks_per_multiprocessor(*args)
1277
1307
 
1278
- def _cuda_python_active_blocks_per_multiprocessor(self, func, blocksize,
1279
- memsize, flags):
1308
+ def _cuda_python_active_blocks_per_multiprocessor(
1309
+ self, func, blocksize, memsize, flags
1310
+ ):
1280
1311
  ps = [func.handle, blocksize, memsize]
1281
1312
 
1282
1313
  if not flags:
@@ -1285,8 +1316,9 @@ class Context(object):
1285
1316
  ps.append(flags)
1286
1317
  return driver.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(*ps)
1287
1318
 
1288
- def _ctypes_active_blocks_per_multiprocessor(self, func, blocksize,
1289
- memsize, flags):
1319
+ def _ctypes_active_blocks_per_multiprocessor(
1320
+ self, func, blocksize, memsize, flags
1321
+ ):
1290
1322
  retval = c_int()
1291
1323
  args = (byref(retval), func.handle, blocksize, memsize)
1292
1324
 
@@ -1297,8 +1329,9 @@ class Context(object):
1297
1329
 
1298
1330
  return retval.value
1299
1331
 
1300
- def get_max_potential_block_size(self, func, b2d_func, memsize,
1301
- blocksizelimit, flags=None):
1332
+ def get_max_potential_block_size(
1333
+ self, func, b2d_func, memsize, blocksizelimit, flags=None
1334
+ ):
1302
1335
  """Suggest a launch configuration with reasonable occupancy.
1303
1336
  :param func: kernel for which occupancy is calculated
1304
1337
  :param b2d_func: function that calculates how much per-block dynamic
@@ -1315,13 +1348,20 @@ class Context(object):
1315
1348
  else:
1316
1349
  return self._ctypes_max_potential_block_size(*args)
1317
1350
 
1318
- def _ctypes_max_potential_block_size(self, func, b2d_func, memsize,
1319
- blocksizelimit, flags):
1351
+ def _ctypes_max_potential_block_size(
1352
+ self, func, b2d_func, memsize, blocksizelimit, flags
1353
+ ):
1320
1354
  gridsize = c_int()
1321
1355
  blocksize = c_int()
1322
1356
  b2d_cb = cu_occupancy_b2d_size(b2d_func)
1323
- args = [byref(gridsize), byref(blocksize), func.handle, b2d_cb,
1324
- memsize, blocksizelimit]
1357
+ args = [
1358
+ byref(gridsize),
1359
+ byref(blocksize),
1360
+ func.handle,
1361
+ b2d_cb,
1362
+ memsize,
1363
+ blocksizelimit,
1364
+ ]
1325
1365
 
1326
1366
  if not flags:
1327
1367
  driver.cuOccupancyMaxPotentialBlockSize(*args)
@@ -1331,10 +1371,11 @@ class Context(object):
1331
1371
 
1332
1372
  return (gridsize.value, blocksize.value)
1333
1373
 
1334
- def _cuda_python_max_potential_block_size(self, func, b2d_func, memsize,
1335
- blocksizelimit, flags):
1374
+ def _cuda_python_max_potential_block_size(
1375
+ self, func, b2d_func, memsize, blocksizelimit, flags
1376
+ ):
1336
1377
  b2d_cb = ctypes.CFUNCTYPE(c_size_t, c_int)(b2d_func)
1337
- ptr = int.from_bytes(b2d_cb, byteorder='little')
1378
+ ptr = int.from_bytes(b2d_cb, byteorder="little")
1338
1379
  driver_b2d_cb = binding.CUoccupancyB2DSize(ptr)
1339
1380
  args = [func.handle, driver_b2d_cb, memsize, blocksizelimit]
1340
1381
 
@@ -1387,7 +1428,7 @@ class Context(object):
1387
1428
  Returns an *IpcHandle* from a GPU allocation.
1388
1429
  """
1389
1430
  if not SUPPORTS_IPC:
1390
- raise OSError('OS does not support CUDA IPC')
1431
+ raise OSError("OS does not support CUDA IPC")
1391
1432
  return self.memory_manager.get_ipc_handle(memory)
1392
1433
 
1393
1434
  def open_ipc_handle(self, handle, size):
@@ -1400,13 +1441,13 @@ class Context(object):
1400
1441
  driver.cuIpcOpenMemHandle(byref(dptr), handle, flags)
1401
1442
 
1402
1443
  # wrap it
1403
- return MemoryPointer(context=weakref.proxy(self), pointer=dptr,
1404
- size=size)
1444
+ return MemoryPointer(
1445
+ context=weakref.proxy(self), pointer=dptr, size=size
1446
+ )
1405
1447
 
1406
1448
  def enable_peer_access(self, peer_context, flags=0):
1407
- """Enable peer access between the current context and the peer context
1408
- """
1409
- assert flags == 0, '*flags* is reserved and MUST be zero'
1449
+ """Enable peer access between the current context and the peer context"""
1450
+ assert flags == 0, "*flags* is reserved and MUST be zero"
1410
1451
  driver.cuCtxEnablePeerAccess(peer_context, flags)
1411
1452
 
1412
1453
  def can_access_peer(self, peer_device):
@@ -1415,26 +1456,34 @@ class Context(object):
1415
1456
  """
1416
1457
  if USE_NV_BINDING:
1417
1458
  peer_device = binding.CUdevice(peer_device)
1418
- can_access_peer = driver.cuDeviceCanAccessPeer(self.device.id,
1419
- peer_device)
1459
+ can_access_peer = driver.cuDeviceCanAccessPeer(
1460
+ self.device.id, peer_device
1461
+ )
1420
1462
  else:
1421
1463
  can_access_peer = c_int()
1422
- driver.cuDeviceCanAccessPeer(byref(can_access_peer),
1423
- self.device.id, peer_device,)
1464
+ driver.cuDeviceCanAccessPeer(
1465
+ byref(can_access_peer),
1466
+ self.device.id,
1467
+ peer_device,
1468
+ )
1424
1469
 
1425
1470
  return bool(can_access_peer)
1426
1471
 
1427
1472
  def create_module_ptx(self, ptx):
1428
1473
  if isinstance(ptx, str):
1429
- ptx = ptx.encode('utf8')
1474
+ ptx = ptx.encode("utf8")
1430
1475
  if USE_NV_BINDING:
1431
1476
  image = ptx
1432
1477
  else:
1433
1478
  image = c_char_p(ptx)
1434
1479
  return self.create_module_image(image)
1435
1480
 
1436
- def create_module_image(self, image):
1437
- module = load_module_image(self, image)
1481
+ def create_module_image(
1482
+ self, image, setup_callbacks=None, teardown_callbacks=None
1483
+ ):
1484
+ module = load_module_image(
1485
+ self, image, setup_callbacks, teardown_callbacks
1486
+ )
1438
1487
  if USE_NV_BINDING:
1439
1488
  key = module.handle
1440
1489
  else:
@@ -1481,8 +1530,11 @@ class Context(object):
1481
1530
  else:
1482
1531
  handle = drvapi.cu_stream()
1483
1532
  driver.cuStreamCreate(byref(handle), 0)
1484
- return Stream(weakref.proxy(self), handle,
1485
- _stream_finalizer(self.deallocations, handle))
1533
+ return Stream(
1534
+ weakref.proxy(self),
1535
+ handle,
1536
+ _stream_finalizer(self.deallocations, handle),
1537
+ )
1486
1538
 
1487
1539
  def create_external_stream(self, ptr):
1488
1540
  if not isinstance(ptr, int):
@@ -1491,8 +1543,7 @@ class Context(object):
1491
1543
  handle = binding.CUstream(ptr)
1492
1544
  else:
1493
1545
  handle = drvapi.cu_stream(ptr)
1494
- return Stream(weakref.proxy(self), handle, None,
1495
- external=True)
1546
+ return Stream(weakref.proxy(self), handle, None, external=True)
1496
1547
 
1497
1548
  def create_event(self, timing=True):
1498
1549
  flags = 0
@@ -1503,8 +1554,11 @@ class Context(object):
1503
1554
  else:
1504
1555
  handle = drvapi.cu_event()
1505
1556
  driver.cuEventCreate(byref(handle), flags)
1506
- return Event(weakref.proxy(self), handle,
1507
- finalizer=_event_finalizer(self.deallocations, handle))
1557
+ return Event(
1558
+ weakref.proxy(self),
1559
+ handle,
1560
+ finalizer=_event_finalizer(self.deallocations, handle),
1561
+ )
1508
1562
 
1509
1563
  def synchronize(self):
1510
1564
  driver.cuCtxSynchronize()
@@ -1528,17 +1582,25 @@ class Context(object):
1528
1582
  return not self.__eq__(other)
1529
1583
 
1530
1584
 
1531
- def load_module_image(context, image):
1585
+ def load_module_image(
1586
+ context, image, setup_callbacks=None, teardown_callbacks=None
1587
+ ):
1532
1588
  """
1533
1589
  image must be a pointer
1534
1590
  """
1535
1591
  if USE_NV_BINDING:
1536
- return load_module_image_cuda_python(context, image)
1592
+ return load_module_image_cuda_python(
1593
+ context, image, setup_callbacks, teardown_callbacks
1594
+ )
1537
1595
  else:
1538
- return load_module_image_ctypes(context, image)
1596
+ return load_module_image_ctypes(
1597
+ context, image, setup_callbacks, teardown_callbacks
1598
+ )
1539
1599
 
1540
1600
 
1541
- def load_module_image_ctypes(context, image):
1601
+ def load_module_image_ctypes(
1602
+ context, image, setup_callbacks, teardown_callbacks
1603
+ ):
1542
1604
  logsz = config.CUDA_LOG_SIZE
1543
1605
 
1544
1606
  jitinfo = (c_char * logsz)()
@@ -1557,19 +1619,28 @@ def load_module_image_ctypes(context, image):
1557
1619
 
1558
1620
  handle = drvapi.cu_module()
1559
1621
  try:
1560
- driver.cuModuleLoadDataEx(byref(handle), image, len(options),
1561
- option_keys, option_vals)
1622
+ driver.cuModuleLoadDataEx(
1623
+ byref(handle), image, len(options), option_keys, option_vals
1624
+ )
1562
1625
  except CudaAPIError as e:
1563
1626
  msg = "cuModuleLoadDataEx error:\n%s" % jiterrors.value.decode("utf8")
1564
1627
  raise CudaAPIError(e.code, msg)
1565
1628
 
1566
1629
  info_log = jitinfo.value
1567
1630
 
1568
- return CtypesModule(weakref.proxy(context), handle, info_log,
1569
- _module_finalizer(context, handle))
1631
+ return CtypesModule(
1632
+ weakref.proxy(context),
1633
+ handle,
1634
+ info_log,
1635
+ _module_finalizer(context, handle),
1636
+ setup_callbacks,
1637
+ teardown_callbacks,
1638
+ )
1570
1639
 
1571
1640
 
1572
- def load_module_image_cuda_python(context, image):
1641
+ def load_module_image_cuda_python(
1642
+ context, image, setup_callbacks, teardown_callbacks
1643
+ ):
1573
1644
  """
1574
1645
  image must be a pointer
1575
1646
  """
@@ -1591,17 +1662,24 @@ def load_module_image_cuda_python(context, image):
1591
1662
  option_vals = [v for v in options.values()]
1592
1663
 
1593
1664
  try:
1594
- handle = driver.cuModuleLoadDataEx(image, len(options), option_keys,
1595
- option_vals)
1665
+ handle = driver.cuModuleLoadDataEx(
1666
+ image, len(options), option_keys, option_vals
1667
+ )
1596
1668
  except CudaAPIError as e:
1597
- err_string = jiterrors.decode('utf-8')
1669
+ err_string = jiterrors.decode("utf-8")
1598
1670
  msg = "cuModuleLoadDataEx error:\n%s" % err_string
1599
1671
  raise CudaAPIError(e.code, msg)
1600
1672
 
1601
- info_log = jitinfo.decode('utf-8')
1673
+ info_log = jitinfo.decode("utf-8")
1602
1674
 
1603
- return CudaPythonModule(weakref.proxy(context), handle, info_log,
1604
- _module_finalizer(context, handle))
1675
+ return CudaPythonModule(
1676
+ weakref.proxy(context),
1677
+ handle,
1678
+ info_log,
1679
+ _module_finalizer(context, handle),
1680
+ setup_callbacks,
1681
+ teardown_callbacks,
1682
+ )
1605
1683
 
1606
1684
 
1607
1685
  def _alloc_finalizer(memory_manager, ptr, alloc_key, size):
@@ -1704,6 +1782,7 @@ class _CudaIpcImpl(object):
1704
1782
  """Implementation of GPU IPC using CUDA driver API.
1705
1783
  This requires the devices to be peer accessible.
1706
1784
  """
1785
+
1707
1786
  def __init__(self, parent):
1708
1787
  self.base = parent.base
1709
1788
  self.handle = parent.handle
@@ -1717,10 +1796,10 @@ class _CudaIpcImpl(object):
1717
1796
  Import the IPC memory and returns a raw CUDA memory pointer object
1718
1797
  """
1719
1798
  if self.base is not None:
1720
- raise ValueError('opening IpcHandle from original process')
1799
+ raise ValueError("opening IpcHandle from original process")
1721
1800
 
1722
1801
  if self._opened_mem is not None:
1723
- raise ValueError('IpcHandle is already opened')
1802
+ raise ValueError("IpcHandle is already opened")
1724
1803
 
1725
1804
  mem = context.open_ipc_handle(self.handle, self.offset + self.size)
1726
1805
  # this object owns the opened allocation
@@ -1731,7 +1810,7 @@ class _CudaIpcImpl(object):
1731
1810
 
1732
1811
  def close(self):
1733
1812
  if self._opened_mem is None:
1734
- raise ValueError('IpcHandle not opened')
1813
+ raise ValueError("IpcHandle not opened")
1735
1814
  driver.cuIpcCloseMemHandle(self._opened_mem.handle)
1736
1815
  self._opened_mem = None
1737
1816
 
@@ -1740,6 +1819,7 @@ class _StagedIpcImpl(object):
1740
1819
  """Implementation of GPU IPC using custom staging logic to workaround
1741
1820
  CUDA IPC limitation on peer accessibility between devices.
1742
1821
  """
1822
+
1743
1823
  def __init__(self, parent, source_info):
1744
1824
  self.parent = parent
1745
1825
  self.base = parent.base
@@ -1795,6 +1875,7 @@ class IpcHandle(object):
1795
1875
  referred to by this IPC handle.
1796
1876
  :type offset: int
1797
1877
  """
1878
+
1798
1879
  def __init__(self, base, handle, size, source_info=None, offset=0):
1799
1880
  self.base = base
1800
1881
  self.handle = handle
@@ -1818,12 +1899,11 @@ class IpcHandle(object):
1818
1899
  return context.can_access_peer(source_device.id)
1819
1900
 
1820
1901
  def open_staged(self, context):
1821
- """Open the IPC by allowing staging on the host memory first.
1822
- """
1902
+ """Open the IPC by allowing staging on the host memory first."""
1823
1903
  self._sentry_source_info()
1824
1904
 
1825
1905
  if self._impl is not None:
1826
- raise ValueError('IpcHandle is already opened')
1906
+ raise ValueError("IpcHandle is already opened")
1827
1907
 
1828
1908
  self._impl = _StagedIpcImpl(self, self.source_info)
1829
1909
  return self._impl.open(context)
@@ -1833,7 +1913,7 @@ class IpcHandle(object):
1833
1913
  Import the IPC memory and returns a raw CUDA memory pointer object
1834
1914
  """
1835
1915
  if self._impl is not None:
1836
- raise ValueError('IpcHandle is already opened')
1916
+ raise ValueError("IpcHandle is already opened")
1837
1917
 
1838
1918
  self._impl = _CudaIpcImpl(self)
1839
1919
  return self._impl.open(context)
@@ -1864,12 +1944,13 @@ class IpcHandle(object):
1864
1944
  strides = dtype.itemsize
1865
1945
  dptr = self.open(context)
1866
1946
  # read the device pointer as an array
1867
- return devicearray.DeviceNDArray(shape=shape, strides=strides,
1868
- dtype=dtype, gpu_data=dptr)
1947
+ return devicearray.DeviceNDArray(
1948
+ shape=shape, strides=strides, dtype=dtype, gpu_data=dptr
1949
+ )
1869
1950
 
1870
1951
  def close(self):
1871
1952
  if self._impl is None:
1872
- raise ValueError('IpcHandle not opened')
1953
+ raise ValueError("IpcHandle not opened")
1873
1954
  self._impl.close()
1874
1955
  self._impl = None
1875
1956
 
@@ -1895,8 +1976,13 @@ class IpcHandle(object):
1895
1976
  else:
1896
1977
  handle = drvapi.cu_ipc_mem_handle()
1897
1978
  handle.reserved = handle_ary
1898
- return cls(base=None, handle=handle, size=size,
1899
- source_info=source_info, offset=offset)
1979
+ return cls(
1980
+ base=None,
1981
+ handle=handle,
1982
+ size=size,
1983
+ source_info=source_info,
1984
+ offset=offset,
1985
+ )
1900
1986
 
1901
1987
 
1902
1988
  class MemoryPointer(object):
@@ -1930,6 +2016,7 @@ class MemoryPointer(object):
1930
2016
  :param finalizer: A function that is called when the buffer is to be freed.
1931
2017
  :type finalizer: function
1932
2018
  """
2019
+
1933
2020
  __cuda_memory__ = True
1934
2021
 
1935
2022
  def __init__(self, context, pointer, size, owner=None, finalizer=None):
@@ -1965,8 +2052,9 @@ class MemoryPointer(object):
1965
2052
  def memset(self, byte, count=None, stream=0):
1966
2053
  count = self.size if count is None else count
1967
2054
  if stream:
1968
- driver.cuMemsetD8Async(self.device_pointer, byte, count,
1969
- stream.handle)
2055
+ driver.cuMemsetD8Async(
2056
+ self.device_pointer, byte, count, stream.handle
2057
+ )
1970
2058
  else:
1971
2059
  driver.cuMemsetD8(self.device_pointer, byte, count)
1972
2060
 
@@ -1980,12 +2068,12 @@ class MemoryPointer(object):
1980
2068
  if not self.device_pointer_value:
1981
2069
  if size != 0:
1982
2070
  raise RuntimeError("non-empty slice into empty slice")
1983
- view = self # new view is just a reference to self
2071
+ view = self # new view is just a reference to self
1984
2072
  # Handle normal case
1985
2073
  else:
1986
2074
  base = self.device_pointer_value + start
1987
2075
  if size < 0:
1988
- raise RuntimeError('size cannot be negative')
2076
+ raise RuntimeError("size cannot be negative")
1989
2077
  if USE_NV_BINDING:
1990
2078
  pointer = binding.CUdeviceptr()
1991
2079
  ctypes_ptr = drvapi.cu_device_ptr.from_address(pointer.getPtr())
@@ -2021,6 +2109,7 @@ class AutoFreePointer(MemoryPointer):
2021
2109
 
2022
2110
  Constructor arguments are the same as for :class:`MemoryPointer`.
2023
2111
  """
2112
+
2024
2113
  def __init__(self, *args, **kwargs):
2025
2114
  super(AutoFreePointer, self).__init__(*args, **kwargs)
2026
2115
  # Releease the self reference to the buffer, so that the finalizer
@@ -2063,8 +2152,9 @@ class MappedMemory(AutoFreePointer):
2063
2152
  self._bufptr_ = self.host_pointer.value
2064
2153
 
2065
2154
  self.device_pointer = devptr
2066
- super(MappedMemory, self).__init__(context, devptr, size,
2067
- finalizer=finalizer)
2155
+ super(MappedMemory, self).__init__(
2156
+ context, devptr, size, finalizer=finalizer
2157
+ )
2068
2158
  self.handle = self.host_pointer
2069
2159
 
2070
2160
  # For buffer interface
@@ -2179,8 +2269,7 @@ class OwnedPointer(object):
2179
2269
  weakref.finalize(self, deref)
2180
2270
 
2181
2271
  def __getattr__(self, fname):
2182
- """Proxy MemoryPointer methods
2183
- """
2272
+ """Proxy MemoryPointer methods"""
2184
2273
  return getattr(self._view, fname)
2185
2274
 
2186
2275
 
@@ -2211,18 +2300,15 @@ class Stream(object):
2211
2300
  if USE_NV_BINDING:
2212
2301
  default_streams = {
2213
2302
  CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2214
- binding.CU_STREAM_LEGACY:
2215
- "<Legacy default CUDA stream on %s>",
2216
- binding.CU_STREAM_PER_THREAD:
2217
- "<Per-thread default CUDA stream on %s>",
2303
+ binding.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2304
+ binding.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2218
2305
  }
2219
2306
  ptr = int(self.handle) or 0
2220
2307
  else:
2221
2308
  default_streams = {
2222
2309
  drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2223
2310
  drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2224
- drvapi.CU_STREAM_PER_THREAD:
2225
- "<Per-thread default CUDA stream on %s>",
2311
+ drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2226
2312
  }
2227
2313
  ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
2228
2314
 
@@ -2234,18 +2320,18 @@ class Stream(object):
2234
2320
  return "<CUDA stream %d on %s>" % (ptr, self.context)
2235
2321
 
2236
2322
  def synchronize(self):
2237
- '''
2323
+ """
2238
2324
  Wait for all commands in this stream to execute. This will commit any
2239
2325
  pending memory transfers.
2240
- '''
2326
+ """
2241
2327
  driver.cuStreamSynchronize(self.handle)
2242
2328
 
2243
2329
  @contextlib.contextmanager
2244
2330
  def auto_synchronize(self):
2245
- '''
2331
+ """
2246
2332
  A context manager that waits for all commands in this stream to execute
2247
2333
  and commits any pending memory transfers upon exiting the context.
2248
- '''
2334
+ """
2249
2335
  yield self
2250
2336
  self.synchronize()
2251
2337
 
@@ -2272,7 +2358,7 @@ class Stream(object):
2272
2358
  data = (self, callback, arg)
2273
2359
  _py_incref(data)
2274
2360
  if USE_NV_BINDING:
2275
- ptr = int.from_bytes(self._stream_callback, byteorder='little')
2361
+ ptr = int.from_bytes(self._stream_callback, byteorder="little")
2276
2362
  stream_callback = binding.CUstreamCallback(ptr)
2277
2363
  # The callback needs to receive a pointer to the data PyObject
2278
2364
  data = id(data)
@@ -2373,9 +2459,9 @@ class Event(object):
2373
2459
 
2374
2460
 
2375
2461
  def event_elapsed_time(evtstart, evtend):
2376
- '''
2462
+ """
2377
2463
  Compute the elapsed time between two events in milliseconds.
2378
- '''
2464
+ """
2379
2465
  if USE_NV_BINDING:
2380
2466
  return driver.cuEventElapsedTime(evtstart.handle, evtend.handle)
2381
2467
  else:
@@ -2387,13 +2473,27 @@ def event_elapsed_time(evtstart, evtend):
2387
2473
  class Module(metaclass=ABCMeta):
2388
2474
  """Abstract base class for modules"""
2389
2475
 
2390
- def __init__(self, context, handle, info_log, finalizer=None):
2476
+ def __init__(
2477
+ self,
2478
+ context,
2479
+ handle,
2480
+ info_log,
2481
+ finalizer=None,
2482
+ setup_callbacks=None,
2483
+ teardown_callbacks=None,
2484
+ ):
2391
2485
  self.context = context
2392
2486
  self.handle = handle
2393
2487
  self.info_log = info_log
2394
2488
  if finalizer is not None:
2395
2489
  self._finalizer = weakref.finalize(self, finalizer)
2396
2490
 
2491
+ self.initialized = False
2492
+ self.setup_functions = setup_callbacks
2493
+ self.teardown_functions = teardown_callbacks
2494
+
2495
+ self._set_finalizers()
2496
+
2397
2497
  def unload(self):
2398
2498
  """Unload this module from the context"""
2399
2499
  self.context.unload_module(self)
@@ -2406,36 +2506,66 @@ class Module(metaclass=ABCMeta):
2406
2506
  def get_global_symbol(self, name):
2407
2507
  """Return a MemoryPointer referring to the named symbol"""
2408
2508
 
2509
+ def setup(self):
2510
+ """Call the setup functions for the module"""
2511
+ if self.initialized:
2512
+ raise RuntimeError("The module has already been initialized.")
2409
2513
 
2410
- class CtypesModule(Module):
2514
+ if self.setup_functions is None:
2515
+ return
2516
+
2517
+ for f in self.setup_functions:
2518
+ f(self.handle)
2411
2519
 
2520
+ self.initialized = True
2521
+
2522
+ def _set_finalizers(self):
2523
+ """Create finalizers that tear down the module."""
2524
+ if self.teardown_functions is None:
2525
+ return
2526
+
2527
+ def _teardown(teardowns, handle):
2528
+ for f in teardowns:
2529
+ f(handle)
2530
+
2531
+ weakref.finalize(
2532
+ self,
2533
+ _teardown,
2534
+ self.teardown_functions,
2535
+ self.handle,
2536
+ )
2537
+
2538
+
2539
+ class CtypesModule(Module):
2412
2540
  def get_function(self, name):
2413
2541
  handle = drvapi.cu_function()
2414
- driver.cuModuleGetFunction(byref(handle), self.handle,
2415
- name.encode('utf8'))
2542
+ driver.cuModuleGetFunction(
2543
+ byref(handle), self.handle, name.encode("utf8")
2544
+ )
2416
2545
  return CtypesFunction(weakref.proxy(self), handle, name)
2417
2546
 
2418
2547
  def get_global_symbol(self, name):
2419
2548
  ptr = drvapi.cu_device_ptr()
2420
2549
  size = drvapi.c_size_t()
2421
- driver.cuModuleGetGlobal(byref(ptr), byref(size), self.handle,
2422
- name.encode('utf8'))
2550
+ driver.cuModuleGetGlobal(
2551
+ byref(ptr), byref(size), self.handle, name.encode("utf8")
2552
+ )
2423
2553
  return MemoryPointer(self.context, ptr, size), size.value
2424
2554
 
2425
2555
 
2426
2556
  class CudaPythonModule(Module):
2427
-
2428
2557
  def get_function(self, name):
2429
- handle = driver.cuModuleGetFunction(self.handle, name.encode('utf8'))
2558
+ handle = driver.cuModuleGetFunction(self.handle, name.encode("utf8"))
2430
2559
  return CudaPythonFunction(weakref.proxy(self), handle, name)
2431
2560
 
2432
2561
  def get_global_symbol(self, name):
2433
- ptr, size = driver.cuModuleGetGlobal(self.handle, name.encode('utf8'))
2562
+ ptr, size = driver.cuModuleGetGlobal(self.handle, name.encode("utf8"))
2434
2563
  return MemoryPointer(self.context, ptr, size), size
2435
2564
 
2436
2565
 
2437
- FuncAttr = namedtuple("FuncAttr", ["regs", "shared", "local", "const",
2438
- "maxthreads"])
2566
+ FuncAttr = namedtuple(
2567
+ "FuncAttr", ["regs", "shared", "local", "const", "maxthreads"]
2568
+ )
2439
2569
 
2440
2570
 
2441
2571
  class Function(metaclass=ABCMeta):
@@ -2458,8 +2588,9 @@ class Function(metaclass=ABCMeta):
2458
2588
  return self.module.context.device
2459
2589
 
2460
2590
  @abstractmethod
2461
- def cache_config(self, prefer_equal=False, prefer_cache=False,
2462
- prefer_shared=False):
2591
+ def cache_config(
2592
+ self, prefer_equal=False, prefer_cache=False, prefer_shared=False
2593
+ ):
2463
2594
  """Set the cache configuration for this function."""
2464
2595
 
2465
2596
  @abstractmethod
@@ -2473,9 +2604,9 @@ class Function(metaclass=ABCMeta):
2473
2604
 
2474
2605
 
2475
2606
  class CtypesFunction(Function):
2476
-
2477
- def cache_config(self, prefer_equal=False, prefer_cache=False,
2478
- prefer_shared=False):
2607
+ def cache_config(
2608
+ self, prefer_equal=False, prefer_cache=False, prefer_shared=False
2609
+ ):
2479
2610
  prefer_equal = prefer_equal or (prefer_cache and prefer_shared)
2480
2611
  if prefer_equal:
2481
2612
  flag = enums.CU_FUNC_CACHE_PREFER_EQUAL
@@ -2498,15 +2629,17 @@ class CtypesFunction(Function):
2498
2629
  lmem = self.read_func_attr(enums.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
2499
2630
  smem = self.read_func_attr(enums.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
2500
2631
  maxtpb = self.read_func_attr(
2501
- enums.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
2502
- return FuncAttr(regs=nregs, const=cmem, local=lmem, shared=smem,
2503
- maxthreads=maxtpb)
2632
+ enums.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
2633
+ )
2634
+ return FuncAttr(
2635
+ regs=nregs, const=cmem, local=lmem, shared=smem, maxthreads=maxtpb
2636
+ )
2504
2637
 
2505
2638
 
2506
2639
  class CudaPythonFunction(Function):
2507
-
2508
- def cache_config(self, prefer_equal=False, prefer_cache=False,
2509
- prefer_shared=False):
2640
+ def cache_config(
2641
+ self, prefer_equal=False, prefer_cache=False, prefer_shared=False
2642
+ ):
2510
2643
  prefer_equal = prefer_equal or (prefer_cache and prefer_shared)
2511
2644
  attr = binding.CUfunction_attribute
2512
2645
  if prefer_equal:
@@ -2529,19 +2662,26 @@ class CudaPythonFunction(Function):
2529
2662
  lmem = self.read_func_attr(attr.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
2530
2663
  smem = self.read_func_attr(attr.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
2531
2664
  maxtpb = self.read_func_attr(
2532
- attr.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
2533
- return FuncAttr(regs=nregs, const=cmem, local=lmem, shared=smem,
2534
- maxthreads=maxtpb)
2535
-
2665
+ attr.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
2666
+ )
2667
+ return FuncAttr(
2668
+ regs=nregs, const=cmem, local=lmem, shared=smem, maxthreads=maxtpb
2669
+ )
2536
2670
 
2537
- def launch_kernel(cufunc_handle,
2538
- gx, gy, gz,
2539
- bx, by, bz,
2540
- sharedmem,
2541
- hstream,
2542
- args,
2543
- cooperative=False):
2544
2671
 
2672
+ def launch_kernel(
2673
+ cufunc_handle,
2674
+ gx,
2675
+ gy,
2676
+ gz,
2677
+ bx,
2678
+ by,
2679
+ bz,
2680
+ sharedmem,
2681
+ hstream,
2682
+ args,
2683
+ cooperative=False,
2684
+ ):
2545
2685
  param_ptrs = [addressof(arg) for arg in args]
2546
2686
  params = (c_void_p * len(param_ptrs))(*param_ptrs)
2547
2687
 
@@ -2553,46 +2693,54 @@ def launch_kernel(cufunc_handle,
2553
2693
  extra = None
2554
2694
 
2555
2695
  if cooperative:
2556
- driver.cuLaunchCooperativeKernel(cufunc_handle,
2557
- gx, gy, gz,
2558
- bx, by, bz,
2559
- sharedmem,
2560
- hstream,
2561
- params_for_launch)
2696
+ driver.cuLaunchCooperativeKernel(
2697
+ cufunc_handle,
2698
+ gx,
2699
+ gy,
2700
+ gz,
2701
+ bx,
2702
+ by,
2703
+ bz,
2704
+ sharedmem,
2705
+ hstream,
2706
+ params_for_launch,
2707
+ )
2562
2708
  else:
2563
- driver.cuLaunchKernel(cufunc_handle,
2564
- gx, gy, gz,
2565
- bx, by, bz,
2566
- sharedmem,
2567
- hstream,
2568
- params_for_launch,
2569
- extra)
2709
+ driver.cuLaunchKernel(
2710
+ cufunc_handle,
2711
+ gx,
2712
+ gy,
2713
+ gz,
2714
+ bx,
2715
+ by,
2716
+ bz,
2717
+ sharedmem,
2718
+ hstream,
2719
+ params_for_launch,
2720
+ extra,
2721
+ )
2570
2722
 
2571
2723
 
2572
2724
  class Linker(metaclass=ABCMeta):
2573
2725
  """Abstract base class for linkers"""
2574
2726
 
2575
2727
  @classmethod
2576
- def new(cls,
2577
- max_registers=0,
2578
- lineinfo=False,
2579
- cc=None,
2580
- lto=None,
2581
- additional_flags=None
2582
- ):
2583
-
2728
+ def new(
2729
+ cls,
2730
+ max_registers=0,
2731
+ lineinfo=False,
2732
+ cc=None,
2733
+ lto=None,
2734
+ additional_flags=None,
2735
+ ):
2584
2736
  driver_ver = driver.get_version()
2585
- if (
2586
- config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY
2587
- and driver_ver >= (12, 0)
2737
+ if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and driver_ver >= (
2738
+ 12,
2739
+ 0,
2588
2740
  ):
2589
- raise ValueError(
2590
- "Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC"
2591
- )
2741
+ raise ValueError("Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC")
2592
2742
  if config.CUDA_ENABLE_PYNVJITLINK and driver_ver < (12, 0):
2593
- raise ValueError(
2594
- "Enabling pynvjitlink requires CUDA 12."
2595
- )
2743
+ raise ValueError("Enabling pynvjitlink requires CUDA 12.")
2596
2744
  if config.CUDA_ENABLE_PYNVJITLINK:
2597
2745
  linker = PyNvJitLinker
2598
2746
 
@@ -2641,9 +2789,9 @@ class Linker(metaclass=ABCMeta):
2641
2789
  ptx, log = nvrtc.compile(cu, name, cc)
2642
2790
 
2643
2791
  if config.DUMP_ASSEMBLY:
2644
- print(("ASSEMBLY %s" % name).center(80, '-'))
2792
+ print(("ASSEMBLY %s" % name).center(80, "-"))
2645
2793
  print(ptx)
2646
- print('=' * 80)
2794
+ print("=" * 80)
2647
2795
 
2648
2796
  # Link the program's PTX using the normal linker mechanism
2649
2797
  ptx_name = os.path.splitext(name)[0] + ".ptx"
@@ -2654,7 +2802,7 @@ class Linker(metaclass=ABCMeta):
2654
2802
  """Add code from a file to the link"""
2655
2803
 
2656
2804
  def add_cu_file(self, path):
2657
- with open(path, 'rb') as f:
2805
+ with open(path, "rb") as f:
2658
2806
  cu = f.read()
2659
2807
  self.add_cu(cu, os.path.basename(path))
2660
2808
 
@@ -2672,24 +2820,24 @@ class Linker(metaclass=ABCMeta):
2672
2820
 
2673
2821
  if isinstance(path_or_code, str):
2674
2822
  ext = pathlib.Path(path_or_code).suffix
2675
- if ext == '':
2823
+ if ext == "":
2676
2824
  raise RuntimeError(
2677
2825
  "Don't know how to link file with no extension"
2678
2826
  )
2679
- elif ext == '.cu':
2827
+ elif ext == ".cu":
2680
2828
  self.add_cu_file(path_or_code)
2681
2829
  else:
2682
- kind = FILE_EXTENSION_MAP.get(ext.lstrip('.'), None)
2830
+ kind = FILE_EXTENSION_MAP.get(ext.lstrip("."), None)
2683
2831
  if kind is None:
2684
2832
  raise RuntimeError(
2685
- "Don't know how to link file with extension "
2686
- f"{ext}"
2833
+ f"Don't know how to link file with extension {ext}"
2687
2834
  )
2688
2835
 
2689
2836
  if ignore_nonlto:
2690
2837
  warn_and_return = False
2691
2838
  if kind in (
2692
- FILE_EXTENSION_MAP["fatbin"], FILE_EXTENSION_MAP["o"]
2839
+ FILE_EXTENSION_MAP["fatbin"],
2840
+ FILE_EXTENSION_MAP["o"],
2693
2841
  ):
2694
2842
  entry_types = inspect_obj_content(path_or_code)
2695
2843
  if "nvvm" not in entry_types:
@@ -2754,6 +2902,7 @@ class MVCLinker(Linker):
2754
2902
  Linker supporting Minor Version Compatibility, backed by the cubinlinker
2755
2903
  package.
2756
2904
  """
2905
+
2757
2906
  def __init__(self, max_registers=None, lineinfo=False, cc=None):
2758
2907
  try:
2759
2908
  from cubinlinker import CubinLinker
@@ -2761,18 +2910,20 @@ class MVCLinker(Linker):
2761
2910
  raise ImportError(_MVC_ERROR_MESSAGE) from err
2762
2911
 
2763
2912
  if cc is None:
2764
- raise RuntimeError("MVCLinker requires Compute Capability to be "
2765
- "specified, but cc is None")
2913
+ raise RuntimeError(
2914
+ "MVCLinker requires Compute Capability to be "
2915
+ "specified, but cc is None"
2916
+ )
2766
2917
 
2767
2918
  super().__init__(max_registers, lineinfo, cc)
2768
2919
 
2769
2920
  arch = f"sm_{cc[0] * 10 + cc[1]}"
2770
- ptx_compile_opts = ['--gpu-name', arch, '-c']
2921
+ ptx_compile_opts = ["--gpu-name", arch, "-c"]
2771
2922
  if max_registers:
2772
2923
  arg = f"--maxrregcount={max_registers}"
2773
2924
  ptx_compile_opts.append(arg)
2774
2925
  if lineinfo:
2775
- ptx_compile_opts.append('--generate-line-info')
2926
+ ptx_compile_opts.append("--generate-line-info")
2776
2927
  self.ptx_compile_options = tuple(ptx_compile_opts)
2777
2928
 
2778
2929
  self._linker = CubinLinker(f"--arch={arch}")
@@ -2785,7 +2936,7 @@ class MVCLinker(Linker):
2785
2936
  def error_log(self):
2786
2937
  return self._linker.error_log
2787
2938
 
2788
- def add_ptx(self, ptx, name='<cudapy-ptx>'):
2939
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
2789
2940
  try:
2790
2941
  from ptxcompiler import compile_ptx
2791
2942
  from cubinlinker import CubinLinkerError
@@ -2804,19 +2955,19 @@ class MVCLinker(Linker):
2804
2955
  raise ImportError(_MVC_ERROR_MESSAGE) from err
2805
2956
 
2806
2957
  try:
2807
- with open(path, 'rb') as f:
2958
+ with open(path, "rb") as f:
2808
2959
  data = f.read()
2809
2960
  except FileNotFoundError:
2810
- raise LinkerError(f'{path} not found')
2961
+ raise LinkerError(f"{path} not found")
2811
2962
 
2812
2963
  name = pathlib.Path(path).name
2813
- if kind == FILE_EXTENSION_MAP['cubin']:
2964
+ if kind == FILE_EXTENSION_MAP["cubin"]:
2814
2965
  fn = self._linker.add_cubin
2815
- elif kind == FILE_EXTENSION_MAP['fatbin']:
2966
+ elif kind == FILE_EXTENSION_MAP["fatbin"]:
2816
2967
  fn = self._linker.add_fatbin
2817
- elif kind == FILE_EXTENSION_MAP['a']:
2968
+ elif kind == FILE_EXTENSION_MAP["a"]:
2818
2969
  raise LinkerError(f"Don't know how to link {kind}")
2819
- elif kind == FILE_EXTENSION_MAP['ptx']:
2970
+ elif kind == FILE_EXTENSION_MAP["ptx"]:
2820
2971
  return self.add_ptx(data, name)
2821
2972
  else:
2822
2973
  raise LinkerError(f"Don't know how to link {kind}")
@@ -2842,6 +2993,7 @@ class CtypesLinker(Linker):
2842
2993
  """
2843
2994
  Links for current device if no CC given
2844
2995
  """
2996
+
2845
2997
  def __init__(self, max_registers=0, lineinfo=False, cc=None):
2846
2998
  super().__init__(max_registers, lineinfo, cc)
2847
2999
 
@@ -2875,8 +3027,9 @@ class CtypesLinker(Linker):
2875
3027
  option_vals = (c_void_p * len(raw_values))(*raw_values)
2876
3028
 
2877
3029
  self.handle = handle = drvapi.cu_link_state()
2878
- driver.cuLinkCreate(len(raw_keys), option_keys, option_vals,
2879
- byref(self.handle))
3030
+ driver.cuLinkCreate(
3031
+ len(raw_keys), option_keys, option_vals, byref(self.handle)
3032
+ )
2880
3033
 
2881
3034
  weakref.finalize(self, driver.cuLinkDestroy, handle)
2882
3035
 
@@ -2887,19 +3040,27 @@ class CtypesLinker(Linker):
2887
3040
 
2888
3041
  @property
2889
3042
  def info_log(self):
2890
- return self.linker_info_buf.value.decode('utf8')
3043
+ return self.linker_info_buf.value.decode("utf8")
2891
3044
 
2892
3045
  @property
2893
3046
  def error_log(self):
2894
- return self.linker_errors_buf.value.decode('utf8')
3047
+ return self.linker_errors_buf.value.decode("utf8")
2895
3048
 
2896
- def add_ptx(self, ptx, name='<cudapy-ptx>'):
3049
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
2897
3050
  ptxbuf = c_char_p(ptx)
2898
- namebuf = c_char_p(name.encode('utf8'))
3051
+ namebuf = c_char_p(name.encode("utf8"))
2899
3052
  self._keep_alive += [ptxbuf, namebuf]
2900
3053
  try:
2901
- driver.cuLinkAddData(self.handle, enums.CU_JIT_INPUT_PTX,
2902
- ptxbuf, len(ptx), namebuf, 0, None, None)
3054
+ driver.cuLinkAddData(
3055
+ self.handle,
3056
+ enums.CU_JIT_INPUT_PTX,
3057
+ ptxbuf,
3058
+ len(ptx),
3059
+ namebuf,
3060
+ 0,
3061
+ None,
3062
+ None,
3063
+ )
2903
3064
  except CudaAPIError as e:
2904
3065
  raise LinkerError("%s\n%s" % (e, self.error_log))
2905
3066
 
@@ -2911,7 +3072,7 @@ class CtypesLinker(Linker):
2911
3072
  driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, None, None)
2912
3073
  except CudaAPIError as e:
2913
3074
  if e.code == enums.CUDA_ERROR_FILE_NOT_FOUND:
2914
- msg = f'{path} not found'
3075
+ msg = f"{path} not found"
2915
3076
  else:
2916
3077
  msg = "%s\n%s" % (e, self.error_log)
2917
3078
  raise LinkerError(msg)
@@ -2926,7 +3087,7 @@ class CtypesLinker(Linker):
2926
3087
  raise LinkerError("%s\n%s" % (e, self.error_log))
2927
3088
 
2928
3089
  size = size.value
2929
- assert size > 0, 'linker returned a zero sized cubin'
3090
+ assert size > 0, "linker returned a zero sized cubin"
2930
3091
  del self._keep_alive[:]
2931
3092
 
2932
3093
  # We return a copy of the cubin because it's owned by the linker
@@ -2938,6 +3099,7 @@ class CudaPythonLinker(Linker):
2938
3099
  """
2939
3100
  Links for current device if no CC given
2940
3101
  """
3102
+
2941
3103
  def __init__(self, max_registers=0, lineinfo=False, cc=None):
2942
3104
  super().__init__(max_registers, lineinfo, cc)
2943
3105
 
@@ -2964,8 +3126,9 @@ class CudaPythonLinker(Linker):
2964
3126
  options[jit_option.CU_JIT_TARGET_FROM_CUCONTEXT] = 1
2965
3127
  else:
2966
3128
  cc_val = cc[0] * 10 + cc[1]
2967
- cc_enum = getattr(binding.CUjit_target,
2968
- f'CU_TARGET_COMPUTE_{cc_val}')
3129
+ cc_enum = getattr(
3130
+ binding.CUjit_target, f"CU_TARGET_COMPUTE_{cc_val}"
3131
+ )
2969
3132
  options[jit_option.CU_JIT_TARGET] = cc_enum
2970
3133
 
2971
3134
  raw_keys = list(options.keys())
@@ -2982,19 +3145,20 @@ class CudaPythonLinker(Linker):
2982
3145
 
2983
3146
  @property
2984
3147
  def info_log(self):
2985
- return self.linker_info_buf.decode('utf8')
3148
+ return self.linker_info_buf.decode("utf8")
2986
3149
 
2987
3150
  @property
2988
3151
  def error_log(self):
2989
- return self.linker_errors_buf.decode('utf8')
3152
+ return self.linker_errors_buf.decode("utf8")
2990
3153
 
2991
- def add_ptx(self, ptx, name='<cudapy-ptx>'):
2992
- namebuf = name.encode('utf8')
3154
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
3155
+ namebuf = name.encode("utf8")
2993
3156
  self._keep_alive += [ptx, namebuf]
2994
3157
  try:
2995
3158
  input_ptx = binding.CUjitInputType.CU_JIT_INPUT_PTX
2996
- driver.cuLinkAddData(self.handle, input_ptx, ptx, len(ptx),
2997
- namebuf, 0, [], [])
3159
+ driver.cuLinkAddData(
3160
+ self.handle, input_ptx, ptx, len(ptx), namebuf, 0, [], []
3161
+ )
2998
3162
  except CudaAPIError as e:
2999
3163
  raise LinkerError("%s\n%s" % (e, self.error_log))
3000
3164
 
@@ -3006,7 +3170,7 @@ class CudaPythonLinker(Linker):
3006
3170
  driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, [], [])
3007
3171
  except CudaAPIError as e:
3008
3172
  if e.code == binding.CUresult.CUDA_ERROR_FILE_NOT_FOUND:
3009
- msg = f'{path} not found'
3173
+ msg = f"{path} not found"
3010
3174
  else:
3011
3175
  msg = "%s\n%s" % (e, self.error_log)
3012
3176
  raise LinkerError(msg)
@@ -3017,7 +3181,7 @@ class CudaPythonLinker(Linker):
3017
3181
  except CudaAPIError as e:
3018
3182
  raise LinkerError("%s\n%s" % (e, self.error_log))
3019
3183
 
3020
- assert size > 0, 'linker returned a zero sized cubin'
3184
+ assert size > 0, "linker returned a zero sized cubin"
3021
3185
  del self._keep_alive[:]
3022
3186
  # We return a copy of the cubin because it's owned by the linker
3023
3187
  cubin_ptr = ctypes.cast(cubin_buf, ctypes.POINTER(ctypes.c_char))
@@ -3151,6 +3315,7 @@ class PyNvJitLinker(Linker):
3151
3315
  except NvJitLinkError as e:
3152
3316
  raise LinkerError from e
3153
3317
 
3318
+
3154
3319
  # -----------------------------------------------------------------------------
3155
3320
 
3156
3321
 
@@ -3200,7 +3365,7 @@ def device_memory_size(devmem):
3200
3365
  The result is cached in the device memory object.
3201
3366
  It may query the driver for the memory size of the device memory allocation.
3202
3367
  """
3203
- sz = getattr(devmem, '_cuda_memsize_', None)
3368
+ sz = getattr(devmem, "_cuda_memsize_", None)
3204
3369
  if sz is None:
3205
3370
  s, e = device_extents(devmem)
3206
3371
  if USE_NV_BINDING:
@@ -3213,10 +3378,9 @@ def device_memory_size(devmem):
3213
3378
 
3214
3379
 
3215
3380
  def _is_datetime_dtype(obj):
3216
- """Returns True if the obj.dtype is datetime64 or timedelta64
3217
- """
3218
- dtype = getattr(obj, 'dtype', None)
3219
- return dtype is not None and dtype.char in 'Mm'
3381
+ """Returns True if the obj.dtype is datetime64 or timedelta64"""
3382
+ dtype = getattr(obj, "dtype", None)
3383
+ return dtype is not None and dtype.char in "Mm"
3220
3384
 
3221
3385
 
3222
3386
  def _workaround_for_datetime(obj):
@@ -3295,12 +3459,11 @@ def is_device_memory(obj):
3295
3459
  "device_pointer" which value is an int object carrying the pointer
3296
3460
  value of the device memory address. This is not tested in this method.
3297
3461
  """
3298
- return getattr(obj, '__cuda_memory__', False)
3462
+ return getattr(obj, "__cuda_memory__", False)
3299
3463
 
3300
3464
 
3301
3465
  def require_device_memory(obj):
3302
- """A sentry for methods that accept CUDA memory object.
3303
- """
3466
+ """A sentry for methods that accept CUDA memory object."""
3304
3467
  if not is_device_memory(obj):
3305
3468
  raise Exception("Not a CUDA memory object.")
3306
3469
 
@@ -3391,16 +3554,16 @@ def device_memset(dst, val, size, stream=0):
3391
3554
 
3392
3555
 
3393
3556
  def profile_start():
3394
- '''
3557
+ """
3395
3558
  Enable profile collection in the current context.
3396
- '''
3559
+ """
3397
3560
  driver.cuProfilerStart()
3398
3561
 
3399
3562
 
3400
3563
  def profile_stop():
3401
- '''
3564
+ """
3402
3565
  Disable profile collection in the current context.
3403
- '''
3566
+ """
3404
3567
  driver.cuProfilerStop()
3405
3568
 
3406
3569
 
@@ -3427,18 +3590,21 @@ def inspect_obj_content(objpath: str):
3427
3590
  Given path to a fatbin or object, use `cuobjdump` to examine its content
3428
3591
  Return the set of entries in the object.
3429
3592
  """
3430
- code_types :set[str] = set()
3593
+ code_types: set[str] = set()
3431
3594
 
3432
3595
  try:
3433
- out = subprocess.run(["cuobjdump", objpath], check=True,
3434
- capture_output=True)
3596
+ out = subprocess.run(
3597
+ ["cuobjdump", objpath], check=True, capture_output=True
3598
+ )
3435
3599
  except FileNotFoundError as e:
3436
- msg = ("cuobjdump has not been found. You may need "
3437
- "to install the CUDA toolkit and ensure that "
3438
- "it is available on your PATH.\n")
3600
+ msg = (
3601
+ "cuobjdump has not been found. You may need "
3602
+ "to install the CUDA toolkit and ensure that "
3603
+ "it is available on your PATH.\n"
3604
+ )
3439
3605
  raise RuntimeError(msg) from e
3440
3606
 
3441
- objtable = out.stdout.decode('utf-8')
3607
+ objtable = out.stdout.decode("utf-8")
3442
3608
  entry_pattern = r"Fatbin (.*) code"
3443
3609
  for line in objtable.split("\n"):
3444
3610
  if match := re.match(entry_pattern, line):