numba-cuda 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/codegen.py +69 -2
  3. numba_cuda/numba/cuda/compiler.py +41 -17
  4. numba_cuda/numba/cuda/cudadecl.py +15 -5
  5. numba_cuda/numba/cuda/cudadrv/driver.py +103 -20
  6. numba_cuda/numba/cuda/cudadrv/linkable_code.py +10 -2
  7. numba_cuda/numba/cuda/cudaimpl.py +103 -11
  8. numba_cuda/numba/cuda/decorators.py +18 -2
  9. numba_cuda/numba/cuda/dispatcher.py +27 -66
  10. numba_cuda/numba/cuda/runtime/nrt.cu +2 -17
  11. numba_cuda/numba/cuda/runtime/nrt.cuh +41 -0
  12. numba_cuda/numba/cuda/runtime/nrt.py +13 -1
  13. numba_cuda/numba/cuda/stubs.py +23 -11
  14. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +236 -0
  15. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +140 -0
  16. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +98 -1
  17. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +122 -3
  18. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +11 -0
  19. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +5 -2
  20. numba_cuda/numba/cuda/tests/test_binary_generation/nrt_extern.cu +7 -0
  21. numba_cuda/numba/cuda/tests/test_binary_generation/test_device_functions.cu +4 -0
  22. numba_cuda/numba/cuda/utils.py +7 -0
  23. {numba_cuda-0.10.0.dist-info → numba_cuda-0.11.0.dist-info}/METADATA +1 -1
  24. {numba_cuda-0.10.0.dist-info → numba_cuda-0.11.0.dist-info}/RECORD +27 -24
  25. {numba_cuda-0.10.0.dist-info → numba_cuda-0.11.0.dist-info}/WHEEL +1 -1
  26. {numba_cuda-0.10.0.dist-info → numba_cuda-0.11.0.dist-info}/licenses/LICENSE +0 -0
  27. {numba_cuda-0.10.0.dist-info → numba_cuda-0.11.0.dist-info}/top_level.txt +0 -0
numba_cuda/VERSION CHANGED
@@ -1 +1 @@
1
- 0.10.0
1
+ 0.11.0
@@ -5,6 +5,7 @@ from numba.core.codegen import Codegen, CodeLibrary
5
5
  from .cudadrv import devices, driver, nvvm, runtime
6
6
  from numba.cuda.cudadrv.libs import get_cudalib
7
7
  from numba.cuda.cudadrv.linkable_code import LinkableCode
8
+ from numba.cuda.runtime.nrt import NRT_LIBRARY
8
9
 
9
10
  import os
10
11
  import subprocess
@@ -57,6 +58,57 @@ def disassemble_cubin_for_cfg(cubin):
57
58
  return run_nvdisasm(cubin, flags)
58
59
 
59
60
 
61
+ class ExternalCodeLibrary(CodeLibrary):
62
+ """Holds code produced externally, for linking with generated code."""
63
+
64
+ def __init__(self, codegen, name):
65
+ super().__init__(codegen, name)
66
+ # Files to link
67
+ self._linking_files = set()
68
+ # Setup and teardown functions for the module.
69
+ # The order is determined by the order they are added to the codelib.
70
+ self._setup_functions = []
71
+ self._teardown_functions = []
72
+
73
+ @property
74
+ def modules(self):
75
+ # There are no LLVM IR modules in an ExternalCodeLibrary
76
+ return set()
77
+
78
+ def add_linking_file(self, path_or_obj):
79
+ # Adding new files after finalization is prohibited, in case the list
80
+ # of libraries has already been added to another code library; the
81
+ # newly-added files would be omitted from their linking process.
82
+ self._raise_if_finalized()
83
+
84
+ if isinstance(path_or_obj, LinkableCode):
85
+ if path_or_obj.setup_callback:
86
+ self._setup_functions.append(path_or_obj.setup_callback)
87
+ if path_or_obj.teardown_callback:
88
+ self._teardown_functions.append(path_or_obj.teardown_callback)
89
+
90
+ self._linking_files.add(path_or_obj)
91
+
92
+ def add_ir_module(self, module):
93
+ raise NotImplementedError("Cannot add LLVM IR to external code")
94
+
95
+ def add_linking_library(self, library):
96
+ raise NotImplementedError("Cannot add libraries to external code")
97
+
98
+ def finalize(self):
99
+ self._raise_if_finalized()
100
+ self._finalized = True
101
+
102
+ def get_asm_str(self):
103
+ raise NotImplementedError("No assembly for external code")
104
+
105
+ def get_llvm_str(self):
106
+ raise NotImplementedError("No LLVM IR for external code")
107
+
108
+ def get_function(self, name):
109
+ raise NotImplementedError("Cannot get function from external code")
110
+
111
+
60
112
  class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
