numba-cuda 0.17.0__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +0 -8
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
- numba_cuda/numba/cuda/api_util.py +6 -0
- numba_cuda/numba/cuda/cgutils.py +1291 -0
- numba_cuda/numba/cuda/codegen.py +32 -14
- numba_cuda/numba/cuda/compiler.py +113 -10
- numba_cuda/numba/cuda/core/caching.py +741 -0
- numba_cuda/numba/cuda/core/callconv.py +338 -0
- numba_cuda/numba/cuda/core/codegen.py +168 -0
- numba_cuda/numba/cuda/core/compiler.py +205 -0
- numba_cuda/numba/cuda/core/typed_passes.py +139 -0
- numba_cuda/numba/cuda/cudadecl.py +0 -268
- numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +4 -178
- numba_cuda/numba/cuda/debuginfo.py +469 -3
- numba_cuda/numba/cuda/device_init.py +0 -1
- numba_cuda/numba/cuda/dispatcher.py +310 -11
- numba_cuda/numba/cuda/extending.py +2 -1
- numba_cuda/numba/cuda/fp16.py +348 -0
- numba_cuda/numba/cuda/intrinsics.py +1 -1
- numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
- numba_cuda/numba/cuda/lowering.py +1833 -8
- numba_cuda/numba/cuda/mathimpl.py +2 -90
- numba_cuda/numba/cuda/nvvmutils.py +2 -1
- numba_cuda/numba/cuda/printimpl.py +2 -1
- numba_cuda/numba/cuda/serialize.py +264 -0
- numba_cuda/numba/cuda/simulator/__init__.py +2 -0
- numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
- numba_cuda/numba/cuda/stubs.py +0 -308
- numba_cuda/numba/cuda/target.py +13 -5
- numba_cuda/numba/cuda/testing.py +156 -5
- numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
- numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +10 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +15 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +108 -24
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
- numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
- numba_cuda/numba/cuda/utils.py +785 -0
- numba_cuda/numba/cuda/vector_types.py +1 -1
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/METADATA +18 -4
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/RECORD +63 -50
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/WHEEL +0 -0
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import warnings
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from numba.core import errors, types, funcdesc
|
|
5
|
+
from numba.core.compiler_machinery import LoweringPass
|
|
6
|
+
from llvmlite import binding as llvm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@contextmanager
|
|
10
|
+
def fallback_context(state, msg):
|
|
11
|
+
"""
|
|
12
|
+
Wraps code that would signal a fallback to object mode
|
|
13
|
+
"""
|
|
14
|
+
try:
|
|
15
|
+
yield
|
|
16
|
+
except Exception as e:
|
|
17
|
+
if not state.status.can_fallback:
|
|
18
|
+
raise
|
|
19
|
+
else:
|
|
20
|
+
# Clear all references attached to the traceback
|
|
21
|
+
e = e.with_traceback(None)
|
|
22
|
+
# this emits a warning containing the error message body in the
|
|
23
|
+
# case of fallback from npm to objmode
|
|
24
|
+
loop_lift = "" if state.flags.enable_looplift else "OUT"
|
|
25
|
+
msg_rewrite = (
|
|
26
|
+
"\nCompilation is falling back to object mode "
|
|
27
|
+
"WITH%s looplifting enabled because %s" % (loop_lift, msg)
|
|
28
|
+
)
|
|
29
|
+
warnings.warn_explicit(
|
|
30
|
+
"%s due to: %s" % (msg_rewrite, e),
|
|
31
|
+
errors.NumbaWarning,
|
|
32
|
+
state.func_id.filename,
|
|
33
|
+
state.func_id.firstlineno,
|
|
34
|
+
)
|
|
35
|
+
raise
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BaseNativeLowering(abc.ABC, LoweringPass):
|
|
39
|
+
"""The base class for a lowering pass. The lowering functionality must be
|
|
40
|
+
specified in inheriting classes by providing an appropriate lowering class
|
|
41
|
+
implementation in the overridden `lowering_class` property."""
|
|
42
|
+
|
|
43
|
+
_name = None
|
|
44
|
+
|
|
45
|
+
def __init__(self):
|
|
46
|
+
LoweringPass.__init__(self)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def lowering_class(self):
|
|
51
|
+
"""Returns the class that performs the lowering of the IR describing the
|
|
52
|
+
function that is the target of the current compilation."""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def run_pass(self, state):
|
|
56
|
+
if state.library is None:
|
|
57
|
+
codegen = state.targetctx.codegen()
|
|
58
|
+
state.library = codegen.create_library(state.func_id.func_qualname)
|
|
59
|
+
# Enable object caching upfront, so that the library can
|
|
60
|
+
# be later serialized.
|
|
61
|
+
state.library.enable_object_caching()
|
|
62
|
+
|
|
63
|
+
library = state.library
|
|
64
|
+
targetctx = state.targetctx
|
|
65
|
+
interp = state.func_ir # why is it called this?!
|
|
66
|
+
typemap = state.typemap
|
|
67
|
+
restype = state.return_type
|
|
68
|
+
calltypes = state.calltypes
|
|
69
|
+
flags = state.flags
|
|
70
|
+
metadata = state.metadata
|
|
71
|
+
pre_stats = llvm.passmanagers.dump_refprune_stats()
|
|
72
|
+
|
|
73
|
+
msg = "Function %s failed at nopython mode lowering" % (
|
|
74
|
+
state.func_id.func_name,
|
|
75
|
+
)
|
|
76
|
+
with fallback_context(state, msg):
|
|
77
|
+
# Lowering
|
|
78
|
+
fndesc = (
|
|
79
|
+
funcdesc.PythonFunctionDescriptor.from_specialized_function(
|
|
80
|
+
interp,
|
|
81
|
+
typemap,
|
|
82
|
+
restype,
|
|
83
|
+
calltypes,
|
|
84
|
+
mangler=targetctx.mangler,
|
|
85
|
+
inline=flags.forceinline,
|
|
86
|
+
noalias=flags.noalias,
|
|
87
|
+
abi_tags=[flags.get_mangle_string()],
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
with targetctx.push_code_library(library):
|
|
92
|
+
lower = self.lowering_class(
|
|
93
|
+
targetctx, library, fndesc, interp, metadata=metadata
|
|
94
|
+
)
|
|
95
|
+
lower.lower()
|
|
96
|
+
if not flags.no_cpython_wrapper:
|
|
97
|
+
lower.create_cpython_wrapper(flags.release_gil)
|
|
98
|
+
|
|
99
|
+
if not flags.no_cfunc_wrapper:
|
|
100
|
+
# skip cfunc wrapper generation if unsupported
|
|
101
|
+
# argument or return types are used
|
|
102
|
+
for t in state.args:
|
|
103
|
+
if isinstance(t, (types.Omitted, types.Generator)):
|
|
104
|
+
break
|
|
105
|
+
else:
|
|
106
|
+
if isinstance(
|
|
107
|
+
restype, (types.Optional, types.Generator)
|
|
108
|
+
):
|
|
109
|
+
pass
|
|
110
|
+
else:
|
|
111
|
+
lower.create_cfunc_wrapper()
|
|
112
|
+
|
|
113
|
+
env = lower.env
|
|
114
|
+
call_helper = lower.call_helper
|
|
115
|
+
del lower
|
|
116
|
+
|
|
117
|
+
from numba.core.compiler import _LowerResult # TODO: move this
|
|
118
|
+
|
|
119
|
+
if flags.no_compile:
|
|
120
|
+
state["cr"] = _LowerResult(
|
|
121
|
+
fndesc, call_helper, cfunc=None, env=env
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
# Prepare for execution
|
|
125
|
+
# Insert native function for use by other jitted-functions.
|
|
126
|
+
# We also register its library to allow for inlining.
|
|
127
|
+
cfunc = targetctx.get_executable(library, fndesc, env)
|
|
128
|
+
targetctx.insert_user_function(cfunc, fndesc, [library])
|
|
129
|
+
state["cr"] = _LowerResult(
|
|
130
|
+
fndesc, call_helper, cfunc=cfunc, env=env
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# capture pruning stats
|
|
134
|
+
post_stats = llvm.passmanagers.dump_refprune_stats()
|
|
135
|
+
metadata["prune_stats"] = post_stats - pre_stats
|
|
136
|
+
|
|
137
|
+
# Save the LLVM pass timings
|
|
138
|
+
metadata["llvm_pass_timings"] = library.recorded_timings
|
|
139
|
+
return True
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import operator
|
|
2
1
|
from numba.core import errors, types
|
|
3
2
|
from numba.core.typing.npydecl import (
|
|
4
3
|
parse_dtype,
|
|
@@ -19,9 +18,7 @@ from numba.core.typing.templates import (
|
|
|
19
18
|
Registry,
|
|
20
19
|
)
|
|
21
20
|
from numba.cuda.types import dim3
|
|
22
|
-
from numba.core.typeconv import Conversion
|
|
23
21
|
from numba import cuda
|
|
24
|
-
from numba.cuda.compiler import declare_device_function
|
|
25
22
|
|
|
26
23
|
registry = Registry()
|
|
27
24
|
register = registry.register
|
|
@@ -188,14 +185,6 @@ class Cuda_fma(ConcreteTemplate):
|
|
|
188
185
|
]
|
|
189
186
|
|
|
190
187
|
|
|
191
|
-
@register
|
|
192
|
-
class Cuda_hfma(ConcreteTemplate):
|
|
193
|
-
key = cuda.fp16.hfma
|
|
194
|
-
cases = [
|
|
195
|
-
signature(types.float16, types.float16, types.float16, types.float16)
|
|
196
|
-
]
|
|
197
|
-
|
|
198
|
-
|
|
199
188
|
@register
|
|
200
189
|
class Cuda_cbrt(ConcreteTemplate):
|
|
201
190
|
key = cuda.cbrt
|
|
@@ -281,37 +270,6 @@ class Cuda_selp(AbstractTemplate):
|
|
|
281
270
|
return signature(a, test, a, a)
|
|
282
271
|
|
|
283
272
|
|
|
284
|
-
def _genfp16_unary(l_key):
|
|
285
|
-
@register
|
|
286
|
-
class Cuda_fp16_unary(ConcreteTemplate):
|
|
287
|
-
key = l_key
|
|
288
|
-
cases = [signature(types.float16, types.float16)]
|
|
289
|
-
|
|
290
|
-
return Cuda_fp16_unary
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
def _genfp16_unary_operator(l_key):
|
|
294
|
-
@register_global(l_key)
|
|
295
|
-
class Cuda_fp16_unary(AbstractTemplate):
|
|
296
|
-
key = l_key
|
|
297
|
-
|
|
298
|
-
def generic(self, args, kws):
|
|
299
|
-
assert not kws
|
|
300
|
-
if len(args) == 1 and args[0] == types.float16:
|
|
301
|
-
return signature(types.float16, types.float16)
|
|
302
|
-
|
|
303
|
-
return Cuda_fp16_unary
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _genfp16_binary(l_key):
|
|
307
|
-
@register
|
|
308
|
-
class Cuda_fp16_binary(ConcreteTemplate):
|
|
309
|
-
key = l_key
|
|
310
|
-
cases = [signature(types.float16, types.float16, types.float16)]
|
|
311
|
-
|
|
312
|
-
return Cuda_fp16_binary
|
|
313
|
-
|
|
314
|
-
|
|
315
273
|
@register_global(float)
|
|
316
274
|
class Float(AbstractTemplate):
|
|
317
275
|
def generic(self, args, kws):
|
|
@@ -323,16 +281,6 @@ class Float(AbstractTemplate):
|
|
|
323
281
|
return signature(arg, arg)
|
|
324
282
|
|
|
325
283
|
|
|
326
|
-
def _genfp16_binary_comparison(l_key):
|
|
327
|
-
@register
|
|
328
|
-
class Cuda_fp16_cmp(ConcreteTemplate):
|
|
329
|
-
key = l_key
|
|
330
|
-
|
|
331
|
-
cases = [signature(types.b1, types.float16, types.float16)]
|
|
332
|
-
|
|
333
|
-
return Cuda_fp16_cmp
|
|
334
|
-
|
|
335
|
-
|
|
336
284
|
# If multiple ConcreteTemplates provide typing for a single function, then
|
|
337
285
|
# function resolution will pick the first compatible typing it finds even if it
|
|
338
286
|
# involves inserting a cast that would be considered undesirable (in this
|
|
@@ -347,124 +295,6 @@ def _genfp16_binary_comparison(l_key):
|
|
|
347
295
|
# with a ConcreteTemplate to simplify the logic.
|
|
348
296
|
|
|
349
297
|
|
|
350
|
-
def _fp16_binary_operator(l_key, retty):
|
|
351
|
-
@register_global(l_key)
|
|
352
|
-
class Cuda_fp16_operator(AbstractTemplate):
|
|
353
|
-
key = l_key
|
|
354
|
-
|
|
355
|
-
def generic(self, args, kws):
|
|
356
|
-
assert not kws
|
|
357
|
-
|
|
358
|
-
if len(args) == 2 and (
|
|
359
|
-
args[0] == types.float16 or args[1] == types.float16
|
|
360
|
-
):
|
|
361
|
-
if args[0] == types.float16:
|
|
362
|
-
convertible = self.context.can_convert(args[1], args[0])
|
|
363
|
-
else:
|
|
364
|
-
convertible = self.context.can_convert(args[0], args[1])
|
|
365
|
-
|
|
366
|
-
# We allow three cases here:
|
|
367
|
-
#
|
|
368
|
-
# 1. fp16 to fp16 - Conversion.exact
|
|
369
|
-
# 2. fp16 to other types fp16 can be promoted to
|
|
370
|
-
# - Conversion.promote
|
|
371
|
-
# 3. fp16 to int8 (safe conversion) -
|
|
372
|
-
# - Conversion.safe
|
|
373
|
-
|
|
374
|
-
if (
|
|
375
|
-
(convertible == Conversion.exact)
|
|
376
|
-
or (convertible == Conversion.promote)
|
|
377
|
-
or (convertible == Conversion.safe)
|
|
378
|
-
):
|
|
379
|
-
return signature(retty, types.float16, types.float16)
|
|
380
|
-
|
|
381
|
-
return Cuda_fp16_operator
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
def _genfp16_comparison_operator(op):
|
|
385
|
-
return _fp16_binary_operator(op, types.b1)
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
def _genfp16_binary_operator(op):
|
|
389
|
-
return _fp16_binary_operator(op, types.float16)
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
Cuda_hadd = _genfp16_binary(cuda.fp16.hadd)
|
|
393
|
-
Cuda_add = _genfp16_binary_operator(operator.add)
|
|
394
|
-
Cuda_iadd = _genfp16_binary_operator(operator.iadd)
|
|
395
|
-
Cuda_hsub = _genfp16_binary(cuda.fp16.hsub)
|
|
396
|
-
Cuda_sub = _genfp16_binary_operator(operator.sub)
|
|
397
|
-
Cuda_isub = _genfp16_binary_operator(operator.isub)
|
|
398
|
-
Cuda_hmul = _genfp16_binary(cuda.fp16.hmul)
|
|
399
|
-
Cuda_mul = _genfp16_binary_operator(operator.mul)
|
|
400
|
-
Cuda_imul = _genfp16_binary_operator(operator.imul)
|
|
401
|
-
Cuda_hmax = _genfp16_binary(cuda.fp16.hmax)
|
|
402
|
-
Cuda_hmin = _genfp16_binary(cuda.fp16.hmin)
|
|
403
|
-
Cuda_hneg = _genfp16_unary(cuda.fp16.hneg)
|
|
404
|
-
Cuda_neg = _genfp16_unary_operator(operator.neg)
|
|
405
|
-
Cuda_habs = _genfp16_unary(cuda.fp16.habs)
|
|
406
|
-
Cuda_abs = _genfp16_unary_operator(abs)
|
|
407
|
-
Cuda_heq = _genfp16_binary_comparison(cuda.fp16.heq)
|
|
408
|
-
_genfp16_comparison_operator(operator.eq)
|
|
409
|
-
Cuda_hne = _genfp16_binary_comparison(cuda.fp16.hne)
|
|
410
|
-
_genfp16_comparison_operator(operator.ne)
|
|
411
|
-
Cuda_hge = _genfp16_binary_comparison(cuda.fp16.hge)
|
|
412
|
-
_genfp16_comparison_operator(operator.ge)
|
|
413
|
-
Cuda_hgt = _genfp16_binary_comparison(cuda.fp16.hgt)
|
|
414
|
-
_genfp16_comparison_operator(operator.gt)
|
|
415
|
-
Cuda_hle = _genfp16_binary_comparison(cuda.fp16.hle)
|
|
416
|
-
_genfp16_comparison_operator(operator.le)
|
|
417
|
-
Cuda_hlt = _genfp16_binary_comparison(cuda.fp16.hlt)
|
|
418
|
-
_genfp16_comparison_operator(operator.lt)
|
|
419
|
-
_genfp16_binary_operator(operator.truediv)
|
|
420
|
-
_genfp16_binary_operator(operator.itruediv)
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
def _resolve_wrapped_unary(fname):
|
|
424
|
-
link = tuple()
|
|
425
|
-
decl = declare_device_function(
|
|
426
|
-
f"__numba_wrapper_{fname}",
|
|
427
|
-
types.float16,
|
|
428
|
-
(types.float16,),
|
|
429
|
-
link,
|
|
430
|
-
use_cooperative=False,
|
|
431
|
-
)
|
|
432
|
-
return types.Function(decl)
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
def _resolve_wrapped_binary(fname):
|
|
436
|
-
link = tuple()
|
|
437
|
-
decl = declare_device_function(
|
|
438
|
-
f"__numba_wrapper_{fname}",
|
|
439
|
-
types.float16,
|
|
440
|
-
(
|
|
441
|
-
types.float16,
|
|
442
|
-
types.float16,
|
|
443
|
-
),
|
|
444
|
-
link,
|
|
445
|
-
use_cooperative=False,
|
|
446
|
-
)
|
|
447
|
-
return types.Function(decl)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
hsin_device = _resolve_wrapped_unary("hsin")
|
|
451
|
-
hcos_device = _resolve_wrapped_unary("hcos")
|
|
452
|
-
hlog_device = _resolve_wrapped_unary("hlog")
|
|
453
|
-
hlog10_device = _resolve_wrapped_unary("hlog10")
|
|
454
|
-
hlog2_device = _resolve_wrapped_unary("hlog2")
|
|
455
|
-
hexp_device = _resolve_wrapped_unary("hexp")
|
|
456
|
-
hexp10_device = _resolve_wrapped_unary("hexp10")
|
|
457
|
-
hexp2_device = _resolve_wrapped_unary("hexp2")
|
|
458
|
-
hsqrt_device = _resolve_wrapped_unary("hsqrt")
|
|
459
|
-
hrsqrt_device = _resolve_wrapped_unary("hrsqrt")
|
|
460
|
-
hfloor_device = _resolve_wrapped_unary("hfloor")
|
|
461
|
-
hceil_device = _resolve_wrapped_unary("hceil")
|
|
462
|
-
hrcp_device = _resolve_wrapped_unary("hrcp")
|
|
463
|
-
hrint_device = _resolve_wrapped_unary("hrint")
|
|
464
|
-
htrunc_device = _resolve_wrapped_unary("htrunc")
|
|
465
|
-
hdiv_device = _resolve_wrapped_binary("hdiv")
|
|
466
|
-
|
|
467
|
-
|
|
468
298
|
# generate atomic operations
|
|
469
299
|
def _gen(l_key, supported_types):
|
|
470
300
|
@register
|
|
@@ -641,101 +471,6 @@ class CudaAtomicTemplate(AttributeTemplate):
|
|
|
641
471
|
return types.Function(Cuda_atomic_cas)
|
|
642
472
|
|
|
643
473
|
|
|
644
|
-
@register_attr
|
|
645
|
-
class CudaFp16Template(AttributeTemplate):
|
|
646
|
-
key = types.Module(cuda.fp16)
|
|
647
|
-
|
|
648
|
-
def resolve_hadd(self, mod):
|
|
649
|
-
return types.Function(Cuda_hadd)
|
|
650
|
-
|
|
651
|
-
def resolve_hsub(self, mod):
|
|
652
|
-
return types.Function(Cuda_hsub)
|
|
653
|
-
|
|
654
|
-
def resolve_hmul(self, mod):
|
|
655
|
-
return types.Function(Cuda_hmul)
|
|
656
|
-
|
|
657
|
-
def resolve_hdiv(self, mod):
|
|
658
|
-
return hdiv_device
|
|
659
|
-
|
|
660
|
-
def resolve_hneg(self, mod):
|
|
661
|
-
return types.Function(Cuda_hneg)
|
|
662
|
-
|
|
663
|
-
def resolve_habs(self, mod):
|
|
664
|
-
return types.Function(Cuda_habs)
|
|
665
|
-
|
|
666
|
-
def resolve_hfma(self, mod):
|
|
667
|
-
return types.Function(Cuda_hfma)
|
|
668
|
-
|
|
669
|
-
def resolve_hsin(self, mod):
|
|
670
|
-
return hsin_device
|
|
671
|
-
|
|
672
|
-
def resolve_hcos(self, mod):
|
|
673
|
-
return hcos_device
|
|
674
|
-
|
|
675
|
-
def resolve_hlog(self, mod):
|
|
676
|
-
return hlog_device
|
|
677
|
-
|
|
678
|
-
def resolve_hlog10(self, mod):
|
|
679
|
-
return hlog10_device
|
|
680
|
-
|
|
681
|
-
def resolve_hlog2(self, mod):
|
|
682
|
-
return hlog2_device
|
|
683
|
-
|
|
684
|
-
def resolve_hexp(self, mod):
|
|
685
|
-
return hexp_device
|
|
686
|
-
|
|
687
|
-
def resolve_hexp10(self, mod):
|
|
688
|
-
return hexp10_device
|
|
689
|
-
|
|
690
|
-
def resolve_hexp2(self, mod):
|
|
691
|
-
return hexp2_device
|
|
692
|
-
|
|
693
|
-
def resolve_hfloor(self, mod):
|
|
694
|
-
return hfloor_device
|
|
695
|
-
|
|
696
|
-
def resolve_hceil(self, mod):
|
|
697
|
-
return hceil_device
|
|
698
|
-
|
|
699
|
-
def resolve_hsqrt(self, mod):
|
|
700
|
-
return hsqrt_device
|
|
701
|
-
|
|
702
|
-
def resolve_hrsqrt(self, mod):
|
|
703
|
-
return hrsqrt_device
|
|
704
|
-
|
|
705
|
-
def resolve_hrcp(self, mod):
|
|
706
|
-
return hrcp_device
|
|
707
|
-
|
|
708
|
-
def resolve_hrint(self, mod):
|
|
709
|
-
return hrint_device
|
|
710
|
-
|
|
711
|
-
def resolve_htrunc(self, mod):
|
|
712
|
-
return htrunc_device
|
|
713
|
-
|
|
714
|
-
def resolve_heq(self, mod):
|
|
715
|
-
return types.Function(Cuda_heq)
|
|
716
|
-
|
|
717
|
-
def resolve_hne(self, mod):
|
|
718
|
-
return types.Function(Cuda_hne)
|
|
719
|
-
|
|
720
|
-
def resolve_hge(self, mod):
|
|
721
|
-
return types.Function(Cuda_hge)
|
|
722
|
-
|
|
723
|
-
def resolve_hgt(self, mod):
|
|
724
|
-
return types.Function(Cuda_hgt)
|
|
725
|
-
|
|
726
|
-
def resolve_hle(self, mod):
|
|
727
|
-
return types.Function(Cuda_hle)
|
|
728
|
-
|
|
729
|
-
def resolve_hlt(self, mod):
|
|
730
|
-
return types.Function(Cuda_hlt)
|
|
731
|
-
|
|
732
|
-
def resolve_hmax(self, mod):
|
|
733
|
-
return types.Function(Cuda_hmax)
|
|
734
|
-
|
|
735
|
-
def resolve_hmin(self, mod):
|
|
736
|
-
return types.Function(Cuda_hmin)
|
|
737
|
-
|
|
738
|
-
|
|
739
474
|
@register_attr
|
|
740
475
|
class CudaModuleTemplate(AttributeTemplate):
|
|
741
476
|
key = types.Module(cuda)
|
|
@@ -815,9 +550,6 @@ class CudaModuleTemplate(AttributeTemplate):
|
|
|
815
550
|
def resolve_atomic(self, mod):
|
|
816
551
|
return types.Module(cuda.atomic)
|
|
817
552
|
|
|
818
|
-
def resolve_fp16(self, mod):
|
|
819
|
-
return types.Module(cuda.fp16)
|
|
820
|
-
|
|
821
553
|
def resolve_const(self, mod):
|
|
822
554
|
return types.Module(cuda.const)
|
|
823
555
|
|
|
@@ -92,6 +92,9 @@ class DeviceNDArrayBase(_devicearray.DeviceArray):
|
|
|
92
92
|
self._dummy = dummyarray.Array.from_desc(
|
|
93
93
|
0, shape, strides, dtype.itemsize
|
|
94
94
|
)
|
|
95
|
+
# confirm that all elements of shape are ints
|
|
96
|
+
if not all(isinstance(dim, (int, np.integer)) for dim in shape):
|
|
97
|
+
raise TypeError("all elements of shape must be ints")
|
|
95
98
|
self.shape = tuple(shape)
|
|
96
99
|
self.strides = tuple(strides)
|
|
97
100
|
self.dtype = dtype
|
|
@@ -44,7 +44,8 @@ from collections import namedtuple, deque
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
from numba import mviewbuf
|
|
47
|
-
from numba.core import
|
|
47
|
+
from numba.core import config
|
|
48
|
+
from numba.cuda import utils, serialize
|
|
48
49
|
from .error import CudaSupportError, CudaDriverError
|
|
49
50
|
from .drvapi import API_PROTOTYPES
|
|
50
51
|
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
|
|
@@ -14,7 +14,7 @@ from llvmlite import ir
|
|
|
14
14
|
|
|
15
15
|
from .error import NvvmError, NvvmSupportError, NvvmWarning
|
|
16
16
|
from .libs import get_libdevice, open_libdevice, open_cudalib
|
|
17
|
-
from numba.
|
|
17
|
+
from numba.cuda import cgutils
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
logger = logging.getLogger(__name__)
|
|
@@ -6,15 +6,16 @@ import struct
|
|
|
6
6
|
from llvmlite import ir
|
|
7
7
|
import llvmlite.binding as ll
|
|
8
8
|
|
|
9
|
-
from numba.core.imputils import Registry
|
|
9
|
+
from numba.core.imputils import Registry
|
|
10
10
|
from numba.core.typing.npydecl import parse_dtype
|
|
11
11
|
from numba.core.datamodel import models
|
|
12
|
-
from numba.core import types
|
|
12
|
+
from numba.core import types
|
|
13
|
+
from numba.cuda import cgutils
|
|
13
14
|
from numba.np import ufunc_db
|
|
14
15
|
from numba.np.npyimpl import register_ufuncs
|
|
15
16
|
from .cudadrv import nvvm
|
|
16
17
|
from numba import cuda
|
|
17
|
-
from numba.cuda import nvvmutils, stubs
|
|
18
|
+
from numba.cuda import nvvmutils, stubs
|
|
18
19
|
from numba.cuda.types import dim3, CUDADispatcher
|
|
19
20
|
|
|
20
21
|
registry = Registry()
|
|
@@ -346,181 +347,6 @@ def ptx_fma(context, builder, sig, args):
|
|
|
346
347
|
return builder.fma(*args)
|
|
347
348
|
|
|
348
349
|
|
|
349
|
-
def float16_float_ty_constraint(bitwidth):
|
|
350
|
-
typemap = {32: ("f32", "f"), 64: ("f64", "d")}
|
|
351
|
-
|
|
352
|
-
try:
|
|
353
|
-
return typemap[bitwidth]
|
|
354
|
-
except KeyError:
|
|
355
|
-
msg = f"Conversion between float16 and float{bitwidth} unsupported"
|
|
356
|
-
raise errors.CudaLoweringError(msg)
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
@lower_cast(types.float16, types.Float)
|
|
360
|
-
def float16_to_float_cast(context, builder, fromty, toty, val):
|
|
361
|
-
if fromty.bitwidth == toty.bitwidth:
|
|
362
|
-
return val
|
|
363
|
-
|
|
364
|
-
ty, constraint = float16_float_ty_constraint(toty.bitwidth)
|
|
365
|
-
|
|
366
|
-
fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
|
|
367
|
-
asm = ir.InlineAsm(fnty, f"cvt.{ty}.f16 $0, $1;", f"={constraint},h")
|
|
368
|
-
return builder.call(asm, [val])
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
@lower_cast(types.Float, types.float16)
|
|
372
|
-
def float_to_float16_cast(context, builder, fromty, toty, val):
|
|
373
|
-
if fromty.bitwidth == toty.bitwidth:
|
|
374
|
-
return val
|
|
375
|
-
|
|
376
|
-
ty, constraint = float16_float_ty_constraint(fromty.bitwidth)
|
|
377
|
-
|
|
378
|
-
fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
|
|
379
|
-
asm = ir.InlineAsm(fnty, f"cvt.rn.f16.{ty} $0, $1;", f"=h,{constraint}")
|
|
380
|
-
return builder.call(asm, [val])
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
def float16_int_constraint(bitwidth):
|
|
384
|
-
typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
|
|
385
|
-
|
|
386
|
-
try:
|
|
387
|
-
return typemap[bitwidth]
|
|
388
|
-
except KeyError:
|
|
389
|
-
msg = f"Conversion between float16 and int{bitwidth} unsupported"
|
|
390
|
-
raise errors.CudaLoweringError(msg)
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
@lower_cast(types.float16, types.Integer)
|
|
394
|
-
def float16_to_integer_cast(context, builder, fromty, toty, val):
|
|
395
|
-
bitwidth = toty.bitwidth
|
|
396
|
-
constraint = float16_int_constraint(bitwidth)
|
|
397
|
-
signedness = "s" if toty.signed else "u"
|
|
398
|
-
|
|
399
|
-
fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
|
|
400
|
-
asm = ir.InlineAsm(
|
|
401
|
-
fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
|
|
402
|
-
)
|
|
403
|
-
return builder.call(asm, [val])
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
@lower_cast(types.Integer, types.float16)
|
|
407
|
-
@lower_cast(types.IntegerLiteral, types.float16)
|
|
408
|
-
def integer_to_float16_cast(context, builder, fromty, toty, val):
|
|
409
|
-
bitwidth = fromty.bitwidth
|
|
410
|
-
constraint = float16_int_constraint(bitwidth)
|
|
411
|
-
signedness = "s" if fromty.signed else "u"
|
|
412
|
-
|
|
413
|
-
fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
|
|
414
|
-
asm = ir.InlineAsm(
|
|
415
|
-
fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
|
|
416
|
-
)
|
|
417
|
-
return builder.call(asm, [val])
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
def lower_fp16_binary(fn, op):
|
|
421
|
-
@lower(fn, types.float16, types.float16)
|
|
422
|
-
def ptx_fp16_binary(context, builder, sig, args):
|
|
423
|
-
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
|
|
424
|
-
asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
|
|
425
|
-
return builder.call(asm, args)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
lower_fp16_binary(stubs.fp16.hadd, "add")
|
|
429
|
-
lower_fp16_binary(operator.add, "add")
|
|
430
|
-
lower_fp16_binary(operator.iadd, "add")
|
|
431
|
-
lower_fp16_binary(stubs.fp16.hsub, "sub")
|
|
432
|
-
lower_fp16_binary(operator.sub, "sub")
|
|
433
|
-
lower_fp16_binary(operator.isub, "sub")
|
|
434
|
-
lower_fp16_binary(stubs.fp16.hmul, "mul")
|
|
435
|
-
lower_fp16_binary(operator.mul, "mul")
|
|
436
|
-
lower_fp16_binary(operator.imul, "mul")
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
@lower(stubs.fp16.hneg, types.float16)
|
|
440
|
-
def ptx_fp16_hneg(context, builder, sig, args):
|
|
441
|
-
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
|
|
442
|
-
asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
|
|
443
|
-
return builder.call(asm, args)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
@lower(operator.neg, types.float16)
|
|
447
|
-
def operator_hneg(context, builder, sig, args):
|
|
448
|
-
return ptx_fp16_hneg(context, builder, sig, args)
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
@lower(stubs.fp16.habs, types.float16)
|
|
452
|
-
def ptx_fp16_habs(context, builder, sig, args):
|
|
453
|
-
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
|
|
454
|
-
asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
|
|
455
|
-
return builder.call(asm, args)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
@lower(abs, types.float16)
|
|
459
|
-
def operator_habs(context, builder, sig, args):
|
|
460
|
-
return ptx_fp16_habs(context, builder, sig, args)
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
@lower(stubs.fp16.hfma, types.float16, types.float16, types.float16)
|
|
464
|
-
def ptx_hfma(context, builder, sig, args):
|
|
465
|
-
argtys = [ir.IntType(16), ir.IntType(16), ir.IntType(16)]
|
|
466
|
-
fnty = ir.FunctionType(ir.IntType(16), argtys)
|
|
467
|
-
asm = ir.InlineAsm(fnty, "fma.rn.f16 $0,$1,$2,$3;", "=h,h,h,h")
|
|
468
|
-
return builder.call(asm, args)
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
@lower(operator.truediv, types.float16, types.float16)
|
|
472
|
-
@lower(operator.itruediv, types.float16, types.float16)
|
|
473
|
-
def fp16_div_impl(context, builder, sig, args):
|
|
474
|
-
def fp16_div(x, y):
|
|
475
|
-
return cuda.fp16.hdiv(x, y)
|
|
476
|
-
|
|
477
|
-
return context.compile_internal(builder, fp16_div, sig, args)
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
_fp16_cmp = """{{
|
|
481
|
-
.reg .pred __$$f16_cmp_tmp;
|
|
482
|
-
setp.{op}.f16 __$$f16_cmp_tmp, $1, $2;
|
|
483
|
-
selp.u16 $0, 1, 0, __$$f16_cmp_tmp;
|
|
484
|
-
}}"""
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
def _gen_fp16_cmp(op):
|
|
488
|
-
def ptx_fp16_comparison(context, builder, sig, args):
|
|
489
|
-
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
|
|
490
|
-
asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
|
|
491
|
-
result = builder.call(asm, args)
|
|
492
|
-
|
|
493
|
-
zero = context.get_constant(types.int16, 0)
|
|
494
|
-
int_result = builder.bitcast(result, ir.IntType(16))
|
|
495
|
-
return builder.icmp_unsigned("!=", int_result, zero)
|
|
496
|
-
|
|
497
|
-
return ptx_fp16_comparison
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
|
|
501
|
-
lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
|
|
502
|
-
lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
|
|
503
|
-
lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
|
|
504
|
-
lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
|
|
505
|
-
lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
|
|
506
|
-
lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
|
|
507
|
-
lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
|
|
508
|
-
lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
|
|
509
|
-
lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
|
|
510
|
-
lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
|
|
511
|
-
lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
def lower_fp16_minmax(fn, fname, op):
|
|
515
|
-
@lower(fn, types.float16, types.float16)
|
|
516
|
-
def ptx_fp16_minmax(context, builder, sig, args):
|
|
517
|
-
choice = _gen_fp16_cmp(op)(context, builder, sig, args)
|
|
518
|
-
return builder.select(choice, args[0], args[1])
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
|
|
522
|
-
lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
|
|
523
|
-
|
|
524
350
|
# See:
|
|
525
351
|
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
|
|
526
352
|
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrtf.html#__nv_cbrtf
|