numba-cuda 0.22.0__cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.pth +4 -0
- _numba_cuda_redirector.py +89 -0
- numba_cuda/VERSION +1 -0
- numba_cuda/__init__.py +6 -0
- numba_cuda/_version.py +11 -0
- numba_cuda/numba/cuda/__init__.py +70 -0
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +16394 -0
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +8112 -0
- numba_cuda/numba/cuda/api.py +580 -0
- numba_cuda/numba/cuda/api_util.py +76 -0
- numba_cuda/numba/cuda/args.py +72 -0
- numba_cuda/numba/cuda/bf16.py +397 -0
- numba_cuda/numba/cuda/cache_hints.py +287 -0
- numba_cuda/numba/cuda/cext/__init__.py +2 -0
- numba_cuda/numba/cuda/cext/_devicearray.cpp +159 -0
- numba_cuda/numba/cuda/cext/_devicearray.cpython-313-aarch64-linux-gnu.so +0 -0
- numba_cuda/numba/cuda/cext/_devicearray.h +29 -0
- numba_cuda/numba/cuda/cext/_dispatcher.cpp +1098 -0
- numba_cuda/numba/cuda/cext/_dispatcher.cpython-313-aarch64-linux-gnu.so +0 -0
- numba_cuda/numba/cuda/cext/_hashtable.cpp +532 -0
- numba_cuda/numba/cuda/cext/_hashtable.h +135 -0
- numba_cuda/numba/cuda/cext/_helperlib.c +71 -0
- numba_cuda/numba/cuda/cext/_helperlib.cpython-313-aarch64-linux-gnu.so +0 -0
- numba_cuda/numba/cuda/cext/_helpermod.c +82 -0
- numba_cuda/numba/cuda/cext/_pymodule.h +38 -0
- numba_cuda/numba/cuda/cext/_typeconv.cpp +206 -0
- numba_cuda/numba/cuda/cext/_typeconv.cpython-313-aarch64-linux-gnu.so +0 -0
- numba_cuda/numba/cuda/cext/_typeof.cpp +1159 -0
- numba_cuda/numba/cuda/cext/_typeof.h +19 -0
- numba_cuda/numba/cuda/cext/capsulethunk.h +111 -0
- numba_cuda/numba/cuda/cext/mviewbuf.c +385 -0
- numba_cuda/numba/cuda/cext/mviewbuf.cpython-313-aarch64-linux-gnu.so +0 -0
- numba_cuda/numba/cuda/cext/typeconv.cpp +212 -0
- numba_cuda/numba/cuda/cext/typeconv.hpp +101 -0
- numba_cuda/numba/cuda/cg.py +67 -0
- numba_cuda/numba/cuda/cgutils.py +1294 -0
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +541 -0
- numba_cuda/numba/cuda/compiler.py +1396 -0
- numba_cuda/numba/cuda/core/analysis.py +758 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/pretty_annotate.py +288 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +305 -0
- numba_cuda/numba/cuda/core/base.py +1332 -0
- numba_cuda/numba/cuda/core/boxing.py +1411 -0
- numba_cuda/numba/cuda/core/bytecode.py +728 -0
- numba_cuda/numba/cuda/core/byteflow.py +2346 -0
- numba_cuda/numba/cuda/core/caching.py +744 -0
- numba_cuda/numba/cuda/core/callconv.py +392 -0
- numba_cuda/numba/cuda/core/codegen.py +171 -0
- numba_cuda/numba/cuda/core/compiler.py +199 -0
- numba_cuda/numba/cuda/core/compiler_lock.py +85 -0
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +650 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/controlflow.py +989 -0
- numba_cuda/numba/cuda/core/entrypoints.py +57 -0
- numba_cuda/numba/cuda/core/environment.py +66 -0
- numba_cuda/numba/cuda/core/errors.py +917 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/generators.py +387 -0
- numba_cuda/numba/cuda/core/imputils.py +509 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1787 -0
- numba_cuda/numba/cuda/core/interpreter.py +3617 -0
- numba_cuda/numba/cuda/core/ir.py +1812 -0
- numba_cuda/numba/cuda/core/ir_utils.py +2638 -0
- numba_cuda/numba/cuda/core/optional.py +129 -0
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1859 -0
- numba_cuda/numba/cuda/core/registry.py +46 -0
- numba_cuda/numba/cuda/core/removerefctpass.py +123 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +91 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +41 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +189 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +100 -0
- numba_cuda/numba/cuda/core/sigutils.py +68 -0
- numba_cuda/numba/cuda/core/ssa.py +498 -0
- numba_cuda/numba/cuda/core/targetconfig.py +330 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +956 -0
- numba_cuda/numba/cuda/core/typed_passes.py +867 -0
- numba_cuda/numba/cuda/core/typeinfer.py +1950 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1979 -0
- numba_cuda/numba/cuda/cpython/builtins.py +1153 -0
- numba_cuda/numba/cuda/cpython/charseq.py +1218 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/enumimpl.py +103 -0
- numba_cuda/numba/cuda/cpython/iterators.py +167 -0
- numba_cuda/numba/cuda/cpython/listobj.py +1326 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1475 -0
- numba_cuda/numba/cuda/cpython/rangeobj.py +289 -0
- numba_cuda/numba/cuda/cpython/slicing.py +322 -0
- numba_cuda/numba/cuda/cpython/tupleobj.py +456 -0
- numba_cuda/numba/cuda/cpython/unicode.py +2865 -0
- numba_cuda/numba/cuda/cpython/unicode_support.py +1597 -0
- numba_cuda/numba/cuda/cpython/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/cpython/unsafe/numbers.py +64 -0
- numba_cuda/numba/cuda/cpython/unsafe/tuple.py +92 -0
- numba_cuda/numba/cuda/cuda_paths.py +691 -0
- numba_cuda/numba/cuda/cudadecl.py +543 -0
- numba_cuda/numba/cuda/cudadrv/__init__.py +14 -0
- numba_cuda/numba/cuda/cudadrv/devicearray.py +954 -0
- numba_cuda/numba/cuda/cudadrv/devices.py +249 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +3238 -0
- numba_cuda/numba/cuda/cudadrv/drvapi.py +435 -0
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +562 -0
- numba_cuda/numba/cuda/cudadrv/enums.py +613 -0
- numba_cuda/numba/cuda/cudadrv/error.py +48 -0
- numba_cuda/numba/cuda/cudadrv/libs.py +220 -0
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +184 -0
- numba_cuda/numba/cuda/cudadrv/mappings.py +14 -0
- numba_cuda/numba/cuda/cudadrv/ndarray.py +26 -0
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +193 -0
- numba_cuda/numba/cuda/cudadrv/nvvm.py +756 -0
- numba_cuda/numba/cuda/cudadrv/rtapi.py +13 -0
- numba_cuda/numba/cuda/cudadrv/runtime.py +34 -0
- numba_cuda/numba/cuda/cudaimpl.py +983 -0
- numba_cuda/numba/cuda/cudamath.py +149 -0
- numba_cuda/numba/cuda/datamodel/__init__.py +7 -0
- numba_cuda/numba/cuda/datamodel/cuda_manager.py +66 -0
- numba_cuda/numba/cuda/datamodel/cuda_models.py +1446 -0
- numba_cuda/numba/cuda/datamodel/cuda_packer.py +224 -0
- numba_cuda/numba/cuda/datamodel/cuda_registry.py +22 -0
- numba_cuda/numba/cuda/datamodel/cuda_testing.py +153 -0
- numba_cuda/numba/cuda/datamodel/manager.py +11 -0
- numba_cuda/numba/cuda/datamodel/models.py +9 -0
- numba_cuda/numba/cuda/datamodel/packer.py +9 -0
- numba_cuda/numba/cuda/datamodel/registry.py +11 -0
- numba_cuda/numba/cuda/datamodel/testing.py +11 -0
- numba_cuda/numba/cuda/debuginfo.py +997 -0
- numba_cuda/numba/cuda/decorators.py +294 -0
- numba_cuda/numba/cuda/descriptor.py +35 -0
- numba_cuda/numba/cuda/device_init.py +155 -0
- numba_cuda/numba/cuda/deviceufunc.py +1021 -0
- numba_cuda/numba/cuda/dispatcher.py +2463 -0
- numba_cuda/numba/cuda/errors.py +72 -0
- numba_cuda/numba/cuda/extending.py +697 -0
- numba_cuda/numba/cuda/flags.py +178 -0
- numba_cuda/numba/cuda/fp16.py +357 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/initialize.py +24 -0
- numba_cuda/numba/cuda/intrinsics.py +531 -0
- numba_cuda/numba/cuda/itanium_mangler.py +214 -0
- numba_cuda/numba/cuda/kernels/__init__.py +2 -0
- numba_cuda/numba/cuda/kernels/reduction.py +265 -0
- numba_cuda/numba/cuda/kernels/transpose.py +65 -0
- numba_cuda/numba/cuda/libdevice.py +3386 -0
- numba_cuda/numba/cuda/libdevicedecl.py +20 -0
- numba_cuda/numba/cuda/libdevicefuncs.py +1060 -0
- numba_cuda/numba/cuda/libdeviceimpl.py +88 -0
- numba_cuda/numba/cuda/locks.py +19 -0
- numba_cuda/numba/cuda/lowering.py +1980 -0
- numba_cuda/numba/cuda/mathimpl.py +374 -0
- numba_cuda/numba/cuda/memory_management/__init__.py +4 -0
- numba_cuda/numba/cuda/memory_management/memsys.cu +99 -0
- numba_cuda/numba/cuda/memory_management/memsys.cuh +22 -0
- numba_cuda/numba/cuda/memory_management/nrt.cu +212 -0
- numba_cuda/numba/cuda/memory_management/nrt.cuh +48 -0
- numba_cuda/numba/cuda/memory_management/nrt.py +390 -0
- numba_cuda/numba/cuda/memory_management/nrt_context.py +438 -0
- numba_cuda/numba/cuda/misc/appdirs.py +594 -0
- numba_cuda/numba/cuda/misc/cffiimpl.py +24 -0
- numba_cuda/numba/cuda/misc/coverage_support.py +43 -0
- numba_cuda/numba/cuda/misc/dump_style.py +41 -0
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/misc/firstlinefinder.py +96 -0
- numba_cuda/numba/cuda/misc/gdb_hook.py +240 -0
- numba_cuda/numba/cuda/misc/literal.py +28 -0
- numba_cuda/numba/cuda/misc/llvm_pass_timings.py +412 -0
- numba_cuda/numba/cuda/misc/special.py +94 -0
- numba_cuda/numba/cuda/models.py +56 -0
- numba_cuda/numba/cuda/np/arraymath.py +5130 -0
- numba_cuda/numba/cuda/np/arrayobj.py +7635 -0
- numba_cuda/numba/cuda/np/extensions.py +11 -0
- numba_cuda/numba/cuda/np/linalg.py +3087 -0
- numba_cuda/numba/cuda/np/math/__init__.py +0 -0
- numba_cuda/numba/cuda/np/math/cmathimpl.py +558 -0
- numba_cuda/numba/cuda/np/math/mathimpl.py +487 -0
- numba_cuda/numba/cuda/np/math/numbers.py +1461 -0
- numba_cuda/numba/cuda/np/npdatetime.py +969 -0
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1808 -0
- numba_cuda/numba/cuda/np/npyimpl.py +1027 -0
- numba_cuda/numba/cuda/np/numpy_support.py +798 -0
- numba_cuda/numba/cuda/np/polynomial/__init__.py +4 -0
- numba_cuda/numba/cuda/np/polynomial/polynomial_core.py +242 -0
- numba_cuda/numba/cuda/np/polynomial/polynomial_functions.py +380 -0
- numba_cuda/numba/cuda/np/ufunc/__init__.py +4 -0
- numba_cuda/numba/cuda/np/ufunc/decorators.py +203 -0
- numba_cuda/numba/cuda/np/ufunc/sigparse.py +68 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +65 -0
- numba_cuda/numba/cuda/np/ufunc_db.py +1282 -0
- numba_cuda/numba/cuda/np/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/np/unsafe/ndarray.py +84 -0
- numba_cuda/numba/cuda/nvvmutils.py +254 -0
- numba_cuda/numba/cuda/printimpl.py +126 -0
- numba_cuda/numba/cuda/random.py +308 -0
- numba_cuda/numba/cuda/reshape_funcs.cu +156 -0
- numba_cuda/numba/cuda/serialize.py +267 -0
- numba_cuda/numba/cuda/simulator/__init__.py +63 -0
- numba_cuda/numba/cuda/simulator/_internal/__init__.py +4 -0
- numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
- numba_cuda/numba/cuda/simulator/api.py +179 -0
- numba_cuda/numba/cuda/simulator/bf16.py +4 -0
- numba_cuda/numba/cuda/simulator/compiler.py +38 -0
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +11 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +462 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +122 -0
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +66 -0
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +7 -0
- numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +7 -0
- numba_cuda/numba/cuda/simulator/cudadrv/error.py +10 -0
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +10 -0
- numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +61 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +11 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +32 -0
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +22 -0
- numba_cuda/numba/cuda/simulator/dispatcher.py +11 -0
- numba_cuda/numba/cuda/simulator/kernel.py +320 -0
- numba_cuda/numba/cuda/simulator/kernelapi.py +509 -0
- numba_cuda/numba/cuda/simulator/memory_management/__init__.py +4 -0
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +21 -0
- numba_cuda/numba/cuda/simulator/reduction.py +19 -0
- numba_cuda/numba/cuda/simulator/tests/support.py +4 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +65 -0
- numba_cuda/numba/cuda/simulator_init.py +18 -0
- numba_cuda/numba/cuda/stubs.py +624 -0
- numba_cuda/numba/cuda/target.py +505 -0
- numba_cuda/numba/cuda/testing.py +347 -0
- numba_cuda/numba/cuda/tests/__init__.py +62 -0
- numba_cuda/numba/cuda/tests/benchmarks/__init__.py +0 -0
- numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py +119 -0
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/serialize_usecases.py +113 -0
- numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +83 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +371 -0
- numba_cuda/numba/cuda/tests/cudadrv/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +147 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +161 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +397 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +24 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +180 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +313 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +191 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +621 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +247 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +100 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +200 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +53 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +72 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +138 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +43 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +15 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_linkable_code.py +58 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +348 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +128 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +301 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +174 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +28 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +185 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +39 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +23 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +38 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +48 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +44 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +127 -0
- numba_cuda/numba/cuda/tests/cudapy/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +231 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +50 -0
- numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +36 -0
- numba_cuda/numba/cuda/tests/cudapy/complex_usecases.py +116 -0
- numba_cuda/numba/cuda/tests/cudapy/enum_usecases.py +59 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +62 -0
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +28 -0
- numba_cuda/numba/cuda/tests/cudapy/overload_usecases.py +33 -0
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +104 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +47 -0
- numba_cuda/numba/cuda/tests/cudapy/test_analysis.py +1122 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +344 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +268 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +203 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +63 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +360 -0
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1815 -0
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +599 -0
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +377 -0
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +160 -0
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +27 -0
- numba_cuda/numba/cuda/tests/cudapy/test_byteflow.py +98 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py +210 -0
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +683 -0
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +265 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +42 -0
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +718 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +370 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +23 -0
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +142 -0
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +178 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +193 -0
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +131 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +438 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +94 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +101 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +105 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +978 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +476 -0
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +500 -0
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +820 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +152 -0
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +111 -0
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +170 -0
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1088 -0
- numba_cuda/numba/cuda/tests/cudapy/test_extending_types.py +71 -0
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +265 -0
- numba_cuda/numba/cuda/tests/cudapy/test_flow_control.py +1433 -0
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +57 -0
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +34 -0
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +69 -0
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +62 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +474 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +167 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +92 -0
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +39 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +170 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +255 -0
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1219 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +263 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ir.py +598 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +276 -0
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +101 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +68 -0
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +123 -0
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +194 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +220 -0
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +173 -0
- numba_cuda/numba/cuda/tests/cudapy/test_make_function_to_jit_function.py +364 -0
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +47 -0
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +842 -0
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +76 -0
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +78 -0
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +25 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +145 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +39 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +82 -0
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +53 -0
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +504 -0
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +93 -0
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +402 -0
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +128 -0
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +193 -0
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +37 -0
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +117 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +614 -0
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +94 -0
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +86 -0
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +40 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +457 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +233 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +454 -0
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +56 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +277 -0
- numba_cuda/numba/cuda/tests/cudapy/test_tracing.py +200 -0
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +90 -0
- numba_cuda/numba/cuda/tests/cudapy/test_typeconv.py +333 -0
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +585 -0
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +42 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +485 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +312 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +23 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +183 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +40 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +40 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +206 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +446 -0
- numba_cuda/numba/cuda/tests/cudasim/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/cudasim/support.py +9 -0
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +111 -0
- numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
- numba_cuda/numba/cuda/tests/data/cta_barrier.cu +28 -0
- numba_cuda/numba/cuda/tests/data/cuda_include.cu +10 -0
- numba_cuda/numba/cuda/tests/data/error.cu +12 -0
- numba_cuda/numba/cuda/tests/data/include/add.cuh +8 -0
- numba_cuda/numba/cuda/tests/data/jitlink.cu +28 -0
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +49 -0
- numba_cuda/numba/cuda/tests/data/warn.cu +12 -0
- numba_cuda/numba/cuda/tests/doc_examples/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +54 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +8 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +14 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +86 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +68 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +81 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +141 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +160 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +180 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +119 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +66 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +80 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +206 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +53 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +76 -0
- numba_cuda/numba/cuda/tests/nocuda/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +452 -0
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +48 -0
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +63 -0
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +252 -0
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +59 -0
- numba_cuda/numba/cuda/tests/nrt/__init__.py +9 -0
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +387 -0
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +124 -0
- numba_cuda/numba/cuda/tests/support.py +900 -0
- numba_cuda/numba/cuda/typeconv/__init__.py +4 -0
- numba_cuda/numba/cuda/typeconv/castgraph.py +137 -0
- numba_cuda/numba/cuda/typeconv/rules.py +63 -0
- numba_cuda/numba/cuda/typeconv/typeconv.py +121 -0
- numba_cuda/numba/cuda/types/__init__.py +233 -0
- numba_cuda/numba/cuda/types/__init__.pyi +167 -0
- numba_cuda/numba/cuda/types/abstract.py +9 -0
- numba_cuda/numba/cuda/types/common.py +9 -0
- numba_cuda/numba/cuda/types/containers.py +9 -0
- numba_cuda/numba/cuda/types/cuda_abstract.py +533 -0
- numba_cuda/numba/cuda/types/cuda_common.py +110 -0
- numba_cuda/numba/cuda/types/cuda_containers.py +971 -0
- numba_cuda/numba/cuda/types/cuda_function_type.py +230 -0
- numba_cuda/numba/cuda/types/cuda_functions.py +798 -0
- numba_cuda/numba/cuda/types/cuda_iterators.py +120 -0
- numba_cuda/numba/cuda/types/cuda_misc.py +569 -0
- numba_cuda/numba/cuda/types/cuda_npytypes.py +690 -0
- numba_cuda/numba/cuda/types/cuda_scalars.py +280 -0
- numba_cuda/numba/cuda/types/ext_types.py +101 -0
- numba_cuda/numba/cuda/types/function_type.py +11 -0
- numba_cuda/numba/cuda/types/functions.py +9 -0
- numba_cuda/numba/cuda/types/iterators.py +9 -0
- numba_cuda/numba/cuda/types/misc.py +9 -0
- numba_cuda/numba/cuda/types/npytypes.py +9 -0
- numba_cuda/numba/cuda/types/scalars.py +9 -0
- numba_cuda/numba/cuda/typing/__init__.py +19 -0
- numba_cuda/numba/cuda/typing/arraydecl.py +939 -0
- numba_cuda/numba/cuda/typing/asnumbatype.py +130 -0
- numba_cuda/numba/cuda/typing/bufproto.py +70 -0
- numba_cuda/numba/cuda/typing/builtins.py +1209 -0
- numba_cuda/numba/cuda/typing/cffi_utils.py +219 -0
- numba_cuda/numba/cuda/typing/cmathdecl.py +47 -0
- numba_cuda/numba/cuda/typing/collections.py +138 -0
- numba_cuda/numba/cuda/typing/context.py +782 -0
- numba_cuda/numba/cuda/typing/ctypes_utils.py +125 -0
- numba_cuda/numba/cuda/typing/dictdecl.py +63 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/listdecl.py +147 -0
- numba_cuda/numba/cuda/typing/mathdecl.py +158 -0
- numba_cuda/numba/cuda/typing/npdatetime.py +322 -0
- numba_cuda/numba/cuda/typing/npydecl.py +749 -0
- numba_cuda/numba/cuda/typing/setdecl.py +115 -0
- numba_cuda/numba/cuda/typing/templates.py +1446 -0
- numba_cuda/numba/cuda/typing/typeof.py +301 -0
- numba_cuda/numba/cuda/ufuncs.py +746 -0
- numba_cuda/numba/cuda/utils.py +724 -0
- numba_cuda/numba/cuda/vector_types.py +214 -0
- numba_cuda/numba/cuda/vectorizers.py +260 -0
- numba_cuda-0.22.0.dist-info/METADATA +109 -0
- numba_cuda-0.22.0.dist-info/RECORD +487 -0
- numba_cuda-0.22.0.dist-info/WHEEL +6 -0
- numba_cuda-0.22.0.dist-info/licenses/LICENSE +26 -0
- numba_cuda-0.22.0.dist-info/licenses/LICENSE.numba +24 -0
- numba_cuda-0.22.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.cuda import utils
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DelayedRegistry(utils.UniqueDict):
|
|
8
|
+
"""
|
|
9
|
+
A unique dictionary but with deferred initialisation of the values.
|
|
10
|
+
|
|
11
|
+
Attributes
|
|
12
|
+
----------
|
|
13
|
+
ondemand:
|
|
14
|
+
|
|
15
|
+
A dictionary of key -> value, where value is executed
|
|
16
|
+
the first time it is is used. It is used for part of a deferred
|
|
17
|
+
initialization strategy.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, *args, **kws):
|
|
21
|
+
self.ondemand = utils.UniqueDict()
|
|
22
|
+
self.key_type = kws.pop("key_type", None)
|
|
23
|
+
self.value_type = kws.pop("value_type", None)
|
|
24
|
+
self._type_check = self.key_type or self.value_type
|
|
25
|
+
super(DelayedRegistry, self).__init__(*args, **kws)
|
|
26
|
+
|
|
27
|
+
def __getitem__(self, item):
|
|
28
|
+
if item in self.ondemand:
|
|
29
|
+
self[item] = self.ondemand[item]()
|
|
30
|
+
del self.ondemand[item]
|
|
31
|
+
return super(DelayedRegistry, self).__getitem__(item)
|
|
32
|
+
|
|
33
|
+
def __setitem__(self, key, value):
|
|
34
|
+
if self._type_check:
|
|
35
|
+
|
|
36
|
+
def check(x, ty_x):
|
|
37
|
+
if isinstance(ty_x, type):
|
|
38
|
+
assert ty_x in x.__mro__, (x, ty_x)
|
|
39
|
+
else:
|
|
40
|
+
assert isinstance(x, ty_x), (x, ty_x)
|
|
41
|
+
|
|
42
|
+
if self.key_type is not None:
|
|
43
|
+
check(key, self.key_type)
|
|
44
|
+
if self.value_type is not None:
|
|
45
|
+
check(value, self.value_type)
|
|
46
|
+
return super(DelayedRegistry, self).__setitem__(key, value)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Implement a rewrite pass on a LLVM module to remove unnecessary
|
|
6
|
+
refcount operations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from llvmlite.ir.transforms import CallVisitor
|
|
10
|
+
|
|
11
|
+
from numba.cuda import types
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _MarkNrtCallVisitor(CallVisitor):
|
|
15
|
+
"""
|
|
16
|
+
A pass to mark all NRT_incref and NRT_decref.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.marked = set()
|
|
21
|
+
|
|
22
|
+
def visit_Call(self, instr):
|
|
23
|
+
if getattr(instr.callee, "name", "") in _accepted_nrtfns:
|
|
24
|
+
self.marked.add(instr)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _rewrite_function(function):
|
|
28
|
+
# Mark NRT usage
|
|
29
|
+
markpass = _MarkNrtCallVisitor()
|
|
30
|
+
markpass.visit_Function(function)
|
|
31
|
+
# Remove NRT usage
|
|
32
|
+
for bb in function.basic_blocks:
|
|
33
|
+
for inst in list(bb.instructions):
|
|
34
|
+
if inst in markpass.marked:
|
|
35
|
+
bb.instructions.remove(inst)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_accepted_nrtfns = "NRT_incref", "NRT_decref"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _legalize(module, dmm, fndesc):
|
|
42
|
+
"""
|
|
43
|
+
Legalize the code in the module.
|
|
44
|
+
Returns True if the module is legal for the rewrite pass that removes
|
|
45
|
+
unnecessary refcounts.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def valid_output(ty):
|
|
49
|
+
"""
|
|
50
|
+
Valid output are any type that does not need refcount
|
|
51
|
+
"""
|
|
52
|
+
model = dmm[ty]
|
|
53
|
+
return not model.contains_nrt_meminfo()
|
|
54
|
+
|
|
55
|
+
def valid_input(ty):
|
|
56
|
+
"""
|
|
57
|
+
Valid input are any type that does not need refcount except Array.
|
|
58
|
+
"""
|
|
59
|
+
return valid_output(ty) or isinstance(ty, types.Array)
|
|
60
|
+
|
|
61
|
+
# Ensure no reference to function marked as
|
|
62
|
+
# "numba_args_may_always_need_nrt"
|
|
63
|
+
try:
|
|
64
|
+
nmd = module.get_named_metadata("numba_args_may_always_need_nrt")
|
|
65
|
+
except KeyError:
|
|
66
|
+
# Nothing marked
|
|
67
|
+
pass
|
|
68
|
+
else:
|
|
69
|
+
# Has functions marked as "numba_args_may_always_need_nrt"
|
|
70
|
+
if len(nmd.operands) > 0:
|
|
71
|
+
# The pass is illegal for this compilation unit.
|
|
72
|
+
return False
|
|
73
|
+
|
|
74
|
+
# More legalization base on function type
|
|
75
|
+
argtypes = fndesc.argtypes
|
|
76
|
+
restype = fndesc.restype
|
|
77
|
+
calltypes = fndesc.calltypes
|
|
78
|
+
|
|
79
|
+
# Legalize function arguments
|
|
80
|
+
for argty in argtypes:
|
|
81
|
+
if not valid_input(argty):
|
|
82
|
+
return False
|
|
83
|
+
|
|
84
|
+
# Legalize function return
|
|
85
|
+
if not valid_output(restype):
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
# Legalize all called functions
|
|
89
|
+
for callty in calltypes.values():
|
|
90
|
+
if callty is not None and not valid_output(callty.return_type):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
# Ensure no allocation
|
|
94
|
+
for fn in module.functions:
|
|
95
|
+
if fn.name.startswith("NRT_"):
|
|
96
|
+
if fn.name not in _accepted_nrtfns:
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
return True
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def remove_unnecessary_nrt_usage(function, context, fndesc):
|
|
103
|
+
"""
|
|
104
|
+
Remove unnecessary NRT incref/decref in the given LLVM function.
|
|
105
|
+
It uses highlevel type info to determine if the function does not need NRT.
|
|
106
|
+
Such a function does not:
|
|
107
|
+
|
|
108
|
+
- return array object(s);
|
|
109
|
+
- take arguments that need refcounting except array;
|
|
110
|
+
- call function(s) that return refcounted object.
|
|
111
|
+
|
|
112
|
+
In effect, the function will not capture or create references that extend
|
|
113
|
+
the lifetime of any refcounted objects beyond the lifetime of the function.
|
|
114
|
+
|
|
115
|
+
The rewrite is performed in place.
|
|
116
|
+
If rewrite has happened, this function returns True, otherwise, it returns False.
|
|
117
|
+
"""
|
|
118
|
+
dmm = context.data_model_manager
|
|
119
|
+
if _legalize(function.module, dmm, fndesc):
|
|
120
|
+
_rewrite_function(function)
|
|
121
|
+
return True
|
|
122
|
+
else:
|
|
123
|
+
return False
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
A subpackage hosting Numba IR rewrite passes.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .registry import register_rewrite, rewrite_registry, Rewrite
|
|
9
|
+
|
|
10
|
+
# Register various built-in rewrite passes
|
|
11
|
+
from numba.cuda.core.rewrites import (
|
|
12
|
+
static_getitem,
|
|
13
|
+
static_raise,
|
|
14
|
+
static_binop,
|
|
15
|
+
ir_print,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = (
|
|
19
|
+
"static_getitem",
|
|
20
|
+
"static_raise",
|
|
21
|
+
"static_binop",
|
|
22
|
+
"ir_print",
|
|
23
|
+
"register_rewrite",
|
|
24
|
+
"rewrite_registry",
|
|
25
|
+
"Rewrite",
|
|
26
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.cuda.core import errors
|
|
5
|
+
from numba.cuda.core import ir
|
|
6
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@register_rewrite("before-inference")
|
|
10
|
+
class RewritePrintCalls(Rewrite):
|
|
11
|
+
"""
|
|
12
|
+
Rewrite calls to the print() global function to dedicated IR print() nodes.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
16
|
+
self.prints = prints = {}
|
|
17
|
+
self.block = block
|
|
18
|
+
# Find all assignments with a right-hand print() call
|
|
19
|
+
for inst in block.find_insts(ir.Assign):
|
|
20
|
+
if isinstance(inst.value, ir.Expr) and inst.value.op == "call":
|
|
21
|
+
expr = inst.value
|
|
22
|
+
try:
|
|
23
|
+
callee = func_ir.infer_constant(expr.func)
|
|
24
|
+
except errors.ConstantInferenceError:
|
|
25
|
+
continue
|
|
26
|
+
if callee is print:
|
|
27
|
+
if expr.kws:
|
|
28
|
+
# Only positional args are supported
|
|
29
|
+
msg = (
|
|
30
|
+
"Numba's print() function implementation does not "
|
|
31
|
+
"support keyword arguments."
|
|
32
|
+
)
|
|
33
|
+
raise errors.UnsupportedError(msg, inst.loc)
|
|
34
|
+
prints[inst] = expr
|
|
35
|
+
return len(prints) > 0
|
|
36
|
+
|
|
37
|
+
def apply(self):
|
|
38
|
+
"""
|
|
39
|
+
Rewrite `var = call <print function>(...)` as a sequence of
|
|
40
|
+
`print(...)` and `var = const(None)`.
|
|
41
|
+
"""
|
|
42
|
+
new_block = self.block.copy()
|
|
43
|
+
new_block.clear()
|
|
44
|
+
for inst in self.block.body:
|
|
45
|
+
if inst in self.prints:
|
|
46
|
+
expr = self.prints[inst]
|
|
47
|
+
print_node = ir.Print(
|
|
48
|
+
args=expr.args, vararg=expr.vararg, loc=expr.loc
|
|
49
|
+
)
|
|
50
|
+
new_block.append(print_node)
|
|
51
|
+
assign_node = ir.Assign(
|
|
52
|
+
value=ir.Const(None, loc=expr.loc),
|
|
53
|
+
target=inst.target,
|
|
54
|
+
loc=inst.loc,
|
|
55
|
+
)
|
|
56
|
+
new_block.append(assign_node)
|
|
57
|
+
else:
|
|
58
|
+
new_block.append(inst)
|
|
59
|
+
return new_block
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@register_rewrite("before-inference")
|
|
63
|
+
class DetectConstPrintArguments(Rewrite):
|
|
64
|
+
"""
|
|
65
|
+
Detect and store constant arguments to print() nodes.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
69
|
+
self.consts = consts = {}
|
|
70
|
+
self.block = block
|
|
71
|
+
for inst in block.find_insts(ir.Print):
|
|
72
|
+
if inst.consts:
|
|
73
|
+
# Already rewritten
|
|
74
|
+
continue
|
|
75
|
+
for idx, var in enumerate(inst.args):
|
|
76
|
+
try:
|
|
77
|
+
const = func_ir.infer_constant(var)
|
|
78
|
+
except errors.ConstantInferenceError:
|
|
79
|
+
continue
|
|
80
|
+
consts.setdefault(inst, {})[idx] = const
|
|
81
|
+
|
|
82
|
+
return len(consts) > 0
|
|
83
|
+
|
|
84
|
+
def apply(self):
|
|
85
|
+
"""
|
|
86
|
+
Store detected constant arguments on their nodes.
|
|
87
|
+
"""
|
|
88
|
+
for inst in self.block.body:
|
|
89
|
+
if inst in self.consts:
|
|
90
|
+
inst.consts = self.consts[inst]
|
|
91
|
+
return self.block
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
|
|
6
|
+
from numba.cuda import config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Rewrite(object):
|
|
10
|
+
"""Defines the abstract base class for Numba rewrites."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, state=None):
|
|
13
|
+
"""Constructor for the Rewrite class."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def match(self, func_ir, block, typemap, calltypes) -> bool:
|
|
17
|
+
"""Overload this method to check an IR block for matching terms in the
|
|
18
|
+
rewrite.
|
|
19
|
+
"""
|
|
20
|
+
return False
|
|
21
|
+
|
|
22
|
+
def apply(self):
|
|
23
|
+
"""Overload this method to return a rewritten IR basic block when a
|
|
24
|
+
match has been found.
|
|
25
|
+
"""
|
|
26
|
+
raise NotImplementedError("Abstract Rewrite.apply() called!")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RewriteRegistry(object):
|
|
30
|
+
"""Defines a registry for Numba rewrites."""
|
|
31
|
+
|
|
32
|
+
_kinds = frozenset(["before-inference", "after-inference"])
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""Constructor for the rewrite registry. Initializes the rewrites
|
|
36
|
+
member to an empty list.
|
|
37
|
+
"""
|
|
38
|
+
self.rewrites = defaultdict(list)
|
|
39
|
+
|
|
40
|
+
def register(self, kind):
|
|
41
|
+
"""
|
|
42
|
+
Decorator adding a subclass of Rewrite to the registry for
|
|
43
|
+
the given *kind*.
|
|
44
|
+
"""
|
|
45
|
+
if kind not in self._kinds:
|
|
46
|
+
raise KeyError("invalid kind %r" % (kind,))
|
|
47
|
+
|
|
48
|
+
def do_register(rewrite_cls):
|
|
49
|
+
if not issubclass(rewrite_cls, Rewrite):
|
|
50
|
+
raise TypeError(
|
|
51
|
+
"{0} is not a subclass of Rewrite".format(rewrite_cls)
|
|
52
|
+
)
|
|
53
|
+
self.rewrites[kind].append(rewrite_cls)
|
|
54
|
+
return rewrite_cls
|
|
55
|
+
|
|
56
|
+
return do_register
|
|
57
|
+
|
|
58
|
+
def apply(self, kind, state):
|
|
59
|
+
"""Given a pipeline and a dictionary of basic blocks, exhaustively
|
|
60
|
+
attempt to apply all registered rewrites to all basic blocks.
|
|
61
|
+
"""
|
|
62
|
+
assert kind in self._kinds
|
|
63
|
+
blocks = state.func_ir.blocks
|
|
64
|
+
old_blocks = blocks.copy()
|
|
65
|
+
for rewrite_cls in self.rewrites[kind]:
|
|
66
|
+
# Exhaustively apply a rewrite until it stops matching.
|
|
67
|
+
rewrite = rewrite_cls(state)
|
|
68
|
+
work_list = list(blocks.items())
|
|
69
|
+
while work_list:
|
|
70
|
+
key, block = work_list.pop()
|
|
71
|
+
matches = rewrite.match(
|
|
72
|
+
state.func_ir, block, state.typemap, state.calltypes
|
|
73
|
+
)
|
|
74
|
+
if matches:
|
|
75
|
+
if config.DEBUG or config.DUMP_IR:
|
|
76
|
+
print("_" * 70)
|
|
77
|
+
print("REWRITING (%s):" % rewrite_cls.__name__)
|
|
78
|
+
block.dump()
|
|
79
|
+
print("_" * 60)
|
|
80
|
+
new_block = rewrite.apply()
|
|
81
|
+
blocks[key] = new_block
|
|
82
|
+
work_list.append((key, new_block))
|
|
83
|
+
if config.DEBUG or config.DUMP_IR:
|
|
84
|
+
new_block.dump()
|
|
85
|
+
print("_" * 70)
|
|
86
|
+
# If any blocks were changed, perform a sanity check.
|
|
87
|
+
for key, block in blocks.items():
|
|
88
|
+
if block != old_blocks[key]:
|
|
89
|
+
block.verify()
|
|
90
|
+
|
|
91
|
+
# Some passes, e.g. _inline_const_arraycall are known to occasionally
|
|
92
|
+
# do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
|
|
93
|
+
# things with ir.Del, but the placement is not optimal. The lines below
|
|
94
|
+
# fix-up the IR so that ref counts are valid and optimally placed,
|
|
95
|
+
# see #4093 for context. This has to be run here opposed to in
|
|
96
|
+
# apply() as the CFG needs computing so full IR is needed.
|
|
97
|
+
from numba.cuda.core import postproc
|
|
98
|
+
|
|
99
|
+
post_proc = postproc.PostProcessor(state.func_ir)
|
|
100
|
+
post_proc.run()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
rewrite_registry = RewriteRegistry()
|
|
104
|
+
register_rewrite = rewrite_registry.register
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.cuda.core import errors
|
|
5
|
+
from numba.cuda.core import ir
|
|
6
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@register_rewrite("before-inference")
|
|
10
|
+
class DetectStaticBinops(Rewrite):
|
|
11
|
+
"""
|
|
12
|
+
Detect constant arguments to select binops.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
# Those operators can benefit from a constant-inferred argument
|
|
16
|
+
rhs_operators = {"**"}
|
|
17
|
+
|
|
18
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
19
|
+
self.static_lhs = {}
|
|
20
|
+
self.static_rhs = {}
|
|
21
|
+
self.block = block
|
|
22
|
+
# Find binop expressions with a constant lhs or rhs
|
|
23
|
+
for expr in block.find_exprs(op="binop"):
|
|
24
|
+
try:
|
|
25
|
+
if (
|
|
26
|
+
expr.fn in self.rhs_operators
|
|
27
|
+
and expr.static_rhs is ir.UNDEFINED
|
|
28
|
+
):
|
|
29
|
+
self.static_rhs[expr] = func_ir.infer_constant(expr.rhs)
|
|
30
|
+
except errors.ConstantInferenceError:
|
|
31
|
+
continue
|
|
32
|
+
|
|
33
|
+
return len(self.static_lhs) > 0 or len(self.static_rhs) > 0
|
|
34
|
+
|
|
35
|
+
def apply(self):
|
|
36
|
+
"""
|
|
37
|
+
Store constant arguments that were detected in match().
|
|
38
|
+
"""
|
|
39
|
+
for expr, rhs in self.static_rhs.items():
|
|
40
|
+
expr.static_rhs = rhs
|
|
41
|
+
return self.block
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.cuda.core import errors
|
|
5
|
+
from numba.cuda.core import ir
|
|
6
|
+
from numba.cuda import types
|
|
7
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@register_rewrite("before-inference")
|
|
11
|
+
class RewriteConstGetitems(Rewrite):
|
|
12
|
+
"""
|
|
13
|
+
Rewrite IR expressions of the kind `getitem(value=arr, index=$constXX)`
|
|
14
|
+
where `$constXX` is a known constant as
|
|
15
|
+
`static_getitem(value=arr, index=<constant value>)`.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
19
|
+
self.getitems = getitems = {}
|
|
20
|
+
self.block = block
|
|
21
|
+
# Detect all getitem expressions and find which ones can be
|
|
22
|
+
# rewritten
|
|
23
|
+
for expr in block.find_exprs(op="getitem"):
|
|
24
|
+
if expr.op == "getitem":
|
|
25
|
+
try:
|
|
26
|
+
const = func_ir.infer_constant(expr.index)
|
|
27
|
+
except errors.ConstantInferenceError:
|
|
28
|
+
continue
|
|
29
|
+
getitems[expr] = const
|
|
30
|
+
|
|
31
|
+
return len(getitems) > 0
|
|
32
|
+
|
|
33
|
+
def apply(self):
|
|
34
|
+
"""
|
|
35
|
+
Rewrite all matching getitems as static_getitems.
|
|
36
|
+
"""
|
|
37
|
+
new_block = self.block.copy()
|
|
38
|
+
new_block.clear()
|
|
39
|
+
for inst in self.block.body:
|
|
40
|
+
if isinstance(inst, ir.Assign):
|
|
41
|
+
expr = inst.value
|
|
42
|
+
if expr in self.getitems:
|
|
43
|
+
const = self.getitems[expr]
|
|
44
|
+
new_expr = ir.Expr.static_getitem(
|
|
45
|
+
value=expr.value,
|
|
46
|
+
index=const,
|
|
47
|
+
index_var=expr.index,
|
|
48
|
+
loc=expr.loc,
|
|
49
|
+
)
|
|
50
|
+
inst = ir.Assign(
|
|
51
|
+
value=new_expr, target=inst.target, loc=inst.loc
|
|
52
|
+
)
|
|
53
|
+
new_block.append(inst)
|
|
54
|
+
return new_block
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@register_rewrite("after-inference")
|
|
58
|
+
class RewriteStringLiteralGetitems(Rewrite):
|
|
59
|
+
"""
|
|
60
|
+
Rewrite IR expressions of the kind `getitem(value=arr, index=$XX)`
|
|
61
|
+
where `$XX` is a StringLiteral value as
|
|
62
|
+
`static_getitem(value=arr, index=<literal value>)`.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
66
|
+
"""
|
|
67
|
+
Detect all getitem expressions and find which ones have
|
|
68
|
+
string literal indexes
|
|
69
|
+
"""
|
|
70
|
+
self.getitems = getitems = {}
|
|
71
|
+
self.block = block
|
|
72
|
+
self.calltypes = calltypes
|
|
73
|
+
for expr in block.find_exprs(op="getitem"):
|
|
74
|
+
if expr.op == "getitem":
|
|
75
|
+
index_ty = typemap[expr.index.name]
|
|
76
|
+
if isinstance(index_ty, types.StringLiteral):
|
|
77
|
+
getitems[expr] = (expr.index, index_ty.literal_value)
|
|
78
|
+
|
|
79
|
+
return len(getitems) > 0
|
|
80
|
+
|
|
81
|
+
def apply(self):
|
|
82
|
+
"""
|
|
83
|
+
Rewrite all matching getitems as static_getitems where the index
|
|
84
|
+
is the literal value of the string.
|
|
85
|
+
"""
|
|
86
|
+
new_block = ir.Block(self.block.scope, self.block.loc)
|
|
87
|
+
for inst in self.block.body:
|
|
88
|
+
if isinstance(inst, ir.Assign):
|
|
89
|
+
expr = inst.value
|
|
90
|
+
if expr in self.getitems:
|
|
91
|
+
const, lit_val = self.getitems[expr]
|
|
92
|
+
new_expr = ir.Expr.static_getitem(
|
|
93
|
+
value=expr.value,
|
|
94
|
+
index=lit_val,
|
|
95
|
+
index_var=expr.index,
|
|
96
|
+
loc=expr.loc,
|
|
97
|
+
)
|
|
98
|
+
self.calltypes[new_expr] = self.calltypes[expr]
|
|
99
|
+
inst = ir.Assign(
|
|
100
|
+
value=new_expr, target=inst.target, loc=inst.loc
|
|
101
|
+
)
|
|
102
|
+
new_block.append(inst)
|
|
103
|
+
return new_block
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@register_rewrite("after-inference")
|
|
107
|
+
class RewriteStringLiteralSetitems(Rewrite):
|
|
108
|
+
"""
|
|
109
|
+
Rewrite IR expressions of the kind `setitem(value=arr, index=$XX, value=)`
|
|
110
|
+
where `$XX` is a StringLiteral value as
|
|
111
|
+
`static_setitem(value=arr, index=<literal value>, value=)`.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
115
|
+
"""
|
|
116
|
+
Detect all setitem expressions and find which ones have
|
|
117
|
+
string literal indexes
|
|
118
|
+
"""
|
|
119
|
+
self.setitems = setitems = {}
|
|
120
|
+
self.block = block
|
|
121
|
+
self.calltypes = calltypes
|
|
122
|
+
for inst in block.find_insts(ir.SetItem):
|
|
123
|
+
index_ty = typemap[inst.index.name]
|
|
124
|
+
if isinstance(index_ty, types.StringLiteral):
|
|
125
|
+
setitems[inst] = (inst.index, index_ty.literal_value)
|
|
126
|
+
|
|
127
|
+
return len(setitems) > 0
|
|
128
|
+
|
|
129
|
+
def apply(self):
|
|
130
|
+
"""
|
|
131
|
+
Rewrite all matching setitems as static_setitems where the index
|
|
132
|
+
is the literal value of the string.
|
|
133
|
+
"""
|
|
134
|
+
new_block = ir.Block(self.block.scope, self.block.loc)
|
|
135
|
+
for inst in self.block.body:
|
|
136
|
+
if isinstance(inst, ir.SetItem):
|
|
137
|
+
if inst in self.setitems:
|
|
138
|
+
const, lit_val = self.setitems[inst]
|
|
139
|
+
new_inst = ir.StaticSetItem(
|
|
140
|
+
target=inst.target,
|
|
141
|
+
index=lit_val,
|
|
142
|
+
index_var=inst.index,
|
|
143
|
+
value=inst.value,
|
|
144
|
+
loc=inst.loc,
|
|
145
|
+
)
|
|
146
|
+
self.calltypes[new_inst] = self.calltypes[inst]
|
|
147
|
+
inst = new_inst
|
|
148
|
+
new_block.append(inst)
|
|
149
|
+
return new_block
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@register_rewrite("before-inference")
|
|
153
|
+
class RewriteConstSetitems(Rewrite):
|
|
154
|
+
"""
|
|
155
|
+
Rewrite IR statements of the kind `setitem(target=arr, index=$constXX, ...)`
|
|
156
|
+
where `$constXX` is a known constant as
|
|
157
|
+
`static_setitem(target=arr, index=<constant value>, ...)`.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
161
|
+
self.setitems = setitems = {}
|
|
162
|
+
self.block = block
|
|
163
|
+
# Detect all setitem statements and find which ones can be
|
|
164
|
+
# rewritten
|
|
165
|
+
for inst in block.find_insts(ir.SetItem):
|
|
166
|
+
try:
|
|
167
|
+
const = func_ir.infer_constant(inst.index)
|
|
168
|
+
except errors.ConstantInferenceError:
|
|
169
|
+
continue
|
|
170
|
+
setitems[inst] = const
|
|
171
|
+
|
|
172
|
+
return len(setitems) > 0
|
|
173
|
+
|
|
174
|
+
def apply(self):
|
|
175
|
+
"""
|
|
176
|
+
Rewrite all matching setitems as static_setitems.
|
|
177
|
+
"""
|
|
178
|
+
new_block = self.block.copy()
|
|
179
|
+
new_block.clear()
|
|
180
|
+
for inst in self.block.body:
|
|
181
|
+
if inst in self.setitems:
|
|
182
|
+
const = self.setitems[inst]
|
|
183
|
+
new_inst = ir.StaticSetItem(
|
|
184
|
+
inst.target, const, inst.index, inst.value, inst.loc
|
|
185
|
+
)
|
|
186
|
+
new_block.append(new_inst)
|
|
187
|
+
else:
|
|
188
|
+
new_block.append(inst)
|
|
189
|
+
return new_block
|