numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (171) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +4 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  129. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  130. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  135. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  140. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  142. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  143. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  145. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  146. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  147. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  148. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  151. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  152. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  153. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  154. numba_cuda/numba/cuda/tests/support.py +55 -15
  155. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  156. numba_cuda/numba/cuda/types.py +56 -0
  157. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  158. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  159. numba_cuda/numba/cuda/typing/context.py +751 -0
  160. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  161. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  162. numba_cuda/numba/cuda/typing/templates.py +7 -6
  163. numba_cuda/numba/cuda/ufuncs.py +3 -3
  164. numba_cuda/numba/cuda/utils.py +6 -112
  165. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
  167. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  168. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
@@ -10,18 +10,14 @@ from llvmlite import ir as llvm_ir
10
10
 
11
11
  from numba.core import (
12
12
  typing,
13
- utils,
14
13
  types,
15
14
  ir,
16
- debuginfo,
17
- funcdesc,
18
15
  generators,
19
- config,
20
- cgutils,
21
16
  removerefctpass,
22
- targetconfig,
23
17
  )
24
- from numba.cuda.core import ir_utils
18
+ from numba.cuda import debuginfo, cgutils, utils
19
+ from numba.cuda.core import ir_utils, targetconfig, funcdesc, config
20
+
25
21
  from numba.core.errors import (
26
22
  LoweringError,
27
23
  new_error_context,
@@ -30,8 +26,8 @@ from numba.core.errors import (
30
26
  UnsupportedError,
31
27
  NumbaDebugInfoWarning,
32
28
  )
33
- from numba.core.funcdesc import default_mangler
34
- from numba.core.environment import Environment
29
+ from numba.cuda.core.funcdesc import default_mangler
30
+ from numba.cuda.core.environment import Environment
35
31
  from numba.core.analysis import compute_use_defs, must_use_alloca
36
32
  from numba.misc.firstlinefinder import get_func_body_first_lineno
37
33
  from numba import version_info
@@ -466,7 +462,7 @@ class Lower(BaseLower):
466
462
  self._blk_local_varmap = {}
467
463
 
468
464
  def pre_block(self, block):
469
- from numba.core.unsafe import eh
465
+ from numba.cuda.core.unsafe import eh
470
466
 
471
467
  super(Lower, self).pre_block(block)
472
468
  self._cur_ir_block = block
@@ -1029,9 +1025,6 @@ class Lower(BaseLower):
1029
1025
  elif isinstance(fnty, types.RecursiveCall):
1030
1026
  res = self._lower_call_RecursiveCall(fnty, expr, signature)
1031
1027
 
1032
- elif isinstance(fnty, types.FunctionType):
1033
- res = self._lower_call_FunctionType(fnty, expr, signature)
1034
-
1035
1028
  else:
1036
1029
  res = self._lower_call_normal(fnty, expr, signature)
1037
1030
 
@@ -1052,7 +1045,7 @@ class Lower(BaseLower):
1052
1045
  )
1053
1046
 
1054
1047
  def _lower_call_ObjModeDispatcher(self, fnty, expr, signature):
1055
- from numba.core.pythonapi import ObjModeUtils
1048
+ from numba.cuda.core.pythonapi import ObjModeUtils
1056
1049
 
1057
1050
  self.init_pyapi()
1058
1051
  # Acquire the GIL
@@ -1229,136 +1222,6 @@ class Lower(BaseLower):
1229
1222
  )
1230
1223
  return res
1231
1224
 
