numba-cuda 0.17.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (64) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +0 -8
  3. numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
  4. numba_cuda/numba/cuda/api_util.py +6 -0
  5. numba_cuda/numba/cuda/cgutils.py +1291 -0
  6. numba_cuda/numba/cuda/codegen.py +32 -14
  7. numba_cuda/numba/cuda/compiler.py +113 -10
  8. numba_cuda/numba/cuda/core/caching.py +741 -0
  9. numba_cuda/numba/cuda/core/callconv.py +338 -0
  10. numba_cuda/numba/cuda/core/codegen.py +168 -0
  11. numba_cuda/numba/cuda/core/compiler.py +205 -0
  12. numba_cuda/numba/cuda/core/typed_passes.py +139 -0
  13. numba_cuda/numba/cuda/cudadecl.py +0 -268
  14. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  15. numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
  16. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
  17. numba_cuda/numba/cuda/cudaimpl.py +4 -178
  18. numba_cuda/numba/cuda/debuginfo.py +469 -3
  19. numba_cuda/numba/cuda/device_init.py +0 -1
  20. numba_cuda/numba/cuda/dispatcher.py +310 -11
  21. numba_cuda/numba/cuda/extending.py +2 -1
  22. numba_cuda/numba/cuda/fp16.py +348 -0
  23. numba_cuda/numba/cuda/intrinsics.py +1 -1
  24. numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
  25. numba_cuda/numba/cuda/lowering.py +1833 -8
  26. numba_cuda/numba/cuda/mathimpl.py +2 -90
  27. numba_cuda/numba/cuda/nvvmutils.py +2 -1
  28. numba_cuda/numba/cuda/printimpl.py +2 -1
  29. numba_cuda/numba/cuda/serialize.py +264 -0
  30. numba_cuda/numba/cuda/simulator/__init__.py +2 -0
  31. numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
  32. numba_cuda/numba/cuda/stubs.py +0 -308
  33. numba_cuda/numba/cuda/target.py +13 -5
  34. numba_cuda/numba/cuda/testing.py +156 -5
  35. numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
  36. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
  37. numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
  38. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +10 -4
  39. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
  40. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
  41. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
  42. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  43. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
  44. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +15 -0
  45. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  46. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
  47. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  49. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +108 -24
  50. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
  51. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
  52. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
  53. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
  54. numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
  55. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
  56. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
  57. numba_cuda/numba/cuda/utils.py +785 -0
  58. numba_cuda/numba/cuda/vector_types.py +1 -1
  59. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/METADATA +18 -4
  60. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/RECORD +63 -50
  61. numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
  62. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/WHEEL +0 -0
  63. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/licenses/LICENSE +0 -0
  64. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ import abc
