numba-cuda 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +0 -8
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
- numba_cuda/numba/cuda/api_util.py +6 -0
- numba_cuda/numba/cuda/cgutils.py +1291 -0
- numba_cuda/numba/cuda/codegen.py +32 -14
- numba_cuda/numba/cuda/compiler.py +113 -10
- numba_cuda/numba/cuda/core/caching.py +741 -0
- numba_cuda/numba/cuda/core/callconv.py +338 -0
- numba_cuda/numba/cuda/core/codegen.py +168 -0
- numba_cuda/numba/cuda/core/compiler.py +205 -0
- numba_cuda/numba/cuda/core/typed_passes.py +139 -0
- numba_cuda/numba/cuda/cudadecl.py +0 -268
- numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +4 -178
- numba_cuda/numba/cuda/debuginfo.py +469 -3
- numba_cuda/numba/cuda/device_init.py +0 -1
- numba_cuda/numba/cuda/dispatcher.py +309 -11
- numba_cuda/numba/cuda/extending.py +2 -1
- numba_cuda/numba/cuda/fp16.py +348 -0
- numba_cuda/numba/cuda/intrinsics.py +1 -1
- numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
- numba_cuda/numba/cuda/lowering.py +1833 -8
- numba_cuda/numba/cuda/mathimpl.py +2 -90
- numba_cuda/numba/cuda/nvvmutils.py +2 -1
- numba_cuda/numba/cuda/printimpl.py +2 -1
- numba_cuda/numba/cuda/serialize.py +264 -0
- numba_cuda/numba/cuda/simulator/__init__.py +2 -0
- numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
- numba_cuda/numba/cuda/stubs.py +0 -308
- numba_cuda/numba/cuda/target.py +13 -5
- numba_cuda/numba/cuda/testing.py +156 -5
- numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
- numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +94 -24
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
- numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
- numba_cuda/numba/cuda/utils.py +785 -0
- numba_cuda/numba/cuda/vector_types.py +1 -1
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/METADATA +18 -4
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/RECORD +61 -48
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/top_level.txt +0 -0
numba_cuda/numba/cuda/stubs.py
CHANGED
|
@@ -542,314 +542,6 @@ class nanosleep(Stub):
|
|
|
542
542
|
_description_ = "<nansleep()>"
|
|
543
543
|
|
|
544
544
|
|
|
545
|
-
# -------------------------------------------------------------------------------
|
|
546
|
-
# Floating point 16
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
class fp16(Stub):
|
|
550
|
-
"""Namespace for fp16 operations"""
|
|
551
|
-
|
|
552
|
-
_description_ = "<fp16>"
|
|
553
|
-
|
|
554
|
-
class hadd(Stub):
|
|
555
|
-
"""hadd(a, b)
|
|
556
|
-
|
|
557
|
-
Perform fp16 addition, (a + b) in round to nearest mode. Supported
|
|
558
|
-
on fp16 operands only.
|
|
559
|
-
|
|
560
|
-
Returns the fp16 result of the addition.
|
|
561
|
-
|
|
562
|
-
"""
|
|
563
|
-
|
|
564
|
-
class hsub(Stub):
|
|
565
|
-
"""hsub(a, b)
|
|
566
|
-
|
|
567
|
-
Perform fp16 subtraction, (a - b) in round to nearest mode. Supported
|
|
568
|
-
on fp16 operands only.
|
|
569
|
-
|
|
570
|
-
Returns the fp16 result of the subtraction.
|
|
571
|
-
|
|
572
|
-
"""
|
|
573
|
-
|
|
574
|
-
class hmul(Stub):
|
|
575
|
-
"""hmul(a, b)
|
|
576
|
-
|
|
577
|
-
Perform fp16 multiplication, (a * b) in round to nearest mode. Supported
|
|
578
|
-
on fp16 operands only.
|
|
579
|
-
|
|
580
|
-
Returns the fp16 result of the multiplication.
|
|
581
|
-
|
|
582
|
-
"""
|
|
583
|
-
|
|
584
|
-
class hdiv(Stub):
|
|
585
|
-
"""hdiv(a, b)
|
|
586
|
-
|
|
587
|
-
Perform fp16 division, (a / b) in round to nearest mode. Supported
|
|
588
|
-
on fp16 operands only.
|
|
589
|
-
|
|
590
|
-
Returns the fp16 result of the division
|
|
591
|
-
|
|
592
|
-
"""
|
|
593
|
-
|
|
594
|
-
class hfma(Stub):
|
|
595
|
-
"""hfma(a, b, c)
|
|
596
|
-
|
|
597
|
-
Perform fp16 multiply and accumulate, (a * b) + c in round to nearest
|
|
598
|
-
mode. Supported on fp16 operands only.
|
|
599
|
-
|
|
600
|
-
Returns the fp16 result of the multiplication.
|
|
601
|
-
|
|
602
|
-
"""
|
|
603
|
-
|
|
604
|
-
class hneg(Stub):
|
|
605
|
-
"""hneg(a)
|
|
606
|
-
|
|
607
|
-
Perform fp16 negation, -(a). Supported on fp16 operands only.
|
|
608
|
-
|
|
609
|
-
Returns the fp16 result of the negation.
|
|
610
|
-
|
|
611
|
-
"""
|
|
612
|
-
|
|
613
|
-
class habs(Stub):
|
|
614
|
-
"""habs(a)
|
|
615
|
-
|
|
616
|
-
Perform fp16 absolute value, |a|. Supported on fp16 operands only.
|
|
617
|
-
|
|
618
|
-
Returns the fp16 result of the absolute value.
|
|
619
|
-
|
|
620
|
-
"""
|
|
621
|
-
|
|
622
|
-
class hsin(Stub):
|
|
623
|
-
"""hsin(a)
|
|
624
|
-
|
|
625
|
-
Calculate sine in round to nearest even mode. Supported on fp16
|
|
626
|
-
operands only.
|
|
627
|
-
|
|
628
|
-
Returns the sine result.
|
|
629
|
-
|
|
630
|
-
"""
|
|
631
|
-
|
|
632
|
-
class hcos(Stub):
|
|
633
|
-
"""hsin(a)
|
|
634
|
-
|
|
635
|
-
Calculate cosine in round to nearest even mode. Supported on fp16
|
|
636
|
-
operands only.
|
|
637
|
-
|
|
638
|
-
Returns the cosine result.
|
|
639
|
-
|
|
640
|
-
"""
|
|
641
|
-
|
|
642
|
-
class hlog(Stub):
|
|
643
|
-
"""hlog(a)
|
|
644
|
-
|
|
645
|
-
Calculate natural logarithm in round to nearest even mode. Supported
|
|
646
|
-
on fp16 operands only.
|
|
647
|
-
|
|
648
|
-
Returns the natural logarithm result.
|
|
649
|
-
|
|
650
|
-
"""
|
|
651
|
-
|
|
652
|
-
class hlog10(Stub):
|
|
653
|
-
"""hlog10(a)
|
|
654
|
-
|
|
655
|
-
Calculate logarithm base 10 in round to nearest even mode. Supported
|
|
656
|
-
on fp16 operands only.
|
|
657
|
-
|
|
658
|
-
Returns the logarithm base 10 result.
|
|
659
|
-
|
|
660
|
-
"""
|
|
661
|
-
|
|
662
|
-
class hlog2(Stub):
|
|
663
|
-
"""hlog2(a)
|
|
664
|
-
|
|
665
|
-
Calculate logarithm base 2 in round to nearest even mode. Supported
|
|
666
|
-
on fp16 operands only.
|
|
667
|
-
|
|
668
|
-
Returns the logarithm base 2 result.
|
|
669
|
-
|
|
670
|
-
"""
|
|
671
|
-
|
|
672
|
-
class hexp(Stub):
|
|
673
|
-
"""hexp(a)
|
|
674
|
-
|
|
675
|
-
Calculate natural exponential, exp(a), in round to nearest mode.
|
|
676
|
-
Supported on fp16 operands only.
|
|
677
|
-
|
|
678
|
-
Returns the natural exponential result.
|
|
679
|
-
|
|
680
|
-
"""
|
|
681
|
-
|
|
682
|
-
class hexp10(Stub):
|
|
683
|
-
"""hexp10(a)
|
|
684
|
-
|
|
685
|
-
Calculate exponential base 10 (10 ** a) in round to nearest mode.
|
|
686
|
-
Supported on fp16 operands only.
|
|
687
|
-
|
|
688
|
-
Returns the exponential base 10 result.
|
|
689
|
-
|
|
690
|
-
"""
|
|
691
|
-
|
|
692
|
-
class hexp2(Stub):
|
|
693
|
-
"""hexp2(a)
|
|
694
|
-
|
|
695
|
-
Calculate exponential base 2 (2 ** a) in round to nearest mode.
|
|
696
|
-
Supported on fp16 operands only.
|
|
697
|
-
|
|
698
|
-
Returns the exponential base 2 result.
|
|
699
|
-
|
|
700
|
-
"""
|
|
701
|
-
|
|
702
|
-
class hfloor(Stub):
|
|
703
|
-
"""hfloor(a)
|
|
704
|
-
|
|
705
|
-
Calculate the floor, the largest integer less than or equal to 'a'.
|
|
706
|
-
Supported on fp16 operands only.
|
|
707
|
-
|
|
708
|
-
Returns the floor result.
|
|
709
|
-
|
|
710
|
-
"""
|
|
711
|
-
|
|
712
|
-
class hceil(Stub):
|
|
713
|
-
"""hceil(a)
|
|
714
|
-
|
|
715
|
-
Calculate the ceil, the smallest integer greater than or equal to 'a'.
|
|
716
|
-
Supported on fp16 operands only.
|
|
717
|
-
|
|
718
|
-
Returns the ceil result.
|
|
719
|
-
|
|
720
|
-
"""
|
|
721
|
-
|
|
722
|
-
class hsqrt(Stub):
|
|
723
|
-
"""hsqrt(a)
|
|
724
|
-
|
|
725
|
-
Calculate the square root of the input argument in round to nearest
|
|
726
|
-
mode. Supported on fp16 operands only.
|
|
727
|
-
|
|
728
|
-
Returns the square root result.
|
|
729
|
-
|
|
730
|
-
"""
|
|
731
|
-
|
|
732
|
-
class hrsqrt(Stub):
|
|
733
|
-
"""hrsqrt(a)
|
|
734
|
-
|
|
735
|
-
Calculate the reciprocal square root of the input argument in round
|
|
736
|
-
to nearest even mode. Supported on fp16 operands only.
|
|
737
|
-
|
|
738
|
-
Returns the reciprocal square root result.
|
|
739
|
-
|
|
740
|
-
"""
|
|
741
|
-
|
|
742
|
-
class hrcp(Stub):
|
|
743
|
-
"""hrcp(a)
|
|
744
|
-
|
|
745
|
-
Calculate the reciprocal of the input argument in round to nearest
|
|
746
|
-
even mode. Supported on fp16 operands only.
|
|
747
|
-
|
|
748
|
-
Returns the reciprocal result.
|
|
749
|
-
|
|
750
|
-
"""
|
|
751
|
-
|
|
752
|
-
class hrint(Stub):
|
|
753
|
-
"""hrint(a)
|
|
754
|
-
|
|
755
|
-
Round the input argument to nearest integer value. Supported on fp16
|
|
756
|
-
operands only.
|
|
757
|
-
|
|
758
|
-
Returns the rounded result.
|
|
759
|
-
|
|
760
|
-
"""
|
|
761
|
-
|
|
762
|
-
class htrunc(Stub):
|
|
763
|
-
"""htrunc(a)
|
|
764
|
-
|
|
765
|
-
Truncate the input argument to its integer portion. Supported
|
|
766
|
-
on fp16 operands only.
|
|
767
|
-
|
|
768
|
-
Returns the truncated result.
|
|
769
|
-
|
|
770
|
-
"""
|
|
771
|
-
|
|
772
|
-
class heq(Stub):
|
|
773
|
-
"""heq(a, b)
|
|
774
|
-
|
|
775
|
-
Perform fp16 comparison, (a == b). Supported
|
|
776
|
-
on fp16 operands only.
|
|
777
|
-
|
|
778
|
-
Returns True if a and b are equal and False otherwise.
|
|
779
|
-
|
|
780
|
-
"""
|
|
781
|
-
|
|
782
|
-
class hne(Stub):
|
|
783
|
-
"""hne(a, b)
|
|
784
|
-
|
|
785
|
-
Perform fp16 comparison, (a != b). Supported
|
|
786
|
-
on fp16 operands only.
|
|
787
|
-
|
|
788
|
-
Returns True if a and b are not equal and False otherwise.
|
|
789
|
-
|
|
790
|
-
"""
|
|
791
|
-
|
|
792
|
-
class hge(Stub):
|
|
793
|
-
"""hge(a, b)
|
|
794
|
-
|
|
795
|
-
Perform fp16 comparison, (a >= b). Supported
|
|
796
|
-
on fp16 operands only.
|
|
797
|
-
|
|
798
|
-
Returns True if a is >= b and False otherwise.
|
|
799
|
-
|
|
800
|
-
"""
|
|
801
|
-
|
|
802
|
-
class hgt(Stub):
|
|
803
|
-
"""hgt(a, b)
|
|
804
|
-
|
|
805
|
-
Perform fp16 comparison, (a > b). Supported
|
|
806
|
-
on fp16 operands only.
|
|
807
|
-
|
|
808
|
-
Returns True if a is > b and False otherwise.
|
|
809
|
-
|
|
810
|
-
"""
|
|
811
|
-
|
|
812
|
-
class hle(Stub):
|
|
813
|
-
"""hle(a, b)
|
|
814
|
-
|
|
815
|
-
Perform fp16 comparison, (a <= b). Supported
|
|
816
|
-
on fp16 operands only.
|
|
817
|
-
|
|
818
|
-
Returns True if a is <= b and False otherwise.
|
|
819
|
-
|
|
820
|
-
"""
|
|
821
|
-
|
|
822
|
-
class hlt(Stub):
|
|
823
|
-
"""hlt(a, b)
|
|
824
|
-
|
|
825
|
-
Perform fp16 comparison, (a < b). Supported
|
|
826
|
-
on fp16 operands only.
|
|
827
|
-
|
|
828
|
-
Returns True if a is < b and False otherwise.
|
|
829
|
-
|
|
830
|
-
"""
|
|
831
|
-
|
|
832
|
-
class hmax(Stub):
|
|
833
|
-
"""hmax(a, b)
|
|
834
|
-
|
|
835
|
-
Perform fp16 maximum operation, max(a,b) Supported
|
|
836
|
-
on fp16 operands only.
|
|
837
|
-
|
|
838
|
-
Returns a if a is greater than b, returns b otherwise.
|
|
839
|
-
|
|
840
|
-
"""
|
|
841
|
-
|
|
842
|
-
class hmin(Stub):
|
|
843
|
-
"""hmin(a, b)
|
|
844
|
-
|
|
845
|
-
Perform fp16 minimum operation, min(a,b). Supported
|
|
846
|
-
on fp16 operands only.
|
|
847
|
-
|
|
848
|
-
Returns a if a is less than b, returns b otherwise.
|
|
849
|
-
|
|
850
|
-
"""
|
|
851
|
-
|
|
852
|
-
|
|
853
545
|
# -------------------------------------------------------------------------------
|
|
854
546
|
# vector types
|
|
855
547
|
|
numba_cuda/numba/cuda/target.py
CHANGED
|
@@ -3,9 +3,8 @@ from functools import cached_property
|
|
|
3
3
|
import llvmlite.binding as ll
|
|
4
4
|
from llvmlite import ir
|
|
5
5
|
import warnings
|
|
6
|
-
|
|
6
|
+
from numba.cuda import cgutils
|
|
7
7
|
from numba.core import (
|
|
8
|
-
cgutils,
|
|
9
8
|
compiler,
|
|
10
9
|
config,
|
|
11
10
|
itanium_mangler,
|
|
@@ -17,7 +16,7 @@ from numba.core.compiler_lock import global_compiler_lock
|
|
|
17
16
|
from numba.core.dispatcher import Dispatcher
|
|
18
17
|
from numba.core.errors import NumbaWarning
|
|
19
18
|
from numba.core.base import BaseContext
|
|
20
|
-
from numba.core.callconv import BaseCallConv, MinimalCallConv
|
|
19
|
+
from numba.cuda.core.callconv import BaseCallConv, MinimalCallConv
|
|
21
20
|
from numba.core.typing import cmathdecl
|
|
22
21
|
from numba.core import datamodel
|
|
23
22
|
|
|
@@ -33,7 +32,7 @@ from numba.cuda.models import cuda_data_manager
|
|
|
33
32
|
|
|
34
33
|
class CUDATypingContext(typing.BaseContext):
|
|
35
34
|
def load_additional_registries(self):
|
|
36
|
-
from . import cudadecl, cudamath, libdevicedecl, vector_types
|
|
35
|
+
from . import cudadecl, cudamath, fp16, libdevicedecl, vector_types
|
|
37
36
|
from numba.core.typing import enumdecl, cffi_utils
|
|
38
37
|
|
|
39
38
|
self.install_registry(cudadecl.registry)
|
|
@@ -43,6 +42,7 @@ class CUDATypingContext(typing.BaseContext):
|
|
|
43
42
|
self.install_registry(libdevicedecl.registry)
|
|
44
43
|
self.install_registry(enumdecl.registry)
|
|
45
44
|
self.install_registry(vector_types.typing_registry)
|
|
45
|
+
self.install_registry(fp16.typing_registry)
|
|
46
46
|
|
|
47
47
|
def resolve_value_type(self, val):
|
|
48
48
|
# treat other dispatcher object as another device function
|
|
@@ -148,7 +148,14 @@ class CUDATargetContext(BaseContext):
|
|
|
148
148
|
from numba.misc import cffiimpl
|
|
149
149
|
from numba.np import arrayobj # noqa: F401
|
|
150
150
|
from numba.np import npdatetime # noqa: F401
|
|
151
|
-
from . import
|
|
151
|
+
from . import (
|
|
152
|
+
cudaimpl,
|
|
153
|
+
fp16,
|
|
154
|
+
printimpl,
|
|
155
|
+
libdeviceimpl,
|
|
156
|
+
mathimpl,
|
|
157
|
+
vector_types,
|
|
158
|
+
)
|
|
152
159
|
|
|
153
160
|
# fix for #8940
|
|
154
161
|
from numba.np.unsafe import ndarray # noqa F401
|
|
@@ -160,6 +167,7 @@ class CUDATargetContext(BaseContext):
|
|
|
160
167
|
self.install_registry(cmathimpl.registry)
|
|
161
168
|
self.install_registry(mathimpl.registry)
|
|
162
169
|
self.install_registry(vector_types.impl_registry)
|
|
170
|
+
self.install_registry(fp16.target_registry)
|
|
163
171
|
|
|
164
172
|
def codegen(self):
|
|
165
173
|
return self._internal_codegen
|
numba_cuda/numba/cuda/testing.py
CHANGED
|
@@ -1,25 +1,44 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import platform
|
|
3
3
|
import shutil
|
|
4
|
-
|
|
5
|
-
from
|
|
4
|
+
import pytest
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from numba.core.utils import PYVERSION
|
|
6
7
|
from numba.cuda.cuda_paths import get_conda_ctk
|
|
7
8
|
from numba.cuda.cudadrv import driver, devices, libs
|
|
9
|
+
from numba.cuda.dispatcher import CUDADispatcher
|
|
8
10
|
from numba.core import config
|
|
9
11
|
from numba.tests.support import TestCase
|
|
10
12
|
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from typing import Iterable, Union
|
|
15
|
+
from io import StringIO
|
|
11
16
|
import unittest
|
|
12
17
|
|
|
18
|
+
if PYVERSION >= (3, 10):
|
|
19
|
+
from filecheck.matcher import Matcher
|
|
20
|
+
from filecheck.options import Options
|
|
21
|
+
from filecheck.parser import Parser, pattern_for_opts
|
|
22
|
+
from filecheck.finput import FInput
|
|
23
|
+
|
|
13
24
|
numba_cuda_dir = Path(__file__).parent
|
|
14
25
|
test_data_dir = numba_cuda_dir / "tests" / "data"
|
|
15
26
|
|
|
16
27
|
|
|
17
|
-
|
|
28
|
+
@pytest.mark.usefixtures("initialize_from_pytest_config")
|
|
29
|
+
class CUDATestCase(TestCase):
|
|
18
30
|
"""
|
|
19
31
|
For tests that use a CUDA device. Test methods in a CUDATestCase must not
|
|
20
32
|
be run out of module order, because the ContextResettingTestCase may reset
|
|
21
33
|
the context and destroy resources used by a normal CUDATestCase if any of
|
|
22
34
|
its tests are run between tests from a CUDATestCase.
|
|
35
|
+
|
|
36
|
+
Methods assertFileCheckAsm and assertFileCheckLLVM will inspect a
|
|
37
|
+
CUDADispatcher and assert that the compilation artifacts match the
|
|
38
|
+
FileCheck checks given in the kernel's docstring.
|
|
39
|
+
|
|
40
|
+
Method assertFileCheckMatches can be used to assert that a given string
|
|
41
|
+
matches FileCheck checks, and is not specific to CUDADispatcher.
|
|
23
42
|
"""
|
|
24
43
|
|
|
25
44
|
def setUp(self):
|
|
@@ -35,6 +54,134 @@ class CUDATestCase(SerialMixin, TestCase):
|
|
|
35
54
|
config.CUDA_LOW_OCCUPANCY_WARNINGS = self._low_occupancy_warnings
|
|
36
55
|
config.CUDA_WARN_ON_IMPLICIT_COPY = self._warn_on_implicit_copy
|
|
37
56
|
|
|
57
|
+
Signature = Union[tuple[type, ...], None]
|
|
58
|
+
|
|
59
|
+
def _getIRContents(
|
|
60
|
+
self,
|
|
61
|
+
ir_result: Union[dict[Signature, str], str],
|
|
62
|
+
signature: Union[Signature, None] = None,
|
|
63
|
+
) -> Iterable[str]:
|
|
64
|
+
if isinstance(ir_result, str):
|
|
65
|
+
assert signature is None, (
|
|
66
|
+
"Cannot use signature because the kernel was only compiled for one signature"
|
|
67
|
+
)
|
|
68
|
+
return [ir_result]
|
|
69
|
+
|
|
70
|
+
if signature is None:
|
|
71
|
+
return list(ir_result.values())
|
|
72
|
+
|
|
73
|
+
return [ir_result[signature]]
|
|
74
|
+
|
|
75
|
+
def assertFileCheckAsm(
|
|
76
|
+
self,
|
|
77
|
+
ir_producer: CUDADispatcher,
|
|
78
|
+
signature: Union[tuple[type, ...], None] = None,
|
|
79
|
+
check_prefixes: tuple[str] = ("ASM",),
|
|
80
|
+
**extra_filecheck_options,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Assert that the assembly output of the given CUDADispatcher matches
|
|
84
|
+
the FileCheck checks given in the kernel's docstring.
|
|
85
|
+
"""
|
|
86
|
+
ir_contents = self._getIRContents(ir_producer.inspect_asm(), signature)
|
|
87
|
+
assert ir_contents, "No assembly output found for the given signature."
|
|
88
|
+
assert ir_producer.__doc__ is not None, (
|
|
89
|
+
"Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
|
|
90
|
+
)
|
|
91
|
+
check_patterns = ir_producer.__doc__
|
|
92
|
+
for ir_content in ir_contents:
|
|
93
|
+
self.assertFileCheckMatches(
|
|
94
|
+
ir_content,
|
|
95
|
+
check_patterns=check_patterns,
|
|
96
|
+
check_prefixes=check_prefixes,
|
|
97
|
+
**extra_filecheck_options,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def assertFileCheckLLVM(
|
|
101
|
+
self,
|
|
102
|
+
ir_producer: CUDADispatcher,
|
|
103
|
+
signature: Union[tuple[type, ...], None] = None,
|
|
104
|
+
check_prefixes: tuple[str] = ("LLVM",),
|
|
105
|
+
**extra_filecheck_options,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Assert that the LLVM IR output of the given CUDADispatcher matches
|
|
109
|
+
the FileCheck checks given in the kernel's docstring.
|
|
110
|
+
"""
|
|
111
|
+
ir_contents = self._getIRContents(ir_producer.inspect_llvm(), signature)
|
|
112
|
+
assert ir_contents, "No LLVM IR output found for the given signature."
|
|
113
|
+
assert ir_producer.__doc__ is not None, (
|
|
114
|
+
"Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
|
|
115
|
+
)
|
|
116
|
+
check_patterns = ir_producer.__doc__
|
|
117
|
+
for ir_content in ir_contents:
|
|
118
|
+
assert ir_content, (
|
|
119
|
+
"LLVM IR content is empty for the given signature."
|
|
120
|
+
)
|
|
121
|
+
self.assertFileCheckMatches(
|
|
122
|
+
ir_content,
|
|
123
|
+
check_patterns=check_patterns,
|
|
124
|
+
check_prefixes=check_prefixes,
|
|
125
|
+
**extra_filecheck_options,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def assertFileCheckMatches(
|
|
129
|
+
self,
|
|
130
|
+
ir_content: str,
|
|
131
|
+
check_patterns: str,
|
|
132
|
+
check_prefixes: tuple[str] = ("CHECK",),
|
|
133
|
+
**extra_filecheck_options,
|
|
134
|
+
) -> None:
|
|
135
|
+
"""
|
|
136
|
+
Assert that the given string matches the passed FileCheck checks.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
ir_content: The string to check against.
|
|
140
|
+
check_patterns: The FileCheck checks to use.
|
|
141
|
+
check_prefixes: The prefixes to use for the FileCheck checks.
|
|
142
|
+
extra_filecheck_options: Extra options to pass to FileCheck.
|
|
143
|
+
"""
|
|
144
|
+
if PYVERSION < (3, 10):
|
|
145
|
+
self.skipTest("FileCheck requires Python 3.10 or later")
|
|
146
|
+
opts = Options(
|
|
147
|
+
match_filename="-",
|
|
148
|
+
check_prefixes=list(check_prefixes),
|
|
149
|
+
**extra_filecheck_options,
|
|
150
|
+
)
|
|
151
|
+
input_file = FInput(fname="-", content=ir_content)
|
|
152
|
+
parser = Parser(opts, StringIO(check_patterns), *pattern_for_opts(opts))
|
|
153
|
+
matcher = Matcher(opts, input_file, parser)
|
|
154
|
+
matcher.stderr = StringIO()
|
|
155
|
+
result = matcher.run()
|
|
156
|
+
if result != 0:
|
|
157
|
+
dump_instructions = ""
|
|
158
|
+
if self._dump_failed_filechecks:
|
|
159
|
+
dump_directory = Path(
|
|
160
|
+
datetime.now().strftime("numba-ir-%Y_%m_%d_%H_%M_%S")
|
|
161
|
+
)
|
|
162
|
+
if not dump_directory.exists():
|
|
163
|
+
dump_directory.mkdir(parents=True, exist_ok=True)
|
|
164
|
+
base_path = self.id().replace(".", "_")
|
|
165
|
+
ir_dump = dump_directory / Path(base_path).with_suffix(".ll")
|
|
166
|
+
checks_dump = dump_directory / Path(base_path).with_suffix(
|
|
167
|
+
".checks"
|
|
168
|
+
)
|
|
169
|
+
with (
|
|
170
|
+
open(ir_dump, "w") as ir_file,
|
|
171
|
+
open(checks_dump, "w") as checks_file,
|
|
172
|
+
):
|
|
173
|
+
_ = ir_file.write(ir_content + "\n")
|
|
174
|
+
_ = checks_file.write(check_patterns)
|
|
175
|
+
dump_instructions = f"Reproduce with:\n\nfilecheck --check-prefixes={','.join(check_prefixes)} {checks_dump} --input-file={ir_dump}"
|
|
176
|
+
|
|
177
|
+
self.fail(
|
|
178
|
+
f"FileCheck failed:\n{matcher.stderr.getvalue()}\n\n"
|
|
179
|
+
+ f"Check prefixes:\n{check_prefixes}\n\n"
|
|
180
|
+
+ f"Check patterns:\n{check_patterns}\n"
|
|
181
|
+
+ f"IR:\n{ir_content}\n\n"
|
|
182
|
+
+ dump_instructions
|
|
183
|
+
)
|
|
184
|
+
|
|
38
185
|
|
|
39
186
|
class ContextResettingTestCase(CUDATestCase):
|
|
40
187
|
"""
|
|
@@ -127,8 +274,8 @@ def skip_if_mvc_enabled(reason):
|
|
|
127
274
|
def skip_if_mvc_libraries_unavailable(fn):
|
|
128
275
|
libs_available = False
|
|
129
276
|
try:
|
|
130
|
-
import cubinlinker # noqa: F401
|
|
131
|
-
import ptxcompiler # noqa: F401
|
|
277
|
+
import cubinlinker # noqa: F401 # type: ignore
|
|
278
|
+
import ptxcompiler # noqa: F401 # type: ignore
|
|
132
279
|
|
|
133
280
|
libs_available = True
|
|
134
281
|
except ImportError:
|
|
@@ -189,6 +336,10 @@ def skip_if_cudadevrt_missing(fn):
|
|
|
189
336
|
return unittest.skipIf(cudadevrt_missing(), "cudadevrt missing")(fn)
|
|
190
337
|
|
|
191
338
|
|
|
339
|
+
def skip_if_nvjitlink_missing(reason):
|
|
340
|
+
return unittest.skipIf(not driver._have_nvjitlink(), reason)
|
|
341
|
+
|
|
342
|
+
|
|
192
343
|
class ForeignArray(object):
|
|
193
344
|
"""
|
|
194
345
|
Class for emulating an array coming from another library through the CUDA
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import cmath
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def div_usecase(x, y):
|
|
5
|
+
return x / y
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def real_usecase(x):
|
|
9
|
+
return x.real
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def imag_usecase(x):
|
|
13
|
+
return x.imag
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def conjugate_usecase(x):
|
|
17
|
+
return x.conjugate()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def acos_usecase(x):
|
|
21
|
+
return cmath.acos(x)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def cos_usecase(x):
|
|
25
|
+
return cmath.cos(x)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def asin_usecase(x):
|
|
29
|
+
return cmath.asin(x)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def sin_usecase(x):
|
|
33
|
+
return cmath.sin(x)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def atan_usecase(x):
|
|
37
|
+
return cmath.atan(x)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def tan_usecase(x):
|
|
41
|
+
return cmath.tan(x)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def acosh_usecase(x):
|
|
45
|
+
return cmath.acosh(x)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def cosh_usecase(x):
|
|
49
|
+
return cmath.cosh(x)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def asinh_usecase(x):
|
|
53
|
+
return cmath.asinh(x)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def sinh_usecase(x):
|
|
57
|
+
return cmath.sinh(x)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def atanh_usecase(x):
|
|
61
|
+
return cmath.atanh(x)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def tanh_usecase(x):
|
|
65
|
+
return cmath.tanh(x)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def exp_usecase(x):
|
|
69
|
+
return cmath.exp(x)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def isfinite_usecase(x):
|
|
73
|
+
return cmath.isfinite(x)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def isinf_usecase(x):
|
|
77
|
+
return cmath.isinf(x)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def isnan_usecase(x):
|
|
81
|
+
return cmath.isnan(x)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def log_usecase(x):
|
|
85
|
+
return cmath.log(x)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def log_base_usecase(x, base):
|
|
89
|
+
return cmath.log(x, base)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def log10_usecase(x):
|
|
93
|
+
return cmath.log10(x)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def phase_usecase(x):
|
|
97
|
+
return cmath.phase(x)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def polar_usecase(x):
|
|
101
|
+
return cmath.polar(x)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def polar_as_complex_usecase(x):
|
|
105
|
+
return complex(*cmath.polar(x))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def rect_usecase(r, phi):
|
|
109
|
+
return cmath.rect(r, phi)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def sqrt_usecase(x):
|
|
113
|
+
return cmath.sqrt(x)
|