numba-cuda 0.21.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.pth +4 -0
- _numba_cuda_redirector.py +89 -0
- numba_cuda/VERSION +1 -0
- numba_cuda/__init__.py +6 -0
- numba_cuda/_version.py +11 -0
- numba_cuda/numba/cuda/__init__.py +70 -0
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
- numba_cuda/numba/cuda/api.py +577 -0
- numba_cuda/numba/cuda/api_util.py +76 -0
- numba_cuda/numba/cuda/args.py +72 -0
- numba_cuda/numba/cuda/bf16.py +397 -0
- numba_cuda/numba/cuda/cache_hints.py +287 -0
- numba_cuda/numba/cuda/cext/__init__.py +2 -0
- numba_cuda/numba/cuda/cext/_devicearray.cp313-win_amd64.pyd +0 -0
- numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
- numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
- numba_cuda/numba/cuda/cext/_dispatcher.cp313-win_amd64.pyd +0 -0
- numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
- numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
- numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
- numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
- numba_cuda/numba/cuda/cext/_helperlib.cp313-win_amd64.pyd +0 -0
- numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
- numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
- numba_cuda/numba/cuda/cext/_typeconv.cp313-win_amd64.pyd +0 -0
- numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
- numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
- numba_cuda/numba/cuda/cext/_typeof.h +19 -0
- numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
- numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
- numba_cuda/numba/cuda/cext/mviewbuf.cp313-win_amd64.pyd +0 -0
- numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
- numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
- numba_cuda/numba/cuda/cg.py +67 -0
- numba_cuda/numba/cuda/cgutils.py +1294 -0
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +541 -0
- numba_cuda/numba/cuda/compiler.py +1396 -0
- numba_cuda/numba/cuda/core/analysis.py +758 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
- numba_cuda/numba/cuda/core/base.py +1332 -0
- numba_cuda/numba/cuda/core/boxing.py +1411 -0
- numba_cuda/numba/cuda/core/bytecode.py +728 -0
- numba_cuda/numba/cuda/core/byteflow.py +2346 -0
- numba_cuda/numba/cuda/core/caching.py +744 -0
- numba_cuda/numba/cuda/core/callconv.py +392 -0
- numba_cuda/numba/cuda/core/codegen.py +171 -0
- numba_cuda/numba/cuda/core/compiler.py +199 -0
- numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +650 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/controlflow.py +989 -0
- numba_cuda/numba/cuda/core/entrypoints.py +57 -0
- numba_cuda/numba/cuda/core/environment.py +66 -0
- numba_cuda/numba/cuda/core/errors.py +917 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/generators.py +387 -0
- numba_cuda/numba/cuda/core/imputils.py +509 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
- numba_cuda/numba/cuda/core/interpreter.py +3617 -0
- numba_cuda/numba/cuda/core/ir.py +1812 -0
- numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
- numba_cuda/numba/cuda/core/optional.py +129 -0
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
- numba_cuda/numba/cuda/core/registry.py +46 -0
- numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
- numba_cuda/numba/cuda/core/sigutils.py +68 -0
- numba_cuda/numba/cuda/core/ssa.py +498 -0
- numba_cuda/numba/cuda/core/targetconfig.py +330 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +956 -0
- numba_cuda/numba/cuda/core/typed_passes.py +867 -0
- numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
- numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
- numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
- numba_cuda/numba/cuda/cpython/iterators.py +167 -0
- numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
- numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
- numba_cuda/numba/cuda/cpython/slicing.py +322 -0
- numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
- numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
- numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
- numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
- numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
- numba_cuda/numba/cuda/cuda_paths.py +691 -0
- numba_cuda/numba/cuda/cudadecl.py +556 -0
- numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
- numba_cuda/numba/cuda/cudadrv/devicearray.py +951 -0
- numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +3222 -0
- numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +558 -0
- numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
- numba_cuda/numba/cuda/cudadrv/error.py +48 -0
- numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
- numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
- numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
- numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
- numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
- numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
- numba_cuda/numba/cuda/cudaimpl.py +995 -0
- numba_cuda/numba/cuda/cudamath.py +149 -0
- numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
- numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
- numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
- numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
- numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
- numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
- numba_cuda/numba/cuda/datamodel/manager.py +11 -0
- numba_cuda/numba/cuda/datamodel/models.py +9 -0
- numba_cuda/numba/cuda/datamodel/packer.py +9 -0
- numba_cuda/numba/cuda/datamodel/registry.py +11 -0
- numba_cuda/numba/cuda/datamodel/testing.py +11 -0
- numba_cuda/numba/cuda/debuginfo.py +903 -0
- numba_cuda/numba/cuda/decorators.py +294 -0
- numba_cuda/numba/cuda/descriptor.py +35 -0
- numba_cuda/numba/cuda/device_init.py +158 -0
- numba_cuda/numba/cuda/deviceufunc.py +1021 -0
- numba_cuda/numba/cuda/dispatcher.py +2463 -0
- numba_cuda/numba/cuda/errors.py +72 -0
- numba_cuda/numba/cuda/extending.py +697 -0
- numba_cuda/numba/cuda/flags.py +178 -0
- numba_cuda/numba/cuda/fp16.py +357 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/initialize.py +24 -0
- numba_cuda/numba/cuda/intrinsic_wrapper.py +41 -0
- numba_cuda/numba/cuda/intrinsics.py +382 -0
- numba_cuda/numba/cuda/itanium_mangler.py +214 -0
- numba_cuda/numba/cuda/kernels/__init__.py +2 -0
- numba_cuda/numba/cuda/kernels/reduction.py +265 -0
- numba_cuda/numba/cuda/kernels/transpose.py +65 -0
- numba_cuda/numba/cuda/libdevice.py +3386 -0
- numba_cuda/numba/cuda/libdevicedecl.py +20 -0
- numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
- numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
- numba_cuda/numba/cuda/locks.py +19 -0
- numba_cuda/numba/cuda/lowering.py +1951 -0
- numba_cuda/numba/cuda/mathimpl.py +374 -0
- numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
- numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
- numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
- numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
- numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
- numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
- numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
- numba_cuda/numba/cuda/misc/appdirs.py +594 -0
- numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
- numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
- numba_cuda/numba/cuda/misc/dump_style.py +41 -0
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
- numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
- numba_cuda/numba/cuda/misc/literal.py +28 -0
- numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
- numba_cuda/numba/cuda/misc/special.py +94 -0
- numba_cuda/numba/cuda/models.py +56 -0
- numba_cuda/numba/cuda/np/arraymath.py +5130 -0
- numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
- numba_cuda/numba/cuda/np/extensions.py +11 -0
- numba_cuda/numba/cuda/np/linalg.py +3087 -0
- numba_cuda/numba/cuda/np/math/__init__.py +0 -0
- numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
- numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
- numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
- numba_cuda/numba/cuda/np/npdatetime.py +969 -0
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
- numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
- numba_cuda/numba/cuda/np/numpy_support.py +798 -0
- numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
- numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
- numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
- numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
- numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
- numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
- numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
- numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
- numba_cuda/numba/cuda/nvvmutils.py +254 -0
- numba_cuda/numba/cuda/printimpl.py +126 -0
- numba_cuda/numba/cuda/random.py +308 -0
- numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
- numba_cuda/numba/cuda/serialize.py +267 -0
- numba_cuda/numba/cuda/simulator/__init__.py +63 -0
- numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
- numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
- numba_cuda/numba/cuda/simulator/api.py +179 -0
- numba_cuda/numba/cuda/simulator/bf16.py +4 -0
- numba_cuda/numba/cuda/simulator/compiler.py +38 -0
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
- numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
- numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
- numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
- numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
- numba_cuda/numba/cuda/simulator/kernel.py +320 -0
- numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
- numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
- numba_cuda/numba/cuda/simulator/reduction.py +19 -0
- numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
- numba_cuda/numba/cuda/simulator_init.py +18 -0
- numba_cuda/numba/cuda/stubs.py +635 -0
- numba_cuda/numba/cuda/target.py +505 -0
- numba_cuda/numba/cuda/testing.py +347 -0
- numba_cuda/numba/cuda/tests/__init__.py +62 -0
- numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
- numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
- numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
- numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +187 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +198 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
- numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
- numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
- numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
- numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
- numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
- numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
- numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +889 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
- numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
- numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
- numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
- numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
- numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +331 -0
- numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
- numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
- numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
- numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
- numba_cuda/numba/cuda/tests/data/error.cu +12 -0
- numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
- numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
- numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
- numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
- numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +391 -0
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
- numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
- numba_cuda/numba/cuda/tests/support.py +900 -0
- numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
- numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
- numba_cuda/numba/cuda/typeconv/rules.py +63 -0
- numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
- numba_cuda/numba/cuda/types/__init__.py +233 -0
- numba_cuda/numba/cuda/types/__init__.pyi +167 -0
- numba_cuda/numba/cuda/types/abstract.py +9 -0
- numba_cuda/numba/cuda/types/common.py +9 -0
- numba_cuda/numba/cuda/types/containers.py +9 -0
- numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
- numba_cuda/numba/cuda/types/cuda_common.py +110 -0
- numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
- numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
- numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
- numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
- numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
- numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
- numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
- numba_cuda/numba/cuda/types/ext_types.py +101 -0
- numba_cuda/numba/cuda/types/function_type.py +11 -0
- numba_cuda/numba/cuda/types/functions.py +9 -0
- numba_cuda/numba/cuda/types/iterators.py +9 -0
- numba_cuda/numba/cuda/types/misc.py +9 -0
- numba_cuda/numba/cuda/types/npytypes.py +9 -0
- numba_cuda/numba/cuda/types/scalars.py +9 -0
- numba_cuda/numba/cuda/typing/__init__.py +19 -0
- numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
- numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
- numba_cuda/numba/cuda/typing/bufproto.py +70 -0
- numba_cuda/numba/cuda/typing/builtins.py +1209 -0
- numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
- numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
- numba_cuda/numba/cuda/typing/collections.py +138 -0
- numba_cuda/numba/cuda/typing/context.py +782 -0
- numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
- numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/listdecl.py +147 -0
- numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
- numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
- numba_cuda/numba/cuda/typing/npydecl.py +749 -0
- numba_cuda/numba/cuda/typing/setdecl.py +115 -0
- numba_cuda/numba/cuda/typing/templates.py +1446 -0
- numba_cuda/numba/cuda/typing/typeof.py +301 -0
- numba_cuda/numba/cuda/ufuncs.py +746 -0
- numba_cuda/numba/cuda/utils.py +724 -0
- numba_cuda/numba/cuda/vector_types.py +214 -0
- numba_cuda/numba/cuda/vectorizers.py +260 -0
- numba_cuda-0.21.1.dist-info/METADATA +109 -0
- numba_cuda-0.21.1.dist-info/RECORD +488 -0
- numba_cuda-0.21.1.dist-info/WHEEL +5 -0
- numba_cuda-0.21.1.dist-info/licenses/LICENSE +26 -0
- numba_cuda-0.21.1.dist-info/licenses/LICENSE.numba +24 -0
- numba_cuda-0.21.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2463 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import ctypes
|
|
8
|
+
import collections
|
|
9
|
+
import functools
|
|
10
|
+
import types as pytypes
|
|
11
|
+
import weakref
|
|
12
|
+
from contextlib import ExitStack
|
|
13
|
+
from abc import abstractmethod
|
|
14
|
+
import uuid
|
|
15
|
+
import re
|
|
16
|
+
from warnings import warn
|
|
17
|
+
|
|
18
|
+
from numba.cuda.core import errors
|
|
19
|
+
from numba.cuda import serialize, utils
|
|
20
|
+
from numba import cuda
|
|
21
|
+
|
|
22
|
+
from numba.cuda.core.compiler_lock import global_compiler_lock
|
|
23
|
+
from numba.cuda.typeconv.rules import default_type_manager
|
|
24
|
+
from numba.cuda.typing.templates import fold_arguments
|
|
25
|
+
from numba.cuda.typing.typeof import Purpose, typeof
|
|
26
|
+
|
|
27
|
+
from numba.cuda import typing, types
|
|
28
|
+
from numba.cuda.types import ext_types
|
|
29
|
+
from numba.cuda.api import get_current_device
|
|
30
|
+
from numba.cuda.args import wrap_arg
|
|
31
|
+
from numba.cuda.core.bytecode import get_code_object
|
|
32
|
+
from numba.cuda.compiler import (
|
|
33
|
+
compile_cuda,
|
|
34
|
+
CUDACompiler,
|
|
35
|
+
kernel_fixup,
|
|
36
|
+
compile_extra,
|
|
37
|
+
compile_ir,
|
|
38
|
+
)
|
|
39
|
+
from numba.cuda.core import sigutils, config, entrypoints
|
|
40
|
+
from numba.cuda.flags import Flags
|
|
41
|
+
from numba.cuda.cudadrv import driver, nvvm
|
|
42
|
+
from numba.cuda.locks import module_init_lock
|
|
43
|
+
from numba.cuda.core.caching import Cache, CacheImpl, NullCache
|
|
44
|
+
from numba.cuda.descriptor import cuda_target
|
|
45
|
+
from numba.cuda.errors import (
|
|
46
|
+
missing_launch_config_msg,
|
|
47
|
+
normalize_kernel_dimensions,
|
|
48
|
+
)
|
|
49
|
+
from numba.cuda.cudadrv.linkable_code import LinkableCode
|
|
50
|
+
from numba.cuda.cudadrv.devices import get_context
|
|
51
|
+
from numba.cuda.memory_management.nrt import rtsys, NRT_LIBRARY
|
|
52
|
+
import numba.cuda.core.event as ev
|
|
53
|
+
from numba.cuda.cext import _dispatcher
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
cuda_fp16_math_funcs = [
|
|
57
|
+
"hsin",
|
|
58
|
+
"hcos",
|
|
59
|
+
"hlog",
|
|
60
|
+
"hlog10",
|
|
61
|
+
"hlog2",
|
|
62
|
+
"hexp",
|
|
63
|
+
"hexp10",
|
|
64
|
+
"hexp2",
|
|
65
|
+
"hsqrt",
|
|
66
|
+
"hrsqrt",
|
|
67
|
+
"hfloor",
|
|
68
|
+
"hceil",
|
|
69
|
+
"hrcp",
|
|
70
|
+
"hrint",
|
|
71
|
+
"htrunc",
|
|
72
|
+
"hdiv",
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class _Kernel(serialize.ReduceMixin):
|
|
79
|
+
"""
|
|
80
|
+
CUDA Kernel specialized for a given set of argument types. When called, this
|
|
81
|
+
object launches the kernel on the device.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
NRT_functions = [
|
|
85
|
+
"NRT_Allocate",
|
|
86
|
+
"NRT_MemInfo_init",
|
|
87
|
+
"NRT_MemInfo_new",
|
|
88
|
+
"NRT_Free",
|
|
89
|
+
"NRT_dealloc",
|
|
90
|
+
"NRT_MemInfo_destroy",
|
|
91
|
+
"NRT_MemInfo_call_dtor",
|
|
92
|
+
"NRT_MemInfo_data_fast",
|
|
93
|
+
"NRT_MemInfo_alloc_aligned",
|
|
94
|
+
"NRT_Allocate_External",
|
|
95
|
+
"NRT_decref",
|
|
96
|
+
"NRT_incref",
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
@global_compiler_lock
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
py_func,
|
|
103
|
+
argtypes,
|
|
104
|
+
link=None,
|
|
105
|
+
debug=False,
|
|
106
|
+
lineinfo=False,
|
|
107
|
+
inline=False,
|
|
108
|
+
forceinline=False,
|
|
109
|
+
fastmath=False,
|
|
110
|
+
extensions=None,
|
|
111
|
+
max_registers=None,
|
|
112
|
+
lto=False,
|
|
113
|
+
opt=True,
|
|
114
|
+
device=False,
|
|
115
|
+
launch_bounds=None,
|
|
116
|
+
):
|
|
117
|
+
if device:
|
|
118
|
+
raise RuntimeError("Cannot compile a device function as a kernel")
|
|
119
|
+
|
|
120
|
+
super().__init__()
|
|
121
|
+
|
|
122
|
+
# _DispatcherBase.nopython_signatures() expects this attribute to be
|
|
123
|
+
# present, because it assumes an overload is a CompileResult. In the
|
|
124
|
+
# CUDA target, _Kernel instances are stored instead, so we provide this
|
|
125
|
+
# attribute here to avoid duplicating nopython_signatures() in the CUDA
|
|
126
|
+
# target with slight modifications.
|
|
127
|
+
self.objectmode = False
|
|
128
|
+
|
|
129
|
+
# The finalizer constructed by _DispatcherBase._make_finalizer also
|
|
130
|
+
# expects overloads to be a CompileResult. It uses the entry_point to
|
|
131
|
+
# remove a CompileResult from a target context. However, since we never
|
|
132
|
+
# insert kernels into a target context (there is no need because they
|
|
133
|
+
# cannot be called by other functions, only through the dispatcher) it
|
|
134
|
+
# suffices to pretend we have an entry point of None.
|
|
135
|
+
self.entry_point = None
|
|
136
|
+
|
|
137
|
+
self.py_func = py_func
|
|
138
|
+
self.argtypes = argtypes
|
|
139
|
+
self.debug = debug
|
|
140
|
+
self.lineinfo = lineinfo
|
|
141
|
+
self.extensions = extensions or []
|
|
142
|
+
self.launch_bounds = launch_bounds
|
|
143
|
+
|
|
144
|
+
nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
|
|
145
|
+
|
|
146
|
+
if debug:
|
|
147
|
+
nvvm_options["g"] = None
|
|
148
|
+
|
|
149
|
+
cc = get_current_device().compute_capability
|
|
150
|
+
|
|
151
|
+
cres = compile_cuda(
|
|
152
|
+
self.py_func,
|
|
153
|
+
types.void,
|
|
154
|
+
self.argtypes,
|
|
155
|
+
debug=self.debug,
|
|
156
|
+
lineinfo=lineinfo,
|
|
157
|
+
forceinline=forceinline,
|
|
158
|
+
fastmath=fastmath,
|
|
159
|
+
nvvm_options=nvvm_options,
|
|
160
|
+
cc=cc,
|
|
161
|
+
max_registers=max_registers,
|
|
162
|
+
lto=lto,
|
|
163
|
+
)
|
|
164
|
+
tgt_ctx = cres.target_context
|
|
165
|
+
lib = cres.library
|
|
166
|
+
kernel = lib.get_function(cres.fndesc.llvm_func_name)
|
|
167
|
+
lib._entry_name = cres.fndesc.llvm_func_name
|
|
168
|
+
kernel_fixup(kernel, self.debug)
|
|
169
|
+
nvvm.set_launch_bounds(kernel, launch_bounds)
|
|
170
|
+
|
|
171
|
+
if not link:
|
|
172
|
+
link = []
|
|
173
|
+
|
|
174
|
+
asm = lib.get_asm_str()
|
|
175
|
+
|
|
176
|
+
# The code library contains functions that require cooperative launch.
|
|
177
|
+
self.cooperative = lib.use_cooperative
|
|
178
|
+
# We need to link against cudadevrt if grid sync is being used.
|
|
179
|
+
if self.cooperative:
|
|
180
|
+
lib.needs_cudadevrt = True
|
|
181
|
+
|
|
182
|
+
def link_to_library_functions(
|
|
183
|
+
library_functions, library_path, prefix=None
|
|
184
|
+
):
|
|
185
|
+
"""
|
|
186
|
+
Dynamically links to library functions by searching for their names
|
|
187
|
+
in the specified library and linking to the corresponding source
|
|
188
|
+
file.
|
|
189
|
+
"""
|
|
190
|
+
if prefix is not None:
|
|
191
|
+
library_functions = [
|
|
192
|
+
f"{prefix}{fn}" for fn in library_functions
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
found_functions = [fn for fn in library_functions if f"{fn}" in asm]
|
|
196
|
+
|
|
197
|
+
if found_functions:
|
|
198
|
+
basedir = os.path.dirname(os.path.abspath(__file__))
|
|
199
|
+
source_file_path = os.path.join(basedir, library_path)
|
|
200
|
+
link.append(source_file_path)
|
|
201
|
+
|
|
202
|
+
return found_functions
|
|
203
|
+
|
|
204
|
+
# Link to the helper library functions if needed
|
|
205
|
+
link_to_library_functions(reshape_funcs, "reshape_funcs.cu")
|
|
206
|
+
|
|
207
|
+
self.maybe_link_nrt(link, tgt_ctx, asm)
|
|
208
|
+
|
|
209
|
+
for filepath in link:
|
|
210
|
+
lib.add_linking_file(filepath)
|
|
211
|
+
|
|
212
|
+
# populate members
|
|
213
|
+
self.entry_name = kernel.name
|
|
214
|
+
self.signature = cres.signature
|
|
215
|
+
self._type_annotation = cres.type_annotation
|
|
216
|
+
self._codelibrary = lib
|
|
217
|
+
self.call_helper = cres.call_helper
|
|
218
|
+
|
|
219
|
+
# The following are referred to by the cache implementation. Note:
|
|
220
|
+
# - There are no referenced environments in CUDA.
|
|
221
|
+
# - Kernels don't have lifted code.
|
|
222
|
+
self.target_context = tgt_ctx
|
|
223
|
+
self.fndesc = cres.fndesc
|
|
224
|
+
self.environment = cres.environment
|
|
225
|
+
self._referenced_environments = []
|
|
226
|
+
self.lifted = []
|
|
227
|
+
|
|
228
|
+
def maybe_link_nrt(self, link, tgt_ctx, asm):
|
|
229
|
+
"""
|
|
230
|
+
Add the NRT source code to the link if the neccesary conditions are met.
|
|
231
|
+
NRT must be enabled for the CUDATargetContext, and either NRT functions
|
|
232
|
+
must be detected in the kernel asm or an NRT enabled LinkableCode object
|
|
233
|
+
must be passed.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
if not tgt_ctx.enable_nrt:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
all_nrt = "|".join(self.NRT_functions)
|
|
240
|
+
pattern = (
|
|
241
|
+
r"\.extern\s+\.func\s+(?:\s*\(.+\)\s*)?("
|
|
242
|
+
+ all_nrt
|
|
243
|
+
+ r")\s*\([^)]*\)\s*;"
|
|
244
|
+
)
|
|
245
|
+
link_nrt = False
|
|
246
|
+
nrt_in_asm = re.findall(pattern, asm)
|
|
247
|
+
if len(nrt_in_asm) > 0:
|
|
248
|
+
link_nrt = True
|
|
249
|
+
if not link_nrt:
|
|
250
|
+
for file in link:
|
|
251
|
+
if isinstance(file, LinkableCode):
|
|
252
|
+
if file.nrt:
|
|
253
|
+
link_nrt = True
|
|
254
|
+
break
|
|
255
|
+
|
|
256
|
+
if link_nrt:
|
|
257
|
+
link.append(NRT_LIBRARY)
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def library(self):
|
|
261
|
+
return self._codelibrary
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def type_annotation(self):
|
|
265
|
+
return self._type_annotation
|
|
266
|
+
|
|
267
|
+
def _find_referenced_environments(self):
|
|
268
|
+
return self._referenced_environments
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def codegen(self):
|
|
272
|
+
return self.target_context.codegen()
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def argument_types(self):
|
|
276
|
+
return tuple(self.signature.args)
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def _rebuild(
|
|
280
|
+
cls,
|
|
281
|
+
cooperative,
|
|
282
|
+
name,
|
|
283
|
+
signature,
|
|
284
|
+
codelibrary,
|
|
285
|
+
debug,
|
|
286
|
+
lineinfo,
|
|
287
|
+
call_helper,
|
|
288
|
+
extensions,
|
|
289
|
+
):
|
|
290
|
+
"""
|
|
291
|
+
Rebuild an instance.
|
|
292
|
+
"""
|
|
293
|
+
instance = cls.__new__(cls)
|
|
294
|
+
# invoke parent constructor
|
|
295
|
+
super(cls, instance).__init__()
|
|
296
|
+
# populate members
|
|
297
|
+
instance.entry_point = None
|
|
298
|
+
instance.cooperative = cooperative
|
|
299
|
+
instance.entry_name = name
|
|
300
|
+
instance.signature = signature
|
|
301
|
+
instance._type_annotation = None
|
|
302
|
+
instance._codelibrary = codelibrary
|
|
303
|
+
instance.debug = debug
|
|
304
|
+
instance.lineinfo = lineinfo
|
|
305
|
+
instance.call_helper = call_helper
|
|
306
|
+
instance.extensions = extensions
|
|
307
|
+
return instance
|
|
308
|
+
|
|
309
|
+
def _reduce_states(self):
|
|
310
|
+
"""
|
|
311
|
+
Reduce the instance for serialization.
|
|
312
|
+
Compiled definitions are serialized in PTX form.
|
|
313
|
+
Type annotation are discarded.
|
|
314
|
+
Thread, block and shared memory configuration are serialized.
|
|
315
|
+
Stream information is discarded.
|
|
316
|
+
"""
|
|
317
|
+
return dict(
|
|
318
|
+
cooperative=self.cooperative,
|
|
319
|
+
name=self.entry_name,
|
|
320
|
+
signature=self.signature,
|
|
321
|
+
codelibrary=self._codelibrary,
|
|
322
|
+
debug=self.debug,
|
|
323
|
+
lineinfo=self.lineinfo,
|
|
324
|
+
call_helper=self.call_helper,
|
|
325
|
+
extensions=self.extensions,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
@module_init_lock
|
|
329
|
+
def initialize_once(self, mod):
|
|
330
|
+
if not mod.initialized:
|
|
331
|
+
mod.setup()
|
|
332
|
+
|
|
333
|
+
def bind(self):
|
|
334
|
+
"""
|
|
335
|
+
Force binding to current CUDA context
|
|
336
|
+
"""
|
|
337
|
+
cufunc = self._codelibrary.get_cufunc()
|
|
338
|
+
|
|
339
|
+
self.initialize_once(cufunc.module)
|
|
340
|
+
|
|
341
|
+
if (
|
|
342
|
+
hasattr(self, "target_context")
|
|
343
|
+
and self.target_context.enable_nrt
|
|
344
|
+
and config.CUDA_NRT_STATS
|
|
345
|
+
):
|
|
346
|
+
rtsys.ensure_initialized()
|
|
347
|
+
rtsys.set_memsys_to_module(cufunc.module)
|
|
348
|
+
# We don't know which stream the kernel will be launched on, so
|
|
349
|
+
# we force synchronize here.
|
|
350
|
+
cuda.synchronize()
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def regs_per_thread(self):
|
|
354
|
+
"""
|
|
355
|
+
The number of registers used by each thread for this kernel.
|
|
356
|
+
"""
|
|
357
|
+
return self._codelibrary.get_cufunc().attrs.regs
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def const_mem_size(self):
|
|
361
|
+
"""
|
|
362
|
+
The amount of constant memory used by this kernel.
|
|
363
|
+
"""
|
|
364
|
+
return self._codelibrary.get_cufunc().attrs.const
|
|
365
|
+
|
|
366
|
+
@property
|
|
367
|
+
def shared_mem_per_block(self):
|
|
368
|
+
"""
|
|
369
|
+
The amount of shared memory used per block for this kernel.
|
|
370
|
+
"""
|
|
371
|
+
return self._codelibrary.get_cufunc().attrs.shared
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def max_threads_per_block(self):
|
|
375
|
+
"""
|
|
376
|
+
The maximum allowable threads per block.
|
|
377
|
+
"""
|
|
378
|
+
return self._codelibrary.get_cufunc().attrs.maxthreads
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def local_mem_per_thread(self):
|
|
382
|
+
"""
|
|
383
|
+
The amount of local memory used per thread for this kernel.
|
|
384
|
+
"""
|
|
385
|
+
return self._codelibrary.get_cufunc().attrs.local
|
|
386
|
+
|
|
387
|
+
def inspect_llvm(self):
|
|
388
|
+
"""
|
|
389
|
+
Returns the LLVM IR for this kernel.
|
|
390
|
+
"""
|
|
391
|
+
return self._codelibrary.get_llvm_str()
|
|
392
|
+
|
|
393
|
+
def inspect_asm(self, cc):
|
|
394
|
+
"""
|
|
395
|
+
Returns the PTX code for this kernel.
|
|
396
|
+
"""
|
|
397
|
+
return self._codelibrary.get_asm_str(cc=cc)
|
|
398
|
+
|
|
399
|
+
def inspect_lto_ptx(self, cc):
|
|
400
|
+
"""
|
|
401
|
+
Returns the PTX code for the external functions linked to this kernel.
|
|
402
|
+
"""
|
|
403
|
+
return self._codelibrary.get_lto_ptx(cc=cc)
|
|
404
|
+
|
|
405
|
+
def inspect_sass_cfg(self):
|
|
406
|
+
"""
|
|
407
|
+
Returns the CFG of the SASS for this kernel.
|
|
408
|
+
|
|
409
|
+
Requires nvdisasm to be available on the PATH.
|
|
410
|
+
"""
|
|
411
|
+
return self._codelibrary.get_sass_cfg()
|
|
412
|
+
|
|
413
|
+
def inspect_sass(self):
|
|
414
|
+
"""
|
|
415
|
+
Returns the SASS code for this kernel.
|
|
416
|
+
|
|
417
|
+
Requires nvdisasm to be available on the PATH.
|
|
418
|
+
"""
|
|
419
|
+
return self._codelibrary.get_sass()
|
|
420
|
+
|
|
421
|
+
def inspect_types(self, file=None):
|
|
422
|
+
"""
|
|
423
|
+
Produce a dump of the Python source of this function annotated with the
|
|
424
|
+
corresponding Numba IR and type information. The dump is written to
|
|
425
|
+
*file*, or *sys.stdout* if *file* is *None*.
|
|
426
|
+
"""
|
|
427
|
+
if self._type_annotation is None:
|
|
428
|
+
raise ValueError("Type annotation is not available")
|
|
429
|
+
|
|
430
|
+
if file is None:
|
|
431
|
+
file = sys.stdout
|
|
432
|
+
|
|
433
|
+
print("%s %s" % (self.entry_name, self.argument_types), file=file)
|
|
434
|
+
print("-" * 80, file=file)
|
|
435
|
+
print(self._type_annotation, file=file)
|
|
436
|
+
print("=" * 80, file=file)
|
|
437
|
+
|
|
438
|
+
def max_cooperative_grid_blocks(self, blockdim, dynsmemsize=0):
|
|
439
|
+
"""
|
|
440
|
+
Calculates the maximum number of blocks that can be launched for this
|
|
441
|
+
kernel in a cooperative grid in the current context, for the given block
|
|
442
|
+
and dynamic shared memory sizes.
|
|
443
|
+
|
|
444
|
+
:param blockdim: Block dimensions, either as a scalar for a 1D block, or
|
|
445
|
+
a tuple for 2D or 3D blocks.
|
|
446
|
+
:param dynsmemsize: Dynamic shared memory size in bytes.
|
|
447
|
+
:return: The maximum number of blocks in the grid.
|
|
448
|
+
"""
|
|
449
|
+
ctx = get_context()
|
|
450
|
+
cufunc = self._codelibrary.get_cufunc()
|
|
451
|
+
|
|
452
|
+
if isinstance(blockdim, tuple):
|
|
453
|
+
blockdim = functools.reduce(lambda x, y: x * y, blockdim)
|
|
454
|
+
active_per_sm = ctx.get_active_blocks_per_multiprocessor(
|
|
455
|
+
cufunc, blockdim, dynsmemsize
|
|
456
|
+
)
|
|
457
|
+
sm_count = ctx.device.MULTIPROCESSOR_COUNT
|
|
458
|
+
return active_per_sm * sm_count
|
|
459
|
+
|
|
460
|
+
def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
|
|
461
|
+
# Prepare kernel
|
|
462
|
+
cufunc = self._codelibrary.get_cufunc()
|
|
463
|
+
|
|
464
|
+
if self.debug:
|
|
465
|
+
excname = cufunc.name + "__errcode__"
|
|
466
|
+
excmem, excsz = cufunc.module.get_global_symbol(excname)
|
|
467
|
+
assert excsz == ctypes.sizeof(ctypes.c_int)
|
|
468
|
+
excval = ctypes.c_int()
|
|
469
|
+
excmem.memset(0, stream=stream)
|
|
470
|
+
|
|
471
|
+
# Prepare arguments
|
|
472
|
+
retr = [] # hold functors for writeback
|
|
473
|
+
|
|
474
|
+
kernelargs = []
|
|
475
|
+
for t, v in zip(self.argument_types, args):
|
|
476
|
+
self._prepare_args(t, v, stream, retr, kernelargs)
|
|
477
|
+
|
|
478
|
+
stream_handle = driver._stream_handle(stream)
|
|
479
|
+
|
|
480
|
+
# Invoke kernel
|
|
481
|
+
driver.launch_kernel(
|
|
482
|
+
cufunc.handle,
|
|
483
|
+
*griddim,
|
|
484
|
+
*blockdim,
|
|
485
|
+
sharedmem,
|
|
486
|
+
stream_handle,
|
|
487
|
+
kernelargs,
|
|
488
|
+
cooperative=self.cooperative,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if self.debug:
|
|
492
|
+
driver.device_to_host(ctypes.addressof(excval), excmem, excsz)
|
|
493
|
+
if excval.value != 0:
|
|
494
|
+
# An error occurred
|
|
495
|
+
def load_symbol(name):
|
|
496
|
+
mem, sz = cufunc.module.get_global_symbol(
|
|
497
|
+
"%s__%s__" % (cufunc.name, name)
|
|
498
|
+
)
|
|
499
|
+
val = ctypes.c_int()
|
|
500
|
+
driver.device_to_host(ctypes.addressof(val), mem, sz)
|
|
501
|
+
return val.value
|
|
502
|
+
|
|
503
|
+
tid = [load_symbol("tid" + i) for i in "zyx"]
|
|
504
|
+
ctaid = [load_symbol("ctaid" + i) for i in "zyx"]
|
|
505
|
+
code = excval.value
|
|
506
|
+
exccls, exc_args, loc = self.call_helper.get_exception(code)
|
|
507
|
+
# Prefix the exception message with the source location
|
|
508
|
+
if loc is None:
|
|
509
|
+
locinfo = ""
|
|
510
|
+
else:
|
|
511
|
+
sym, filepath, lineno = loc
|
|
512
|
+
filepath = os.path.abspath(filepath)
|
|
513
|
+
locinfo = "In function %r, file %s, line %s, " % (
|
|
514
|
+
sym,
|
|
515
|
+
filepath,
|
|
516
|
+
lineno,
|
|
517
|
+
)
|
|
518
|
+
# Prefix the exception message with the thread position
|
|
519
|
+
prefix = "%stid=%s ctaid=%s" % (locinfo, tid, ctaid)
|
|
520
|
+
if exc_args:
|
|
521
|
+
exc_args = ("%s: %s" % (prefix, exc_args[0]),) + exc_args[
|
|
522
|
+
1:
|
|
523
|
+
]
|
|
524
|
+
else:
|
|
525
|
+
exc_args = (prefix,)
|
|
526
|
+
raise exccls(*exc_args)
|
|
527
|
+
|
|
528
|
+
# retrieve auto converted arrays
|
|
529
|
+
for wb in retr:
|
|
530
|
+
wb()
|
|
531
|
+
|
|
532
|
+
def _prepare_args(self, ty, val, stream, retr, kernelargs):
|
|
533
|
+
"""
|
|
534
|
+
Convert arguments to ctypes and append to kernelargs
|
|
535
|
+
"""
|
|
536
|
+
|
|
537
|
+
# map the arguments using any extension you've registered
|
|
538
|
+
for extension in reversed(self.extensions):
|
|
539
|
+
ty, val = extension.prepare_args(ty, val, stream=stream, retr=retr)
|
|
540
|
+
|
|
541
|
+
if isinstance(ty, types.Array):
|
|
542
|
+
devary = wrap_arg(val).to_device(retr, stream)
|
|
543
|
+
c_intp = ctypes.c_ssize_t
|
|
544
|
+
|
|
545
|
+
meminfo = ctypes.c_void_p(0)
|
|
546
|
+
parent = ctypes.c_void_p(0)
|
|
547
|
+
nitems = c_intp(devary.size)
|
|
548
|
+
itemsize = c_intp(devary.dtype.itemsize)
|
|
549
|
+
|
|
550
|
+
ptr = driver.device_pointer(devary)
|
|
551
|
+
|
|
552
|
+
ptr = int(ptr)
|
|
553
|
+
|
|
554
|
+
data = ctypes.c_void_p(ptr)
|
|
555
|
+
|
|
556
|
+
kernelargs.append(meminfo)
|
|
557
|
+
kernelargs.append(parent)
|
|
558
|
+
kernelargs.append(nitems)
|
|
559
|
+
kernelargs.append(itemsize)
|
|
560
|
+
kernelargs.append(data)
|
|
561
|
+
kernelargs.extend(map(c_intp, devary.shape))
|
|
562
|
+
kernelargs.extend(map(c_intp, devary.strides))
|
|
563
|
+
|
|
564
|
+
elif isinstance(ty, types.CPointer):
|
|
565
|
+
# Pointer arguments should be a pointer-sized integer
|
|
566
|
+
kernelargs.append(ctypes.c_uint64(val))
|
|
567
|
+
|
|
568
|
+
elif isinstance(ty, types.Integer):
|
|
569
|
+
cval = getattr(ctypes, "c_%s" % ty)(val)
|
|
570
|
+
kernelargs.append(cval)
|
|
571
|
+
|
|
572
|
+
elif ty == types.float16:
|
|
573
|
+
cval = ctypes.c_uint16(np.float16(val).view(np.uint16))
|
|
574
|
+
kernelargs.append(cval)
|
|
575
|
+
|
|
576
|
+
elif ty == types.float64:
|
|
577
|
+
cval = ctypes.c_double(val)
|
|
578
|
+
kernelargs.append(cval)
|
|
579
|
+
|
|
580
|
+
elif ty == types.float32:
|
|
581
|
+
cval = ctypes.c_float(val)
|
|
582
|
+
kernelargs.append(cval)
|
|
583
|
+
|
|
584
|
+
elif ty == types.boolean:
|
|
585
|
+
cval = ctypes.c_uint8(int(val))
|
|
586
|
+
kernelargs.append(cval)
|
|
587
|
+
|
|
588
|
+
elif ty == types.complex64:
|
|
589
|
+
kernelargs.append(ctypes.c_float(val.real))
|
|
590
|
+
kernelargs.append(ctypes.c_float(val.imag))
|
|
591
|
+
|
|
592
|
+
elif ty == types.complex128:
|
|
593
|
+
kernelargs.append(ctypes.c_double(val.real))
|
|
594
|
+
kernelargs.append(ctypes.c_double(val.imag))
|
|
595
|
+
|
|
596
|
+
elif isinstance(ty, (types.NPDatetime, types.NPTimedelta)):
|
|
597
|
+
kernelargs.append(ctypes.c_int64(val.view(np.int64)))
|
|
598
|
+
|
|
599
|
+
elif isinstance(ty, types.Record):
|
|
600
|
+
devrec = wrap_arg(val).to_device(retr, stream)
|
|
601
|
+
ptr = devrec.device_ctypes_pointer
|
|
602
|
+
kernelargs.append(ptr)
|
|
603
|
+
|
|
604
|
+
elif isinstance(ty, types.BaseTuple):
|
|
605
|
+
assert len(ty) == len(val)
|
|
606
|
+
for t, v in zip(ty, val):
|
|
607
|
+
self._prepare_args(t, v, stream, retr, kernelargs)
|
|
608
|
+
|
|
609
|
+
elif isinstance(ty, types.EnumMember):
|
|
610
|
+
try:
|
|
611
|
+
self._prepare_args(
|
|
612
|
+
ty.dtype, val.value, stream, retr, kernelargs
|
|
613
|
+
)
|
|
614
|
+
except NotImplementedError:
|
|
615
|
+
raise NotImplementedError(ty, val)
|
|
616
|
+
|
|
617
|
+
else:
|
|
618
|
+
raise NotImplementedError(ty, val)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
class ForAll(object):
|
|
622
|
+
def __init__(self, dispatcher, ntasks, tpb, stream, sharedmem):
|
|
623
|
+
if ntasks < 0:
|
|
624
|
+
raise ValueError(
|
|
625
|
+
"Can't create ForAll with negative task count: %s" % ntasks
|
|
626
|
+
)
|
|
627
|
+
self.dispatcher = dispatcher
|
|
628
|
+
self.ntasks = ntasks
|
|
629
|
+
self.thread_per_block = tpb
|
|
630
|
+
self.stream = stream
|
|
631
|
+
self.sharedmem = sharedmem
|
|
632
|
+
|
|
633
|
+
def __call__(self, *args):
|
|
634
|
+
if self.ntasks == 0:
|
|
635
|
+
return
|
|
636
|
+
|
|
637
|
+
if self.dispatcher.specialized:
|
|
638
|
+
specialized = self.dispatcher
|
|
639
|
+
else:
|
|
640
|
+
specialized = self.dispatcher.specialize(*args)
|
|
641
|
+
blockdim = self._compute_thread_per_block(specialized)
|
|
642
|
+
griddim = (self.ntasks + blockdim - 1) // blockdim
|
|
643
|
+
|
|
644
|
+
return specialized[griddim, blockdim, self.stream, self.sharedmem](
|
|
645
|
+
*args
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def _compute_thread_per_block(self, dispatcher):
|
|
649
|
+
tpb = self.thread_per_block
|
|
650
|
+
# Prefer user-specified config
|
|
651
|
+
if tpb != 0:
|
|
652
|
+
return tpb
|
|
653
|
+
# Else, ask the driver to give a good config
|
|
654
|
+
else:
|
|
655
|
+
ctx = get_context()
|
|
656
|
+
# Dispatcher is specialized, so there's only one definition - get
|
|
657
|
+
# it so we can get the cufunc from the code library
|
|
658
|
+
kernel = next(iter(dispatcher.overloads.values()))
|
|
659
|
+
kwargs = dict(
|
|
660
|
+
func=kernel._codelibrary.get_cufunc(),
|
|
661
|
+
b2d_func=0, # dynamic-shared memory is constant to blksz
|
|
662
|
+
memsize=self.sharedmem,
|
|
663
|
+
blocksizelimit=1024,
|
|
664
|
+
)
|
|
665
|
+
_, tpb = ctx.get_max_potential_block_size(**kwargs)
|
|
666
|
+
return tpb
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class _LaunchConfiguration:
|
|
670
|
+
def __init__(self, dispatcher, griddim, blockdim, stream, sharedmem):
|
|
671
|
+
self.dispatcher = dispatcher
|
|
672
|
+
self.griddim = griddim
|
|
673
|
+
self.blockdim = blockdim
|
|
674
|
+
self.stream = stream
|
|
675
|
+
self.sharedmem = sharedmem
|
|
676
|
+
|
|
677
|
+
if (
|
|
678
|
+
config.CUDA_LOW_OCCUPANCY_WARNINGS
|
|
679
|
+
and not config.DISABLE_PERFORMANCE_WARNINGS
|
|
680
|
+
):
|
|
681
|
+
# Warn when the grid has fewer than 128 blocks. This number is
|
|
682
|
+
# chosen somewhat heuristically - ideally the minimum is 2 times
|
|
683
|
+
# the number of SMs, but the number of SMs varies between devices -
|
|
684
|
+
# some very small GPUs might only have 4 SMs, but an H100-SXM5 has
|
|
685
|
+
# 132. In general kernels should be launched with large grids
|
|
686
|
+
# (hundreds or thousands of blocks), so warning when fewer than 128
|
|
687
|
+
# blocks are used will likely catch most beginner errors, where the
|
|
688
|
+
# grid tends to be very small (single-digit or low tens of blocks).
|
|
689
|
+
min_grid_size = 128
|
|
690
|
+
grid_size = griddim[0] * griddim[1] * griddim[2]
|
|
691
|
+
if grid_size < min_grid_size:
|
|
692
|
+
msg = (
|
|
693
|
+
f"Grid size {grid_size} will likely result in GPU "
|
|
694
|
+
"under-utilization due to low occupancy."
|
|
695
|
+
)
|
|
696
|
+
warn(errors.NumbaPerformanceWarning(msg))
|
|
697
|
+
|
|
698
|
+
def __call__(self, *args):
|
|
699
|
+
return self.dispatcher.call(
|
|
700
|
+
args, self.griddim, self.blockdim, self.stream, self.sharedmem
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class CUDACacheImpl(CacheImpl):
|
|
705
|
+
def reduce(self, kernel):
|
|
706
|
+
return kernel._reduce_states()
|
|
707
|
+
|
|
708
|
+
def rebuild(self, target_context, payload):
|
|
709
|
+
return _Kernel._rebuild(**payload)
|
|
710
|
+
|
|
711
|
+
def check_cachable(self, cres):
|
|
712
|
+
# CUDA Kernels are always cachable - the reasons for an entity not to
|
|
713
|
+
# be cachable are:
|
|
714
|
+
#
|
|
715
|
+
# - The presence of lifted loops, or
|
|
716
|
+
# - The presence of dynamic globals.
|
|
717
|
+
#
|
|
718
|
+
# neither of which apply to CUDA kernels.
|
|
719
|
+
return True
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
class CUDACache(Cache):
|
|
723
|
+
"""
|
|
724
|
+
Implements a cache that saves and loads CUDA kernels and compile results.
|
|
725
|
+
"""
|
|
726
|
+
|
|
727
|
+
_impl_class = CUDACacheImpl
|
|
728
|
+
|
|
729
|
+
def load_overload(self, sig, target_context):
|
|
730
|
+
# Loading an overload refreshes the context to ensure it is initialized.
|
|
731
|
+
with utils.numba_target_override():
|
|
732
|
+
return super().load_overload(sig, target_context)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
class OmittedArg(object):
|
|
736
|
+
"""
|
|
737
|
+
A placeholder for omitted arguments with a default value.
|
|
738
|
+
"""
|
|
739
|
+
|
|
740
|
+
def __init__(self, value):
|
|
741
|
+
self.value = value
|
|
742
|
+
|
|
743
|
+
def __repr__(self):
|
|
744
|
+
return "omitted arg(%r)" % (self.value,)
|
|
745
|
+
|
|
746
|
+
@property
|
|
747
|
+
def _numba_type_(self):
|
|
748
|
+
return types.Omitted(self.value)
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
class CompilingCounter(object):
|
|
752
|
+
"""
|
|
753
|
+
A simple counter that increment in __enter__ and decrement in __exit__.
|
|
754
|
+
"""
|
|
755
|
+
|
|
756
|
+
def __init__(self):
|
|
757
|
+
self.counter = 0
|
|
758
|
+
|
|
759
|
+
def __enter__(self):
|
|
760
|
+
assert self.counter >= 0
|
|
761
|
+
self.counter += 1
|
|
762
|
+
|
|
763
|
+
def __exit__(self, *args, **kwargs):
|
|
764
|
+
self.counter -= 1
|
|
765
|
+
assert self.counter >= 0
|
|
766
|
+
|
|
767
|
+
def __bool__(self):
|
|
768
|
+
return self.counter > 0
|
|
769
|
+
|
|
770
|
+
__nonzero__ = __bool__
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class _DispatcherBase(_dispatcher.Dispatcher):
|
|
774
|
+
"""
|
|
775
|
+
Common base class for dispatcher Implementations.
|
|
776
|
+
"""
|
|
777
|
+
|
|
778
|
+
__numba__ = "py_func"
|
|
779
|
+
|
|
780
|
+
def __init__(
|
|
781
|
+
self, arg_count, py_func, pysig, can_fallback, exact_match_required
|
|
782
|
+
):
|
|
783
|
+
self._tm = default_type_manager
|
|
784
|
+
|
|
785
|
+
# A mapping of signatures to compile results
|
|
786
|
+
self.overloads = collections.OrderedDict()
|
|
787
|
+
|
|
788
|
+
self.py_func = py_func
|
|
789
|
+
# other parts of Numba assume the old Python 2 name for code object
|
|
790
|
+
self.func_code = get_code_object(py_func)
|
|
791
|
+
# but newer python uses a different name
|
|
792
|
+
self.__code__ = self.func_code
|
|
793
|
+
# a place to keep an active reference to the types of the active call
|
|
794
|
+
self._types_active_call = set()
|
|
795
|
+
# Default argument values match the py_func
|
|
796
|
+
self.__defaults__ = py_func.__defaults__
|
|
797
|
+
|
|
798
|
+
argnames = tuple(pysig.parameters)
|
|
799
|
+
default_values = self.py_func.__defaults__ or ()
|
|
800
|
+
defargs = tuple(OmittedArg(val) for val in default_values)
|
|
801
|
+
try:
|
|
802
|
+
lastarg = list(pysig.parameters.values())[-1]
|
|
803
|
+
except IndexError:
|
|
804
|
+
has_stararg = False
|
|
805
|
+
else:
|
|
806
|
+
has_stararg = lastarg.kind == lastarg.VAR_POSITIONAL
|
|
807
|
+
_dispatcher.Dispatcher.__init__(
|
|
808
|
+
self,
|
|
809
|
+
self._tm.get_pointer(),
|
|
810
|
+
arg_count,
|
|
811
|
+
self._fold_args,
|
|
812
|
+
argnames,
|
|
813
|
+
defargs,
|
|
814
|
+
can_fallback,
|
|
815
|
+
has_stararg,
|
|
816
|
+
exact_match_required,
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
self.doc = py_func.__doc__
|
|
820
|
+
self._compiling_counter = CompilingCounter()
|
|
821
|
+
weakref.finalize(self, self._make_finalizer())
|
|
822
|
+
|
|
823
|
+
def _compilation_chain_init_hook(self):
|
|
824
|
+
"""
|
|
825
|
+
This will be called ahead of any part of compilation taking place (this
|
|
826
|
+
even includes being ahead of working out the types of the arguments).
|
|
827
|
+
This permits activities such as initialising extension entry points so
|
|
828
|
+
that the compiler knows about additional externally defined types etc
|
|
829
|
+
before it does anything.
|
|
830
|
+
"""
|
|
831
|
+
entrypoints.init_all()
|
|
832
|
+
|
|
833
|
+
def _reset_overloads(self):
|
|
834
|
+
self._clear()
|
|
835
|
+
self.overloads.clear()
|
|
836
|
+
|
|
837
|
+
def _make_finalizer(self):
|
|
838
|
+
"""
|
|
839
|
+
Return a finalizer function that will release references to
|
|
840
|
+
related compiled functions.
|
|
841
|
+
"""
|
|
842
|
+
overloads = self.overloads
|
|
843
|
+
targetctx = self.targetctx
|
|
844
|
+
|
|
845
|
+
# Early-bind utils.shutting_down() into the function's local namespace
|
|
846
|
+
# (see issue #689)
|
|
847
|
+
def finalizer(shutting_down=utils.shutting_down):
|
|
848
|
+
# The finalizer may crash at shutdown, skip it (resources
|
|
849
|
+
# will be cleared by the process exiting, anyway).
|
|
850
|
+
if shutting_down():
|
|
851
|
+
return
|
|
852
|
+
# This function must *not* hold any reference to self:
|
|
853
|
+
# we take care to bind the necessary objects in the closure.
|
|
854
|
+
for cres in overloads.values():
|
|
855
|
+
try:
|
|
856
|
+
targetctx.remove_user_function(cres.entry_point)
|
|
857
|
+
except KeyError:
|
|
858
|
+
pass
|
|
859
|
+
|
|
860
|
+
return finalizer
|
|
861
|
+
|
|
862
|
+
@property
|
|
863
|
+
def signatures(self):
|
|
864
|
+
"""
|
|
865
|
+
Returns a list of compiled function signatures.
|
|
866
|
+
"""
|
|
867
|
+
return list(self.overloads)
|
|
868
|
+
|
|
869
|
+
@property
|
|
870
|
+
def nopython_signatures(self):
|
|
871
|
+
return [
|
|
872
|
+
cres.signature
|
|
873
|
+
for cres in self.overloads.values()
|
|
874
|
+
if not cres.objectmode
|
|
875
|
+
]
|
|
876
|
+
|
|
877
|
+
def disable_compile(self, val=True):
|
|
878
|
+
"""Disable the compilation of new signatures at call time."""
|
|
879
|
+
# If disabling compilation then there must be at least one signature
|
|
880
|
+
assert (not val) or len(self.signatures) > 0
|
|
881
|
+
self._can_compile = not val
|
|
882
|
+
|
|
883
|
+
def add_overload(self, cres):
|
|
884
|
+
args = tuple(cres.signature.args)
|
|
885
|
+
sig = [a._code for a in args]
|
|
886
|
+
self._insert(sig, cres.entry_point, cres.objectmode)
|
|
887
|
+
self.overloads[args] = cres
|
|
888
|
+
|
|
889
|
+
def fold_argument_types(self, args, kws):
|
|
890
|
+
return self._compiler.fold_argument_types(args, kws)
|
|
891
|
+
|
|
892
|
+
def get_call_template(self, args, kws):
|
|
893
|
+
"""
|
|
894
|
+
Get a typing.ConcreteTemplate for this dispatcher and the given
|
|
895
|
+
*args* and *kws* types. This allows to resolve the return type.
|
|
896
|
+
|
|
897
|
+
A (template, pysig, args, kws) tuple is returned.
|
|
898
|
+
"""
|
|
899
|
+
# XXX how about a dispatcher template class automating the
|
|
900
|
+
# following?
|
|
901
|
+
|
|
902
|
+
# Fold keyword arguments and resolve default values
|
|
903
|
+
pysig, args = self._compiler.fold_argument_types(args, kws)
|
|
904
|
+
kws = {}
|
|
905
|
+
# Ensure an overload is available
|
|
906
|
+
if self._can_compile:
|
|
907
|
+
self.compile(tuple(args))
|
|
908
|
+
|
|
909
|
+
# Create function type for typing
|
|
910
|
+
func_name = self.py_func.__name__
|
|
911
|
+
name = "CallTemplate({0})".format(func_name)
|
|
912
|
+
# The `key` isn't really used except for diagnosis here,
|
|
913
|
+
# so avoid keeping a reference to `cfunc`.
|
|
914
|
+
call_template = typing.make_concrete_template(
|
|
915
|
+
name, key=func_name, signatures=self.nopython_signatures
|
|
916
|
+
)
|
|
917
|
+
return call_template, pysig, args, kws
|
|
918
|
+
|
|
919
|
+
def get_overload(self, sig):
|
|
920
|
+
"""
|
|
921
|
+
Return the compiled function for the given signature.
|
|
922
|
+
"""
|
|
923
|
+
args, return_type = sigutils.normalize_signature(sig)
|
|
924
|
+
return self.overloads[tuple(args)].entry_point
|
|
925
|
+
|
|
926
|
+
@property
|
|
927
|
+
def is_compiling(self):
|
|
928
|
+
"""
|
|
929
|
+
Whether a specialization is currently being compiled.
|
|
930
|
+
"""
|
|
931
|
+
return self._compiling_counter
|
|
932
|
+
|
|
933
|
+
def _compile_for_args(self, *args, **kws):
|
|
934
|
+
"""
|
|
935
|
+
For internal use. Compile a specialized version of the function
|
|
936
|
+
for the given *args* and *kws*, and return the resulting callable.
|
|
937
|
+
"""
|
|
938
|
+
assert not kws
|
|
939
|
+
# call any initialisation required for the compilation chain (e.g.
|
|
940
|
+
# extension point registration).
|
|
941
|
+
self._compilation_chain_init_hook()
|
|
942
|
+
|
|
943
|
+
def error_rewrite(e, issue_type):
|
|
944
|
+
"""
|
|
945
|
+
Rewrite and raise Exception `e` with help supplied based on the
|
|
946
|
+
specified issue_type.
|
|
947
|
+
"""
|
|
948
|
+
if config.SHOW_HELP:
|
|
949
|
+
help_msg = errors.error_extras[issue_type]
|
|
950
|
+
e.patch_message("\n".join((str(e).rstrip(), help_msg)))
|
|
951
|
+
if config.FULL_TRACEBACKS:
|
|
952
|
+
raise e
|
|
953
|
+
else:
|
|
954
|
+
raise e.with_traceback(None)
|
|
955
|
+
|
|
956
|
+
argtypes = []
|
|
957
|
+
for a in args:
|
|
958
|
+
if isinstance(a, OmittedArg):
|
|
959
|
+
argtypes.append(types.Omitted(a.value))
|
|
960
|
+
else:
|
|
961
|
+
argtypes.append(self.typeof_pyval(a))
|
|
962
|
+
|
|
963
|
+
return_val = None
|
|
964
|
+
try:
|
|
965
|
+
return_val = self.compile(tuple(argtypes))
|
|
966
|
+
except errors.ForceLiteralArg as e:
|
|
967
|
+
# Received request for compiler re-entry with the list of arguments
|
|
968
|
+
# indicated by e.requested_args.
|
|
969
|
+
# First, check if any of these args are already Literal-ized
|
|
970
|
+
already_lit_pos = [
|
|
971
|
+
i
|
|
972
|
+
for i in e.requested_args
|
|
973
|
+
if isinstance(args[i], types.Literal)
|
|
974
|
+
]
|
|
975
|
+
if already_lit_pos:
|
|
976
|
+
# Abort compilation if any argument is already a Literal.
|
|
977
|
+
# Letting this continue will cause infinite compilation loop.
|
|
978
|
+
m = (
|
|
979
|
+
"Repeated literal typing request.\n"
|
|
980
|
+
"{}.\n"
|
|
981
|
+
"This is likely caused by an error in typing. "
|
|
982
|
+
"Please see nested and suppressed exceptions."
|
|
983
|
+
)
|
|
984
|
+
info = ", ".join(
|
|
985
|
+
"Arg #{} is {}".format(i, args[i])
|
|
986
|
+
for i in sorted(already_lit_pos)
|
|
987
|
+
)
|
|
988
|
+
raise errors.CompilerError(m.format(info))
|
|
989
|
+
# Convert requested arguments into a Literal.
|
|
990
|
+
args = [
|
|
991
|
+
(types.literal if i in e.requested_args else lambda x: x)(
|
|
992
|
+
args[i]
|
|
993
|
+
)
|
|
994
|
+
for i, v in enumerate(args)
|
|
995
|
+
]
|
|
996
|
+
# Re-enter compilation with the Literal-ized arguments
|
|
997
|
+
return_val = self._compile_for_args(*args)
|
|
998
|
+
|
|
999
|
+
except errors.TypingError as e:
|
|
1000
|
+
# Intercept typing error that may be due to an argument
|
|
1001
|
+
# that failed inferencing as a Numba type
|
|
1002
|
+
failed_args = []
|
|
1003
|
+
for i, arg in enumerate(args):
|
|
1004
|
+
val = arg.value if isinstance(arg, OmittedArg) else arg
|
|
1005
|
+
try:
|
|
1006
|
+
tp = typeof(val, Purpose.argument)
|
|
1007
|
+
except (errors.NumbaValueError, ValueError) as typeof_exc:
|
|
1008
|
+
failed_args.append((i, str(typeof_exc)))
|
|
1009
|
+
else:
|
|
1010
|
+
if tp is None:
|
|
1011
|
+
failed_args.append(
|
|
1012
|
+
(i, f"cannot determine Numba type of value {val}")
|
|
1013
|
+
)
|
|
1014
|
+
if failed_args:
|
|
1015
|
+
# Patch error message to ease debugging
|
|
1016
|
+
args_str = "\n".join(
|
|
1017
|
+
f"- argument {i}: {err}" for i, err in failed_args
|
|
1018
|
+
)
|
|
1019
|
+
msg = (
|
|
1020
|
+
f"{str(e).rstrip()} \n\nThis error may have been caused "
|
|
1021
|
+
f"by the following argument(s):\n{args_str}\n"
|
|
1022
|
+
)
|
|
1023
|
+
e.patch_message(msg)
|
|
1024
|
+
|
|
1025
|
+
error_rewrite(e, "typing")
|
|
1026
|
+
except errors.UnsupportedError as e:
|
|
1027
|
+
# Something unsupported is present in the user code, add help info
|
|
1028
|
+
error_rewrite(e, "unsupported_error")
|
|
1029
|
+
except (
|
|
1030
|
+
errors.NotDefinedError,
|
|
1031
|
+
errors.RedefinedError,
|
|
1032
|
+
errors.VerificationError,
|
|
1033
|
+
) as e:
|
|
1034
|
+
# These errors are probably from an issue with either the code
|
|
1035
|
+
# supplied being syntactically or otherwise invalid
|
|
1036
|
+
error_rewrite(e, "interpreter")
|
|
1037
|
+
except errors.ConstantInferenceError as e:
|
|
1038
|
+
# this is from trying to infer something as constant when it isn't
|
|
1039
|
+
# or isn't supported as a constant
|
|
1040
|
+
error_rewrite(e, "constant_inference")
|
|
1041
|
+
except Exception as e:
|
|
1042
|
+
if config.SHOW_HELP:
|
|
1043
|
+
if hasattr(e, "patch_message"):
|
|
1044
|
+
help_msg = errors.error_extras["reportable"]
|
|
1045
|
+
e.patch_message("\n".join((str(e).rstrip(), help_msg)))
|
|
1046
|
+
# ignore the FULL_TRACEBACKS config, this needs reporting!
|
|
1047
|
+
raise e
|
|
1048
|
+
finally:
|
|
1049
|
+
self._types_active_call.clear()
|
|
1050
|
+
return return_val
|
|
1051
|
+
|
|
1052
|
+
def inspect_llvm(self, signature=None):
|
|
1053
|
+
"""Get the LLVM intermediate representation generated by compilation.
|
|
1054
|
+
|
|
1055
|
+
Parameters
|
|
1056
|
+
----------
|
|
1057
|
+
signature : tuple of numba types, optional
|
|
1058
|
+
Specify a signature for which to obtain the LLVM IR. If None, the
|
|
1059
|
+
IR is returned for all available signatures.
|
|
1060
|
+
|
|
1061
|
+
Returns
|
|
1062
|
+
-------
|
|
1063
|
+
llvm : dict[signature, str] or str
|
|
1064
|
+
Either the LLVM IR string for the specified signature, or, if no
|
|
1065
|
+
signature was given, a dictionary mapping signatures to LLVM IR
|
|
1066
|
+
strings.
|
|
1067
|
+
"""
|
|
1068
|
+
if signature is not None:
|
|
1069
|
+
lib = self.overloads[signature].library
|
|
1070
|
+
return lib.get_llvm_str()
|
|
1071
|
+
|
|
1072
|
+
return dict((sig, self.inspect_llvm(sig)) for sig in self.signatures)
|
|
1073
|
+
|
|
1074
|
+
def inspect_asm(self, signature=None):
|
|
1075
|
+
"""Get the generated assembly code.
|
|
1076
|
+
|
|
1077
|
+
Parameters
|
|
1078
|
+
----------
|
|
1079
|
+
signature : tuple of numba types, optional
|
|
1080
|
+
Specify a signature for which to obtain the assembly code. If
|
|
1081
|
+
None, the assembly code is returned for all available signatures.
|
|
1082
|
+
|
|
1083
|
+
Returns
|
|
1084
|
+
-------
|
|
1085
|
+
asm : dict[signature, str] or str
|
|
1086
|
+
Either the assembly code for the specified signature, or, if no
|
|
1087
|
+
signature was given, a dictionary mapping signatures to assembly
|
|
1088
|
+
code.
|
|
1089
|
+
"""
|
|
1090
|
+
if signature is not None:
|
|
1091
|
+
lib = self.overloads[signature].library
|
|
1092
|
+
return lib.get_asm_str()
|
|
1093
|
+
|
|
1094
|
+
return dict((sig, self.inspect_asm(sig)) for sig in self.signatures)
|
|
1095
|
+
|
|
1096
|
+
def inspect_types(
|
|
1097
|
+
self, file=None, signature=None, pretty=False, style="default", **kwargs
|
|
1098
|
+
):
|
|
1099
|
+
"""Print/return Numba intermediate representation (IR)-annotated code.
|
|
1100
|
+
|
|
1101
|
+
Parameters
|
|
1102
|
+
----------
|
|
1103
|
+
file : file-like object, optional
|
|
1104
|
+
File to which to print. Defaults to sys.stdout if None. Must be
|
|
1105
|
+
None if ``pretty=True``.
|
|
1106
|
+
signature : tuple of numba types, optional
|
|
1107
|
+
Print/return the intermediate representation for only the given
|
|
1108
|
+
signature. If None, the IR is printed for all available signatures.
|
|
1109
|
+
pretty : bool, optional
|
|
1110
|
+
If True, an Annotate object will be returned that can render the
|
|
1111
|
+
IR with color highlighting in Jupyter and IPython. ``file`` must
|
|
1112
|
+
be None if ``pretty`` is True. Additionally, the ``pygments``
|
|
1113
|
+
library must be installed for ``pretty=True``.
|
|
1114
|
+
style : str, optional
|
|
1115
|
+
Choose a style for rendering. Ignored if ``pretty`` is ``False``.
|
|
1116
|
+
This is directly consumed by ``pygments`` formatters. To see a
|
|
1117
|
+
list of available styles, import ``pygments`` and run
|
|
1118
|
+
``list(pygments.styles.get_all_styles())``.
|
|
1119
|
+
|
|
1120
|
+
Returns
|
|
1121
|
+
-------
|
|
1122
|
+
annotated : Annotate object, optional
|
|
1123
|
+
Only returned if ``pretty=True``, otherwise this function is only
|
|
1124
|
+
used for its printing side effect. If ``pretty=True``, an Annotate
|
|
1125
|
+
object is returned that can render itself in Jupyter and IPython.
|
|
1126
|
+
"""
|
|
1127
|
+
overloads = self.overloads
|
|
1128
|
+
if signature is not None:
|
|
1129
|
+
overloads = {signature: self.overloads[signature]}
|
|
1130
|
+
|
|
1131
|
+
if not pretty:
|
|
1132
|
+
if file is None:
|
|
1133
|
+
file = sys.stdout
|
|
1134
|
+
|
|
1135
|
+
for ver, res in overloads.items():
|
|
1136
|
+
print("%s %s" % (self.py_func.__name__, ver), file=file)
|
|
1137
|
+
print("-" * 80, file=file)
|
|
1138
|
+
print(res.type_annotation, file=file)
|
|
1139
|
+
print("=" * 80, file=file)
|
|
1140
|
+
else:
|
|
1141
|
+
if file is not None:
|
|
1142
|
+
raise ValueError("`file` must be None if `pretty=True`")
|
|
1143
|
+
from numba.cuda.core.annotations.pretty_annotate import Annotate
|
|
1144
|
+
|
|
1145
|
+
return Annotate(self, signature=signature, style=style)
|
|
1146
|
+
|
|
1147
|
+
def inspect_cfg(self, signature=None, show_wrapper=None, **kwargs):
|
|
1148
|
+
"""
|
|
1149
|
+
For inspecting the CFG of the function.
|
|
1150
|
+
|
|
1151
|
+
By default the CFG of the user function is shown. The *show_wrapper*
|
|
1152
|
+
option can be set to "python" or "cfunc" to show the python wrapper
|
|
1153
|
+
function or the *cfunc* wrapper function, respectively.
|
|
1154
|
+
|
|
1155
|
+
Parameters accepted in kwargs
|
|
1156
|
+
-----------------------------
|
|
1157
|
+
filename : string, optional
|
|
1158
|
+
the name of the output file, if given this will write the output to
|
|
1159
|
+
filename
|
|
1160
|
+
view : bool, optional
|
|
1161
|
+
whether to immediately view the optional output file
|
|
1162
|
+
highlight : bool, set, dict, optional
|
|
1163
|
+
what, if anything, to highlight, options are:
|
|
1164
|
+
{ incref : bool, # highlight NRT_incref calls
|
|
1165
|
+
decref : bool, # highlight NRT_decref calls
|
|
1166
|
+
returns : bool, # highlight exits which are normal returns
|
|
1167
|
+
raises : bool, # highlight exits which are from raise
|
|
1168
|
+
meminfo : bool, # highlight calls to NRT*meminfo
|
|
1169
|
+
branches : bool, # highlight true/false branches
|
|
1170
|
+
}
|
|
1171
|
+
Default is True which sets all of the above to True. Supplying a set
|
|
1172
|
+
of strings is also accepted, these are interpreted as key:True with
|
|
1173
|
+
respect to the above dictionary. e.g. {'incref', 'decref'} would
|
|
1174
|
+
switch on highlighting on increfs and decrefs.
|
|
1175
|
+
interleave: bool, set, dict, optional
|
|
1176
|
+
what, if anything, to interleave in the LLVM IR, options are:
|
|
1177
|
+
{ python: bool # interleave python source code with the LLVM IR
|
|
1178
|
+
lineinfo: bool # interleave line information markers with the LLVM
|
|
1179
|
+
# IR
|
|
1180
|
+
}
|
|
1181
|
+
Default is True which sets all of the above to True. Supplying a set
|
|
1182
|
+
of strings is also accepted, these are interpreted as key:True with
|
|
1183
|
+
respect to the above dictionary. e.g. {'python',} would
|
|
1184
|
+
switch on interleaving of python source code in the LLVM IR.
|
|
1185
|
+
strip_ir : bool, optional
|
|
1186
|
+
Default is False. If set to True all LLVM IR that is superfluous to
|
|
1187
|
+
that requested in kwarg `highlight` will be removed.
|
|
1188
|
+
show_key : bool, optional
|
|
1189
|
+
Default is True. Create a "key" for the highlighting in the rendered
|
|
1190
|
+
CFG.
|
|
1191
|
+
fontsize : int, optional
|
|
1192
|
+
Default is 8. Set the fontsize in the output to this value.
|
|
1193
|
+
"""
|
|
1194
|
+
if signature is not None:
|
|
1195
|
+
cres = self.overloads[signature]
|
|
1196
|
+
lib = cres.library
|
|
1197
|
+
if show_wrapper == "python":
|
|
1198
|
+
fname = cres.fndesc.llvm_cpython_wrapper_name
|
|
1199
|
+
elif show_wrapper == "cfunc":
|
|
1200
|
+
fname = cres.fndesc.llvm_cfunc_wrapper_name
|
|
1201
|
+
else:
|
|
1202
|
+
fname = cres.fndesc.mangled_name
|
|
1203
|
+
return lib.get_function_cfg(fname, py_func=self.py_func, **kwargs)
|
|
1204
|
+
|
|
1205
|
+
return dict(
|
|
1206
|
+
(sig, self.inspect_cfg(sig, show_wrapper=show_wrapper))
|
|
1207
|
+
for sig in self.signatures
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
def inspect_disasm_cfg(self, signature=None):
|
|
1211
|
+
"""
|
|
1212
|
+
For inspecting the CFG of the disassembly of the function.
|
|
1213
|
+
|
|
1214
|
+
Requires python package: r2pipe
|
|
1215
|
+
Requires radare2 binary on $PATH.
|
|
1216
|
+
Notebook rendering requires python package: graphviz
|
|
1217
|
+
|
|
1218
|
+
signature : tuple of Numba types, optional
|
|
1219
|
+
Print/return the disassembly CFG for only the given signatures.
|
|
1220
|
+
If None, the IR is printed for all available signatures.
|
|
1221
|
+
"""
|
|
1222
|
+
if signature is not None:
|
|
1223
|
+
cres = self.overloads[signature]
|
|
1224
|
+
lib = cres.library
|
|
1225
|
+
return lib.get_disasm_cfg(cres.fndesc.mangled_name)
|
|
1226
|
+
|
|
1227
|
+
return dict(
|
|
1228
|
+
(sig, self.inspect_disasm_cfg(sig)) for sig in self.signatures
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
def get_annotation_info(self, signature=None):
|
|
1232
|
+
"""
|
|
1233
|
+
Gets the annotation information for the function specified by
|
|
1234
|
+
signature. If no signature is supplied a dictionary of signature to
|
|
1235
|
+
annotation information is returned.
|
|
1236
|
+
"""
|
|
1237
|
+
signatures = self.signatures if signature is None else [signature]
|
|
1238
|
+
out = collections.OrderedDict()
|
|
1239
|
+
for sig in signatures:
|
|
1240
|
+
cres = self.overloads[sig]
|
|
1241
|
+
ta = cres.type_annotation
|
|
1242
|
+
key = (
|
|
1243
|
+
ta.func_id.filename + ":" + str(ta.func_id.firstlineno + 1),
|
|
1244
|
+
ta.signature,
|
|
1245
|
+
)
|
|
1246
|
+
out[key] = ta.annotate_raw()[key]
|
|
1247
|
+
return out
|
|
1248
|
+
|
|
1249
|
+
def _explain_ambiguous(self, *args, **kws):
|
|
1250
|
+
"""
|
|
1251
|
+
Callback for the C _Dispatcher object.
|
|
1252
|
+
"""
|
|
1253
|
+
assert not kws, "kwargs not handled"
|
|
1254
|
+
args = tuple([self.typeof_pyval(a) for a in args])
|
|
1255
|
+
# The order here must be deterministic for testing purposes, which
|
|
1256
|
+
# is ensured by the OrderedDict.
|
|
1257
|
+
sigs = self.nopython_signatures
|
|
1258
|
+
# This will raise
|
|
1259
|
+
self.typingctx.resolve_overload(
|
|
1260
|
+
self.py_func, sigs, args, kws, allow_ambiguous=False
|
|
1261
|
+
)
|
|
1262
|
+
|
|
1263
|
+
def _explain_matching_error(self, *args, **kws):
|
|
1264
|
+
"""
|
|
1265
|
+
Callback for the C _Dispatcher object.
|
|
1266
|
+
"""
|
|
1267
|
+
assert not kws, "kwargs not handled"
|
|
1268
|
+
args = [self.typeof_pyval(a) for a in args]
|
|
1269
|
+
msg = "No matching definition for argument type(s) %s" % ", ".join(
|
|
1270
|
+
map(str, args)
|
|
1271
|
+
)
|
|
1272
|
+
raise TypeError(msg)
|
|
1273
|
+
|
|
1274
|
+
def _search_new_conversions(self, *args, **kws):
|
|
1275
|
+
"""
|
|
1276
|
+
Callback for the C _Dispatcher object.
|
|
1277
|
+
Search for approximately matching signatures for the given arguments,
|
|
1278
|
+
and ensure the corresponding conversions are registered in the C++
|
|
1279
|
+
type manager.
|
|
1280
|
+
"""
|
|
1281
|
+
assert not kws, "kwargs not handled"
|
|
1282
|
+
args = [self.typeof_pyval(a) for a in args]
|
|
1283
|
+
found = False
|
|
1284
|
+
for sig in self.nopython_signatures:
|
|
1285
|
+
conv = self.typingctx.install_possible_conversions(args, sig.args)
|
|
1286
|
+
if conv:
|
|
1287
|
+
found = True
|
|
1288
|
+
return found
|
|
1289
|
+
|
|
1290
|
+
def __repr__(self):
|
|
1291
|
+
return "%s(%s)" % (type(self).__name__, self.py_func)
|
|
1292
|
+
|
|
1293
|
+
def typeof_pyval(self, val):
|
|
1294
|
+
"""
|
|
1295
|
+
Resolve the Numba type of Python value *val*.
|
|
1296
|
+
This is called from numba._dispatcher as a fallback if the native code
|
|
1297
|
+
cannot decide the type.
|
|
1298
|
+
"""
|
|
1299
|
+
try:
|
|
1300
|
+
tp = typeof(val, Purpose.argument)
|
|
1301
|
+
except (errors.NumbaValueError, ValueError):
|
|
1302
|
+
tp = types.pyobject
|
|
1303
|
+
else:
|
|
1304
|
+
if tp is None:
|
|
1305
|
+
tp = types.pyobject
|
|
1306
|
+
self._types_active_call.add(tp)
|
|
1307
|
+
return tp
|
|
1308
|
+
|
|
1309
|
+
def _callback_add_timer(self, duration, cres, lock_name):
|
|
1310
|
+
md = cres.metadata
|
|
1311
|
+
# md can be None when code is loaded from cache
|
|
1312
|
+
if md is not None:
|
|
1313
|
+
timers = md.setdefault("timers", {})
|
|
1314
|
+
if lock_name not in timers:
|
|
1315
|
+
# Only write if the metadata does not exist
|
|
1316
|
+
timers[lock_name] = duration
|
|
1317
|
+
else:
|
|
1318
|
+
msg = f"'{lock_name} metadata is already defined."
|
|
1319
|
+
raise AssertionError(msg)
|
|
1320
|
+
|
|
1321
|
+
def _callback_add_compiler_timer(self, duration, cres):
|
|
1322
|
+
return self._callback_add_timer(
|
|
1323
|
+
duration, cres, lock_name="compiler_lock"
|
|
1324
|
+
)
|
|
1325
|
+
|
|
1326
|
+
def _callback_add_llvm_timer(self, duration, cres):
|
|
1327
|
+
return self._callback_add_timer(duration, cres, lock_name="llvm_lock")
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
class _MemoMixin:
|
|
1331
|
+
__uuid = None
|
|
1332
|
+
# A {uuid -> instance} mapping, for deserialization
|
|
1333
|
+
_memo = weakref.WeakValueDictionary()
|
|
1334
|
+
# hold refs to last N functions deserialized, retaining them in _memo
|
|
1335
|
+
# regardless of whether there is another reference
|
|
1336
|
+
_recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)
|
|
1337
|
+
|
|
1338
|
+
@property
|
|
1339
|
+
def _uuid(self):
|
|
1340
|
+
"""
|
|
1341
|
+
An instance-specific UUID, to avoid multiple deserializations of
|
|
1342
|
+
a given instance.
|
|
1343
|
+
|
|
1344
|
+
Note: this is lazily-generated, for performance reasons.
|
|
1345
|
+
"""
|
|
1346
|
+
u = self.__uuid
|
|
1347
|
+
if u is None:
|
|
1348
|
+
u = str(uuid.uuid4())
|
|
1349
|
+
self._set_uuid(u)
|
|
1350
|
+
return u
|
|
1351
|
+
|
|
1352
|
+
def _set_uuid(self, u):
|
|
1353
|
+
assert self.__uuid is None
|
|
1354
|
+
self.__uuid = u
|
|
1355
|
+
self._memo[u] = self
|
|
1356
|
+
self._recent.append(self)
|
|
1357
|
+
|
|
1358
|
+
|
|
1359
|
+
_CompileStats = collections.namedtuple(
|
|
1360
|
+
"_CompileStats", ("cache_path", "cache_hits", "cache_misses")
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
|
|
1364
|
+
class _FunctionCompiler(object):
|
|
1365
|
+
def __init__(self, py_func, targetdescr, targetoptions, pipeline_class):
|
|
1366
|
+
self.py_func = py_func
|
|
1367
|
+
self.targetdescr = targetdescr
|
|
1368
|
+
self.targetoptions = targetoptions
|
|
1369
|
+
self.locals = {}
|
|
1370
|
+
self.pysig = utils.pysignature(self.py_func)
|
|
1371
|
+
self.pipeline_class = pipeline_class
|
|
1372
|
+
# Remember key=(args, return_type) combinations that will fail
|
|
1373
|
+
# compilation to avoid compilation attempt on them. The values are
|
|
1374
|
+
# the exceptions.
|
|
1375
|
+
self._failed_cache = {}
|
|
1376
|
+
|
|
1377
|
+
def fold_argument_types(self, args, kws):
|
|
1378
|
+
"""
|
|
1379
|
+
Given positional and named argument types, fold keyword arguments
|
|
1380
|
+
and resolve defaults by inserting types.Omitted() instances.
|
|
1381
|
+
|
|
1382
|
+
A (pysig, argument types) tuple is returned.
|
|
1383
|
+
"""
|
|
1384
|
+
|
|
1385
|
+
def normal_handler(index, param, value):
|
|
1386
|
+
return value
|
|
1387
|
+
|
|
1388
|
+
def default_handler(index, param, default):
|
|
1389
|
+
return types.Omitted(default)
|
|
1390
|
+
|
|
1391
|
+
def stararg_handler(index, param, values):
|
|
1392
|
+
return types.StarArgTuple(values)
|
|
1393
|
+
|
|
1394
|
+
# For now, we take argument values from the @jit function
|
|
1395
|
+
args = fold_arguments(
|
|
1396
|
+
self.pysig,
|
|
1397
|
+
args,
|
|
1398
|
+
kws,
|
|
1399
|
+
normal_handler,
|
|
1400
|
+
default_handler,
|
|
1401
|
+
stararg_handler,
|
|
1402
|
+
)
|
|
1403
|
+
return self.pysig, args
|
|
1404
|
+
|
|
1405
|
+
def compile(self, args, return_type):
|
|
1406
|
+
status, retval = self._compile_cached(args, return_type)
|
|
1407
|
+
if status:
|
|
1408
|
+
return retval
|
|
1409
|
+
else:
|
|
1410
|
+
raise retval
|
|
1411
|
+
|
|
1412
|
+
def _compile_cached(self, args, return_type):
|
|
1413
|
+
key = tuple(args), return_type
|
|
1414
|
+
try:
|
|
1415
|
+
return False, self._failed_cache[key]
|
|
1416
|
+
except KeyError:
|
|
1417
|
+
pass
|
|
1418
|
+
|
|
1419
|
+
try:
|
|
1420
|
+
retval = self._compile_core(args, return_type)
|
|
1421
|
+
except errors.TypingError as e:
|
|
1422
|
+
self._failed_cache[key] = e
|
|
1423
|
+
return False, e
|
|
1424
|
+
else:
|
|
1425
|
+
return True, retval
|
|
1426
|
+
|
|
1427
|
+
def _compile_core(self, args, return_type):
|
|
1428
|
+
flags = Flags()
|
|
1429
|
+
self.targetdescr.options.parse_as_flags(flags, self.targetoptions)
|
|
1430
|
+
flags = self._customize_flags(flags)
|
|
1431
|
+
|
|
1432
|
+
impl = self._get_implementation(args, {})
|
|
1433
|
+
cres = compile_extra(
|
|
1434
|
+
self.targetdescr.typing_context,
|
|
1435
|
+
self.targetdescr.target_context,
|
|
1436
|
+
impl,
|
|
1437
|
+
args=args,
|
|
1438
|
+
return_type=return_type,
|
|
1439
|
+
flags=flags,
|
|
1440
|
+
locals=self.locals,
|
|
1441
|
+
pipeline_class=self.pipeline_class,
|
|
1442
|
+
)
|
|
1443
|
+
# Check typing error if object mode is used
|
|
1444
|
+
if cres.typing_error is not None and not flags.enable_pyobject:
|
|
1445
|
+
raise cres.typing_error
|
|
1446
|
+
return cres
|
|
1447
|
+
|
|
1448
|
+
def get_globals_for_reduction(self):
|
|
1449
|
+
return serialize._get_function_globals_for_reduction(self.py_func)
|
|
1450
|
+
|
|
1451
|
+
def _get_implementation(self, args, kws):
|
|
1452
|
+
return self.py_func
|
|
1453
|
+
|
|
1454
|
+
def _customize_flags(self, flags):
|
|
1455
|
+
return flags
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
|
|
1459
|
+
"""
|
|
1460
|
+
CUDA Dispatcher object. When configured and called, the dispatcher will
|
|
1461
|
+
specialize itself for the given arguments (if no suitable specialized
|
|
1462
|
+
version already exists) & compute capability, and launch on the device
|
|
1463
|
+
associated with the current context.
|
|
1464
|
+
|
|
1465
|
+
Dispatcher objects are not to be constructed by the user, but instead are
|
|
1466
|
+
created using the :func:`numba.cuda.jit` decorator.
|
|
1467
|
+
"""
|
|
1468
|
+
|
|
1469
|
+
# Whether to fold named arguments and default values. Default values are
|
|
1470
|
+
# presently unsupported on CUDA, so we can leave this as False in all
|
|
1471
|
+
# cases.
|
|
1472
|
+
_fold_args = False
|
|
1473
|
+
|
|
1474
|
+
targetdescr = cuda_target
|
|
1475
|
+
|
|
1476
|
+
def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
|
|
1477
|
+
"""
|
|
1478
|
+
Parameters
|
|
1479
|
+
----------
|
|
1480
|
+
py_func: function object to be compiled
|
|
1481
|
+
targetoptions: dict, optional
|
|
1482
|
+
Target-specific config options.
|
|
1483
|
+
pipeline_class: type numba.compiler.CompilerBase
|
|
1484
|
+
The compiler pipeline type.
|
|
1485
|
+
"""
|
|
1486
|
+
self.typingctx = self.targetdescr.typing_context
|
|
1487
|
+
self.targetctx = self.targetdescr.target_context
|
|
1488
|
+
|
|
1489
|
+
pysig = utils.pysignature(py_func)
|
|
1490
|
+
arg_count = len(pysig.parameters)
|
|
1491
|
+
can_fallback = not targetoptions.get("nopython", False)
|
|
1492
|
+
|
|
1493
|
+
_DispatcherBase.__init__(
|
|
1494
|
+
self,
|
|
1495
|
+
arg_count,
|
|
1496
|
+
py_func,
|
|
1497
|
+
pysig,
|
|
1498
|
+
can_fallback,
|
|
1499
|
+
exact_match_required=False,
|
|
1500
|
+
)
|
|
1501
|
+
|
|
1502
|
+
functools.update_wrapper(self, py_func)
|
|
1503
|
+
|
|
1504
|
+
self.targetoptions = targetoptions
|
|
1505
|
+
self._cache = NullCache()
|
|
1506
|
+
compiler_class = _FunctionCompiler
|
|
1507
|
+
self._compiler = compiler_class(
|
|
1508
|
+
py_func, self.targetdescr, targetoptions, pipeline_class
|
|
1509
|
+
)
|
|
1510
|
+
self._cache_hits = collections.Counter()
|
|
1511
|
+
self._cache_misses = collections.Counter()
|
|
1512
|
+
|
|
1513
|
+
# The following properties are for specialization of CUDADispatchers. A
|
|
1514
|
+
# specialized CUDADispatcher is one that is compiled for exactly one
|
|
1515
|
+
# set of argument types, and bypasses some argument type checking for
|
|
1516
|
+
# faster kernel launches.
|
|
1517
|
+
|
|
1518
|
+
# Is this a specialized dispatcher?
|
|
1519
|
+
self._specialized = False
|
|
1520
|
+
|
|
1521
|
+
# If we produced specialized dispatchers, we cache them for each set of
|
|
1522
|
+
# argument types
|
|
1523
|
+
self.specializations = {}
|
|
1524
|
+
|
|
1525
|
+
def dump(self, tab=""):
|
|
1526
|
+
print(
|
|
1527
|
+
f"{tab}DUMP {type(self).__name__}[{self.py_func.__name__}"
|
|
1528
|
+
f", type code={self._type._code}]"
|
|
1529
|
+
)
|
|
1530
|
+
for cres in self.overloads.values():
|
|
1531
|
+
cres.dump(tab=tab + " ")
|
|
1532
|
+
print(f"{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]")
|
|
1533
|
+
|
|
1534
|
+
@property
|
|
1535
|
+
def _numba_type_(self):
|
|
1536
|
+
return ext_types.CUDADispatcher(self)
|
|
1537
|
+
|
|
1538
|
+
def enable_caching(self):
|
|
1539
|
+
self._cache = CUDACache(self.py_func)
|
|
1540
|
+
|
|
1541
|
+
def __get__(self, obj, objtype=None):
|
|
1542
|
+
"""Allow a JIT function to be bound as a method to an object"""
|
|
1543
|
+
if obj is None: # Unbound method
|
|
1544
|
+
return self
|
|
1545
|
+
else: # Bound method
|
|
1546
|
+
return pytypes.MethodType(self, obj)
|
|
1547
|
+
|
|
1548
|
+
@functools.lru_cache(maxsize=128)
|
|
1549
|
+
def configure(self, griddim, blockdim, stream=0, sharedmem=0):
|
|
1550
|
+
griddim, blockdim = normalize_kernel_dimensions(griddim, blockdim)
|
|
1551
|
+
return _LaunchConfiguration(self, griddim, blockdim, stream, sharedmem)
|
|
1552
|
+
|
|
1553
|
+
def __getitem__(self, args):
|
|
1554
|
+
if len(args) not in [2, 3, 4]:
|
|
1555
|
+
raise ValueError("must specify at least the griddim and blockdim")
|
|
1556
|
+
return self.configure(*args)
|
|
1557
|
+
|
|
1558
|
+
def forall(self, ntasks, tpb=0, stream=0, sharedmem=0):
|
|
1559
|
+
"""Returns a 1D-configured dispatcher for a given number of tasks.
|
|
1560
|
+
|
|
1561
|
+
This assumes that:
|
|
1562
|
+
|
|
1563
|
+
- the kernel maps the Global Thread ID ``cuda.grid(1)`` to tasks on a
|
|
1564
|
+
1-1 basis.
|
|
1565
|
+
- the kernel checks that the Global Thread ID is upper-bounded by
|
|
1566
|
+
``ntasks``, and does nothing if it is not.
|
|
1567
|
+
|
|
1568
|
+
:param ntasks: The number of tasks.
|
|
1569
|
+
:param tpb: The size of a block. An appropriate value is chosen if this
|
|
1570
|
+
parameter is not supplied.
|
|
1571
|
+
:param stream: The stream on which the configured dispatcher will be
|
|
1572
|
+
launched.
|
|
1573
|
+
:param sharedmem: The number of bytes of dynamic shared memory required
|
|
1574
|
+
by the kernel.
|
|
1575
|
+
:return: A configured dispatcher, ready to launch on a set of
|
|
1576
|
+
arguments."""
|
|
1577
|
+
|
|
1578
|
+
return ForAll(self, ntasks, tpb=tpb, stream=stream, sharedmem=sharedmem)
|
|
1579
|
+
|
|
1580
|
+
@property
|
|
1581
|
+
def extensions(self):
|
|
1582
|
+
"""
|
|
1583
|
+
A list of objects that must have a `prepare_args` function. When a
|
|
1584
|
+
specialized kernel is called, each argument will be passed through
|
|
1585
|
+
to the `prepare_args` (from the last object in this list to the
|
|
1586
|
+
first). The arguments to `prepare_args` are:
|
|
1587
|
+
|
|
1588
|
+
- `ty` the numba type of the argument
|
|
1589
|
+
- `val` the argument value itself
|
|
1590
|
+
- `stream` the CUDA stream used for the current call to the kernel
|
|
1591
|
+
- `retr` a list of zero-arg functions that you may want to append
|
|
1592
|
+
post-call cleanup work to.
|
|
1593
|
+
|
|
1594
|
+
The `prepare_args` function must return a tuple `(ty, val)`, which
|
|
1595
|
+
will be passed in turn to the next right-most `extension`. After all
|
|
1596
|
+
the extensions have been called, the resulting `(ty, val)` will be
|
|
1597
|
+
passed into Numba's default argument marshalling logic.
|
|
1598
|
+
"""
|
|
1599
|
+
return self.targetoptions.get("extensions")
|
|
1600
|
+
|
|
1601
|
+
def __call__(self, *args, **kwargs):
|
|
1602
|
+
# An attempt to launch an unconfigured kernel
|
|
1603
|
+
raise ValueError(missing_launch_config_msg)
|
|
1604
|
+
|
|
1605
|
+
def call(self, args, griddim, blockdim, stream, sharedmem):
|
|
1606
|
+
"""
|
|
1607
|
+
Compile if necessary and invoke this kernel with *args*.
|
|
1608
|
+
"""
|
|
1609
|
+
if self.specialized:
|
|
1610
|
+
kernel = next(iter(self.overloads.values()))
|
|
1611
|
+
else:
|
|
1612
|
+
kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
|
|
1613
|
+
|
|
1614
|
+
kernel.launch(args, griddim, blockdim, stream, sharedmem)
|
|
1615
|
+
|
|
1616
|
+
def _compile_for_args(self, *args, **kws):
|
|
1617
|
+
# Based on _DispatcherBase._compile_for_args.
|
|
1618
|
+
assert not kws
|
|
1619
|
+
argtypes = [self.typeof_pyval(a) for a in args]
|
|
1620
|
+
return self.compile(tuple(argtypes))
|
|
1621
|
+
|
|
1622
|
+
def typeof_pyval(self, val):
|
|
1623
|
+
# Based on _DispatcherBase.typeof_pyval, but differs from it to support
|
|
1624
|
+
# the CUDA Array Interface.
|
|
1625
|
+
try:
|
|
1626
|
+
return typeof(val, Purpose.argument)
|
|
1627
|
+
except ValueError:
|
|
1628
|
+
if (
|
|
1629
|
+
interface := getattr(val, "__cuda_array_interface__")
|
|
1630
|
+
) is not None:
|
|
1631
|
+
# When typing, we don't need to synchronize on the array's
|
|
1632
|
+
# stream - this is done when the kernel is launched.
|
|
1633
|
+
|
|
1634
|
+
return typeof(
|
|
1635
|
+
cuda.from_cuda_array_interface(interface, sync=False),
|
|
1636
|
+
Purpose.argument,
|
|
1637
|
+
)
|
|
1638
|
+
else:
|
|
1639
|
+
raise
|
|
1640
|
+
|
|
1641
|
+
def specialize(self, *args):
|
|
1642
|
+
"""
|
|
1643
|
+
Create a new instance of this dispatcher specialized for the given
|
|
1644
|
+
*args*.
|
|
1645
|
+
"""
|
|
1646
|
+
cc = get_current_device().compute_capability
|
|
1647
|
+
argtypes = tuple(self.typeof_pyval(a) for a in args)
|
|
1648
|
+
if self.specialized:
|
|
1649
|
+
raise RuntimeError("Dispatcher already specialized")
|
|
1650
|
+
|
|
1651
|
+
specialization = self.specializations.get((cc, argtypes))
|
|
1652
|
+
if specialization:
|
|
1653
|
+
return specialization
|
|
1654
|
+
|
|
1655
|
+
targetoptions = self.targetoptions
|
|
1656
|
+
specialization = CUDADispatcher(
|
|
1657
|
+
self.py_func, targetoptions=targetoptions
|
|
1658
|
+
)
|
|
1659
|
+
specialization.compile(argtypes)
|
|
1660
|
+
specialization.disable_compile()
|
|
1661
|
+
specialization._specialized = True
|
|
1662
|
+
self.specializations[cc, argtypes] = specialization
|
|
1663
|
+
return specialization
|
|
1664
|
+
|
|
1665
|
+
@property
|
|
1666
|
+
def specialized(self):
|
|
1667
|
+
"""
|
|
1668
|
+
True if the Dispatcher has been specialized.
|
|
1669
|
+
"""
|
|
1670
|
+
return self._specialized
|
|
1671
|
+
|
|
1672
|
+
def get_regs_per_thread(self, signature=None):
|
|
1673
|
+
"""
|
|
1674
|
+
Returns the number of registers used by each thread in this kernel for
|
|
1675
|
+
the device in the current context.
|
|
1676
|
+
|
|
1677
|
+
:param signature: The signature of the compiled kernel to get register
|
|
1678
|
+
usage for. This may be omitted for a specialized
|
|
1679
|
+
kernel.
|
|
1680
|
+
:return: The number of registers used by the compiled variant of the
|
|
1681
|
+
kernel for the given signature and current device.
|
|
1682
|
+
"""
|
|
1683
|
+
if signature is not None:
|
|
1684
|
+
return self.overloads[signature.args].regs_per_thread
|
|
1685
|
+
if self.specialized:
|
|
1686
|
+
return next(iter(self.overloads.values())).regs_per_thread
|
|
1687
|
+
else:
|
|
1688
|
+
return {
|
|
1689
|
+
sig: overload.regs_per_thread
|
|
1690
|
+
for sig, overload in self.overloads.items()
|
|
1691
|
+
}
|
|
1692
|
+
|
|
1693
|
+
def get_const_mem_size(self, signature=None):
|
|
1694
|
+
"""
|
|
1695
|
+
Returns the size in bytes of constant memory used by this kernel for
|
|
1696
|
+
the device in the current context.
|
|
1697
|
+
|
|
1698
|
+
:param signature: The signature of the compiled kernel to get constant
|
|
1699
|
+
memory usage for. This may be omitted for a
|
|
1700
|
+
specialized kernel.
|
|
1701
|
+
:return: The size in bytes of constant memory allocated by the
|
|
1702
|
+
compiled variant of the kernel for the given signature and
|
|
1703
|
+
current device.
|
|
1704
|
+
"""
|
|
1705
|
+
if signature is not None:
|
|
1706
|
+
return self.overloads[signature.args].const_mem_size
|
|
1707
|
+
if self.specialized:
|
|
1708
|
+
return next(iter(self.overloads.values())).const_mem_size
|
|
1709
|
+
else:
|
|
1710
|
+
return {
|
|
1711
|
+
sig: overload.const_mem_size
|
|
1712
|
+
for sig, overload in self.overloads.items()
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
def get_shared_mem_per_block(self, signature=None):
|
|
1716
|
+
"""
|
|
1717
|
+
Returns the size in bytes of statically allocated shared memory
|
|
1718
|
+
for this kernel.
|
|
1719
|
+
|
|
1720
|
+
:param signature: The signature of the compiled kernel to get shared
|
|
1721
|
+
memory usage for. This may be omitted for a
|
|
1722
|
+
specialized kernel.
|
|
1723
|
+
:return: The amount of shared memory allocated by the compiled variant
|
|
1724
|
+
of the kernel for the given signature and current device.
|
|
1725
|
+
"""
|
|
1726
|
+
if signature is not None:
|
|
1727
|
+
return self.overloads[signature.args].shared_mem_per_block
|
|
1728
|
+
if self.specialized:
|
|
1729
|
+
return next(iter(self.overloads.values())).shared_mem_per_block
|
|
1730
|
+
else:
|
|
1731
|
+
return {
|
|
1732
|
+
sig: overload.shared_mem_per_block
|
|
1733
|
+
for sig, overload in self.overloads.items()
|
|
1734
|
+
}
|
|
1735
|
+
|
|
1736
|
+
def get_max_threads_per_block(self, signature=None):
|
|
1737
|
+
"""
|
|
1738
|
+
Returns the maximum allowable number of threads per block
|
|
1739
|
+
for this kernel. Exceeding this threshold will result in
|
|
1740
|
+
the kernel failing to launch.
|
|
1741
|
+
|
|
1742
|
+
:param signature: The signature of the compiled kernel to get the max
|
|
1743
|
+
threads per block for. This may be omitted for a
|
|
1744
|
+
specialized kernel.
|
|
1745
|
+
:return: The maximum allowable threads per block for the compiled
|
|
1746
|
+
variant of the kernel for the given signature and current
|
|
1747
|
+
device.
|
|
1748
|
+
"""
|
|
1749
|
+
if signature is not None:
|
|
1750
|
+
return self.overloads[signature.args].max_threads_per_block
|
|
1751
|
+
if self.specialized:
|
|
1752
|
+
return next(iter(self.overloads.values())).max_threads_per_block
|
|
1753
|
+
else:
|
|
1754
|
+
return {
|
|
1755
|
+
sig: overload.max_threads_per_block
|
|
1756
|
+
for sig, overload in self.overloads.items()
|
|
1757
|
+
}
|
|
1758
|
+
|
|
1759
|
+
def get_local_mem_per_thread(self, signature=None):
|
|
1760
|
+
"""
|
|
1761
|
+
Returns the size in bytes of local memory per thread
|
|
1762
|
+
for this kernel.
|
|
1763
|
+
|
|
1764
|
+
:param signature: The signature of the compiled kernel to get local
|
|
1765
|
+
memory usage for. This may be omitted for a
|
|
1766
|
+
specialized kernel.
|
|
1767
|
+
:return: The amount of local memory allocated by the compiled variant
|
|
1768
|
+
of the kernel for the given signature and current device.
|
|
1769
|
+
"""
|
|
1770
|
+
if signature is not None:
|
|
1771
|
+
return self.overloads[signature.args].local_mem_per_thread
|
|
1772
|
+
if self.specialized:
|
|
1773
|
+
return next(iter(self.overloads.values())).local_mem_per_thread
|
|
1774
|
+
else:
|
|
1775
|
+
return {
|
|
1776
|
+
sig: overload.local_mem_per_thread
|
|
1777
|
+
for sig, overload in self.overloads.items()
|
|
1778
|
+
}
|
|
1779
|
+
|
|
1780
|
+
def get_call_template(self, args, kws):
|
|
1781
|
+
# Originally copied from _DispatcherBase.get_call_template. This
|
|
1782
|
+
# version deviates slightly from the _DispatcherBase version in order
|
|
1783
|
+
# to force casts when calling device functions. See e.g.
|
|
1784
|
+
# TestDeviceFunc.test_device_casting, added in PR #7496.
|
|
1785
|
+
"""
|
|
1786
|
+
Get a typing.ConcreteTemplate for this dispatcher and the given
|
|
1787
|
+
*args* and *kws* types. This allows resolution of the return type.
|
|
1788
|
+
|
|
1789
|
+
A (template, pysig, args, kws) tuple is returned.
|
|
1790
|
+
"""
|
|
1791
|
+
# Fold keyword arguments and resolve default values
|
|
1792
|
+
pysig, args = self.fold_argument_types(args, kws)
|
|
1793
|
+
kws = {}
|
|
1794
|
+
|
|
1795
|
+
# Ensure an exactly-matching overload is available if we can
|
|
1796
|
+
# compile. We proceed with the typing even if we can't compile
|
|
1797
|
+
# because we may be able to force a cast on the caller side.
|
|
1798
|
+
if self._can_compile:
|
|
1799
|
+
self.compile_device(tuple(args))
|
|
1800
|
+
|
|
1801
|
+
# Create function type for typing
|
|
1802
|
+
func_name = self.py_func.__name__
|
|
1803
|
+
name = "CallTemplate({0})".format(func_name)
|
|
1804
|
+
|
|
1805
|
+
call_template = typing.make_concrete_template(
|
|
1806
|
+
name, key=func_name, signatures=self.nopython_signatures
|
|
1807
|
+
)
|
|
1808
|
+
pysig = utils.pysignature(self.py_func)
|
|
1809
|
+
|
|
1810
|
+
return call_template, pysig, args, kws
|
|
1811
|
+
|
|
1812
|
+
def compile_device(self, args, return_type=None):
|
|
1813
|
+
"""Compile the device function for the given argument types.
|
|
1814
|
+
|
|
1815
|
+
Each signature is compiled once by caching the compiled function inside
|
|
1816
|
+
this object.
|
|
1817
|
+
|
|
1818
|
+
Returns the `CompileResult`.
|
|
1819
|
+
"""
|
|
1820
|
+
if args not in self.overloads:
|
|
1821
|
+
with self._compiling_counter:
|
|
1822
|
+
debug = self.targetoptions.get("debug")
|
|
1823
|
+
lineinfo = self.targetoptions.get("lineinfo")
|
|
1824
|
+
forceinline = self.targetoptions.get("forceinline")
|
|
1825
|
+
fastmath = self.targetoptions.get("fastmath")
|
|
1826
|
+
|
|
1827
|
+
nvvm_options = {
|
|
1828
|
+
"opt": 3 if self.targetoptions.get("opt") else 0,
|
|
1829
|
+
"fastmath": fastmath,
|
|
1830
|
+
}
|
|
1831
|
+
|
|
1832
|
+
if debug:
|
|
1833
|
+
nvvm_options["g"] = None
|
|
1834
|
+
|
|
1835
|
+
cc = get_current_device().compute_capability
|
|
1836
|
+
cres = compile_cuda(
|
|
1837
|
+
self.py_func,
|
|
1838
|
+
return_type,
|
|
1839
|
+
args,
|
|
1840
|
+
debug=debug,
|
|
1841
|
+
lineinfo=lineinfo,
|
|
1842
|
+
forceinline=forceinline,
|
|
1843
|
+
fastmath=fastmath,
|
|
1844
|
+
nvvm_options=nvvm_options,
|
|
1845
|
+
cc=cc,
|
|
1846
|
+
)
|
|
1847
|
+
self.overloads[args] = cres
|
|
1848
|
+
|
|
1849
|
+
cres.target_context.insert_user_function(
|
|
1850
|
+
cres.entry_point, cres.fndesc, [cres.library]
|
|
1851
|
+
)
|
|
1852
|
+
else:
|
|
1853
|
+
cres = self.overloads[args]
|
|
1854
|
+
|
|
1855
|
+
return cres
|
|
1856
|
+
|
|
1857
|
+
def add_overload(self, kernel, argtypes):
|
|
1858
|
+
c_sig = [a._code for a in argtypes]
|
|
1859
|
+
self._insert(c_sig, kernel, cuda=True)
|
|
1860
|
+
self.overloads[argtypes] = kernel
|
|
1861
|
+
|
|
1862
|
+
@global_compiler_lock
|
|
1863
|
+
def compile(self, sig):
|
|
1864
|
+
"""
|
|
1865
|
+
Compile and bind to the current context a version of this kernel
|
|
1866
|
+
specialized for the given signature.
|
|
1867
|
+
"""
|
|
1868
|
+
argtypes, return_type = sigutils.normalize_signature(sig)
|
|
1869
|
+
assert return_type is None or return_type == types.none
|
|
1870
|
+
|
|
1871
|
+
# Do we already have an in-memory compiled kernel?
|
|
1872
|
+
if self.specialized:
|
|
1873
|
+
return next(iter(self.overloads.values()))
|
|
1874
|
+
else:
|
|
1875
|
+
kernel = self.overloads.get(argtypes)
|
|
1876
|
+
if kernel is not None:
|
|
1877
|
+
return kernel
|
|
1878
|
+
|
|
1879
|
+
# Can we load from the disk cache?
|
|
1880
|
+
kernel = self._cache.load_overload(sig, self.targetctx)
|
|
1881
|
+
|
|
1882
|
+
if kernel is not None:
|
|
1883
|
+
self._cache_hits[sig] += 1
|
|
1884
|
+
else:
|
|
1885
|
+
# We need to compile a new kernel
|
|
1886
|
+
self._cache_misses[sig] += 1
|
|
1887
|
+
if not self._can_compile:
|
|
1888
|
+
raise RuntimeError("Compilation disabled")
|
|
1889
|
+
|
|
1890
|
+
kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
|
|
1891
|
+
# We call bind to force codegen, so that there is a cubin to cache
|
|
1892
|
+
kernel.bind()
|
|
1893
|
+
self._cache.save_overload(sig, kernel)
|
|
1894
|
+
|
|
1895
|
+
self.add_overload(kernel, argtypes)
|
|
1896
|
+
|
|
1897
|
+
return kernel
|
|
1898
|
+
|
|
1899
|
+
def get_compile_result(self, sig):
|
|
1900
|
+
"""Compile (if needed) and return the compilation result with the
|
|
1901
|
+
given signature.
|
|
1902
|
+
|
|
1903
|
+
Returns ``CompileResult``.
|
|
1904
|
+
Raises ``NumbaError`` if the signature is incompatible.
|
|
1905
|
+
"""
|
|
1906
|
+
atypes = tuple(sig.args)
|
|
1907
|
+
if atypes not in self.overloads:
|
|
1908
|
+
if self._can_compile:
|
|
1909
|
+
# Compiling may raise any NumbaError
|
|
1910
|
+
self.compile(atypes)
|
|
1911
|
+
else:
|
|
1912
|
+
msg = f"{sig} not available and compilation disabled"
|
|
1913
|
+
raise errors.TypingError(msg)
|
|
1914
|
+
return self.overloads[atypes]
|
|
1915
|
+
|
|
1916
|
+
def recompile(self):
|
|
1917
|
+
"""
|
|
1918
|
+
Recompile all signatures afresh.
|
|
1919
|
+
"""
|
|
1920
|
+
sigs = list(self.overloads)
|
|
1921
|
+
old_can_compile = self._can_compile
|
|
1922
|
+
# Ensure the old overloads are disposed of,
|
|
1923
|
+
# including compiled functions.
|
|
1924
|
+
self._make_finalizer()()
|
|
1925
|
+
self._reset_overloads()
|
|
1926
|
+
self._cache.flush()
|
|
1927
|
+
self._can_compile = True
|
|
1928
|
+
try:
|
|
1929
|
+
for sig in sigs:
|
|
1930
|
+
self.compile(sig)
|
|
1931
|
+
finally:
|
|
1932
|
+
self._can_compile = old_can_compile
|
|
1933
|
+
|
|
1934
|
+
@property
|
|
1935
|
+
def stats(self):
|
|
1936
|
+
return _CompileStats(
|
|
1937
|
+
cache_path=self._cache.cache_path,
|
|
1938
|
+
cache_hits=self._cache_hits,
|
|
1939
|
+
cache_misses=self._cache_misses,
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
def get_metadata(self, signature=None):
|
|
1943
|
+
"""
|
|
1944
|
+
Obtain the compilation metadata for a given signature.
|
|
1945
|
+
"""
|
|
1946
|
+
if signature is not None:
|
|
1947
|
+
return self.overloads[signature].metadata
|
|
1948
|
+
else:
|
|
1949
|
+
return dict(
|
|
1950
|
+
(sig, self.overloads[sig].metadata) for sig in self.signatures
|
|
1951
|
+
)
|
|
1952
|
+
|
|
1953
|
+
def get_function_type(self):
|
|
1954
|
+
"""Return unique function type of dispatcher when possible, otherwise
|
|
1955
|
+
return None.
|
|
1956
|
+
|
|
1957
|
+
A Dispatcher instance has unique function type when it
|
|
1958
|
+
contains exactly one compilation result and its compilation
|
|
1959
|
+
has been disabled (via its disable_compile method).
|
|
1960
|
+
"""
|
|
1961
|
+
if not self._can_compile and len(self.overloads) == 1:
|
|
1962
|
+
cres = tuple(self.overloads.values())[0]
|
|
1963
|
+
return types.FunctionType(cres.signature)
|
|
1964
|
+
|
|
1965
|
+
def inspect_llvm(self, signature=None):
|
|
1966
|
+
"""
|
|
1967
|
+
Return the LLVM IR for this kernel.
|
|
1968
|
+
|
|
1969
|
+
:param signature: A tuple of argument types.
|
|
1970
|
+
:return: The LLVM IR for the given signature, or a dict of LLVM IR
|
|
1971
|
+
for all previously-encountered signatures.
|
|
1972
|
+
|
|
1973
|
+
"""
|
|
1974
|
+
device = self.targetoptions.get("device")
|
|
1975
|
+
if signature is not None:
|
|
1976
|
+
if device:
|
|
1977
|
+
return self.overloads[signature].library.get_llvm_str()
|
|
1978
|
+
else:
|
|
1979
|
+
return self.overloads[signature].inspect_llvm()
|
|
1980
|
+
else:
|
|
1981
|
+
if device:
|
|
1982
|
+
return {
|
|
1983
|
+
sig: overload.library.get_llvm_str()
|
|
1984
|
+
for sig, overload in self.overloads.items()
|
|
1985
|
+
}
|
|
1986
|
+
else:
|
|
1987
|
+
return {
|
|
1988
|
+
sig: overload.inspect_llvm()
|
|
1989
|
+
for sig, overload in self.overloads.items()
|
|
1990
|
+
}
|
|
1991
|
+
|
|
1992
|
+
def inspect_asm(self, signature=None):
|
|
1993
|
+
"""
|
|
1994
|
+
Return this kernel's PTX assembly code for for the device in the
|
|
1995
|
+
current context.
|
|
1996
|
+
|
|
1997
|
+
:param signature: A tuple of argument types.
|
|
1998
|
+
:return: The PTX code for the given signature, or a dict of PTX codes
|
|
1999
|
+
for all previously-encountered signatures.
|
|
2000
|
+
"""
|
|
2001
|
+
cc = get_current_device().compute_capability
|
|
2002
|
+
device = self.targetoptions.get("device")
|
|
2003
|
+
if signature is not None:
|
|
2004
|
+
if device:
|
|
2005
|
+
return self.overloads[signature].library.get_asm_str(cc)
|
|
2006
|
+
else:
|
|
2007
|
+
return self.overloads[signature].inspect_asm(cc)
|
|
2008
|
+
else:
|
|
2009
|
+
if device:
|
|
2010
|
+
return {
|
|
2011
|
+
sig: overload.library.get_asm_str(cc)
|
|
2012
|
+
for sig, overload in self.overloads.items()
|
|
2013
|
+
}
|
|
2014
|
+
else:
|
|
2015
|
+
return {
|
|
2016
|
+
sig: overload.inspect_asm(cc)
|
|
2017
|
+
for sig, overload in self.overloads.items()
|
|
2018
|
+
}
|
|
2019
|
+
|
|
2020
|
+
def inspect_lto_ptx(self, signature=None):
|
|
2021
|
+
"""
|
|
2022
|
+
Return link-time optimized PTX code for the given signature.
|
|
2023
|
+
|
|
2024
|
+
:param signature: A tuple of argument types.
|
|
2025
|
+
:return: The PTX code for the given signature, or a dict of PTX codes
|
|
2026
|
+
for all previously-encountered signatures.
|
|
2027
|
+
"""
|
|
2028
|
+
cc = get_current_device().compute_capability
|
|
2029
|
+
device = self.targetoptions.get("device")
|
|
2030
|
+
|
|
2031
|
+
if signature is not None:
|
|
2032
|
+
if device:
|
|
2033
|
+
return self.overloads[signature].library.get_lto_ptx(cc)
|
|
2034
|
+
else:
|
|
2035
|
+
return self.overloads[signature].inspect_lto_ptx(cc)
|
|
2036
|
+
else:
|
|
2037
|
+
if device:
|
|
2038
|
+
return {
|
|
2039
|
+
sig: overload.library.get_lto_ptx(cc)
|
|
2040
|
+
for sig, overload in self.overloads.items()
|
|
2041
|
+
}
|
|
2042
|
+
else:
|
|
2043
|
+
return {
|
|
2044
|
+
sig: overload.inspect_lto_ptx(cc)
|
|
2045
|
+
for sig, overload in self.overloads.items()
|
|
2046
|
+
}
|
|
2047
|
+
|
|
2048
|
+
def inspect_sass_cfg(self, signature=None):
|
|
2049
|
+
"""
|
|
2050
|
+
Return this kernel's CFG for the device in the current context.
|
|
2051
|
+
|
|
2052
|
+
:param signature: A tuple of argument types.
|
|
2053
|
+
:return: The CFG for the given signature, or a dict of CFGs
|
|
2054
|
+
for all previously-encountered signatures.
|
|
2055
|
+
|
|
2056
|
+
The CFG for the device in the current context is returned.
|
|
2057
|
+
|
|
2058
|
+
Requires nvdisasm to be available on the PATH.
|
|
2059
|
+
"""
|
|
2060
|
+
if self.targetoptions.get("device"):
|
|
2061
|
+
raise RuntimeError("Cannot get the CFG of a device function")
|
|
2062
|
+
|
|
2063
|
+
if signature is not None:
|
|
2064
|
+
return self.overloads[signature].inspect_sass_cfg()
|
|
2065
|
+
else:
|
|
2066
|
+
return {
|
|
2067
|
+
sig: defn.inspect_sass_cfg()
|
|
2068
|
+
for sig, defn in self.overloads.items()
|
|
2069
|
+
}
|
|
2070
|
+
|
|
2071
|
+
def inspect_sass(self, signature=None):
|
|
2072
|
+
"""
|
|
2073
|
+
Return this kernel's SASS assembly code for for the device in the
|
|
2074
|
+
current context.
|
|
2075
|
+
|
|
2076
|
+
:param signature: A tuple of argument types.
|
|
2077
|
+
:return: The SASS code for the given signature, or a dict of SASS codes
|
|
2078
|
+
for all previously-encountered signatures.
|
|
2079
|
+
|
|
2080
|
+
SASS for the device in the current context is returned.
|
|
2081
|
+
|
|
2082
|
+
Requires nvdisasm to be available on the PATH.
|
|
2083
|
+
"""
|
|
2084
|
+
if self.targetoptions.get("device"):
|
|
2085
|
+
raise RuntimeError("Cannot inspect SASS of a device function")
|
|
2086
|
+
|
|
2087
|
+
if signature is not None:
|
|
2088
|
+
return self.overloads[signature].inspect_sass()
|
|
2089
|
+
else:
|
|
2090
|
+
return {
|
|
2091
|
+
sig: defn.inspect_sass() for sig, defn in self.overloads.items()
|
|
2092
|
+
}
|
|
2093
|
+
|
|
2094
|
+
def inspect_types(self, file=None):
|
|
2095
|
+
"""
|
|
2096
|
+
Produce a dump of the Python source of this function annotated with the
|
|
2097
|
+
corresponding Numba IR and type information. The dump is written to
|
|
2098
|
+
*file*, or *sys.stdout* if *file* is *None*.
|
|
2099
|
+
"""
|
|
2100
|
+
if file is None:
|
|
2101
|
+
file = sys.stdout
|
|
2102
|
+
|
|
2103
|
+
for _, defn in self.overloads.items():
|
|
2104
|
+
defn.inspect_types(file=file)
|
|
2105
|
+
|
|
2106
|
+
@classmethod
|
|
2107
|
+
def _rebuild(cls, py_func, targetoptions):
|
|
2108
|
+
"""
|
|
2109
|
+
Rebuild an instance.
|
|
2110
|
+
"""
|
|
2111
|
+
instance = cls(py_func, targetoptions)
|
|
2112
|
+
return instance
|
|
2113
|
+
|
|
2114
|
+
def _reduce_states(self):
|
|
2115
|
+
"""
|
|
2116
|
+
Reduce the instance for serialization.
|
|
2117
|
+
Compiled definitions are discarded.
|
|
2118
|
+
"""
|
|
2119
|
+
return dict(py_func=self.py_func, targetoptions=self.targetoptions)
|
|
2120
|
+
|
|
2121
|
+
|
|
2122
|
+
class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
|
|
2123
|
+
"""
|
|
2124
|
+
Implementation of the hidden dispatcher objects used for lifted code
|
|
2125
|
+
(a lifted loop is really compiled as a separate function).
|
|
2126
|
+
"""
|
|
2127
|
+
|
|
2128
|
+
_fold_args = False
|
|
2129
|
+
can_cache = False
|
|
2130
|
+
|
|
2131
|
+
def __init__(self, func_ir, typingctx, targetctx, flags, locals):
|
|
2132
|
+
self.func_ir = func_ir
|
|
2133
|
+
self.lifted_from = None
|
|
2134
|
+
|
|
2135
|
+
self.typingctx = typingctx
|
|
2136
|
+
self.targetctx = targetctx
|
|
2137
|
+
self.flags = flags
|
|
2138
|
+
self.locals = locals
|
|
2139
|
+
|
|
2140
|
+
_DispatcherBase.__init__(
|
|
2141
|
+
self,
|
|
2142
|
+
self.func_ir.arg_count,
|
|
2143
|
+
self.func_ir.func_id.func,
|
|
2144
|
+
self.func_ir.func_id.pysig,
|
|
2145
|
+
can_fallback=True,
|
|
2146
|
+
exact_match_required=False,
|
|
2147
|
+
)
|
|
2148
|
+
|
|
2149
|
+
def _reduce_states(self):
|
|
2150
|
+
"""
|
|
2151
|
+
Reduce the instance for pickling. This will serialize
|
|
2152
|
+
the original function as well the compilation options and
|
|
2153
|
+
compiled signatures, but not the compiled code itself.
|
|
2154
|
+
|
|
2155
|
+
NOTE: part of ReduceMixin protocol
|
|
2156
|
+
"""
|
|
2157
|
+
return dict(
|
|
2158
|
+
uuid=self._uuid,
|
|
2159
|
+
func_ir=self.func_ir,
|
|
2160
|
+
flags=self.flags,
|
|
2161
|
+
locals=self.locals,
|
|
2162
|
+
extras=self._reduce_extras(),
|
|
2163
|
+
)
|
|
2164
|
+
|
|
2165
|
+
def _reduce_extras(self):
|
|
2166
|
+
"""
|
|
2167
|
+
NOTE: sub-class can override to add extra states
|
|
2168
|
+
"""
|
|
2169
|
+
return {}
|
|
2170
|
+
|
|
2171
|
+
@classmethod
|
|
2172
|
+
def _rebuild(cls, uuid, func_ir, flags, locals, extras):
|
|
2173
|
+
"""
|
|
2174
|
+
Rebuild an Dispatcher instance after it was __reduce__'d.
|
|
2175
|
+
|
|
2176
|
+
NOTE: part of ReduceMixin protocol
|
|
2177
|
+
"""
|
|
2178
|
+
try:
|
|
2179
|
+
return cls._memo[uuid]
|
|
2180
|
+
except KeyError:
|
|
2181
|
+
pass
|
|
2182
|
+
|
|
2183
|
+
from numba.cuda.descriptor import cuda_target
|
|
2184
|
+
|
|
2185
|
+
typingctx = cuda_target.typing_context
|
|
2186
|
+
targetctx = cuda_target.target_context
|
|
2187
|
+
|
|
2188
|
+
self = cls(func_ir, typingctx, targetctx, flags, locals, **extras)
|
|
2189
|
+
self._set_uuid(uuid)
|
|
2190
|
+
return self
|
|
2191
|
+
|
|
2192
|
+
def get_source_location(self):
|
|
2193
|
+
"""Return the starting line number of the loop."""
|
|
2194
|
+
return self.func_ir.loc.line
|
|
2195
|
+
|
|
2196
|
+
def _pre_compile(self, args, return_type, flags):
|
|
2197
|
+
"""Pre-compile actions"""
|
|
2198
|
+
pass
|
|
2199
|
+
|
|
2200
|
+
@abstractmethod
|
|
2201
|
+
def compile(self, sig):
|
|
2202
|
+
"""Lifted code should implement a compilation method that will return
|
|
2203
|
+
a CompileResult.entry_point for the given signature."""
|
|
2204
|
+
pass
|
|
2205
|
+
|
|
2206
|
+
def _get_dispatcher_for_current_target(self):
|
|
2207
|
+
# Lifted code does not honor the target switch currently.
|
|
2208
|
+
# No work has been done to check if this can be allowed.
|
|
2209
|
+
return self
|
|
2210
|
+
|
|
2211
|
+
|
|
2212
|
+
class LiftedLoop(LiftedCode):
|
|
2213
|
+
def _pre_compile(self, args, return_type, flags):
|
|
2214
|
+
assert not flags.enable_looplift, "Enable looplift flags is on"
|
|
2215
|
+
|
|
2216
|
+
def compile(self, sig):
|
|
2217
|
+
with ExitStack() as scope:
|
|
2218
|
+
cres = None
|
|
2219
|
+
|
|
2220
|
+
def cb_compiler(dur):
|
|
2221
|
+
if cres is not None:
|
|
2222
|
+
self._callback_add_compiler_timer(dur, cres)
|
|
2223
|
+
|
|
2224
|
+
def cb_llvm(dur):
|
|
2225
|
+
if cres is not None:
|
|
2226
|
+
self._callback_add_llvm_timer(dur, cres)
|
|
2227
|
+
|
|
2228
|
+
scope.enter_context(
|
|
2229
|
+
ev.install_timer("numba:compiler_lock", cb_compiler)
|
|
2230
|
+
)
|
|
2231
|
+
scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
|
|
2232
|
+
scope.enter_context(global_compiler_lock)
|
|
2233
|
+
|
|
2234
|
+
# Use counter to track recursion compilation depth
|
|
2235
|
+
with self._compiling_counter:
|
|
2236
|
+
# XXX this is mostly duplicated from Dispatcher.
|
|
2237
|
+
flags = self.flags
|
|
2238
|
+
args, return_type = sigutils.normalize_signature(sig)
|
|
2239
|
+
|
|
2240
|
+
# Don't recompile if signature already exists
|
|
2241
|
+
# (e.g. if another thread compiled it before we got the lock)
|
|
2242
|
+
existing = self.overloads.get(tuple(args))
|
|
2243
|
+
if existing is not None:
|
|
2244
|
+
return existing.entry_point
|
|
2245
|
+
|
|
2246
|
+
self._pre_compile(args, return_type, flags)
|
|
2247
|
+
|
|
2248
|
+
# copy the flags, use nopython first
|
|
2249
|
+
npm_loop_flags = flags.copy()
|
|
2250
|
+
npm_loop_flags.force_pyobject = False
|
|
2251
|
+
|
|
2252
|
+
pyobject_loop_flags = flags.copy()
|
|
2253
|
+
pyobject_loop_flags.force_pyobject = True
|
|
2254
|
+
|
|
2255
|
+
# Clone IR to avoid (some of the) mutation in the rewrite pass
|
|
2256
|
+
cloned_func_ir_npm = self.func_ir.copy()
|
|
2257
|
+
cloned_func_ir_fbk = self.func_ir.copy()
|
|
2258
|
+
|
|
2259
|
+
ev_details = dict(
|
|
2260
|
+
dispatcher=self,
|
|
2261
|
+
args=args,
|
|
2262
|
+
return_type=return_type,
|
|
2263
|
+
)
|
|
2264
|
+
with ev.trigger_event("numba:compile", data=ev_details):
|
|
2265
|
+
# this emulates "object mode fall-back", try nopython, if it
|
|
2266
|
+
# fails, then try again in object mode.
|
|
2267
|
+
try:
|
|
2268
|
+
cres = compile_ir(
|
|
2269
|
+
typingctx=self.typingctx,
|
|
2270
|
+
targetctx=self.targetctx,
|
|
2271
|
+
func_ir=cloned_func_ir_npm,
|
|
2272
|
+
args=args,
|
|
2273
|
+
return_type=return_type,
|
|
2274
|
+
flags=npm_loop_flags,
|
|
2275
|
+
locals=self.locals,
|
|
2276
|
+
lifted=(),
|
|
2277
|
+
lifted_from=self.lifted_from,
|
|
2278
|
+
is_lifted_loop=True,
|
|
2279
|
+
)
|
|
2280
|
+
except errors.TypingError:
|
|
2281
|
+
cres = compile_ir(
|
|
2282
|
+
typingctx=self.typingctx,
|
|
2283
|
+
targetctx=self.targetctx,
|
|
2284
|
+
func_ir=cloned_func_ir_fbk,
|
|
2285
|
+
args=args,
|
|
2286
|
+
return_type=return_type,
|
|
2287
|
+
flags=pyobject_loop_flags,
|
|
2288
|
+
locals=self.locals,
|
|
2289
|
+
lifted=(),
|
|
2290
|
+
lifted_from=self.lifted_from,
|
|
2291
|
+
is_lifted_loop=True,
|
|
2292
|
+
)
|
|
2293
|
+
# Check typing error if object mode is used
|
|
2294
|
+
if cres.typing_error is not None:
|
|
2295
|
+
raise cres.typing_error
|
|
2296
|
+
self.add_overload(cres)
|
|
2297
|
+
return cres.entry_point
|
|
2298
|
+
|
|
2299
|
+
|
|
2300
|
+
class LiftedWith(LiftedCode):
|
|
2301
|
+
can_cache = True
|
|
2302
|
+
|
|
2303
|
+
def _reduce_extras(self):
|
|
2304
|
+
return dict(output_types=self.output_types)
|
|
2305
|
+
|
|
2306
|
+
@property
|
|
2307
|
+
def _numba_type_(self):
|
|
2308
|
+
return types.Dispatcher(self)
|
|
2309
|
+
|
|
2310
|
+
def get_call_template(self, args, kws):
|
|
2311
|
+
"""
|
|
2312
|
+
Get a typing.ConcreteTemplate for this dispatcher and the given
|
|
2313
|
+
*args* and *kws* types. This enables the resolving of the return type.
|
|
2314
|
+
|
|
2315
|
+
A (template, pysig, args, kws) tuple is returned.
|
|
2316
|
+
"""
|
|
2317
|
+
# Ensure an overload is available
|
|
2318
|
+
if self._can_compile:
|
|
2319
|
+
self.compile(tuple(args))
|
|
2320
|
+
|
|
2321
|
+
pysig = None
|
|
2322
|
+
# Create function type for typing
|
|
2323
|
+
func_name = self.py_func.__name__
|
|
2324
|
+
name = "CallTemplate({0})".format(func_name)
|
|
2325
|
+
# The `key` isn't really used except for diagnosis here,
|
|
2326
|
+
# so avoid keeping a reference to `cfunc`.
|
|
2327
|
+
call_template = typing.make_concrete_template(
|
|
2328
|
+
name, key=func_name, signatures=self.nopython_signatures
|
|
2329
|
+
)
|
|
2330
|
+
return call_template, pysig, args, kws
|
|
2331
|
+
|
|
2332
|
+
def compile(self, sig):
|
|
2333
|
+
# this is similar to LiftedLoop's compile but does not have the
|
|
2334
|
+
# "fallback" to object mode part.
|
|
2335
|
+
with ExitStack() as scope:
|
|
2336
|
+
cres = None
|
|
2337
|
+
|
|
2338
|
+
def cb_compiler(dur):
|
|
2339
|
+
if cres is not None:
|
|
2340
|
+
self._callback_add_compiler_timer(dur, cres)
|
|
2341
|
+
|
|
2342
|
+
def cb_llvm(dur):
|
|
2343
|
+
if cres is not None:
|
|
2344
|
+
self._callback_add_llvm_timer(dur, cres)
|
|
2345
|
+
|
|
2346
|
+
scope.enter_context(
|
|
2347
|
+
ev.install_timer("numba:compiler_lock", cb_compiler)
|
|
2348
|
+
)
|
|
2349
|
+
scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
|
|
2350
|
+
scope.enter_context(global_compiler_lock)
|
|
2351
|
+
|
|
2352
|
+
# Use counter to track recursion compilation depth
|
|
2353
|
+
with self._compiling_counter:
|
|
2354
|
+
# XXX this is mostly duplicated from Dispatcher.
|
|
2355
|
+
flags = self.flags
|
|
2356
|
+
args, return_type = sigutils.normalize_signature(sig)
|
|
2357
|
+
|
|
2358
|
+
# Don't recompile if signature already exists
|
|
2359
|
+
# (e.g. if another thread compiled it before we got the lock)
|
|
2360
|
+
existing = self.overloads.get(tuple(args))
|
|
2361
|
+
if existing is not None:
|
|
2362
|
+
return existing.entry_point
|
|
2363
|
+
|
|
2364
|
+
self._pre_compile(args, return_type, flags)
|
|
2365
|
+
|
|
2366
|
+
# Clone IR to avoid (some of the) mutation in the rewrite pass
|
|
2367
|
+
cloned_func_ir = self.func_ir.copy()
|
|
2368
|
+
|
|
2369
|
+
ev_details = dict(
|
|
2370
|
+
dispatcher=self,
|
|
2371
|
+
args=args,
|
|
2372
|
+
return_type=return_type,
|
|
2373
|
+
)
|
|
2374
|
+
with ev.trigger_event("numba:compile", data=ev_details):
|
|
2375
|
+
cres = compile_ir(
|
|
2376
|
+
typingctx=self.typingctx,
|
|
2377
|
+
targetctx=self.targetctx,
|
|
2378
|
+
func_ir=cloned_func_ir,
|
|
2379
|
+
args=args,
|
|
2380
|
+
return_type=return_type,
|
|
2381
|
+
flags=flags,
|
|
2382
|
+
locals=self.locals,
|
|
2383
|
+
lifted=(),
|
|
2384
|
+
lifted_from=self.lifted_from,
|
|
2385
|
+
is_lifted_loop=True,
|
|
2386
|
+
)
|
|
2387
|
+
|
|
2388
|
+
# Check typing error if object mode is used
|
|
2389
|
+
if (
|
|
2390
|
+
cres.typing_error is not None
|
|
2391
|
+
and not flags.enable_pyobject
|
|
2392
|
+
):
|
|
2393
|
+
raise cres.typing_error
|
|
2394
|
+
self.add_overload(cres)
|
|
2395
|
+
return cres.entry_point
|
|
2396
|
+
|
|
2397
|
+
|
|
2398
|
+
class ObjModeLiftedWith(LiftedWith):
|
|
2399
|
+
def __init__(self, *args, **kwargs):
|
|
2400
|
+
self.output_types = kwargs.pop("output_types", None)
|
|
2401
|
+
super(LiftedWith, self).__init__(*args, **kwargs)
|
|
2402
|
+
if not self.flags.force_pyobject:
|
|
2403
|
+
raise ValueError("expecting `flags.force_pyobject`")
|
|
2404
|
+
if self.output_types is None:
|
|
2405
|
+
raise TypeError("`output_types` must be provided")
|
|
2406
|
+
# switch off rewrites, they have no effect
|
|
2407
|
+
self.flags.no_rewrites = True
|
|
2408
|
+
|
|
2409
|
+
@property
|
|
2410
|
+
def _numba_type_(self):
|
|
2411
|
+
return types.ObjModeDispatcher(self)
|
|
2412
|
+
|
|
2413
|
+
def get_call_template(self, args, kws):
|
|
2414
|
+
"""
|
|
2415
|
+
Get a typing.ConcreteTemplate for this dispatcher and the given
|
|
2416
|
+
*args* and *kws* types. This enables the resolving of the return type.
|
|
2417
|
+
|
|
2418
|
+
A (template, pysig, args, kws) tuple is returned.
|
|
2419
|
+
"""
|
|
2420
|
+
assert not kws
|
|
2421
|
+
self._legalize_arg_types(args)
|
|
2422
|
+
# Coerce to object mode
|
|
2423
|
+
args = [types.ffi_forced_object] * len(args)
|
|
2424
|
+
|
|
2425
|
+
if self._can_compile:
|
|
2426
|
+
self.compile(tuple(args))
|
|
2427
|
+
|
|
2428
|
+
signatures = [typing.signature(self.output_types, *args)]
|
|
2429
|
+
pysig = None
|
|
2430
|
+
func_name = self.py_func.__name__
|
|
2431
|
+
name = "CallTemplate({0})".format(func_name)
|
|
2432
|
+
call_template = typing.make_concrete_template(
|
|
2433
|
+
name, key=func_name, signatures=signatures
|
|
2434
|
+
)
|
|
2435
|
+
|
|
2436
|
+
return call_template, pysig, args, kws
|
|
2437
|
+
|
|
2438
|
+
def _legalize_arg_types(self, args):
|
|
2439
|
+
for i, a in enumerate(args, start=1):
|
|
2440
|
+
if isinstance(a, types.List):
|
|
2441
|
+
msg = (
|
|
2442
|
+
"Does not support list type inputs into "
|
|
2443
|
+
"with-context for arg {}"
|
|
2444
|
+
)
|
|
2445
|
+
raise errors.TypingError(msg.format(i))
|
|
2446
|
+
elif isinstance(a, types.Dispatcher):
|
|
2447
|
+
msg = (
|
|
2448
|
+
"Does not support function type inputs into "
|
|
2449
|
+
"with-context for arg {}"
|
|
2450
|
+
)
|
|
2451
|
+
raise errors.TypingError(msg.format(i))
|
|
2452
|
+
|
|
2453
|
+
@global_compiler_lock
|
|
2454
|
+
def compile(self, sig):
|
|
2455
|
+
args, _ = sigutils.normalize_signature(sig)
|
|
2456
|
+
sig = (types.ffi_forced_object,) * len(args)
|
|
2457
|
+
return super().compile(sig)
|
|
2458
|
+
|
|
2459
|
+
|
|
2460
|
+
# Initialize typeof machinery
|
|
2461
|
+
_dispatcher.typeof_init(
|
|
2462
|
+
OmittedArg, dict((str(t), t._code) for t in types.number_domain)
|
|
2463
|
+
)
|