numba-cuda 0.19.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- _numba_cuda_redirector.pth +3 -0
- _numba_cuda_redirector.py +3 -0
- numba_cuda/VERSION +1 -1
- numba_cuda/__init__.py +2 -1
- numba_cuda/_version.py +2 -13
- numba_cuda/numba/cuda/__init__.py +4 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12708 -1469
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2656 -8769
- numba_cuda/numba/cuda/api.py +9 -1
- numba_cuda/numba/cuda/api_util.py +3 -0
- numba_cuda/numba/cuda/args.py +3 -0
- numba_cuda/numba/cuda/bf16.py +288 -2
- numba_cuda/numba/cuda/cg.py +3 -0
- numba_cuda/numba/cuda/cgutils.py +5 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +4 -1
- numba_cuda/numba/cuda/compiler.py +376 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +5 -2
- numba_cuda/numba/cuda/core/callconv.py +3 -0
- numba_cuda/numba/cuda/core/codegen.py +3 -0
- numba_cuda/numba/cuda/core/compiler.py +9 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +52 -27
- numba_cuda/numba/cuda/core/ir_utils.py +17 -29
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/sigutils.py +3 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +741 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +425 -246
- numba_cuda/numba/cuda/cudadecl.py +4 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +4 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +5 -1
- numba_cuda/numba/cuda/cudadrv/devices.py +3 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +14 -140
- numba_cuda/numba/cuda/cudadrv/drvapi.py +3 -0
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +114 -24
- numba_cuda/numba/cuda/cudadrv/enums.py +3 -0
- numba_cuda/numba/cuda/cudadrv/error.py +4 -0
- numba_cuda/numba/cuda/cudadrv/libs.py +8 -5
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +3 -0
- numba_cuda/numba/cuda/cudadrv/mappings.py +4 -1
- numba_cuda/numba/cuda/cudadrv/ndarray.py +3 -0
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +22 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +4 -4
- numba_cuda/numba/cuda/cudadrv/rtapi.py +3 -0
- numba_cuda/numba/cuda/cudadrv/runtime.py +4 -1
- numba_cuda/numba/cuda/cudaimpl.py +8 -1
- numba_cuda/numba/cuda/cudamath.py +3 -0
- numba_cuda/numba/cuda/debuginfo.py +88 -2
- numba_cuda/numba/cuda/decorators.py +6 -3
- numba_cuda/numba/cuda/descriptor.py +6 -4
- numba_cuda/numba/cuda/device_init.py +3 -0
- numba_cuda/numba/cuda/deviceufunc.py +69 -2
- numba_cuda/numba/cuda/dispatcher.py +21 -39
- numba_cuda/numba/cuda/errors.py +10 -0
- numba_cuda/numba/cuda/extending.py +3 -0
- numba_cuda/numba/cuda/flags.py +143 -1
- numba_cuda/numba/cuda/fp16.py +3 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/initialize.py +4 -0
- numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -0
- numba_cuda/numba/cuda/intrinsics.py +3 -0
- numba_cuda/numba/cuda/itanium_mangler.py +3 -0
- numba_cuda/numba/cuda/kernels/__init__.py +2 -0
- numba_cuda/numba/cuda/kernels/reduction.py +3 -0
- numba_cuda/numba/cuda/kernels/transpose.py +3 -0
- numba_cuda/numba/cuda/libdevice.py +4 -0
- numba_cuda/numba/cuda/libdevicedecl.py +3 -0
- numba_cuda/numba/cuda/libdevicefuncs.py +3 -0
- numba_cuda/numba/cuda/libdeviceimpl.py +3 -0
- numba_cuda/numba/cuda/locks.py +3 -0
- numba_cuda/numba/cuda/lowering.py +59 -159
- numba_cuda/numba/cuda/mathimpl.py +5 -1
- numba_cuda/numba/cuda/memory_management/__init__.py +3 -0
- numba_cuda/numba/cuda/memory_management/memsys.cu +5 -0
- numba_cuda/numba/cuda/memory_management/memsys.cuh +5 -0
- numba_cuda/numba/cuda/memory_management/nrt.cu +5 -0
- numba_cuda/numba/cuda/memory_management/nrt.cuh +5 -0
- numba_cuda/numba/cuda/memory_management/nrt.py +48 -18
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +12 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +4 -1
- numba_cuda/numba/cuda/printimpl.py +15 -1
- numba_cuda/numba/cuda/random.py +4 -1
- numba_cuda/numba/cuda/reshape_funcs.cu +5 -0
- numba_cuda/numba/cuda/serialize.py +4 -1
- numba_cuda/numba/cuda/simulator/__init__.py +4 -1
- numba_cuda/numba/cuda/simulator/_internal/__init__.py +3 -0
- numba_cuda/numba/cuda/simulator/_internal/cuda_bf16.py +2 -0
- numba_cuda/numba/cuda/simulator/api.py +4 -1
- numba_cuda/numba/cuda/simulator/bf16.py +3 -0
- numba_cuda/numba/cuda/simulator/compiler.py +7 -0
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +4 -1
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/error.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/linkable_code.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvrtc.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -0
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -0
- numba_cuda/numba/cuda/simulator/dispatcher.py +4 -0
- numba_cuda/numba/cuda/simulator/kernel.py +3 -0
- numba_cuda/numba/cuda/simulator/kernelapi.py +4 -1
- numba_cuda/numba/cuda/simulator/memory_management/__init__.py +3 -0
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +17 -2
- numba_cuda/numba/cuda/simulator/reduction.py +3 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +3 -0
- numba_cuda/numba/cuda/simulator_init.py +3 -0
- numba_cuda/numba/cuda/stubs.py +3 -0
- numba_cuda/numba/cuda/target.py +38 -17
- numba_cuda/numba/cuda/testing.py +7 -19
- numba_cuda/numba/cuda/tests/__init__.py +4 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/complex_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/core/serialize_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +3 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +7 -4
- numba_cuda/numba/cuda/tests/cudadrv/__init__.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +4 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +4 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +4 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +9 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +4 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +21 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +5 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +4 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +4 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +3 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/__init__.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +5 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +542 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +84 -1
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +4 -3
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +314 -3
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +5 -1
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +21 -8
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +13 -37
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +23 -0
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +266 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +115 -6
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +3 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +4 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +4 -1
- numba_cuda/numba/cuda/tests/cudasim/__init__.py +3 -0
- numba_cuda/numba/cuda/tests/cudasim/support.py +3 -0
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +3 -0
- numba_cuda/numba/cuda/tests/data/__init__.py +2 -0
- numba_cuda/numba/cuda/tests/data/cta_barrier.cu +5 -0
- numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
- numba_cuda/numba/cuda/tests/data/error.cu +5 -0
- numba_cuda/numba/cuda/tests/data/include/add.cuh +5 -0
- numba_cuda/numba/cuda/tests/data/jitlink.cu +5 -0
- numba_cuda/numba/cuda/tests/data/warn.cu +5 -0
- numba_cuda/numba/cuda/tests/doc_examples/__init__.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +5 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +5 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +5 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +6 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +3 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +3 -0
- numba_cuda/numba/cuda/tests/enum_usecases.py +3 -0
- numba_cuda/numba/cuda/tests/nocuda/__init__.py +3 -0
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +3 -0
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +3 -0
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +6 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +27 -12
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +3 -0
- numba_cuda/numba/cuda/tests/nrt/__init__.py +3 -0
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +5 -1
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +3 -0
- numba_cuda/numba/cuda/tests/support.py +58 -15
- numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +3 -0
- numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -1
- numba_cuda/numba/cuda/tests/test_binary_generation/nrt_extern.cu +5 -0
- numba_cuda/numba/cuda/tests/test_binary_generation/test_device_functions.cu +5 -0
- numba_cuda/numba/cuda/tests/test_binary_generation/undefined_extern.cu +5 -0
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +59 -0
- numba_cuda/numba/cuda/typing/__init__.py +12 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +10 -14
- numba_cuda/numba/cuda/ufuncs.py +6 -3
- numba_cuda/numba/cuda/utils.py +9 -112
- numba_cuda/numba/cuda/vector_types.py +3 -0
- numba_cuda/numba/cuda/vectorizers.py +3 -0
- {numba_cuda-0.19.0.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +6 -2
- numba_cuda-0.20.0.dist-info/RECORD +357 -0
- {numba_cuda-0.19.0.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +1 -0
- numba_cuda-0.20.0.dist-info/licenses/LICENSE.numba +24 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -57
- numba_cuda-0.19.0.dist-info/RECORD +0 -301
- {numba_cuda-0.19.0.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.0.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
A subpackage hosting Numba IR rewrite passes.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .registry import register_rewrite, rewrite_registry, Rewrite
|
|
9
|
+
|
|
10
|
+
# Register various built-in rewrite passes
|
|
11
|
+
from numba.cuda.core.rewrites import (
|
|
12
|
+
static_getitem,
|
|
13
|
+
static_raise,
|
|
14
|
+
static_binop,
|
|
15
|
+
ir_print,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = (
|
|
19
|
+
"static_getitem",
|
|
20
|
+
"static_raise",
|
|
21
|
+
"static_binop",
|
|
22
|
+
"ir_print",
|
|
23
|
+
"register_rewrite",
|
|
24
|
+
"rewrite_registry",
|
|
25
|
+
"Rewrite",
|
|
26
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.core import errors, ir
|
|
5
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_rewrite("before-inference")
|
|
9
|
+
class RewritePrintCalls(Rewrite):
|
|
10
|
+
"""
|
|
11
|
+
Rewrite calls to the print() global function to dedicated IR print() nodes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
15
|
+
self.prints = prints = {}
|
|
16
|
+
self.block = block
|
|
17
|
+
# Find all assignments with a right-hand print() call
|
|
18
|
+
for inst in block.find_insts(ir.Assign):
|
|
19
|
+
if isinstance(inst.value, ir.Expr) and inst.value.op == "call":
|
|
20
|
+
expr = inst.value
|
|
21
|
+
try:
|
|
22
|
+
callee = func_ir.infer_constant(expr.func)
|
|
23
|
+
except errors.ConstantInferenceError:
|
|
24
|
+
continue
|
|
25
|
+
if callee is print:
|
|
26
|
+
if expr.kws:
|
|
27
|
+
# Only positional args are supported
|
|
28
|
+
msg = (
|
|
29
|
+
"Numba's print() function implementation does not "
|
|
30
|
+
"support keyword arguments."
|
|
31
|
+
)
|
|
32
|
+
raise errors.UnsupportedError(msg, inst.loc)
|
|
33
|
+
prints[inst] = expr
|
|
34
|
+
return len(prints) > 0
|
|
35
|
+
|
|
36
|
+
def apply(self):
|
|
37
|
+
"""
|
|
38
|
+
Rewrite `var = call <print function>(...)` as a sequence of
|
|
39
|
+
`print(...)` and `var = const(None)`.
|
|
40
|
+
"""
|
|
41
|
+
new_block = self.block.copy()
|
|
42
|
+
new_block.clear()
|
|
43
|
+
for inst in self.block.body:
|
|
44
|
+
if inst in self.prints:
|
|
45
|
+
expr = self.prints[inst]
|
|
46
|
+
print_node = ir.Print(
|
|
47
|
+
args=expr.args, vararg=expr.vararg, loc=expr.loc
|
|
48
|
+
)
|
|
49
|
+
new_block.append(print_node)
|
|
50
|
+
assign_node = ir.Assign(
|
|
51
|
+
value=ir.Const(None, loc=expr.loc),
|
|
52
|
+
target=inst.target,
|
|
53
|
+
loc=inst.loc,
|
|
54
|
+
)
|
|
55
|
+
new_block.append(assign_node)
|
|
56
|
+
else:
|
|
57
|
+
new_block.append(inst)
|
|
58
|
+
return new_block
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register_rewrite("before-inference")
|
|
62
|
+
class DetectConstPrintArguments(Rewrite):
|
|
63
|
+
"""
|
|
64
|
+
Detect and store constant arguments to print() nodes.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
68
|
+
self.consts = consts = {}
|
|
69
|
+
self.block = block
|
|
70
|
+
for inst in block.find_insts(ir.Print):
|
|
71
|
+
if inst.consts:
|
|
72
|
+
# Already rewritten
|
|
73
|
+
continue
|
|
74
|
+
for idx, var in enumerate(inst.args):
|
|
75
|
+
try:
|
|
76
|
+
const = func_ir.infer_constant(var)
|
|
77
|
+
except errors.ConstantInferenceError:
|
|
78
|
+
continue
|
|
79
|
+
consts.setdefault(inst, {})[idx] = const
|
|
80
|
+
|
|
81
|
+
return len(consts) > 0
|
|
82
|
+
|
|
83
|
+
def apply(self):
|
|
84
|
+
"""
|
|
85
|
+
Store detected constant arguments on their nodes.
|
|
86
|
+
"""
|
|
87
|
+
for inst in self.block.body:
|
|
88
|
+
if inst in self.consts:
|
|
89
|
+
inst.consts = self.consts[inst]
|
|
90
|
+
return self.block
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
|
|
6
|
+
from numba.core import config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Rewrite(object):
|
|
10
|
+
"""Defines the abstract base class for Numba rewrites."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, state=None):
|
|
13
|
+
"""Constructor for the Rewrite class."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def match(self, func_ir, block, typemap, calltypes) -> bool:
|
|
17
|
+
"""Overload this method to check an IR block for matching terms in the
|
|
18
|
+
rewrite.
|
|
19
|
+
"""
|
|
20
|
+
return False
|
|
21
|
+
|
|
22
|
+
def apply(self):
|
|
23
|
+
"""Overload this method to return a rewritten IR basic block when a
|
|
24
|
+
match has been found.
|
|
25
|
+
"""
|
|
26
|
+
raise NotImplementedError("Abstract Rewrite.apply() called!")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RewriteRegistry(object):
|
|
30
|
+
"""Defines a registry for Numba rewrites."""
|
|
31
|
+
|
|
32
|
+
_kinds = frozenset(["before-inference", "after-inference"])
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""Constructor for the rewrite registry. Initializes the rewrites
|
|
36
|
+
member to an empty list.
|
|
37
|
+
"""
|
|
38
|
+
self.rewrites = defaultdict(list)
|
|
39
|
+
|
|
40
|
+
def register(self, kind):
|
|
41
|
+
"""
|
|
42
|
+
Decorator adding a subclass of Rewrite to the registry for
|
|
43
|
+
the given *kind*.
|
|
44
|
+
"""
|
|
45
|
+
if kind not in self._kinds:
|
|
46
|
+
raise KeyError("invalid kind %r" % (kind,))
|
|
47
|
+
|
|
48
|
+
def do_register(rewrite_cls):
|
|
49
|
+
if not issubclass(rewrite_cls, Rewrite):
|
|
50
|
+
raise TypeError(
|
|
51
|
+
"{0} is not a subclass of Rewrite".format(rewrite_cls)
|
|
52
|
+
)
|
|
53
|
+
self.rewrites[kind].append(rewrite_cls)
|
|
54
|
+
return rewrite_cls
|
|
55
|
+
|
|
56
|
+
return do_register
|
|
57
|
+
|
|
58
|
+
def apply(self, kind, state):
|
|
59
|
+
"""Given a pipeline and a dictionary of basic blocks, exhaustively
|
|
60
|
+
attempt to apply all registered rewrites to all basic blocks.
|
|
61
|
+
"""
|
|
62
|
+
assert kind in self._kinds
|
|
63
|
+
blocks = state.func_ir.blocks
|
|
64
|
+
old_blocks = blocks.copy()
|
|
65
|
+
for rewrite_cls in self.rewrites[kind]:
|
|
66
|
+
# Exhaustively apply a rewrite until it stops matching.
|
|
67
|
+
rewrite = rewrite_cls(state)
|
|
68
|
+
work_list = list(blocks.items())
|
|
69
|
+
while work_list:
|
|
70
|
+
key, block = work_list.pop()
|
|
71
|
+
matches = rewrite.match(
|
|
72
|
+
state.func_ir, block, state.typemap, state.calltypes
|
|
73
|
+
)
|
|
74
|
+
if matches:
|
|
75
|
+
if config.DEBUG or config.DUMP_IR:
|
|
76
|
+
print("_" * 70)
|
|
77
|
+
print("REWRITING (%s):" % rewrite_cls.__name__)
|
|
78
|
+
block.dump()
|
|
79
|
+
print("_" * 60)
|
|
80
|
+
new_block = rewrite.apply()
|
|
81
|
+
blocks[key] = new_block
|
|
82
|
+
work_list.append((key, new_block))
|
|
83
|
+
if config.DEBUG or config.DUMP_IR:
|
|
84
|
+
new_block.dump()
|
|
85
|
+
print("_" * 70)
|
|
86
|
+
# If any blocks were changed, perform a sanity check.
|
|
87
|
+
for key, block in blocks.items():
|
|
88
|
+
if block != old_blocks[key]:
|
|
89
|
+
block.verify()
|
|
90
|
+
|
|
91
|
+
# Some passes, e.g. _inline_const_arraycall are known to occasionally
|
|
92
|
+
# do invalid things WRT ir.Del, others, e.g. RewriteArrayExprs do valid
|
|
93
|
+
# things with ir.Del, but the placement is not optimal. The lines below
|
|
94
|
+
# fix-up the IR so that ref counts are valid and optimally placed,
|
|
95
|
+
# see #4093 for context. This has to be run here opposed to in
|
|
96
|
+
# apply() as the CFG needs computing so full IR is needed.
|
|
97
|
+
from numba.core import postproc
|
|
98
|
+
|
|
99
|
+
post_proc = postproc.PostProcessor(state.func_ir)
|
|
100
|
+
post_proc.run()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
rewrite_registry = RewriteRegistry()
|
|
104
|
+
register_rewrite = rewrite_registry.register
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.core import errors, ir
|
|
5
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_rewrite("before-inference")
|
|
9
|
+
class DetectStaticBinops(Rewrite):
|
|
10
|
+
"""
|
|
11
|
+
Detect constant arguments to select binops.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Those operators can benefit from a constant-inferred argument
|
|
15
|
+
rhs_operators = {"**"}
|
|
16
|
+
|
|
17
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
18
|
+
self.static_lhs = {}
|
|
19
|
+
self.static_rhs = {}
|
|
20
|
+
self.block = block
|
|
21
|
+
# Find binop expressions with a constant lhs or rhs
|
|
22
|
+
for expr in block.find_exprs(op="binop"):
|
|
23
|
+
try:
|
|
24
|
+
if (
|
|
25
|
+
expr.fn in self.rhs_operators
|
|
26
|
+
and expr.static_rhs is ir.UNDEFINED
|
|
27
|
+
):
|
|
28
|
+
self.static_rhs[expr] = func_ir.infer_constant(expr.rhs)
|
|
29
|
+
except errors.ConstantInferenceError:
|
|
30
|
+
continue
|
|
31
|
+
|
|
32
|
+
return len(self.static_lhs) > 0 or len(self.static_rhs) > 0
|
|
33
|
+
|
|
34
|
+
def apply(self):
|
|
35
|
+
"""
|
|
36
|
+
Store constant arguments that were detected in match().
|
|
37
|
+
"""
|
|
38
|
+
for expr, rhs in self.static_rhs.items():
|
|
39
|
+
expr.static_rhs = rhs
|
|
40
|
+
return self.block
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.core import errors, types, ir
|
|
5
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_rewrite("before-inference")
|
|
9
|
+
class RewriteConstGetitems(Rewrite):
|
|
10
|
+
"""
|
|
11
|
+
Rewrite IR expressions of the kind `getitem(value=arr, index=$constXX)`
|
|
12
|
+
where `$constXX` is a known constant as
|
|
13
|
+
`static_getitem(value=arr, index=<constant value>)`.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
17
|
+
self.getitems = getitems = {}
|
|
18
|
+
self.block = block
|
|
19
|
+
# Detect all getitem expressions and find which ones can be
|
|
20
|
+
# rewritten
|
|
21
|
+
for expr in block.find_exprs(op="getitem"):
|
|
22
|
+
if expr.op == "getitem":
|
|
23
|
+
try:
|
|
24
|
+
const = func_ir.infer_constant(expr.index)
|
|
25
|
+
except errors.ConstantInferenceError:
|
|
26
|
+
continue
|
|
27
|
+
getitems[expr] = const
|
|
28
|
+
|
|
29
|
+
return len(getitems) > 0
|
|
30
|
+
|
|
31
|
+
def apply(self):
|
|
32
|
+
"""
|
|
33
|
+
Rewrite all matching getitems as static_getitems.
|
|
34
|
+
"""
|
|
35
|
+
new_block = self.block.copy()
|
|
36
|
+
new_block.clear()
|
|
37
|
+
for inst in self.block.body:
|
|
38
|
+
if isinstance(inst, ir.Assign):
|
|
39
|
+
expr = inst.value
|
|
40
|
+
if expr in self.getitems:
|
|
41
|
+
const = self.getitems[expr]
|
|
42
|
+
new_expr = ir.Expr.static_getitem(
|
|
43
|
+
value=expr.value,
|
|
44
|
+
index=const,
|
|
45
|
+
index_var=expr.index,
|
|
46
|
+
loc=expr.loc,
|
|
47
|
+
)
|
|
48
|
+
inst = ir.Assign(
|
|
49
|
+
value=new_expr, target=inst.target, loc=inst.loc
|
|
50
|
+
)
|
|
51
|
+
new_block.append(inst)
|
|
52
|
+
return new_block
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@register_rewrite("after-inference")
|
|
56
|
+
class RewriteStringLiteralGetitems(Rewrite):
|
|
57
|
+
"""
|
|
58
|
+
Rewrite IR expressions of the kind `getitem(value=arr, index=$XX)`
|
|
59
|
+
where `$XX` is a StringLiteral value as
|
|
60
|
+
`static_getitem(value=arr, index=<literal value>)`.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
64
|
+
"""
|
|
65
|
+
Detect all getitem expressions and find which ones have
|
|
66
|
+
string literal indexes
|
|
67
|
+
"""
|
|
68
|
+
self.getitems = getitems = {}
|
|
69
|
+
self.block = block
|
|
70
|
+
self.calltypes = calltypes
|
|
71
|
+
for expr in block.find_exprs(op="getitem"):
|
|
72
|
+
if expr.op == "getitem":
|
|
73
|
+
index_ty = typemap[expr.index.name]
|
|
74
|
+
if isinstance(index_ty, types.StringLiteral):
|
|
75
|
+
getitems[expr] = (expr.index, index_ty.literal_value)
|
|
76
|
+
|
|
77
|
+
return len(getitems) > 0
|
|
78
|
+
|
|
79
|
+
def apply(self):
|
|
80
|
+
"""
|
|
81
|
+
Rewrite all matching getitems as static_getitems where the index
|
|
82
|
+
is the literal value of the string.
|
|
83
|
+
"""
|
|
84
|
+
new_block = ir.Block(self.block.scope, self.block.loc)
|
|
85
|
+
for inst in self.block.body:
|
|
86
|
+
if isinstance(inst, ir.Assign):
|
|
87
|
+
expr = inst.value
|
|
88
|
+
if expr in self.getitems:
|
|
89
|
+
const, lit_val = self.getitems[expr]
|
|
90
|
+
new_expr = ir.Expr.static_getitem(
|
|
91
|
+
value=expr.value,
|
|
92
|
+
index=lit_val,
|
|
93
|
+
index_var=expr.index,
|
|
94
|
+
loc=expr.loc,
|
|
95
|
+
)
|
|
96
|
+
self.calltypes[new_expr] = self.calltypes[expr]
|
|
97
|
+
inst = ir.Assign(
|
|
98
|
+
value=new_expr, target=inst.target, loc=inst.loc
|
|
99
|
+
)
|
|
100
|
+
new_block.append(inst)
|
|
101
|
+
return new_block
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@register_rewrite("after-inference")
|
|
105
|
+
class RewriteStringLiteralSetitems(Rewrite):
|
|
106
|
+
"""
|
|
107
|
+
Rewrite IR expressions of the kind `setitem(value=arr, index=$XX, value=)`
|
|
108
|
+
where `$XX` is a StringLiteral value as
|
|
109
|
+
`static_setitem(value=arr, index=<literal value>, value=)`.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
113
|
+
"""
|
|
114
|
+
Detect all setitem expressions and find which ones have
|
|
115
|
+
string literal indexes
|
|
116
|
+
"""
|
|
117
|
+
self.setitems = setitems = {}
|
|
118
|
+
self.block = block
|
|
119
|
+
self.calltypes = calltypes
|
|
120
|
+
for inst in block.find_insts(ir.SetItem):
|
|
121
|
+
index_ty = typemap[inst.index.name]
|
|
122
|
+
if isinstance(index_ty, types.StringLiteral):
|
|
123
|
+
setitems[inst] = (inst.index, index_ty.literal_value)
|
|
124
|
+
|
|
125
|
+
return len(setitems) > 0
|
|
126
|
+
|
|
127
|
+
def apply(self):
|
|
128
|
+
"""
|
|
129
|
+
Rewrite all matching setitems as static_setitems where the index
|
|
130
|
+
is the literal value of the string.
|
|
131
|
+
"""
|
|
132
|
+
new_block = ir.Block(self.block.scope, self.block.loc)
|
|
133
|
+
for inst in self.block.body:
|
|
134
|
+
if isinstance(inst, ir.SetItem):
|
|
135
|
+
if inst in self.setitems:
|
|
136
|
+
const, lit_val = self.setitems[inst]
|
|
137
|
+
new_inst = ir.StaticSetItem(
|
|
138
|
+
target=inst.target,
|
|
139
|
+
index=lit_val,
|
|
140
|
+
index_var=inst.index,
|
|
141
|
+
value=inst.value,
|
|
142
|
+
loc=inst.loc,
|
|
143
|
+
)
|
|
144
|
+
self.calltypes[new_inst] = self.calltypes[inst]
|
|
145
|
+
inst = new_inst
|
|
146
|
+
new_block.append(inst)
|
|
147
|
+
return new_block
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@register_rewrite("before-inference")
|
|
151
|
+
class RewriteConstSetitems(Rewrite):
|
|
152
|
+
"""
|
|
153
|
+
Rewrite IR statements of the kind `setitem(target=arr, index=$constXX, ...)`
|
|
154
|
+
where `$constXX` is a known constant as
|
|
155
|
+
`static_setitem(target=arr, index=<constant value>, ...)`.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
159
|
+
self.setitems = setitems = {}
|
|
160
|
+
self.block = block
|
|
161
|
+
# Detect all setitem statements and find which ones can be
|
|
162
|
+
# rewritten
|
|
163
|
+
for inst in block.find_insts(ir.SetItem):
|
|
164
|
+
try:
|
|
165
|
+
const = func_ir.infer_constant(inst.index)
|
|
166
|
+
except errors.ConstantInferenceError:
|
|
167
|
+
continue
|
|
168
|
+
setitems[inst] = const
|
|
169
|
+
|
|
170
|
+
return len(setitems) > 0
|
|
171
|
+
|
|
172
|
+
def apply(self):
|
|
173
|
+
"""
|
|
174
|
+
Rewrite all matching setitems as static_setitems.
|
|
175
|
+
"""
|
|
176
|
+
new_block = self.block.copy()
|
|
177
|
+
new_block.clear()
|
|
178
|
+
for inst in self.block.body:
|
|
179
|
+
if inst in self.setitems:
|
|
180
|
+
const = self.setitems[inst]
|
|
181
|
+
new_inst = ir.StaticSetItem(
|
|
182
|
+
inst.target, const, inst.index, inst.value, inst.loc
|
|
183
|
+
)
|
|
184
|
+
new_block.append(new_inst)
|
|
185
|
+
else:
|
|
186
|
+
new_block.append(inst)
|
|
187
|
+
return new_block
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from numba.core import errors, consts, ir
|
|
5
|
+
from numba.cuda.core.rewrites import register_rewrite, Rewrite
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_rewrite("before-inference")
|
|
9
|
+
class RewriteConstRaises(Rewrite):
|
|
10
|
+
"""
|
|
11
|
+
Rewrite IR statements of the kind `raise(value)`
|
|
12
|
+
where `value` is the result of instantiating an exception with
|
|
13
|
+
constant arguments
|
|
14
|
+
into `static_raise(exception_type, constant args)`.
|
|
15
|
+
|
|
16
|
+
This allows lowering in nopython mode, where one can't instantiate
|
|
17
|
+
exception instances from runtime data.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def _is_exception_type(self, const):
|
|
21
|
+
return isinstance(const, type) and issubclass(const, Exception)
|
|
22
|
+
|
|
23
|
+
def _break_constant(self, const, loc):
|
|
24
|
+
"""
|
|
25
|
+
Break down constant exception.
|
|
26
|
+
"""
|
|
27
|
+
if isinstance(const, tuple): # it's a tuple(exception class, args)
|
|
28
|
+
if not self._is_exception_type(const[0]):
|
|
29
|
+
msg = "Encountered unsupported exception constant %r"
|
|
30
|
+
raise errors.UnsupportedError(msg % (const[0],), loc)
|
|
31
|
+
return const[0], tuple(const[1])
|
|
32
|
+
elif self._is_exception_type(const):
|
|
33
|
+
return const, None
|
|
34
|
+
else:
|
|
35
|
+
if isinstance(const, str):
|
|
36
|
+
msg = (
|
|
37
|
+
"Directly raising a string constant as an exception is "
|
|
38
|
+
"not supported."
|
|
39
|
+
)
|
|
40
|
+
else:
|
|
41
|
+
msg = "Encountered unsupported constant type used for exception"
|
|
42
|
+
raise errors.UnsupportedError(msg, loc)
|
|
43
|
+
|
|
44
|
+
def _try_infer_constant(self, func_ir, inst):
|
|
45
|
+
try:
|
|
46
|
+
return func_ir.infer_constant(inst.exception)
|
|
47
|
+
except consts.ConstantInferenceError:
|
|
48
|
+
# not a static exception
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
def match(self, func_ir, block, typemap, calltypes):
|
|
52
|
+
self.raises = raises = {}
|
|
53
|
+
self.tryraises = tryraises = {}
|
|
54
|
+
self.block = block
|
|
55
|
+
# Detect all raise statements and find which ones can be
|
|
56
|
+
# rewritten
|
|
57
|
+
for inst in block.find_insts((ir.Raise, ir.TryRaise)):
|
|
58
|
+
if inst.exception is None:
|
|
59
|
+
# re-reraise
|
|
60
|
+
exc_type, exc_args = None, None
|
|
61
|
+
else:
|
|
62
|
+
# raise <something> => find the definition site for <something>
|
|
63
|
+
const = self._try_infer_constant(func_ir, inst)
|
|
64
|
+
|
|
65
|
+
# failure to infer constant indicates this isn't a static
|
|
66
|
+
# exception
|
|
67
|
+
if const is None:
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
loc = inst.exception.loc
|
|
71
|
+
exc_type, exc_args = self._break_constant(const, loc)
|
|
72
|
+
|
|
73
|
+
if isinstance(inst, ir.Raise):
|
|
74
|
+
raises[inst] = exc_type, exc_args
|
|
75
|
+
elif isinstance(inst, ir.TryRaise):
|
|
76
|
+
tryraises[inst] = exc_type, exc_args
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("unexpected: {}".format(type(inst)))
|
|
79
|
+
return (len(raises) + len(tryraises)) > 0
|
|
80
|
+
|
|
81
|
+
def apply(self):
|
|
82
|
+
"""
|
|
83
|
+
Rewrite all matching setitems as static_setitems.
|
|
84
|
+
"""
|
|
85
|
+
new_block = self.block.copy()
|
|
86
|
+
new_block.clear()
|
|
87
|
+
for inst in self.block.body:
|
|
88
|
+
if inst in self.raises:
|
|
89
|
+
exc_type, exc_args = self.raises[inst]
|
|
90
|
+
new_inst = ir.StaticRaise(exc_type, exc_args, inst.loc)
|
|
91
|
+
new_block.append(new_inst)
|
|
92
|
+
elif inst in self.tryraises:
|
|
93
|
+
exc_type, exc_args = self.tryraises[inst]
|
|
94
|
+
new_inst = ir.StaticTryRaise(exc_type, exc_args, inst.loc)
|
|
95
|
+
new_block.append(new_inst)
|
|
96
|
+
else:
|
|
97
|
+
new_block.append(inst)
|
|
98
|
+
return new_block
|