numba-cuda 0.21.1__cp313-cp313-win_amd64.whl → 0.23.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/api.py +4 -1
  3. numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
  4. numba_cuda/numba/cuda/cext/_dispatcher.cpp +0 -38
  5. numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
  6. numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
  7. numba_cuda/numba/cuda/cext/_typeof.cpp +0 -111
  8. numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
  9. numba_cuda/numba/cuda/codegen.py +42 -10
  10. numba_cuda/numba/cuda/compiler.py +10 -4
  11. numba_cuda/numba/cuda/core/analysis.py +29 -21
  12. numba_cuda/numba/cuda/core/annotations/type_annotations.py +4 -4
  13. numba_cuda/numba/cuda/core/base.py +6 -1
  14. numba_cuda/numba/cuda/core/consts.py +1 -1
  15. numba_cuda/numba/cuda/core/cuda_errors.py +917 -0
  16. numba_cuda/numba/cuda/core/errors.py +4 -912
  17. numba_cuda/numba/cuda/core/inline_closurecall.py +71 -57
  18. numba_cuda/numba/cuda/core/interpreter.py +79 -64
  19. numba_cuda/numba/cuda/core/ir.py +191 -119
  20. numba_cuda/numba/cuda/core/ir_utils.py +142 -112
  21. numba_cuda/numba/cuda/core/postproc.py +8 -8
  22. numba_cuda/numba/cuda/core/rewrites/ir_print.py +6 -3
  23. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +5 -5
  24. numba_cuda/numba/cuda/core/rewrites/static_raise.py +3 -3
  25. numba_cuda/numba/cuda/core/ssa.py +3 -3
  26. numba_cuda/numba/cuda/core/transforms.py +25 -10
  27. numba_cuda/numba/cuda/core/typed_passes.py +9 -9
  28. numba_cuda/numba/cuda/core/typeinfer.py +39 -24
  29. numba_cuda/numba/cuda/core/untyped_passes.py +71 -55
  30. numba_cuda/numba/cuda/cudadecl.py +0 -13
  31. numba_cuda/numba/cuda/cudadrv/devicearray.py +6 -5
  32. numba_cuda/numba/cuda/cudadrv/driver.py +132 -511
  33. numba_cuda/numba/cuda/cudadrv/dummyarray.py +4 -0
  34. numba_cuda/numba/cuda/cudadrv/nvrtc.py +16 -0
  35. numba_cuda/numba/cuda/cudaimpl.py +0 -12
  36. numba_cuda/numba/cuda/debuginfo.py +104 -10
  37. numba_cuda/numba/cuda/descriptor.py +1 -1
  38. numba_cuda/numba/cuda/device_init.py +4 -7
  39. numba_cuda/numba/cuda/dispatcher.py +36 -32
  40. numba_cuda/numba/cuda/intrinsics.py +150 -1
  41. numba_cuda/numba/cuda/lowering.py +64 -29
  42. numba_cuda/numba/cuda/memory_management/nrt.py +10 -14
  43. numba_cuda/numba/cuda/np/arrayobj.py +54 -0
  44. numba_cuda/numba/cuda/np/numpy_support.py +26 -0
  45. numba_cuda/numba/cuda/printimpl.py +20 -0
  46. numba_cuda/numba/cuda/serialize.py +10 -0
  47. numba_cuda/numba/cuda/stubs.py +0 -11
  48. numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +21 -4
  49. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +1 -2
  50. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +130 -48
  51. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +6 -2
  52. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +3 -1
  53. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +5 -6
  54. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +11 -12
  55. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +27 -19
  56. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +47 -0
  57. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +10 -0
  58. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +89 -0
  59. numba_cuda/numba/cuda/tests/cudapy/test_device_array_capture.py +243 -0
  60. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +3 -3
  61. numba_cuda/numba/cuda/tests/cudapy/test_numba_interop.py +35 -0
  62. numba_cuda/numba/cuda/tests/cudapy/test_print.py +51 -0
  63. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +116 -1
  64. numba_cuda/numba/cuda/tests/doc_examples/test_globals.py +111 -0
  65. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +61 -0
  66. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +31 -0
  67. numba_cuda/numba/cuda/typing/context.py +3 -1
  68. numba_cuda/numba/cuda/typing/typeof.py +56 -0
  69. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/METADATA +1 -1
  70. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/RECORD +74 -74
  71. numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
  72. numba_cuda/numba/cuda/cext/_devicearray.cpp +0 -159
  73. numba_cuda/numba/cuda/cext/_devicearray.h +0 -29
  74. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -41
  75. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/WHEEL +0 -0
  76. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/licenses/LICENSE +0 -0
  77. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/licenses/LICENSE.numba +0 -0
  78. {numba_cuda-0.21.1.dist-info → numba_cuda-0.23.0.dist-info}/top_level.txt +0 -0
