mlx 0.30.7.2 → 0.30.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/extconf.rb +267 -8
- data/ext/mlx/native.cpp +112 -58
- data/ext/mlx-onnx/native.cpp +1402 -0
- data/ext/mlx-onnx/native.hpp +19 -0
- data/lib/mlx/core.rb +342 -117
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +2 -0
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +8 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +10 -0
- data/lib/mlx/nn/base.rb +4 -0
- data/lib/mlx/nn/layers/linear.rb +2 -3
- data/lib/mlx/onnx.rb +250 -0
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx-onnx/webgpu_harness.rb +289 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
- data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
- data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
- data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
- data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
- data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
- data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
- data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
- data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
- data/submodules/mlx-onnx/CMakeLists.txt +159 -0
- data/submodules/mlx-onnx/LICENSE +21 -0
- data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
- data/submodules/mlx-onnx/src/api.cpp +81 -0
- data/submodules/mlx-onnx/src/compat.cpp +111 -0
- data/submodules/mlx-onnx/src/detail.hpp +69 -0
- data/submodules/mlx-onnx/src/export.cpp +653 -0
- data/submodules/mlx-onnx/src/io.cpp +61 -0
- data/submodules/mlx-onnx/src/json.hpp +25 -0
- data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
- data/submodules/mlx-onnx/src/mappings.cpp +201 -0
- data/submodules/mlx-onnx/src/mappings.hpp +16 -0
- data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
- data/submodules/mlx-onnx/src/shared.cpp +206 -0
- metadata +665 -567
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
- /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
- /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
#pragma once
|
|
4
4
|
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
5
7
|
#include "mlx/api.h"
|
|
6
8
|
#include "mlx/device.h"
|
|
7
9
|
|
|
@@ -25,6 +27,9 @@ MLX_API Stream new_stream(Device d);
|
|
|
25
27
|
/** Get the stream with the given index. */
|
|
26
28
|
MLX_API Stream get_stream(int index);
|
|
27
29
|
|
|
30
|
+
/** Get all available streams. */
|
|
31
|
+
MLX_API std::vector<Stream> get_streams();
|
|
32
|
+
|
|
28
33
|
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
|
29
34
|
return lhs.index == rhs.index;
|
|
30
35
|
}
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.25)
|
|
2
|
+
|
|
3
|
+
project(mlx_onnx VERSION 0.30.7.1 LANGUAGES C CXX)
|
|
4
|
+
|
|
5
|
+
set(CMAKE_CXX_STANDARD 20)
|
|
6
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
7
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
8
|
+
|
|
9
|
+
option(MLX_ONNX_USE_EXTERNAL_MLX "Build against an externally provided MLX install" OFF)
|
|
10
|
+
option(MLX_ONNX_BUILD_PYTHON_BINDINGS "Build Python IR bindings" OFF)
|
|
11
|
+
option(MLX_ONNX_INSTALL_CPP_ARTIFACTS "Install C++ library and headers" ON)
|
|
12
|
+
|
|
13
|
+
include(FetchContent)
|
|
14
|
+
|
|
15
|
+
if(MLX_ONNX_USE_EXTERNAL_MLX AND MLX_ONNX_BUILD_PYTHON_BINDINGS)
|
|
16
|
+
message(
|
|
17
|
+
FATAL_ERROR
|
|
18
|
+
"MLX_ONNX_BUILD_PYTHON_BINDINGS requires bundled mlx sources; set MLX_ONNX_USE_EXTERNAL_MLX=OFF")
|
|
19
|
+
endif()
|
|
20
|
+
|
|
21
|
+
if(MLX_ONNX_USE_EXTERNAL_MLX)
|
|
22
|
+
set(MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR "" CACHE PATH "Path to MLX include root")
|
|
23
|
+
set(MLX_ONNX_EXTERNAL_MLX_LIB_DIR "" CACHE PATH "Path to MLX library directory")
|
|
24
|
+
|
|
25
|
+
if(MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR STREQUAL "")
|
|
26
|
+
message(FATAL_ERROR "MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR must be set when MLX_ONNX_USE_EXTERNAL_MLX=ON")
|
|
27
|
+
endif()
|
|
28
|
+
if(MLX_ONNX_EXTERNAL_MLX_LIB_DIR STREQUAL "")
|
|
29
|
+
message(FATAL_ERROR "MLX_ONNX_EXTERNAL_MLX_LIB_DIR must be set when MLX_ONNX_USE_EXTERNAL_MLX=ON")
|
|
30
|
+
endif()
|
|
31
|
+
|
|
32
|
+
find_library(
|
|
33
|
+
MLX_EXTERNAL_LIBRARY
|
|
34
|
+
NAMES mlx
|
|
35
|
+
PATHS ${MLX_ONNX_EXTERNAL_MLX_LIB_DIR}
|
|
36
|
+
NO_DEFAULT_PATH)
|
|
37
|
+
|
|
38
|
+
if(NOT MLX_EXTERNAL_LIBRARY)
|
|
39
|
+
message(FATAL_ERROR "Could not find libmlx in ${MLX_ONNX_EXTERNAL_MLX_LIB_DIR}")
|
|
40
|
+
endif()
|
|
41
|
+
|
|
42
|
+
add_library(mlx SHARED IMPORTED GLOBAL)
|
|
43
|
+
set_target_properties(
|
|
44
|
+
mlx
|
|
45
|
+
PROPERTIES
|
|
46
|
+
IMPORTED_LOCATION ${MLX_EXTERNAL_LIBRARY}
|
|
47
|
+
INTERFACE_INCLUDE_DIRECTORIES ${MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR})
|
|
48
|
+
else()
|
|
49
|
+
set(MLX_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
|
50
|
+
set(MLX_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
|
51
|
+
set(MLX_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
|
|
52
|
+
if(MLX_ONNX_BUILD_PYTHON_BINDINGS)
|
|
53
|
+
set(MLX_BUILD_PYTHON_BINDINGS ON CACHE BOOL "" FORCE)
|
|
54
|
+
else()
|
|
55
|
+
set(MLX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
|
|
56
|
+
endif()
|
|
57
|
+
set(MLX_BUILD_PYTHON_STUBS OFF CACHE BOOL "" FORCE)
|
|
58
|
+
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
|
59
|
+
set(MLX_BUILD_SAFETENSORS OFF CACHE BOOL "" FORCE)
|
|
60
|
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mlx)
|
|
61
|
+
endif()
|
|
62
|
+
|
|
63
|
+
if(NOT TARGET nlohmann_json::nlohmann_json)
|
|
64
|
+
FetchContent_Declare(
|
|
65
|
+
nlohmann_json
|
|
66
|
+
GIT_REPOSITORY https://github.com/nlohmann/json.git
|
|
67
|
+
GIT_TAG v3.11.3
|
|
68
|
+
EXCLUDE_FROM_ALL)
|
|
69
|
+
FetchContent_MakeAvailable(nlohmann_json)
|
|
70
|
+
endif()
|
|
71
|
+
|
|
72
|
+
add_library(
|
|
73
|
+
mlx_onnx
|
|
74
|
+
src/export.cpp
|
|
75
|
+
src/api.cpp
|
|
76
|
+
src/compat.cpp
|
|
77
|
+
src/io.cpp
|
|
78
|
+
src/lowering.cpp
|
|
79
|
+
src/mappings.cpp
|
|
80
|
+
src/onnx.cpp
|
|
81
|
+
src/shared.cpp)
|
|
82
|
+
|
|
83
|
+
set_target_properties(mlx_onnx PROPERTIES OUTPUT_NAME mlx_onnx)
|
|
84
|
+
|
|
85
|
+
target_include_directories(
|
|
86
|
+
mlx_onnx
|
|
87
|
+
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
|
88
|
+
$<INSTALL_INTERFACE:include>
|
|
89
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
|
|
90
|
+
|
|
91
|
+
if(MLX_ONNX_USE_EXTERNAL_MLX)
|
|
92
|
+
target_include_directories(mlx_onnx PRIVATE ${MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR})
|
|
93
|
+
endif()
|
|
94
|
+
|
|
95
|
+
target_link_libraries(mlx_onnx PUBLIC mlx nlohmann_json::nlohmann_json)
|
|
96
|
+
|
|
97
|
+
if(MLX_ONNX_BUILD_PYTHON_BINDINGS)
|
|
98
|
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
|
99
|
+
set(MLX_ONNX_PY_INIT_FILE ${CMAKE_CURRENT_SOURCE_DIR}/python/mlx_onnx/__init__.py)
|
|
100
|
+
if(NOT EXISTS ${MLX_ONNX_PY_INIT_FILE})
|
|
101
|
+
set(MLX_ONNX_PY_INIT_FILE ${CMAKE_CURRENT_BINARY_DIR}/mlx_onnx___init__.py)
|
|
102
|
+
file(WRITE ${MLX_ONNX_PY_INIT_FILE} "from ._core import * # noqa: F401,F403\n")
|
|
103
|
+
endif()
|
|
104
|
+
if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mlx/CMakeLists.txt)
|
|
105
|
+
message(FATAL_ERROR "Bundled mlx sources are missing at ${CMAKE_CURRENT_SOURCE_DIR}/mlx")
|
|
106
|
+
endif()
|
|
107
|
+
if(NOT TARGET core)
|
|
108
|
+
message(FATAL_ERROR "Bundled mlx Python extension target `core` was not built")
|
|
109
|
+
endif()
|
|
110
|
+
install(TARGETS core LIBRARY DESTINATION mlx COMPONENT python)
|
|
111
|
+
if(APPLE AND MLX_BUILD_METAL)
|
|
112
|
+
# MLX looks for mlx.metallib next to the extension module using MLX runtime.
|
|
113
|
+
install(
|
|
114
|
+
FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib
|
|
115
|
+
DESTINATION mlx
|
|
116
|
+
COMPONENT python)
|
|
117
|
+
# mlx_onnx._core also links MLX and resolves the same metallib at runtime.
|
|
118
|
+
install(
|
|
119
|
+
FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib
|
|
120
|
+
DESTINATION mlx_onnx
|
|
121
|
+
COMPONENT python)
|
|
122
|
+
endif()
|
|
123
|
+
install(
|
|
124
|
+
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlx/python/mlx/
|
|
125
|
+
DESTINATION mlx
|
|
126
|
+
COMPONENT python
|
|
127
|
+
PATTERN "__pycache__" EXCLUDE)
|
|
128
|
+
install(
|
|
129
|
+
FILES ${MLX_ONNX_PY_INIT_FILE}
|
|
130
|
+
DESTINATION mlx_onnx
|
|
131
|
+
RENAME __init__.py
|
|
132
|
+
COMPONENT python)
|
|
133
|
+
set(MLX_ONNX_VENDOR_MLX_ROOT mlx_onnx/_vendor/mlx)
|
|
134
|
+
install(
|
|
135
|
+
FILES ${CMAKE_CURRENT_SOURCE_DIR}/mlx/CMakeLists.txt
|
|
136
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/mlx.pc.in
|
|
137
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/LICENSE
|
|
138
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/ACKNOWLEDGMENTS.md
|
|
139
|
+
DESTINATION ${MLX_ONNX_VENDOR_MLX_ROOT}
|
|
140
|
+
COMPONENT python)
|
|
141
|
+
install(
|
|
142
|
+
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlx/cmake
|
|
143
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/mlx
|
|
144
|
+
DESTINATION ${MLX_ONNX_VENDOR_MLX_ROOT}
|
|
145
|
+
COMPONENT python)
|
|
146
|
+
endif()
|
|
147
|
+
|
|
148
|
+
if(MLX_ONNX_INSTALL_CPP_ARTIFACTS)
|
|
149
|
+
include(GNUInstallDirs)
|
|
150
|
+
install(
|
|
151
|
+
TARGETS mlx_onnx
|
|
152
|
+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
153
|
+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
154
|
+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
|
155
|
+
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
|
156
|
+
COMPONENT cpp)
|
|
157
|
+
|
|
158
|
+
install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} COMPONENT cpp)
|
|
159
|
+
endif()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 MLX Contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
#include <functional>
|
|
5
|
+
#include <string>
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
#include "mlx/api.h"
|
|
9
|
+
#include "mlx/array.h"
|
|
10
|
+
#include "mlx/export.h"
|
|
11
|
+
|
|
12
|
+
namespace mlx::onnx {
|
|
13
|
+
|
|
14
|
+
constexpr int64_t kGraphIrVersion = 1;
|
|
15
|
+
|
|
16
|
+
using IrCaptureFunction = std::function<std::vector<mlx::core::array>(
|
|
17
|
+
const mlx::core::Args&,
|
|
18
|
+
const mlx::core::Kwargs&)>;
|
|
19
|
+
|
|
20
|
+
struct OnnxBinaryWriteOptions {
|
|
21
|
+
bool external_data = false;
|
|
22
|
+
std::string external_data_file = "weights.bin";
|
|
23
|
+
int64_t external_data_size_threshold = 1024;
|
|
24
|
+
};
|
|
25
|
+
|
|
26
|
+
struct OnnxBinaryArtifact {
|
|
27
|
+
std::string model_bytes;
|
|
28
|
+
std::string external_data_bytes;
|
|
29
|
+
bool has_external_data = false;
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
MLX_API std::string ir_to_onnx_json(
|
|
33
|
+
const std::string& ir_json,
|
|
34
|
+
int64_t opset,
|
|
35
|
+
const std::string& model_name);
|
|
36
|
+
|
|
37
|
+
MLX_API std::string ir_to_onnx(
|
|
38
|
+
const std::string& target_path,
|
|
39
|
+
const std::string& ir_json,
|
|
40
|
+
int64_t opset,
|
|
41
|
+
const std::string& model_name,
|
|
42
|
+
const OnnxBinaryWriteOptions& options);
|
|
43
|
+
|
|
44
|
+
MLX_API std::string ir_compatibility_report_json(
|
|
45
|
+
const std::string& ir_json);
|
|
46
|
+
|
|
47
|
+
MLX_API std::string export_ir_json(
|
|
48
|
+
const IrCaptureFunction& fun,
|
|
49
|
+
const mlx::core::Args& args,
|
|
50
|
+
const mlx::core::Kwargs& kwargs,
|
|
51
|
+
bool shapeless);
|
|
52
|
+
|
|
53
|
+
MLX_API std::string export_onnx_compatibility_report_json(
|
|
54
|
+
const IrCaptureFunction& fun,
|
|
55
|
+
const mlx::core::Args& args,
|
|
56
|
+
const mlx::core::Kwargs& kwargs,
|
|
57
|
+
bool shapeless);
|
|
58
|
+
|
|
59
|
+
MLX_API std::string export_onnx_json(
|
|
60
|
+
const IrCaptureFunction& fun,
|
|
61
|
+
const mlx::core::Args& args,
|
|
62
|
+
const mlx::core::Kwargs& kwargs,
|
|
63
|
+
bool shapeless,
|
|
64
|
+
int64_t opset,
|
|
65
|
+
const std::string& model_name);
|
|
66
|
+
|
|
67
|
+
MLX_API std::string export_onnx(
|
|
68
|
+
const std::string& target_path,
|
|
69
|
+
const IrCaptureFunction& fun,
|
|
70
|
+
const mlx::core::Args& args,
|
|
71
|
+
const mlx::core::Kwargs& kwargs,
|
|
72
|
+
bool shapeless,
|
|
73
|
+
int64_t opset,
|
|
74
|
+
const std::string& model_name,
|
|
75
|
+
const OnnxBinaryWriteOptions& options);
|
|
76
|
+
|
|
77
|
+
MLX_API OnnxBinaryArtifact build_onnx_binary_artifact_from_onnx_json(
|
|
78
|
+
const std::string& onnx_json,
|
|
79
|
+
const OnnxBinaryWriteOptions& options);
|
|
80
|
+
|
|
81
|
+
MLX_API std::string write_onnx_binary_artifact_to_path(
|
|
82
|
+
const std::string& target_path,
|
|
83
|
+
const OnnxBinaryArtifact& artifact,
|
|
84
|
+
const OnnxBinaryWriteOptions& options);
|
|
85
|
+
|
|
86
|
+
MLX_API bool ir_is_unsupported_error_message(const std::string& message);
|
|
87
|
+
|
|
88
|
+
} // namespace mlx::onnx
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
#include "detail.hpp"
|
|
2
|
+
#include "json.hpp"
|
|
3
|
+
|
|
4
|
+
#include <sstream>
|
|
5
|
+
#include <stdexcept>
|
|
6
|
+
|
|
7
|
+
namespace mlx::onnx {
|
|
8
|
+
namespace {
|
|
9
|
+
|
|
10
|
+
OrderedJson parse_json_payload_from_string(
|
|
11
|
+
const std::string& raw,
|
|
12
|
+
const char* label) {
|
|
13
|
+
try {
|
|
14
|
+
return OrderedJson::parse(raw);
|
|
15
|
+
} catch (const std::exception& error) {
|
|
16
|
+
std::ostringstream out;
|
|
17
|
+
out << "failed to parse " << label << ": " << error.what();
|
|
18
|
+
throw std::invalid_argument(
|
|
19
|
+
detail::tagged_error_message("ir.api", out.str()));
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
template <typename Result, typename Callable>
|
|
24
|
+
Result with_ir_api_error_tag(Callable&& callable) {
|
|
25
|
+
try {
|
|
26
|
+
return callable();
|
|
27
|
+
} catch (const std::exception& error) {
|
|
28
|
+
throw std::runtime_error(
|
|
29
|
+
detail::tagged_error_message("ir.api", error.what()));
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
std::string ir_to_onnx_json_impl(
|
|
34
|
+
const std::string& ir_json,
|
|
35
|
+
int64_t opset,
|
|
36
|
+
const std::string& model_name) {
|
|
37
|
+
const auto payload =
|
|
38
|
+
parse_json_payload_from_string(ir_json, "graph ir json");
|
|
39
|
+
return ir_to_onnx_json_payload(payload, opset, model_name).dump();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
std::string ir_compatibility_report_json_impl(
|
|
43
|
+
const std::string& ir_json) {
|
|
44
|
+
const auto payload =
|
|
45
|
+
parse_json_payload_from_string(ir_json, "graph ir json");
|
|
46
|
+
return ir_compatibility_report_payload(payload).dump();
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
OnnxBinaryArtifact build_onnx_binary_artifact_from_onnx_json_impl(
|
|
50
|
+
const std::string& onnx_json,
|
|
51
|
+
const OnnxBinaryWriteOptions& options) {
|
|
52
|
+
const auto payload = parse_json_payload_from_string(onnx_json, "onnx json");
|
|
53
|
+
return build_onnx_binary_artifact_from_stub(payload, options);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
} // namespace
|
|
57
|
+
|
|
58
|
+
std::string ir_to_onnx_json(
|
|
59
|
+
const std::string& ir_json,
|
|
60
|
+
int64_t opset,
|
|
61
|
+
const std::string& model_name) {
|
|
62
|
+
return with_ir_api_error_tag<std::string>([&]() {
|
|
63
|
+
return ir_to_onnx_json_impl(ir_json, opset, model_name);
|
|
64
|
+
});
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
std::string ir_compatibility_report_json(
|
|
68
|
+
const std::string& ir_json) {
|
|
69
|
+
return with_ir_api_error_tag<std::string>(
|
|
70
|
+
[&]() { return ir_compatibility_report_json_impl(ir_json); });
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
OnnxBinaryArtifact build_onnx_binary_artifact_from_onnx_json(
|
|
74
|
+
const std::string& onnx_json,
|
|
75
|
+
const OnnxBinaryWriteOptions& options) {
|
|
76
|
+
return with_ir_api_error_tag<OnnxBinaryArtifact>([&]() {
|
|
77
|
+
return build_onnx_binary_artifact_from_onnx_json_impl(onnx_json, options);
|
|
78
|
+
});
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
} // namespace mlx::onnx
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
#include "detail.hpp"
|
|
2
|
+
|
|
3
|
+
#include <exception>
|
|
4
|
+
#include <limits>
|
|
5
|
+
#include <optional>
|
|
6
|
+
#include <set>
|
|
7
|
+
#include <utility>
|
|
8
|
+
|
|
9
|
+
namespace mlx::onnx::detail {
|
|
10
|
+
|
|
11
|
+
OrderedJson ir_compatibility_report_payload_impl(
|
|
12
|
+
const OrderedJson& payload) {
|
|
13
|
+
// Simulate lowering node-by-node on cloned state. Successful probes commit
|
|
14
|
+
// inferred state so later nodes see realistic tensor facts.
|
|
15
|
+
OrderedJson probe_initializers = OrderedJson::array();
|
|
16
|
+
auto probe_used_tensor_names = collect_payload_tensor_names(payload);
|
|
17
|
+
auto probe_known_shapes = collect_known_tensor_shapes(payload);
|
|
18
|
+
auto probe_known_dtypes = collect_known_tensor_dtypes(payload);
|
|
19
|
+
|
|
20
|
+
OrderedJson node_support = OrderedJson::array();
|
|
21
|
+
size_t unsupported_nodes = 0;
|
|
22
|
+
std::set<std::string> unsupported_ops;
|
|
23
|
+
|
|
24
|
+
const auto& source_nodes = payload.at("nodes");
|
|
25
|
+
for (size_t index = 0; index < source_nodes.size(); ++index) {
|
|
26
|
+
const auto& node = source_nodes.at(index);
|
|
27
|
+
const auto op = node.at("op").get<std::string>();
|
|
28
|
+
|
|
29
|
+
bool supported = false;
|
|
30
|
+
std::optional<std::string> mapped;
|
|
31
|
+
|
|
32
|
+
try {
|
|
33
|
+
auto trial_initializers = probe_initializers;
|
|
34
|
+
auto trial_used_tensor_names = probe_used_tensor_names;
|
|
35
|
+
auto trial_known_shapes = probe_known_shapes;
|
|
36
|
+
auto trial_known_dtypes = probe_known_dtypes;
|
|
37
|
+
LoweringContext trial_lowering{
|
|
38
|
+
trial_initializers,
|
|
39
|
+
trial_used_tensor_names,
|
|
40
|
+
trial_known_shapes,
|
|
41
|
+
trial_known_dtypes};
|
|
42
|
+
|
|
43
|
+
auto lowered = lower_onnx_node_default(node, index, trial_lowering);
|
|
44
|
+
if (!lowered.empty() && lowered.front().contains("op_type") &&
|
|
45
|
+
lowered.front().at("op_type").is_string()) {
|
|
46
|
+
mapped = lowered.front().at("op_type").get<std::string>();
|
|
47
|
+
} else {
|
|
48
|
+
mapped = onnx_op_type_for_node(node, false, &trial_known_shapes);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
probe_initializers = std::move(trial_initializers);
|
|
52
|
+
probe_used_tensor_names = std::move(trial_used_tensor_names);
|
|
53
|
+
probe_known_shapes = std::move(trial_known_shapes);
|
|
54
|
+
probe_known_dtypes = std::move(trial_known_dtypes);
|
|
55
|
+
supported = true;
|
|
56
|
+
} catch (const std::exception&) {
|
|
57
|
+
try {
|
|
58
|
+
mapped = onnx_op_type_for_node(node, false, &probe_known_shapes);
|
|
59
|
+
} catch (const std::exception&) {
|
|
60
|
+
mapped = std::nullopt;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
OrderedJson entry = OrderedJson::object();
|
|
65
|
+
entry["index"] = index;
|
|
66
|
+
entry["op"] = op;
|
|
67
|
+
entry["supported"] = supported;
|
|
68
|
+
if (mapped.has_value()) {
|
|
69
|
+
entry["onnx_op_type"] = mapped.value();
|
|
70
|
+
} else {
|
|
71
|
+
entry["onnx_op_type"] = nullptr;
|
|
72
|
+
}
|
|
73
|
+
node_support.push_back(std::move(entry));
|
|
74
|
+
|
|
75
|
+
if (!supported) {
|
|
76
|
+
++unsupported_nodes;
|
|
77
|
+
unsupported_ops.insert(op);
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
int64_t ir_version = kGraphIrVersion;
|
|
82
|
+
if (payload.contains("ir_version")) {
|
|
83
|
+
const auto& value = payload.at("ir_version");
|
|
84
|
+
if (value.is_number_integer()) {
|
|
85
|
+
ir_version = value.get<int64_t>();
|
|
86
|
+
} else if (value.is_number_unsigned()) {
|
|
87
|
+
const auto raw = value.get<uint64_t>();
|
|
88
|
+
if (raw <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
|
|
89
|
+
ir_version = static_cast<int64_t>(raw);
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
OrderedJson unsupported_ops_json = OrderedJson::array();
|
|
95
|
+
for (const auto& unsupported_op : unsupported_ops) {
|
|
96
|
+
unsupported_ops_json.push_back(unsupported_op);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
OrderedJson report = OrderedJson::object();
|
|
100
|
+
report["format"] = "webgpu_compat_report_v1";
|
|
101
|
+
report["ir_version"] = ir_version;
|
|
102
|
+
report["total_nodes"] = source_nodes.size();
|
|
103
|
+
report["supported_nodes"] = source_nodes.size() - unsupported_nodes;
|
|
104
|
+
report["unsupported_nodes"] = unsupported_nodes;
|
|
105
|
+
report["unsupported_ops"] = std::move(unsupported_ops_json);
|
|
106
|
+
report["ready_for_stub_conversion"] = unsupported_nodes == 0;
|
|
107
|
+
report["nodes"] = std::move(node_support);
|
|
108
|
+
return report;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
} // namespace mlx::onnx::detail
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstddef>
|
|
4
|
+
#include <cstdint>
|
|
5
|
+
#include <functional>
|
|
6
|
+
#include <map>
|
|
7
|
+
#include <optional>
|
|
8
|
+
#include <set>
|
|
9
|
+
#include <string>
|
|
10
|
+
#include <vector>
|
|
11
|
+
|
|
12
|
+
#include "json.hpp"
|
|
13
|
+
|
|
14
|
+
namespace mlx::onnx::detail {
|
|
15
|
+
|
|
16
|
+
using Shape = std::vector<int64_t>;
|
|
17
|
+
using ShapeMap = std::map<std::string, Shape>;
|
|
18
|
+
using DtypeMap = std::map<std::string, std::string>;
|
|
19
|
+
using NameSet = std::set<std::string>;
|
|
20
|
+
|
|
21
|
+
std::string tagged_error_message(
|
|
22
|
+
const std::string& tag,
|
|
23
|
+
const std::string& message);
|
|
24
|
+
bool json_is_numeric(const OrderedJson& value);
|
|
25
|
+
int64_t normalized_integer_scalar(
|
|
26
|
+
const OrderedJson& value,
|
|
27
|
+
const std::string& label);
|
|
28
|
+
std::vector<int64_t> normalize_integer_vector(
|
|
29
|
+
const OrderedJson& value,
|
|
30
|
+
const std::string& label);
|
|
31
|
+
std::vector<std::string> parse_string_array(
|
|
32
|
+
const OrderedJson& value,
|
|
33
|
+
const std::string& label);
|
|
34
|
+
std::string canonical_dtype(const std::string& dtype);
|
|
35
|
+
std::optional<std::string> canonical_dtype(
|
|
36
|
+
const std::optional<std::string>& dtype);
|
|
37
|
+
std::optional<std::string> onnx_effective_dtype(
|
|
38
|
+
const std::optional<std::string>& dtype);
|
|
39
|
+
std::string onnx_effective_dtype(const std::string& dtype);
|
|
40
|
+
std::string onnx_dtype_symbol(const std::string& dtype);
|
|
41
|
+
void for_each_declared_payload_tensor(
|
|
42
|
+
const OrderedJson& payload,
|
|
43
|
+
const std::function<void(const OrderedJson&)>& visitor);
|
|
44
|
+
|
|
45
|
+
struct LoweringContext {
|
|
46
|
+
OrderedJson& initializers;
|
|
47
|
+
NameSet& used_tensor_names;
|
|
48
|
+
ShapeMap& known_shapes;
|
|
49
|
+
DtypeMap& known_dtypes;
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
NameSet collect_payload_tensor_names(const OrderedJson& payload);
|
|
53
|
+
ShapeMap collect_known_tensor_shapes(const OrderedJson& payload);
|
|
54
|
+
DtypeMap collect_known_tensor_dtypes(const OrderedJson& payload);
|
|
55
|
+
std::optional<std::string> onnx_op_type_for_node(
|
|
56
|
+
const OrderedJson& node,
|
|
57
|
+
bool strict,
|
|
58
|
+
const ShapeMap* known_shapes);
|
|
59
|
+
std::vector<OrderedJson> lower_onnx_node_default(
|
|
60
|
+
const OrderedJson& node,
|
|
61
|
+
size_t node_index,
|
|
62
|
+
LoweringContext& lowering);
|
|
63
|
+
OrderedJson ir_compatibility_report_payload_impl(
|
|
64
|
+
const OrderedJson& payload);
|
|
65
|
+
OnnxBinaryArtifact build_onnx_binary_artifact_from_stub_impl(
|
|
66
|
+
const OrderedJson& onnx_stub,
|
|
67
|
+
const OnnxBinaryWriteOptions& options);
|
|
68
|
+
|
|
69
|
+
} // namespace mlx::onnx::detail
|