numba-cuda 0.21.1__cp313-cp313-win_amd64.whl → 0.24.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +4 -1
  3. numba_cuda/numba/cuda/_compat.py +47 -0
  4. numba_cuda/numba/cuda/api.py +4 -1
  5. numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
  6. numba_cuda/numba/cuda/cext/_dispatcher.cpp +8 -40
  7. numba_cuda/numba/cuda/cext/_hashtable.cpp +5 -0
  8. numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
  9. numba_cuda/numba/cuda/cext/_pymodule.h +1 -1
  10. numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
  11. numba_cuda/numba/cuda/cext/_typeof.cpp +56 -119
  12. numba_cuda/numba/cuda/cext/mviewbuf.c +7 -1
  13. numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
  14. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +4 -5
  15. numba_cuda/numba/cuda/codegen.py +46 -12
  16. numba_cuda/numba/cuda/compiler.py +15 -9
  17. numba_cuda/numba/cuda/core/analysis.py +29 -21
  18. numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +1 -1
  19. numba_cuda/numba/cuda/core/annotations/type_annotations.py +4 -4
  20. numba_cuda/numba/cuda/core/base.py +12 -11
  21. numba_cuda/numba/cuda/core/bytecode.py +21 -13
  22. numba_cuda/numba/cuda/core/byteflow.py +336 -90
  23. numba_cuda/numba/cuda/core/compiler.py +3 -4
  24. numba_cuda/numba/cuda/core/compiler_machinery.py +3 -3
  25. numba_cuda/numba/cuda/core/config.py +5 -7
  26. numba_cuda/numba/cuda/core/consts.py +1 -1
  27. numba_cuda/numba/cuda/core/controlflow.py +17 -9
  28. numba_cuda/numba/cuda/core/cuda_errors.py +917 -0
  29. numba_cuda/numba/cuda/core/errors.py +4 -912
  30. numba_cuda/numba/cuda/core/inline_closurecall.py +82 -67
  31. numba_cuda/numba/cuda/core/interpreter.py +334 -160
  32. numba_cuda/numba/cuda/core/ir.py +191 -119
  33. numba_cuda/numba/cuda/core/ir_utils.py +149 -128
  34. numba_cuda/numba/cuda/core/postproc.py +8 -8
  35. numba_cuda/numba/cuda/core/pythonapi.py +3 -0
  36. numba_cuda/numba/cuda/core/rewrites/ir_print.py +6 -3
  37. numba_cuda/numba/cuda/core/rewrites/static_binop.py +1 -1
  38. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +5 -5
  39. numba_cuda/numba/cuda/core/rewrites/static_raise.py +3 -3
  40. numba_cuda/numba/cuda/core/ssa.py +5 -5
  41. numba_cuda/numba/cuda/core/transforms.py +29 -16
  42. numba_cuda/numba/cuda/core/typed_passes.py +10 -10
  43. numba_cuda/numba/cuda/core/typeinfer.py +42 -27
  44. numba_cuda/numba/cuda/core/untyped_passes.py +82 -65
  45. numba_cuda/numba/cuda/cpython/unicode.py +2 -2
  46. numba_cuda/numba/cuda/cpython/unicode_support.py +1 -3
  47. numba_cuda/numba/cuda/cudadecl.py +0 -13
  48. numba_cuda/numba/cuda/cudadrv/devicearray.py +10 -9
  49. numba_cuda/numba/cuda/cudadrv/driver.py +142 -519
  50. numba_cuda/numba/cuda/cudadrv/dummyarray.py +4 -0
  51. numba_cuda/numba/cuda/cudadrv/nvrtc.py +87 -32
  52. numba_cuda/numba/cuda/cudaimpl.py +0 -12
  53. numba_cuda/numba/cuda/debuginfo.py +25 -0
  54. numba_cuda/numba/cuda/descriptor.py +1 -1
  55. numba_cuda/numba/cuda/device_init.py +4 -7
  56. numba_cuda/numba/cuda/deviceufunc.py +3 -6
  57. numba_cuda/numba/cuda/dispatcher.py +39 -49
  58. numba_cuda/numba/cuda/intrinsics.py +150 -1
  59. numba_cuda/numba/cuda/libdeviceimpl.py +1 -2
  60. numba_cuda/numba/cuda/lowering.py +36 -29
  61. numba_cuda/numba/cuda/memory_management/nrt.py +10 -14
  62. numba_cuda/numba/cuda/np/arrayobj.py +61 -9
  63. numba_cuda/numba/cuda/np/numpy_support.py +32 -9
  64. numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +4 -3
  65. numba_cuda/numba/cuda/printimpl.py +20 -0
  66. numba_cuda/numba/cuda/serialize.py +10 -0
  67. numba_cuda/numba/cuda/stubs.py +0 -11
  68. numba_cuda/numba/cuda/testing.py +4 -8
  69. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +21 -4
  70. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +1 -2
  71. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +195 -51
  72. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +6 -2
  73. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +3 -1
  74. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  75. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +6 -7
  76. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +11 -12
  77. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +53 -23
  78. numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +61 -9
  79. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +6 -0
  80. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +47 -0
  81. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +22 -1
  82. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +13 -0
  83. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  85. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +94 -0
  86. numba_cuda/numba/cuda/tests/cudapy/test_device_array_capture.py +243 -0
  87. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +3 -3
  88. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  89. numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py +35 -0
  90. numba_cuda/numba/cuda/tests/cudapy/test_print.py +51 -0
  91. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +37 -35
  92. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +117 -1
  93. numba_cuda/numba/cuda/tests/doc_examples/test_globals.py +111 -0
  94. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +61 -0
  95. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +31 -0
  96. numba_cuda/numba/cuda/tests/support.py +11 -0
  97. numba_cuda/numba/cuda/types/cuda_functions.py +1 -1
  98. numba_cuda/numba/cuda/typing/asnumbatype.py +37 -2
  99. numba_cuda/numba/cuda/typing/context.py +3 -1
  100. numba_cuda/numba/cuda/typing/typeof.py +51 -2
  101. {numba_cuda-0.21.1.dist-info → numba_cuda-0.24.0.dist-info}/METADATA +4 -13
  102. {numba_cuda-0.21.1.dist-info → numba_cuda-0.24.0.dist-info}/RECORD +106 -105
  103. numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
  104. numba_cuda/numba/cuda/cext/_devicearray.cpp +0 -159
  105. numba_cuda/numba/cuda/cext/_devicearray.h +0 -29
  106. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -41
  107. {numba_cuda-0.21.1.dist-info → numba_cuda-0.24.0.dist-info}/WHEEL +0 -0
  108. {numba_cuda-0.21.1.dist-info → numba_cuda-0.24.0.dist-info}/licenses/LICENSE +0 -0
  109. {numba_cuda-0.21.1.dist-info → numba_cuda-0.24.0.dist-info}/licenses/LICENSE.numba +0 -0
  110. {numba_cuda-0.21.1.dist-info → numba_cuda-0.24.0.dist-info}/top_level.txt +0 -0