2
+ import warnings
3
+ from contextlib import contextmanager
4
+ from numba.core import errors, types, funcdesc
5
+ from numba.core.compiler_machinery import LoweringPass
6
+ from llvmlite import binding as llvm
7
+
8
+
9
+ @contextmanager
10
+ def fallback_context(state, msg):
11
+ """
12
+ Wraps code that would signal a fallback to object mode
13
+ """
14
+ try:
15
+ yield
16
+ except Exception as e:
17
+ if not state.status.can_fallback:
18
+ raise
19
+ else:
20
+ # Clear all references attached to the traceback
21
+ e = e.with_traceback(None)
22
+ # this emits a warning containing the error message body in the
23
+ # case of fallback from npm to objmode
24
+ loop_lift = "" if state.flags.enable_looplift else "OUT"
25
+ msg_rewrite = (
26
+ "\nCompilation is falling back to object mode "
27
+ "WITH%s looplifting enabled because %s" % (loop_lift, msg)
28
+ )
29
+ warnings.warn_explicit(
30
+ "%s due to: %s" % (msg_rewrite, e),
31
+ errors.NumbaWarning,
32
+ state.func_id.filename,
33
+ state.func_id.firstlineno,
34
+ )
35
+ raise
36
+
37
+
38
+ class BaseNativeLowering(abc.ABC, LoweringPass):
39
+ """The base class for a lowering pass. The lowering functionality must be
40
+ specified in inheriting classes by providing an appropriate lowering class
41
+ implementation in the overridden `lowering_class` property."""
42
+
43
+ _name = None
44
+
45
+ def __init__(self):
46
+ LoweringPass.__init__(self)
47
+
48
+ @property
49
+ @abc.abstractmethod
50
+ def lowering_class(self):
51
+ """Returns the class that performs the lowering of the IR describing the
52
+ function that is the target of the current compilation."""
53
+ pass
54
+
55
+ def run_pass(self, state):
56
+ if state.library is None:
57
+ codegen = state.targetctx.codegen()
58
+ state.library = codegen.create_library(state.func_id.func_qualname)
59
+ # Enable object caching upfront, so that the library can
60
+ # be later serialized.
61
+ state.library.enable_object_caching()
62
+
63
+ library = state.library
64
+ targetctx = state.targetctx
65
+ interp = state.func_ir # why is it called this?!
66
+ typemap = state.typemap
67
+ restype = state.return_type
68
+ calltypes = state.calltypes
69
+ flags = state.flags
70
+ metadata = state.metadata
71
+ pre_stats = llvm.passmanagers.dump_refprune_stats()
72
+
73
+ msg = "Function %s failed at nopython mode lowering" % (
74
+ state.func_id.func_name,
75
+ )
76
+ with fallback_context(state, msg):
77
+ # Lowering
78
+ fndesc = (
79
+ funcdesc.PythonFunctionDescriptor.from_specialized_function(
80
+ interp,
81
+ typemap,
82
+ restype,
83
+ calltypes,
84
+ mangler=targetctx.mangler,
85
+ inline=flags.forceinline,
86
+ noalias=flags.noalias,
87
+ abi_tags=[flags.get_mangle_string()],
88
+ )
89
+ )
90
+
91
+ with targetctx.push_code_library(library):
92
+ lower = self.lowering_class(
93
+ targetctx, library, fndesc, interp, metadata=metadata
94
+ )
95
+ lower.lower()
96
+ if not flags.no_cpython_wrapper:
97
+ lower.create_cpython_wrapper(flags.release_gil)
98
+
99
+ if not flags.no_cfunc_wrapper:
100
+ # skip cfunc wrapper generation if unsupported
101
+ # argument or return types are used
102
+ for t in state.args:
103
+ if isinstance(t, (types.Omitted, types.Generator)):
104
+ break
105
+ else:
106
+ if isinstance(
107
+ restype, (types.Optional, types.Generator)
108
+ ):
109
+ pass
110
+ else:
111
+ lower.create_cfunc_wrapper()
112
+
113
+ env = lower.env
114
+ call_helper = lower.call_helper
115
+ del lower
116
+
117
+ from numba.core.compiler import _LowerResult # TODO: move this
118
+
119
+ if flags.no_compile:
120
+ state["cr"] = _LowerResult(
121
+ fndesc, call_helper, cfunc=None, env=env
122
+ )
123
+ else:
124
+ # Prepare for execution
125
+ # Insert native function for use by other jitted-functions.
126
+ # We also register its library to allow for inlining.
127
+ cfunc = targetctx.get_executable(library, fndesc, env)
128
+ targetctx.insert_user_function(cfunc, fndesc, [library])
129
+ state["cr"] = _LowerResult(
130
+ fndesc, call_helper, cfunc=cfunc, env=env
131
+ )
132
+
133
+ # capture pruning stats
134
+ post_stats = llvm.passmanagers.dump_refprune_stats()
135
+ metadata["prune_stats"] = post_stats - pre_stats
136
+
137
+ # Save the LLVM pass timings
138
+ metadata["llvm_pass_timings"] = library.recorded_timings
139
+ return True
@@ -1,4 +1,3 @@
1
- import operator
2
1
  from numba.core import errors, types
3
2
  from numba.core.typing.npydecl import (
4
3
  parse_dtype,
@@ -19,9 +18,7 @@ from numba.core.typing.templates import (
19
18
  Registry,
20
19
  )
21
20
  from numba.cuda.types import dim3
22
- from numba.core.typeconv import Conversion
23
21
  from numba import cuda
24
- from numba.cuda.compiler import declare_device_function
25
22
 
26
23
  registry = Registry()
27
24
  register = registry.register
@@ -188,14 +185,6 @@ class Cuda_fma(ConcreteTemplate):
188
185
  ]
