numba-cuda 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (88) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +2 -2
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +1 -1
  5. numba_cuda/numba/cuda/api.py +2 -7
  6. numba_cuda/numba/cuda/compiler.py +7 -4
  7. numba_cuda/numba/cuda/core/interpreter.py +3592 -0
  8. numba_cuda/numba/cuda/core/ir_utils.py +2645 -0
  9. numba_cuda/numba/cuda/core/sigutils.py +55 -0
  10. numba_cuda/numba/cuda/cuda_paths.py +9 -17
  11. numba_cuda/numba/cuda/cudadecl.py +1 -1
  12. numba_cuda/numba/cuda/cudadrv/driver.py +4 -19
  13. numba_cuda/numba/cuda/cudadrv/libs.py +1 -2
  14. numba_cuda/numba/cuda/cudadrv/nvrtc.py +44 -44
  15. numba_cuda/numba/cuda/cudadrv/nvvm.py +3 -18
  16. numba_cuda/numba/cuda/cudadrv/runtime.py +12 -1
  17. numba_cuda/numba/cuda/cudamath.py +1 -1
  18. numba_cuda/numba/cuda/decorators.py +4 -3
  19. numba_cuda/numba/cuda/deviceufunc.py +2 -1
  20. numba_cuda/numba/cuda/dispatcher.py +3 -2
  21. numba_cuda/numba/cuda/extending.py +1 -1
  22. numba_cuda/numba/cuda/itanium_mangler.py +211 -0
  23. numba_cuda/numba/cuda/libdevicedecl.py +1 -1
  24. numba_cuda/numba/cuda/libdevicefuncs.py +1 -1
  25. numba_cuda/numba/cuda/lowering.py +1 -1
  26. numba_cuda/numba/cuda/simulator/api.py +1 -1
  27. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +0 -7
  28. numba_cuda/numba/cuda/target.py +1 -2
  29. numba_cuda/numba/cuda/testing.py +4 -6
  30. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +80 -0
  31. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +1 -1
  32. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  33. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +1 -1
  34. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  35. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +1 -1
  36. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +1 -1
  37. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +1 -1
  38. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +4 -6
  39. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +0 -4
  40. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  41. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +1 -3
  42. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +1 -3
  43. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +146 -3
  44. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +1 -1
  45. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +0 -4
  46. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -1
  47. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  49. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +1 -284
  50. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +473 -0
  51. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +1 -1
  52. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  53. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -6
  54. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +1 -1
  55. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +1 -1
  56. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +295 -0
  57. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +1 -1
  58. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  59. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +1 -1
  60. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +5 -1
  61. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +1 -1
  62. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +1 -1
  63. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +1 -1
  64. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +1 -1
  65. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +1 -1
  66. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +1 -1
  67. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +1 -1
  68. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +1 -1
  69. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +1 -1
  70. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +1 -1
  71. numba_cuda/numba/cuda/tests/nocuda/test_import.py +1 -1
  72. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -2
  73. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +1 -1
  74. numba_cuda/numba/cuda/tests/support.py +752 -0
  75. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +3 -3
  76. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +4 -1
  77. numba_cuda/numba/cuda/typing/__init__.py +8 -0
  78. numba_cuda/numba/cuda/typing/templates.py +1453 -0
  79. numba_cuda/numba/cuda/vector_types.py +3 -3
  80. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/METADATA +21 -28
  81. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/RECORD +84 -79
  82. numba_cuda/numba/cuda/include/11/cuda_bf16.h +0 -3749
  83. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +0 -2683
  84. numba_cuda/numba/cuda/include/11/cuda_fp16.h +0 -3794
  85. numba_cuda/numba/cuda/include/11/cuda_fp16.hpp +0 -2614
  86. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/WHEEL +0 -0
  87. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/licenses/LICENSE +0 -0
  88. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,211 @@
