numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +246 -114
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +293 -99
  14. numba_cuda/numba/cuda/cudadecl.py +93 -79
  15. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  16. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  17. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  18. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  19. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  20. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  21. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  22. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  23. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  24. numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
  25. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  26. numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
  27. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  28. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  29. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  30. numba_cuda/numba/cuda/cudaimpl.py +296 -275
  31. numba_cuda/numba/cuda/cudamath.py +1 -1
  32. numba_cuda/numba/cuda/debuginfo.py +99 -7
  33. numba_cuda/numba/cuda/decorators.py +87 -45
  34. numba_cuda/numba/cuda/descriptor.py +1 -1
  35. numba_cuda/numba/cuda/device_init.py +68 -18
  36. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  37. numba_cuda/numba/cuda/dispatcher.py +300 -213
  38. numba_cuda/numba/cuda/errors.py +13 -10
  39. numba_cuda/numba/cuda/extending.py +55 -1
  40. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  41. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  42. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
  43. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
  44. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  45. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  46. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  47. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  48. numba_cuda/numba/cuda/initialize.py +5 -3
  49. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  50. numba_cuda/numba/cuda/intrinsics.py +203 -28
  51. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  52. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  53. numba_cuda/numba/cuda/libdevice.py +317 -317
  54. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  55. numba_cuda/numba/cuda/locks.py +16 -0
  56. numba_cuda/numba/cuda/lowering.py +43 -0
  57. numba_cuda/numba/cuda/mathimpl.py +62 -57
  58. numba_cuda/numba/cuda/models.py +1 -5
  59. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  60. numba_cuda/numba/cuda/printimpl.py +9 -5
  61. numba_cuda/numba/cuda/random.py +46 -36
  62. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  63. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  64. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  65. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  66. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  67. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  68. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  69. numba_cuda/numba/cuda/simulator/api.py +38 -22
  70. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  71. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  72. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  73. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  74. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  75. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  76. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  77. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  78. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  79. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  80. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  81. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  82. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  83. numba_cuda/numba/cuda/simulator_init.py +2 -4
  84. numba_cuda/numba/cuda/stubs.py +134 -108
  85. numba_cuda/numba/cuda/target.py +92 -47
  86. numba_cuda/numba/cuda/testing.py +24 -19
  87. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  88. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  89. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  90. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  91. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  92. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  93. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  94. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  95. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  96. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  97. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  98. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  99. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  100. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  102. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  103. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  104. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  105. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  107. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  108. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  109. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  110. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  111. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  112. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  113. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  114. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  115. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  117. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  118. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  119. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
  120. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  121. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  123. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  124. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  125. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  127. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
  129. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  130. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  131. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  132. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  133. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  134. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  135. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  136. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  137. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  138. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  139. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  140. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  141. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  142. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  143. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
  144. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  145. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  146. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
  147. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  148. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  149. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
  150. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  151. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  152. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  153. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  154. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  155. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  156. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  157. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  158. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  159. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  161. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  162. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  163. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  164. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  165. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
  166. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  167. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  168. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  169. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  170. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  171. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  172. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  173. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  174. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  175. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  176. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  178. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  179. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  180. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  181. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  182. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  183. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  184. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  185. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  186. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  187. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  188. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  189. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  190. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  191. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  192. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  193. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  194. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  195. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  196. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  197. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  198. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  199. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  200. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  201. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  202. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  203. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  204. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  205. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
  206. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  207. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  208. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  209. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  210. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  211. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  212. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  213. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  214. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  216. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  217. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  218. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  219. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  220. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  221. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  222. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  223. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  224. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  225. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  226. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  227. numba_cuda/numba/cuda/types.py +5 -2
  228. numba_cuda/numba/cuda/ufuncs.py +382 -362
  229. numba_cuda/numba/cuda/utils.py +2 -2
  230. numba_cuda/numba/cuda/vector_types.py +5 -3
  231. numba_cuda/numba/cuda/vectorizers.py +38 -33
  232. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
  233. numba_cuda-0.10.0.dist-info/RECORD +263 -0
  234. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
  235. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  236. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
  237. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,13 @@ from numba.core import config, serialize