61
113
  """
62
114
  The CUDACodeLibrary generates PTX, SASS, cubins for multiple different
@@ -297,6 +349,9 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
297
349
  self._raise_if_finalized()
298
350
 
299
351
  self._linking_libraries.add(library)
352
+ self._linking_files.update(library._linking_files)
353
+ self._setup_functions.extend(library._setup_functions)
354
+ self._teardown_functions.extend(library._teardown_functions)
300
355
 
301
356
  def add_linking_file(self, path_or_obj):
302
357
  if isinstance(path_or_obj, LinkableCode):
@@ -362,9 +417,17 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
362
417
  but loaded functions are discarded. They are recreated when needed
363
418
  after deserialization.
364
419
  """
420
+ nrt = False
365
421
  if self._linking_files:
366
- msg = "Cannot pickle CUDACodeLibrary with linking files"
367
- raise RuntimeError(msg)
422
+ if (
423
+ len(self._linking_files) == 1
424
+ and NRT_LIBRARY in self._linking_files
425
+ ):
426
+ nrt = True
427
+ else:
428
+ msg = "Cannot pickle CUDACodeLibrary with linking files"
429
+ raise RuntimeError(msg)
430
+
368
431
  if not self._finalized:
369
432
  raise RuntimeError("Cannot pickle unfinalized CUDACodeLibrary")
370
433
  return dict(
@@ -378,6 +441,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
378
441
  max_registers=self._max_registers,
379
442
  nvvm_options=self._nvvm_options,
380
443
  needs_cudadevrt=self.needs_cudadevrt,
444
+ nrt=nrt,
381
445
  )
382
446
 
383
447
  @classmethod
@@ -393,6 +457,7 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
393
457
  max_registers,
394
458
  nvvm_options,
395
459
  needs_cudadevrt,
460
+ nrt,
396
461
  ):
397
462
  """
398
463
  Rebuild an instance.
@@ -409,6 +474,8 @@ class CUDACodeLibrary(serialize.ReduceMixin, CodeLibrary):
409
474
  instance.needs_cudadevrt = needs_cudadevrt
410
475
 
411
476
  instance._finalized = True
477
+ if nrt:
478
+ instance._linking_files = {NRT_LIBRARY}
412
479
 
413
480
  return instance
414
481
 
@@ -1,5 +1,4 @@
1
1
  from llvmlite import ir
2
- from numba.core.typing.templates import ConcreteTemplate
3
2
  from numba.core import ir as numba_ir
4
3
  from numba.core import (
5
4
  cgutils,
@@ -37,6 +36,7 @@ from numba.core.typed_passes import (
37
36
  from warnings import warn
38
37
  from numba.cuda import nvvmutils
39
38
  from numba.cuda.api import get_current_device
39
+ from numba.cuda.codegen import ExternalCodeLibrary
40
40
  from numba.cuda.cudadrv import nvvm
41
41
  from numba.cuda.descriptor import cuda_target
42
42
  from numba.cuda.target import CUDACABICallConv
@@ -278,7 +278,7 @@ def compile_cuda(
278
278
  args,
279
279
  debug=False,
280
280
  lineinfo=False,
281
- inline=False,
281
+ forceinline=False,
282
282
  fastmath=False,
283
283
  nvvm_options=None,
284
284
  cc=None,
@@ -316,7 +316,7 @@ def compile_cuda(
316
316
  else:
317
317
  flags.error_model = "numpy"
318
318
 
319
- if inline:
319
+ if forceinline:
320
320
  flags.forceinline = True
321
321
  if fastmath:
322
322
  flags.fastmath = True
@@ -574,6 +574,7 @@ def compile(
574
574
  abi="c",
575
575
  abi_info=None,
576
576
  output="ptx",
577
+ forceinline=False,
577
578
  ):
578
579
  """Compile a Python function to PTX or LTO-IR for a given set of argument