189
186
 
190
187
 
191
- @register
192
- class Cuda_hfma(ConcreteTemplate):
193
- key = cuda.fp16.hfma
194
- cases = [
195
- signature(types.float16, types.float16, types.float16, types.float16)
196
- ]
197
-
198
-
199
188
  @register
200
189
  class Cuda_cbrt(ConcreteTemplate):
201
190
  key = cuda.cbrt
@@ -281,37 +270,6 @@ class Cuda_selp(AbstractTemplate):
281
270
  return signature(a, test, a, a)
282
271
 
283
272
 
284
- def _genfp16_unary(l_key):
285
- @register
286
- class Cuda_fp16_unary(ConcreteTemplate):
287
- key = l_key
288
- cases = [signature(types.float16, types.float16)]
289
-
290
- return Cuda_fp16_unary
291
-
292
-
293
- def _genfp16_unary_operator(l_key):
294
- @register_global(l_key)
295
- class Cuda_fp16_unary(AbstractTemplate):
296
- key = l_key
297
-
298
- def generic(self, args, kws):
299
- assert not kws
300
- if len(args) == 1 and args[0] == types.float16:
301
- return signature(types.float16, types.float16)
302
-
303
- return Cuda_fp16_unary
304
-
305
-
306
- def _genfp16_binary(l_key):
307
- @register
308
- class Cuda_fp16_binary(ConcreteTemplate):
309
- key = l_key
310
- cases = [signature(types.float16, types.float16, types.float16)]
311
-
312
- return Cuda_fp16_binary
313
-
314
-
315
273
  @register_global(float)
316
274
  class Float(AbstractTemplate):
317
275
  def generic(self, args, kws):
@@ -323,16 +281,6 @@ class Float(AbstractTemplate):
323
281
  return signature(arg, arg)
324
282
 
325
283
 
326
- def _genfp16_binary_comparison(l_key):
327
- @register
328
- class Cuda_fp16_cmp(ConcreteTemplate):
329
- key = l_key
330
-
331
- cases = [signature(types.b1, types.float16, types.float16)]
332
-
333
- return Cuda_fp16_cmp
334
-
335
-
336
284
  # If multiple ConcreteTemplates provide typing for a single function, then
337
285
  # function resolution will pick the first compatible typing it finds even if it
338
286
  # involves inserting a cast that would be considered undesirable (in this
@@ -347,124 +295,6 @@ def _genfp16_binary_comparison(l_key):
347
295
  # with a ConcreteTemplate to simplify the logic.
348
296
 
349
297
 
