numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (172) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +7 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
  129. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  130. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  134. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  139. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  141. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  143. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  146. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  147. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  148. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  151. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  152. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  153. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  154. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  155. numba_cuda/numba/cuda/tests/support.py +55 -15
  156. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  157. numba_cuda/numba/cuda/types.py +56 -0
  158. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  159. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  160. numba_cuda/numba/cuda/typing/context.py +751 -0
  161. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  162. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  163. numba_cuda/numba/cuda/typing/templates.py +7 -6
  164. numba_cuda/numba/cuda/ufuncs.py +3 -3
  165. numba_cuda/numba/cuda/utils.py +6 -112
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
  167. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
  168. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
  172. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,453 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+ """
4
+ Tests for SSA reconstruction
5
+ """
6
+
7
+ import sys
8
+ import copy
9
+ import logging
10
+
11
+ import numpy as np
12
+
13
+ from numba import types, cuda
14
+ from numba.cuda import jit
15
+ from numba.core import errors
16
+
17
+ from numba.extending import overload
18
+ from numba.cuda.tests.support import override_config
19
+ from numba.cuda.testing import CUDATestCase, skip_on_cudasim
20
+
21
+
22
+ _DEBUG = False
23
+
24
+ if _DEBUG:
25
+ # Enable debug logger on SSA reconstruction
26
+ ssa_logger = logging.getLogger("numba.cuda.core.ssa")
27
+ ssa_logger.setLevel(level=logging.DEBUG)
28
+ ssa_logger.addHandler(logging.StreamHandler(sys.stderr))
29
+
30
+
31
+ class SSABaseTest(CUDATestCase):
32
+ """
33
+ This class comes from numba tests, but has been modified to work with CUDA kernels.
34
+ Return values were replaced by output arrays, and tuple returns assign to elements of the output array.
35
+ """
36
+
37
+ def check_func(self, func, result_array, *args):
38
+ # For CUDA kernels, we need to create output arrays and call with [1,1] launch config
39
+ # Create GPU array with same shape as expected result array
40
+ gpu_result_array = cuda.to_device(np.zeros_like(result_array))
41
+
42
+ # Call the CUDA kernel
43
+ func[1, 1](gpu_result_array, *copy.deepcopy(args))
44
+ gpu_result = gpu_result_array.copy_to_host()
45
+
46
+ # Call the original Python function for expected result
47
+ cpu_result = np.zeros_like(result_array)
48
+ func.py_func(cpu_result, *copy.deepcopy(args))
49
+
50
+ # Compare all results
51
+ np.testing.assert_array_equal(gpu_result, cpu_result)
52
+
53
+
54
+ class TestSSA(SSABaseTest):
55
+ """
56
+ Contains tests to help isolate problems in SSA
57
+ """
58
+
59
+ def test_argument_name_reused(self):
60
+ @jit
61
+ def foo(result, x):
62
+ x += 1
63
+ result[0] = x
64
+
65
+ self.check_func(foo, np.array([124.0]), 123)
66
+
67
+ def test_if_else_redefine(self):
68
+ @jit
69
+ def foo(result, x, y):
70
+ z = x * y
71
+ if x < y:
72
+ z = x
73
+ else:
74
+ z = y
75
+ result[0] = z
76
+
77
+ self.check_func(foo, np.array([2.0]), 3, 2)
78
+ self.check_func(foo, np.array([2.0]), 2, 3)
79
+
80
+ def test_sum_loop(self):
81
+ @jit
82
+ def foo(result, n):
83
+ c = 0
84
+ for i in range(n):
85
+ c += i
86
+ result[0] = c
87
+
88
+ self.check_func(foo, np.array([0.0]), 0)
89
+ self.check_func(foo, np.array([45.0]), 10)
90
+
91
+ def test_sum_loop_2vars(self):
92
+ @jit
93
+ def foo(result, n):
94
+ c = 0
95
+ d = n
96
+ for i in range(n):
97
+ c += i
98
+ d += n
99
+ result[0] = c
100
+ result[1] = d
101
+
102
+ self.check_func(foo, np.array([0.0, 0.0]), 0)
103
+ self.check_func(foo, np.array([45.0, 110.0]), 10)
104
+
105
+ def test_sum_2d_loop(self):
106
+ @jit
107
+ def foo(result, n):
108
+ c = 0
109
+ for i in range(n):
110
+ for j in range(n):
111
+ c += j
112
+ c += i
113
+ result[0] = c
114
+
115
+ self.check_func(foo, np.array([0.0]), 0)
116
+ self.check_func(foo, np.array([495.0]), 10)
117
+
118
+ def check_undefined_var(self, should_warn):
119
+ @jit
120
+ def foo(result, n):
121
+ if n:
122
+ if n > 0:
123
+ c = 0
124
+ result[0] = c
125
+ else:
126
+ # variable c is not defined in this branch
127
+ c += 1
128
+ result[0] = c
129
+
130
+ if should_warn:
131
+ with self.assertWarns(errors.NumbaWarning) as warns:
132
+ # n=1 so we won't actually run the branch with the uninitialized
133
+ self.check_func(foo, np.array([0]), 1)
134
+ self.assertIn(
135
+ "Detected uninitialized variable c", str(warns.warning)
136
+ )
137
+ else:
138
+ self.check_func(foo, np.array([0]), 1)
139
+
140
+ with self.assertRaises(UnboundLocalError):
141
+ result = np.array([0])
142
+ foo.py_func(result, 0)
143
+
144
+ @skip_on_cudasim(
145
+ "Numba variable warnings are not supported in the simulator"
146
+ )
147
+ def test_undefined_var(self):
148
+ with override_config("ALWAYS_WARN_UNINIT_VAR", 0):
149
+ self.check_undefined_var(should_warn=False)
150
+ with override_config("ALWAYS_WARN_UNINIT_VAR", 1):
151
+ self.check_undefined_var(should_warn=True)
152
+
153
+ def test_phi_propagation(self):
154
+ @jit
155
+ def foo(result, actions):
156
+ n = 1
157
+
158
+ i = 0
159
+ ct = 0
160
+ while n > 0 and i < len(actions):
161
+ n -= 1
162
+
163
+ while actions[i]:
164
+ if actions[i]:
165
+ if actions[i]:
166
+ n += 10
167
+ actions[i] -= 1
168
+ else:
169
+ if actions[i]:
170
+ n += 20
171
+ actions[i] += 1
172
+
173
+ ct += n
174
+ ct += n
175
+ result[0] = ct
176
+ result[1] = n
177
+
178
+ self.check_func(foo, np.array([1, 2]), np.array([1, 2]))
179
+
180
+ def test_unhandled_undefined(self):
181
+ @cuda.jit
182
+ def function1(arg1, arg2, arg3, arg4, arg5):
183
+ # This function is auto-generated.
184
+ if arg1:
185
+ var1 = arg2
186
+ var2 = arg3
187
+ var3 = var2
188
+ var4 = arg1
189
+ return
190
+ else:
191
+ if arg2:
192
+ if arg4:
193
+ var5 = arg4 # noqa: F841
194
+ return
195
+ else:
196
+ var6 = var4
197
+ return
198
+ return var6
199
+ else:
200
+ if arg5:
201
+ if var1:
202
+ if arg5:
203
+ var1 = var6
204
+ return
205
+ else:
206
+ var7 = arg2 # noqa: F841
207
+ return arg2
208
+ return
209
+ else:
210
+ if var2:
211
+ arg5 = arg2
212
+ return arg1
213
+ else:
214
+ var6 = var3
215
+ return var4
216
+ return
217
+ return
218
+ else:
219
+ var8 = var1
220
+ return
221
+ return var8
222
+ var9 = var3 # noqa: F841
223
+ var10 = arg5 # noqa: F841
224
+ return var1
225
+
226
+ NONE_SENTINEL = 99
227
+
228
+ @cuda.jit
229
+ def function1_caller(result, arg1, arg2, arg3, arg4, arg5):
230
+ retval = function1(arg1, arg2, arg3, arg4, arg5)
231
+ if retval is None:
232
+ result[0] = NONE_SENTINEL
233
+ else:
234
+ result[0] = retval
235
+
236
+ # The argument values is not critical for re-creating the bug
237
+ # because the bug is in compile-time.
238
+
239
+ expect = function1.py_func(2, 3, 6, 0, 7)
240
+ if expect is None:
241
+ expect = NONE_SENTINEL
242
+ result = np.zeros(1, dtype=np.int64)
243
+ function1_caller[1, 1](result, 2, 3, 6, 0, 7)
244
+ got = result[0]
245
+ self.assertEqual(expect, got)
246
+
247
+
248
+ class TestReportedSSAIssues(SSABaseTest):
249
+ # Tests from issues
250
+ # https://github.com/numba/numba/issues?q=is%3Aopen+is%3Aissue+label%3ASSA
251
+
252
+ def test_issue2194(self):
253
+ @jit
254
+ def foo(result, V):
255
+ s = np.uint32(1)
256
+
257
+ for i in range(s):
258
+ V[i] = 1
259
+ for i in range(s, 1):
260
+ pass
261
+ result[0] = V[0]
262
+
263
+ V = np.empty(1)
264
+ self.check_func(foo, np.array([1.0]), V)
265
+
266
+ def test_issue3094(self):
267
+ @jit
268
+ def foo(result, pred):
269
+ if pred:
270
+ x = 1
271
+ else:
272
+ x = 0
273
+ result[0] = x
274
+
275
+ self.check_func(foo, np.array([0]), False)
276
+
277
+ def test_issue3931(self):
278
+ @jit
279
+ def foo(result, arr):
280
+ for i in range(1):
281
+ arr = arr.reshape(3 * 2)
282
+ arr = arr.reshape(3, 2)
283
+ # Copy result array elements
284
+ for i in range(arr.shape[0]):
285
+ for j in range(arr.shape[1]):
286
+ result[i, j] = arr[i, j]
287
+
288
+ result_gpu = np.zeros((3, 2))
289
+ self.check_func(foo, result_gpu, np.zeros((3, 2)))
290
+
291
+ def test_issue3976(self):
292
+ def overload_this(a):
293
+ return 42
294
+
295
+ @jit
296
+ def foo(result, a):
297
+ if a:
298
+ s = 5
299
+ s = overload_this(s)
300
+ else:
301
+ s = 99
302
+
303
+ result[0] = s
304
+
305
+ @overload(overload_this)
306
+ def ol(a):
307
+ return overload_this
308
+
309
+ self.check_func(foo, np.array([42]), True)
310
+
311
+ def test_issue3979(self):
312
+ @jit
313
+ def foo(result, A, B):
314
+ x = A[0]
315
+ y = B[0]
316
+ for i in A:
317
+ x = i
318
+ for i in B:
319
+ y = i
320
+ result[0] = x
321
+ result[1] = y
322
+
323
+ self.check_func(
324
+ foo, np.array([2, 4]), np.array([1, 2]), np.array([3, 4])
325
+ )
326
+
327
+ def test_issue5219(self):
328
+ def overload_this(a, b=None):
329
+ if isinstance(b, tuple):
330
+ b = b[0]
331
+ return b
332
+
333
+ @overload(overload_this)
334
+ def ol(a, b=None):
335
+ b_is_tuple = isinstance(b, (types.Tuple, types.UniTuple))
336
+
337
+ def impl(a, b=None):
338
+ if b_is_tuple is True:
339
+ b = b[0]
340
+ return b
341
+
342
+ return impl
343
+
344
+ @jit
345
+ def test_tuple(result, a, b):
346
+ result[0] = overload_this(a, b)
347
+
348
+ self.check_func(test_tuple, np.array([2]), 1, (2,))
349
+
350
+ def test_issue5223(self):
351
+ @jit
352
+ def bar(result, x):
353
+ if len(x) == 5:
354
+ for i in range(len(x)):
355
+ result[i] = x[i]
356
+ else:
357
+ # Manual copy since .copy() not available in CUDA
358
+ for i in range(len(x)):
359
+ result[i] = x[i] + 1
360
+
361
+ a = np.ones(5)
362
+ a.flags.writeable = False
363
+ expected = np.ones(5) # Since len(a) == 5, it should return unchanged
364
+ self.check_func(bar, expected, a)
365
+
366
+ def test_issue5243(self):
367
+ @jit
368
+ def foo(result, q, lin):
369
+ stencil_val = 0.0 # noqa: F841
370
+ stencil_val = q[0, 0] # noqa: F841
371
+ result[0] = lin[0]
372
+
373
+ lin = np.array([0.1, 0.6, 0.3])
374
+ self.check_func(foo, np.array([0.1]), np.zeros((2, 2)), lin)
375
+
376
+ def test_issue5482_missing_variable_init(self):
377
+ # Test error that lowering fails because variable is missing
378
+ # a definition before use.
379
+ @jit
380
+ def foo(result, x, v, n):
381
+ for i in range(n):
382
+ if i == 0:
383
+ if i == x:
384
+ pass
385
+ else:
386
+ problematic = v
387
+ else:
388
+ if i == x:
389
+ pass
390
+ else:
391
+ problematic = problematic + v
392
+ result[0] = problematic
393
+
394
+ self.check_func(foo, np.array([10]), 1, 5, 3)
395
+
396
+ def test_issue5493_unneeded_phi(self):
397
+ # Test error that unneeded phi is inserted because variable does not
398
+ # have a dominance definition.
399
+ data = (np.ones(2), np.ones(2))
400
+ A = np.ones(1)
401
+ B = np.ones(1)
402
+
403
+ @jit
404
+ def foo(res, m, n, data):
405
+ if len(data) == 1:
406
+ v0 = data[0]
407
+ else:
408
+ v0 = data[0]
409
+ # Unneeded PHI node for `problematic` would be placed here
410
+ for _ in range(1, len(data)):
411
+ v0[0] += A[0]
412
+
413
+ for t in range(1, m):
414
+ for idx in range(n):
415
+ t = B
416
+
417
+ if idx == 0:
418
+ if idx == n - 1:
419
+ pass
420
+ else:
421
+ res[0] = t[0]
422
+ else:
423
+ if idx == n - 1:
424
+ pass
425
+ else:
426
+ res[0] += t[0]
427
+
428
+ self.check_func(foo, np.array([10]), 10, 10, data)
429
+
430
+ def test_issue5623_equal_statements_in_same_bb(self):
431
+ def foo(pred, stack):
432
+ i = 0
433
+ c = 1
434
+
435
+ if pred is True:
436
+ stack[i] = c
437
+ i += 1
438
+ stack[i] = c
439
+ i += 1
440
+
441
+ python = np.array([0, 666])
442
+ foo(True, python)
443
+
444
+ nb = np.array([0, 666])
445
+
446
+ # Convert to CUDA kernel
447
+ foo_cuda = jit(foo)
448
+ foo_cuda[1, 1](True, nb)
449
+
450
+ expect = np.array([1, 1])
451
+
452
+ np.testing.assert_array_equal(python, expect)
453
+ np.testing.assert_array_equal(nb, expect)
@@ -4,7 +4,7 @@
4
4
  import numpy as np
5
5
  from numba import cuda, int32, float32
6
6
  from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase
7
- from numba.core.config import ENABLE_CUDASIM
7
+ from numba.cuda.core.config import ENABLE_CUDASIM
8
8
 
9
9
 
10
10
  def useless_syncthreads(ary):