@@ -279,6 +279,10 @@ class Array(object):
279
279
  if not self.dims:
280
280
  return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
281
281
 
282
+ # All 0-size arrays are considered contiguous, even if they are multidimensional
283
+ if self.size == 0:
284
+ return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
285
+
282
286
  # If this is a broadcast array then it is not contiguous
283
287
  if any([dim.stride == 0 for dim in self.dims]):
284
288
  return {"C_CONTIGUOUS": False, "F_CONTIGUOUS": False}
@@ -12,7 +12,7 @@ import os
12
12
  import warnings
13
13
  import functools
14
14
 
15
- from cuda.core.experimental import Program, ProgramOptions
15
+ from numba.cuda._compat import Program, ProgramOptions
16
16
  from cuda.bindings import nvrtc as bindings_nvrtc
17
17
 
18
18
  NVRTC_EXTRA_SEARCH_PATHS = _readenv(
@@ -30,6 +30,44 @@ def _get_nvrtc_version():
30
30
  return (major, minor)
31
31
 
32
32
 
33
+ def _verify_cc_tuple(cc):
34
+ version = _get_nvrtc_version()
35
+ ver_str = lambda version: ".".join(str(v) for v in version)
36
+
37
+ if len(cc) == 3:
38
+ cc, arch = (cc[0], cc[1]), cc[2]
39
+ else:
40
+ arch = ""
41
+
42
+ if arch not in ("", "a", "f"):
43
+ raise ValueError(
44
+ f"Invalid architecture suffix '{arch}' in compute capability "
45
+ f"{ver_str(cc)}{arch}. Expected '', 'a', or 'f'."
46
+ )
47
+
48
+ supported_ccs = get_supported_ccs()
49
+ try:
50
+ found = max(filter(lambda v: v <= cc, [v for v in supported_ccs]))
51
+ except ValueError:
52
+ raise RuntimeError(
53
+ f"Device compute capability {ver_str(cc)} is less than the "
54
+ f"minimum supported by NVRTC {ver_str(version)}. Supported "
55
+ "compute capabilities are "
56
+ f"{', '.join([ver_str(v) for v in supported_ccs])}."
57
+ )
58
+
59
+ if found != cc:
60
+ found = (found[0], found[1], arch)
61
+ warnings.warn(
62
+ f"Device compute capability {ver_str(cc)} is not supported by "
63
+ f"NVRTC {ver_str(version)}. Using {ver_str(found)} instead."
64
+ )
65
+ else:
66
+ found = (cc[0], cc[1], arch)
67
+
68
+ return found
69
+
70
+
33
71
  def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
34
72
  """
35
73
  Compile a CUDA C/C++ source to PTX or LTOIR for a given compute capability.
@@ -38,7 +76,8 @@ def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
38
76
  :type src: str
39
77
  :param name: The filename of the source (for information only)
40
78
  :type name: str
41
- :param cc: A tuple ``(major, minor)`` of the compute capability
79
+ :param cc: A tuple ``(major, minor)`` or ``(major, minor, arch)`` of the
80
+ compute capability
42
81
  :type cc: tuple
43
82
  :param ltoir: Compile into LTOIR if True, otherwise into PTX
44
83
  :type ltoir: bool
@@ -49,34 +88,18 @@ def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
49
88
  :return: The compiled PTX or LTOIR and compilation log
50
89
  :rtype: tuple
51
90
  """
91
+ found = _verify_cc_tuple(cc)
52
92
  version = _get_nvrtc_version()
53
93
 
54
- ver_str = lambda version: ".".join(str(v) for v in version)
55
- supported_ccs = get_supported_ccs()
56
- try:
57
- found = max(filter(lambda v: v <= cc, [v for v in supported_ccs]))
58
- except ValueError:
59
- raise RuntimeError(
60
- f"Device compute capability {ver_str(cc)} is less than the "
61
- f"minimum supported by NVRTC {ver_str(version)}. Supported "
62
- "compute capabilities are "
63
- f"{', '.join([ver_str(v) for v in supported_ccs])}."
64
- )
65
-
66
- if found != cc:
67
- warnings.warn(
68
- f"Device compute capability {ver_str(cc)} is not supported by "
69
- f"NVRTC {ver_str(version)}. Using {ver_str(found)} instead."
70
- )
71
-
72
94
  # Compilation options:
73
95
  # - Compile for the current device's compute capability.
74
96
  # - The CUDA include path is added.
75
97
  # - Relocatable Device Code (rdc) is needed to prevent device functions
76
98
  # being optimized away.
77
- major, minor = found
99
+ major, minor = found[0], found[1]
100
+ cc_arch = found[2] if len(found) == 3 else ""
78
101
 
79
- arch = f"sm_{major}{minor}"
102
+ arch = f"sm_{major}{minor}{cc_arch}"
80
103
 
81
104
  cuda_include_dir = get_cuda_paths()["include_dir"].info
82
105
  cuda_includes = [f"{cuda_include_dir}"]
@@ -109,6 +132,22 @@ def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
109
132
 
110
133
  includes = [numba_include, *cuda_includes, nrt_include, *extra_includes]
111
134
 
135
+ # TODO: move all this into Program/ProgramOptions
136
+ # logsz = config.CUDA_LOG_SIZE
137
+ #
138
+ # jitinfo = bytearray(logsz)
139
+ # jiterrors = bytearray(logsz)
140
+ #
141
+ # jit_option = binding.CUjit_option
142
+ # options = {
143
+ # jit_option.CU_JIT_INFO_LOG_BUFFER: jitinfo,
144
+ # jit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES: logsz,
145
+ # jit_option.CU_JIT_ERROR_LOG_BUFFER: jiterrors,
146
+ # jit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES: logsz,
147
+ # jit_option.CU_JIT_LOG_VERBOSE: config.CUDA_VERBOSE_JIT_LOG,
148
+ # }
149
+ # info_log = jitinfo.decode("utf-8")
150
+
112
151
  options = ProgramOptions(
113
152
  arch=arch,
114
153
  include_path=includes,
@@ -140,7 +179,7 @@ def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
140
179
  return result, log
141
180
 
142
181
 
143
- def find_closest_arch(mycc):
182
+ def find_closest_arch(cc):
144
183
  """
145
184
  Given a compute capability, return the closest compute capability supported
146
185
  by the CUDA toolkit.
@@ -150,17 +189,17 @@ def find_closest_arch(mycc):
150
189
  """
151
190
  supported_ccs = get_supported_ccs()
152
191
 
153
- for i, cc in enumerate(supported_ccs):
154
- if cc == mycc:
192
+ for i, supported_cc in enumerate(supported_ccs):
193
+ if supported_cc == cc:
155
194
  # Matches
156
- return cc
157
- elif cc > mycc:
195
+ return supported_cc
196
+ elif supported_cc > cc:
158
197
  # Exceeded
159
198
  if i == 0:
160
199
  # CC lower than supported
161
200
  msg = (
162
201
  "GPU compute capability %d.%d is not supported"
163
- "(requires >=%d.%d)" % (mycc + cc)
202
+ "(requires >=%d.%d)" % (cc + supported_cc)
164
203
  )
165
204
  raise CCSupportError(msg)
166
205
  else:
@@ -171,13 +210,29 @@ def find_closest_arch(mycc):
171
210
  return supported_ccs[-1] # Choose the highest
172
211
 
173
212
 
174
- def get_arch_option(major, minor):
213
+ def get_arch_option(major, minor, arch=""):
175
214
  """Matches with the closest architecture option"""
176
215
  if config.FORCE_CUDA_CC:
177
- arch = config.FORCE_CUDA_CC
216
+ fcc = config.FORCE_CUDA_CC
217
+ major, minor = fcc[0], fcc[1]
218
+ if len(fcc) == 3:
219
+ arch = fcc[2]
220
+ else:
221
+ arch = ""
178
222
  else:
179
- arch = find_closest_arch((major, minor))
180
- return "compute_%d%d" % arch
223
+ new_major, new_minor = find_closest_arch((major, minor))
224
+ if (new_major, new_minor) != (major, minor):
225
+ # If we picked a different major / minor, then using an
226
+ # arch-specific version is invalid
227
+ if arch != "":
228
+ raise ValueError(
229
+ f"Can't use arch-specific compute_{major}{minor}{arch} with "
230
+ "closest found compute capability "
231
+ f"compute_{new_major}{new_minor}"
232
+ )
233
+ major, minor = new_major, new_minor
234
+
235
+ return f"compute_{major}{minor}{arch}"
181
236
 
182
237
 
183
238
  def get_lowest_supported_cc():
@@ -280,18 +280,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
280
280
  return context.get_dummy_value()
281
281
 
282
282
 
283
- @lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
284
- def ptx_vote_sync(context, builder, sig, args):
285
- fname = "llvm.nvvm.vote.sync"
286
- lmod = builder.module
287
- fnty = ir.FunctionType(
288
- ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
289
- (ir.IntType(32), ir.IntType(32), ir.IntType(1)),
290
- )
291
- func = cgutils.get_or_insert_function(lmod, fnty, fname)
292
- return builder.call(func, args)
293
-
294
-
295
283
  @lower(stubs.match_any_sync, types.i4, types.i4)
296
284
  @lower(stubs.match_any_sync, types.i4, types.i8)
297
285
  @lower(stubs.match_any_sync, types.i4, types.f4)
@@ -4,6 +4,7 @@
4
4
  import abc
5
5
  import os
6
6
  from contextlib import contextmanager
7
+ from enum import IntEnum
7
8
 
8
9
  import llvmlite
9
10
  from llvmlite import ir
@@ -71,6 +72,16 @@ if not hasattr(config, "CUDA_DEBUG_POLY_USE_TYPED_CONST"):
71
72
  config.CUDA_DEBUG_POLY_USE_TYPED_CONST = DEBUG_POLY_USE_TYPED_CONST
72
73
 
73
74
 
75
+ class DwarfAddressClass(IntEnum):
76
+ GENERIC = 0x00
77
+ GLOBAL = 0x01
78
+ REGISTER = 0x02
79
+ CONSTANT = 0x05
80
+ LOCAL = 0x06
81
+ PARAMETER = 0x07
82
+ SHARED = 0x08
83
+
84
+
74
85
  @contextmanager
75
86
  def suspend_emission(builder):
76
87
  """Suspends the emission of debug_metadata for the duration of the context
@@ -179,6 +190,19 @@ class DIBuilder(AbstractDIBuilder):
179
190
  # constructing subprograms
180
191
  self.dicompileunit = self._di_compile_unit()
181
192
 
193
+ def get_dwarf_address_class(self, addrspace):
194
+ # Map NVVM address space to DWARF address class.
195
+ from numba.cuda.cudadrv import nvvm
196
+
197
+ addrspace_to_addrclass_dict = {
198
+ nvvm.ADDRSPACE_GENERIC: None,
199
+ nvvm.ADDRSPACE_GLOBAL: DwarfAddressClass.GLOBAL,
200
+ nvvm.ADDRSPACE_SHARED: DwarfAddressClass.SHARED,
201
+ nvvm.ADDRSPACE_CONSTANT: DwarfAddressClass.CONSTANT,
202
+ nvvm.ADDRSPACE_LOCAL: DwarfAddressClass.LOCAL,
203
+ }
204
+ return addrspace_to_addrclass_dict.get(addrspace)
205
+
182
206
  def _var_type(self, lltype, size, datamodel=None):
183
207
  if self._DEBUG:
184
208
  print(
@@ -796,6 +820,7 @@ class CUDADIBuilder(DIBuilder):
796
820
  },
797
821
  is_distinct=True,
798
822
  )
823
+
799
824
  # For other cases, use upstream Numba implementation
800
825
  return super()._var_type(lltype, size, datamodel=datamodel)
801
826
 
@@ -28,7 +28,7 @@ class CUDATarget:
28
28
  @property
29
29
  def target_context(self):
30
30
  if self._targetctx is None:
31
- self._targetctx = CUDATargetContext(self._typingctx)
31
+ self._targetctx = CUDATargetContext(self.typing_context)
32
32
  return self._targetctx
33
33
 
34
34
 
@@ -27,7 +27,6 @@ from .stubs import (
27
27
  local,
28
28
  const,
29
29
  atomic,
30
- vote_sync_intrinsic,
31
30
  match_any_sync,
32
31
  match_all_sync,
33
32
  threadfence_block,
@@ -56,6 +55,10 @@ from .intrinsics import (
56
55
  shfl_up_sync,
57
56
  shfl_down_sync,
58
57
  shfl_xor_sync,
58
+ all_sync,
59
+ any_sync,
60
+ eq_sync,
61
+ ballot_sync,
59
62
  )
60
63
  from .cudadrv.error import CudaSupportError
61
64
  from numba.cuda.cudadrv.driver import (
@@ -79,12 +82,6 @@ from .api import *
79
82
  from .api import _auto_device
80
83
  from .args import In, Out, InOut
81
84
 
82
- from .intrinsic_wrapper import (
83
- all_sync,
84
- any_sync,
85
- eq_sync,
86
- ballot_sync,
87
- )
88
85
 
89
86
  from .kernels import reduction
90
87
  from numba.cuda.cudadrv.linkable_code import (
@@ -682,12 +682,9 @@ class GUFuncEngine(object):
682
682
  inner_shapes.append(inner_shape)
683
683
 
684
684
  # solve output shape
685
- oshapes = []
686
- for outsig in self.sout:
687
- oshape = []
688
- for sym in outsig:
689
- oshape.append(symbolmap[sym])
690
- oshapes.append(tuple(oshape))
685
+ oshapes = [
686
+ tuple(map(symbolmap.__getitem__, outsig)) for outsig in self.sout
687
+ ]
691
688
 
692
689
  # find the biggest outershape as looping dimension
693
690
  sizes = [reduce(operator.mul, s, 1) for s in outer_shapes]
@@ -15,6 +15,8 @@ import uuid
15
15
  import re
16
16
  from warnings import warn
17
17
 
18
+ from numba.cuda._compat import launch, LaunchConfig
19
+
18
20
  from numba.cuda.core import errors
19
21
  from numba.cuda import serialize, utils
20
22
  from numba import cuda
@@ -39,6 +41,7 @@ from numba.cuda.compiler import (
39
41
  from numba.cuda.core import sigutils, config, entrypoints
40
42
  from numba.cuda.flags import Flags
41
43
  from numba.cuda.cudadrv import driver, nvvm
44
+
42
45
  from numba.cuda.locks import module_init_lock
43
46
  from numba.cuda.core.caching import Cache, CacheImpl, NullCache
44
47
  from numba.cuda.descriptor import cuda_target
@@ -475,18 +478,15 @@ class _Kernel(serialize.ReduceMixin):
475
478
  for t, v in zip(self.argument_types, args):
476
479
  self._prepare_args(t, v, stream, retr, kernelargs)
477
480
 
478
- stream_handle = driver._stream_handle(stream)
479
-
480
481
  # Invoke kernel
481
- driver.launch_kernel(
482
- cufunc.handle,
483
- *griddim,
484
- *blockdim,
485
- sharedmem,
486
- stream_handle,
487
- kernelargs,
488
- cooperative=self.cooperative,
482
+ config = LaunchConfig(
483
+ grid=griddim,
484
+ block=blockdim,
485
+ shmem_size=sharedmem,
486
+ cooperative_launch=self.cooperative,
489
487
  )
488
+ kernel = cufunc.kernel
489
+ launch(stream, config, kernel, *kernelargs)
490
490
 
491
491
  if self.debug:
492
492
  driver.device_to_host(ctypes.addressof(excval), excmem, excsz)
@@ -540,30 +540,26 @@ class _Kernel(serialize.ReduceMixin):
540
540
 
541
541
  if isinstance(ty, types.Array):
542
542
  devary = wrap_arg(val).to_device(retr, stream)
543
- c_intp = ctypes.c_ssize_t
544
543
 
545
- meminfo = ctypes.c_void_p(0)
546
- parent = ctypes.c_void_p(0)
547
- nitems = c_intp(devary.size)
548
- itemsize = c_intp(devary.dtype.itemsize)
549
-
550
- ptr = driver.device_pointer(devary)
551
-
552
- ptr = int(ptr)
553
-
554
- data = ctypes.c_void_p(ptr)
544
+ meminfo = 0
545
+ parent = 0
555
546
 
556
547
  kernelargs.append(meminfo)
557
548
  kernelargs.append(parent)
558
- kernelargs.append(nitems)
559
- kernelargs.append(itemsize)
560
- kernelargs.append(data)
561
- kernelargs.extend(map(c_intp, devary.shape))
562
- kernelargs.extend(map(c_intp, devary.strides))
549
+
550
+ # non-pointer-arguments-without-ctypes might be dicey, since we're
551
+ # assuming shape, strides, size, and itemsize fit into intptr_t
552
+ # however, this saves a noticeable amount of overhead in kernel
553
+ # invocation
554
+ kernelargs.append(devary.size)
555
+ kernelargs.append(devary.dtype.itemsize)
556
+ kernelargs.append(devary.device_ctypes_pointer.value)
557
+ kernelargs.extend(devary.shape)
558
+ kernelargs.extend(devary.strides)
563
559
 
564
560
  elif isinstance(ty, types.CPointer):
565
561
  # Pointer arguments should be a pointer-sized integer
566
- kernelargs.append(ctypes.c_uint64(val))
562
+ kernelargs.append(val)
567
563
 
568
564
  elif isinstance(ty, types.Integer):
569
565
  cval = getattr(ctypes, "c_%s" % ty)(val)
@@ -582,8 +578,7 @@ class _Kernel(serialize.ReduceMixin):
582
578
  kernelargs.append(cval)
583
579
 
584
580
  elif ty == types.boolean:
585
- cval = ctypes.c_uint8(int(val))
586
- kernelargs.append(cval)
581
+ kernelargs.append(val)
587
582
 
588
583
  elif ty == types.complex64:
589
584
  kernelargs.append(ctypes.c_float(val.real))
@@ -598,8 +593,7 @@ class _Kernel(serialize.ReduceMixin):
598
593
 
599
594
  elif isinstance(ty, types.Record):
600
595
  devrec = wrap_arg(val).to_device(retr, stream)
601
- ptr = devrec.device_ctypes_pointer
602
- kernelargs.append(ptr)
596
+ kernelargs.append(devrec.device_ctypes_pointer.value)
603
597
 
604
598
  elif isinstance(ty, types.BaseTuple):
605
599
  assert len(ty) == len(val)
@@ -671,7 +665,7 @@ class _LaunchConfiguration:
671
665
  self.dispatcher = dispatcher
672
666
  self.griddim = griddim
673
667
  self.blockdim = blockdim
674
- self.stream = stream
668
+ self.stream = driver._to_core_stream(stream)
675
669
  self.sharedmem = sharedmem
676
670
 
677
671
  if (
@@ -700,6 +694,16 @@ class _LaunchConfiguration:
700
694
  args, self.griddim, self.blockdim, self.stream, self.sharedmem
701
695
  )
702
696
 
697
+ def __getstate__(self):
698
+ state = self.__dict__.copy()
699
+ state["stream"] = int(state["stream"].handle)
700
+ return state
701
+
702
+ def __setstate__(self, state):
703
+ handle = state.pop("stream")
704
+ self.__dict__.update(state)
705
+ self.stream = driver._to_core_stream(handle)
706
+
703
707
 
704
708
  class CUDACacheImpl(CacheImpl):
705
709
  def reduce(self, kernel):
@@ -854,7 +858,7 @@ class _DispatcherBase(_dispatcher.Dispatcher):
854
858
  for cres in overloads.values():
855
859
  try:
856
860
  targetctx.remove_user_function(cres.entry_point)
857
- except KeyError:
861
+ except KeyError: # noqa: PERF203
858
862
  pass
859
863
 
860
864
  return finalizer
@@ -1622,21 +1626,7 @@ class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
1622
1626
  def typeof_pyval(self, val):
1623
1627
  # Based on _DispatcherBase.typeof_pyval, but differs from it to support
1624
1628
  # the CUDA Array Interface.
1625
- try:
1626
- return typeof(val, Purpose.argument)
1627
- except ValueError:
1628
- if (
1629
- interface := getattr(val, "__cuda_array_interface__")
1630
- ) is not None:
1631
- # When typing, we don't need to synchronize on the array's
1632
- # stream - this is done when the kernel is launched.
1633
-
1634
- return typeof(
1635
- cuda.from_cuda_array_interface(interface, sync=False),
1636
- Purpose.argument,
1637
- )
1638
- else:
1639
- raise
1629
+ return typeof(val, Purpose.argument)
1640
1630
 
1641
1631
  def specialize(self, *args):
1642
1632
  """
@@ -2100,7 +2090,7 @@ class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
2100
2090
  if file is None:
2101
2091
  file = sys.stdout
2102
2092
 
2103
- for _, defn in self.overloads.items():
2093
+ for defn in self.overloads.values():
2104
2094
  defn.inspect_types(file=file)
2105
2095
 
2106
2096
  @classmethod
@@ -6,7 +6,11 @@ from llvmlite import ir
6
6
  from numba import cuda
7
7
  from numba.cuda import types
8
8
  from numba.cuda import cgutils
9
- from numba.cuda.core.errors import RequireLiteralValue, TypingError
9
+ from numba.cuda.core.errors import (
10
+ RequireLiteralValue,
11
+ TypingError,
12
+ NumbaTypeError,
13
+ )
10
14
  from numba.cuda.typing import signature
11
15
  from numba.cuda.extending import overload_attribute, overload_method
12
16
  from numba.cuda import nvvmutils
@@ -380,3 +384,148 @@ def shfl_sync_intrinsic(
380
384
  sig = signature(a_type, membermask_type, a_type, b_type)
381
385
 
382
386
  return sig, codegen
387
+
388
+
389
+ # -------------------------------------------------------------------------------
390
+ # Warp vote functions
391
+ #
392
+ # References:
393
+ #
394
+ # - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-vote-functions
395
+ # - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html?highlight=data%2520movement#vote
396
+ #
397
+ # Notes:
398
+ #
399
+ # - The NVVM IR specification requires some of the mode parameter to be
400
+ # constants. It's therefore essential that we pass in mode values to the
401
+ # vote_sync_intrinsic.
402
+
403
+
404
+ @intrinsic
405
+ def all_sync(typingctx, mask_type, predicate_type):
406
+ """
407
+ If for all threads in the masked warp the predicate is true, then
408
+ a non-zero value is returned, otherwise 0 is returned.
409
+ """
410
+ mode_value = 0
411
+ sig, codegen_inner = vote_sync_intrinsic(
412
+ typingctx, mask_type, mode_value, predicate_type
413
+ )
414
+
415
+ def codegen(context, builder, sig_outer, args):
416
+ # Call vote_sync_intrinsic and extract the boolean result (index 1)
417
+ result_tuple = codegen_inner(context, builder, sig, args)
418
+ return builder.extract_value(result_tuple, 1)
419
+
420
+ sig_outer = signature(types.b1, mask_type, predicate_type)
421
+ return sig_outer, codegen
422
+
423
+
424
+ @intrinsic
425
+ def any_sync(typingctx, mask_type, predicate_type):
426
+ """
427
+ If for any thread in the masked warp the predicate is true, then
428
+ a non-zero value is returned, otherwise 0 is returned.
429
+ """
430
+ mode_value = 1
431
+ sig, codegen_inner = vote_sync_intrinsic(
432
+ typingctx, mask_type, mode_value, predicate_type
433
+ )
434
+
435
+ def codegen(context, builder, sig_outer, args):
436
+ result_tuple = codegen_inner(context, builder, sig, args)
437
+ return builder.extract_value(result_tuple, 1)
438
+
439
+ sig_outer = signature(types.b1, mask_type, predicate_type)
440
+ return sig_outer, codegen
441
+
442
+
443
+ @intrinsic
444
+ def eq_sync(typingctx, mask_type, predicate_type):
445
+ """
446
+ If for all threads in the masked warp the boolean predicate is the same,
447
+ then a non-zero value is returned, otherwise 0 is returned.
448
+ """
449
+ mode_value = 2
450
+ sig, codegen_inner = vote_sync_intrinsic(
451
+ typingctx, mask_type, mode_value, predicate_type
452
+ )
453
+
454
+ def codegen(context, builder, sig_outer, args):
455
+ result_tuple = codegen_inner(context, builder, sig, args)
456
+ return builder.extract_value(result_tuple, 1)
457
+
458
+ sig_outer = signature(types.b1, mask_type, predicate_type)
459
+ return sig_outer, codegen
460
+
461
+
462
+ @intrinsic
463
+ def ballot_sync(typingctx, mask_type, predicate_type):
464
+ """
465
+ Returns a mask of all threads in the warp whose predicate is true,
466
+ and are within the given mask.
467
+ """
468
+ mode_value = 3
469
+ sig, codegen_inner = vote_sync_intrinsic(
470
+ typingctx, mask_type, mode_value, predicate_type
471
+ )
472
+
473
+ def codegen(context, builder, sig_outer, args):
474
+ result_tuple = codegen_inner(context, builder, sig, args)
475
+ return builder.extract_value(
476
+ result_tuple, 0
477
+ ) # Extract ballot result (index 0)
478
+
479
+ sig_outer = signature(types.i4, mask_type, predicate_type)
480
+ return sig_outer, codegen
481
+
482
+
483
+ def vote_sync_intrinsic(typingctx, mask_type, mode_value, predicate_type):
484
+ # Validate mode value
485
+ if mode_value not in (0, 1, 2, 3):
486
+ raise ValueError("Mode must be 0 (all), 1 (any), 2 (eq), or 3 (ballot)")
487
+
488
+ if types.unliteral(mask_type) not in types.integer_domain:
489
+ raise NumbaTypeError(f"Mask type must be an integer. Got {mask_type}")
490
+ predicate_types = types.integer_domain | {types.boolean}
491
+
492
+ if types.unliteral(predicate_type) not in predicate_types:
493
+ raise NumbaTypeError(
494
+ f"Predicate must be an integer or boolean. Got {predicate_type}"
495
+ )
496
+
497
+ def codegen(context, builder, sig, args):
498
+ mask, predicate = args
499
+
500
+ # Types
501
+ i1 = ir.IntType(1)
502
+ i32 = ir.IntType(32)
503
+
504
+ # NVVM intrinsic definition
505
+ arg_types = (i32, i32, i1)
506
+ vote_return_type = ir.LiteralStructType((i32, i1))
507
+ fnty = ir.FunctionType(vote_return_type, arg_types)
508
+
509
+ fname = "llvm.nvvm.vote.sync"
510
+ lmod = builder.module
511
+ vote_sync = cgutils.get_or_insert_function(lmod, fnty, fname)
512
+
513
+ # Intrinsic arguments
514
+ mode = ir.Constant(i32, mode_value)
515
+ mask_i32 = builder.trunc(mask, i32)
516
+
517
+ # Convert predicate to i1
518
+ if predicate.type != ir.IntType(1):
519
+ predicate_bool = builder.icmp_signed(
520
+ "!=", predicate, ir.Constant(predicate.type, 0)
521
+ )
522
+ else:
523
+ predicate_bool = predicate
524
+
525
+ return builder.call(vote_sync, [mask_i32, mode, predicate_bool])
526
+
527
+ sig = signature(
528
+ types.Tuple((types.i4, types.b1)), mask_type, predicate_type
529
+ )
530
+
531
+ return sig, codegen
@@ -69,8 +69,7 @@ def libdevice_implement_multiple_returns(func, retty, prototype_args):
69
69
  tuple_args = []
70
70
  if retty != types.void:
71
71
  tuple_args.append(ret)
72
- for arg in virtual_args:
73
- tuple_args.append(builder.load(arg))
72
+ tuple_args.extend(map(builder.load, virtual_args))
74
73
 
75
74
  if isinstance(nb_retty, types.UniTuple):
76
75
  return cgutils.pack_array(builder, tuple_args)