579
580
  types.
@@ -614,6 +615,11 @@ def compile(
614
615
  :type abi_info: dict
615
616
  :param output: Type of output to generate, either ``"ptx"`` or ``"ltoir"``.
616
617
  :type output: str
618
+ :param forceinline: Enables inlining at the NVVM IR level when set to
619
+ ``True``. This is accomplished by adding the
620
+ ``alwaysinline`` function attribute to the function
621
+ definition. This is only valid when the output is
622
+ ``"ltoir"``.
617
623
  :return: (code, resty): The compiled code and inferred return type
618
624
  :rtype: tuple
619
625
  """
@@ -626,6 +632,12 @@ def compile(
626
632
  if output not in ("ptx", "ltoir"):
627
633
  raise NotImplementedError(f"Unsupported output type: {output}")
628
634
 
635
+ if forceinline and not device:
636
+ raise ValueError("Cannot force-inline kernels")
637
+
638
+ if forceinline and output != "ltoir":
639
+ raise ValueError("Can only designate forced inlining in LTO-IR")
640
+
629
641
  debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
630
642
  opt = (config.OPT != 0) if opt is None else opt
631
643
 
@@ -660,6 +672,7 @@ def compile(
660
672
  fastmath=fastmath,
661
673
  nvvm_options=nvvm_options,
662
674
  cc=cc,
675
+ forceinline=forceinline,
663
676
  )
664
677
  resty = cres.signature.return_type
665
678
 
@@ -699,6 +712,7 @@ def compile_for_current_device(
699
712
  abi="c",
700
713
  abi_info=None,
701
714
  output="ptx",
715
+ forceinline=False,
702
716
  ):
703
717
  """Compile a Python function to PTX or LTO-IR for a given signature for the
704
718
  current device's compute capabilility. This calls :func:`compile` with an
@@ -716,6 +730,7 @@ def compile_for_current_device(
716
730
  abi=abi,
717
731
  abi_info=abi_info,
718
732
  output=output,
733
+ forceinline=forceinline,
719
734
  )
720
735
 
721
736
 
@@ -730,6 +745,7 @@ def compile_ptx(
730
745
  opt=None,
731
746
  abi="numba",
732
747
  abi_info=None,
748
+ forceinline=False,
733
749
  ):
734
750
  """Compile a Python function to PTX for a given signature. See
735
751
  :func:`compile`. The defaults for this function are to compile a kernel
@@ -747,6 +763,7 @@ def compile_ptx(
747
763
  abi=abi,
748
764
  abi_info=abi_info,
749
765
  output="ptx",
766
+ forceinline=forceinline,
750
767
  )
751
768
 
752
769
 
@@ -760,6 +777,7 @@ def compile_ptx_for_current_device(
760
777
  opt=None,
761
778
  abi="numba",
762
779
  abi_info=None,
780
+ forceinline=False,
763
781
  ):
764
782
  """Compile a Python function to PTX for a given signature for the current
765
783
  device's compute capabilility. See :func:`compile_ptx`."""
@@ -775,36 +793,42 @@ def compile_ptx_for_current_device(
775
793
  opt=opt,
776
794
  abi=abi,
777
795
  abi_info=abi_info,
796
+ forceinline=forceinline,
778
797
  )
779
798
 
780
799
 
781
800
  def declare_device_function(name, restype, argtypes, link):
782
- return declare_device_function_template(name, restype, argtypes, link).key
783
-
784
-
785
- def declare_device_function_template(name, restype, argtypes, link):
786
801
  from .descriptor import cuda_target
787
802
 
788
803
  typingctx = cuda_target.typing_context
789
804
  targetctx = cuda_target.target_context
790
805
  sig = typing.signature(restype, *argtypes)
791
- extfn = ExternFunction(name, sig, link)
792
806
 
793
- class device_function_template(ConcreteTemplate):
794
- key = extfn
795
- cases = [sig]
807
+ # extfn is the descriptor used to call the function from Python code, and
808
+ # is used as the key for typing and lowering.
809
+ extfn = ExternFunction(name, sig)
796
810
 
797
- fndesc = funcdesc.ExternalFunctionDescriptor(
798
- name=name, restype=restype, argtypes=argtypes
799
- )
811
+ # Typing
812
+ device_function_template = typing.make_concrete_template(name, extfn, [sig])
800
813
  typingctx.insert_user_function(extfn, device_function_template)
801
- targetctx.insert_user_function(extfn, fndesc)
814
+
815
+ # Lowering
816
+ lib = ExternalCodeLibrary(f"{name}_externals", targetctx.codegen())
817
+ for file in link:
818
+ lib.add_linking_file(file)
819
+
820
+ # ExternalFunctionDescriptor provides a lowering implementation for calling
821
+ # external functions
822
+ fndesc = funcdesc.ExternalFunctionDescriptor(name, restype, argtypes)
823
+ targetctx.insert_user_function(extfn, fndesc, libs=(lib,))
802
824
 
803
825
  return device_function_template
804
826
 
805
827
 
806
828
  class ExternFunction:
807
- def __init__(self, name, sig, link):
829
+ """A descriptor that can be used to call the external function from within
830
+ a Python kernel."""
831
+
832
+ def __init__(self, name, sig):
808
833
  self.name = name
809
834
  self.sig = sig
810
- self.link = link
@@ -1,5 +1,5 @@
1
1
  import operator
2
- from numba.core import types
2
+ from numba.core import errors, types
3
3
  from numba.core.typing.npydecl import (
4
4
  parse_dtype,
5
5
  parse_shape,
@@ -21,7 +21,7 @@ from numba.core.typing.templates import (
21
21
  from numba.cuda.types import dim3
22
22
  from numba.core.typeconv import Conversion
23
23
  from numba import cuda
24
- from numba.cuda.compiler import declare_device_function_template
24
+ from numba.cuda.compiler import declare_device_function
25
25
 
26
26
  registry = Registry()
27
27
  register = registry.register
@@ -33,7 +33,7 @@ register_number_classes(register_global)
33
33
 
34
34
  class Cuda_array_decl(CallableTemplate):
35
35
  def generic(self):
36
- def typer(shape, dtype):
36
+ def typer(shape, dtype, alignment=None):
37
37
  # Only integer literals and tuples of integer literals are valid
38
38
  # shapes
39
39
  if isinstance(shape, types.Integer):
@@ -47,6 +47,16 @@ class Cuda_array_decl(CallableTemplate):
47
47
  else:
48
48
  return None
49
49
 
50
+ if alignment is not None:
51
+ permitted = (types.IntegerLiteral, types.NoneType)
52
+ if not isinstance(alignment, permitted):
53
+ msg = "alignment must be a constant integer"
54
+ raise errors.RequireLiteralValue(msg)
55
+
56
+ # N.B. We don't use alignment for typing; it's not part of
57
+ # types.Array. The value supplied to the array declaration
58
+ # is handled in the lowering.
59
+
50
60
  ndim = parse_shape(shape)
51
61
  nb_dtype = parse_dtype(dtype)
52
62
  if nb_dtype is not None and ndim is not None:
@@ -412,7 +422,7 @@ _genfp16_binary_operator(operator.itruediv)
412
422
 
413
423
  def _resolve_wrapped_unary(fname):
414
424
  link = tuple()
415
- decl = declare_device_function_template(
425
+ decl = declare_device_function(
416
426
  f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
417
427
  )
418
428
  return types.Function(decl)
@@ -420,7 +430,7 @@ def _resolve_wrapped_unary(fname):
420
430
 
421
431
  def _resolve_wrapped_binary(fname):
422
432
  link = tuple()
423
- decl = declare_device_function_template(
433
+ decl = declare_device_function(
424
434
  f"__numba_wrapper_{fname}",
425
435
  types.float16,
426
436
  (
@@ -49,7 +49,7 @@ from .drvapi import API_PROTOTYPES
49
49
  from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
50
50
  from .mappings import FILE_EXTENSION_MAP
51
51
  from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
52
- from numba.cuda.utils import _readenv
52
+ from numba.cuda.utils import _readenv, cached_file_read
53
53
  from numba.cuda.cudadrv import enums, drvapi, nvrtc
54
54
 
55
55
  try:
@@ -2797,13 +2797,16 @@ class Linker(metaclass=ABCMeta):
2797
2797
  ptx_name = os.path.splitext(name)[0] + ".ptx"
2798
2798
  self.add_ptx(ptx.encode(), ptx_name)
2799
2799
 
2800
+ @abstractmethod
2801
+ def add_data(self, data, kind, name):
2802
+ """Add in-memory data to the link"""
2803
+
2800
2804
  @abstractmethod
2801
2805
  def add_file(self, path, kind):
2802
2806
  """Add code from a file to the link"""
2803
2807
 
2804
2808
  def add_cu_file(self, path):
2805
- with open(path, "rb") as f:
2806
- cu = f.read()
2809
+ cu = cached_file_read(path, how="rb")
2807
2810
  self.add_cu(cu, os.path.basename(path))
2808
2811
 
2809
2812
  def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
@@ -2948,6 +2951,10 @@ class MVCLinker(Linker):
2948
2951
  except CubinLinkerError as e:
2949
2952
  raise LinkerError from e
2950
2953
 
2954
+ def add_data(self, data, kind, name):
2955
+ msg = "Adding in-memory data unsupported in the MVC linker"
2956
+ raise LinkerError(msg)
2957
+
2951
2958
  def add_file(self, path, kind):
2952
2959
  try:
2953
2960
  from cubinlinker import CubinLinkerError
@@ -2955,8 +2962,7 @@ class MVCLinker(Linker):
2955
2962
  raise ImportError(_MVC_ERROR_MESSAGE) from err
2956
2963
 
2957
2964
  try:
2958
- with open(path, "rb") as f:
2959
- data = f.read()
2965
+ data = cached_file_read(path, how="rb")
2960
2966
  except FileNotFoundError:
2961
2967
  raise LinkerError(f"{path} not found")
2962
2968
 
@@ -3046,17 +3052,32 @@ class CtypesLinker(Linker):
3046
3052
  def error_log(self):
3047
3053
  return self.linker_errors_buf.value.decode("utf8")
3048
3054
 
3049
- def add_ptx(self, ptx, name="<cudapy-ptx>"):
3050
- ptxbuf = c_char_p(ptx)
3051
- namebuf = c_char_p(name.encode("utf8"))
3052
- self._keep_alive += [ptxbuf, namebuf]
3055
+ def add_cubin(self, cubin, name="<unnamed-cubin>"):
3056
+ return self._add_data(enums.CU_JIT_INPUT_CUBIN, cubin, name)
3057
+
3058
+ def add_ptx(self, ptx, name="<unnamed-ptx>"):
3059
+ return self._add_data(enums.CU_JIT_INPUT_PTX, ptx, name)
3060
+
3061
+ def add_object(self, object_, name="<unnamed-object>"):
3062
+ return self._add_data(enums.CU_JIT_INPUT_OBJECT, object_, name)
3063
+
3064
+ def add_fatbin(self, fatbin, name="<unnamed-fatbin>"):
3065
+ return self._add_data(enums.CU_JIT_INPUT_FATBINARY, fatbin, name)
3066
+
3067
+ def add_library(self, library, name="<unnamed-library>"):
3068
+ return self._add_data(enums.CU_JIT_INPUT_LIBRARY, library, name)
3069
+
3070
+ def _add_data(self, input_type, data, name):
3071
+ data_buffer = c_char_p(data)
3072
+ name_buffer = c_char_p(name.encode("utf8"))
3073
+ self._keep_alive += [data_buffer, name_buffer]
3053
3074
  try:
3054
3075
  driver.cuLinkAddData(
3055
3076
  self.handle,
3056
- enums.CU_JIT_INPUT_PTX,
3057
- ptxbuf,
3058
- len(ptx),
3059
- namebuf,
3077
+ input_type,
3078
+ data_buffer,
3079
+ len(data),
3080
+ name_buffer,
3060
3081
  0,
3061
3082
  None,
3062
3083
  None,
@@ -3064,6 +3085,28 @@ class CtypesLinker(Linker):
3064
3085
  except CudaAPIError as e:
3065
3086
  raise LinkerError("%s\n%s" % (e, self.error_log))
3066
3087
 
3088
+ def add_data(self, data, kind, name=None):
3089
+ # We pass the name as **kwargs to ensure the default name for the input
3090
+ # type is used if none is supplied
3091
+ kws = {}
3092
+ if name is not None:
3093
+ kws["name"] = name
3094
+
3095
+ if kind == FILE_EXTENSION_MAP["cubin"]:
3096
+ self.add_cubin(data, **kws)
3097
+ elif kind == FILE_EXTENSION_MAP["fatbin"]:
3098
+ self.add_fatbin(data, **kws)
3099
+ elif kind == FILE_EXTENSION_MAP["a"]:
3100
+ self.add_library(data, **kws)
3101
+ elif kind == FILE_EXTENSION_MAP["ptx"]:
3102
+ self.add_ptx(data, **kws)
3103
+ elif kind == FILE_EXTENSION_MAP["o"]:
3104
+ self.add_object(data, **kws)
3105
+ elif kind == FILE_EXTENSION_MAP["ltoir"]:
3106
+ raise LinkerError("Ctypes linker cannot link LTO-IR")
3107
+ else:
3108
+ raise LinkerError(f"Don't know how to link {kind}")
3109
+
3067
3110
  def add_file(self, path, kind):
3068
3111
  pathbuf = c_char_p(path.encode("utf8"))
3069
3112
  self._keep_alive.append(pathbuf)
@@ -3151,17 +3194,58 @@ class CudaPythonLinker(Linker):
3151
3194
  def error_log(self):
3152
3195
  return self.linker_errors_buf.decode("utf8")
3153
3196
 
3154
- def add_ptx(self, ptx, name="<cudapy-ptx>"):
3155
- namebuf = name.encode("utf8")
3156
- self._keep_alive += [ptx, namebuf]
3197
+ def add_cubin(self, cubin, name="<unnamed-cubin>"):
3198
+ input_type = binding.CUjitInputType.CU_JIT_INPUT_CUBIN
3199
+ return self._add_data(input_type, cubin, name)
3200
+
3201
+ def add_ptx(self, ptx, name="<unnamed-ptx>"):
3202
+ input_type = binding.CUjitInputType.CU_JIT_INPUT_PTX
3203
+ return self._add_data(input_type, ptx, name)
3204
+
3205
+ def add_object(self, object_, name="<unnamed-object>"):
3206
+ input_type = binding.CUjitInputType.CU_JIT_INPUT_OBJECT
3207
+ return self._add_data(input_type, object_, name)
3208
+
3209
+ def add_fatbin(self, fatbin, name="<unnamed-fatbin>"):
3210
+ input_type = binding.CUjitInputType.CU_JIT_INPUT_FATBINARY
3211
+ return self._add_data(input_type, fatbin, name)
3212
+
3213
+ def add_library(self, library, name="<unnamed-library>"):
3214
+ input_type = binding.CUjitInputType.CU_JIT_INPUT_LIBRARY
3215
+ return self._add_data(input_type, library, name)
3216
+
3217
+ def _add_data(self, input_type, data, name):
3218
+ name_buffer = name.encode("utf8")
3219
+ self._keep_alive += [data, name_buffer]
3157
3220
  try:
3158
- input_ptx = binding.CUjitInputType.CU_JIT_INPUT_PTX
3159
3221
  driver.cuLinkAddData(
3160
- self.handle, input_ptx, ptx, len(ptx), namebuf, 0, [], []
3222
+ self.handle, input_type, data, len(data), name_buffer, 0, [], []
3161
3223
  )
3162
3224
  except CudaAPIError as e:
3163
3225
  raise LinkerError("%s\n%s" % (e, self.error_log))
3164
3226
 
3227
+ def add_data(self, data, kind, name=None):
3228
+ # We pass the name as **kwargs to ensure the default name for the input
3229
+ # type is used if none is supplied
3230
+ kws = {}
3231
+ if name is not None:
3232
+ kws["name"] = name
3233
+
3234
+ if kind == FILE_EXTENSION_MAP["cubin"]:
3235
+ self.add_cubin(data, **kws)
3236
+ elif kind == FILE_EXTENSION_MAP["fatbin"]:
3237
+ self.add_fatbin(data, **kws)
3238
+ elif kind == FILE_EXTENSION_MAP["a"]:
3239
+ self.add_library(data, **kws)
3240
+ elif kind == FILE_EXTENSION_MAP["ptx"]:
3241
+ self.add_ptx(data, **kws)
3242
+ elif kind == FILE_EXTENSION_MAP["o"]:
3243
+ self.add_object(data, **kws)
3244
+ elif kind == FILE_EXTENSION_MAP["ltoir"]:
3245
+ raise LinkerError("CudaPythonLinker cannot link LTO-IR")
3246
+ else:
3247
+ raise LinkerError(f"Don't know how to link {kind}")
3248
+
3165
3249
  def add_file(self, path, kind):
3166
3250
  pathbuf = path.encode("utf8")
3167
3251
  self._keep_alive.append(pathbuf)
@@ -3252,8 +3336,7 @@ class PyNvJitLinker(Linker):
3252
3336
 
3253
3337
  def add_file(self, path, kind):
3254
3338
  try:
3255
- with open(path, "rb") as f:
3256
- data = f.read()
3339
+ data = cached_file_read(path, "rb")
3257
3340
  except FileNotFoundError:
3258
3341
  raise LinkerError(f"{path} not found")
3259
3342
 
@@ -16,16 +16,24 @@ class LinkableCode:
16
16
  :param teardown_callback: A function called just prior to the unloading of
17
17
  a module that has this code object linked into
18
18
  it.
19
+ :param nrt: If True, assume this object contains NRT function calls and
20
+ add NRT source code to the final link.
19
21
  """
20
22
 
21
23
  def __init__(
22
- self, data, name=None, setup_callback=None, teardown_callback=None
24
+ self,
25
+ data,
26
+ name=None,
27
+ setup_callback=None,
28
+ teardown_callback=None,
29
+ nrt=False,
23
30
  ):
24
31
  if setup_callback and not callable(setup_callback):
25
32
  raise TypeError("setup_callback must be callable")
26
33
  if teardown_callback and not callable(teardown_callback):
27
34
  raise TypeError("teardown_callback must be callable")
28
35
 
36
+ self.nrt = nrt
29
37
  self._name = name
30
38
  self._data = data
31
39
  self.setup_callback = setup_callback
@@ -87,5 +95,5 @@ class Object(LinkableCode):
87
95
  class LTOIR(LinkableCode):
88
96
  """An LTOIR file in memory."""
89
97
 
90
- kind = "ltoir"
98
+ kind = FILE_EXTENSION_MAP["ltoir"]
91
99
  default_name = "<unnamed-ltoir>"