4
4
  from numba.core.codegen import Codegen, CodeLibrary
5
5
  from .cudadrv import devices, driver, nvvm, runtime
6
6
  from numba.cuda.cudadrv.libs import get_cudalib
7
+ from numba.cuda.cudadrv.linkable_code import LinkableCode
7
8
 
8
9
  import os
9
10
  import subprocess
10
11
  import tempfile
11
12
 
12
- CUDA_TRIPLE = 'nvptx64-nvidia-cuda'
13
+ CUDA_TRIPLE = "nvptx64-nvidia-cuda"
13
14
 
14
15
 
15
16
  def run_nvdisasm(cubin, flags):
@@ -19,19 +20,24 @@ def run_nvdisasm(cubin, flags):
19
20
  fname = None
20
21
  try:
21
22
  fd, fname = tempfile.mkstemp()
22
- with open(fname, 'wb') as f:
23
+ with open(fname, "wb") as f:
23
24
  f.write(cubin)
24
25
 
25
26
  try:
26
- cp = subprocess.run(['nvdisasm', *flags, fname], check=True,
27
- stdout=subprocess.PIPE,
28
- stderr=subprocess.PIPE)
27
+ cp = subprocess.run(
28
+ ["nvdisasm", *flags, fname],
29
+ check=True,
30
+ stdout=subprocess.PIPE,
31
+ stderr=subprocess.PIPE,
32
+ )
29
33
  except FileNotFoundError as e:
30
- msg = ("nvdisasm has not been found. You may need "
31
- "to install the CUDA toolkit and ensure that "
32
- "it is available on your PATH.\n")
34
+ msg = (
35
+ "nvdisasm has not been found. You may need "
36
+ "to install the CUDA toolkit and ensure that "
37
+ "it is available on your PATH.\n"
38
+ )
33
39
  raise RuntimeError(msg) from e
34
- return cp.stdout.decode('utf-8')
40
+ return cp.stdout.decode("utf-8")
35
41
  finally:
36
42
  if fd is not None:
37
43
  os.close(fd)
@@ -41,13 +47,13 @@ def run_nvdisasm(cubin, flags):
41
47
 
42
48
  def disassemble_cubin(cubin):
43
49
  # Request lineinfo in disassembly
44
- flags = ['-gi']
50
+ flags = ["-gi"]
45
51
  return run_nvdisasm(cubin, flags)
46
52
 
47
53
 
48
54
  def disassemble_cubin_for_cfg(cubin):
49
55
  # Request control flow graph in disassembly
50
- flags = ['-cfg']
56
+ flags = ["-cfg"]
51
57
  return run_nvdisasm(cubin, flags)
52
58
 
53
59
 
@@ -65,7 +71,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
65
71
  entry_name=None,
66
72
  max_registers=None,
67
73
  lto=False,
68
- nvvm_options=None
74
+ nvvm_options=None,
69
75
  ):
