numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.0.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -15,32 +15,48 @@ from numba.core.typing.typeof import Purpose, typeof
15
15
  from numba.core.types.functions import Function
16
16
  from numba.cuda.api import get_current_device
17
17
  from numba.cuda.args import wrap_arg
18
- from numba.cuda.compiler import (compile_cuda, CUDACompiler, kernel_fixup,
19
- ExternFunction)
18
+ from numba.cuda.compiler import (
19
+ compile_cuda,
20
+ CUDACompiler,
21
+ kernel_fixup,
22
+ ExternFunction,
23
+ )
20
24
  from numba.cuda.cudadrv import driver
21
25
  from numba.cuda.cudadrv.devices import get_context
22
26
  from numba.cuda.descriptor import cuda_target
23
- from numba.cuda.errors import (missing_launch_config_msg,
24
- normalize_kernel_dimensions)
27
+ from numba.cuda.errors import (
28
+ missing_launch_config_msg,
29
+ normalize_kernel_dimensions,
30
+ )
25
31
  from numba.cuda import types as cuda_types
26
32
  from numba.cuda.runtime.nrt import rtsys
33
+ from numba.cuda.locks import module_init_lock
27
34
 
28
35
  from numba import cuda
29
36
  from numba import _dispatcher
30
37
 
31
38
  from warnings import warn
32
39
 
33
- cuda_fp16_math_funcs = ['hsin', 'hcos',
34
- 'hlog', 'hlog10',
35
- 'hlog2',
36
- 'hexp', 'hexp10',
37
- 'hexp2',
38
- 'hsqrt', 'hrsqrt',
39
- 'hfloor', 'hceil',
40
- 'hrcp', 'hrint',
41
- 'htrunc', 'hdiv']
42
-
43
- reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape']
40
+ cuda_fp16_math_funcs = [
41
+ "hsin",
42
+ "hcos",
43
+ "hlog",
44
+ "hlog10",
45
+ "hlog2",
46
+ "hexp",
47
+ "hexp10",
48
+ "hexp2",
49
+ "hsqrt",
50
+ "hrsqrt",
51
+ "hfloor",
52
+ "hceil",
53
+ "hrcp",
54
+ "hrint",
55
+ "htrunc",
56
+ "hdiv",
57
+ ]
58
+
59
+ reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"]
44
60
 
45
61
 
46
62
  def get_cres_link_objects(cres):
@@ -51,17 +67,16 @@ def get_cres_link_objects(cres):
51
67
 
52
68
  # List of calls into declared device functions
53
69
  device_func_calls = [
54
- (name, v) for name, v in cres.fndesc.typemap.items() if (
55
- isinstance(v, cuda_types.CUDADispatcher)
56
- )
70
+ (name, v)
71
+ for name, v in cres.fndesc.typemap.items()
72
+ if (isinstance(v, cuda_types.CUDADispatcher))
57
73
  ]
58
74
 
59
75
  # List of tuples with SSA name of calls and corresponding signature
60
76
  call_signatures = [
61
77
  (call.func.name, sig)
62
- for call, sig in cres.fndesc.calltypes.items() if (
63
- isinstance(call, ir.Expr) and call.op == 'call'
64
- )
78
+ for call, sig in cres.fndesc.calltypes.items()
79
+ if (isinstance(call, ir.Expr) and call.op == "call")
65
80
  ]
66
81
 
67
82
  # Map SSA names to all invoked signatures
@@ -93,10 +108,10 @@ def get_cres_link_objects(cres):
93
108
 
94
109
 
95
110
  class _Kernel(serialize.ReduceMixin):