1232
- def _lower_call_FunctionType(self, fnty, expr, signature):
1233
- self.debug_print("# calling first-class function type")
1234
- sig = types.unliteral(signature)
1235
- if not fnty.check_signature(signature):
1236
- # value dependent polymorphism?
1237
- raise UnsupportedError(
1238
- f"mismatch of function types:"
1239
- f" expected {fnty} but got {types.FunctionType(sig)}"
1240
- )
1241
- argvals = self.fold_call_args(
1242
- fnty,
1243
- sig,
1244
- expr.args,
1245
- expr.vararg,
1246
- expr.kws,
1247
- )
1248
- return self.__call_first_class_function_pointer(
1249
- fnty.ftype,
1250
- expr.func.name,
1251
- sig,
1252
- argvals,
1253
- )
1254
-
1255
- def __call_first_class_function_pointer(self, ftype, fname, sig, argvals):
1256
- """
1257
- Calls a first-class function pointer.
1258
-
1259
- This function is responsible for calling a first-class function pointer,
1260
- which can either be a JIT-compiled function or a Python function. It
1261
- determines if a JIT address is available, and if so, calls the function
1262
- using the JIT address. Otherwise, it calls the function using a function
1263
- pointer obtained from the `__get_first_class_function_pointer` method.
1264
-
1265
- Args:
1266
- ftype: The type of the function.
1267
- fname: The name of the function.
1268
- sig: The signature of the function.
1269
- argvals: The argument values to pass to the function.
1270
-
1271
- Returns:
1272
- The result of calling the function.
1273
- """
1274
- context = self.context
1275
- builder = self.builder
1276
- # Determine if jit address is available
1277
- fstruct = self.loadvar(fname)
1278
- struct = cgutils.create_struct_proxy(self.typeof(fname))(
1279
- context, builder, value=fstruct
1280
- )
1281
- jit_addr = struct.jit_addr
1282
- jit_addr.name = f"jit_addr_of_{fname}"
1283
-
1284
- ctx = context
1285
- res_slot = cgutils.alloca_once(
1286
- builder, ctx.get_value_type(sig.return_type)
1287
- )
1288
-
1289
- if_jit_addr_is_null = builder.if_else(
1290
- cgutils.is_null(builder, jit_addr), likely=False
1291
- )
1292
- with if_jit_addr_is_null as (then, orelse):
1293
- with then:
1294
- func_ptr = self.__get_first_class_function_pointer(
1295
- ftype, fname, sig
1296
- )
1297
- res = builder.call(func_ptr, argvals)
1298
- builder.store(res, res_slot)
1299
-
1300
- with orelse:
1301
- llty = ctx.call_conv.get_function_type(
1302
- sig.return_type, sig.args
1303
- ).as_pointer()
1304
- func_ptr = builder.bitcast(jit_addr, llty)
1305
- # call
1306
- status, res = ctx.call_conv.call_function(
1307
- builder, func_ptr, sig.return_type, sig.args, argvals
1308
- )
1309
- with cgutils.if_unlikely(builder, status.is_error):
1310
- context.call_conv.return_status_propagate(builder, status)
1311
- builder.store(res, res_slot)
1312
- return builder.load(res_slot)
1313
-
1314
- def __get_first_class_function_pointer(self, ftype, fname, sig):
1315
- from numba.experimental.function_type import lower_get_wrapper_address
1316
-
1317
- llty = self.context.get_value_type(ftype)
1318
- fstruct = self.loadvar(fname)
1319
- addr = self.builder.extract_value(
1320
- fstruct, 0, name="addr_of_%s" % (fname)
1321
- )
1322
-
1323
- fptr = cgutils.alloca_once(
1324
- self.builder, llty, name="fptr_of_%s" % (fname)
1325
- )
1326
- with self.builder.if_else(
1327
- cgutils.is_null(self.builder, addr), likely=False
1328
- ) as (then, orelse):
1329
- with then:
1330
- self.init_pyapi()
1331
- # Acquire the GIL
1332
- gil_state = self.pyapi.gil_ensure()
1333
- pyaddr = self.builder.extract_value(
1334
- fstruct, 1, name="pyaddr_of_%s" % (fname)
1335
- )
1336
- # try to recover the function address, see
1337
- # test_zero_address BadToGood example in
1338
- # test_function_type.py
1339
- addr1 = lower_get_wrapper_address(
1340
- self.context,
1341
- self.builder,
1342
- pyaddr,
1343
- sig,
1344
- failure_mode="ignore",
1345
- )
1346
- with self.builder.if_then(
1347
- cgutils.is_null(self.builder, addr1), likely=False
1348
- ):
1349
- self.return_exception(
1350
- RuntimeError,
1351
- exc_args=(f"{ftype} function address is null",),
1352
- loc=self.loc,
1353
- )
1354
- addr2 = self.pyapi.long_as_voidptr(addr1)
1355
- self.builder.store(self.builder.bitcast(addr2, llty), fptr)
1356
- self.pyapi.decref(addr1)
1357
- self.pyapi.gil_release(gil_state)
1358
- with orelse:
1359
- self.builder.store(self.builder.bitcast(addr, llty), fptr)
1360
- return self.builder.load(fptr)
1361
-
1362
1225
  def _lower_call_normal(self, fnty, expr, signature):
