mlx 0.30.7.2 → 0.30.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/extconf.rb +267 -8
- data/ext/mlx/native.cpp +112 -58
- data/ext/mlx-onnx/native.cpp +1402 -0
- data/ext/mlx-onnx/native.hpp +19 -0
- data/lib/mlx/core.rb +342 -117
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +2 -0
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +8 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +10 -0
- data/lib/mlx/nn/base.rb +4 -0
- data/lib/mlx/nn/layers/linear.rb +2 -3
- data/lib/mlx/onnx.rb +250 -0
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx-onnx/webgpu_harness.rb +289 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
- data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
- data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
- data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
- data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
- data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
- data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
- data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
- data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
- data/submodules/mlx-onnx/CMakeLists.txt +159 -0
- data/submodules/mlx-onnx/LICENSE +21 -0
- data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
- data/submodules/mlx-onnx/src/api.cpp +81 -0
- data/submodules/mlx-onnx/src/compat.cpp +111 -0
- data/submodules/mlx-onnx/src/detail.hpp +69 -0
- data/submodules/mlx-onnx/src/export.cpp +653 -0
- data/submodules/mlx-onnx/src/io.cpp +61 -0
- data/submodules/mlx-onnx/src/json.hpp +25 -0
- data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
- data/submodules/mlx-onnx/src/mappings.cpp +201 -0
- data/submodules/mlx-onnx/src/mappings.hpp +16 -0
- data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
- data/submodules/mlx-onnx/src/shared.cpp +206 -0
- metadata +665 -567
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
- /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
- /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
|
@@ -0,0 +1,1029 @@
|
|
|
1
|
+
#include "detail.hpp"
|
|
2
|
+
#include "mappings.hpp"
|
|
3
|
+
|
|
4
|
+
#include <array>
|
|
5
|
+
#include <bit>
|
|
6
|
+
#include <cctype>
|
|
7
|
+
#include <cerrno>
|
|
8
|
+
#include <cmath>
|
|
9
|
+
#include <cstdint>
|
|
10
|
+
#include <cstdlib>
|
|
11
|
+
#include <cstring>
|
|
12
|
+
#include <sstream>
|
|
13
|
+
#include <stdexcept>
|
|
14
|
+
#include <string>
|
|
15
|
+
#include <string_view>
|
|
16
|
+
#include <utility>
|
|
17
|
+
#include <vector>
|
|
18
|
+
|
|
19
|
+
namespace mlx::onnx::detail {
|
|
20
|
+
|
|
21
|
+
struct OnnxAttributeModel {
|
|
22
|
+
std::string name;
|
|
23
|
+
OrderedJson value;
|
|
24
|
+
};
|
|
25
|
+
|
|
26
|
+
struct OnnxNodeModel {
|
|
27
|
+
std::string name;
|
|
28
|
+
std::string op_type;
|
|
29
|
+
std::vector<std::string> inputs;
|
|
30
|
+
std::vector<std::string> outputs;
|
|
31
|
+
std::vector<OnnxAttributeModel> attributes;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
struct OnnxValueInfoModel {
|
|
35
|
+
std::string name;
|
|
36
|
+
std::vector<int64_t> shape;
|
|
37
|
+
std::string dtype;
|
|
38
|
+
int elem_type = 0;
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
struct OnnxInitializerModel {
|
|
42
|
+
std::string name;
|
|
43
|
+
std::vector<int64_t> shape;
|
|
44
|
+
std::string dtype;
|
|
45
|
+
int elem_type = 0;
|
|
46
|
+
OrderedJson values;
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
struct OnnxGraphModel {
|
|
50
|
+
std::string name;
|
|
51
|
+
std::vector<OnnxNodeModel> nodes;
|
|
52
|
+
std::vector<OnnxInitializerModel> initializers;
|
|
53
|
+
std::vector<OnnxValueInfoModel> inputs;
|
|
54
|
+
std::vector<OnnxValueInfoModel> outputs;
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
struct OnnxStubModel {
|
|
58
|
+
int64_t opset = 0;
|
|
59
|
+
std::string producer_name;
|
|
60
|
+
OnnxGraphModel graph;
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
enum class PbWireType : uint8_t {
|
|
64
|
+
kVarint = 0,
|
|
65
|
+
kFixed64 = 1,
|
|
66
|
+
kLengthDelimited = 2,
|
|
67
|
+
kFixed32 = 5,
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
OnnxStubModel onnx_stub_model_from_json(const OrderedJson& onnx_stub);
|
|
71
|
+
|
|
72
|
+
void pb_write_varint(std::string& out, uint64_t value);
|
|
73
|
+
void pb_write_key(std::string& out, int field_number, PbWireType wire_type);
|
|
74
|
+
void pb_write_varint_field(std::string& out, int field_number, uint64_t value);
|
|
75
|
+
void pb_write_int64_field(std::string& out, int field_number, int64_t value);
|
|
76
|
+
void pb_write_string_field(
|
|
77
|
+
std::string& out,
|
|
78
|
+
int field_number,
|
|
79
|
+
const std::string& value);
|
|
80
|
+
void pb_write_bytes_field(
|
|
81
|
+
std::string& out,
|
|
82
|
+
int field_number,
|
|
83
|
+
const std::string& value);
|
|
84
|
+
void pb_write_message_field(
|
|
85
|
+
std::string& out,
|
|
86
|
+
int field_number,
|
|
87
|
+
const std::string& message);
|
|
88
|
+
void pb_write_fixed32_field(std::string& out, int field_number, uint32_t value);
|
|
89
|
+
|
|
90
|
+
std::string pb_encode_tensor(
|
|
91
|
+
const OnnxInitializerModel& tensor,
|
|
92
|
+
const OnnxBinaryWriteOptions& options,
|
|
93
|
+
std::string& external_data,
|
|
94
|
+
uint64_t& external_offset,
|
|
95
|
+
bool& has_external_data);
|
|
96
|
+
|
|
97
|
+
} // namespace mlx::onnx::detail
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
namespace mlx::onnx::detail {
|
|
103
|
+
namespace {
|
|
104
|
+
|
|
105
|
+
std::vector<int64_t> parse_shape(
|
|
106
|
+
const OrderedJson& shape,
|
|
107
|
+
const std::string& label) {
|
|
108
|
+
if (!shape.is_array()) {
|
|
109
|
+
throw std::invalid_argument(label + " shape must be an Array");
|
|
110
|
+
}
|
|
111
|
+
return normalize_integer_vector(shape, label + " shape dim");
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
int parse_elem_type_for_dtype(
|
|
115
|
+
const std::string& dtype,
|
|
116
|
+
const std::string& label) {
|
|
117
|
+
const auto symbol = onnx_dtype_symbol(onnx_effective_dtype(dtype));
|
|
118
|
+
try {
|
|
119
|
+
return onnx_elem_type_from_symbol_lookup(symbol);
|
|
120
|
+
} catch (const std::exception& error) {
|
|
121
|
+
throw std::invalid_argument(
|
|
122
|
+
label + " unsupported dtype " + dtype + ": " + error.what());
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
int parse_elem_type_from_value_info(const OrderedJson& info) {
|
|
127
|
+
if (info.contains("onnx_elem_type") &&
|
|
128
|
+
info.at("onnx_elem_type").is_string()) {
|
|
129
|
+
return onnx_elem_type_from_symbol_lookup(
|
|
130
|
+
info.at("onnx_elem_type").get<std::string>());
|
|
131
|
+
}
|
|
132
|
+
return parse_elem_type_for_dtype(
|
|
133
|
+
info.at("dtype").get<std::string>(), "value_info");
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
OnnxValueInfoModel parse_value_info(const OrderedJson& info) {
|
|
137
|
+
if (!info.is_object()) {
|
|
138
|
+
throw std::invalid_argument("value_info must be an Object");
|
|
139
|
+
}
|
|
140
|
+
OnnxValueInfoModel out;
|
|
141
|
+
out.name = info.at("name").get<std::string>();
|
|
142
|
+
out.shape = parse_shape(info.at("shape"), "value_info");
|
|
143
|
+
out.dtype = info.at("dtype").get<std::string>();
|
|
144
|
+
out.elem_type = parse_elem_type_from_value_info(info);
|
|
145
|
+
return out;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
OnnxAttributeModel parse_attribute(
|
|
149
|
+
const std::string& name,
|
|
150
|
+
const OrderedJson& value) {
|
|
151
|
+
OnnxAttributeModel out;
|
|
152
|
+
out.name = name;
|
|
153
|
+
out.value = value;
|
|
154
|
+
return out;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
OnnxNodeModel parse_node(const OrderedJson& node) {
|
|
158
|
+
if (!node.is_object()) {
|
|
159
|
+
throw std::invalid_argument("node must be an Object");
|
|
160
|
+
}
|
|
161
|
+
OnnxNodeModel out;
|
|
162
|
+
out.name = node.at("name").get<std::string>();
|
|
163
|
+
out.op_type = node.at("op_type").get<std::string>();
|
|
164
|
+
out.inputs = parse_string_array(node.at("inputs"), "node inputs");
|
|
165
|
+
out.outputs = parse_string_array(node.at("outputs"), "node outputs");
|
|
166
|
+
const auto& attrs = node.at("attributes");
|
|
167
|
+
if (!attrs.is_object()) {
|
|
168
|
+
throw std::invalid_argument("node attributes must be an Object");
|
|
169
|
+
}
|
|
170
|
+
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
|
|
171
|
+
out.attributes.push_back(parse_attribute(it.key(), it.value()));
|
|
172
|
+
}
|
|
173
|
+
return out;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
OnnxInitializerModel parse_initializer(const OrderedJson& initializer) {
|
|
177
|
+
if (!initializer.is_object()) {
|
|
178
|
+
throw std::invalid_argument("initializer must be an Object");
|
|
179
|
+
}
|
|
180
|
+
OnnxInitializerModel out;
|
|
181
|
+
out.name = initializer.at("name").get<std::string>();
|
|
182
|
+
out.shape = parse_shape(initializer.at("shape"), "initializer");
|
|
183
|
+
out.dtype = initializer.at("dtype").get<std::string>();
|
|
184
|
+
out.elem_type = parse_elem_type_for_dtype(out.dtype, "initializer");
|
|
185
|
+
out.values = initializer.at("values");
|
|
186
|
+
return out;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
OnnxGraphModel parse_graph(const OrderedJson& graph) {
|
|
190
|
+
if (!graph.is_object()) {
|
|
191
|
+
throw std::invalid_argument("onnx stub graph must be an Object");
|
|
192
|
+
}
|
|
193
|
+
OnnxGraphModel out;
|
|
194
|
+
out.name = graph.at("name").get<std::string>();
|
|
195
|
+
|
|
196
|
+
const auto& nodes = graph.at("nodes");
|
|
197
|
+
if (!nodes.is_array()) {
|
|
198
|
+
throw std::invalid_argument("onnx stub graph nodes must be an Array");
|
|
199
|
+
}
|
|
200
|
+
out.nodes.reserve(nodes.size());
|
|
201
|
+
for (const auto& node : nodes) {
|
|
202
|
+
out.nodes.push_back(parse_node(node));
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
const auto& initializers = graph.at("initializers");
|
|
206
|
+
if (!initializers.is_array()) {
|
|
207
|
+
throw std::invalid_argument(
|
|
208
|
+
"onnx stub graph initializers must be an Array");
|
|
209
|
+
}
|
|
210
|
+
out.initializers.reserve(initializers.size());
|
|
211
|
+
for (const auto& initializer : initializers) {
|
|
212
|
+
out.initializers.push_back(parse_initializer(initializer));
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
const auto& inputs = graph.at("inputs");
|
|
216
|
+
if (!inputs.is_array()) {
|
|
217
|
+
throw std::invalid_argument("onnx stub graph inputs must be an Array");
|
|
218
|
+
}
|
|
219
|
+
out.inputs.reserve(inputs.size());
|
|
220
|
+
for (const auto& input : inputs) {
|
|
221
|
+
out.inputs.push_back(parse_value_info(input));
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
const auto& outputs = graph.at("outputs");
|
|
225
|
+
if (!outputs.is_array()) {
|
|
226
|
+
throw std::invalid_argument("onnx stub graph outputs must be an Array");
|
|
227
|
+
}
|
|
228
|
+
out.outputs.reserve(outputs.size());
|
|
229
|
+
for (const auto& output : outputs) {
|
|
230
|
+
out.outputs.push_back(parse_value_info(output));
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
return out;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
} // namespace
|
|
237
|
+
|
|
238
|
+
OnnxStubModel onnx_stub_model_from_json(const OrderedJson& onnx_stub) {
|
|
239
|
+
if (!onnx_stub.is_object()) {
|
|
240
|
+
throw std::invalid_argument("onnx stub must be a JSON object");
|
|
241
|
+
}
|
|
242
|
+
if (!onnx_stub.contains("graph") || !onnx_stub.at("graph").is_object()) {
|
|
243
|
+
throw std::invalid_argument("onnx stub must include graph object");
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
OnnxStubModel out;
|
|
247
|
+
out.opset =
|
|
248
|
+
normalized_integer_scalar(onnx_stub.at("opset"), "onnx_stub opset");
|
|
249
|
+
out.producer_name = onnx_stub.contains("producer_name") &&
|
|
250
|
+
onnx_stub.at("producer_name").is_string()
|
|
251
|
+
? onnx_stub.at("producer_name").get<std::string>()
|
|
252
|
+
: "mlx-ruby";
|
|
253
|
+
out.graph = parse_graph(onnx_stub.at("graph"));
|
|
254
|
+
return out;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
} // namespace mlx::onnx::detail
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
namespace mlx::onnx::detail {
|
|
262
|
+
|
|
263
|
+
void pb_write_varint(std::string& out, uint64_t value) {
|
|
264
|
+
while (value >= 0x80) {
|
|
265
|
+
out.push_back(static_cast<char>((value & 0x7fU) | 0x80U));
|
|
266
|
+
value >>= 7;
|
|
267
|
+
}
|
|
268
|
+
out.push_back(static_cast<char>(value));
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
void pb_write_key(std::string& out, int field_number, PbWireType wire_type) {
|
|
272
|
+
const uint64_t key = (static_cast<uint64_t>(field_number) << 3) |
|
|
273
|
+
static_cast<uint64_t>(wire_type);
|
|
274
|
+
pb_write_varint(out, key);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
void pb_write_varint_field(std::string& out, int field_number, uint64_t value) {
|
|
278
|
+
pb_write_key(out, field_number, PbWireType::kVarint);
|
|
279
|
+
pb_write_varint(out, value);
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
void pb_write_int64_field(std::string& out, int field_number, int64_t value) {
|
|
283
|
+
pb_write_varint_field(out, field_number, static_cast<uint64_t>(value));
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
void pb_write_string_field(
|
|
287
|
+
std::string& out,
|
|
288
|
+
int field_number,
|
|
289
|
+
const std::string& value) {
|
|
290
|
+
pb_write_key(out, field_number, PbWireType::kLengthDelimited);
|
|
291
|
+
pb_write_varint(out, static_cast<uint64_t>(value.size()));
|
|
292
|
+
out.append(value);
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
void pb_write_bytes_field(
|
|
296
|
+
std::string& out,
|
|
297
|
+
int field_number,
|
|
298
|
+
const std::string& value) {
|
|
299
|
+
pb_write_string_field(out, field_number, value);
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
void pb_write_message_field(
|
|
303
|
+
std::string& out,
|
|
304
|
+
int field_number,
|
|
305
|
+
const std::string& message) {
|
|
306
|
+
pb_write_key(out, field_number, PbWireType::kLengthDelimited);
|
|
307
|
+
pb_write_varint(out, static_cast<uint64_t>(message.size()));
|
|
308
|
+
out.append(message);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
void pb_write_fixed32_field(
|
|
312
|
+
std::string& out,
|
|
313
|
+
int field_number,
|
|
314
|
+
uint32_t value) {
|
|
315
|
+
pb_write_key(out, field_number, PbWireType::kFixed32);
|
|
316
|
+
std::array<char, 4> bytes = {
|
|
317
|
+
static_cast<char>(value & 0xffU),
|
|
318
|
+
static_cast<char>((value >> 8) & 0xffU),
|
|
319
|
+
static_cast<char>((value >> 16) & 0xffU),
|
|
320
|
+
static_cast<char>((value >> 24) & 0xffU)};
|
|
321
|
+
out.append(bytes.data(), bytes.size());
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
} // namespace mlx::onnx::detail
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
namespace mlx::onnx::detail {
|
|
329
|
+
namespace {
|
|
330
|
+
|
|
331
|
+
bool json_integer_like(const OrderedJson& value) {
|
|
332
|
+
if (value.is_number_integer() || value.is_number_unsigned()) {
|
|
333
|
+
return true;
|
|
334
|
+
}
|
|
335
|
+
if (value.is_number_float()) {
|
|
336
|
+
const double v = value.get<double>();
|
|
337
|
+
return std::isfinite(v) && std::trunc(v) == v;
|
|
338
|
+
}
|
|
339
|
+
return false;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
size_t expected_initializer_value_count(const std::vector<int64_t>& dims) {
|
|
343
|
+
if (dims.empty()) {
|
|
344
|
+
return 1;
|
|
345
|
+
}
|
|
346
|
+
size_t total = 1;
|
|
347
|
+
for (const auto dim : dims) {
|
|
348
|
+
if (dim < 0) {
|
|
349
|
+
throw std::invalid_argument(
|
|
350
|
+
"initializer shape values must be non-negative");
|
|
351
|
+
}
|
|
352
|
+
total *= static_cast<size_t>(dim);
|
|
353
|
+
}
|
|
354
|
+
return total;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
void collect_initializer_leaves(
|
|
358
|
+
const OrderedJson& value,
|
|
359
|
+
std::vector<const OrderedJson*>& out) {
|
|
360
|
+
if (value.is_array()) {
|
|
361
|
+
for (const auto& item : value) {
|
|
362
|
+
collect_initializer_leaves(item, out);
|
|
363
|
+
}
|
|
364
|
+
return;
|
|
365
|
+
}
|
|
366
|
+
out.push_back(&value);
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
uint16_t float32_to_float16_bits(float value) {
|
|
370
|
+
uint32_t bits = 0;
|
|
371
|
+
std::memcpy(&bits, &value, sizeof(bits));
|
|
372
|
+
|
|
373
|
+
const uint32_t sign = (bits >> 16) & 0x8000U;
|
|
374
|
+
int32_t exponent = static_cast<int32_t>((bits >> 23) & 0xffU) - 127 + 15;
|
|
375
|
+
uint32_t mantissa = bits & 0x7fffffU;
|
|
376
|
+
|
|
377
|
+
if (exponent <= 0) {
|
|
378
|
+
if (exponent < -10) {
|
|
379
|
+
return static_cast<uint16_t>(sign);
|
|
380
|
+
}
|
|
381
|
+
mantissa |= 0x800000U;
|
|
382
|
+
const uint32_t shift = static_cast<uint32_t>(14 - exponent);
|
|
383
|
+
uint32_t half_mantissa = mantissa >> shift;
|
|
384
|
+
if ((mantissa >> (shift - 1)) & 1U) {
|
|
385
|
+
half_mantissa += 1U;
|
|
386
|
+
}
|
|
387
|
+
return static_cast<uint16_t>(sign | half_mantissa);
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
if (exponent >= 0x1f) {
|
|
391
|
+
return static_cast<uint16_t>(sign | 0x7c00U);
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
uint16_t half = static_cast<uint16_t>(
|
|
395
|
+
sign | (static_cast<uint32_t>(exponent) << 10) | (mantissa >> 13));
|
|
396
|
+
if (mantissa & 0x00001000U) {
|
|
397
|
+
half = static_cast<uint16_t>(half + 1U);
|
|
398
|
+
}
|
|
399
|
+
return half;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
uint16_t float32_to_bfloat16_bits(float value) {
|
|
403
|
+
uint32_t bits = 0;
|
|
404
|
+
std::memcpy(&bits, &value, sizeof(bits));
|
|
405
|
+
return static_cast<uint16_t>((bits + 0x00008000U) >> 16);
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
template <typename T>
|
|
409
|
+
void append_le_bytes(std::string& out, T value) {
|
|
410
|
+
std::array<char, sizeof(T)> bytes{};
|
|
411
|
+
std::memcpy(bytes.data(), &value, sizeof(T));
|
|
412
|
+
out.append(bytes.data(), bytes.size());
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
void raise_invalid_complex_literal(const std::string& label) {
|
|
416
|
+
std::ostringstream out;
|
|
417
|
+
out << label << " unsupported complex literal";
|
|
418
|
+
throw std::invalid_argument(out.str());
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
std::string normalize_complex_literal(std::string_view value) {
|
|
422
|
+
std::string normalized;
|
|
423
|
+
normalized.reserve(value.size());
|
|
424
|
+
for (const char ch : value) {
|
|
425
|
+
if (!std::isspace(static_cast<unsigned char>(ch))) {
|
|
426
|
+
normalized.push_back(ch);
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
return normalized;
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
std::size_t find_complex_literal_split(std::string_view value) {
|
|
433
|
+
for (std::size_t idx = value.size(); idx > 0; --idx) {
|
|
434
|
+
const std::size_t pos = idx - 1;
|
|
435
|
+
if (pos == 0) {
|
|
436
|
+
continue;
|
|
437
|
+
}
|
|
438
|
+
const char ch = value[pos];
|
|
439
|
+
if ((ch == '+' || ch == '-') && value[pos - 1] != 'e' &&
|
|
440
|
+
value[pos - 1] != 'E') {
|
|
441
|
+
return pos;
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
return std::string_view::npos;
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
double parse_complex_literal_double(
|
|
448
|
+
const std::string& text,
|
|
449
|
+
const std::string& label) {
|
|
450
|
+
if (text.empty()) {
|
|
451
|
+
raise_invalid_complex_literal(label);
|
|
452
|
+
}
|
|
453
|
+
char* end = nullptr;
|
|
454
|
+
errno = 0;
|
|
455
|
+
const double value = std::strtod(text.c_str(), &end);
|
|
456
|
+
if (text.c_str() == end) {
|
|
457
|
+
raise_invalid_complex_literal(label);
|
|
458
|
+
}
|
|
459
|
+
if (errno == ERANGE) {
|
|
460
|
+
raise_invalid_complex_literal(label);
|
|
461
|
+
}
|
|
462
|
+
if (static_cast<size_t>(end - text.c_str()) != text.size()) {
|
|
463
|
+
raise_invalid_complex_literal(label);
|
|
464
|
+
}
|
|
465
|
+
return value;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
std::pair<float, float> complex64_pair_from_string(
|
|
469
|
+
const std::string& raw,
|
|
470
|
+
const std::string& label) {
|
|
471
|
+
const std::string normalized = normalize_complex_literal(raw);
|
|
472
|
+
if (normalized.empty()) {
|
|
473
|
+
raise_invalid_complex_literal(label);
|
|
474
|
+
}
|
|
475
|
+
const char last = normalized.back();
|
|
476
|
+
if (last != 'i' && last != 'I') {
|
|
477
|
+
raise_invalid_complex_literal(label);
|
|
478
|
+
}
|
|
479
|
+
std::string_view remaining(normalized.data(), normalized.size() - 1);
|
|
480
|
+
if (remaining.empty()) {
|
|
481
|
+
return {0.0f, 1.0f};
|
|
482
|
+
}
|
|
483
|
+
const std::size_t split = find_complex_literal_split(remaining);
|
|
484
|
+
std::string real_text;
|
|
485
|
+
std::string imag_text;
|
|
486
|
+
if (split == std::string_view::npos) {
|
|
487
|
+
real_text = "0";
|
|
488
|
+
if (remaining == "+" || remaining == "-") {
|
|
489
|
+
imag_text = remaining == "+" ? "1" : "-1";
|
|
490
|
+
} else {
|
|
491
|
+
imag_text = std::string(remaining);
|
|
492
|
+
}
|
|
493
|
+
} else {
|
|
494
|
+
real_text = std::string(remaining.substr(0, split));
|
|
495
|
+
imag_text = std::string(remaining.substr(split));
|
|
496
|
+
}
|
|
497
|
+
const double real = parse_complex_literal_double(real_text, label);
|
|
498
|
+
const double imag = parse_complex_literal_double(imag_text, label);
|
|
499
|
+
return {static_cast<float>(real), static_cast<float>(imag)};
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
std::pair<float, float> complex64_pair_from_json(
|
|
503
|
+
const OrderedJson& value,
|
|
504
|
+
const std::string& label) {
|
|
505
|
+
if (value.is_object() && value.contains("__mlx_complex__")) {
|
|
506
|
+
const auto& pair = value.at("__mlx_complex__");
|
|
507
|
+
if (!pair.is_array() || pair.size() != 2 || !json_is_numeric(pair.at(0)) ||
|
|
508
|
+
!json_is_numeric(pair.at(1))) {
|
|
509
|
+
throw std::invalid_argument(label + " invalid complex marker");
|
|
510
|
+
}
|
|
511
|
+
return {
|
|
512
|
+
static_cast<float>(pair.at(0).get<double>()),
|
|
513
|
+
static_cast<float>(pair.at(1).get<double>())};
|
|
514
|
+
}
|
|
515
|
+
if (value.is_string()) {
|
|
516
|
+
return complex64_pair_from_string(
|
|
517
|
+
value.get_ref<const std::string&>(), label);
|
|
518
|
+
}
|
|
519
|
+
if (value.is_boolean()) {
|
|
520
|
+
return {value.get<bool>() ? 1.0f : 0.0f, 0.0f};
|
|
521
|
+
}
|
|
522
|
+
if (json_is_numeric(value)) {
|
|
523
|
+
return {static_cast<float>(value.get<double>()), 0.0f};
|
|
524
|
+
}
|
|
525
|
+
throw std::invalid_argument(
|
|
526
|
+
label + " unsupported complex64 initializer leaf");
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
template <typename IntegerType>
|
|
530
|
+
std::string tensor_raw_integer_initializer(
|
|
531
|
+
const std::vector<const OrderedJson*>& leaves,
|
|
532
|
+
size_t expected,
|
|
533
|
+
const std::string& label) {
|
|
534
|
+
std::string raw;
|
|
535
|
+
raw.reserve(expected * sizeof(IntegerType));
|
|
536
|
+
for (const auto* item : leaves) {
|
|
537
|
+
append_le_bytes<IntegerType>(
|
|
538
|
+
raw, static_cast<IntegerType>(normalized_integer_scalar(*item, label)));
|
|
539
|
+
}
|
|
540
|
+
return raw;
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
double numeric_initializer_leaf(
|
|
544
|
+
const OrderedJson& value,
|
|
545
|
+
const std::string& numeric_error_message) {
|
|
546
|
+
if (!json_is_numeric(value)) {
|
|
547
|
+
throw std::invalid_argument(numeric_error_message);
|
|
548
|
+
}
|
|
549
|
+
return value.get<double>();
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
template <typename FloatType>
|
|
553
|
+
std::string tensor_raw_float_initializer(
|
|
554
|
+
const std::vector<const OrderedJson*>& leaves,
|
|
555
|
+
size_t expected,
|
|
556
|
+
const std::string& numeric_error_message) {
|
|
557
|
+
std::string raw;
|
|
558
|
+
raw.reserve(expected * sizeof(FloatType));
|
|
559
|
+
for (const auto* item : leaves) {
|
|
560
|
+
append_le_bytes<FloatType>(
|
|
561
|
+
raw,
|
|
562
|
+
static_cast<FloatType>(
|
|
563
|
+
numeric_initializer_leaf(*item, numeric_error_message)));
|
|
564
|
+
}
|
|
565
|
+
return raw;
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
std::string tensor_raw_bool_initializer(
|
|
569
|
+
const std::vector<const OrderedJson*>& leaves,
|
|
570
|
+
size_t expected) {
|
|
571
|
+
std::string raw;
|
|
572
|
+
raw.reserve(expected);
|
|
573
|
+
for (const auto* item : leaves) {
|
|
574
|
+
uint8_t value = 0;
|
|
575
|
+
if (item->is_boolean()) {
|
|
576
|
+
value = item->get<bool>() ? 1 : 0;
|
|
577
|
+
} else if (json_integer_like(*item)) {
|
|
578
|
+
value = normalized_integer_scalar(*item, "bool initializer leaf") == 0
|
|
579
|
+
? 0
|
|
580
|
+
: 1;
|
|
581
|
+
} else if (item->is_number_float()) {
|
|
582
|
+
value = item->get<double>() == 0.0 ? 0 : 1;
|
|
583
|
+
} else {
|
|
584
|
+
throw std::invalid_argument(
|
|
585
|
+
"bool initializer values must be numeric/boolean");
|
|
586
|
+
}
|
|
587
|
+
raw.push_back(static_cast<char>(value));
|
|
588
|
+
}
|
|
589
|
+
return raw;
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
std::string tensor_raw_float16_initializer(
|
|
593
|
+
const std::vector<const OrderedJson*>& leaves,
|
|
594
|
+
size_t expected) {
|
|
595
|
+
std::string raw;
|
|
596
|
+
raw.reserve(expected * sizeof(uint16_t));
|
|
597
|
+
for (const auto* item : leaves) {
|
|
598
|
+
const float value = static_cast<float>(numeric_initializer_leaf(
|
|
599
|
+
*item, "float16 initializer values must be numeric"));
|
|
600
|
+
append_le_bytes<uint16_t>(raw, float32_to_float16_bits(value));
|
|
601
|
+
}
|
|
602
|
+
return raw;
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
std::string tensor_raw_bfloat16_initializer(
|
|
606
|
+
const std::vector<const OrderedJson*>& leaves,
|
|
607
|
+
size_t expected) {
|
|
608
|
+
std::string raw;
|
|
609
|
+
raw.reserve(expected * sizeof(uint16_t));
|
|
610
|
+
for (const auto* item : leaves) {
|
|
611
|
+
const float value = static_cast<float>(numeric_initializer_leaf(
|
|
612
|
+
*item, "bfloat16 initializer values must be numeric"));
|
|
613
|
+
append_le_bytes<uint16_t>(raw, float32_to_bfloat16_bits(value));
|
|
614
|
+
}
|
|
615
|
+
return raw;
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
std::string tensor_raw_complex64_initializer(
|
|
619
|
+
const std::vector<const OrderedJson*>& leaves,
|
|
620
|
+
size_t expected) {
|
|
621
|
+
std::string raw;
|
|
622
|
+
raw.reserve(expected * sizeof(float) * 2);
|
|
623
|
+
for (const auto* item : leaves) {
|
|
624
|
+
auto [real, imag] =
|
|
625
|
+
complex64_pair_from_json(*item, "complex64 initializer leaf");
|
|
626
|
+
append_le_bytes<float>(raw, real);
|
|
627
|
+
append_le_bytes<float>(raw, imag);
|
|
628
|
+
}
|
|
629
|
+
return raw;
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
std::string tensor_raw_bytes_from_initializer(
|
|
633
|
+
const OnnxInitializerModel& tensor) {
|
|
634
|
+
const size_t expected = expected_initializer_value_count(tensor.shape);
|
|
635
|
+
|
|
636
|
+
std::vector<const OrderedJson*> leaves;
|
|
637
|
+
collect_initializer_leaves(tensor.values, leaves);
|
|
638
|
+
if (leaves.size() != expected) {
|
|
639
|
+
std::ostringstream out;
|
|
640
|
+
out << "initializer " << tensor.name << " has " << leaves.size()
|
|
641
|
+
<< " values but expected " << expected;
|
|
642
|
+
throw std::invalid_argument(out.str());
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
const std::string& dtype = tensor.dtype;
|
|
646
|
+
if (dtype == "bool" || dtype == "bool_") {
|
|
647
|
+
return tensor_raw_bool_initializer(leaves, expected);
|
|
648
|
+
}
|
|
649
|
+
if (dtype == "uint8") {
|
|
650
|
+
return tensor_raw_integer_initializer<uint8_t>(
|
|
651
|
+
leaves, expected, "uint8 initializer leaf");
|
|
652
|
+
}
|
|
653
|
+
if (dtype == "uint16") {
|
|
654
|
+
return tensor_raw_integer_initializer<uint16_t>(
|
|
655
|
+
leaves, expected, "uint16 initializer leaf");
|
|
656
|
+
}
|
|
657
|
+
if (dtype == "uint32") {
|
|
658
|
+
return tensor_raw_integer_initializer<uint32_t>(
|
|
659
|
+
leaves, expected, "uint32 initializer leaf");
|
|
660
|
+
}
|
|
661
|
+
if (dtype == "uint64") {
|
|
662
|
+
return tensor_raw_integer_initializer<uint64_t>(
|
|
663
|
+
leaves, expected, "uint64 initializer leaf");
|
|
664
|
+
}
|
|
665
|
+
if (dtype == "int8") {
|
|
666
|
+
return tensor_raw_integer_initializer<int8_t>(
|
|
667
|
+
leaves, expected, "int8 initializer leaf");
|
|
668
|
+
}
|
|
669
|
+
if (dtype == "int16") {
|
|
670
|
+
return tensor_raw_integer_initializer<int16_t>(
|
|
671
|
+
leaves, expected, "int16 initializer leaf");
|
|
672
|
+
}
|
|
673
|
+
if (dtype == "int32") {
|
|
674
|
+
return tensor_raw_integer_initializer<int32_t>(
|
|
675
|
+
leaves, expected, "int32 initializer leaf");
|
|
676
|
+
}
|
|
677
|
+
if (dtype == "int64") {
|
|
678
|
+
return tensor_raw_integer_initializer<int64_t>(
|
|
679
|
+
leaves, expected, "int64 initializer leaf");
|
|
680
|
+
}
|
|
681
|
+
if (dtype == "float16") {
|
|
682
|
+
return tensor_raw_float16_initializer(leaves, expected);
|
|
683
|
+
}
|
|
684
|
+
if (dtype == "bfloat16") {
|
|
685
|
+
return tensor_raw_bfloat16_initializer(leaves, expected);
|
|
686
|
+
}
|
|
687
|
+
if (dtype == "float32") {
|
|
688
|
+
return tensor_raw_float_initializer<float>(
|
|
689
|
+
leaves, expected, "float32 initializer values must be numeric");
|
|
690
|
+
}
|
|
691
|
+
if (dtype == "float64") {
|
|
692
|
+
return tensor_raw_float_initializer<double>(
|
|
693
|
+
leaves, expected, "float64 initializer values must be numeric");
|
|
694
|
+
}
|
|
695
|
+
if (dtype == "complex64") {
|
|
696
|
+
return tensor_raw_complex64_initializer(leaves, expected);
|
|
697
|
+
}
|
|
698
|
+
|
|
699
|
+
throw std::invalid_argument(
|
|
700
|
+
"unsupported initializer dtype for native ONNX binary export: " + dtype);
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
std::string pb_encode_string_string_entry(
|
|
704
|
+
const std::string& key,
|
|
705
|
+
const std::string& value) {
|
|
706
|
+
std::string out;
|
|
707
|
+
pb_write_string_field(out, 1, key);
|
|
708
|
+
pb_write_string_field(out, 2, value);
|
|
709
|
+
return out;
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
void pb_encode_tensor_header(
|
|
713
|
+
std::string& out,
|
|
714
|
+
const std::vector<int64_t>& shape,
|
|
715
|
+
int elem_type,
|
|
716
|
+
const std::string& name) {
|
|
717
|
+
for (const auto dim : shape) {
|
|
718
|
+
pb_write_int64_field(out, 1, dim);
|
|
719
|
+
}
|
|
720
|
+
pb_write_varint_field(out, 2, static_cast<uint64_t>(elem_type));
|
|
721
|
+
pb_write_string_field(out, 8, name);
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
uint32_t raw_bytes_u32_le_at(const std::string& raw, size_t offset) {
|
|
725
|
+
const auto b0 =
|
|
726
|
+
static_cast<uint32_t>(static_cast<unsigned char>(raw[offset + 0]));
|
|
727
|
+
const auto b1 =
|
|
728
|
+
static_cast<uint32_t>(static_cast<unsigned char>(raw[offset + 1]));
|
|
729
|
+
const auto b2 =
|
|
730
|
+
static_cast<uint32_t>(static_cast<unsigned char>(raw[offset + 2]));
|
|
731
|
+
const auto b3 =
|
|
732
|
+
static_cast<uint32_t>(static_cast<unsigned char>(raw[offset + 3]));
|
|
733
|
+
return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
void pb_encode_tensor_inline_complex64_data(
|
|
737
|
+
std::string& out,
|
|
738
|
+
const std::string& raw) {
|
|
739
|
+
if ((raw.size() % sizeof(float)) != 0) {
|
|
740
|
+
throw std::invalid_argument(
|
|
741
|
+
"complex64 initializer raw byte count must be divisible by 4");
|
|
742
|
+
}
|
|
743
|
+
for (size_t offset = 0; offset < raw.size(); offset += sizeof(float)) {
|
|
744
|
+
pb_write_fixed32_field(out, 4, raw_bytes_u32_le_at(raw, offset));
|
|
745
|
+
}
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
void pb_encode_tensor_inline_data(
|
|
749
|
+
std::string& out,
|
|
750
|
+
const std::string& dtype,
|
|
751
|
+
const std::string& raw) {
|
|
752
|
+
if (dtype == "complex64") {
|
|
753
|
+
pb_encode_tensor_inline_complex64_data(out, raw);
|
|
754
|
+
return;
|
|
755
|
+
}
|
|
756
|
+
pb_write_bytes_field(out, 9, raw);
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
void pb_encode_tensor_external_data_entries(
|
|
760
|
+
std::string& out,
|
|
761
|
+
const std::string& external_data_file,
|
|
762
|
+
uint64_t external_offset,
|
|
763
|
+
size_t raw_size) {
|
|
764
|
+
pb_write_message_field(
|
|
765
|
+
out, 13, pb_encode_string_string_entry("location", external_data_file));
|
|
766
|
+
pb_write_message_field(
|
|
767
|
+
out,
|
|
768
|
+
13,
|
|
769
|
+
pb_encode_string_string_entry("offset", std::to_string(external_offset)));
|
|
770
|
+
pb_write_message_field(
|
|
771
|
+
out,
|
|
772
|
+
13,
|
|
773
|
+
pb_encode_string_string_entry("length", std::to_string(raw_size)));
|
|
774
|
+
pb_write_varint_field(out, 14, 1);
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
void append_tensor_external_data(
|
|
778
|
+
std::string& external_data,
|
|
779
|
+
uint64_t& external_offset,
|
|
780
|
+
const std::string& raw) {
|
|
781
|
+
external_data.append(raw);
|
|
782
|
+
external_offset += static_cast<uint64_t>(raw.size());
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
bool should_externalize_tensor_raw_bytes(
|
|
786
|
+
const OnnxBinaryWriteOptions& options,
|
|
787
|
+
const std::string& raw) {
|
|
788
|
+
return options.external_data &&
|
|
789
|
+
static_cast<int64_t>(raw.size()) >= options.external_data_size_threshold;
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
} // namespace
|
|
793
|
+
|
|
794
|
+
std::string pb_encode_tensor(
|
|
795
|
+
const OnnxInitializerModel& tensor,
|
|
796
|
+
const OnnxBinaryWriteOptions& options,
|
|
797
|
+
std::string& external_data,
|
|
798
|
+
uint64_t& external_offset,
|
|
799
|
+
bool& has_external_data) {
|
|
800
|
+
const std::string raw = tensor_raw_bytes_from_initializer(tensor);
|
|
801
|
+
|
|
802
|
+
std::string out;
|
|
803
|
+
pb_encode_tensor_header(out, tensor.shape, tensor.elem_type, tensor.name);
|
|
804
|
+
|
|
805
|
+
const bool externalize = should_externalize_tensor_raw_bytes(options, raw);
|
|
806
|
+
if (!externalize) {
|
|
807
|
+
pb_encode_tensor_inline_data(out, tensor.dtype, raw);
|
|
808
|
+
return out;
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
has_external_data = true;
|
|
812
|
+
pb_encode_tensor_external_data_entries(
|
|
813
|
+
out, options.external_data_file, external_offset, raw.size());
|
|
814
|
+
append_tensor_external_data(external_data, external_offset, raw);
|
|
815
|
+
return out;
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
} // namespace mlx::onnx::detail
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
namespace mlx::onnx::detail {
|
|
824
|
+
namespace {
|
|
825
|
+
|
|
826
|
+
std::string pb_encode_tensor_shape(const std::vector<int64_t>& shape) {
|
|
827
|
+
std::string out;
|
|
828
|
+
for (const auto dim : shape) {
|
|
829
|
+
std::string dim_msg;
|
|
830
|
+
pb_write_int64_field(dim_msg, 1, dim);
|
|
831
|
+
pb_write_message_field(out, 1, dim_msg);
|
|
832
|
+
}
|
|
833
|
+
return out;
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
std::string pb_encode_tensor_type_proto(
|
|
837
|
+
int elem_type,
|
|
838
|
+
const std::vector<int64_t>& shape) {
|
|
839
|
+
std::string tensor_type;
|
|
840
|
+
pb_write_varint_field(tensor_type, 1, static_cast<uint64_t>(elem_type));
|
|
841
|
+
pb_write_message_field(tensor_type, 2, pb_encode_tensor_shape(shape));
|
|
842
|
+
|
|
843
|
+
std::string type_proto;
|
|
844
|
+
pb_write_message_field(type_proto, 1, tensor_type);
|
|
845
|
+
return type_proto;
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
std::string pb_encode_value_info(const OnnxValueInfoModel& info) {
|
|
849
|
+
std::string out;
|
|
850
|
+
pb_write_string_field(out, 1, info.name);
|
|
851
|
+
pb_write_message_field(
|
|
852
|
+
out, 2, pb_encode_tensor_type_proto(info.elem_type, info.shape));
|
|
853
|
+
return out;
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
std::string pb_encode_attribute(
|
|
857
|
+
const std::string& op_type,
|
|
858
|
+
const std::string& name,
|
|
859
|
+
const OrderedJson& value) {
|
|
860
|
+
std::string out;
|
|
861
|
+
pb_write_string_field(out, 1, name);
|
|
862
|
+
|
|
863
|
+
if (op_type == "Cast" && name == "to" && value.is_string()) {
|
|
864
|
+
const int cast_to =
|
|
865
|
+
onnx_elem_type_from_symbol_lookup(value.get<std::string>());
|
|
866
|
+
pb_write_varint_field(out, 20, 2);
|
|
867
|
+
pb_write_int64_field(out, 3, cast_to);
|
|
868
|
+
return out;
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
if (value.is_boolean()) {
|
|
872
|
+
pb_write_varint_field(out, 20, 2);
|
|
873
|
+
pb_write_int64_field(out, 3, value.get<bool>() ? 1 : 0);
|
|
874
|
+
return out;
|
|
875
|
+
}
|
|
876
|
+
if (value.is_number_integer() || value.is_number_unsigned()) {
|
|
877
|
+
pb_write_varint_field(out, 20, 2);
|
|
878
|
+
pb_write_int64_field(out, 3, normalized_integer_scalar(value, "attribute"));
|
|
879
|
+
return out;
|
|
880
|
+
}
|
|
881
|
+
if (value.is_number_float()) {
|
|
882
|
+
pb_write_varint_field(out, 20, 1);
|
|
883
|
+
pb_write_fixed32_field(
|
|
884
|
+
out,
|
|
885
|
+
2,
|
|
886
|
+
std::bit_cast<uint32_t>(static_cast<float>(value.get<double>())));
|
|
887
|
+
return out;
|
|
888
|
+
}
|
|
889
|
+
if (value.is_string()) {
|
|
890
|
+
pb_write_varint_field(out, 20, 3);
|
|
891
|
+
pb_write_bytes_field(out, 4, value.get<std::string>());
|
|
892
|
+
return out;
|
|
893
|
+
}
|
|
894
|
+
if (value.is_array()) {
|
|
895
|
+
bool all_integer_typed = true;
|
|
896
|
+
bool all_numeric = true;
|
|
897
|
+
bool all_string = true;
|
|
898
|
+
for (const auto& item : value) {
|
|
899
|
+
all_integer_typed = all_integer_typed &&
|
|
900
|
+
(item.is_boolean() || item.is_number_integer() ||
|
|
901
|
+
item.is_number_unsigned());
|
|
902
|
+
all_numeric = all_numeric && json_is_numeric(item);
|
|
903
|
+
all_string = all_string && item.is_string();
|
|
904
|
+
}
|
|
905
|
+
if (value.empty() || all_integer_typed) {
|
|
906
|
+
pb_write_varint_field(out, 20, 7);
|
|
907
|
+
for (const auto& item : value) {
|
|
908
|
+
pb_write_int64_field(
|
|
909
|
+
out, 8, normalized_integer_scalar(item, "attribute vector"));
|
|
910
|
+
}
|
|
911
|
+
return out;
|
|
912
|
+
}
|
|
913
|
+
if (all_numeric) {
|
|
914
|
+
pb_write_varint_field(out, 20, 6);
|
|
915
|
+
for (const auto& item : value) {
|
|
916
|
+
pb_write_fixed32_field(
|
|
917
|
+
out,
|
|
918
|
+
7,
|
|
919
|
+
std::bit_cast<uint32_t>(static_cast<float>(item.get<double>())));
|
|
920
|
+
}
|
|
921
|
+
return out;
|
|
922
|
+
}
|
|
923
|
+
if (all_string) {
|
|
924
|
+
pb_write_varint_field(out, 20, 8);
|
|
925
|
+
for (const auto& item : value) {
|
|
926
|
+
pb_write_bytes_field(out, 9, item.get<std::string>());
|
|
927
|
+
}
|
|
928
|
+
return out;
|
|
929
|
+
}
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
throw std::invalid_argument(
|
|
933
|
+
"unsupported ONNX attribute type for " + op_type + "." + name);
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
std::string pb_encode_node(const OnnxNodeModel& node) {
|
|
937
|
+
std::string out;
|
|
938
|
+
for (const auto& input : node.inputs) {
|
|
939
|
+
pb_write_string_field(out, 1, input);
|
|
940
|
+
}
|
|
941
|
+
for (const auto& output : node.outputs) {
|
|
942
|
+
pb_write_string_field(out, 2, output);
|
|
943
|
+
}
|
|
944
|
+
pb_write_string_field(out, 3, node.name);
|
|
945
|
+
pb_write_string_field(out, 4, node.op_type);
|
|
946
|
+
for (const auto& attribute : node.attributes) {
|
|
947
|
+
pb_write_message_field(
|
|
948
|
+
out,
|
|
949
|
+
5,
|
|
950
|
+
pb_encode_attribute(node.op_type, attribute.name, attribute.value));
|
|
951
|
+
}
|
|
952
|
+
return out;
|
|
953
|
+
}
|
|
954
|
+
|
|
955
|
+
std::string pb_encode_graph(
|
|
956
|
+
const OnnxGraphModel& graph,
|
|
957
|
+
const OnnxBinaryWriteOptions& options,
|
|
958
|
+
std::string& external_data,
|
|
959
|
+
bool& has_external_data) {
|
|
960
|
+
std::string out;
|
|
961
|
+
pb_write_string_field(out, 2, graph.name);
|
|
962
|
+
|
|
963
|
+
for (const auto& node : graph.nodes) {
|
|
964
|
+
pb_write_message_field(out, 1, pb_encode_node(node));
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
uint64_t external_offset = 0;
|
|
968
|
+
for (const auto& initializer : graph.initializers) {
|
|
969
|
+
pb_write_message_field(
|
|
970
|
+
out,
|
|
971
|
+
5,
|
|
972
|
+
pb_encode_tensor(
|
|
973
|
+
initializer,
|
|
974
|
+
options,
|
|
975
|
+
external_data,
|
|
976
|
+
external_offset,
|
|
977
|
+
has_external_data));
|
|
978
|
+
}
|
|
979
|
+
|
|
980
|
+
for (const auto& input : graph.inputs) {
|
|
981
|
+
pb_write_message_field(out, 11, pb_encode_value_info(input));
|
|
982
|
+
}
|
|
983
|
+
for (const auto& output : graph.outputs) {
|
|
984
|
+
pb_write_message_field(out, 12, pb_encode_value_info(output));
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
return out;
|
|
988
|
+
}
|
|
989
|
+
|
|
990
|
+
std::string pb_encode_opset_import(int64_t opset) {
|
|
991
|
+
std::string out;
|
|
992
|
+
pb_write_string_field(out, 1, "");
|
|
993
|
+
pb_write_int64_field(out, 2, opset);
|
|
994
|
+
return out;
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
} // namespace
|
|
998
|
+
|
|
999
|
+
OnnxBinaryArtifact build_onnx_binary_artifact_from_stub_impl(
|
|
1000
|
+
const OrderedJson& onnx_stub,
|
|
1001
|
+
const OnnxBinaryWriteOptions& options) {
|
|
1002
|
+
const auto model_spec = onnx_stub_model_from_json(onnx_stub);
|
|
1003
|
+
|
|
1004
|
+
std::string external_data;
|
|
1005
|
+
bool has_external_data = false;
|
|
1006
|
+
const std::string graph_message = pb_encode_graph(
|
|
1007
|
+
model_spec.graph, options, external_data, has_external_data);
|
|
1008
|
+
|
|
1009
|
+
std::string model;
|
|
1010
|
+
pb_write_int64_field(model, 1, 10);
|
|
1011
|
+
pb_write_string_field(model, 2, model_spec.producer_name);
|
|
1012
|
+
pb_write_message_field(model, 7, graph_message);
|
|
1013
|
+
pb_write_message_field(model, 8, pb_encode_opset_import(model_spec.opset));
|
|
1014
|
+
|
|
1015
|
+
OnnxBinaryArtifact artifact;
|
|
1016
|
+
artifact.model_bytes = std::move(model);
|
|
1017
|
+
artifact.external_data_bytes = std::move(external_data);
|
|
1018
|
+
artifact.has_external_data = has_external_data;
|
|
1019
|
+
return artifact;
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
} // namespace mlx::onnx::detail
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
namespace mlx::onnx::detail {
|
|
1026
|
+
// Intentionally kept as a lightweight aggregation unit after splitting ONNX
|
|
1027
|
+
// binary encoding into wire/tensor/assembly translation units.
|
|
1028
|
+
} // namespace mlx::onnx::detail
|
|
1029
|
+
|