96
- '''
111
+ """
97
112
  CUDA Kernel specialized for a given set of argument types. When called, this
98
113
  object launches the kernel on the device.
99
- '''
114
+ """
100
115
 
101
116
  NRT_functions = [
102
117
  "NRT_Allocate",
@@ -110,16 +125,27 @@ class _Kernel(serialize.ReduceMixin):
110
125
  "NRT_MemInfo_alloc_aligned",
111
126
  "NRT_Allocate_External",
112
127
  "NRT_decref",
113
- "NRT_incref"
128
+ "NRT_incref",
114
129
  ]
115
130
 
116
131
  @global_compiler_lock
117
- def __init__(self, py_func, argtypes, link=None, debug=False,
118
- lineinfo=False, inline=False, fastmath=False, extensions=None,
119
- max_registers=None, lto=False, opt=True, device=False):
120
-
132
+ def __init__(
133
+ self,
134
+ py_func,
135
+ argtypes,
136
+ link=None,
137
+ debug=False,
138
+ lineinfo=False,
139
+ inline=False,
140
+ fastmath=False,
141
+ extensions=None,
142
+ max_registers=None,
143
+ lto=False,
144
+ opt=True,
145
+ device=False,
146
+ ):
121
147
  if device:
122
- raise RuntimeError('Cannot compile a device function as a kernel')
148
+ raise RuntimeError("Cannot compile a device function as a kernel")
123
149
 
124
150
  super().__init__()
125
151
 
@@ -144,24 +170,25 @@ class _Kernel(serialize.ReduceMixin):
144
170
  self.lineinfo = lineinfo
145
171
  self.extensions = extensions or []
146
172
 
147
- nvvm_options = {
148
- 'fastmath': fastmath,
149
- 'opt': 3 if opt else 0
150
- }
173
+ nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
151
174
 
152
175
  if debug:
153
- nvvm_options['g'] = None
176
+ nvvm_options["g"] = None
154
177
 
155
178
  cc = get_current_device().compute_capability
156
- cres = compile_cuda(self.py_func, types.void, self.argtypes,
157
- debug=self.debug,
158
- lineinfo=lineinfo,
159
- inline=inline,
160
- fastmath=fastmath,
161
- nvvm_options=nvvm_options,
162
- cc=cc,
163
- max_registers=max_registers,
164
- lto=lto)
179
+ cres = compile_cuda(
180
+ self.py_func,
181
+ types.void,
182
+ self.argtypes,
183
+ debug=self.debug,
184
+ lineinfo=lineinfo,
185
+ inline=inline,
186
+ fastmath=fastmath,
187
+ nvvm_options=nvvm_options,
188
+ cc=cc,
189
+ max_registers=max_registers,
190
+ lto=lto,
191
+ )
165
192
  tgt_ctx = cres.target_context
166
193
  lib = cres.library
167
194
  kernel = lib.get_function(cres.fndesc.llvm_func_name)
@@ -174,24 +201,25 @@ class _Kernel(serialize.ReduceMixin):
174
201
  asm = lib.get_asm_str()
175
202
 
176
203
  # A kernel needs cooperative launch if grid_sync is being used.
177
- self.cooperative = 'cudaCGGetIntrinsicHandle' in asm
204
+ self.cooperative = "cudaCGGetIntrinsicHandle" in asm
178
205
  # We need to link against cudadevrt if grid sync is being used.
179
206
  if self.cooperative:
180
207
  lib.needs_cudadevrt = True
181
208
 
182
- def link_to_library_functions(library_functions, library_path,
183
- prefix=None):
209
+ def link_to_library_functions(
210
+ library_functions, library_path, prefix=None
211
+ ):
184
212
  """
185
213
  Dynamically links to library functions by searching for their names
186
214
  in the specified library and linking to the corresponding source
187
215
  file.
188
216
  """
189
217
  if prefix is not None:
190
- library_functions = [f"{prefix}{fn}" for fn in
191
- library_functions]
218
+ library_functions = [
219
+ f"{prefix}{fn}" for fn in library_functions
220
+ ]
192
221
 
193
- found_functions = [fn for fn in library_functions
194
- if f'{fn}' in asm]
222
+ found_functions = [fn for fn in library_functions if f"{fn}" in asm]
195
223
 
196
224
  if found_functions:
197
225
  basedir = os.path.dirname(os.path.abspath(__file__))
@@ -201,11 +229,11 @@ class _Kernel(serialize.ReduceMixin):
201
229
  return found_functions
202
230
 
203
231
  # Link to the helper library functions if needed
204
- link_to_library_functions(reshape_funcs, 'reshape_funcs.cu')
232
+ link_to_library_functions(reshape_funcs, "reshape_funcs.cu")
205
233
  # Link to the CUDA FP16 math library functions if needed
206
- link_to_library_functions(cuda_fp16_math_funcs,
207
- 'cpp_function_wrappers.cu',
208
- '__numba_wrapper_')
234
+ link_to_library_functions(
235
+ cuda_fp16_math_funcs, "cpp_function_wrappers.cu", "__numba_wrapper_"
236
+ )
209
237
 
210
238
  self.maybe_link_nrt(link, tgt_ctx, asm)
211
239
 
@@ -239,15 +267,16 @@ class _Kernel(serialize.ReduceMixin):
239
267
 
240
268
  all_nrt = "|".join(self.NRT_functions)
241
269
  pattern = (
242
- r'\.extern\s+\.func\s+(?:\s*\(.+\)\s*)?('
243
- + all_nrt + r')\s*\([^)]*\)\s*;'
270
+ r"\.extern\s+\.func\s+(?:\s*\(.+\)\s*)?("
271
+ + all_nrt
272
+ + r")\s*\([^)]*\)\s*;"
244
273
  )
245
274
 
246
275
  nrt_in_asm = re.findall(pattern, asm)
247
276
 
248
277
  basedir = os.path.dirname(os.path.abspath(__file__))
249
278
  if nrt_in_asm:
250
- nrt_path = os.path.join(basedir, 'runtime', 'nrt.cu')
279
+ nrt_path = os.path.join(basedir, "runtime", "nrt.cu")
251
280
  link.append(nrt_path)
252
281
 
253
282
  @property
@@ -270,8 +299,17 @@ class _Kernel(serialize.ReduceMixin):
270
299
  return tuple(self.signature.args)
271
300
 
272
301
  @classmethod
273
- def _rebuild(cls, cooperative, name, signature, codelibrary,
274
- debug, lineinfo, call_helper, extensions):
302
+ def _rebuild(
303
+ cls,
304
+ cooperative,
305
+ name,
306
+ signature,
307
+ codelibrary,
308
+ debug,
309
+ lineinfo,
310
+ call_helper,
311
+ extensions,
312
+ ):
275
313
  """
276
314
  Rebuild an instance.