1363
1226
  # Normal function resolution
1364
1227
  self.debug_print("# calling normal function: {0}".format(fnty))
@@ -4,11 +4,12 @@
4
4
  import math
5
5
  import operator
6
6
  from llvmlite import ir
7
- from numba.core import types, typing, targetconfig
7
+ from numba.core import types, typing
8
8
  from numba.cuda import cgutils
9
9
  from numba.core.imputils import Registry
10
10
  from numba.types import float32, float64, int64, uint64
11
11
  from numba.cuda import libdevice
12
+ from numba.cuda.core import targetconfig
12
13
 
13
14
  registry = Registry()
14
15
  lower = registry.lower
@@ -6,7 +6,10 @@ import os
6
6
  from functools import wraps
7
7
  import numpy as np
8
8
 
9
- from numba import cuda, config
9
+
10
+ from numba import cuda, types
11
+ from numba.cuda import config
12
+
10
13
  from numba.core.runtime.nrt import _nrt_mstats
11
14
  from numba.cuda.cudadrv.driver import (
12
15
  _Linker,
@@ -17,24 +20,11 @@ from numba.cuda.cudadrv.driver import (
17
20
  )
18
21
  from numba.cuda.cudadrv import devices
19
22
  from numba.cuda.api import get_current_device
20
- from numba.cuda.utils import _readenv, cached_file_read
23
+ from numba.cuda.utils import cached_file_read
21
24
  from numba.cuda.cudadrv.linkable_code import CUSource
25
+ from numba.cuda.typing.templates import signature
22
26
 
23
-
24
- # Check environment variable or config for NRT statistics enablement
25
- NRT_STATS = _readenv("NUMBA_CUDA_NRT_STATS", bool, False) or getattr(
26
- config, "NUMBA_CUDA_NRT_STATS", False
27
- )
28
- if not hasattr(config, "NUMBA_CUDA_NRT_STATS"):
29
- config.CUDA_NRT_STATS = NRT_STATS
30
-
31
-
32
- # Check environment variable or config for NRT enablement
33
- ENABLE_NRT = _readenv("NUMBA_CUDA_ENABLE_NRT", bool, False) or getattr(
34
- config, "NUMBA_CUDA_ENABLE_NRT", False
35
- )
36
- if not hasattr(config, "NUMBA_CUDA_ENABLE_NRT"):
37
- config.CUDA_ENABLE_NRT = ENABLE_NRT
27
+ from numba.core.extending import intrinsic, overload_classmethod
38
28
 
39
29
 
40
30
  def get_include():
@@ -42,6 +32,34 @@ def get_include():
42
32
  return os.path.dirname(os.path.abspath(__file__))
43
33
 
44
34
 
35
+ # Provide an implementation of Array._allocate() for the CUDA target (used
36
+ # internally by Numba when generating the allocation of an array)
37
+
38
+
39
+ @intrinsic
40
+ def intrin_alloc(typingctx, allocsize, align):
41
+ """Intrinsic to call into the allocator for Array"""
42
+
43
+ def codegen(context, builder, signature, args):
44
+ allocsize, align = args
45
+ meminfo = context.nrt.meminfo_alloc_aligned(builder, allocsize, align)
46
+ return meminfo
47
+
48
+ mip = types.MemInfoPointer(types.voidptr) # return untyped pointer
49
+ sig = signature(mip, allocsize, align)
50
+ return sig, codegen
51
+
52
+
53
+ @overload_classmethod(types.Array, "_allocate", target="CUDA")
54
+ def _ol_array_allocate(cls, allocsize, align):
55
+ """Implements a Numba-only CUDA-target classmethod on the array type."""
56
+
57
+ def impl(cls, allocsize, align):
58
+ return intrin_alloc(allocsize, align)
59
+
60
+ return impl
61
+
62
+
45
63
  # Protect method to ensure NRT memory allocation and initialization
46
64
  def _alloc_init_guard(method):
47
65
  """
@@ -69,10 +87,18 @@ class _Runtime:
69
87
 
70
88
  def __init__(self):
71
89
  """Initialize memsys module and variable"""
90
+ self._reset()
91
+
92
+ def _reset(self):
93
+ """Reset to the uninitialized state"""
72
94
  self._memsys_module = None
73
95
  self._memsys = None
74
96
  self._initialized = False
75
97
 
98
+ def close(self):
99
+ """Close and reset"""
100
+ self._reset()
101
+
76
102
  def _compile_memsys_module(self):
77
103
  """
78
104
  Compile memsys.cu and create a module from it in the current context
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ import sys
5
+ import os
6
+ import re
7
+
8
+
9
+ def get_lib_dirs():
10
+ """
11
+ Anaconda specific
12
+ """
13
+ if sys.platform == "win32":
14
+ # CUDA 12 puts in "bin" directory, whereas CUDA 13 puts in "bin\x64" directory
15
+ dirnames = [
16
+ os.path.join("Library", "bin"),
17
+ os.path.join("Library", "bin", "x64"),
18
+ os.path.join("Library", "nvvm", "bin"),
19
+ os.path.join("Library", "nvvm", "bin", "x64"),
20
+ ]
21
+ else:
22
+ dirnames = [
23
+ "lib",
24
+ ]
25
+ libdirs = [os.path.join(sys.prefix, x) for x in dirnames]
26
+ return libdirs
27
+
28
+
29
+ DLLNAMEMAP = {
30
+ "linux": r"lib%(name)s\.so\.%(ver)s$",
31
+ "linux2": r"lib%(name)s\.so\.%(ver)s$",
32
+ "linux-static": r"lib%(name)s\.a$",
33
+ "darwin": r"lib%(name)s\.%(ver)s\.dylib$",
34
+ "win32": r"%(name)s%(ver)s\.dll$",
35
+ "win32-static": r"%(name)s\.lib$",
36
+ "bsd": r"lib%(name)s\.so\.%(ver)s$",
37
+ }
38
+
39
+ RE_VER = r"[0-9]*([_\.][0-9]+)*"
40
+
41
+
42
+ def find_lib(libname, libdir=None, platform=None, static=False):
43
+ platform = platform or sys.platform
44
+ platform = "bsd" if "bsd" in platform else platform
45
+ if static:
46
+ platform = f"{platform}-static"
47
+ if platform not in DLLNAMEMAP:
48
+ # Return empty list if platform name is undefined.
49
+ # Not all platforms define their static library paths.
50
+ return []
51
+ pat = DLLNAMEMAP[platform] % {"name": libname, "ver": RE_VER}
52
+ regex = re.compile(pat)
53
+ return find_file(regex, libdir)
54
+
55
+
56
+ def find_file(pat, libdir=None):
57
+ if libdir is None:
58
+ libdirs = get_lib_dirs()
59
+ elif isinstance(libdir, str):
60
+ libdirs = [
61
+ libdir,
62
+ ]
63
+ else:
64
+ libdirs = list(libdir)
65
+ files = []
66
+ for ldir in libdirs:
67
+ try:
68
+ entries = os.listdir(ldir)
69
+ except FileNotFoundError:
70
+ continue
71
+ candidates = [
72
+ os.path.join(ldir, ent) for ent in entries if pat.match(ent)
73
+ ]
74
+ files.extend([c for c in candidates if os.path.isfile(c)])
75
+ return files
@@ -6,9 +6,10 @@ import functools
6
6
  from llvmlite import ir
7
7
 
8
8
  from numba.core.datamodel.registry import DataModelManager, register
9
+ from numba.core.datamodel import PrimitiveModel
9
10
  from numba.core.extending import models
10
11
  from numba.core import types
11
- from numba.cuda.types import Dim3, GridGroup, CUDADispatcher
12
+ from numba.cuda.types import Dim3, GridGroup, CUDADispatcher, Bfloat16
12
13
 
13
14
 
14
15
  cuda_data_manager = DataModelManager()
@@ -45,3 +46,10 @@ class FloatModel(models.PrimitiveModel):
45
46
 
46
47
 
47
48
  register_model(CUDADispatcher)(models.OpaqueModel)
49
+
50
+
51
+ @register_model(Bfloat16)
52
+ class _model___nv_bfloat16(PrimitiveModel):
53
+ def __init__(self, dmm, fe_type):
54
+ be_type = ir.IntType(16)
55
+ super(_model___nv_bfloat16, self).__init__(dmm, fe_type, be_type)
@@ -0,0 +1,217 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Helper functions for np.timedelta64 and np.datetime64.
6
+ For now, multiples-of-units (for example timedeltas expressed in tens
7
+ of seconds) are not supported.
8
+ """
9
+
10
+ import numpy as np
11
+
12
+
13
+ DATETIME_UNITS = {
14
+ "Y": 0, # Years
15
+ "M": 1, # Months
16
+ "W": 2, # Weeks
17
+ # Yes, there's a gap here
18
+ "D": 4, # Days
19
+ "h": 5, # Hours
20
+ "m": 6, # Minutes
21
+ "s": 7, # Seconds
22
+ "ms": 8, # Milliseconds
23
+ "us": 9, # Microseconds
24
+ "ns": 10, # Nanoseconds
25
+ "ps": 11, # Picoseconds
26
+ "fs": 12, # Femtoseconds
27
+ "as": 13, # Attoseconds
28
+ "": 14, # "generic", i.e. unit-less
29
+ }
30
+
31
+ NAT = np.timedelta64("nat").astype(np.int64)
32
+
33
+ # NOTE: numpy has several inconsistent functions for timedelta casting:
34
+ # - can_cast_timedelta64_{metadata,units}() disallows "safe" casting
35
+ # to and from generic units
36
+ # - cast_timedelta_to_timedelta() allows casting from (but not to)
37
+ # generic units
38
+ # - compute_datetime_metadata_greatest_common_divisor() allows casting from
39
+ # generic units (used for promotion)
40
+
41
+
42
+ def same_kind(src, dest):
43
+ """
44
+ Whether the *src* and *dest* units are of the same kind.
45
+ """
46
+ return (DATETIME_UNITS[src] < 5) == (DATETIME_UNITS[dest] < 5)
47
+
48
+
49
+ def can_cast_timedelta_units(src, dest):
50
+ # Mimic NumPy's "safe" casting and promotion
51
+ # `dest` must be more precise than `src` and they must be compatible
52
+ # for conversion.
53
+ # XXX should we switch to enforcing "same-kind" for Numpy 1.10+ ?
54
+ src = DATETIME_UNITS[src]
55
+ dest = DATETIME_UNITS[dest]
56
+ if src == dest:
57
+ return True
58
+ if src == 14:
59
+ return True
60
+ if src > dest:
61
+ return False
62
+ if dest == 14:
63
+ # unit-less timedelta64 is not compatible with anything else
64
+ return False
65
+ if src <= 1 and dest > 1:
66
+ # Cannot convert between months or years and other units
67
+ return False
68
+ return True
69
+
70
+
71
+ # Exact conversion factors from one unit to the immediately more precise one
72
+ _factors = {
73
+ 0: (1, 12), # Years -> Months
74
+ 2: (4, 7), # Weeks -> Days
75
+ 4: (5, 24), # Days -> Hours
76
+ 5: (6, 60), # Hours -> Minutes
77
+ 6: (7, 60), # Minutes -> Seconds
78
+ 7: (8, 1000),
79
+ 8: (9, 1000),
80
+ 9: (10, 1000),
81
+ 10: (11, 1000),
82
+ 11: (12, 1000),
83
+ 12: (13, 1000),
84
+ }
85
+
86
+
87
+ def _get_conversion_multiplier(big_unit_code, small_unit_code):
88
+ """
89
+ Return an integer multiplier allowing to convert from *big_unit_code*
90
+ to *small_unit_code*.
91
+ None is returned if the conversion is not possible through a
92
+ simple integer multiplication.
93
+ """
94
+ # Mimics get_datetime_units_factor() in NumPy's datetime.c,
95
+ # with a twist to allow no-op conversion from generic units.
96
+ if big_unit_code == 14:
97
+ return 1
98
+ c = big_unit_code
99
+ factor = 1
100
+ while c < small_unit_code:
101
+ try:
102
+ c, mult = _factors[c]
103
+ except KeyError:
104
+ # No possible conversion
105
+ return None
106
+ factor *= mult
107
+ if c == small_unit_code:
108
+ return factor
109
+ else:
110
+ return None
111
+
112
+
113
+ def get_timedelta_conversion_factor(src_unit, dest_unit):
114
+ """
115
+ Return an integer multiplier allowing to convert from timedeltas
116
+ of *src_unit* to *dest_unit*.
117
+ """
118
+ return _get_conversion_multiplier(
119
+ DATETIME_UNITS[src_unit], DATETIME_UNITS[dest_unit]
120
+ )
121
+
122
+
123
+ def get_datetime_timedelta_conversion(datetime_unit, timedelta_unit):
124
+ """
125
+ Compute a possible conversion for combining *datetime_unit* and
126
+ *timedelta_unit* (presumably for adding or subtracting).
127
+ Return (result unit, integer datetime multiplier, integer timedelta
128
+ multiplier). RuntimeError is raised if the combination is impossible.
129
+ """
130
+ # XXX now unused (I don't know where / how Numpy uses this)
131
+ dt_unit_code = DATETIME_UNITS[datetime_unit]
132
+ td_unit_code = DATETIME_UNITS[timedelta_unit]
133
+ if td_unit_code == 14 or dt_unit_code == 14:
134
+ return datetime_unit, 1, 1
135
+ if td_unit_code < 2 and dt_unit_code >= 2:
136
+ # Cannot combine Y or M timedelta64 with a finer-grained datetime64
137
+ raise RuntimeError(
138
+ "cannot combine datetime64(%r) and timedelta64(%r)"
139
+ % (datetime_unit, timedelta_unit)
140
+ )
141
+ dt_factor, td_factor = 1, 1
142
+
143
+ # If years or months, the datetime unit is first scaled to weeks or days,
144
+ # then conversion continues below. This is the same algorithm as used
145
+ # in Numpy's get_datetime_conversion_factor() (src/multiarray/datetime.c):
146
+ # """Conversions between years/months and other units use
147
+ # the factor averaged over the 400 year leap year cycle."""
148
+ if dt_unit_code == 0:
149
+ if td_unit_code >= 4:
150
+ dt_factor = 97 + 400 * 365
151
+ td_factor = 400
152
+ dt_unit_code = 4
153
+ elif td_unit_code == 2:
154
+ dt_factor = 97 + 400 * 365
155
+ td_factor = 400 * 7
156
+ dt_unit_code = 2
157
+ elif dt_unit_code == 1:
158
+ if td_unit_code >= 4:
159
+ dt_factor = 97 + 400 * 365
160
+ td_factor = 400 * 12
161
+ dt_unit_code = 4
162
+ elif td_unit_code == 2:
163
+ dt_factor = 97 + 400 * 365
164
+ td_factor = 400 * 12 * 7
165
+ dt_unit_code = 2
166
+
167
+ if td_unit_code >= dt_unit_code:
168
+ factor = _get_conversion_multiplier(dt_unit_code, td_unit_code)
169
+ assert factor is not None, (dt_unit_code, td_unit_code)
170
+ return timedelta_unit, dt_factor * factor, td_factor
171
+ else:
172
+ factor = _get_conversion_multiplier(td_unit_code, dt_unit_code)
173
+ assert factor is not None, (dt_unit_code, td_unit_code)
174
+ return datetime_unit, dt_factor, td_factor * factor
175
+
176
+
177
+ def combine_datetime_timedelta_units(datetime_unit, timedelta_unit):
178
+ """
179
+ Return the unit result of combining *datetime_unit* with *timedelta_unit*
180
+ (e.g. by adding or subtracting). None is returned if combining
181
+ those units is forbidden.
182
+ """
183
+ dt_unit_code = DATETIME_UNITS[datetime_unit]
184
+ td_unit_code = DATETIME_UNITS[timedelta_unit]
185
+ if dt_unit_code == 14:
186
+ return timedelta_unit
187
+ elif td_unit_code == 14:
188
+ return datetime_unit
189
+ if td_unit_code < 2 and dt_unit_code >= 2:
190
+ return None
191
+ if dt_unit_code > td_unit_code:
192
+ return datetime_unit
193
+ else:
194
+ return timedelta_unit
195
+
196
+
197
+ def get_best_unit(unit_a, unit_b):
198
+ """
199
+ Get the best (i.e. finer-grained) of two units.
200
+ """
201
+ a = DATETIME_UNITS[unit_a]
202
+ b = DATETIME_UNITS[unit_b]
203
+ if a == 14:
204
+ return unit_b
205
+ if b == 14:
206
+ return unit_a
207
+ if b > a:
208
+ return unit_b
209
+ return unit_a
210
+
211
+
212
+ def datetime_minimum(a, b):
213
+ pass
214
+
215
+
216
+ def datetime_maximum(a, b):
217
+ pass