@@ -279,6 +279,10 @@ class Array(object):
279
279
  if not self.dims:
280
280
  return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
281
281
 
282
+ # All 0-size arrays are considered contiguous, even if they are multidimensional
283
+ if self.size == 0:
284
+ return {"C_CONTIGUOUS": True, "F_CONTIGUOUS": True}
285
+
282
286
  # If this is a broadcast array then it is not contiguous
283
287
  if any([dim.stride == 0 for dim in self.dims]):
284
288
  return {"C_CONTIGUOUS": False, "F_CONTIGUOUS": False}
@@ -109,6 +109,22 @@ def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
109
109
 
110
110
  includes = [numba_include, *cuda_includes, nrt_include, *extra_includes]
111
111
 
112
+ # TODO: move all this into Program/ProgramOptions
113
+ # logsz = config.CUDA_LOG_SIZE
114
+ #
115
+ # jitinfo = bytearray(logsz)
116
+ # jiterrors = bytearray(logsz)
117
+ #
118
+ # jit_option = binding.CUjit_option
119
+ # options = {
120
+ # jit_option.CU_JIT_INFO_LOG_BUFFER: jitinfo,
121
+ # jit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES: logsz,
122
+ # jit_option.CU_JIT_ERROR_LOG_BUFFER: jiterrors,
123
+ # jit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES: logsz,
124
+ # jit_option.CU_JIT_LOG_VERBOSE: config.CUDA_VERBOSE_JIT_LOG,
125
+ # }
126
+ # info_log = jitinfo.decode("utf-8")
127
+
112
128
  options = ProgramOptions(
113
129
  arch=arch,
114
130
  include_path=includes,
@@ -280,18 +280,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
280
280
  return context.get_dummy_value()
281
281
 
282
282
 
283
- @lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
284
- def ptx_vote_sync(context, builder, sig, args):
285
- fname = "llvm.nvvm.vote.sync"
286
- lmod = builder.module
287
- fnty = ir.FunctionType(
288
- ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
289
- (ir.IntType(32), ir.IntType(32), ir.IntType(1)),
290
- )
291
- func = cgutils.get_or_insert_function(lmod, fnty, fname)
292
- return builder.call(func, args)
293
-
294
-
295
283
  @lower(stubs.match_any_sync, types.i4, types.i4)
296
284
  @lower(stubs.match_any_sync, types.i4, types.i8)
297
285
  @lower(stubs.match_any_sync, types.i4, types.f4)
@@ -4,6 +4,7 @@
4
4
  import abc
5
5
  import os
6
6
  from contextlib import contextmanager
7
+ from enum import IntEnum
7
8
 
8
9
  import llvmlite
9
10
  from llvmlite import ir
@@ -71,6 +72,16 @@ if not hasattr(config, "CUDA_DEBUG_POLY_USE_TYPED_CONST"):
71
72
  config.CUDA_DEBUG_POLY_USE_TYPED_CONST = DEBUG_POLY_USE_TYPED_CONST
72
73
 
73
74
 
75
+ class DwarfAddressClass(IntEnum):
76
+ GENERIC = 0x00
77
+ GLOBAL = 0x01
78
+ REGISTER = 0x02
79
+ CONSTANT = 0x05
80
+ LOCAL = 0x06
81
+ PARAMETER = 0x07
82
+ SHARED = 0x08
83
+
84
+
74
85
  @contextmanager
75
86
  def suspend_emission(builder):
76
87
  """Suspends the emission of debug_metadata for the duration of the context
@@ -179,6 +190,19 @@ class DIBuilder(AbstractDIBuilder):
179
190
  # constructing subprograms
180
191
  self.dicompileunit = self._di_compile_unit()
181
192
 
193
+ def get_dwarf_address_class(self, addrspace):
194
+ # Map NVVM address space to DWARF address class.
195
+ from numba.cuda.cudadrv import nvvm
196
+
197
+ addrspace_to_addrclass_dict = {
198
+ nvvm.ADDRSPACE_GENERIC: None,
199
+ nvvm.ADDRSPACE_GLOBAL: DwarfAddressClass.GLOBAL,
200
+ nvvm.ADDRSPACE_SHARED: DwarfAddressClass.SHARED,
201
+ nvvm.ADDRSPACE_CONSTANT: DwarfAddressClass.CONSTANT,
202
+ nvvm.ADDRSPACE_LOCAL: DwarfAddressClass.LOCAL,
203
+ }
204
+ return addrspace_to_addrclass_dict.get(addrspace)
205
+
182
206
  def _var_type(self, lltype, size, datamodel=None):
183
207
  if self._DEBUG:
184
208
  print(
@@ -622,6 +646,11 @@ class CUDADIBuilder(DIBuilder):
622
646
  super().__init__(module, filepath, cgctx, directives_only)
623
647
  # Cache for local variable metadata type and line deduplication
624
648
  self._vartypelinemap = {}
649
+ # Variable address space dictionary
650
+ self._var_addrspace_map = {}
651
+
652
+ def _set_addrspace_map(self, map):
653
+ self._var_addrspace_map = map
625
654
 
626
655
  def _var_type(self, lltype, size, datamodel=None):
627
656
  is_bool = False
@@ -796,6 +825,65 @@ class CUDADIBuilder(DIBuilder):
796
825
  },
797
826
  is_distinct=True,
798
827
  )
828
+
829
+ # Check if there's actually address space info to handle
830
+ addrspace = getattr(self, "_addrspace", None)
831
+ if (
832
+ isinstance(lltype, ir.LiteralStructType)
833
+ and datamodel is not None
834
+ and datamodel.inner_models()
835
+ and addrspace not in (None, 0)
836
+ ):
837
+ # Process struct with datamodel that has address space info
838
+ meta = []
839
+ offset = 0
840
+ for element, field, model in zip(
841
+ lltype.elements, datamodel._fields, datamodel.inner_models()
842
+ ):
843
+ size_field = self.cgctx.get_abi_sizeof(element)
844
+ if isinstance(element, ir.PointerType) and field == "data":
845
+ # Create pointer type with correct address space
846
+ pointee_size = self.cgctx.get_abi_sizeof(element.pointee)
847
+ pointee_model = getattr(model, "_pointee_model", None)
848
+ pointee_type = self._var_type(
849
+ element.pointee, pointee_size, datamodel=pointee_model
850
+ )
851
+ meta_ptr = {
852
+ "tag": ir.DIToken("DW_TAG_pointer_type"),
853
+ "baseType": pointee_type,
854
+ "size": _BYTE_SIZE * size_field,
855
+ }
856
+ dwarf_addrclass = self.get_dwarf_address_class(addrspace)
857
+ if dwarf_addrclass is not None:
858
+ meta_ptr["dwarfAddressSpace"] = int(dwarf_addrclass)
859
+ basetype = m.add_debug_info("DIDerivedType", meta_ptr)
860
+ else:
861
+ basetype = self._var_type(
862
+ element, size_field, datamodel=model
863
+ )
864
+ derived_type = m.add_debug_info(
865
+ "DIDerivedType",
866
+ {
867
+ "tag": ir.DIToken("DW_TAG_member"),
868
+ "name": field,
869
+ "baseType": basetype,
870
+ "size": _BYTE_SIZE * size_field,
871
+ "offset": offset,
872
+ },
873
+ )
874
+ meta.append(derived_type)
875
+ offset += _BYTE_SIZE * size_field
876
+
877
+ return m.add_debug_info(
878
+ "DICompositeType",
879
+ {
880
+ "tag": ir.DIToken("DW_TAG_structure_type"),
881
+ "name": f"{datamodel.fe_type}",
882
+ "elements": m.add_metadata(meta),
883
+ "size": offset,
884
+ },
885
+ is_distinct=True,
886
+ )
799
887
  # For other cases, use upstream Numba implementation
800
888
  return super()._var_type(lltype, size, datamodel=datamodel)
801
889
 
@@ -848,16 +936,22 @@ class CUDADIBuilder(DIBuilder):
848
936
  # to llvm.dbg.value
849
937
  return
850
938
  else:
851
- return super().mark_variable(
852
- builder,
853
- allocavalue,
854
- name,
855
- lltype,
856
- size,
857
- line,
858
- datamodel,
859
- argidx,
860
- )
939
+ # Look up address space for this variable
940
+ self._addrspace = self._var_addrspace_map.get(name)
941
+ try:
942
+ return super().mark_variable(
943
+ builder,
944
+ allocavalue,
945
+ name,
946
+ lltype,
947
+ size,
948
+ line,
949
+ datamodel,
950
+ argidx,
951
+ )
952
+ finally:
953
+ # Clean up address space info
954
+ self._addrspace = None
861
955
 
862
956
  def update_variable(
863
957
  self,
@@ -28,7 +28,7 @@ class CUDATarget:
28
28
  @property
29
29
  def target_context(self):
30
30
  if self._targetctx is None:
31
- self._targetctx = CUDATargetContext(self._typingctx)
31
+ self._targetctx = CUDATargetContext(self.typing_context)
32
32
  return self._targetctx
33
33
 
34
34
 
@@ -27,7 +27,6 @@ from .stubs import (
27
27
  local,
28
28
  const,
29
29
  atomic,
30
- vote_sync_intrinsic,
31
30
  match_any_sync,
32
31
  match_all_sync,
33
32
  threadfence_block,
@@ -56,6 +55,10 @@ from .intrinsics import (
56
55
  shfl_up_sync,
57
56
  shfl_down_sync,
58
57
  shfl_xor_sync,
58
+ all_sync,
59
+ any_sync,
60
+ eq_sync,
61
+ ballot_sync,
59
62
  )
60
63
  from .cudadrv.error import CudaSupportError
61
64
  from numba.cuda.cudadrv.driver import (
@@ -79,12 +82,6 @@ from .api import *
79
82
  from .api import _auto_device
80
83
  from .args import In, Out, InOut
81
84
 
82
- from .intrinsic_wrapper import (
83
- all_sync,
84
- any_sync,
85
- eq_sync,
86
- ballot_sync,
87
- )
88
85
 
89
86
  from .kernels import reduction
90
87
  from numba.cuda.cudadrv.linkable_code import (
@@ -15,6 +15,8 @@ import uuid
15
15
  import re
16
16
  from warnings import warn
17
17
 
18
+ from cuda.core.experimental import launch
19
+
18
20
  from numba.cuda.core import errors
19
21
  from numba.cuda import serialize, utils
20
22
  from numba import cuda
@@ -39,6 +41,7 @@ from numba.cuda.compiler import (
39
41
  from numba.cuda.core import sigutils, config, entrypoints
40
42
  from numba.cuda.flags import Flags
41
43
  from numba.cuda.cudadrv import driver, nvvm
44
+ from cuda.core.experimental import LaunchConfig
42
45
  from numba.cuda.locks import module_init_lock
43
46
  from numba.cuda.core.caching import Cache, CacheImpl, NullCache
44
47
  from numba.cuda.descriptor import cuda_target
@@ -475,18 +478,15 @@ class _Kernel(serialize.ReduceMixin):
475
478
  for t, v in zip(self.argument_types, args):
476
479
  self._prepare_args(t, v, stream, retr, kernelargs)
477
480
 
478
- stream_handle = driver._stream_handle(stream)
479
-
480
481
  # Invoke kernel
481
- driver.launch_kernel(
482
- cufunc.handle,
483
- *griddim,
484
- *blockdim,
485
- sharedmem,
486
- stream_handle,
487
- kernelargs,
488
- cooperative=self.cooperative,
482
+ config = LaunchConfig(
483
+ grid=griddim,
484
+ block=blockdim,
485
+ shmem_size=sharedmem,
486
+ cooperative_launch=self.cooperative,
489
487
  )
488
+ kernel = cufunc.kernel
489
+ launch(stream, config, kernel, *kernelargs)
490
490
 
491
491
  if self.debug:
492
492
  driver.device_to_host(ctypes.addressof(excval), excmem, excsz)
@@ -540,30 +540,26 @@ class _Kernel(serialize.ReduceMixin):
540
540
 
541
541
  if isinstance(ty, types.Array):
542
542
  devary = wrap_arg(val).to_device(retr, stream)
543
- c_intp = ctypes.c_ssize_t
544
-
545
- meminfo = ctypes.c_void_p(0)
546
- parent = ctypes.c_void_p(0)
547
- nitems = c_intp(devary.size)
548
- itemsize = c_intp(devary.dtype.itemsize)
549
-
550
- ptr = driver.device_pointer(devary)
551
-
552
- ptr = int(ptr)
553
543
 
554
- data = ctypes.c_void_p(ptr)
544
+ meminfo = 0
545
+ parent = 0
555
546
 
556
547
  kernelargs.append(meminfo)
557
548
  kernelargs.append(parent)
558
- kernelargs.append(nitems)
559
- kernelargs.append(itemsize)
560
- kernelargs.append(data)
561
- kernelargs.extend(map(c_intp, devary.shape))
562
- kernelargs.extend(map(c_intp, devary.strides))
549
+
550
+ # non-pointer-arguments-without-ctypes might be dicey, since we're
551
+ # assuming shape, strides, size, and itemsize fit into intptr_t
552
+ # however, this saves a noticeable amount of overhead in kernel
553
+ # invocation
554
+ kernelargs.append(devary.size)
555
+ kernelargs.append(devary.dtype.itemsize)
556
+ kernelargs.append(devary.device_ctypes_pointer.value)
557
+ kernelargs.extend(devary.shape)
558
+ kernelargs.extend(devary.strides)
563
559
 
564
560
  elif isinstance(ty, types.CPointer):
565
561
  # Pointer arguments should be a pointer-sized integer
566
- kernelargs.append(ctypes.c_uint64(val))
562
+ kernelargs.append(val)
567
563
 
568
564
  elif isinstance(ty, types.Integer):
569
565
  cval = getattr(ctypes, "c_%s" % ty)(val)
@@ -582,8 +578,7 @@ class _Kernel(serialize.ReduceMixin):
582
578
  kernelargs.append(cval)
583
579
 
584
580
  elif ty == types.boolean:
585
- cval = ctypes.c_uint8(int(val))
586
- kernelargs.append(cval)
581
+ kernelargs.append(val)
587
582
 
588
583
  elif ty == types.complex64:
589
584
  kernelargs.append(ctypes.c_float(val.real))
@@ -598,8 +593,7 @@ class _Kernel(serialize.ReduceMixin):
598
593
 
599
594
  elif isinstance(ty, types.Record):
600
595
  devrec = wrap_arg(val).to_device(retr, stream)
601
- ptr = devrec.device_ctypes_pointer
602
- kernelargs.append(ptr)
596
+ kernelargs.append(devrec.device_ctypes_pointer.value)
603
597
 
604
598
  elif isinstance(ty, types.BaseTuple):
605
599
  assert len(ty) == len(val)
@@ -671,7 +665,7 @@ class _LaunchConfiguration:
671
665
  self.dispatcher = dispatcher
672
666
  self.griddim = griddim
673
667
  self.blockdim = blockdim
674
- self.stream = stream
668
+ self.stream = driver._to_core_stream(stream)
675
669
  self.sharedmem = sharedmem
676
670
 
677
671
  if (
@@ -700,6 +694,16 @@ class _LaunchConfiguration:
700
694
  args, self.griddim, self.blockdim, self.stream, self.sharedmem
701
695
  )
702
696
 
697
+ def __getstate__(self):
698
+ state = self.__dict__.copy()
699
+ state["stream"] = int(state["stream"].handle)
700
+ return state
701
+
702
+ def __setstate__(self, state):
703
+ handle = state.pop("stream")
704
+ self.__dict__.update(state)
705
+ self.stream = driver._to_core_stream(handle)
706
+
703
707
 
704
708
  class CUDACacheImpl(CacheImpl):
705
709
  def reduce(self, kernel):
@@ -6,7 +6,11 @@ from llvmlite import ir
6
6
  from numba import cuda
7
7
  from numba.cuda import types
8
8
  from numba.cuda import cgutils
9
- from numba.cuda.core.errors import RequireLiteralValue, TypingError
9
+ from numba.cuda.core.errors import (
10
+ RequireLiteralValue,
11
+ TypingError,
12
+ NumbaTypeError,
13
+ )
10
14
  from numba.cuda.typing import signature
11
15
  from numba.cuda.extending import overload_attribute, overload_method
12
16
  from numba.cuda import nvvmutils
@@ -380,3 +384,148 @@ def shfl_sync_intrinsic(
380
384
  sig = signature(a_type, membermask_type, a_type, b_type)
381
385
 
382
386
  return sig, codegen
387
+
388
+
389
+ # -------------------------------------------------------------------------------
390
+ # Warp vote functions
391
+ #
392
+ # References:
393
+ #
394
+ # - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-vote-functions
395
+ # - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html?highlight=data%2520movement#vote
396
+ #
397
+ # Notes:
398
+ #
399
+ # - The NVVM IR specification requires some of the mode parameter to be
400
+ # constants. It's therefore essential that we pass in mode values to the
401
+ # vote_sync_intrinsic.
402
+
403
+
404
+ @intrinsic
405
+ def all_sync(typingctx, mask_type, predicate_type):
406
+ """
407
+ If for all threads in the masked warp the predicate is true, then
408
+ a non-zero value is returned, otherwise 0 is returned.
409
+ """
410
+ mode_value = 0
411
+ sig, codegen_inner = vote_sync_intrinsic(
412
+ typingctx, mask_type, mode_value, predicate_type
413
+ )
414
+
415
+ def codegen(context, builder, sig_outer, args):
416
+ # Call vote_sync_intrinsic and extract the boolean result (index 1)
417
+ result_tuple = codegen_inner(context, builder, sig, args)
418
+ return builder.extract_value(result_tuple, 1)
419
+
420
+ sig_outer = signature(types.b1, mask_type, predicate_type)
421
+ return sig_outer, codegen
422
+
423
+
424
+ @intrinsic
425
+ def any_sync(typingctx, mask_type, predicate_type):
426
+ """
427
+ If for any thread in the masked warp the predicate is true, then
428
+ a non-zero value is returned, otherwise 0 is returned.
429
+ """
430
+ mode_value = 1
431
+ sig, codegen_inner = vote_sync_intrinsic(
432
+ typingctx, mask_type, mode_value, predicate_type
433
+ )
434
+
435
+ def codegen(context, builder, sig_outer, args):
436
+ result_tuple = codegen_inner(context, builder, sig, args)
437
+ return builder.extract_value(result_tuple, 1)
438
+
439
+ sig_outer = signature(types.b1, mask_type, predicate_type)
440
+ return sig_outer, codegen
441
+
442
+
443
+ @intrinsic
444
+ def eq_sync(typingctx, mask_type, predicate_type):
445
+ """
446
+ If for all threads in the masked warp the boolean predicate is the same,
447
+ then a non-zero value is returned, otherwise 0 is returned.
448
+ """
449
+ mode_value = 2
450
+ sig, codegen_inner = vote_sync_intrinsic(
451
+ typingctx, mask_type, mode_value, predicate_type
452
+ )
453
+
454
+ def codegen(context, builder, sig_outer, args):
455
+ result_tuple = codegen_inner(context, builder, sig, args)
456
+ return builder.extract_value(result_tuple, 1)
457
+
458
+ sig_outer = signature(types.b1, mask_type, predicate_type)
459
+ return sig_outer, codegen
460
+
461
+
462
+ @intrinsic
463
+ def ballot_sync(typingctx, mask_type, predicate_type):
464
+ """
465
+ Returns a mask of all threads in the warp whose predicate is true,
466
+ and are within the given mask.
467
+ """
468
+ mode_value = 3
469
+ sig, codegen_inner = vote_sync_intrinsic(
470
+ typingctx, mask_type, mode_value, predicate_type
471
+ )
472
+
473
+ def codegen(context, builder, sig_outer, args):
474
+ result_tuple = codegen_inner(context, builder, sig, args)
475
+ return builder.extract_value(
476
+ result_tuple, 0
477
+ ) # Extract ballot result (index 0)
478
+
479
+ sig_outer = signature(types.i4, mask_type, predicate_type)
480
+ return sig_outer, codegen
481
+
482
+
483
+ def vote_sync_intrinsic(typingctx, mask_type, mode_value, predicate_type):
484
+ # Validate mode value
485
+ if mode_value not in (0, 1, 2, 3):
486
+ raise ValueError("Mode must be 0 (all), 1 (any), 2 (eq), or 3 (ballot)")
487
+
488
+ if types.unliteral(mask_type) not in types.integer_domain:
489
+ raise NumbaTypeError(f"Mask type must be an integer. Got {mask_type}")
490
+ predicate_types = types.integer_domain | {types.boolean}
491
+
492
+ if types.unliteral(predicate_type) not in predicate_types:
493
+ raise NumbaTypeError(
494
+ f"Predicate must be an integer or boolean. Got {predicate_type}"
495
+ )
496
+
497
+ def codegen(context, builder, sig, args):
498
+ mask, predicate = args
499
+
500
+ # Types
501
+ i1 = ir.IntType(1)
502
+ i32 = ir.IntType(32)
503
+
504
+ # NVVM intrinsic definition
505
+ arg_types = (i32, i32, i1)
506
+ vote_return_type = ir.LiteralStructType((i32, i1))
507
+ fnty = ir.FunctionType(vote_return_type, arg_types)
508
+
509
+ fname = "llvm.nvvm.vote.sync"
510
+ lmod = builder.module
511
+ vote_sync = cgutils.get_or_insert_function(lmod, fnty, fname)
512
+
513
+ # Intrinsic arguments
514
+ mode = ir.Constant(i32, mode_value)
515
+ mask_i32 = builder.trunc(mask, i32)
516
+
517
+ # Convert predicate to i1
518
+ if predicate.type != ir.IntType(1):
519
+ predicate_bool = builder.icmp_signed(
520
+ "!=", predicate, ir.Constant(predicate.type, 0)
521
+ )
522
+ else:
523
+ predicate_bool = predicate
524
+
525
+ return builder.call(vote_sync, [mask_i32, mode, predicate_bool])
526
+
527
+ sig = signature(
528
+ types.Tuple((types.i4, types.b1)), mask_type, predicate_type
529
+ )
530
+
531
+ return sig, codegen