numba-cuda 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +0 -8
  3. numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
  4. numba_cuda/numba/cuda/api_util.py +6 -0
  5. numba_cuda/numba/cuda/cgutils.py +1291 -0
  6. numba_cuda/numba/cuda/codegen.py +32 -14
  7. numba_cuda/numba/cuda/compiler.py +113 -10
  8. numba_cuda/numba/cuda/core/caching.py +741 -0
  9. numba_cuda/numba/cuda/core/callconv.py +338 -0
  10. numba_cuda/numba/cuda/core/codegen.py +168 -0
  11. numba_cuda/numba/cuda/core/compiler.py +205 -0
  12. numba_cuda/numba/cuda/core/typed_passes.py +139 -0
  13. numba_cuda/numba/cuda/cudadecl.py +0 -268
  14. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  15. numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
  16. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
  17. numba_cuda/numba/cuda/cudaimpl.py +4 -178
  18. numba_cuda/numba/cuda/debuginfo.py +469 -3
  19. numba_cuda/numba/cuda/device_init.py +0 -1
  20. numba_cuda/numba/cuda/dispatcher.py +309 -11
  21. numba_cuda/numba/cuda/extending.py +2 -1
  22. numba_cuda/numba/cuda/fp16.py +348 -0
  23. numba_cuda/numba/cuda/intrinsics.py +1 -1
  24. numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
  25. numba_cuda/numba/cuda/lowering.py +1833 -8
  26. numba_cuda/numba/cuda/mathimpl.py +2 -90
  27. numba_cuda/numba/cuda/nvvmutils.py +2 -1
  28. numba_cuda/numba/cuda/printimpl.py +2 -1
  29. numba_cuda/numba/cuda/serialize.py +264 -0
  30. numba_cuda/numba/cuda/simulator/__init__.py +2 -0
  31. numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
  32. numba_cuda/numba/cuda/stubs.py +0 -308
  33. numba_cuda/numba/cuda/target.py +13 -5
  34. numba_cuda/numba/cuda/testing.py +156 -5
  35. numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
  36. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
  37. numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
  38. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
  39. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
  40. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
  41. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  42. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
  43. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  44. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
  45. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
  46. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  47. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +94 -24
  48. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
  49. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
  50. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
  51. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
  52. numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
  53. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
  54. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
  55. numba_cuda/numba/cuda/utils.py +785 -0
  56. numba_cuda/numba/cuda/vector_types.py +1 -1
  57. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/METADATA +18 -4
  58. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/RECORD +61 -48
  59. numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
  60. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/WHEEL +0 -0
  61. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/licenses/LICENSE +0 -0
  62. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.0.dist-info}/top_level.txt +0 -0
@@ -542,314 +542,6 @@ class nanosleep(Stub):
542
542
  _description_ = "<nansleep()>"
543
543
 
544
544
 