277
315
  """
@@ -299,10 +337,21 @@ class _Kernel(serialize.ReduceMixin):
299
337
  Thread, block and shared memory configuration are serialized.
300
338
  Stream information is discarded.
301
339
  """
302
- return dict(cooperative=self.cooperative, name=self.entry_name,
303
- signature=self.signature, codelibrary=self._codelibrary,
304
- debug=self.debug, lineinfo=self.lineinfo,
305
- call_helper=self.call_helper, extensions=self.extensions)
340
+ return dict(
341
+ cooperative=self.cooperative,
342
+ name=self.entry_name,
343
+ signature=self.signature,
344
+ codelibrary=self._codelibrary,
345
+ debug=self.debug,
346
+ lineinfo=self.lineinfo,
347
+ call_helper=self.call_helper,
348
+ extensions=self.extensions,
349
+ )
350
+
351
+ @module_init_lock
352
+ def initialize_once(self, mod):
353
+ if not mod.initialized:
354
+ mod.setup()
306
355
 
307
356
  def bind(self):
308
357
  """
@@ -310,6 +359,8 @@ class _Kernel(serialize.ReduceMixin):
310
359
  """
311
360
  cufunc = self._codelibrary.get_cufunc()
312
361
 
362
+ self.initialize_once(cufunc.module)
363
+
313
364
  if (
314
365
  hasattr(self, "target_context")
315
366
  and self.target_context.enable_nrt
@@ -323,73 +374,73 @@ class _Kernel(serialize.ReduceMixin):
323
374
 
324
375
  @property
325
376
  def regs_per_thread(self):
326
- '''
377
+ """
327
378
  The number of registers used by each thread for this kernel.
328
- '''
379
+ """
329
380
  return self._codelibrary.get_cufunc().attrs.regs
330
381
 
331
382
  @property
332
383
  def const_mem_size(self):
333
- '''
384
+ """
334
385
  The amount of constant memory used by this kernel.
335
- '''
386
+ """
336
387
  return self._codelibrary.get_cufunc().attrs.const
337
388
 
338
389
  @property
339
390
  def shared_mem_per_block(self):
340
- '''
391
+ """
341
392
  The amount of shared memory used per block for this kernel.
342
- '''
393
+ """
343
394
  return self._codelibrary.get_cufunc().attrs.shared
344
395
 
345
396
  @property
346
397
  def max_threads_per_block(self):
347
- '''
398
+ """
348
399
  The maximum allowable threads per block.
349
- '''
400
+ """
350
401
  return self._codelibrary.get_cufunc().attrs.maxthreads
351
402
 
352
403
  @property
353
404
  def local_mem_per_thread(self):
354
- '''
405
+ """
355
406
  The amount of local memory used per thread for this kernel.
356
- '''
407
+ """
357
408
  return self._codelibrary.get_cufunc().attrs.local
358
409
 
359
410
  def inspect_llvm(self):
360
- '''
411
+ """
361
412
  Returns the LLVM IR for this kernel.
362
- '''
413
+ """
363
414
  return self._codelibrary.get_llvm_str()
364
415
 
365
416
  def inspect_asm(self, cc):
366
- '''
417
+ """
367
418
  Returns the PTX code for this kernel.
368
- '''
419
+ """
369
420
  return self._codelibrary.get_asm_str(cc=cc)
370
421
 
371
422
  def inspect_sass_cfg(self):
372
- '''
423
+ """
373
424
  Returns the CFG of the SASS for this kernel.
374
425
 
375
426
  Requires nvdisasm to be available on the PATH.
376
- '''
427
+ """
377
428
  return self._codelibrary.get_sass_cfg()
378
429
 
379
430
  def inspect_sass(self):
380
- '''
431
+ """
381
432
  Returns the SASS code for this kernel.
382
433
 
383
434
  Requires nvdisasm to be available on the PATH.
384
- '''
435
+ """
385
436
  return self._codelibrary.get_sass()
386
437
 
387
438
  def inspect_types(self, file=None):
388
- '''
439
+ """
389
440
  Produce a dump of the Python source of this function annotated with the
390
441
  corresponding Numba IR and type information. The dump is written to
391
442
  *file*, or *sys.stdout* if *file* is *None*.
392
- '''
443
+ """
393
444
  if self._type_annotation is None:
394
445
  raise ValueError("Type annotation is not available")
395
446
 
@@ -397,12 +448,12 @@ class _Kernel(serialize.ReduceMixin):
397
448
  file = sys.stdout
398
449
 
399
450
  print("%s %s" % (self.entry_name, self.argument_types), file=file)
400
- print('-' * 80, file=file)
451
+ print("-" * 80, file=file)
401
452
  print(self._type_annotation, file=file)
402
- print('=' * 80, file=file)
453
+ print("=" * 80, file=file)
403
454
 
404
455
  def max_cooperative_grid_blocks(self, blockdim, dynsmemsize=0):
405
- '''
456
+ """
406
457
  Calculates the maximum number of blocks that can be launched for this
407
458
  kernel in a cooperative grid in the current context, for the given block
408
459
  and dynamic shared memory sizes.
@@ -411,15 +462,15 @@ class _Kernel(serialize.ReduceMixin):
411
462
  a tuple for 2D or 3D blocks.
412
463
  :param dynsmemsize: Dynamic shared memory size in bytes.
413
464
  :return: The maximum number of blocks in the grid.
414
- '''
465
+ """
415
466
  ctx = get_context()
416
467
  cufunc = self._codelibrary.get_cufunc()
417
468
 
418
469
  if isinstance(blockdim, tuple):
419
470
  blockdim = functools.reduce(lambda x, y: x * y, blockdim)
420
- active_per_sm = ctx.get_active_blocks_per_multiprocessor(cufunc,
421
- blockdim,
422
- dynsmemsize)
471
+ active_per_sm = ctx.get_active_blocks_per_multiprocessor(
472
+ cufunc, blockdim, dynsmemsize
473
+ )
423
474
  sm_count = ctx.device.MULTIPROCESSOR_COUNT
424
475
  return active_per_sm * sm_count
425
476
 
@@ -435,7 +486,7 @@ class _Kernel(serialize.ReduceMixin):
435
486
  excmem.memset(0, stream=stream)
436
487
 
437
488
  # Prepare arguments
438
- retr = [] # hold functors for writeback
489
+ retr = [] # hold functors for writeback
439
490
 
440
491
  kernelargs = []
441
492
  for t, v in zip(self.argument_types, args):
@@ -449,46 +500,51 @@ class _Kernel(serialize.ReduceMixin):
449
500
  stream_handle = stream and stream.handle or zero_stream
450
501
 
451
502
  # Invoke kernel
452
- driver.launch_kernel(cufunc.handle,
453
- *griddim,
454
- *blockdim,
455
- sharedmem,
456
- stream_handle,
457
- kernelargs,
458
- cooperative=self.cooperative)
503
+ driver.launch_kernel(
504
+ cufunc.handle,
505
+ *griddim,
506
+ *blockdim,
507
+ sharedmem,
508
+ stream_handle,
509
+ kernelargs,
510
+ cooperative=self.cooperative,
511
+ )
459
512
 
460
513
  if self.debug:
461
514
  driver.device_to_host(ctypes.addressof(excval), excmem, excsz)
462
515
  if excval.value != 0:
463
516
  # An error occurred
464
517
  def load_symbol(name):
465
- mem, sz = cufunc.module.get_global_symbol("%s__%s__" %
466
- (cufunc.name,
467
- name))
518
+ mem, sz = cufunc.module.get_global_symbol(
519
+ "%s__%s__" % (cufunc.name, name)
520
+ )
468
521
  val = ctypes.c_int()
469
522
  driver.device_to_host(ctypes.addressof(val), mem, sz)
470
523
  return val.value
471
524
 
472
- tid = [load_symbol("tid" + i) for i in 'zyx']
473
- ctaid = [load_symbol("ctaid" + i) for i in 'zyx']
525
+ tid = [load_symbol("tid" + i) for i in "zyx"]
526
+ ctaid = [load_symbol("ctaid" + i) for i in "zyx"]
474
527
  code = excval.value
475
528
  exccls, exc_args, loc = self.call_helper.get_exception(code)
476
529
  # Prefix the exception message with the source location
477
530
  if loc is None:
478
- locinfo = ''
531
+ locinfo = ""
479
532
  else:
480
533
  sym, filepath, lineno = loc
481
534
  filepath = os.path.abspath(filepath)
482
- locinfo = 'In function %r, file %s, line %s, ' % (sym,
483
- filepath,
484
- lineno,)
535
+ locinfo = "In function %r, file %s, line %s, " % (
536
+ sym,
537
+ filepath,
538
+ lineno,
539
+ )
485
540
  # Prefix the exception message with the thread position
486
541
  prefix = "%stid=%s ctaid=%s" % (locinfo, tid, ctaid)
487
542
  if exc_args:
488
- exc_args = ("%s: %s" % (prefix, exc_args[0]),) + \
489
- exc_args[1:]
543
+ exc_args = ("%s: %s" % (prefix, exc_args[0]),) + exc_args[
544
+ 1:
545
+ ]
490
546
  else:
491
- exc_args = prefix,
547
+ exc_args = (prefix,)
492
548
  raise exccls(*exc_args)
493
549
 
494
550
  # retrieve auto converted arrays
@@ -502,11 +558,7 @@ class _Kernel(serialize.ReduceMixin):
502
558
 
503
559
  # map the arguments using any extension you've registered
504
560
  for extension in reversed(self.extensions):
505
- ty, val = extension.prepare_args(
506
- ty,
507
- val,
508
- stream=stream,
509
- retr=retr)
561
+ ty, val = extension.prepare_args(ty, val, stream=stream, retr=retr)
510
562
 
511
563
  if isinstance(ty, types.Array):
512
564
  devary = wrap_arg(val).to_device(retr, stream)
@@ -592,8 +644,9 @@ class _Kernel(serialize.ReduceMixin):
592
644
  class ForAll(object):
593
645
  def __init__(self, dispatcher, ntasks, tpb, stream, sharedmem):
594
646
  if ntasks < 0:
595
- raise ValueError("Can't create ForAll with negative task count: %s"
596
- % ntasks)
647
+ raise ValueError(
648
+ "Can't create ForAll with negative task count: %s" % ntasks
649
+ )
597
650
  self.dispatcher = dispatcher
598
651
  self.ntasks = ntasks
599
652
  self.thread_per_block = tpb
@@ -611,8 +664,9 @@ class ForAll(object):
611
664
  blockdim = self._compute_thread_per_block(specialized)
612
665
  griddim = (self.ntasks + blockdim - 1) // blockdim
613
666
 
614
- return specialized[griddim, blockdim, self.stream,
615
- self.sharedmem](*args)
667
+ return specialized[griddim, blockdim, self.stream, self.sharedmem](
668
+ *args
669
+ )
616
670
 
617
671
  def _compute_thread_per_block(self, dispatcher):
618
672
  tpb = self.thread_per_block
@@ -627,7 +681,7 @@ class ForAll(object):
627
681
  kernel = next(iter(dispatcher.overloads.values()))
628
682
  kwargs = dict(
629
683
  func=kernel._codelibrary.get_cufunc(),
630
- b2d_func=0, # dynamic-shared memory is constant to blksz
684
+ b2d_func=0, # dynamic-shared memory is constant to blksz
631
685
  memsize=self.sharedmem,
632
686
  blocksizelimit=1024,
633
687
  )
@@ -658,13 +712,16 @@ class _LaunchConfiguration:
658
712
  min_grid_size = 128
659
713
  grid_size = griddim[0] * griddim[1] * griddim[2]
660
714
  if grid_size < min_grid_size:
661
- msg = (f"Grid size {grid_size} will likely result in GPU "
662
- "under-utilization due to low occupancy.")
715
+ msg = (
716
+ f"Grid size {grid_size} will likely result in GPU "
717
+ "under-utilization due to low occupancy."
718
+ )
663
719
  warn(NumbaPerformanceWarning(msg))
664
720
 
665
721
  def __call__(self, *args):
666
- return self.dispatcher.call(args, self.griddim, self.blockdim,
667
- self.stream, self.sharedmem)
722
+ return self.dispatcher.call(
723
+ args, self.griddim, self.blockdim, self.stream, self.sharedmem
724
+ )
668
725
 
669
726
 
670
727
  class CUDACacheImpl(CacheImpl):
@@ -689,6 +746,7 @@ class CUDACache(Cache):
689
746
  """
690
747
  Implements a cache that saves and loads CUDA kernels and compile results.
691
748
  """
749
+
692
750
  _impl_class = CUDACacheImpl
693
751
 
694
752
  def load_overload(self, sig, target_context):
@@ -696,12 +754,13 @@ class CUDACache(Cache):
696
754
  # initialized. To initialize the correct (i.e. CUDA) target, we need to
697
755
  # enforce that the current target is the CUDA target.
698
756
  from numba.core.target_extension import target_override
699
- with target_override('cuda'):
757
+
758
+ with target_override("cuda"):
700
759
  return super().load_overload(sig, target_context)
701
760
 
702
761
 
703
762
  class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
704
- '''
763
+ """
705
764
  CUDA Dispatcher object. When configured and called, the dispatcher will
706
765
  specialize itself for the given arguments (if no suitable specialized
707
766
  version already exists) & compute capability, and launch on the device
@@ -709,7 +768,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
709
768
 
710
769
  Dispatcher objects are not to be constructed by the user, but instead are
711
770
  created using the :func:`numba.cuda.jit` decorator.
712
- '''
771
+ """
713
772
 
714
773
  # Whether to fold named arguments and default values. Default values are
715
774
  # presently unsupported on CUDA, so we can leave this as False in all
@@ -719,8 +778,9 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
719
778
  targetdescr = cuda_target
720
779
 
721
780
  def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
722
- super().__init__(py_func, targetoptions=targetoptions,
723
- pipeline_class=pipeline_class)
781
+ super().__init__(
782
+ py_func, targetoptions=targetoptions, pipeline_class=pipeline_class
783
+ )
724
784
 
725
785
  # The following properties are for specialization of CUDADispatchers. A
726
786
  # specialized CUDADispatcher is one that is compiled for exactly one
@@ -748,7 +808,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
748
808
 
749
809
  def __getitem__(self, args):
750
810
  if len(args) not in [2, 3, 4]:
751
- raise ValueError('must specify at least the griddim and blockdim')
811
+ raise ValueError("must specify at least the griddim and blockdim")
752
812
  return self.configure(*args)
753
813
 
754
814
  def forall(self, ntasks, tpb=0, stream=0, sharedmem=0):
@@ -775,7 +835,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
775
835
 
776
836
  @property
777
837
  def extensions(self):
778
- '''
838
+ """
779
839
  A list of objects that must have a `prepare_args` function. When a
780
840
  specialized kernel is called, each argument will be passed through
781
841
  to the `prepare_args` (from the last object in this list to the
@@ -791,17 +851,17 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
791
851
  will be passed in turn to the next right-most `extension`. After all
792
852
  the extensions have been called, the resulting `(ty, val)` will be
793
853
  passed into Numba's default argument marshalling logic.
794
- '''
795
- return self.targetoptions.get('extensions')
854
+ """
855
+ return self.targetoptions.get("extensions")
796
856
 
797
857
  def __call__(self, *args, **kwargs):
798
858
  # An attempt to launch an unconfigured kernel
799
859
  raise ValueError(missing_launch_config_msg)
800
860
 
801
861
  def call(self, args, griddim, blockdim, stream, sharedmem):
802
- '''
862
+ """
803
863
  Compile if necessary and invoke this kernel with *args*.
804
- '''
864
+ """
805
865
  if self.specialized:
806
866
  kernel = next(iter(self.overloads.values()))
807
867
  else:
@@ -824,28 +884,30 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
824
884
  if cuda.is_cuda_array(val):
825
885
  # When typing, we don't need to synchronize on the array's
826
886
  # stream - this is done when the kernel is launched.
827
- return typeof(cuda.as_cuda_array(val, sync=False),
828
- Purpose.argument)
887
+ return typeof(
888
+ cuda.as_cuda_array(val, sync=False), Purpose.argument
889
+ )
829
890
  else:
830
891
  raise
831
892
 
832
893
  def specialize(self, *args):
833
- '''
894
+ """
834
895
  Create a new instance of this dispatcher specialized for the given
835
896
  *args*.
836
- '''
897
+ """
837
898
  cc = get_current_device().compute_capability
838
899
  argtypes = tuple(self.typeof_pyval(a) for a in args)
839
900
  if self.specialized:
840
- raise RuntimeError('Dispatcher already specialized')
901
+ raise RuntimeError("Dispatcher already specialized")
841
902
 
842
903
  specialization = self.specializations.get((cc, argtypes))
843
904
  if specialization:
844
905
  return specialization
845
906
 
846
907
  targetoptions = self.targetoptions
847
- specialization = CUDADispatcher(self.py_func,
848
- targetoptions=targetoptions)
908
+ specialization = CUDADispatcher(
909
+ self.py_func, targetoptions=targetoptions
910
+ )
849
911
  specialization.compile(argtypes)
850
912
  specialization.disable_compile()
851
913
  specialization._specialized = True
@@ -860,7 +922,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
860
922
  return self._specialized
861
923
 
862
924
  def get_regs_per_thread(self, signature=None):
863
- '''
925
+ """
864
926
  Returns the number of registers used by each thread in this kernel for
865
927
  the device in the current context.
866
928
 
@@ -869,17 +931,19 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
869
931
  kernel.
870
932
  :return: The number of registers used by the compiled variant of the
871
933
  kernel for the given signature and current device.
872
- '''
934
+ """
873
935
  if signature is not None:
874
936
  return self.overloads[signature.args].regs_per_thread
875
937
  if self.specialized:
876
938
  return next(iter(self.overloads.values())).regs_per_thread
877
939
  else:
878
- return {sig: overload.regs_per_thread
879
- for sig, overload in self.overloads.items()}
940
+ return {
941
+ sig: overload.regs_per_thread
942
+ for sig, overload in self.overloads.items()
943
+ }
880
944
 
881
945
  def get_const_mem_size(self, signature=None):
882
- '''
946
+ """
883
947
  Returns the size in bytes of constant memory used by this kernel for
884
948
  the device in the current context.
885
949
 
@@ -889,17 +953,19 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
889
953
  :return: The size in bytes of constant memory allocated by the
890
954
  compiled variant of the kernel for the given signature and
891
955
  current device.
892
- '''
956
+ """
893
957
  if signature is not None:
894
958
  return self.overloads[signature.args].const_mem_size
895
959
  if self.specialized:
896
960
  return next(iter(self.overloads.values())).const_mem_size
897
961
  else:
898
- return {sig: overload.const_mem_size
899
- for sig, overload in self.overloads.items()}
962
+ return {
963
+ sig: overload.const_mem_size
964
+ for sig, overload in self.overloads.items()
965
+ }
900
966
 
901
967
  def get_shared_mem_per_block(self, signature=None):
902
- '''
968
+ """
903
969
  Returns the size in bytes of statically allocated shared memory
904
970
  for this kernel.
905
971
 
@@ -908,17 +974,19 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
908
974
  specialized kernel.
909
975
  :return: The amount of shared memory allocated by the compiled variant
910
976
  of the kernel for the given signature and current device.
911
- '''
977
+ """
912
978
  if signature is not None:
913
979
  return self.overloads[signature.args].shared_mem_per_block
914
980
  if self.specialized:
915
981
  return next(iter(self.overloads.values())).shared_mem_per_block
916
982
  else:
917
- return {sig: overload.shared_mem_per_block
918
- for sig, overload in self.overloads.items()}
983
+ return {
984
+ sig: overload.shared_mem_per_block
985
+ for sig, overload in self.overloads.items()
986
+ }
919
987
 
920
988
  def get_max_threads_per_block(self, signature=None):
921
- '''
989
+ """
922
990
  Returns the maximum allowable number of threads per block
923
991
  for this kernel. Exceeding this threshold will result in
924
992
  the kernel failing to launch.
@@ -929,17 +997,19 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
929
997
  :return: The maximum allowable threads per block for the compiled
930
998
  variant of the kernel for the given signature and current
931
999
  device.
932
- '''
1000
+ """
933
1001
  if signature is not None:
934
1002
  return self.overloads[signature.args].max_threads_per_block
935
1003
  if self.specialized:
936
1004
  return next(iter(self.overloads.values())).max_threads_per_block
937
1005
  else:
938
- return {sig: overload.max_threads_per_block
939
- for sig, overload in self.overloads.items()}
1006
+ return {
1007
+ sig: overload.max_threads_per_block
1008
+ for sig, overload in self.overloads.items()
1009
+ }
940
1010
 
941
1011
  def get_local_mem_per_thread(self, signature=None):
942
- '''
1012
+ """
943
1013
  Returns the size in bytes of local memory per thread
944
1014
  for this kernel.
945
1015
 
@@ -948,14 +1018,16 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
948
1018
  specialized kernel.
949
1019
  :return: The amount of local memory allocated by the compiled variant
950
1020
  of the kernel for the given signature and current device.
951
- '''
1021
+ """
952
1022
  if signature is not None:
953
1023
  return self.overloads[signature.args].local_mem_per_thread
954
1024
  if self.specialized:
955
1025
  return next(iter(self.overloads.values())).local_mem_per_thread
956
1026
  else:
957
- return {sig: overload.local_mem_per_thread
958
- for sig, overload in self.overloads.items()}
1027
+ return {
1028
+ sig: overload.local_mem_per_thread
1029
+ for sig, overload in self.overloads.items()
1030
+ }
959
1031
 
960
1032
  def get_call_template(self, args, kws):
961
1033
  # Originally copied from _DispatcherBase.get_call_template. This
@@ -983,7 +1055,8 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
983
1055
  name = "CallTemplate({0})".format(func_name)
984
1056
 
985
1057
  call_template = typing.make_concrete_template(
986
- name, key=func_name, signatures=self.nopython_signatures)
1058
+ name, key=func_name, signatures=self.nopython_signatures
1059
+ )
987
1060
  pysig = utils.pysignature(self.py_func)
988
1061
 
989
1062
  return call_template, pysig, args, kws
@@ -998,33 +1071,36 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
998
1071
  """
999
1072
  if args not in self.overloads:
1000
1073
  with self._compiling_counter:
1001
-
1002
- debug = self.targetoptions.get('debug')
1003
- lineinfo = self.targetoptions.get('lineinfo')
1004
- inline = self.targetoptions.get('inline')
1005
- fastmath = self.targetoptions.get('fastmath')
1074
+ debug = self.targetoptions.get("debug")
1075
+ lineinfo = self.targetoptions.get("lineinfo")
1076
+ inline = self.targetoptions.get("inline")
1077
+ fastmath = self.targetoptions.get("fastmath")
1006
1078
 
1007
1079
  nvvm_options = {
1008
- 'opt': 3 if self.targetoptions.get('opt') else 0,
1009
- 'fastmath': fastmath
1080
+ "opt": 3 if self.targetoptions.get("opt") else 0,
1081
+ "fastmath": fastmath,
1010
1082
  }
1011
1083
 
1012
1084
  if debug:
1013
- nvvm_options['g'] = None
1085
+ nvvm_options["g"] = None
1014
1086
 
1015
1087
  cc = get_current_device().compute_capability
1016
- cres = compile_cuda(self.py_func, return_type, args,
1017
- debug=debug,
1018
- lineinfo=lineinfo,
1019
- inline=inline,
1020
- fastmath=fastmath,
1021
- nvvm_options=nvvm_options,
1022
- cc=cc)
1088
+ cres = compile_cuda(
1089
+ self.py_func,
1090
+ return_type,
1091
+ args,
1092
+ debug=debug,
1093
+ lineinfo=lineinfo,
1094
+ inline=inline,
1095
+ fastmath=fastmath,
1096
+ nvvm_options=nvvm_options,
1097
+ cc=cc,
1098
+ )
1023
1099
  self.overloads[args] = cres
1024
1100
 
1025
- cres.target_context.insert_user_function(cres.entry_point,
1026
- cres.fndesc,
1027
- [cres.library])
1101
+ cres.target_context.insert_user_function(
1102
+ cres.entry_point, cres.fndesc, [cres.library]
1103
+ )
1028
1104
  else:
1029
1105
  cres = self.overloads[args]
1030
1106
 
@@ -1035,11 +1111,12 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1035
1111
  self._insert(c_sig, kernel, cuda=True)
1036
1112
  self.overloads[argtypes] = kernel
1037
1113
 
1114
+ @global_compiler_lock
1038
1115
  def compile(self, sig):
1039
- '''
1116
+ """
1040
1117
  Compile and bind to the current context a version of this kernel
1041
1118
  specialized for the given signature.
1042
- '''
1119
+ """
1043
1120
  argtypes, return_type = sigutils.normalize_signature(sig)
1044
1121
  assert return_type is None or return_type == types.none
1045
1122
 
@@ -1072,15 +1149,15 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1072
1149
  return kernel
1073
1150
 
1074
1151
  def inspect_llvm(self, signature=None):
1075
- '''
1152
+ """
1076
1153
  Return the LLVM IR for this kernel.
1077
1154
 
1078
1155
  :param signature: A tuple of argument types.
1079
1156
  :return: The LLVM IR for the given signature, or a dict of LLVM IR
1080
1157
  for all previously-encountered signatures.
1081
1158
 
1082
- '''
1083
- device = self.targetoptions.get('device')
1159
+ """
1160
+ device = self.targetoptions.get("device")
1084
1161
  if signature is not None:
1085
1162
  if device:
1086
1163
  return self.overloads[signature].library.get_llvm_str()
@@ -1088,23 +1165,27 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1088
1165
  return self.overloads[signature].inspect_llvm()
1089
1166
  else:
1090
1167
  if device:
1091
- return {sig: overload.library.get_llvm_str()
1092
- for sig, overload in self.overloads.items()}
1168
+ return {
1169
+ sig: overload.library.get_llvm_str()
1170
+ for sig, overload in self.overloads.items()
1171
+ }
1093
1172
  else:
1094
- return {sig: overload.inspect_llvm()
1095
- for sig, overload in self.overloads.items()}
1173
+ return {
1174
+ sig: overload.inspect_llvm()
1175
+ for sig, overload in self.overloads.items()
1176
+ }
1096
1177
 
1097
1178
  def inspect_asm(self, signature=None):
1098
- '''
1179
+ """
1099
1180
  Return this kernel's PTX assembly code for for the device in the
1100
1181
  current context.
1101
1182
 
1102
1183
  :param signature: A tuple of argument types.
1103
1184
  :return: The PTX code for the given signature, or a dict of PTX codes
1104
1185
  for all previously-encountered signatures.
1105
- '''
1186
+ """
1106
1187
  cc = get_current_device().compute_capability
1107
- device = self.targetoptions.get('device')
1188
+ device = self.targetoptions.get("device")
1108
1189
  if signature is not None:
1109
1190
  if device:
1110
1191
  return self.overloads[signature].library.get_asm_str(cc)
@@ -1112,14 +1193,18 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1112
1193
  return self.overloads[signature].inspect_asm(cc)
1113
1194
  else:
1114
1195
  if device:
1115
- return {sig: overload.library.get_asm_str(cc)
1116
- for sig, overload in self.overloads.items()}
1196
+ return {
1197
+ sig: overload.library.get_asm_str(cc)
1198
+ for sig, overload in self.overloads.items()
1199
+ }
1117
1200
  else:
1118
- return {sig: overload.inspect_asm(cc)
1119
- for sig, overload in self.overloads.items()}
1201
+ return {
1202
+ sig: overload.inspect_asm(cc)
1203
+ for sig, overload in self.overloads.items()
1204
+ }
1120
1205
 
1121
1206
  def inspect_sass_cfg(self, signature=None):
1122
- '''
1207
+ """
1123
1208
  Return this kernel's CFG for the device in the current context.
1124
1209
 
1125
1210
  :param signature: A tuple of argument types.
@@ -1129,18 +1214,20 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1129
1214
  The CFG for the device in the current context is returned.
1130
1215
 
1131
1216
  Requires nvdisasm to be available on the PATH.
1132
- '''
1133
- if self.targetoptions.get('device'):
1134
- raise RuntimeError('Cannot get the CFG of a device function')
1217
+ """
1218
+ if self.targetoptions.get("device"):
1219
+ raise RuntimeError("Cannot get the CFG of a device function")
1135
1220
 
1136
1221
  if signature is not None:
1137
1222
  return self.overloads[signature].inspect_sass_cfg()
1138
1223
  else:
1139
- return {sig: defn.inspect_sass_cfg()
1140
- for sig, defn in self.overloads.items()}
1224
+ return {
1225
+ sig: defn.inspect_sass_cfg()
1226
+ for sig, defn in self.overloads.items()
1227
+ }
1141
1228
 
1142
1229
  def inspect_sass(self, signature=None):
1143
- '''
1230
+ """
1144
1231
  Return this kernel's SASS assembly code for for the device in the
1145
1232
  current context.
1146
1233
 
@@ -1151,22 +1238,23 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1151
1238
  SASS for the device in the current context is returned.
1152
1239
 
1153
1240
  Requires nvdisasm to be available on the PATH.
1154
- '''
1155
- if self.targetoptions.get('device'):
1156
- raise RuntimeError('Cannot inspect SASS of a device function')
1241
+ """
1242
+ if self.targetoptions.get("device"):
1243
+ raise RuntimeError("Cannot inspect SASS of a device function")
1157
1244
 
1158
1245
  if signature is not None:
1159
1246
  return self.overloads[signature].inspect_sass()
1160
1247
  else:
1161
- return {sig: defn.inspect_sass()
1162
- for sig, defn in self.overloads.items()}
1248
+ return {
1249
+ sig: defn.inspect_sass() for sig, defn in self.overloads.items()
1250
+ }
1163
1251
 
1164
1252
  def inspect_types(self, file=None):
1165
- '''
1253
+ """
1166
1254
  Produce a dump of the Python source of this function annotated with the
1167
1255
  corresponding Numba IR and type information. The dump is written to
1168
1256
  *file*, or *sys.stdout* if *file* is *None*.
1169
- '''
1257
+ """
1170
1258
  if file is None:
1171
1259
  file = sys.stdout
1172
1260
 
@@ -1186,5 +1274,4 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1186
1274
  Reduce the instance for serialization.
1187
1275
  Compiled definitions are discarded.
1188
1276
  """
1189
- return dict(py_func=self.py_func,
1190
- targetoptions=self.targetoptions)
1277
+ return dict(py_func=self.py_func, targetoptions=self.targetoptions)