350
- def _fp16_binary_operator(l_key, retty):
351
- @register_global(l_key)
352
- class Cuda_fp16_operator(AbstractTemplate):
353
- key = l_key
354
-
355
- def generic(self, args, kws):
356
- assert not kws
357
-
358
- if len(args) == 2 and (
359
- args[0] == types.float16 or args[1] == types.float16
360
- ):
361
- if args[0] == types.float16:
362
- convertible = self.context.can_convert(args[1], args[0])
363
- else:
364
- convertible = self.context.can_convert(args[0], args[1])
365
-
366
- # We allow three cases here:
367
- #
368
- # 1. fp16 to fp16 - Conversion.exact
369
- # 2. fp16 to other types fp16 can be promoted to
370
- # - Conversion.promote
371
- # 3. fp16 to int8 (safe conversion) -
372
- # - Conversion.safe
373
-
374
- if (
375
- (convertible == Conversion.exact)
376
- or (convertible == Conversion.promote)
377
- or (convertible == Conversion.safe)
378
- ):
379
- return signature(retty, types.float16, types.float16)
380
-
381
- return Cuda_fp16_operator
382
-
383
-
384
- def _genfp16_comparison_operator(op):
385
- return _fp16_binary_operator(op, types.b1)
386
-
387
-
388
- def _genfp16_binary_operator(op):
389
- return _fp16_binary_operator(op, types.float16)
390
-
391
-
392
- Cuda_hadd = _genfp16_binary(cuda.fp16.hadd)
393
- Cuda_add = _genfp16_binary_operator(operator.add)
394
- Cuda_iadd = _genfp16_binary_operator(operator.iadd)
395
- Cuda_hsub = _genfp16_binary(cuda.fp16.hsub)
396
- Cuda_sub = _genfp16_binary_operator(operator.sub)
397
- Cuda_isub = _genfp16_binary_operator(operator.isub)
398
- Cuda_hmul = _genfp16_binary(cuda.fp16.hmul)
399
- Cuda_mul = _genfp16_binary_operator(operator.mul)
400
- Cuda_imul = _genfp16_binary_operator(operator.imul)
401
- Cuda_hmax = _genfp16_binary(cuda.fp16.hmax)
402
- Cuda_hmin = _genfp16_binary(cuda.fp16.hmin)
403
- Cuda_hneg = _genfp16_unary(cuda.fp16.hneg)
404
- Cuda_neg = _genfp16_unary_operator(operator.neg)
405
- Cuda_habs = _genfp16_unary(cuda.fp16.habs)
406
- Cuda_abs = _genfp16_unary_operator(abs)
407
- Cuda_heq = _genfp16_binary_comparison(cuda.fp16.heq)
408
- _genfp16_comparison_operator(operator.eq)
409
- Cuda_hne = _genfp16_binary_comparison(cuda.fp16.hne)
410
- _genfp16_comparison_operator(operator.ne)
411
- Cuda_hge = _genfp16_binary_comparison(cuda.fp16.hge)
412
- _genfp16_comparison_operator(operator.ge)
413
- Cuda_hgt = _genfp16_binary_comparison(cuda.fp16.hgt)
414
- _genfp16_comparison_operator(operator.gt)
415
- Cuda_hle = _genfp16_binary_comparison(cuda.fp16.hle)
416
- _genfp16_comparison_operator(operator.le)
417
- Cuda_hlt = _genfp16_binary_comparison(cuda.fp16.hlt)
418
- _genfp16_comparison_operator(operator.lt)
419
- _genfp16_binary_operator(operator.truediv)
420
- _genfp16_binary_operator(operator.itruediv)
421
-
422
-
423
- def _resolve_wrapped_unary(fname):
424
- link = tuple()
425
- decl = declare_device_function(
426
- f"__numba_wrapper_{fname}",
427
- types.float16,
428
- (types.float16,),
429
- link,
430
- use_cooperative=False,
431
- )
432
- return types.Function(decl)
433
-
434
-
435
- def _resolve_wrapped_binary(fname):
436
- link = tuple()
437
- decl = declare_device_function(
438
- f"__numba_wrapper_{fname}",
439
- types.float16,
440
- (
441
- types.float16,
442
- types.float16,
443
- ),
444
- link,
445
- use_cooperative=False,
446
- )
447
- return types.Function(decl)
448
-
449
-
450
- hsin_device = _resolve_wrapped_unary("hsin")
451
- hcos_device = _resolve_wrapped_unary("hcos")
452
- hlog_device = _resolve_wrapped_unary("hlog")
453
- hlog10_device = _resolve_wrapped_unary("hlog10")
454
- hlog2_device = _resolve_wrapped_unary("hlog2")
455
- hexp_device = _resolve_wrapped_unary("hexp")
456
- hexp10_device = _resolve_wrapped_unary("hexp10")
457
- hexp2_device = _resolve_wrapped_unary("hexp2")
458
- hsqrt_device = _resolve_wrapped_unary("hsqrt")
459
- hrsqrt_device = _resolve_wrapped_unary("hrsqrt")
460
- hfloor_device = _resolve_wrapped_unary("hfloor")
461
- hceil_device = _resolve_wrapped_unary("hceil")
462
- hrcp_device = _resolve_wrapped_unary("hrcp")
463
- hrint_device = _resolve_wrapped_unary("hrint")
464
- htrunc_device = _resolve_wrapped_unary("htrunc")
465
- hdiv_device = _resolve_wrapped_binary("hdiv")
466
-
467
-
468
298
  # generate atomic operations
469
299
  def _gen(l_key, supported_types):
470
300
  @register
@@ -641,101 +471,6 @@ class CudaAtomicTemplate(AttributeTemplate):
641
471
  return types.Function(Cuda_atomic_cas)