545
- # -------------------------------------------------------------------------------
546
- # Floating point 16
547
-
548
-
549
- class fp16(Stub):
550
- """Namespace for fp16 operations"""
551
-
552
- _description_ = "<fp16>"
553
-
554
- class hadd(Stub):
555
- """hadd(a, b)
556
-
557
- Perform fp16 addition, (a + b) in round to nearest mode. Supported
558
- on fp16 operands only.
559
-
560
- Returns the fp16 result of the addition.
561
-
562
- """
563
-
564
- class hsub(Stub):
565
- """hsub(a, b)
566
-
567
- Perform fp16 subtraction, (a - b) in round to nearest mode. Supported
568
- on fp16 operands only.
569
-
570
- Returns the fp16 result of the subtraction.
571
-
572
- """
573
-
574
- class hmul(Stub):
575
- """hmul(a, b)
576
-
577
- Perform fp16 multiplication, (a * b) in round to nearest mode. Supported
578
- on fp16 operands only.
579
-
580
- Returns the fp16 result of the multiplication.
581
-
582
- """
583
-
584
- class hdiv(Stub):
585
- """hdiv(a, b)
586
-
587
- Perform fp16 division, (a / b) in round to nearest mode. Supported
588
- on fp16 operands only.
589
-
590
- Returns the fp16 result of the division
591
-
592
- """
593
-
594
- class hfma(Stub):
595
- """hfma(a, b, c)
596
-
597
- Perform fp16 multiply and accumulate, (a * b) + c in round to nearest
598
- mode. Supported on fp16 operands only.
599
-
600
- Returns the fp16 result of the multiplication.
601
-
602
- """
603
-
604
- class hneg(Stub):
605
- """hneg(a)
606
-
607
- Perform fp16 negation, -(a). Supported on fp16 operands only.
608
-
609
- Returns the fp16 result of the negation.
610
-
611
- """
612
-
613
- class habs(Stub):
614
- """habs(a)
615
-
616
- Perform fp16 absolute value, |a|. Supported on fp16 operands only.
617
-
618
- Returns the fp16 result of the absolute value.
619
-
620
- """
621
-
622
- class hsin(Stub):
623
- """hsin(a)
624
-
625
- Calculate sine in round to nearest even mode. Supported on fp16
626
- operands only.
627
-
628
- Returns the sine result.
629
-
630
- """
631
-
632
- class hcos(Stub):
633
- """hsin(a)
634
-
635
- Calculate cosine in round to nearest even mode. Supported on fp16
636
- operands only.
637
-
638
- Returns the cosine result.
639
-
640
- """
641
-
642
- class hlog(Stub):
643
- """hlog(a)
644
-
645
- Calculate natural logarithm in round to nearest even mode. Supported
646
- on fp16 operands only.
647
-
648
- Returns the natural logarithm result.
649
-
650
- """
651
-
652
- class hlog10(Stub):
653
- """hlog10(a)
654
-
655
- Calculate logarithm base 10 in round to nearest even mode. Supported
656
- on fp16 operands only.
657
-
658
- Returns the logarithm base 10 result.
659
-
660
- """
661
-
662
- class hlog2(Stub):
663
- """hlog2(a)
664
-
665
- Calculate logarithm base 2 in round to nearest even mode. Supported
666
- on fp16 operands only.
667
-
668
- Returns the logarithm base 2 result.
669
-
670
- """
671
-
672
- class hexp(Stub):
673
- """hexp(a)
674
-
675
- Calculate natural exponential, exp(a), in round to nearest mode.
676
- Supported on fp16 operands only.
677
-
678
- Returns the natural exponential result.
679
-
680
- """
681
-
682
- class hexp10(Stub):
683
- """hexp10(a)
684
-
685
- Calculate exponential base 10 (10 ** a) in round to nearest mode.
686
- Supported on fp16 operands only.
687
-
688
- Returns the exponential base 10 result.
689
-
690
- """
691
-
692
- class hexp2(Stub):
693
- """hexp2(a)
694
-
695
- Calculate exponential base 2 (2 ** a) in round to nearest mode.
696
- Supported on fp16 operands only.
697
-
698
- Returns the exponential base 2 result.
699
-
700
- """
701
-
702
- class hfloor(Stub):
703
- """hfloor(a)
704
-
705
- Calculate the floor, the largest integer less than or equal to 'a'.
706
- Supported on fp16 operands only.
707
-
708
- Returns the floor result.
709
-
710
- """
711
-
712
- class hceil(Stub):
713
- """hceil(a)
714
-
715
- Calculate the ceil, the smallest integer greater than or equal to 'a'.
716
- Supported on fp16 operands only.
717
-
718
- Returns the ceil result.
719
-
720
- """
721
-
722
- class hsqrt(Stub):
723
- """hsqrt(a)
724
-
725
- Calculate the square root of the input argument in round to nearest
726
- mode. Supported on fp16 operands only.
727
-
728
- Returns the square root result.
729
-
730
- """
731
-
732
- class hrsqrt(Stub):
733
- """hrsqrt(a)
734
-
735
- Calculate the reciprocal square root of the input argument in round
736
- to nearest even mode. Supported on fp16 operands only.
737
-
738
- Returns the reciprocal square root result.
739
-
740
- """
741
-
742
- class hrcp(Stub):
743
- """hrcp(a)
744
-
745
- Calculate the reciprocal of the input argument in round to nearest
746
- even mode. Supported on fp16 operands only.
747
-
748
- Returns the reciprocal result.
749
-
750
- """
751
-
752
- class hrint(Stub):
753
- """hrint(a)
754
-
755
- Round the input argument to nearest integer value. Supported on fp16
756
- operands only.
757
-
758
- Returns the rounded result.
759
-
760
- """
761
-
762
- class htrunc(Stub):
763
- """htrunc(a)
764
-
765
- Truncate the input argument to its integer portion. Supported
766
- on fp16 operands only.
767
-
768
- Returns the truncated result.
769
-
770
- """
771
-
772
- class heq(Stub):
773
- """heq(a, b)
774
-
775
- Perform fp16 comparison, (a == b). Supported
776
- on fp16 operands only.
777
-
778
- Returns True if a and b are equal and False otherwise.
779
-
780
- """
781
-
782
- class hne(Stub):
783
- """hne(a, b)
784
-
785
- Perform fp16 comparison, (a != b). Supported
786
- on fp16 operands only.
787
-
788
- Returns True if a and b are not equal and False otherwise.
789
-
790
- """
791
-
792
- class hge(Stub):
793
- """hge(a, b)
794
-
795
- Perform fp16 comparison, (a >= b). Supported
796
- on fp16 operands only.
797
-
798
- Returns True if a is >= b and False otherwise.
799
-
800
- """
801
-
802
- class hgt(Stub):
803
- """hgt(a, b)
804
-
805
- Perform fp16 comparison, (a > b). Supported
806
- on fp16 operands only.
807
-
808
- Returns True if a is > b and False otherwise.
809
-
810
- """
811
-
812
- class hle(Stub):
813
- """hle(a, b)
814
-
815
- Perform fp16 comparison, (a <= b). Supported
816
- on fp16 operands only.
817
-
818
- Returns True if a is <= b and False otherwise.
819
-
820
- """
821
-
822
- class hlt(Stub):
823
- """hlt(a, b)
824
-
825
- Perform fp16 comparison, (a < b). Supported
826
- on fp16 operands only.
827
-
828
- Returns True if a is < b and False otherwise.
829
-
830
- """
831
-
832
- class hmax(Stub):
833
- """hmax(a, b)
834
-
835
- Perform fp16 maximum operation, max(a,b) Supported
836
- on fp16 operands only.
837
-
838
- Returns a if a is greater than b, returns b otherwise.
839
-
840
- """
841
-
842
- class hmin(Stub):
843
- """hmin(a, b)
844
-
845
- Perform fp16 minimum operation, min(a,b). Supported
846
- on fp16 operands only.
847
-
848
- Returns a if a is less than b, returns b otherwise.
849
-
850
- """
851
-
852
-
853
545
  # -------------------------------------------------------------------------------