1
+ """
2
+ Itanium CXX ABI Mangler
3
+
4
+ Reference: https://itanium-cxx-abi.github.io/cxx-abi/abi.html
5
+
6
+ The basics of the mangling scheme.
7
+
8
+ We are hijacking the CXX mangling scheme for our use. We map Python modules
9
+ into CXX namespace. A `module1.submodule2.foo` is mapped to
10
+ `module1::submodule2::foo`. For parameterized numba types, we treat them as
11
+ templated types; for example, `array(int64, 1d, C)` becomes an
12
+ `array<int64, 1, C>`.
13
+
14
+ All mangled names are prefixed with "_Z". It is followed by the name of the
15
+ entity. A name contains one or more identifiers. Each identifier is encoded
16
+ as "<num of char><name>". If the name is namespaced and, therefore,
17
+ has multiple identifiers, the entire name is encoded as "N<name>E".
18
+
19
+ For functions, arguments types follow. There are condensed encodings for basic
20
+ built-in types; e.g. "i" for int, "f" for float. For other types, the
21
+ previously mentioned name encoding should be used.
22
+
23
+ For templated types, the template parameters are encoded immediately after the
24
+ name. If it is namespaced, it should be within the 'N' 'E' marker. Template
25
+ parameters are encoded in "I<params>E", where each parameter is encoded using
26
+ the mentioned name encoding scheme. Template parameters can contain literal
27
+ values like the '1' in the array type shown earlier. There is special encoding
28
+ scheme for them to avoid leading digits.
29
+ """
30
+
31
+ import re
32
+
33
+ from numba.core import types
34
+
35
+
36
+ # According the scheme, valid characters for mangled names are [a-zA-Z0-9_].
37
+ # We borrow the '_' as the escape character to encode invalid char into
38
+ # '_xx' where 'xx' is the hex codepoint.
39
+ _re_invalid_char = re.compile(r"[^a-z0-9_]", re.I)
40
+
41
+ PREFIX = "_Z"
42
+
43
+ # Numba types to mangled type code. These correspond with the codes listed in
44
+ # https://itanium-cxx-abi.github.io/cxx-abi/abi.html#mangling-builtin
45
+ N2CODE = {
46
+ types.void: "v",
47
+ types.boolean: "b",
48
+ types.uint8: "h",
49
+ types.int8: "a",
50
+ types.uint16: "t",
51
+ types.int16: "s",
52
+ types.uint32: "j",
53
+ types.int32: "i",
54
+ types.uint64: "y",
55
+ types.int64: "x",
56
+ types.float16: "Dh",
57
+ types.float32: "f",
58
+ types.float64: "d",
59
+ }
60
+
61
+
62
+ def _escape_string(text):
63
+ """Escape the given string so that it only contains ASCII characters
64
+ of [a-zA-Z0-9_$].
65
+
66
+ The dollar symbol ($) and other invalid characters are escaped into
67
+ the string sequence of "$xx" where "xx" is the hex codepoint of the char.
68
+
69
+ Multibyte characters are encoded into utf8 and converted into the above
70
+ hex format.
71
+ """
72
+
73
+ def repl(m):
74
+ return "".join(("_%02x" % ch) for ch in m.group(0).encode("utf8"))
75
+
76
+ ret = re.sub(_re_invalid_char, repl, text)
77
+ # Return str if we got a unicode (for py2)
78
+ if not isinstance(ret, str):
79
+ return ret.encode("ascii")
80
+ return ret
81
+
82
+
83
+ def _fix_lead_digit(text):
84
+ """
85
+ Fix text with leading digit
86
+ """
87
+ if text and text[0].isdigit():
88
+ return "_" + text
89
+ else:
90
+ return text
91
+
92
+
93
+ def _len_encoded(string):
94
+ """
95
+ Prefix string with digit indicating the length.
96
+ Add underscore if string is prefixed with digits.
97
+ """
98
+ string = _fix_lead_digit(string)
99
+ return "%u%s" % (len(string), string)
100
+
101
+
102
+ def mangle_abi_tag(abi_tag: str) -> str:
103
+ return "B" + _len_encoded(_escape_string(abi_tag))
104
+
105
+
106
+ def mangle_identifier(ident, template_params="", *, abi_tags=(), uid=None):
107
+ """
108
+ Mangle the identifier with optional template parameters and abi_tags.
109
+
110
+ Note:
111
+
112
+ This treats '.' as '::' in C++.
113
+ """
114
+ if uid is not None:
115
+ # Add uid to abi-tags
116
+ abi_tags = (f"v{uid}", *abi_tags)
117
+ parts = [_len_encoded(_escape_string(x)) for x in ident.split(".")]
118
+ enc_abi_tags = list(map(mangle_abi_tag, abi_tags))
119
+ extras = template_params + "".join(enc_abi_tags)
120
+ if len(parts) > 1:
121
+ return "N%s%sE" % ("".join(parts), extras)
122
+ else:
123
+ return "%s%s" % (parts[0], extras)
124
+
125
+
126
+ def mangle_type_or_value(typ):
127
+ """
128
+ Mangle type parameter and arbitrary value.
129
+ """
130
+ # Handle numba types
131
+ if isinstance(typ, types.Type):
132
+ if typ in N2CODE:
133
+ return N2CODE[typ]
134
+ else:
135
+ return mangle_templated_ident(*typ.mangling_args)
136
+ # Handle integer literal
137
+ elif isinstance(typ, int):
138
+ return "Li%dE" % typ
139
+ # Handle str as identifier
140
+ elif isinstance(typ, str):
141
+ return mangle_identifier(typ)
142
+ # Otherwise
143
+ else:
144
+ enc = _escape_string(str(typ))
145
+ return _len_encoded(enc)
146
+
147
+
148
+ # Alias
149
+ mangle_type = mangle_type_or_value
150
+ mangle_value = mangle_type_or_value
151
+
152
+
153
+ def mangle_templated_ident(identifier, parameters):
154
+ """
155
+ Mangle templated identifier.
156
+ """
157
+ template_params = (
158
+ "I%sE" % "".join(map(mangle_type_or_value, parameters))
159
+ if parameters
160
+ else ""
161
+ )
162
+ return mangle_identifier(identifier, template_params)
163
+
164
+
165
+ def mangle_args(argtys):
166
+ """
167
+ Mangle sequence of Numba type objects and arbitrary values.
168
+ """
169
+ return "".join([mangle_type_or_value(t) for t in argtys])
170
+
171
+
172
+ def mangle(ident, argtys, *, abi_tags=(), uid=None):
173
+ """
174
+ Mangle identifier with Numba type objects and abi-tags.
175
+ """
176
+ return "".join(
177
+ [
178
+ PREFIX,
179
+ mangle_identifier(ident, abi_tags=abi_tags, uid=uid),
180
+ mangle_args(argtys),
181
+ ]
182
+ )
183
+
184
+
185
+ def prepend_namespace(mangled, ns):
186
+ """
187
+ Prepend namespace to mangled name.
188
+ """
189
+ if not mangled.startswith(PREFIX):
190
+ raise ValueError("input is not a mangled name")
191
+ elif mangled.startswith(PREFIX + "N"):
192
+ # nested
193
+ remaining = mangled[3:]
194
+ ret = PREFIX + "N" + mangle_identifier(ns) + remaining
195
+ else:
196
+ # non-nested
197
+ remaining = mangled[2:]
198
+ head, tail = _split_mangled_ident(remaining)
199
+ ret = PREFIX + "N" + mangle_identifier(ns) + head + "E" + tail
200
+ return ret
201
+
202
+
203
+ def _split_mangled_ident(mangled):
204
+ """
205
+ Returns `(head, tail)` where `head` is the `<len> + <name>` encoded
206
+ identifier and `tail` is the remaining.
207
+ """
208
+ ct = int(mangled)
209
+ ctlen = len(str(ct))
210
+ at = ctlen + ct
211
+ return mangled[:at], mangled[at:]
@@ -1,5 +1,5 @@
1
1
  from numba.cuda import libdevice, libdevicefuncs