70
76
  """
71
77
  codegen:
@@ -94,6 +100,12 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
94
100
  # Files to link with the generated PTX. These are linked using the
95
101
  # Driver API at link time.
96
102
  self._linking_files = set()
103
+ # List of setup functions to the loaded module
104
+ # the order is determined by the order they are added to the codelib.
105
+ self._setup_functions = []
106
+ # List of teardown functions to the loaded module
107
+ # the order is determined by the order they are added to the codelib.
108
+ self._teardown_functions = []
97
109
  # Should we link libcudadevrt?
98
110
  self.needs_cudadevrt = False
99
111
 
@@ -142,7 +154,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
142
154
 
143
155
  arch = nvvm.get_arch_option(*cc)
144
156
  options = self._nvvm_options.copy()
145
- options['arch'] = arch
157
+ options["arch"] = arch
146
158
 
147
159
  irs = self.llvm_strs
148
160
 
@@ -151,12 +163,12 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
151
163
  # Sometimes the result from NVVM contains trailing whitespace and
152
164
  # nulls, which we strip so that the assembly dump looks a little
153
165
  # tidier.
154
- ptx = ptx.decode().strip('\x00').strip()
166
+ ptx = ptx.decode().strip("\x00").strip()
155
167
 
156
168
  if config.DUMP_ASSEMBLY:
157
- print(("ASSEMBLY %s" % self._name).center(80, '-'))
169
+ print(("ASSEMBLY %s" % self._name).center(80, "-"))
158
170
  print(ptx)
159
- print('=' * 80)
171
+ print("=" * 80)
160
172
 
161
173
  self._ptx_cache[cc] = ptx
162
174
 
@@ -171,8 +183,8 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
171
183
 
172
184
  arch = nvvm.get_arch_option(*cc)
173
185
  options = self._nvvm_options.copy()
174
- options['arch'] = arch
175
- options['gen-lto'] = None
186
+ options["arch"] = arch
187
+ options["gen-lto"] = None
176
188
 
177
189
  irs = self.llvm_strs
178
190
  ltoir = nvvm.compile_ir(irs, **options)
@@ -192,7 +204,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
192
204
  linker.add_file_guess_ext(path, ignore_nonlto)
193
205
  if self.needs_cudadevrt:
194
206
  linker.add_file_guess_ext(
195
- get_cudalib('cudadevrt', static=True), ignore_nonlto
207
+ get_cudalib("cudadevrt", static=True), ignore_nonlto
196
208
  )
197
209
 
198
210
  def get_cubin(self, cc=None):
@@ -207,22 +219,20 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
207
219
  max_registers=self._max_registers,
208
220
  cc=cc,
209
221
  additional_flags=["-ptx"],
210
- lto=self._lto
222
+ lto=self._lto,
211
223
  )
212
224
  # `-ptx` flag is meant to view the optimized PTX for LTO objects.
213
225
  # Non-LTO objects are not passed to linker.
214
226
  self._link_all(linker, cc, ignore_nonlto=True)
215
227
 
216
- ptx = linker.get_linked_ptx().decode('utf-8')
228
+ ptx = linker.get_linked_ptx().decode("utf-8")
217
229
 
218
- print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
230
+ print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, "-"))
219
231
  print(ptx)
220
- print('=' * 80)
232
+ print("=" * 80)
221
233
 
222
234
  linker = driver.Linker.new(
223
- max_registers=self._max_registers,
224
- cc=cc,
225
- lto=self._lto
235
+ max_registers=self._max_registers, cc=cc, lto=self._lto
226
236
  )
227
237
  self._link_all(linker, cc, ignore_nonlto=False)
228
238
  cubin = linker.complete()
@@ -234,8 +244,10 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
234
244
 
235
245
  def get_cufunc(self):
236
246
  if self._entry_name is None:
237
- msg = "Missing entry_name - are you trying to get the cufunc " \
238
- "for a device function?"
247
+ msg = (
248
+ "Missing entry_name - are you trying to get the cufunc "
249
+ "for a device function?"
250
+ )
239
251
  raise RuntimeError(msg)
240
252
 
241
253
  ctx = devices.get_context()
@@ -246,7 +258,9 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
246
258
  return cufunc
247
259
 
248
260
  cubin = self.get_cubin(cc=device.compute_capability)
249
- module = ctx.create_module_image(cubin)
261
+ module = ctx.create_module_image(
262
+ cubin, self._setup_functions, self._teardown_functions
263
+ )
250
264
 
251
265
  # Load
252
266
  cufunc = module.get_function(self._entry_name)
@@ -260,7 +274,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
260
274
  try:
261
275
  return self._linkerinfo_cache[cc]
262
276
  except KeyError:
263
- raise KeyError(f'No linkerinfo for CC {cc}')
277
+ raise KeyError(f"No linkerinfo for CC {cc}")
264
278
 
265
279
  def get_sass(self, cc=None):
266
280
  return disassemble_cubin(self.get_cubin(cc=cc))
@@ -271,7 +285,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
271
285
  def add_ir_module(self, mod):
272
286
  self._raise_if_finalized()
273
287
  if self._module is not None:
274
- raise RuntimeError('CUDACodeLibrary only supports one module')
288
+ raise RuntimeError("CUDACodeLibrary only supports one module")
275
289
  self._module = mod
276
290
 
277
291
  def add_linking_library(self, library):
@@ -284,19 +298,26 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
284
298
 
285
299
  self._linking_libraries.add(library)
286
300
 
287
- def add_linking_file(self, filepath):
288
- self._linking_files.add(filepath)
301
+ def add_linking_file(self, path_or_obj):
302
+ if isinstance(path_or_obj, LinkableCode):
303
+ if path_or_obj.setup_callback:
304
+ self._setup_functions.append(path_or_obj.setup_callback)
305
+ if path_or_obj.teardown_callback:
306
+ self._teardown_functions.append(path_or_obj.teardown_callback)
307
+
308
+ self._linking_files.add(path_or_obj)
289
309
 
290
310
  def get_function(self, name):
291
311
  for fn in self._module.functions:
292
312
  if fn.name == name:
293
313
  return fn
294
- raise KeyError(f'Function {name} not found')
314
+ raise KeyError(f"Function {name} not found")
295
315
 
296
316
  @property
297
317
  def modules(self):
298
- return [self._module] + [mod for lib in self._linking_libraries
299
- for mod in lib.modules]
318
+ return [self._module] + [
319
+ mod for lib in self._linking_libraries for mod in lib.modules
320
+ ]
300
321
 
301
322
  @property
302
323
  def linking_libraries(self):
@@ -331,7 +352,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
331
352
  for mod in library.modules:
332
353
  for fn in mod.functions:
333
354
  if not fn.is_declaration:
334
- fn.linkage = 'linkonce_odr'
355
+ fn.linkage = "linkonce_odr"
335
356
 
336
357
  self._finalized = True
337
358
 
@@ -342,10 +363,10 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
342
363
  after deserialization.
343
364
  """