642
472
 
643
473
 
644
- @register_attr
645
- class CudaFp16Template(AttributeTemplate):
646
- key = types.Module(cuda.fp16)
647
-
648
- def resolve_hadd(self, mod):
649
- return types.Function(Cuda_hadd)
650
-
651
- def resolve_hsub(self, mod):
652
- return types.Function(Cuda_hsub)
653
-
654
- def resolve_hmul(self, mod):
655
- return types.Function(Cuda_hmul)
656
-
657
- def resolve_hdiv(self, mod):
658
- return hdiv_device
659
-
660
- def resolve_hneg(self, mod):
661
- return types.Function(Cuda_hneg)
662
-
663
- def resolve_habs(self, mod):
664
- return types.Function(Cuda_habs)
665
-
666
- def resolve_hfma(self, mod):
667
- return types.Function(Cuda_hfma)
668
-
669
- def resolve_hsin(self, mod):
670
- return hsin_device
671
-
672
- def resolve_hcos(self, mod):
673
- return hcos_device
674
-
675
- def resolve_hlog(self, mod):
676
- return hlog_device
677
-
678
- def resolve_hlog10(self, mod):
679
- return hlog10_device
680
-
681
- def resolve_hlog2(self, mod):
682
- return hlog2_device
683
-
684
- def resolve_hexp(self, mod):
685
- return hexp_device
686
-
687
- def resolve_hexp10(self, mod):
688
- return hexp10_device
689
-
690
- def resolve_hexp2(self, mod):
691
- return hexp2_device
692
-
693
- def resolve_hfloor(self, mod):
694
- return hfloor_device
695
-
696
- def resolve_hceil(self, mod):
697
- return hceil_device
698
-
699
- def resolve_hsqrt(self, mod):
700
- return hsqrt_device
701
-
702
- def resolve_hrsqrt(self, mod):
703
- return hrsqrt_device
704
-
705
- def resolve_hrcp(self, mod):
706
- return hrcp_device
707
-
708
- def resolve_hrint(self, mod):
709
- return hrint_device
710
-
711
- def resolve_htrunc(self, mod):
712
- return htrunc_device
713
-
714
- def resolve_heq(self, mod):
715
- return types.Function(Cuda_heq)
716
-
717
- def resolve_hne(self, mod):
718
- return types.Function(Cuda_hne)
719
-
720
- def resolve_hge(self, mod):
721
- return types.Function(Cuda_hge)
722
-
723
- def resolve_hgt(self, mod):
724
- return types.Function(Cuda_hgt)
725
-
726
- def resolve_hle(self, mod):
727
- return types.Function(Cuda_hle)
728
-
729
- def resolve_hlt(self, mod):
730
- return types.Function(Cuda_hlt)
731
-
732
- def resolve_hmax(self, mod):
733
- return types.Function(Cuda_hmax)
734
-
735
- def resolve_hmin(self, mod):
736
- return types.Function(Cuda_hmin)
737
-
738
-
739
474
  @register_attr
740
475
  class CudaModuleTemplate(AttributeTemplate):
741
476
  key = types.Module(cuda)
@@ -815,9 +550,6 @@ class CudaModuleTemplate(AttributeTemplate):
815
550
  def resolve_atomic(self, mod):
816
551
  return types.Module(cuda.atomic)
817
552
 
818
- def resolve_fp16(self, mod):
819
- return types.Module(cuda.fp16)
820
-
821
553
  def resolve_const(self, mod):
822
554
  return types.Module(cuda.const)
823
555
 
@@ -92,6 +92,9 @@ class DeviceNDArrayBase(_devicearray.DeviceArray):
92
92
  self._dummy = dummyarray.Array.from_desc(
93
93
  0, shape, strides, dtype.itemsize
94
94
  )
95
+ # confirm that all elements of shape are ints
96
+ if not all(isinstance(dim, (int, np.integer)) for dim in shape):
97
+ raise TypeError("all elements of shape must be ints")
95
98
  self.shape = tuple(shape)
96
99
  self.strides = tuple(strides)
97
100
  self.dtype = dtype
@@ -44,7 +44,8 @@ from collections import namedtuple, deque
44
44
 