2
- from numba.core.typing.templates import ConcreteTemplate, Registry
2
+ from numba.cuda.typing.templates import ConcreteTemplate, Registry
3
3
 
4
4
  registry = Registry()
5
5
  register_global = registry.register_global
@@ -2,7 +2,7 @@ from collections import namedtuple
2
2
  from textwrap import indent
3
3
 
4
4
  from numba.types import float32, float64, int16, int32, int64, void, Tuple
5
- from numba.core.typing.templates import signature
5
+ from numba.cuda.typing.templates import signature
6
6
 
7
7
  arg = namedtuple("arg", ("name", "ty", "is_ptr"))
8
8
 
@@ -14,11 +14,11 @@ from numba.core import (
14
14
  funcdesc,
15
15
  generators,
16
16
  config,
17
- ir_utils,
18
17
  cgutils,
19
18
  removerefctpass,
20
19
  targetconfig,
21
20
  )
21
+ from numba.cuda.core import ir_utils
22
22
  from numba.core.errors import (
23
23
  LoweringError,
24
24
  new_error_context,
@@ -17,8 +17,8 @@ from .cudadrv.linkable_code import (
17
17
  LTOIR, # noqa: F401
18
18
  ) # noqa: F401
19
19
  from .kernel import FakeCUDAKernel
20
- from numba.core.sigutils import is_signature
21
20
  from numba.core import config
21
+ from numba.cuda.core.sigutils import is_signature
22
22
  from warnings import warn
23
23
  from ..args import In, Out, InOut # noqa: F401
24
24
 
@@ -3,8 +3,6 @@ Most of the driver API is unsupported in the simulator, but some stubs are
3
3
  provided to allow tests to import correctly.
4
4
  """
5
5
 
6
- from numba import config
7
-
8
6
 
9
7
  def device_memset(dst, val, size, stream=0):
10
8
  dst.view("u1")[:size].fill(bytes([val])[0])
@@ -63,11 +61,6 @@ def launch_kernel(*args, **kwargs):
63
61
 
64
62
  USE_NV_BINDING = False
65
63
 
66
- PyNvJitLinker = None
67
-
68
- if config.ENABLE_CUDASIM:
69
- config.CUDA_ENABLE_PYNVJITLINK = False
70
-
71
64
 
72
65
  def _have_nvjitlink():
73
66
  return False
@@ -3,11 +3,10 @@ from functools import cached_property
3
3
  import llvmlite.binding as ll
4
4
  from llvmlite import ir
5
5
  import warnings
6
- from numba.cuda import cgutils
6
+ from numba.cuda import cgutils, itanium_mangler
7
7
  from numba.core import (
8
8
  compiler,
9
9
  config,
10
- itanium_mangler,
11
10
  targetconfig,
12
11
  types,
13
12
  typing,
@@ -8,7 +8,7 @@ from numba.cuda.cuda_paths import get_conda_ctk
8
8
  from numba.cuda.cudadrv import driver, devices, libs
9
9
  from numba.cuda.dispatcher import CUDADispatcher
10
10
  from numba.core import config
11
- from numba.tests.support import TestCase
11
+ from numba.cuda.tests.support import TestCase
12
12
  from pathlib import Path
13
13
 
14
14
  from typing import Iterable, Union
@@ -154,7 +154,6 @@ class CUDATestCase(TestCase):
154
154
  matcher.stderr = StringIO()
155
155
  result = matcher.run()
156
156
  if result != 0:
157
- dump_instructions = ""
158
157
  if self._dump_failed_filechecks:
159
158
  dump_directory = Path(
160
159
  datetime.now().strftime("numba-ir-%Y_%m_%d_%H_%M_%S")
@@ -172,13 +171,12 @@ class CUDATestCase(TestCase):
172
171
  ):
173
172
  _ = ir_file.write(ir_content + "\n")
174
173
  _ = checks_file.write(check_patterns)
175
- dump_instructions = f"Reproduce with:\n\nfilecheck --check-prefixes={','.join(check_prefixes)} {checks_dump} --input-file={ir_dump}"
174
+ dump_instructions = f"Reproduce with:\n\nfilecheck --check-prefixes={','.join(check_prefixes)} {checks_dump} --input-file {ir_dump}"
175
+ else:
176
+ dump_instructions = "Rerun with --dump-failed-filechecks to generate a reproducer."
176
177
 
177
178
  self.fail(
178
179
  f"FileCheck failed:\n{matcher.stderr.getvalue()}\n\n"
179
- + f"Check prefixes:\n{check_prefixes}\n\n"
180
- + f"Check patterns:\n{check_patterns}\n"
181
- + f"IR:\n{ir_content}\n\n"
182
180
  + dump_instructions
183
181
  )
184
182
 
@@ -0,0 +1,80 @@
1
+ # -*- coding: utf-8 -*-
2
+ from numba import int32, int64, uint32, uint64, float32, float64
3
+ from numba.core.types import range_iter32_type
4
+ from numba.cuda import itanium_mangler
5
+ import unittest
6
+
7
+
8
+ class TestItaniumManager(unittest.TestCase):
9
+ def test_ident(self):
10
+ got = itanium_mangler.mangle_identifier("apple")
11
+ expect = "5apple"
12
+ self.assertEqual(expect, got)
13
+
14
+ got = itanium_mangler.mangle_identifier("ap_ple")
15
+ expect = "6ap_ple"
16
+ self.assertEqual(expect, got)
17
+
18
+ got = itanium_mangler.mangle_identifier("apple213")
19
+ expect = "8apple213"
20
+ self.assertEqual(expect, got)
21
+
22
+ def test_types(self):
23
+ got = itanium_mangler.mangle_type(int32)
24
+ expect = "i"
25
+ self.assertEqual(expect, got)
26
+
27
+ got = itanium_mangler.mangle_type(int64)
28
+ expect = "x"
29
+ self.assertEqual(expect, got)
30
+
31
+ got = itanium_mangler.mangle_type(uint32)
32
+ expect = "j"
33
+ self.assertEqual(expect, got)
34
+
35
+ got = itanium_mangler.mangle_type(uint64)
36
+ expect = "y"
37
+ self.assertEqual(expect, got)
38
+
39
+ got = itanium_mangler.mangle_type(float32)
40
+ expect = "f"
41
+ self.assertEqual(expect, got)
42
+
43
+ got = itanium_mangler.mangle_type(float64)
44
+ expect = "d"
45
+ self.assertEqual(expect, got)
46
+
47
+ def test_function(self):
48
+ got = itanium_mangler.mangle("what", [int32, float32])
49
+ expect = "_Z4whatif"
50
+ self.assertEqual(expect, got)
51
+
52
+ got = itanium_mangler.mangle(
53
+ "a_little_brown_fox", [uint64, uint32, float64]
54
+ )
55
+ expect = "_Z18a_little_brown_foxyjd"
56
+ self.assertEqual(expect, got)
57
+
58
+ def test_custom_type(self):
59
+ got = itanium_mangler.mangle_type(range_iter32_type)
60
+ name = str(range_iter32_type)
61
+ expect = "{n}{name}".format(n=len(name), name=name)
62
+ self.assertEqual(expect, got)
63
+
64
+ def test_mangle_literal(self):
65
+ # check int
66
+ got = itanium_mangler.mangle_value(123)
67
+ expect = "Li123E"
68
+ self.assertEqual(expect, got)
69
+ # check float (not handled using standard)
70
+ got = itanium_mangler.mangle_value(12.3)
71
+ self.assertRegex(got, r"^\d+_12_[0-9a-z][0-9a-z]3$")
72
+
73
+ def test_mangle_unicode(self):
74
+ name = "f∂ƒ©z"
75
+ got = itanium_mangler.mangle_identifier(name)
76
+ self.assertRegex(got, r"^\d+f(_[a-z0-9][a-z0-9])+z$")
77
+
78
+
79
+ if __name__ == "__main__":
80
+ unittest.main()
@@ -4,7 +4,7 @@ from numba.cuda.cudadrv import devicearray
4
4
  from numba import cuda
5
5
  from numba.cuda.testing import unittest, CUDATestCase
6
6
  from numba.cuda.testing import skip_on_cudasim
7
- from numba.tests.support import IS_NUMPY_2
7
+ from numba.cuda.tests.support import IS_NUMPY_2
8
8
 
9
9
 
10
10
  class TestCudaNDArray(CUDATestCase):
@@ -9,7 +9,7 @@ from numba.cuda.testing import (
9
9
  skip_if_external_memmgr,
10
10
  CUDATestCase,
11
11
  )
12
- from numba.tests.support import captured_stderr
12
+ from numba.cuda.tests.support import captured_stderr
13
13
  from numba.core import config
14
14
 
15
15
 
@@ -9,7 +9,7 @@ from numba.cuda.testing import (
9
9
  skip_on_cudasim,
10
10
  skip_under_cuda_memcheck,
11
11
  )
12
- from numba.tests.support import captured_stdout
12
+ from numba.cuda.tests.support import captured_stdout
13
13
 
14
14
 
15
15
  class TestCudaDetect(CUDATestCase):
@@ -5,7 +5,7 @@ import weakref
5
5
  from numba import cuda
6
6
  from numba.core import config
7
7
  from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim
8
- from numba.tests.support import linux_only
8
+ from numba.cuda.tests.support import linux_only
9
9
 
10
10
  if not config.ENABLE_CUDASIM:
11
11
 
@@ -6,7 +6,7 @@ from numba.cuda.testing import skip_on_cudasim, skip_if_cuda_includes_missing
6
6
  from numba.cuda.testing import CUDATestCase, test_data_dir
7
7
  from numba.cuda.cudadrv.driver import CudaAPIError, _Linker, LinkerError
8
8
  from numba.cuda import require_context
9
- from numba.tests.support import ignore_internal_warnings
9
+ from numba.cuda.tests.support import ignore_internal_warnings
10
10
  from numba import cuda, void, float64, int64, int32, typeof, float32
11
11
  from numba.cuda.cudadrv.error import NvrtcError
12
12
 
@@ -4,7 +4,7 @@ from numba.cuda.cudadrv.driver import device_memset, driver, USE_NV_BINDING
4
4
  from numba import cuda
5
5
  from numba.cuda.testing import unittest, ContextResettingTestCase
6
6
  from numba.cuda.testing import skip_on_cudasim, skip_on_arm
7
- from numba.tests.support import linux_only
7
+ from numba.cuda.tests.support import linux_only
8
8
 
9
9
 
10
10
  @skip_on_cudasim("CUDA Driver API unsupported in the simulator")
@@ -6,7 +6,7 @@ from numba.cuda.testing import (
6
6
  skip_under_cuda_memcheck,
7
7
  skip_if_mvc_libraries_unavailable,
8
8
  )
9
- from numba.tests.support import linux_only
9
+ from numba.cuda.tests.support import linux_only
10
10
 
11
11
 
12
12
  def child_test():
@@ -13,13 +13,11 @@ class TestArchOption(unittest.TestCase):
13
13
  self.assertEqual(nvrtc.get_arch_option(8, 5), "compute_80")
14
14
  self.assertEqual(nvrtc.get_arch_option(9, 1), "compute_90")
15
15
  # Test known arch.
16
- supported_cc = nvrtc.NVRTC().get_supported_archs()
17
- for arch in supported_cc:
18
- self.assertEqual(
19
- nvrtc.get_arch_option(*arch), "compute_%d%d" % arch
20
- )
16
+ supported_ccs = nvrtc.get_supported_ccs()
17
+ for cc in supported_ccs:
18
+ self.assertEqual(nvrtc.get_arch_option(*cc), "compute_%d%d" % cc)
21
19
  self.assertEqual(
22
- nvrtc.get_arch_option(1000, 0), "compute_%d%d" % supported_cc[-1]
20
+ nvrtc.get_arch_option(1000, 0), "compute_%d%d" % supported_ccs[-1]
23
21
  )
24
22
 
25
23
 
@@ -25,10 +25,6 @@ class TestNvvmDriver(unittest.TestCase):
25
25
  # ("-gen-lto") - all other NVVM options are of the form
26
26
  # "-<name>=<value>"
27
27
 
28
- # -gen-lto is not available prior to CUDA 11.5
29
- if runtime.get_version() < (11, 5):
30
- self.skipTest("-gen-lto unavailable in this toolkit version")
31
-
32
28
  nvvmir = self.get_nvvmir()
33
29
  arch = "compute_%d%d" % nvrtc.get_lowest_supported_cc()
34
30
  ltoir = nvvm.compile_ir(nvvmir, opt=3, gen_lto=None, arch=arch)
@@ -7,7 +7,7 @@ from numba.cuda.testing import (
7
7
  skip_with_cuda_python,
8
8
  skip_under_cuda_memcheck,
9
9
  )
10
- from numba.tests.support import linux_only
10
+ from numba.cuda.tests.support import linux_only
11
11
 
12
12
 
13
13
  def child_test():
@@ -8,9 +8,7 @@ import math
8
8
  class TestBfloat16HighLevelBindings(CUDATestCase):
9
9
  def skip_unsupported(self):
10
10
  if not cuda.is_bfloat16_supported():
11
- self.skipTest(
12
- "bfloat16 requires compute capability 8.0+ and CUDA version>= 12.0"
13
- )
11
+ self.skipTest("bfloat16 requires compute capability 8.0+")
14
12
 
15
13
  def test_use_type_in_kernel(self):
16
14
  self.skip_unsupported()
@@ -43,9 +43,7 @@ dtypes = [int16, int32, int64, uint16, uint32, uint64, float32]
43
43
  class Bfloat16Test(CUDATestCase):
44
44
  def skip_unsupported(self):
45
45
  if not cuda.is_bfloat16_supported():
46
- self.skipTest(
47
- "bfloat16 requires compute capability 8.0+ and CUDA version>= 12.0"
48
- )
46
+ self.skipTest("bfloat16 requires compute capability 8.0+")
49
47
 
50
48
  def test_ctor(self):
51
49
  self.skip_unsupported()