344
365
  if self._linking_files:
345
- msg = 'Cannot pickle CUDACodeLibrary with linking files'
366
+ msg = "Cannot pickle CUDACodeLibrary with linking files"
346
367
  raise RuntimeError(msg)
347
368
  if not self._finalized:
348
- raise RuntimeError('Cannot pickle unfinalized CUDACodeLibrary')
369
+ raise RuntimeError("Cannot pickle unfinalized CUDACodeLibrary")
349
370
  return dict(
350
371
  codegen=None,
351
372
  name=self.name,
@@ -356,13 +377,23 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
356
377
  linkerinfo_cache=self._linkerinfo_cache,
357
378
  max_registers=self._max_registers,
358
379
  nvvm_options=self._nvvm_options,
359
- needs_cudadevrt=self.needs_cudadevrt
380
+ needs_cudadevrt=self.needs_cudadevrt,
360
381
  )
361
382
 
362
383
  @classmethod
363
- def _rebuild(cls, codegen, name, entry_name, llvm_strs, ptx_cache,
364
- cubin_cache, linkerinfo_cache, max_registers, nvvm_options,
365
- needs_cudadevrt):
384
+ def _rebuild(
385
+ cls,
386
+ codegen,
387
+ name,
388
+ entry_name,
389
+ llvm_strs,
390
+ ptx_cache,
391
+ cubin_cache,
392
+ linkerinfo_cache,
393
+ max_registers,
394
+ nvvm_options,
395
+ needs_cudadevrt,
396
+ ):
366
397
  """
367
398
  Rebuild an instance.
368
399
  """