854
546
  # vector types
855
547
 
@@ -3,9 +3,8 @@ from functools import cached_property
3
3
  import llvmlite.binding as ll
4
4
  from llvmlite import ir
5
5
  import warnings
6
-
6
+ from numba.cuda import cgutils
7
7
  from numba.core import (
8
- cgutils,
9
8
  compiler,
10
9
  config,
11
10
  itanium_mangler,
@@ -17,7 +16,7 @@ from numba.core.compiler_lock import global_compiler_lock
17
16
  from numba.core.dispatcher import Dispatcher
18
17
  from numba.core.errors import NumbaWarning
19
18
  from numba.core.base import BaseContext
20
- from numba.core.callconv import BaseCallConv, MinimalCallConv
19
+ from numba.cuda.core.callconv import BaseCallConv, MinimalCallConv
21
20
  from numba.core.typing import cmathdecl
22
21
  from numba.core import datamodel
23
22
 
@@ -33,7 +32,7 @@ from numba.cuda.models import cuda_data_manager
33
32
 
34
33
  class CUDATypingContext(typing.BaseContext):
35
34
  def load_additional_registries(self):
36
- from . import cudadecl, cudamath, libdevicedecl, vector_types
35
+ from . import cudadecl, cudamath, fp16, libdevicedecl, vector_types
37
36
  from numba.core.typing import enumdecl, cffi_utils
38
37
 
39
38
  self.install_registry(cudadecl.registry)
@@ -43,6 +42,7 @@ class CUDATypingContext(typing.BaseContext):
43
42
  self.install_registry(libdevicedecl.registry)
44
43
  self.install_registry(enumdecl.registry)
45
44
  self.install_registry(vector_types.typing_registry)
45
+ self.install_registry(fp16.typing_registry)
46
46
 
47
47
  def resolve_value_type(self, val):
48
48
  # treat other dispatcher object as another device function
@@ -148,7 +148,14 @@ class CUDATargetContext(BaseContext):
148
148
  from numba.misc import cffiimpl
149
149
  from numba.np import arrayobj # noqa: F401
150
150
  from numba.np import npdatetime # noqa: F401
151
- from . import cudaimpl, printimpl, libdeviceimpl, mathimpl, vector_types
151
+ from . import (
152
+ cudaimpl,
153
+ fp16,
154
+ printimpl,
155
+ libdeviceimpl,
156
+ mathimpl,
157
+ vector_types,
158
+ )
152
159
 
153
160
  # fix for #8940
154
161
  from numba.np.unsafe import ndarray # noqa F401
@@ -160,6 +167,7 @@ class CUDATargetContext(BaseContext):
160
167
  self.install_registry(cmathimpl.registry)
161
168
  self.install_registry(mathimpl.registry)
162
169
  self.install_registry(vector_types.impl_registry)
170
+ self.install_registry(fp16.target_registry)
163
171
 
164
172
  def codegen(self):
165
173
  return self._internal_codegen
@@ -1,25 +1,44 @@
1
1
  import os
2
2
  import platform
3
3
  import shutil
4
-
5
- from numba.tests.support import SerialMixin
4
+ import pytest
5
+ from datetime import datetime
6
+ from numba.core.utils import PYVERSION
6
7
  from numba.cuda.cuda_paths import get_conda_ctk
7
8
  from numba.cuda.cudadrv import driver, devices, libs
9
+ from numba.cuda.dispatcher import CUDADispatcher
8
10
  from numba.core import config
9
11
  from numba.tests.support import TestCase
10
12
  from pathlib import Path
13
+
14
+ from typing import Iterable, Union
15
+ from io import StringIO
11
16
  import unittest
12
17
 
18
+ if PYVERSION >= (3, 10):
19
+ from filecheck.matcher import Matcher
20
+ from filecheck.options import Options
21
+ from filecheck.parser import Parser, pattern_for_opts
22
+ from filecheck.finput import FInput
23
+
13
24
  numba_cuda_dir = Path(__file__).parent
14
25
  test_data_dir = numba_cuda_dir / "tests" / "data"
15
26
 
16
27
 
17
- class CUDATestCase(SerialMixin, TestCase):
28
+ @pytest.mark.usefixtures("initialize_from_pytest_config")
29
+ class CUDATestCase(TestCase):
18
30
  """
19
31
  For tests that use a CUDA device. Test methods in a CUDATestCase must not
20
32
  be run out of module order, because the ContextResettingTestCase may reset
21
33
  the context and destroy resources used by a normal CUDATestCase if any of
22
34
  its tests are run between tests from a CUDATestCase.
35
+
36
+ Methods assertFileCheckAsm and assertFileCheckLLVM will inspect a
37
+ CUDADispatcher and assert that the compilation artifacts match the
38
+ FileCheck checks given in the kernel's docstring.
39
+
40
+ Method assertFileCheckMatches can be used to assert that a given string
41
+ matches FileCheck checks, and is not specific to CUDADispatcher.
23
42
  """
24
43
 
25
44
  def setUp(self):
@@ -35,6 +54,134 @@ class CUDATestCase(SerialMixin, TestCase):
35
54
  config.CUDA_LOW_OCCUPANCY_WARNINGS = self._low_occupancy_warnings
36
55
  config.CUDA_WARN_ON_IMPLICIT_COPY = self._warn_on_implicit_copy
37
56
 
57
+ Signature = Union[tuple[type, ...], None]
58
+
59
+ def _getIRContents(
60
+ self,
61
+ ir_result: Union[dict[Signature, str], str],
62
+ signature: Union[Signature, None] = None,
63
+ ) -> Iterable[str]:
64
+ if isinstance(ir_result, str):
65
+ assert signature is None, (
66
+ "Cannot use signature because the kernel was only compiled for one signature"
67
+ )
68
+ return [ir_result]
69
+
70
+ if signature is None:
71
+ return list(ir_result.values())
72
+
73
+ return [ir_result[signature]]
74
+
75
+ def assertFileCheckAsm(
76
+ self,
77
+ ir_producer: CUDADispatcher,
78
+ signature: Union[tuple[type, ...], None] = None,
79
+ check_prefixes: tuple[str] = ("ASM",),
80
+ **extra_filecheck_options,
81
+ ) -> None:
82
+ """
83
+ Assert that the assembly output of the given CUDADispatcher matches
84
+ the FileCheck checks given in the kernel's docstring.
85
+ """
86
+ ir_contents = self._getIRContents(ir_producer.inspect_asm(), signature)
87
+ assert ir_contents, "No assembly output found for the given signature."
88
+ assert ir_producer.__doc__ is not None, (
89
+ "Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
90
+ )
91
+ check_patterns = ir_producer.__doc__
92
+ for ir_content in ir_contents:
93
+ self.assertFileCheckMatches(
94
+ ir_content,
95
+ check_patterns=check_patterns,
96
+ check_prefixes=check_prefixes,
97
+ **extra_filecheck_options,
98
+ )
99
+
100
+ def assertFileCheckLLVM(
101
+ self,
102
+ ir_producer: CUDADispatcher,
103
+ signature: Union[tuple[type, ...], None] = None,
104
+ check_prefixes: tuple[str] = ("LLVM",),
105
+ **extra_filecheck_options,
106
+ ) -> None:
107
+ """
108
+ Assert that the LLVM IR output of the given CUDADispatcher matches
109
+ the FileCheck checks given in the kernel's docstring.
110
+ """
111
+ ir_contents = self._getIRContents(ir_producer.inspect_llvm(), signature)
112
+ assert ir_contents, "No LLVM IR output found for the given signature."
113
+ assert ir_producer.__doc__ is not None, (
114
+ "Kernel docstring is required. To pass checks explicitly, use assertFileCheckMatches."
115
+ )
116
+ check_patterns = ir_producer.__doc__
117
+ for ir_content in ir_contents:
118
+ assert ir_content, (
119
+ "LLVM IR content is empty for the given signature."
120
+ )
121
+ self.assertFileCheckMatches(
122
+ ir_content,
123
+ check_patterns=check_patterns,
124
+ check_prefixes=check_prefixes,
125
+ **extra_filecheck_options,
126
+ )
127
+
128
+ def assertFileCheckMatches(
129
+ self,
130
+ ir_content: str,
131
+ check_patterns: str,
132
+ check_prefixes: tuple[str] = ("CHECK",),
133
+ **extra_filecheck_options,
134
+ ) -> None:
135
+ """
136
+ Assert that the given string matches the passed FileCheck checks.
137
+
138
+ Args:
139
+ ir_content: The string to check against.
140
+ check_patterns: The FileCheck checks to use.
141
+ check_prefixes: The prefixes to use for the FileCheck checks.
142
+ extra_filecheck_options: Extra options to pass to FileCheck.
143
+ """
144
+ if PYVERSION < (3, 10):
145
+ self.skipTest("FileCheck requires Python 3.10 or later")
146
+ opts = Options(
147
+ match_filename="-",
148
+ check_prefixes=list(check_prefixes),
149
+ **extra_filecheck_options,
150
+ )
151
+ input_file = FInput(fname="-", content=ir_content)
152
+ parser = Parser(opts, StringIO(check_patterns), *pattern_for_opts(opts))
153
+ matcher = Matcher(opts, input_file, parser)
154
+ matcher.stderr = StringIO()
155
+ result = matcher.run()
156
+ if result != 0:
157
+ dump_instructions = ""
158
+ if self._dump_failed_filechecks:
159
+ dump_directory = Path(
160
+ datetime.now().strftime("numba-ir-%Y_%m_%d_%H_%M_%S")
161
+ )
162
+ if not dump_directory.exists():
163
+ dump_directory.mkdir(parents=True, exist_ok=True)
164
+ base_path = self.id().replace(".", "_")
165
+ ir_dump = dump_directory / Path(base_path).with_suffix(".ll")
166
+ checks_dump = dump_directory / Path(base_path).with_suffix(
167
+ ".checks"
168
+ )
169
+ with (
170
+ open(ir_dump, "w") as ir_file,
171
+ open(checks_dump, "w") as checks_file,
172
+ ):
173
+ _ = ir_file.write(ir_content + "\n")
174
+ _ = checks_file.write(check_patterns)
175
+ dump_instructions = f"Reproduce with:\n\nfilecheck --check-prefixes={','.join(check_prefixes)} {checks_dump} --input-file={ir_dump}"
176
+
177
+ self.fail(
178
+ f"FileCheck failed:\n{matcher.stderr.getvalue()}\n\n"
179
+ + f"Check prefixes:\n{check_prefixes}\n\n"
180
+ + f"Check patterns:\n{check_patterns}\n"
181
+ + f"IR:\n{ir_content}\n\n"
182
+ + dump_instructions
183
+ )
184
+
38
185
 
39
186
  class ContextResettingTestCase(CUDATestCase):
40
187
  """
@@ -127,8 +274,8 @@ def skip_if_mvc_enabled(reason):
127
274
  def skip_if_mvc_libraries_unavailable(fn):
128
275
  libs_available = False
129
276
  try:
130
- import cubinlinker # noqa: F401
131
- import ptxcompiler # noqa: F401
277
+ import cubinlinker # noqa: F401 # type: ignore
278
+ import ptxcompiler # noqa: F401 # type: ignore
132
279
 
133
280
  libs_available = True
134
281
  except ImportError:
@@ -189,6 +336,10 @@ def skip_if_cudadevrt_missing(fn):
189
336
  return unittest.skipIf(cudadevrt_missing(), "cudadevrt missing")(fn)
190
337
 
191
338
 
339
+ def skip_if_nvjitlink_missing(reason):
340
+ return unittest.skipIf(not driver._have_nvjitlink(), reason)
341
+
342
+
192
343
  class ForeignArray(object):
193
344
  """
194
345
  Class for emulating an array coming from another library through the CUDA
@@ -0,0 +1,113 @@
1
+ import cmath
2
+
3
+
4
+ def div_usecase(x, y):
5
+ return x / y
6
+
7
+
8
+ def real_usecase(x):
9
+ return x.real
10
+
11
+
12
+ def imag_usecase(x):
13
+ return x.imag
14
+
15
+
16
+ def conjugate_usecase(x):
17
+ return x.conjugate()
18
+
19
+
20
+ def acos_usecase(x):
21
+ return cmath.acos(x)
22
+
23
+
24
+ def cos_usecase(x):
25
+ return cmath.cos(x)
26
+
27
+
28
+ def asin_usecase(x):
29
+ return cmath.asin(x)
30
+
31
+
32
+ def sin_usecase(x):
33
+ return cmath.sin(x)
34
+
35
+
36
+ def atan_usecase(x):
37
+ return cmath.atan(x)
38
+
39
+
40
+ def tan_usecase(x):
41
+ return cmath.tan(x)
42
+
43
+
44
+ def acosh_usecase(x):
45
+ return cmath.acosh(x)
46
+
47
+
48
+ def cosh_usecase(x):
49
+ return cmath.cosh(x)
50
+
51
+
52
+ def asinh_usecase(x):
53
+ return cmath.asinh(x)
54
+
55
+
56
+ def sinh_usecase(x):
57
+ return cmath.sinh(x)
58
+
59
+
60
+ def atanh_usecase(x):
61
+ return cmath.atanh(x)
62
+
63
+
64
+ def tanh_usecase(x):
65
+ return cmath.tanh(x)
66
+
67
+
68
+ def exp_usecase(x):
69
+ return cmath.exp(x)
70
+
71
+
72
+ def isfinite_usecase(x):
73
+ return cmath.isfinite(x)
74
+
75
+
76
+ def isinf_usecase(x):
77
+ return cmath.isinf(x)
78
+
79
+
80
+ def isnan_usecase(x):
81
+ return cmath.isnan(x)
82
+
83
+
84
+ def log_usecase(x):
85
+ return cmath.log(x)
86
+
87
+
88
+ def log_base_usecase(x, base):
89
+ return cmath.log(x, base)
90
+
91
+
92
+ def log10_usecase(x):
93
+ return cmath.log10(x)
94
+
95
+
96
+ def phase_usecase(x):
97
+ return cmath.phase(x)
98
+
99
+
100
+ def polar_usecase(x):
101
+ return cmath.polar(x)
102
+
103
+
104
+ def polar_as_complex_usecase(x):
105
+ return complex(*cmath.polar(x))
106
+
107
+
108
+ def rect_usecase(r, phi):
109
+ return cmath.rect(r, phi)
110
+
111
+
112
+ def sqrt_usecase(x):
113
+ return cmath.sqrt(x)