numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +7 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
|
@@ -10,18 +10,14 @@ from llvmlite import ir as llvm_ir
|
|
|
10
10
|
|
|
11
11
|
from numba.core import (
|
|
12
12
|
typing,
|
|
13
|
-
utils,
|
|
14
13
|
types,
|
|
15
14
|
ir,
|
|
16
|
-
debuginfo,
|
|
17
|
-
funcdesc,
|
|
18
15
|
generators,
|
|
19
|
-
config,
|
|
20
|
-
cgutils,
|
|
21
16
|
removerefctpass,
|
|
22
|
-
targetconfig,
|
|
23
17
|
)
|
|
24
|
-
from numba.cuda
|
|
18
|
+
from numba.cuda import debuginfo, cgutils, utils
|
|
19
|
+
from numba.cuda.core import ir_utils, targetconfig, funcdesc, config
|
|
20
|
+
|
|
25
21
|
from numba.core.errors import (
|
|
26
22
|
LoweringError,
|
|
27
23
|
new_error_context,
|
|
@@ -30,8 +26,8 @@ from numba.core.errors import (
|
|
|
30
26
|
UnsupportedError,
|
|
31
27
|
NumbaDebugInfoWarning,
|
|
32
28
|
)
|
|
33
|
-
from numba.core.funcdesc import default_mangler
|
|
34
|
-
from numba.core.environment import Environment
|
|
29
|
+
from numba.cuda.core.funcdesc import default_mangler
|
|
30
|
+
from numba.cuda.core.environment import Environment
|
|
35
31
|
from numba.core.analysis import compute_use_defs, must_use_alloca
|
|
36
32
|
from numba.misc.firstlinefinder import get_func_body_first_lineno
|
|
37
33
|
from numba import version_info
|
|
@@ -466,7 +462,7 @@ class Lower(BaseLower):
|
|
|
466
462
|
self._blk_local_varmap = {}
|
|
467
463
|
|
|
468
464
|
def pre_block(self, block):
|
|
469
|
-
from numba.core.unsafe import eh
|
|
465
|
+
from numba.cuda.core.unsafe import eh
|
|
470
466
|
|
|
471
467
|
super(Lower, self).pre_block(block)
|
|
472
468
|
self._cur_ir_block = block
|
|
@@ -1029,9 +1025,6 @@ class Lower(BaseLower):
|
|
|
1029
1025
|
elif isinstance(fnty, types.RecursiveCall):
|
|
1030
1026
|
res = self._lower_call_RecursiveCall(fnty, expr, signature)
|
|
1031
1027
|
|
|
1032
|
-
elif isinstance(fnty, types.FunctionType):
|
|
1033
|
-
res = self._lower_call_FunctionType(fnty, expr, signature)
|
|
1034
|
-
|
|
1035
1028
|
else:
|
|
1036
1029
|
res = self._lower_call_normal(fnty, expr, signature)
|
|
1037
1030
|
|
|
@@ -1052,7 +1045,7 @@ class Lower(BaseLower):
|
|
|
1052
1045
|
)
|
|
1053
1046
|
|
|
1054
1047
|
def _lower_call_ObjModeDispatcher(self, fnty, expr, signature):
|
|
1055
|
-
from numba.core.pythonapi import ObjModeUtils
|
|
1048
|
+
from numba.cuda.core.pythonapi import ObjModeUtils
|
|
1056
1049
|
|
|
1057
1050
|
self.init_pyapi()
|
|
1058
1051
|
# Acquire the GIL
|
|
@@ -1229,136 +1222,6 @@ class Lower(BaseLower):
|
|
|
1229
1222
|
)
|
|
1230
1223
|
return res
|
|
1231
1224
|
|
|
1232
|
-
def _lower_call_FunctionType(self, fnty, expr, signature):
|
|
1233
|
-
self.debug_print("# calling first-class function type")
|
|
1234
|
-
sig = types.unliteral(signature)
|
|
1235
|
-
if not fnty.check_signature(signature):
|
|
1236
|
-
# value dependent polymorphism?
|
|
1237
|
-
raise UnsupportedError(
|
|
1238
|
-
f"mismatch of function types:"
|
|
1239
|
-
f" expected {fnty} but got {types.FunctionType(sig)}"
|
|
1240
|
-
)
|
|
1241
|
-
argvals = self.fold_call_args(
|
|
1242
|
-
fnty,
|
|
1243
|
-
sig,
|
|
1244
|
-
expr.args,
|
|
1245
|
-
expr.vararg,
|
|
1246
|
-
expr.kws,
|
|
1247
|
-
)
|
|
1248
|
-
return self.__call_first_class_function_pointer(
|
|
1249
|
-
fnty.ftype,
|
|
1250
|
-
expr.func.name,
|
|
1251
|
-
sig,
|
|
1252
|
-
argvals,
|
|
1253
|
-
)
|
|
1254
|
-
|
|
1255
|
-
def __call_first_class_function_pointer(self, ftype, fname, sig, argvals):
|
|
1256
|
-
"""
|
|
1257
|
-
Calls a first-class function pointer.
|
|
1258
|
-
|
|
1259
|
-
This function is responsible for calling a first-class function pointer,
|
|
1260
|
-
which can either be a JIT-compiled function or a Python function. It
|
|
1261
|
-
determines if a JIT address is available, and if so, calls the function
|
|
1262
|
-
using the JIT address. Otherwise, it calls the function using a function
|
|
1263
|
-
pointer obtained from the `__get_first_class_function_pointer` method.
|
|
1264
|
-
|
|
1265
|
-
Args:
|
|
1266
|
-
ftype: The type of the function.
|
|
1267
|
-
fname: The name of the function.
|
|
1268
|
-
sig: The signature of the function.
|
|
1269
|
-
argvals: The argument values to pass to the function.
|
|
1270
|
-
|
|
1271
|
-
Returns:
|
|
1272
|
-
The result of calling the function.
|
|
1273
|
-
"""
|
|
1274
|
-
context = self.context
|
|
1275
|
-
builder = self.builder
|
|
1276
|
-
# Determine if jit address is available
|
|
1277
|
-
fstruct = self.loadvar(fname)
|
|
1278
|
-
struct = cgutils.create_struct_proxy(self.typeof(fname))(
|
|
1279
|
-
context, builder, value=fstruct
|
|
1280
|
-
)
|
|
1281
|
-
jit_addr = struct.jit_addr
|
|
1282
|
-
jit_addr.name = f"jit_addr_of_{fname}"
|
|
1283
|
-
|
|
1284
|
-
ctx = context
|
|
1285
|
-
res_slot = cgutils.alloca_once(
|
|
1286
|
-
builder, ctx.get_value_type(sig.return_type)
|
|
1287
|
-
)
|
|
1288
|
-
|
|
1289
|
-
if_jit_addr_is_null = builder.if_else(
|
|
1290
|
-
cgutils.is_null(builder, jit_addr), likely=False
|
|
1291
|
-
)
|
|
1292
|
-
with if_jit_addr_is_null as (then, orelse):
|
|
1293
|
-
with then:
|
|
1294
|
-
func_ptr = self.__get_first_class_function_pointer(
|
|
1295
|
-
ftype, fname, sig
|
|
1296
|
-
)
|
|
1297
|
-
res = builder.call(func_ptr, argvals)
|
|
1298
|
-
builder.store(res, res_slot)
|
|
1299
|
-
|
|
1300
|
-
with orelse:
|
|
1301
|
-
llty = ctx.call_conv.get_function_type(
|
|
1302
|
-
sig.return_type, sig.args
|
|
1303
|
-
).as_pointer()
|
|
1304
|
-
func_ptr = builder.bitcast(jit_addr, llty)
|
|
1305
|
-
# call
|
|
1306
|
-
status, res = ctx.call_conv.call_function(
|
|
1307
|
-
builder, func_ptr, sig.return_type, sig.args, argvals
|
|
1308
|
-
)
|
|
1309
|
-
with cgutils.if_unlikely(builder, status.is_error):
|
|
1310
|
-
context.call_conv.return_status_propagate(builder, status)
|
|
1311
|
-
builder.store(res, res_slot)
|
|
1312
|
-
return builder.load(res_slot)
|
|
1313
|
-
|
|
1314
|
-
def __get_first_class_function_pointer(self, ftype, fname, sig):
|
|
1315
|
-
from numba.experimental.function_type import lower_get_wrapper_address
|
|
1316
|
-
|
|
1317
|
-
llty = self.context.get_value_type(ftype)
|
|
1318
|
-
fstruct = self.loadvar(fname)
|
|
1319
|
-
addr = self.builder.extract_value(
|
|
1320
|
-
fstruct, 0, name="addr_of_%s" % (fname)
|
|
1321
|
-
)
|
|
1322
|
-
|
|
1323
|
-
fptr = cgutils.alloca_once(
|
|
1324
|
-
self.builder, llty, name="fptr_of_%s" % (fname)
|
|
1325
|
-
)
|
|
1326
|
-
with self.builder.if_else(
|
|
1327
|
-
cgutils.is_null(self.builder, addr), likely=False
|
|
1328
|
-
) as (then, orelse):
|
|
1329
|
-
with then:
|
|
1330
|
-
self.init_pyapi()
|
|
1331
|
-
# Acquire the GIL
|
|
1332
|
-
gil_state = self.pyapi.gil_ensure()
|
|
1333
|
-
pyaddr = self.builder.extract_value(
|
|
1334
|
-
fstruct, 1, name="pyaddr_of_%s" % (fname)
|
|
1335
|
-
)
|
|
1336
|
-
# try to recover the function address, see
|
|
1337
|
-
# test_zero_address BadToGood example in
|
|
1338
|
-
# test_function_type.py
|
|
1339
|
-
addr1 = lower_get_wrapper_address(
|
|
1340
|
-
self.context,
|
|
1341
|
-
self.builder,
|
|
1342
|
-
pyaddr,
|
|
1343
|
-
sig,
|
|
1344
|
-
failure_mode="ignore",
|
|
1345
|
-
)
|
|
1346
|
-
with self.builder.if_then(
|
|
1347
|
-
cgutils.is_null(self.builder, addr1), likely=False
|
|
1348
|
-
):
|
|
1349
|
-
self.return_exception(
|
|
1350
|
-
RuntimeError,
|
|
1351
|
-
exc_args=(f"{ftype} function address is null",),
|
|
1352
|
-
loc=self.loc,
|
|
1353
|
-
)
|
|
1354
|
-
addr2 = self.pyapi.long_as_voidptr(addr1)
|
|
1355
|
-
self.builder.store(self.builder.bitcast(addr2, llty), fptr)
|
|
1356
|
-
self.pyapi.decref(addr1)
|
|
1357
|
-
self.pyapi.gil_release(gil_state)
|
|
1358
|
-
with orelse:
|
|
1359
|
-
self.builder.store(self.builder.bitcast(addr, llty), fptr)
|
|
1360
|
-
return self.builder.load(fptr)
|
|
1361
|
-
|
|
1362
1225
|
def _lower_call_normal(self, fnty, expr, signature):
|
|
1363
1226
|
# Normal function resolution
|
|
1364
1227
|
self.debug_print("# calling normal function: {0}".format(fnty))
|
|
@@ -4,11 +4,12 @@
|
|
|
4
4
|
import math
|
|
5
5
|
import operator
|
|
6
6
|
from llvmlite import ir
|
|
7
|
-
from numba.core import types, typing
|
|
7
|
+
from numba.core import types, typing
|
|
8
8
|
from numba.cuda import cgutils
|
|
9
9
|
from numba.core.imputils import Registry
|
|
10
10
|
from numba.types import float32, float64, int64, uint64
|
|
11
11
|
from numba.cuda import libdevice
|
|
12
|
+
from numba.cuda.core import targetconfig
|
|
12
13
|
|
|
13
14
|
registry = Registry()
|
|
14
15
|
lower = registry.lower
|
|
@@ -6,7 +6,10 @@ import os
|
|
|
6
6
|
from functools import wraps
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
|
|
10
|
+
from numba import cuda, types
|
|
11
|
+
from numba.cuda import config
|
|
12
|
+
|
|
10
13
|
from numba.core.runtime.nrt import _nrt_mstats
|
|
11
14
|
from numba.cuda.cudadrv.driver import (
|
|
12
15
|
_Linker,
|
|
@@ -17,24 +20,11 @@ from numba.cuda.cudadrv.driver import (
|
|
|
17
20
|
)
|
|
18
21
|
from numba.cuda.cudadrv import devices
|
|
19
22
|
from numba.cuda.api import get_current_device
|
|
20
|
-
from numba.cuda.utils import
|
|
23
|
+
from numba.cuda.utils import cached_file_read
|
|
21
24
|
from numba.cuda.cudadrv.linkable_code import CUSource
|
|
25
|
+
from numba.cuda.typing.templates import signature
|
|
22
26
|
|
|
23
|
-
|
|
24
|
-
# Check environment variable or config for NRT statistics enablement
|
|
25
|
-
NRT_STATS = _readenv("NUMBA_CUDA_NRT_STATS", bool, False) or getattr(
|
|
26
|
-
config, "NUMBA_CUDA_NRT_STATS", False
|
|
27
|
-
)
|
|
28
|
-
if not hasattr(config, "NUMBA_CUDA_NRT_STATS"):
|
|
29
|
-
config.CUDA_NRT_STATS = NRT_STATS
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# Check environment variable or config for NRT enablement
|
|
33
|
-
ENABLE_NRT = _readenv("NUMBA_CUDA_ENABLE_NRT", bool, False) or getattr(
|
|
34
|
-
config, "NUMBA_CUDA_ENABLE_NRT", False
|
|
35
|
-
)
|
|
36
|
-
if not hasattr(config, "NUMBA_CUDA_ENABLE_NRT"):
|
|
37
|
-
config.CUDA_ENABLE_NRT = ENABLE_NRT
|
|
27
|
+
from numba.core.extending import intrinsic, overload_classmethod
|
|
38
28
|
|
|
39
29
|
|
|
40
30
|
def get_include():
|
|
@@ -42,6 +32,34 @@ def get_include():
|
|
|
42
32
|
return os.path.dirname(os.path.abspath(__file__))
|
|
43
33
|
|
|
44
34
|
|
|
35
|
+
# Provide an implementation of Array._allocate() for the CUDA target (used
|
|
36
|
+
# internally by Numba when generating the allocation of an array)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@intrinsic
|
|
40
|
+
def intrin_alloc(typingctx, allocsize, align):
|
|
41
|
+
"""Intrinsic to call into the allocator for Array"""
|
|
42
|
+
|
|
43
|
+
def codegen(context, builder, signature, args):
|
|
44
|
+
allocsize, align = args
|
|
45
|
+
meminfo = context.nrt.meminfo_alloc_aligned(builder, allocsize, align)
|
|
46
|
+
return meminfo
|
|
47
|
+
|
|
48
|
+
mip = types.MemInfoPointer(types.voidptr) # return untyped pointer
|
|
49
|
+
sig = signature(mip, allocsize, align)
|
|
50
|
+
return sig, codegen
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@overload_classmethod(types.Array, "_allocate", target="CUDA")
|
|
54
|
+
def _ol_array_allocate(cls, allocsize, align):
|
|
55
|
+
"""Implements a Numba-only CUDA-target classmethod on the array type."""
|
|
56
|
+
|
|
57
|
+
def impl(cls, allocsize, align):
|
|
58
|
+
return intrin_alloc(allocsize, align)
|
|
59
|
+
|
|
60
|
+
return impl
|
|
61
|
+
|
|
62
|
+
|
|
45
63
|
# Protect method to ensure NRT memory allocation and initialization
|
|
46
64
|
def _alloc_init_guard(method):
|
|
47
65
|
"""
|
|
@@ -69,10 +87,18 @@ class _Runtime:
|
|
|
69
87
|
|
|
70
88
|
def __init__(self):
|
|
71
89
|
"""Initialize memsys module and variable"""
|
|
90
|
+
self._reset()
|
|
91
|
+
|
|
92
|
+
def _reset(self):
|
|
93
|
+
"""Reset to the uninitialized state"""
|
|
72
94
|
self._memsys_module = None
|
|
73
95
|
self._memsys = None
|
|
74
96
|
self._initialized = False
|
|
75
97
|
|
|
98
|
+
def close(self):
|
|
99
|
+
"""Close and reset"""
|
|
100
|
+
self._reset()
|
|
101
|
+
|
|
76
102
|
def _compile_memsys_module(self):
|
|
77
103
|
"""
|
|
78
104
|
Compile memsys.cu and create a module from it in the current context
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
import sys
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_lib_dirs():
|
|
10
|
+
"""
|
|
11
|
+
Anaconda specific
|
|
12
|
+
"""
|
|
13
|
+
if sys.platform == "win32":
|
|
14
|
+
# CUDA 12 puts in "bin" directory, whereas CUDA 13 puts in "bin\x64" directory
|
|
15
|
+
dirnames = [
|
|
16
|
+
os.path.join("Library", "bin"),
|
|
17
|
+
os.path.join("Library", "bin", "x64"),
|
|
18
|
+
os.path.join("Library", "nvvm", "bin"),
|
|
19
|
+
os.path.join("Library", "nvvm", "bin", "x64"),
|
|
20
|
+
]
|
|
21
|
+
else:
|
|
22
|
+
dirnames = [
|
|
23
|
+
"lib",
|
|
24
|
+
]
|
|
25
|
+
libdirs = [os.path.join(sys.prefix, x) for x in dirnames]
|
|
26
|
+
return libdirs
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
DLLNAMEMAP = {
|
|
30
|
+
"linux": r"lib%(name)s\.so\.%(ver)s$",
|
|
31
|
+
"linux2": r"lib%(name)s\.so\.%(ver)s$",
|
|
32
|
+
"linux-static": r"lib%(name)s\.a$",
|
|
33
|
+
"darwin": r"lib%(name)s\.%(ver)s\.dylib$",
|
|
34
|
+
"win32": r"%(name)s%(ver)s\.dll$",
|
|
35
|
+
"win32-static": r"%(name)s\.lib$",
|
|
36
|
+
"bsd": r"lib%(name)s\.so\.%(ver)s$",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
RE_VER = r"[0-9]*([_\.][0-9]+)*"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def find_lib(libname, libdir=None, platform=None, static=False):
|
|
43
|
+
platform = platform or sys.platform
|
|
44
|
+
platform = "bsd" if "bsd" in platform else platform
|
|
45
|
+
if static:
|
|
46
|
+
platform = f"{platform}-static"
|
|
47
|
+
if platform not in DLLNAMEMAP:
|
|
48
|
+
# Return empty list if platform name is undefined.
|
|
49
|
+
# Not all platforms define their static library paths.
|
|
50
|
+
return []
|
|
51
|
+
pat = DLLNAMEMAP[platform] % {"name": libname, "ver": RE_VER}
|
|
52
|
+
regex = re.compile(pat)
|
|
53
|
+
return find_file(regex, libdir)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def find_file(pat, libdir=None):
|
|
57
|
+
if libdir is None:
|
|
58
|
+
libdirs = get_lib_dirs()
|
|
59
|
+
elif isinstance(libdir, str):
|
|
60
|
+
libdirs = [
|
|
61
|
+
libdir,
|
|
62
|
+
]
|
|
63
|
+
else:
|
|
64
|
+
libdirs = list(libdir)
|
|
65
|
+
files = []
|
|
66
|
+
for ldir in libdirs:
|
|
67
|
+
try:
|
|
68
|
+
entries = os.listdir(ldir)
|
|
69
|
+
except FileNotFoundError:
|
|
70
|
+
continue
|
|
71
|
+
candidates = [
|
|
72
|
+
os.path.join(ldir, ent) for ent in entries if pat.match(ent)
|
|
73
|
+
]
|
|
74
|
+
files.extend([c for c in candidates if os.path.isfile(c)])
|
|
75
|
+
return files
|
numba_cuda/numba/cuda/models.py
CHANGED
|
@@ -6,9 +6,10 @@ import functools
|
|
|
6
6
|
from llvmlite import ir
|
|
7
7
|
|
|
8
8
|
from numba.core.datamodel.registry import DataModelManager, register
|
|
9
|
+
from numba.core.datamodel import PrimitiveModel
|
|
9
10
|
from numba.core.extending import models
|
|
10
11
|
from numba.core import types
|
|
11
|
-
from numba.cuda.types import Dim3, GridGroup, CUDADispatcher
|
|
12
|
+
from numba.cuda.types import Dim3, GridGroup, CUDADispatcher, Bfloat16
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
cuda_data_manager = DataModelManager()
|
|
@@ -45,3 +46,10 @@ class FloatModel(models.PrimitiveModel):
|
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
register_model(CUDADispatcher)(models.OpaqueModel)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@register_model(Bfloat16)
|
|
52
|
+
class _model___nv_bfloat16(PrimitiveModel):
|
|
53
|
+
def __init__(self, dmm, fe_type):
|
|
54
|
+
be_type = ir.IntType(16)
|
|
55
|
+
super(_model___nv_bfloat16, self).__init__(dmm, fe_type, be_type)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Helper functions for np.timedelta64 and np.datetime64.
|
|
6
|
+
For now, multiples-of-units (for example timedeltas expressed in tens
|
|
7
|
+
of seconds) are not supported.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DATETIME_UNITS = {
|
|
14
|
+
"Y": 0, # Years
|
|
15
|
+
"M": 1, # Months
|
|
16
|
+
"W": 2, # Weeks
|
|
17
|
+
# Yes, there's a gap here
|
|
18
|
+
"D": 4, # Days
|
|
19
|
+
"h": 5, # Hours
|
|
20
|
+
"m": 6, # Minutes
|
|
21
|
+
"s": 7, # Seconds
|
|
22
|
+
"ms": 8, # Milliseconds
|
|
23
|
+
"us": 9, # Microseconds
|
|
24
|
+
"ns": 10, # Nanoseconds
|
|
25
|
+
"ps": 11, # Picoseconds
|
|
26
|
+
"fs": 12, # Femtoseconds
|
|
27
|
+
"as": 13, # Attoseconds
|
|
28
|
+
"": 14, # "generic", i.e. unit-less
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
NAT = np.timedelta64("nat").astype(np.int64)
|
|
32
|
+
|
|
33
|
+
# NOTE: numpy has several inconsistent functions for timedelta casting:
|
|
34
|
+
# - can_cast_timedelta64_{metadata,units}() disallows "safe" casting
|
|
35
|
+
# to and from generic units
|
|
36
|
+
# - cast_timedelta_to_timedelta() allows casting from (but not to)
|
|
37
|
+
# generic units
|
|
38
|
+
# - compute_datetime_metadata_greatest_common_divisor() allows casting from
|
|
39
|
+
# generic units (used for promotion)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def same_kind(src, dest):
|
|
43
|
+
"""
|
|
44
|
+
Whether the *src* and *dest* units are of the same kind.
|
|
45
|
+
"""
|
|
46
|
+
return (DATETIME_UNITS[src] < 5) == (DATETIME_UNITS[dest] < 5)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def can_cast_timedelta_units(src, dest):
|
|
50
|
+
# Mimic NumPy's "safe" casting and promotion
|
|
51
|
+
# `dest` must be more precise than `src` and they must be compatible
|
|
52
|
+
# for conversion.
|
|
53
|
+
# XXX should we switch to enforcing "same-kind" for Numpy 1.10+ ?
|
|
54
|
+
src = DATETIME_UNITS[src]
|
|
55
|
+
dest = DATETIME_UNITS[dest]
|
|
56
|
+
if src == dest:
|
|
57
|
+
return True
|
|
58
|
+
if src == 14:
|
|
59
|
+
return True
|
|
60
|
+
if src > dest:
|
|
61
|
+
return False
|
|
62
|
+
if dest == 14:
|
|
63
|
+
# unit-less timedelta64 is not compatible with anything else
|
|
64
|
+
return False
|
|
65
|
+
if src <= 1 and dest > 1:
|
|
66
|
+
# Cannot convert between months or years and other units
|
|
67
|
+
return False
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# Exact conversion factors from one unit to the immediately more precise one
|
|
72
|
+
_factors = {
|
|
73
|
+
0: (1, 12), # Years -> Months
|
|
74
|
+
2: (4, 7), # Weeks -> Days
|
|
75
|
+
4: (5, 24), # Days -> Hours
|
|
76
|
+
5: (6, 60), # Hours -> Minutes
|
|
77
|
+
6: (7, 60), # Minutes -> Seconds
|
|
78
|
+
7: (8, 1000),
|
|
79
|
+
8: (9, 1000),
|
|
80
|
+
9: (10, 1000),
|
|
81
|
+
10: (11, 1000),
|
|
82
|
+
11: (12, 1000),
|
|
83
|
+
12: (13, 1000),
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _get_conversion_multiplier(big_unit_code, small_unit_code):
|
|
88
|
+
"""
|
|
89
|
+
Return an integer multiplier allowing to convert from *big_unit_code*
|
|
90
|
+
to *small_unit_code*.
|
|
91
|
+
None is returned if the conversion is not possible through a
|
|
92
|
+
simple integer multiplication.
|
|
93
|
+
"""
|
|
94
|
+
# Mimics get_datetime_units_factor() in NumPy's datetime.c,
|
|
95
|
+
# with a twist to allow no-op conversion from generic units.
|
|
96
|
+
if big_unit_code == 14:
|
|
97
|
+
return 1
|
|
98
|
+
c = big_unit_code
|
|
99
|
+
factor = 1
|
|
100
|
+
while c < small_unit_code:
|
|
101
|
+
try:
|
|
102
|
+
c, mult = _factors[c]
|
|
103
|
+
except KeyError:
|
|
104
|
+
# No possible conversion
|
|
105
|
+
return None
|
|
106
|
+
factor *= mult
|
|
107
|
+
if c == small_unit_code:
|
|
108
|
+
return factor
|
|
109
|
+
else:
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_timedelta_conversion_factor(src_unit, dest_unit):
|
|
114
|
+
"""
|
|
115
|
+
Return an integer multiplier allowing to convert from timedeltas
|
|
116
|
+
of *src_unit* to *dest_unit*.
|
|
117
|
+
"""
|
|
118
|
+
return _get_conversion_multiplier(
|
|
119
|
+
DATETIME_UNITS[src_unit], DATETIME_UNITS[dest_unit]
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_datetime_timedelta_conversion(datetime_unit, timedelta_unit):
|
|
124
|
+
"""
|
|
125
|
+
Compute a possible conversion for combining *datetime_unit* and
|
|
126
|
+
*timedelta_unit* (presumably for adding or subtracting).
|
|
127
|
+
Return (result unit, integer datetime multiplier, integer timedelta
|
|
128
|
+
multiplier). RuntimeError is raised if the combination is impossible.
|
|
129
|
+
"""
|
|
130
|
+
# XXX now unused (I don't know where / how Numpy uses this)
|
|
131
|
+
dt_unit_code = DATETIME_UNITS[datetime_unit]
|
|
132
|
+
td_unit_code = DATETIME_UNITS[timedelta_unit]
|
|
133
|
+
if td_unit_code == 14 or dt_unit_code == 14:
|
|
134
|
+
return datetime_unit, 1, 1
|
|
135
|
+
if td_unit_code < 2 and dt_unit_code >= 2:
|
|
136
|
+
# Cannot combine Y or M timedelta64 with a finer-grained datetime64
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
"cannot combine datetime64(%r) and timedelta64(%r)"
|
|
139
|
+
% (datetime_unit, timedelta_unit)
|
|
140
|
+
)
|
|
141
|
+
dt_factor, td_factor = 1, 1
|
|
142
|
+
|
|
143
|
+
# If years or months, the datetime unit is first scaled to weeks or days,
|
|
144
|
+
# then conversion continues below. This is the same algorithm as used
|
|
145
|
+
# in Numpy's get_datetime_conversion_factor() (src/multiarray/datetime.c):
|
|
146
|
+
# """Conversions between years/months and other units use
|
|
147
|
+
# the factor averaged over the 400 year leap year cycle."""
|
|
148
|
+
if dt_unit_code == 0:
|
|
149
|
+
if td_unit_code >= 4:
|
|
150
|
+
dt_factor = 97 + 400 * 365
|
|
151
|
+
td_factor = 400
|
|
152
|
+
dt_unit_code = 4
|
|
153
|
+
elif td_unit_code == 2:
|
|
154
|
+
dt_factor = 97 + 400 * 365
|
|
155
|
+
td_factor = 400 * 7
|
|
156
|
+
dt_unit_code = 2
|
|
157
|
+
elif dt_unit_code == 1:
|
|
158
|
+
if td_unit_code >= 4:
|
|
159
|
+
dt_factor = 97 + 400 * 365
|
|
160
|
+
td_factor = 400 * 12
|
|
161
|
+
dt_unit_code = 4
|
|
162
|
+
elif td_unit_code == 2:
|
|
163
|
+
dt_factor = 97 + 400 * 365
|
|
164
|
+
td_factor = 400 * 12 * 7
|
|
165
|
+
dt_unit_code = 2
|
|
166
|
+
|
|
167
|
+
if td_unit_code >= dt_unit_code:
|
|
168
|
+
factor = _get_conversion_multiplier(dt_unit_code, td_unit_code)
|
|
169
|
+
assert factor is not None, (dt_unit_code, td_unit_code)
|
|
170
|
+
return timedelta_unit, dt_factor * factor, td_factor
|
|
171
|
+
else:
|
|
172
|
+
factor = _get_conversion_multiplier(td_unit_code, dt_unit_code)
|
|
173
|
+
assert factor is not None, (dt_unit_code, td_unit_code)
|
|
174
|
+
return datetime_unit, dt_factor, td_factor * factor
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def combine_datetime_timedelta_units(datetime_unit, timedelta_unit):
|
|
178
|
+
"""
|
|
179
|
+
Return the unit result of combining *datetime_unit* with *timedelta_unit*
|
|
180
|
+
(e.g. by adding or subtracting). None is returned if combining
|
|
181
|
+
those units is forbidden.
|
|
182
|
+
"""
|
|
183
|
+
dt_unit_code = DATETIME_UNITS[datetime_unit]
|
|
184
|
+
td_unit_code = DATETIME_UNITS[timedelta_unit]
|
|
185
|
+
if dt_unit_code == 14:
|
|
186
|
+
return timedelta_unit
|
|
187
|
+
elif td_unit_code == 14:
|
|
188
|
+
return datetime_unit
|
|
189
|
+
if td_unit_code < 2 and dt_unit_code >= 2:
|
|
190
|
+
return None
|
|
191
|
+
if dt_unit_code > td_unit_code:
|
|
192
|
+
return datetime_unit
|
|
193
|
+
else:
|
|
194
|
+
return timedelta_unit
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def get_best_unit(unit_a, unit_b):
|
|
198
|
+
"""
|
|
199
|
+
Get the best (i.e. finer-grained) of two units.
|
|
200
|
+
"""
|
|
201
|
+
a = DATETIME_UNITS[unit_a]
|
|
202
|
+
b = DATETIME_UNITS[unit_b]
|
|
203
|
+
if a == 14:
|
|
204
|
+
return unit_b
|
|
205
|
+
if b == 14:
|
|
206
|
+
return unit_a
|
|
207
|
+
if b > a:
|
|
208
|
+
return unit_b
|
|
209
|
+
return unit_a
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def datetime_minimum(a, b):
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def datetime_maximum(a, b):
|
|
217
|
+
pass
|