45
45
 
46
46
  from numba import mviewbuf
47
- from numba.core import utils, serialize, config
47
+ from numba.core import config
48
+ from numba.cuda import utils, serialize
48
49
  from .error import CudaSupportError, CudaDriverError
49
50
  from .drvapi import API_PROTOTYPES
50
51
  from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
@@ -14,7 +14,7 @@ from llvmlite import ir
14
14
 
15
15
  from .error import NvvmError, NvvmSupportError, NvvmWarning
16
16
  from .libs import get_libdevice, open_libdevice, open_cudalib
17
- from numba.core import cgutils
17
+ from numba.cuda import cgutils
18
18
 
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -6,15 +6,16 @@ import struct
6
6
  from llvmlite import ir
7
7
  import llvmlite.binding as ll
8
8
 
9
- from numba.core.imputils import Registry, lower_cast
9
+ from numba.core.imputils import Registry
10
10
  from numba.core.typing.npydecl import parse_dtype
11
11
  from numba.core.datamodel import models
12
- from numba.core import types, cgutils
12
+ from numba.core import types
13
+ from numba.cuda import cgutils
13
14
  from numba.np import ufunc_db
14
15
  from numba.np.npyimpl import register_ufuncs
15
16
  from .cudadrv import nvvm
16
17
  from numba import cuda
17
- from numba.cuda import nvvmutils, stubs, errors
18
+ from numba.cuda import nvvmutils, stubs
18
19
  from numba.cuda.types import dim3, CUDADispatcher
19
20
 
20
21
  registry = Registry()
@@ -346,181 +347,6 @@ def ptx_fma(context, builder, sig, args):
346
347
  return builder.fma(*args)
347
348
 
348
349
 
