numba-cuda 0.0.1__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. _numba_cuda_redirector.pth +1 -0
  2. _numba_cuda_redirector.py +74 -0
  3. numba_cuda/VERSION +1 -0
  4. numba_cuda/__init__.py +5 -0
  5. numba_cuda/_version.py +19 -0
  6. numba_cuda/numba/cuda/__init__.py +22 -0
  7. numba_cuda/numba/cuda/api.py +526 -0
  8. numba_cuda/numba/cuda/api_util.py +30 -0
  9. numba_cuda/numba/cuda/args.py +77 -0
  10. numba_cuda/numba/cuda/cg.py +62 -0
  11. numba_cuda/numba/cuda/codegen.py +378 -0
  12. numba_cuda/numba/cuda/compiler.py +422 -0
  13. numba_cuda/numba/cuda/cpp_function_wrappers.cu +47 -0
  14. numba_cuda/numba/cuda/cuda_fp16.h +3631 -0
  15. numba_cuda/numba/cuda/cuda_fp16.hpp +2465 -0
  16. numba_cuda/numba/cuda/cuda_paths.py +258 -0
  17. numba_cuda/numba/cuda/cudadecl.py +806 -0
  18. numba_cuda/numba/cuda/cudadrv/__init__.py +9 -0
  19. numba_cuda/numba/cuda/cudadrv/devicearray.py +904 -0
  20. numba_cuda/numba/cuda/cudadrv/devices.py +248 -0
  21. numba_cuda/numba/cuda/cudadrv/driver.py +3201 -0
  22. numba_cuda/numba/cuda/cudadrv/drvapi.py +398 -0
  23. numba_cuda/numba/cuda/cudadrv/dummyarray.py +452 -0
  24. numba_cuda/numba/cuda/cudadrv/enums.py +607 -0
  25. numba_cuda/numba/cuda/cudadrv/error.py +36 -0
  26. numba_cuda/numba/cuda/cudadrv/libs.py +176 -0
  27. numba_cuda/numba/cuda/cudadrv/ndarray.py +20 -0
  28. numba_cuda/numba/cuda/cudadrv/nvrtc.py +260 -0
  29. numba_cuda/numba/cuda/cudadrv/nvvm.py +707 -0
  30. numba_cuda/numba/cuda/cudadrv/rtapi.py +10 -0
  31. numba_cuda/numba/cuda/cudadrv/runtime.py +142 -0
  32. numba_cuda/numba/cuda/cudaimpl.py +1055 -0
  33. numba_cuda/numba/cuda/cudamath.py +140 -0
  34. numba_cuda/numba/cuda/decorators.py +189 -0
  35. numba_cuda/numba/cuda/descriptor.py +33 -0
  36. numba_cuda/numba/cuda/device_init.py +89 -0
  37. numba_cuda/numba/cuda/deviceufunc.py +908 -0
  38. numba_cuda/numba/cuda/dispatcher.py +1057 -0
  39. numba_cuda/numba/cuda/errors.py +59 -0
  40. numba_cuda/numba/cuda/extending.py +7 -0
  41. numba_cuda/numba/cuda/initialize.py +13 -0
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +77 -0
  43. numba_cuda/numba/cuda/intrinsics.py +198 -0
  44. numba_cuda/numba/cuda/kernels/__init__.py +0 -0
  45. numba_cuda/numba/cuda/kernels/reduction.py +262 -0
  46. numba_cuda/numba/cuda/kernels/transpose.py +65 -0
  47. numba_cuda/numba/cuda/libdevice.py +3382 -0
  48. numba_cuda/numba/cuda/libdevicedecl.py +17 -0
  49. numba_cuda/numba/cuda/libdevicefuncs.py +1057 -0
  50. numba_cuda/numba/cuda/libdeviceimpl.py +83 -0
  51. numba_cuda/numba/cuda/mathimpl.py +448 -0
  52. numba_cuda/numba/cuda/models.py +48 -0
  53. numba_cuda/numba/cuda/nvvmutils.py +235 -0
  54. numba_cuda/numba/cuda/printimpl.py +86 -0
  55. numba_cuda/numba/cuda/random.py +292 -0
  56. numba_cuda/numba/cuda/simulator/__init__.py +38 -0
  57. numba_cuda/numba/cuda/simulator/api.py +110 -0
  58. numba_cuda/numba/cuda/simulator/compiler.py +9 -0
  59. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +2 -0
  60. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +432 -0
  61. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +117 -0
  62. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +62 -0
  63. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +4 -0
  64. numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +4 -0
  65. numba_cuda/numba/cuda/simulator/cudadrv/error.py +6 -0
  66. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +2 -0
  67. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +29 -0
  68. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +19 -0
  69. numba_cuda/numba/cuda/simulator/kernel.py +308 -0
  70. numba_cuda/numba/cuda/simulator/kernelapi.py +495 -0
  71. numba_cuda/numba/cuda/simulator/reduction.py +15 -0
  72. numba_cuda/numba/cuda/simulator/vector_types.py +58 -0
  73. numba_cuda/numba/cuda/simulator_init.py +17 -0
  74. numba_cuda/numba/cuda/stubs.py +902 -0
  75. numba_cuda/numba/cuda/target.py +440 -0
  76. numba_cuda/numba/cuda/testing.py +202 -0
  77. numba_cuda/numba/cuda/tests/__init__.py +58 -0
  78. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +8 -0
  79. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +145 -0
  80. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +145 -0
  81. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +375 -0
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +21 -0
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +179 -0
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +235 -0
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +22 -0
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +193 -0
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +547 -0
  88. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +249 -0
  89. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +81 -0
  90. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +192 -0
  91. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +38 -0
  92. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +65 -0
  93. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +139 -0
  94. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +37 -0
  95. numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +12 -0
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +317 -0
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +127 -0
  98. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +54 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +199 -0
  100. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +37 -0
  101. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +20 -0
  102. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +149 -0
  103. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +36 -0
  104. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +85 -0
  105. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +41 -0
  106. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +122 -0
  107. numba_cuda/numba/cuda/tests/cudapy/__init__.py +8 -0
  108. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +234 -0
  109. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +41 -0
  110. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +58 -0
  111. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +30 -0
  112. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +100 -0
  113. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +42 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_array.py +260 -0
  115. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +201 -0
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +35 -0
  117. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1620 -0
  118. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +120 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +24 -0
  120. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +545 -0
  121. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +257 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +33 -0
  123. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +276 -0
  124. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +296 -0
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +20 -0
  126. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +129 -0
  127. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +176 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +147 -0
  129. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +435 -0
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +90 -0
  131. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +94 -0
  132. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +101 -0
  133. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +221 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +222 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +700 -0
  136. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +121 -0
  137. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +79 -0
  138. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +174 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +155 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +244 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +52 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +29 -0
  143. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +66 -0
  144. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +60 -0
  145. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +456 -0
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +159 -0
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +95 -0
  148. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +37 -0
  149. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +165 -0
  150. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1106 -0
  151. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +318 -0
  152. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +99 -0
  153. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +64 -0
  154. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +119 -0
  155. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +187 -0
  156. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +199 -0
  157. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +164 -0
  158. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +37 -0
  159. numba_cuda/numba/cuda/tests/cudapy/test_math.py +786 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +74 -0
  161. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +113 -0
  162. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +22 -0
  163. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +140 -0
  164. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +46 -0
  165. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +101 -0
  166. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +49 -0
  167. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +401 -0
  168. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +86 -0
  169. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +335 -0
  170. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +124 -0
  171. numba_cuda/numba/cuda/tests/cudapy/test_print.py +128 -0
  172. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +33 -0
  173. numba_cuda/numba/cuda/tests/cudapy/test_random.py +104 -0
  174. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +610 -0
  175. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +125 -0
  176. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +76 -0
  177. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
  178. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +85 -0
  179. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +37 -0
  180. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +444 -0
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +205 -0
  182. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +271 -0
  183. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +80 -0
  184. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +277 -0
  185. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +47 -0
  186. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +307 -0
  187. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +283 -0
  188. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +20 -0
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +69 -0
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +36 -0
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +37 -0
  192. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +139 -0
  193. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +276 -0
  194. numba_cuda/numba/cuda/tests/cudasim/__init__.py +6 -0
  195. numba_cuda/numba/cuda/tests/cudasim/support.py +6 -0
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +102 -0
  197. numba_cuda/numba/cuda/tests/data/__init__.py +0 -0
  198. numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
  199. numba_cuda/numba/cuda/tests/data/error.cu +7 -0
  200. numba_cuda/numba/cuda/tests/data/jitlink.cu +23 -0
  201. numba_cuda/numba/cuda/tests/data/jitlink.ptx +51 -0
  202. numba_cuda/numba/cuda/tests/data/warn.cu +7 -0
  203. numba_cuda/numba/cuda/tests/doc_examples/__init__.py +6 -0
  204. numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +0 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +49 -0
  206. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +77 -0
  207. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +76 -0
  208. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +82 -0
  209. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +155 -0
  210. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +173 -0
  211. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +109 -0
  212. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +59 -0
  213. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +76 -0
  214. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +130 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +50 -0
  216. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +73 -0
  217. numba_cuda/numba/cuda/tests/nocuda/__init__.py +8 -0
  218. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +359 -0
  219. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +36 -0
  220. numba_cuda/numba/cuda/tests/nocuda/test_import.py +49 -0
  221. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +238 -0
  222. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +54 -0
  223. numba_cuda/numba/cuda/types.py +37 -0
  224. numba_cuda/numba/cuda/ufuncs.py +662 -0
  225. numba_cuda/numba/cuda/vector_types.py +209 -0
  226. numba_cuda/numba/cuda/vectorizers.py +252 -0
  227. numba_cuda-0.0.12.dist-info/LICENSE +25 -0
  228. numba_cuda-0.0.12.dist-info/METADATA +68 -0
  229. numba_cuda-0.0.12.dist-info/RECORD +231 -0
  230. {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.12.dist-info}/WHEEL +1 -1
  231. numba_cuda-0.0.1.dist-info/METADATA +0 -10
  232. numba_cuda-0.0.1.dist-info/RECORD +0 -5
  233. {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,359 @@
1
+ import unittest
2
+ import itertools
3
+ import numpy as np
4
+ from numba.cuda.cudadrv.dummyarray import Array
5
+ from numba.cuda.testing import skip_on_cudasim
6
+
7
+
8
+ @skip_on_cudasim("Tests internals of the CUDA driver device array")
9
+ class TestSlicing(unittest.TestCase):
10
+
11
+ def assertSameContig(self, arr, nparr):
12
+ attrs = 'C_CONTIGUOUS', 'F_CONTIGUOUS'
13
+ for attr in attrs:
14
+ if arr.flags[attr] != nparr.flags[attr]:
15
+ if arr.size == 0 and nparr.size == 0:
16
+ # numpy <=1.7 bug that some empty array are contiguous and
17
+ # some are not
18
+ pass
19
+ else:
20
+ self.fail("contiguous flag mismatch:\ngot=%s\nexpect=%s" %
21
+ (arr.flags, nparr.flags))
22
+
23
+ #### 1D
24
+
25
+ def test_slice0_1d(self):
26
+ nparr = np.empty(4)
27
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
28
+ nparr.dtype.itemsize)
29
+ self.assertSameContig(arr, nparr)
30
+ xx = -2, -1, 0, 1, 2
31
+ for x in xx:
32
+ expect = nparr[x:]
33
+ got = arr[x:]
34
+ self.assertSameContig(got, expect)
35
+ self.assertEqual(got.shape, expect.shape)
36
+ self.assertEqual(got.strides, expect.strides)
37
+
38
+ def test_slice1_1d(self):
39
+ nparr = np.empty(4)
40
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
41
+ nparr.dtype.itemsize)
42
+ xx = -2, -1, 0, 1, 2
43
+ for x in xx:
44
+ expect = nparr[:x]
45
+ got = arr[:x]
46
+ self.assertSameContig(got, expect)
47
+ self.assertEqual(got.shape, expect.shape)
48
+ self.assertEqual(got.strides, expect.strides)
49
+
50
+ def test_slice2_1d(self):
51
+ nparr = np.empty(4)
52
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
53
+ nparr.dtype.itemsize)
54
+ xx = -2, -1, 0, 1, 2
55
+ for x, y in itertools.product(xx, xx):
56
+ expect = nparr[x:y]
57
+ got = arr[x:y]
58
+ self.assertSameContig(got, expect)
59
+ self.assertEqual(got.shape, expect.shape)
60
+ self.assertEqual(got.strides, expect.strides)
61
+
62
+ #### 2D
63
+
64
+ def test_slice0_2d(self):
65
+ nparr = np.empty((4, 5))
66
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
67
+ nparr.dtype.itemsize)
68
+ xx = -2, 0, 1, 2
69
+ for x in xx:
70
+ expect = nparr[x:]
71
+ got = arr[x:]
72
+ self.assertSameContig(got, expect)
73
+ self.assertEqual(got.shape, expect.shape)
74
+ self.assertEqual(got.strides, expect.strides)
75
+
76
+ for x, y in itertools.product(xx, xx):
77
+ expect = nparr[x:, y:]
78
+ got = arr[x:, y:]
79
+ self.assertSameContig(got, expect)
80
+ self.assertEqual(got.shape, expect.shape)
81
+ self.assertEqual(got.strides, expect.strides)
82
+
83
+ def test_slice1_2d(self):
84
+ nparr = np.empty((4, 5))
85
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
86
+ nparr.dtype.itemsize)
87
+ xx = -2, 0, 2
88
+ for x in xx:
89
+ expect = nparr[:x]
90
+ got = arr[:x]
91
+ self.assertEqual(got.shape, expect.shape)
92
+ self.assertEqual(got.strides, expect.strides)
93
+ self.assertSameContig(got, expect)
94
+
95
+ for x, y in itertools.product(xx, xx):
96
+ expect = nparr[:x, :y]
97
+ got = arr[:x, :y]
98
+ self.assertEqual(got.shape, expect.shape)
99
+ self.assertEqual(got.strides, expect.strides)
100
+ self.assertSameContig(got, expect)
101
+
102
+ def test_slice2_2d(self):
103
+ nparr = np.empty((4, 5))
104
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
105
+ nparr.dtype.itemsize)
106
+ xx = -2, 0, 2
107
+ for s, t, u, v in itertools.product(xx, xx, xx, xx):
108
+ expect = nparr[s:t, u:v]
109
+ got = arr[s:t, u:v]
110
+ self.assertSameContig(got, expect)
111
+ self.assertEqual(got.shape, expect.shape)
112
+ self.assertEqual(got.strides, expect.strides)
113
+
114
+ for x, y in itertools.product(xx, xx):
115
+ expect = nparr[s:t, u:v]
116
+ got = arr[s:t, u:v]
117
+ self.assertSameContig(got, expect)
118
+ self.assertEqual(got.shape, expect.shape)
119
+ self.assertEqual(got.strides, expect.strides)
120
+
121
+ #### Strided
122
+
123
+ def test_strided_1d(self):
124
+ nparr = np.empty(4)
125
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
126
+ nparr.dtype.itemsize)
127
+ xx = -2, -1, 1, 2
128
+ for x in xx:
129
+ expect = nparr[::x]
130
+ got = arr[::x]
131
+ self.assertSameContig(got, expect)
132
+ self.assertEqual(got.shape, expect.shape)
133
+ self.assertEqual(got.strides, expect.strides)
134
+
135
+ def test_strided_2d(self):
136
+ nparr = np.empty((4, 5))
137
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
138
+ nparr.dtype.itemsize)
139
+ xx = -2, -1, 1, 2
140
+ for a, b in itertools.product(xx, xx):
141
+ expect = nparr[::a, ::b]
142
+ got = arr[::a, ::b]
143
+ self.assertSameContig(got, expect)
144
+ self.assertEqual(got.shape, expect.shape)
145
+ self.assertEqual(got.strides, expect.strides)
146
+
147
+ def test_strided_3d(self):
148
+ nparr = np.empty((4, 5, 6))
149
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
150
+ nparr.dtype.itemsize)
151
+ xx = -2, -1, 1, 2
152
+ for a, b, c in itertools.product(xx, xx, xx):
153
+ expect = nparr[::a, ::b, ::c]
154
+ got = arr[::a, ::b, ::c]
155
+ self.assertSameContig(got, expect)
156
+ self.assertEqual(got.shape, expect.shape)
157
+ self.assertEqual(got.strides, expect.strides)
158
+
159
+ def test_issue_2766(self):
160
+ z = np.empty((1, 2, 3))
161
+ z = np.transpose(z, axes=(2, 0, 1))
162
+ arr = Array.from_desc(0, z.shape, z.strides, z.itemsize)
163
+ self.assertEqual(z.flags['C_CONTIGUOUS'], arr.flags['C_CONTIGUOUS'])
164
+ self.assertEqual(z.flags['F_CONTIGUOUS'], arr.flags['F_CONTIGUOUS'])
165
+
166
+
167
+ @skip_on_cudasim("Tests internals of the CUDA driver device array")
168
+ class TestReshape(unittest.TestCase):
169
+ def test_reshape_2d2d(self):
170
+ nparr = np.empty((4, 5))
171
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
172
+ nparr.dtype.itemsize)
173
+ expect = nparr.reshape(5, 4)
174
+ got = arr.reshape(5, 4)[0]
175
+ self.assertEqual(got.shape, expect.shape)
176
+ self.assertEqual(got.strides, expect.strides)
177
+
178
+ def test_reshape_2d1d(self):
179
+ nparr = np.empty((4, 5))
180
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
181
+ nparr.dtype.itemsize)
182
+ expect = nparr.reshape(5 * 4)
183
+ got = arr.reshape(5 * 4)[0]
184
+ self.assertEqual(got.shape, expect.shape)
185
+ self.assertEqual(got.strides, expect.strides)
186
+
187
+ def test_reshape_3d3d(self):
188
+ nparr = np.empty((3, 4, 5))
189
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
190
+ nparr.dtype.itemsize)
191
+ expect = nparr.reshape(5, 3, 4)
192
+ got = arr.reshape(5, 3, 4)[0]
193
+ self.assertEqual(got.shape, expect.shape)
194
+ self.assertEqual(got.strides, expect.strides)
195
+
196
+ def test_reshape_3d2d(self):
197
+ nparr = np.empty((3, 4, 5))
198
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
199
+ nparr.dtype.itemsize)
200
+ expect = nparr.reshape(3 * 4, 5)
201
+ got = arr.reshape(3 * 4, 5)[0]
202
+ self.assertEqual(got.shape, expect.shape)
203
+ self.assertEqual(got.strides, expect.strides)
204
+
205
+ def test_reshape_3d1d(self):
206
+ nparr = np.empty((3, 4, 5))
207
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
208
+ nparr.dtype.itemsize)
209
+ expect = nparr.reshape(3 * 4 * 5)
210
+ got = arr.reshape(3 * 4 * 5)[0]
211
+ self.assertEqual(got.shape, expect.shape)
212
+ self.assertEqual(got.strides, expect.strides)
213
+
214
+ def test_reshape_infer2d2d(self):
215
+ nparr = np.empty((4, 5))
216
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
217
+ nparr.dtype.itemsize)
218
+ expect = nparr.reshape(-1, 4)
219
+ got = arr.reshape(-1, 4)[0]
220
+ self.assertEqual(got.shape, expect.shape)
221
+ self.assertEqual(got.strides, expect.strides)
222
+
223
+ def test_reshape_infer2d1d(self):
224
+ nparr = np.empty((4, 5))
225
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
226
+ nparr.dtype.itemsize)
227
+ expect = nparr.reshape(-1)
228
+ got = arr.reshape(-1)[0]
229
+ self.assertEqual(got.shape, expect.shape)
230
+ self.assertEqual(got.strides, expect.strides)
231
+
232
+ def test_reshape_infer3d3d(self):
233
+ nparr = np.empty((3, 4, 5))
234
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
235
+ nparr.dtype.itemsize)
236
+ expect = nparr.reshape(5, -1, 4)
237
+ got = arr.reshape(5, -1, 4)[0]
238
+ self.assertEqual(got.shape, expect.shape)
239
+ self.assertEqual(got.strides, expect.strides)
240
+
241
+ def test_reshape_infer3d2d(self):
242
+ nparr = np.empty((3, 4, 5))
243
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
244
+ nparr.dtype.itemsize)
245
+ expect = nparr.reshape(3, -1)
246
+ got = arr.reshape(3, -1)[0]
247
+ self.assertEqual(got.shape, expect.shape)
248
+ self.assertEqual(got.strides, expect.strides)
249
+
250
+ def test_reshape_infer3d1d(self):
251
+ nparr = np.empty((3, 4, 5))
252
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
253
+ nparr.dtype.itemsize)
254
+ expect = nparr.reshape(-1)
255
+ got = arr.reshape(-1)[0]
256
+ self.assertEqual(got.shape, expect.shape)
257
+ self.assertEqual(got.strides, expect.strides)
258
+
259
+ def test_reshape_infer_two_unknowns(self):
260
+ nparr = np.empty((3, 4, 5))
261
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
262
+ nparr.dtype.itemsize)
263
+
264
+ with self.assertRaises(ValueError) as raises:
265
+ arr.reshape(-1, -1, 3)
266
+ self.assertIn('can only specify one unknown dimension',
267
+ str(raises.exception))
268
+
269
+ def test_reshape_infer_invalid_shape(self):
270
+ nparr = np.empty((3, 4, 5))
271
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
272
+ nparr.dtype.itemsize)
273
+
274
+ with self.assertRaises(ValueError) as raises:
275
+ arr.reshape(-1, 7)
276
+
277
+ expected_message = 'cannot infer valid shape for unknown dimension'
278
+ self.assertIn(expected_message, str(raises.exception))
279
+
280
+
281
+ @skip_on_cudasim("Tests internals of the CUDA driver device array")
282
+ class TestSqueeze(unittest.TestCase):
283
+ def test_squeeze(self):
284
+ nparr = np.empty((1, 2, 1, 4, 1, 3))
285
+ arr = Array.from_desc(
286
+ 0, nparr.shape, nparr.strides, nparr.dtype.itemsize
287
+ )
288
+
289
+ def _assert_equal_shape_strides(arr1, arr2):
290
+ self.assertEqual(arr1.shape, arr2.shape)
291
+ self.assertEqual(arr1.strides, arr2.strides)
292
+ _assert_equal_shape_strides(arr, nparr)
293
+ _assert_equal_shape_strides(arr.squeeze()[0], nparr.squeeze())
294
+ for axis in (0, 2, 4, (0, 2), (0, 4), (2, 4), (0, 2, 4)):
295
+ _assert_equal_shape_strides(
296
+ arr.squeeze(axis=axis)[0], nparr.squeeze(axis=axis)
297
+ )
298
+
299
+ def test_squeeze_invalid_axis(self):
300
+ nparr = np.empty((1, 2, 1, 4, 1, 3))
301
+ arr = Array.from_desc(
302
+ 0, nparr.shape, nparr.strides, nparr.dtype.itemsize
303
+ )
304
+ with self.assertRaises(ValueError):
305
+ arr.squeeze(axis=1)
306
+ with self.assertRaises(ValueError):
307
+ arr.squeeze(axis=(2, 3))
308
+
309
+
310
+ @skip_on_cudasim("Tests internals of the CUDA driver device array")
311
+ class TestExtent(unittest.TestCase):
312
+ def test_extent_1d(self):
313
+ nparr = np.empty(4)
314
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
315
+ nparr.dtype.itemsize)
316
+ s, e = arr.extent
317
+ self.assertEqual(e - s, nparr.size * nparr.dtype.itemsize)
318
+
319
+ def test_extent_2d(self):
320
+ nparr = np.empty((4, 5))
321
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
322
+ nparr.dtype.itemsize)
323
+ s, e = arr.extent
324
+ self.assertEqual(e - s, nparr.size * nparr.dtype.itemsize)
325
+
326
+ def test_extent_iter_1d(self):
327
+ nparr = np.empty(4)
328
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
329
+ nparr.dtype.itemsize)
330
+ [ext] = list(arr.iter_contiguous_extent())
331
+ self.assertEqual(ext, arr.extent)
332
+
333
+ def test_extent_iter_2d(self):
334
+ nparr = np.empty((4, 5))
335
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
336
+ nparr.dtype.itemsize)
337
+ [ext] = list(arr.iter_contiguous_extent())
338
+ self.assertEqual(ext, arr.extent)
339
+
340
+ self.assertEqual(len(list(arr[::2].iter_contiguous_extent())), 2)
341
+
342
+
343
+ @skip_on_cudasim("Tests internals of the CUDA driver device array")
344
+ class TestIterate(unittest.TestCase):
345
+ def test_for_loop(self):
346
+ # for #4201
347
+ N = 5
348
+ nparr = np.empty(N)
349
+ arr = Array.from_desc(0, nparr.shape, nparr.strides,
350
+ nparr.dtype.itemsize)
351
+
352
+ x = 0 # just a placeholder
353
+ # this loop should not raise AssertionError
354
+ for val in arr:
355
+ x = val # noqa: F841
356
+
357
+
358
+ if __name__ == '__main__':
359
+ unittest.main()
@@ -0,0 +1,36 @@
1
+ from numba.cuda.testing import unittest, skip_on_cudasim
2
+ import operator
3
+ from numba.core import types, typing
4
+ from numba.cuda.cudadrv import nvvm
5
+
6
+
7
+ @unittest.skipIf(not nvvm.is_available(), "No libNVVM")
8
+ @skip_on_cudasim("Skip on simulator due to use of cuda_target")
9
+ class TestFunctionResolution(unittest.TestCase):
10
+ def test_fp16_binary_operators(self):
11
+ from numba.cuda.descriptor import cuda_target
12
+ ops = (operator.add, operator.iadd, operator.sub, operator.isub,
13
+ operator.mul, operator.imul)
14
+ for op in ops:
15
+ fp16 = types.float16
16
+ typingctx = cuda_target.typing_context
17
+ typingctx.refresh()
18
+ fnty = typingctx.resolve_value_type(op)
19
+ out = typingctx.resolve_function_type(fnty, (fp16, fp16), {})
20
+ self.assertEqual(out, typing.signature(fp16, fp16, fp16),
21
+ msg=str(out))
22
+
23
+ def test_fp16_unary_operators(self):
24
+ from numba.cuda.descriptor import cuda_target
25
+ ops = (operator.neg, abs)
26
+ for op in ops:
27
+ fp16 = types.float16
28
+ typingctx = cuda_target.typing_context
29
+ typingctx.refresh()
30
+ fnty = typingctx.resolve_value_type(op)
31
+ out = typingctx.resolve_function_type(fnty, (fp16,), {})
32
+ self.assertEqual(out, typing.signature(fp16, fp16), msg=str(out))
33
+
34
+
35
+ if __name__ == '__main__':
36
+ unittest.main()
@@ -0,0 +1,49 @@
1
+ from numba.tests.support import run_in_subprocess
2
+ import unittest
3
+
4
+
5
+ class TestImport(unittest.TestCase):
6
+ def test_no_impl_import(self):
7
+ """
8
+ Tests that importing cuda doesn't trigger the import of modules
9
+ containing lowering implementation that would likely install things in
10
+ the builtins registry and have side effects impacting other targets.
11
+ """
12
+
13
+ banlist = (
14
+ 'numba.cpython.slicing',
15
+ 'numba.cpython.tupleobj',
16
+ 'numba.cpython.enumimpl',
17
+ 'numba.cpython.hashing',
18
+ 'numba.cpython.heapq',
19
+ 'numba.cpython.iterators',
20
+ 'numba.cpython.numbers',
21
+ 'numba.cpython.rangeobj',
22
+ 'numba.cpython.cmathimpl',
23
+ 'numba.cpython.mathimpl',
24
+ 'numba.cpython.printimpl',
25
+ 'numba.cpython.randomimpl',
26
+ 'numba.core.optional',
27
+ 'numba.misc.gdb_hook',
28
+ 'numba.misc.literal',
29
+ 'numba.misc.cffiimpl',
30
+ 'numba.np.linalg',
31
+ 'numba.np.polynomial',
32
+ 'numba.np.arraymath',
33
+ 'numba.np.npdatetime',
34
+ 'numba.np.npyimpl',
35
+ 'numba.typed.typeddict',
36
+ 'numba.typed.typedlist',
37
+ 'numba.experimental.jitclass.base',
38
+ )
39
+
40
+ code = "import sys; from numba import cuda; print(list(sys.modules))"
41
+
42
+ out, _ = run_in_subprocess(code)
43
+ modlist = set(eval(out.strip()))
44
+ unexpected = set(banlist) & set(modlist)
45
+ self.assertFalse(unexpected, "some modules unexpectedly imported")
46
+
47
+
48
+ if __name__ == '__main__':
49
+ unittest.main()
@@ -0,0 +1,238 @@
1
+ import sys
2
+ import os
3
+ import multiprocessing as mp
4
+ import warnings
5
+
6
+ from numba.core.config import IS_WIN32, IS_OSX
7
+ from numba.core.errors import NumbaWarning
8
+ from numba.cuda.cudadrv import nvvm
9
+ from numba.cuda.testing import (
10
+ unittest,
11
+ skip_on_cudasim,
12
+ SerialMixin,
13
+ skip_unless_conda_cudatoolkit,
14
+ )
15
+ from numba.cuda.cuda_paths import (
16
+ _get_libdevice_path_decision,
17
+ _get_nvvm_path_decision,
18
+ _get_cudalib_dir_path_decision,
19
+ get_system_ctk,
20
+ )
21
+
22
+
23
+ has_cuda = nvvm.is_available()
24
+ has_mp_get_context = hasattr(mp, 'get_context')
25
+
26
+
27
+ class LibraryLookupBase(SerialMixin, unittest.TestCase):
28
+ def setUp(self):
29
+ ctx = mp.get_context('spawn')
30
+
31
+ qrecv = ctx.Queue()
32
+ qsend = ctx.Queue()
33
+ self.qsend = qsend
34
+ self.qrecv = qrecv
35
+ self.child_process = ctx.Process(
36
+ target=check_lib_lookup,
37
+ args=(qrecv, qsend),
38
+ daemon=True,
39
+ )
40
+ self.child_process.start()
41
+
42
+ def tearDown(self):
43
+ self.qsend.put(self.do_terminate)
44
+ self.child_process.join(3)
45
+ # Ensure the process is terminated
46
+ self.assertIsNotNone(self.child_process)
47
+
48
+ def remote_do(self, action):
49
+ self.qsend.put(action)
50
+ out = self.qrecv.get()
51
+ self.assertNotIsInstance(out, BaseException)
52
+ return out
53
+
54
+ @staticmethod
55
+ def do_terminate():
56
+ return False, None
57
+
58
+
59
+ def remove_env(name):
60
+ try:
61
+ del os.environ[name]
62
+ except KeyError:
63
+ return False
64
+ else:
65
+ return True
66
+
67
+
68
+ def check_lib_lookup(qout, qin):
69
+ status = True
70
+ while status:
71
+ try:
72
+ action = qin.get()
73
+ except Exception as e:
74
+ qout.put(e)
75
+ status = False
76
+ else:
77
+ try:
78
+ with warnings.catch_warnings(record=True) as w:
79
+ warnings.simplefilter("always", NumbaWarning)
80
+ status, result = action()
81
+ qout.put(result + (w,))
82
+ except Exception as e:
83
+ qout.put(e)
84
+ status = False
85
+
86
+
87
+ @skip_on_cudasim('Library detection unsupported in the simulator')
88
+ @unittest.skipUnless(has_mp_get_context, 'mp.get_context not available')
89
+ @skip_unless_conda_cudatoolkit('test assumes conda installed cudatoolkit')
90
+ class TestLibDeviceLookUp(LibraryLookupBase):
91
+ def test_libdevice_path_decision(self):
92
+ # Check that the default is using conda environment
93
+ by, info, warns = self.remote_do(self.do_clear_envs)
94
+ if has_cuda:
95
+ self.assertEqual(by, 'Conda environment')
96
+ else:
97
+ self.assertEqual(by, "<unknown>")
98
+ self.assertIsNone(info)
99
+ self.assertFalse(warns)
100
+ # Check that CUDA_HOME works by removing conda-env
101
+ by, info, warns = self.remote_do(self.do_set_cuda_home)
102
+ self.assertEqual(by, 'CUDA_HOME')
103
+ self.assertEqual(info, os.path.join('mycudahome', 'nvvm', 'libdevice'))
104
+ self.assertFalse(warns)
105
+
106
+ if get_system_ctk() is None:
107
+ # Fake remove conda environment so no cudatoolkit is available
108
+ by, info, warns = self.remote_do(self.do_clear_envs)
109
+ self.assertEqual(by, '<unknown>')
110
+ self.assertIsNone(info)
111
+ self.assertFalse(warns)
112
+ else:
113
+ # Use system available cudatoolkit
114
+ by, info, warns = self.remote_do(self.do_clear_envs)
115
+ self.assertEqual(by, 'System')
116
+ self.assertFalse(warns)
117
+
118
+ @staticmethod
119
+ def do_clear_envs():
120
+ remove_env('CUDA_HOME')
121
+ remove_env('CUDA_PATH')
122
+ return True, _get_libdevice_path_decision()
123
+
124
+ @staticmethod
125
+ def do_set_cuda_home():
126
+ os.environ['CUDA_HOME'] = os.path.join('mycudahome')
127
+ _fake_non_conda_env()
128
+ return True, _get_libdevice_path_decision()
129
+
130
+
131
+ @skip_on_cudasim('Library detection unsupported in the simulator')
132
+ @unittest.skipUnless(has_mp_get_context, 'mp.get_context not available')
133
+ @skip_unless_conda_cudatoolkit('test assumes conda installed cudatoolkit')
134
+ class TestNvvmLookUp(LibraryLookupBase):
135
+ def test_nvvm_path_decision(self):
136
+ # Check that the default is using conda environment
137
+ by, info, warns = self.remote_do(self.do_clear_envs)
138
+ if has_cuda:
139
+ self.assertEqual(by, 'Conda environment')
140
+ else:
141
+ self.assertEqual(by, "<unknown>")
142
+ self.assertIsNone(info)
143
+ self.assertFalse(warns)
144
+ # Check that CUDA_HOME works by removing conda-env
145
+ by, info, warns = self.remote_do(self.do_set_cuda_home)
146
+ self.assertEqual(by, 'CUDA_HOME')
147
+ self.assertFalse(warns)
148
+ if IS_WIN32:
149
+ self.assertEqual(info, os.path.join('mycudahome', 'nvvm', 'bin'))
150
+ elif IS_OSX:
151
+ self.assertEqual(info, os.path.join('mycudahome', 'nvvm', 'lib'))
152
+ else:
153
+ self.assertEqual(info, os.path.join('mycudahome', 'nvvm', 'lib64'))
154
+
155
+ if get_system_ctk() is None:
156
+ # Fake remove conda environment so no cudatoolkit is available
157
+ by, info, warns = self.remote_do(self.do_clear_envs)
158
+ self.assertEqual(by, '<unknown>')
159
+ self.assertIsNone(info)
160
+ self.assertFalse(warns)
161
+ else:
162
+ # Use system available cudatoolkit
163
+ by, info, warns = self.remote_do(self.do_clear_envs)
164
+ self.assertEqual(by, 'System')
165
+ self.assertFalse(warns)
166
+
167
+ @staticmethod
168
+ def do_clear_envs():
169
+ remove_env('CUDA_HOME')
170
+ remove_env('CUDA_PATH')
171
+ return True, _get_nvvm_path_decision()
172
+
173
+ @staticmethod
174
+ def do_set_cuda_home():
175
+ os.environ['CUDA_HOME'] = os.path.join('mycudahome')
176
+ _fake_non_conda_env()
177
+ return True, _get_nvvm_path_decision()
178
+
179
+
180
+ @skip_on_cudasim('Library detection unsupported in the simulator')
181
+ @unittest.skipUnless(has_mp_get_context, 'mp.get_context not available')
182
+ @skip_unless_conda_cudatoolkit('test assumes conda installed cudatoolkit')
183
+ class TestCudaLibLookUp(LibraryLookupBase):
184
+ def test_cudalib_path_decision(self):
185
+ # Check that the default is using conda environment
186
+ by, info, warns = self.remote_do(self.do_clear_envs)
187
+ if has_cuda:
188
+ self.assertEqual(by, 'Conda environment')
189
+ else:
190
+ self.assertEqual(by, "<unknown>")
191
+ self.assertIsNone(info)
192
+ self.assertFalse(warns)
193
+
194
+ # Check that CUDA_HOME works by removing conda-env
195
+ self.remote_do(self.do_clear_envs)
196
+ by, info, warns = self.remote_do(self.do_set_cuda_home)
197
+ self.assertEqual(by, 'CUDA_HOME')
198
+ self.assertFalse(warns)
199
+ if IS_WIN32:
200
+ self.assertEqual(info, os.path.join('mycudahome', 'bin'))
201
+ elif IS_OSX:
202
+ self.assertEqual(info, os.path.join('mycudahome', 'lib'))
203
+ else:
204
+ self.assertEqual(info, os.path.join('mycudahome', 'lib64'))
205
+ if get_system_ctk() is None:
206
+ # Fake remove conda environment so no cudatoolkit is available
207
+ by, info, warns = self.remote_do(self.do_clear_envs)
208
+ self.assertEqual(by, "<unknown>")
209
+ self.assertIsNone(info)
210
+ self.assertFalse(warns)
211
+ else:
212
+ # Use system available cudatoolkit
213
+ by, info, warns = self.remote_do(self.do_clear_envs)
214
+ self.assertEqual(by, 'System')
215
+ self.assertFalse(warns)
216
+
217
+ @staticmethod
218
+ def do_clear_envs():
219
+ remove_env('CUDA_HOME')
220
+ remove_env('CUDA_PATH')
221
+ return True, _get_cudalib_dir_path_decision()
222
+
223
+ @staticmethod
224
+ def do_set_cuda_home():
225
+ os.environ['CUDA_HOME'] = os.path.join('mycudahome')
226
+ _fake_non_conda_env()
227
+ return True, _get_cudalib_dir_path_decision()
228
+
229
+
230
+ def _fake_non_conda_env():
231
+ """
232
+ Monkeypatch sys.prefix to hide the fact we are in a conda-env
233
+ """
234
+ sys.prefix = ''
235
+
236
+
237
+ if __name__ == '__main__':
238
+ unittest.main()