349
- def float16_float_ty_constraint(bitwidth):
350
- typemap = {32: ("f32", "f"), 64: ("f64", "d")}
351
-
352
- try:
353
- return typemap[bitwidth]
354
- except KeyError:
355
- msg = f"Conversion between float16 and float{bitwidth} unsupported"
356
- raise errors.CudaLoweringError(msg)
357
-
358
-
359
- @lower_cast(types.float16, types.Float)
360
- def float16_to_float_cast(context, builder, fromty, toty, val):
361
- if fromty.bitwidth == toty.bitwidth:
362
- return val
363
-
364
- ty, constraint = float16_float_ty_constraint(toty.bitwidth)
365
-
366
- fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
367
- asm = ir.InlineAsm(fnty, f"cvt.{ty}.f16 $0, $1;", f"={constraint},h")
368
- return builder.call(asm, [val])
369
-
370
-
371
- @lower_cast(types.Float, types.float16)
372
- def float_to_float16_cast(context, builder, fromty, toty, val):
373
- if fromty.bitwidth == toty.bitwidth:
374
- return val
375
-
376
- ty, constraint = float16_float_ty_constraint(fromty.bitwidth)
377
-
378
- fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
379
- asm = ir.InlineAsm(fnty, f"cvt.rn.f16.{ty} $0, $1;", f"=h,{constraint}")
380
- return builder.call(asm, [val])
381
-
382
-
383
- def float16_int_constraint(bitwidth):
384
- typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
385
-
386
- try:
387
- return typemap[bitwidth]
388
- except KeyError:
389
- msg = f"Conversion between float16 and int{bitwidth} unsupported"
390
- raise errors.CudaLoweringError(msg)
391
-
392
-
393
- @lower_cast(types.float16, types.Integer)
394
- def float16_to_integer_cast(context, builder, fromty, toty, val):
395
- bitwidth = toty.bitwidth
396
- constraint = float16_int_constraint(bitwidth)
397
- signedness = "s" if toty.signed else "u"
398
-
399
- fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
400
- asm = ir.InlineAsm(
401
- fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
402
- )
403
- return builder.call(asm, [val])
404
-
405
-
406
- @lower_cast(types.Integer, types.float16)
407
- @lower_cast(types.IntegerLiteral, types.float16)
408
- def integer_to_float16_cast(context, builder, fromty, toty, val):
409
- bitwidth = fromty.bitwidth
410
- constraint = float16_int_constraint(bitwidth)
411
- signedness = "s" if fromty.signed else "u"
412
-
413
- fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
414
- asm = ir.InlineAsm(
415
- fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
416
- )
417
- return builder.call(asm, [val])
418
-
419
-
420
- def lower_fp16_binary(fn, op):
421
- @lower(fn, types.float16, types.float16)
422
- def ptx_fp16_binary(context, builder, sig, args):
423
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
424
- asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
425
- return builder.call(asm, args)
426
-
427
-
428
- lower_fp16_binary(stubs.fp16.hadd, "add")
429
- lower_fp16_binary(operator.add, "add")
430
- lower_fp16_binary(operator.iadd, "add")
431
- lower_fp16_binary(stubs.fp16.hsub, "sub")
432
- lower_fp16_binary(operator.sub, "sub")
433
- lower_fp16_binary(operator.isub, "sub")
434
- lower_fp16_binary(stubs.fp16.hmul, "mul")
435
- lower_fp16_binary(operator.mul, "mul")
436
- lower_fp16_binary(operator.imul, "mul")
437
-
438
-
439
- @lower(stubs.fp16.hneg, types.float16)
440
- def ptx_fp16_hneg(context, builder, sig, args):
441
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
442
- asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
443
- return builder.call(asm, args)
444
-
445
-
446
- @lower(operator.neg, types.float16)
447
- def operator_hneg(context, builder, sig, args):
448
- return ptx_fp16_hneg(context, builder, sig, args)
449
-
450
-
451
- @lower(stubs.fp16.habs, types.float16)
452
- def ptx_fp16_habs(context, builder, sig, args):
453
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
454
- asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
455
- return builder.call(asm, args)
456
-
457
-
458
- @lower(abs, types.float16)
459
- def operator_habs(context, builder, sig, args):
460
- return ptx_fp16_habs(context, builder, sig, args)
461
-
462
-
463
- @lower(stubs.fp16.hfma, types.float16, types.float16, types.float16)
464
- def ptx_hfma(context, builder, sig, args):
465
- argtys = [ir.IntType(16), ir.IntType(16), ir.IntType(16)]
466
- fnty = ir.FunctionType(ir.IntType(16), argtys)
467
- asm = ir.InlineAsm(fnty, "fma.rn.f16 $0,$1,$2,$3;", "=h,h,h,h")
468
- return builder.call(asm, args)
469
-
470
-
471
- @lower(operator.truediv, types.float16, types.float16)
472
- @lower(operator.itruediv, types.float16, types.float16)
473
- def fp16_div_impl(context, builder, sig, args):
474
- def fp16_div(x, y):
475
- return cuda.fp16.hdiv(x, y)
476
-
477
- return context.compile_internal(builder, fp16_div, sig, args)
478
-
479
-
480
- _fp16_cmp = """{{
481
- .reg .pred __$$f16_cmp_tmp;
482
- setp.{op}.f16 __$$f16_cmp_tmp, $1, $2;
483
- selp.u16 $0, 1, 0, __$$f16_cmp_tmp;
484
- }}"""
485
-
486
-
487
- def _gen_fp16_cmp(op):
488
- def ptx_fp16_comparison(context, builder, sig, args):
489
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
490
- asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
491
- result = builder.call(asm, args)
492
-
493
- zero = context.get_constant(types.int16, 0)
494
- int_result = builder.bitcast(result, ir.IntType(16))
495
- return builder.icmp_unsigned("!=", int_result, zero)
496
-
497
- return ptx_fp16_comparison
498
-
499
-
500
- lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
501
- lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
502
- lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
503
- lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
504
- lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
505
- lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
506
- lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
507
- lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
508
- lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
509
- lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
510
- lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
511
- lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
512
-
513
-
514
- def lower_fp16_minmax(fn, fname, op):
515
- @lower(fn, types.float16, types.float16)
516
- def ptx_fp16_minmax(context, builder, sig, args):
517
- choice = _gen_fp16_cmp(op)(context, builder, sig, args)
518
- return builder.select(choice, args[0], args[1])
519
-
520
-
521
- lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
522
- lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
523
-
524
350
  # See:
525
351
  # https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
526
352
  # https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrtf.html#__nv_cbrtf