mlx 0.30.7.3 → 0.30.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/extconf.rb +267 -8
- data/ext/mlx/native.cpp +104 -56
- data/ext/mlx-onnx/native.cpp +1402 -0
- data/ext/mlx-onnx/native.hpp +19 -0
- data/lib/mlx/core.rb +342 -117
- data/lib/mlx/nn/base.rb +4 -0
- data/lib/mlx/nn/layers/linear.rb +2 -3
- data/lib/mlx/onnx.rb +250 -0
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx-onnx/webgpu_harness.rb +289 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
- data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
- data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
- data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
- data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
- data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
- data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
- data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
- data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
- data/submodules/mlx-onnx/CMakeLists.txt +159 -0
- data/submodules/mlx-onnx/LICENSE +21 -0
- data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
- data/submodules/mlx-onnx/src/api.cpp +81 -0
- data/submodules/mlx-onnx/src/compat.cpp +111 -0
- data/submodules/mlx-onnx/src/detail.hpp +69 -0
- data/submodules/mlx-onnx/src/export.cpp +653 -0
- data/submodules/mlx-onnx/src/io.cpp +61 -0
- data/submodules/mlx-onnx/src/json.hpp +25 -0
- data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
- data/submodules/mlx-onnx/src/mappings.cpp +201 -0
- data/submodules/mlx-onnx/src/mappings.hpp +16 -0
- data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
- data/submodules/mlx-onnx/src/shared.cpp +206 -0
- metadata +609 -563
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
- /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
- /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
#include <sstream>
|
|
11
11
|
|
|
12
12
|
#include "mlx/backend/cuda/cuda.h"
|
|
13
|
+
#include "mlx/backend/metal/metal.h"
|
|
13
14
|
#include "mlx/fast_primitives.h"
|
|
14
15
|
#include "mlx/ops.h"
|
|
15
16
|
#include "mlx/primitives.h"
|
|
@@ -2311,6 +2312,40 @@ array argmax(
|
|
|
2311
2312
|
return out;
|
|
2312
2313
|
}
|
|
2313
2314
|
|
|
2315
|
+
array hanning(int M, StreamOrDevice s /* = {} */) {
|
|
2316
|
+
if (M < 1) {
|
|
2317
|
+
return array({});
|
|
2318
|
+
}
|
|
2319
|
+
if (M == 1) {
|
|
2320
|
+
return ones({1}, float32, s);
|
|
2321
|
+
}
|
|
2322
|
+
|
|
2323
|
+
auto n = arange(0, M, float32, s);
|
|
2324
|
+
array factor(M_PI / (M - 1), float32);
|
|
2325
|
+
return square(sin(multiply(factor, n, s), s), s);
|
|
2326
|
+
}
|
|
2327
|
+
|
|
2328
|
+
array hamming(int M, StreamOrDevice s /* = {} */) {
|
|
2329
|
+
if (M < 1) {
|
|
2330
|
+
return array({});
|
|
2331
|
+
}
|
|
2332
|
+
if (M == 1) {
|
|
2333
|
+
return ones({1}, float32, s);
|
|
2334
|
+
}
|
|
2335
|
+
|
|
2336
|
+
auto n = arange(0, M, float32, s);
|
|
2337
|
+
float factor_val = (2.0 * M_PI) / (M - 1);
|
|
2338
|
+
auto factor = array(factor_val, float32);
|
|
2339
|
+
|
|
2340
|
+
auto arg = multiply(factor, n, s);
|
|
2341
|
+
auto cos_vals = cos(arg, s);
|
|
2342
|
+
|
|
2343
|
+
auto left_coef = array(0.54f, float32);
|
|
2344
|
+
auto right_coef = array(0.46f, float32);
|
|
2345
|
+
|
|
2346
|
+
return subtract(left_coef, multiply(right_coef, cos_vals, s), s);
|
|
2347
|
+
}
|
|
2348
|
+
|
|
2314
2349
|
/** Returns a sorted copy of the flattened array. */
|
|
2315
2350
|
array sort(const array& a, StreamOrDevice s /* = {} */) {
|
|
2316
2351
|
int size = a.size();
|
|
@@ -4209,6 +4244,34 @@ std::pair<Dtype, QuantizationMode> validate_mode_with_type(
|
|
|
4209
4244
|
}
|
|
4210
4245
|
}
|
|
4211
4246
|
|
|
4247
|
+
void validate_global_scale(
|
|
4248
|
+
std::string_view tag,
|
|
4249
|
+
QuantizationMode qmode,
|
|
4250
|
+
const std::optional<array>& global_scale) {
|
|
4251
|
+
if (global_scale.has_value()) {
|
|
4252
|
+
if (qmode != QuantizationMode::Nvfp4) {
|
|
4253
|
+
std::ostringstream msg;
|
|
4254
|
+
msg << "[" << tag << "] Global scale is only supported for 'nvfp4' "
|
|
4255
|
+
<< "quantization mode.";
|
|
4256
|
+
throw std::invalid_argument(msg.str());
|
|
4257
|
+
} else {
|
|
4258
|
+
if (global_scale->size() != 1) {
|
|
4259
|
+
std::ostringstream msg;
|
|
4260
|
+
msg << "[" << tag << "] Global scale must be a scalar but got shape "
|
|
4261
|
+
<< global_scale->shape() << ".";
|
|
4262
|
+
throw std::invalid_argument(msg.str());
|
|
4263
|
+
}
|
|
4264
|
+
// TODO: not sure if type should be restricted to float32
|
|
4265
|
+
if (global_scale->dtype() != float32) {
|
|
4266
|
+
std::ostringstream msg;
|
|
4267
|
+
msg << "[" << tag << "] Global scale must have dtype float32 but got "
|
|
4268
|
+
<< global_scale->dtype() << ".";
|
|
4269
|
+
throw std::invalid_argument(msg.str());
|
|
4270
|
+
}
|
|
4271
|
+
}
|
|
4272
|
+
}
|
|
4273
|
+
}
|
|
4274
|
+
|
|
4212
4275
|
array quantized_matmul(
|
|
4213
4276
|
array x,
|
|
4214
4277
|
array w,
|
|
@@ -4251,7 +4314,6 @@ array quantized_matmul(
|
|
|
4251
4314
|
if (x.ndim() > 2 && w.ndim() > 2) {
|
|
4252
4315
|
inputs = broadcast_arrays(inputs, {-2, -1}, s);
|
|
4253
4316
|
}
|
|
4254
|
-
|
|
4255
4317
|
auto out_shape = inputs[0].shape();
|
|
4256
4318
|
out_shape.back() = w_outer_dims;
|
|
4257
4319
|
return array(
|
|
@@ -4267,7 +4329,10 @@ void validate_qqmm_inputs(
|
|
|
4267
4329
|
array w,
|
|
4268
4330
|
std::optional<array> scales_w,
|
|
4269
4331
|
int group_size,
|
|
4270
|
-
int bits
|
|
4332
|
+
int bits,
|
|
4333
|
+
std::optional<array> global_scale_x,
|
|
4334
|
+
std::optional<array> global_scale_w,
|
|
4335
|
+
QuantizationMode qmode) {
|
|
4271
4336
|
// check 2D (for now)
|
|
4272
4337
|
if (x.ndim() > 2 || w.ndim() > 2) {
|
|
4273
4338
|
std::ostringstream msg;
|
|
@@ -4304,6 +4369,19 @@ void validate_qqmm_inputs(
|
|
|
4304
4369
|
<< "first argument dtype == " << x.dtype() << ".";
|
|
4305
4370
|
throw std::invalid_argument(msg.str());
|
|
4306
4371
|
}
|
|
4372
|
+
// validate global scales
|
|
4373
|
+
validate_global_scale("qqmm", qmode, global_scale_x);
|
|
4374
|
+
validate_global_scale("qqmm", qmode, global_scale_w);
|
|
4375
|
+
// For nvfp4 mode, both global scales must be provided together or neither
|
|
4376
|
+
if (qmode == QuantizationMode::Nvfp4) {
|
|
4377
|
+
bool has_x = global_scale_x.has_value();
|
|
4378
|
+
bool has_w = global_scale_w.has_value();
|
|
4379
|
+
if (has_x != has_w) {
|
|
4380
|
+
throw std::invalid_argument(
|
|
4381
|
+
"[qqmm] For nvfp4 mode, either both global_scale_x and "
|
|
4382
|
+
"global_scale_w must be provided, or neither.");
|
|
4383
|
+
}
|
|
4384
|
+
}
|
|
4307
4385
|
}
|
|
4308
4386
|
|
|
4309
4387
|
std::pair<int, int> extract_qqmm_dims(
|
|
@@ -4343,6 +4421,8 @@ array qqmm(
|
|
|
4343
4421
|
std::optional<int> group_size_ /* = std::nullopt */,
|
|
4344
4422
|
std::optional<int> bits_ /* = std::nullopt */,
|
|
4345
4423
|
const std::string& mode /* = "nvfp4" */,
|
|
4424
|
+
const std::optional<array> global_scale_x /* = std::nullopt */,
|
|
4425
|
+
const std::optional<array> global_scale_w /* = std::nullopt */,
|
|
4346
4426
|
StreamOrDevice s /* = {} */) {
|
|
4347
4427
|
auto stream = to_stream(s);
|
|
4348
4428
|
auto qmode = string_to_quantization_mode(mode, "qqmm");
|
|
@@ -4369,7 +4449,8 @@ array qqmm(
|
|
|
4369
4449
|
}
|
|
4370
4450
|
|
|
4371
4451
|
// validate inputs
|
|
4372
|
-
validate_qqmm_inputs(
|
|
4452
|
+
validate_qqmm_inputs(
|
|
4453
|
+
x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);
|
|
4373
4454
|
// validate and extract shapes
|
|
4374
4455
|
auto [w_inner_dims, w_outer_dims] =
|
|
4375
4456
|
extract_qqmm_dims(x, w, scales_w, group_size, bits);
|
|
@@ -4380,6 +4461,11 @@ array qqmm(
|
|
|
4380
4461
|
if (scales_w.has_value()) {
|
|
4381
4462
|
inputs.push_back(*scales_w);
|
|
4382
4463
|
}
|
|
4464
|
+
if (global_scale_x.has_value() && global_scale_w.has_value()) {
|
|
4465
|
+
inputs.push_back(*global_scale_x);
|
|
4466
|
+
inputs.push_back(*global_scale_w);
|
|
4467
|
+
}
|
|
4468
|
+
|
|
4383
4469
|
auto out_shape = inputs[0].shape();
|
|
4384
4470
|
out_shape.back() = w_outer_dims;
|
|
4385
4471
|
auto out = array(
|
|
@@ -4515,6 +4601,7 @@ std::vector<array> fp_quantize(
|
|
|
4515
4601
|
int group_size,
|
|
4516
4602
|
int bits,
|
|
4517
4603
|
QuantizationMode mode,
|
|
4604
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4518
4605
|
Stream s) {
|
|
4519
4606
|
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
|
4520
4607
|
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
|
@@ -4532,6 +4619,12 @@ std::vector<array> fp_quantize(
|
|
|
4532
4619
|
<< bits << ".";
|
|
4533
4620
|
throw std::invalid_argument(msg.str());
|
|
4534
4621
|
}
|
|
4622
|
+
|
|
4623
|
+
auto inputs = std::vector<array>{w};
|
|
4624
|
+
if (global_scale.has_value()) {
|
|
4625
|
+
inputs.push_back(global_scale.value());
|
|
4626
|
+
}
|
|
4627
|
+
|
|
4535
4628
|
auto fallback = [bits = bits, group_size = group_size, s](
|
|
4536
4629
|
const std::vector<array>& inputs) -> std::vector<array> {
|
|
4537
4630
|
auto& w = inputs[0];
|
|
@@ -4543,8 +4636,13 @@ std::vector<array> fp_quantize(
|
|
|
4543
4636
|
divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
|
|
4544
4637
|
if (group_size == 16) {
|
|
4545
4638
|
// convert to e4m3
|
|
4639
|
+
auto scale_encode = inputs.size() > 1
|
|
4640
|
+
? divide(array(448.0f * 6.0f, float32), inputs[1], s)
|
|
4641
|
+
: array(1.0f, float32);
|
|
4642
|
+
scales = multiply(scales, scale_encode, s);
|
|
4546
4643
|
scales = to_fp8(scales, s);
|
|
4547
|
-
wq =
|
|
4644
|
+
wq = multiply(
|
|
4645
|
+
divide(wq, from_fp8(scales, w.dtype(), s), s), scale_encode, s);
|
|
4548
4646
|
} else {
|
|
4549
4647
|
// convert to e8m0
|
|
4550
4648
|
auto z = array(0, scales.dtype());
|
|
@@ -4600,9 +4698,9 @@ std::vector<array> fp_quantize(
|
|
|
4600
4698
|
{uint32, uint8},
|
|
4601
4699
|
std::make_shared<fast::Quantize>(
|
|
4602
4700
|
s, fallback, group_size, bits, mode, false),
|
|
4603
|
-
|
|
4701
|
+
inputs);
|
|
4604
4702
|
}
|
|
4605
|
-
return fallback(
|
|
4703
|
+
return fallback(inputs);
|
|
4606
4704
|
}
|
|
4607
4705
|
|
|
4608
4706
|
std::vector<array> quantize(
|
|
@@ -4610,6 +4708,7 @@ std::vector<array> quantize(
|
|
|
4610
4708
|
std::optional<int> group_size_ /* = std::nullopt */,
|
|
4611
4709
|
std::optional<int> bits_ /* = std::nullopt */,
|
|
4612
4710
|
const std::string& mode /* = "affine" */,
|
|
4711
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4613
4712
|
StreamOrDevice s /* = {} */) {
|
|
4614
4713
|
auto qmode = string_to_quantization_mode(mode, "quantize");
|
|
4615
4714
|
auto [group_size, bits] =
|
|
@@ -4636,11 +4735,17 @@ std::vector<array> quantize(
|
|
|
4636
4735
|
<< " matrix has shape " << w.shape();
|
|
4637
4736
|
throw std::invalid_argument(msg.str());
|
|
4638
4737
|
}
|
|
4639
|
-
|
|
4738
|
+
if (to_stream(s).device == Device::gpu && metal::is_available() &&
|
|
4739
|
+
global_scale.has_value()) {
|
|
4740
|
+
std::ostringstream msg;
|
|
4741
|
+
msg << "[quantize] Global scale is not supported on the Metal backend.";
|
|
4742
|
+
throw std::invalid_argument(msg.str());
|
|
4743
|
+
}
|
|
4744
|
+
validate_global_scale("quantize", qmode, global_scale);
|
|
4640
4745
|
if (qmode == QuantizationMode::Affine) {
|
|
4641
4746
|
return affine_quantize(w, group_size, bits, s);
|
|
4642
4747
|
} else {
|
|
4643
|
-
return fp_quantize(w, group_size, bits, qmode, to_stream(s));
|
|
4748
|
+
return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s));
|
|
4644
4749
|
}
|
|
4645
4750
|
}
|
|
4646
4751
|
|
|
@@ -4745,6 +4850,7 @@ array fp_dequantize(
|
|
|
4745
4850
|
int bits,
|
|
4746
4851
|
Dtype out_type,
|
|
4747
4852
|
QuantizationMode mode,
|
|
4853
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4748
4854
|
Stream s) {
|
|
4749
4855
|
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
|
4750
4856
|
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
|
@@ -4789,6 +4895,11 @@ array fp_dequantize(
|
|
|
4789
4895
|
throw std::invalid_argument(msg.str());
|
|
4790
4896
|
}
|
|
4791
4897
|
|
|
4898
|
+
auto inputs = std::vector<array>{w, scales};
|
|
4899
|
+
if (global_scale.has_value()) {
|
|
4900
|
+
inputs.push_back(global_scale.value());
|
|
4901
|
+
}
|
|
4902
|
+
|
|
4792
4903
|
auto fallback =
|
|
4793
4904
|
[wshape = std::move(wshape),
|
|
4794
4905
|
sshape = std::move(sshape),
|
|
@@ -4831,13 +4942,17 @@ array fp_dequantize(
|
|
|
4831
4942
|
out = reshape(out, {-1, group_size}, s);
|
|
4832
4943
|
scales = reshape(scales, {-1, 1}, s);
|
|
4833
4944
|
if (group_size == 16) {
|
|
4834
|
-
|
|
4945
|
+
array inv_scale_enc = inputs.size() > 2
|
|
4946
|
+
? divide(inputs[2], array(448.0f * 6.0f, out_type), s)
|
|
4947
|
+
: array(1.0f, out_type);
|
|
4948
|
+
scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s);
|
|
4835
4949
|
} else {
|
|
4836
4950
|
scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
|
|
4837
4951
|
scales = power(array(2.0f, out_type), scales, s);
|
|
4838
4952
|
}
|
|
4839
4953
|
return {reshape(multiply(out, scales, s), wshape, s)};
|
|
4840
4954
|
};
|
|
4955
|
+
|
|
4841
4956
|
if (s.device == Device::gpu) {
|
|
4842
4957
|
auto out_shape = w.shape();
|
|
4843
4958
|
out_shape.back() = out_size;
|
|
@@ -4846,9 +4961,9 @@ array fp_dequantize(
|
|
|
4846
4961
|
out_type,
|
|
4847
4962
|
std::make_shared<fast::Quantize>(
|
|
4848
4963
|
s, fallback, group_size, bits, mode, true),
|
|
4849
|
-
|
|
4964
|
+
inputs);
|
|
4850
4965
|
}
|
|
4851
|
-
return fallback(
|
|
4966
|
+
return fallback(inputs)[0];
|
|
4852
4967
|
}
|
|
4853
4968
|
|
|
4854
4969
|
array dequantize(
|
|
@@ -4858,6 +4973,7 @@ array dequantize(
|
|
|
4858
4973
|
std::optional<int> group_size_ /* = std::nullopt */,
|
|
4859
4974
|
std::optional<int> bits_ /* = std::nullopt */,
|
|
4860
4975
|
const std::string& mode /* = "affine" */,
|
|
4976
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4861
4977
|
std::optional<Dtype> dtype /* = std::nullopt */,
|
|
4862
4978
|
StreamOrDevice s /* = {} */) {
|
|
4863
4979
|
auto [out_type, qmode] =
|
|
@@ -4884,6 +5000,14 @@ array dequantize(
|
|
|
4884
5000
|
<< "but it has only " << w.ndim() << ".";
|
|
4885
5001
|
throw std::invalid_argument(msg.str());
|
|
4886
5002
|
}
|
|
5003
|
+
if (global_scale.has_value()) {
|
|
5004
|
+
if (to_stream(s).device == Device::gpu && metal::is_available()) {
|
|
5005
|
+
std::ostringstream msg;
|
|
5006
|
+
msg << "[dequantize] Global scale is not supported on the Metal backend.";
|
|
5007
|
+
throw std::invalid_argument(msg.str());
|
|
5008
|
+
}
|
|
5009
|
+
}
|
|
5010
|
+
validate_global_scale("dequantize", qmode, global_scale);
|
|
4887
5011
|
|
|
4888
5012
|
if (qmode == QuantizationMode::Affine) {
|
|
4889
5013
|
return astype(
|
|
@@ -4892,7 +5016,14 @@ array dequantize(
|
|
|
4892
5016
|
s);
|
|
4893
5017
|
} else {
|
|
4894
5018
|
return fp_dequantize(
|
|
4895
|
-
w,
|
|
5019
|
+
w,
|
|
5020
|
+
scales,
|
|
5021
|
+
group_size,
|
|
5022
|
+
bits,
|
|
5023
|
+
out_type,
|
|
5024
|
+
qmode,
|
|
5025
|
+
global_scale,
|
|
5026
|
+
to_stream(s));
|
|
4896
5027
|
}
|
|
4897
5028
|
}
|
|
4898
5029
|
|
|
@@ -6091,4 +6222,4 @@ array contiguous(
|
|
|
6091
6222
|
{a});
|
|
6092
6223
|
}
|
|
6093
6224
|
|
|
6094
|
-
} // namespace mlx::core
|
|
6225
|
+
} // namespace mlx::core
|
|
@@ -666,6 +666,12 @@ min(const array& a,
|
|
|
666
666
|
MLX_API array
|
|
667
667
|
min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});
|
|
668
668
|
|
|
669
|
+
/** Returns the Hanning window of size M. */
|
|
670
|
+
MLX_API array hanning(int M, StreamOrDevice s = {});
|
|
671
|
+
|
|
672
|
+
/** Returns the Hamming window of size M. */
|
|
673
|
+
MLX_API array hamming(int M, StreamOrDevice s = {});
|
|
674
|
+
|
|
669
675
|
/** Returns the index of the minimum value in the array. */
|
|
670
676
|
MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
671
677
|
inline array argmin(const array& a, StreamOrDevice s = {}) {
|
|
@@ -1391,6 +1397,7 @@ MLX_API std::vector<array> quantize(
|
|
|
1391
1397
|
std::optional<int> group_size = std::nullopt,
|
|
1392
1398
|
std::optional<int> bits = std::nullopt,
|
|
1393
1399
|
const std::string& mode = "affine",
|
|
1400
|
+
const std::optional<array>& global_scale = std::nullopt,
|
|
1394
1401
|
StreamOrDevice s = {});
|
|
1395
1402
|
|
|
1396
1403
|
/** Dequantize a matrix produced by quantize() */
|
|
@@ -1401,17 +1408,20 @@ MLX_API array dequantize(
|
|
|
1401
1408
|
std::optional<int> group_size = std::nullopt,
|
|
1402
1409
|
std::optional<int> bits = std::nullopt,
|
|
1403
1410
|
const std::string& mode = "affine",
|
|
1411
|
+
const std::optional<array>& global_scale = std::nullopt,
|
|
1404
1412
|
std::optional<Dtype> dtype = std::nullopt,
|
|
1405
1413
|
StreamOrDevice s = {});
|
|
1406
1414
|
|
|
1407
1415
|
MLX_API array qqmm(
|
|
1408
1416
|
array x, // input activations
|
|
1409
1417
|
array w, // maybe quantized weights
|
|
1410
|
-
std::optional<array> w_scales = std::nullopt, // optional scales if w
|
|
1411
|
-
|
|
1418
|
+
const std::optional<array> w_scales = std::nullopt, // optional scales if w
|
|
1419
|
+
// is quantized
|
|
1412
1420
|
std::optional<int> group_size = std::nullopt,
|
|
1413
1421
|
std::optional<int> bits = std::nullopt,
|
|
1414
1422
|
const std::string& mode = "nvfp4",
|
|
1423
|
+
const std::optional<array> global_scale_x = std::nullopt,
|
|
1424
|
+
const std::optional<array> global_scale_w = std::nullopt,
|
|
1415
1425
|
StreamOrDevice s = {});
|
|
1416
1426
|
|
|
1417
1427
|
/** Convert an E4M3 float8 to the given floating point dtype. */
|
|
@@ -3424,6 +3424,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
|
|
3424
3424
|
group_size_,
|
|
3425
3425
|
bits_,
|
|
3426
3426
|
quantization_mode_to_string(mode_),
|
|
3427
|
+
{}, // placeholder for amax
|
|
3427
3428
|
std::nullopt,
|
|
3428
3429
|
stream());
|
|
3429
3430
|
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
|
@@ -3484,14 +3485,14 @@ std::vector<Shape> QQMatmul::output_shapes(const std::vector<array>& inputs) {
|
|
|
3484
3485
|
}
|
|
3485
3486
|
|
|
3486
3487
|
std::vector<array> QQMatmul::vjp(
|
|
3487
|
-
const std::vector<array>& primals, // non quantized x, non quantized w
|
|
3488
|
+
const std::vector<array>& primals, // non quantized x, non quantized w, if
|
|
3489
|
+
// nvfp4 global_scale_x, global_scale_w
|
|
3488
3490
|
const std::vector<array>& cotangents, // non quantized upstream grads
|
|
3489
3491
|
const std::vector<int>& argnums,
|
|
3490
3492
|
const std::vector<array>&) {
|
|
3491
|
-
|
|
3492
|
-
|
|
3493
|
-
|
|
3494
|
-
}
|
|
3493
|
+
bool is_nvfp4 = mode_ == QuantizationMode::Nvfp4;
|
|
3494
|
+
assert(primals.size() == 2 || (is_nvfp4 && primals.size() == 4));
|
|
3495
|
+
|
|
3495
3496
|
std::vector<array> vjps;
|
|
3496
3497
|
auto& cotan = cotangents[0];
|
|
3497
3498
|
auto& s = stream();
|
|
@@ -3499,6 +3500,15 @@ std::vector<array> QQMatmul::vjp(
|
|
|
3499
3500
|
// primal[0] -- non quantized activations (M, K)
|
|
3500
3501
|
// cotan -- non quantized grads (M, N)
|
|
3501
3502
|
auto qmode = quantization_mode_to_string(mode_);
|
|
3503
|
+
std::optional<array> cotan_amax = (primals.size() == 4)
|
|
3504
|
+
? std::make_optional(astype(max(abs(cotan, s), s), float32, s))
|
|
3505
|
+
: std::nullopt;
|
|
3506
|
+
|
|
3507
|
+
auto get_primal_scale = [&](int idx) {
|
|
3508
|
+
return (primals.size() == 4) ? std::make_optional(primals[idx])
|
|
3509
|
+
: std::nullopt;
|
|
3510
|
+
};
|
|
3511
|
+
|
|
3502
3512
|
for (auto arg : argnums) {
|
|
3503
3513
|
if (arg == 0) { // gradient wrt to x
|
|
3504
3514
|
// We transpose weights -> quantize along N
|
|
@@ -3509,6 +3519,8 @@ std::vector<array> QQMatmul::vjp(
|
|
|
3509
3519
|
group_size_,
|
|
3510
3520
|
bits_,
|
|
3511
3521
|
qmode,
|
|
3522
|
+
cotan_amax,
|
|
3523
|
+
get_primal_scale(3), // global_scale_w (for w.T)
|
|
3512
3524
|
s));
|
|
3513
3525
|
} else if (arg == 1) { // gradient wrt to weights
|
|
3514
3526
|
vjps.push_back(qqmm(
|
|
@@ -3518,7 +3530,11 @@ std::vector<array> QQMatmul::vjp(
|
|
|
3518
3530
|
group_size_,
|
|
3519
3531
|
bits_,
|
|
3520
3532
|
qmode,
|
|
3533
|
+
cotan_amax,
|
|
3534
|
+
get_primal_scale(2), // global_scale_x (for x.T)
|
|
3521
3535
|
s));
|
|
3536
|
+
} else {
|
|
3537
|
+
vjps.push_back(zeros_like(primals[arg], s));
|
|
3522
3538
|
}
|
|
3523
3539
|
}
|
|
3524
3540
|
return vjps;
|
|
@@ -3643,6 +3659,7 @@ std::vector<array> GatherQMM::vjp(
|
|
|
3643
3659
|
bits_,
|
|
3644
3660
|
quantization_mode_to_string(mode_),
|
|
3645
3661
|
std::nullopt,
|
|
3662
|
+
std::nullopt, // amax placeholder
|
|
3646
3663
|
stream()),
|
|
3647
3664
|
-1,
|
|
3648
3665
|
{-1, group_size_},
|
|
@@ -26,6 +26,10 @@ Stream get_stream(int index) {
|
|
|
26
26
|
return scheduler::scheduler().get_stream(index);
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
+
std::vector<Stream> get_streams() {
|
|
30
|
+
return scheduler::scheduler().get_streams();
|
|
31
|
+
}
|
|
32
|
+
|
|
29
33
|
Stream new_stream(Device d) {
|
|
30
34
|
if (!gpu::is_available() && d == Device::gpu) {
|
|
31
35
|
throw std::invalid_argument(
|
|
@@ -99,6 +99,9 @@ class Scheduler {
|
|
|
99
99
|
Stream get_stream(int index) const {
|
|
100
100
|
return streams_.at(index);
|
|
101
101
|
}
|
|
102
|
+
std::vector<Stream> get_streams() const {
|
|
103
|
+
return streams_;
|
|
104
|
+
}
|
|
102
105
|
|
|
103
106
|
void set_default_stream(const Stream& s) {
|
|
104
107
|
default_streams_.at(s.device.type) = s;
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
#pragma once
|
|
4
4
|
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
5
7
|
#include "mlx/api.h"
|
|
6
8
|
#include "mlx/device.h"
|
|
7
9
|
|
|
@@ -25,6 +27,9 @@ MLX_API Stream new_stream(Device d);
|
|
|
25
27
|
/** Get the stream with the given index. */
|
|
26
28
|
MLX_API Stream get_stream(int index);
|
|
27
29
|
|
|
30
|
+
/** Get all available streams. */
|
|
31
|
+
MLX_API std::vector<Stream> get_streams();
|
|
32
|
+
|
|
28
33
|
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
|
29
34
|
return lhs.index == rhs.index;
|
|
30
35
|
}
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.25)
|
|
2
|
+
|
|
3
|
+
project(mlx_onnx VERSION 0.30.7.1 LANGUAGES C CXX)
|
|
4
|
+
|
|
5
|
+
set(CMAKE_CXX_STANDARD 20)
|
|
6
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
7
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
8
|
+
|
|
9
|
+
option(MLX_ONNX_USE_EXTERNAL_MLX "Build against an externally provided MLX install" OFF)
|
|
10
|
+
option(MLX_ONNX_BUILD_PYTHON_BINDINGS "Build Python IR bindings" OFF)
|
|
11
|
+
option(MLX_ONNX_INSTALL_CPP_ARTIFACTS "Install C++ library and headers" ON)
|
|
12
|
+
|
|
13
|
+
include(FetchContent)
|
|
14
|
+
|
|
15
|
+
if(MLX_ONNX_USE_EXTERNAL_MLX AND MLX_ONNX_BUILD_PYTHON_BINDINGS)
|
|
16
|
+
message(
|
|
17
|
+
FATAL_ERROR
|
|
18
|
+
"MLX_ONNX_BUILD_PYTHON_BINDINGS requires bundled mlx sources; set MLX_ONNX_USE_EXTERNAL_MLX=OFF")
|
|
19
|
+
endif()
|
|
20
|
+
|
|
21
|
+
if(MLX_ONNX_USE_EXTERNAL_MLX)
|
|
22
|
+
set(MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR "" CACHE PATH "Path to MLX include root")
|
|
23
|
+
set(MLX_ONNX_EXTERNAL_MLX_LIB_DIR "" CACHE PATH "Path to MLX library directory")
|
|
24
|
+
|
|
25
|
+
if(MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR STREQUAL "")
|
|
26
|
+
message(FATAL_ERROR "MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR must be set when MLX_ONNX_USE_EXTERNAL_MLX=ON")
|
|
27
|
+
endif()
|
|
28
|
+
if(MLX_ONNX_EXTERNAL_MLX_LIB_DIR STREQUAL "")
|
|
29
|
+
message(FATAL_ERROR "MLX_ONNX_EXTERNAL_MLX_LIB_DIR must be set when MLX_ONNX_USE_EXTERNAL_MLX=ON")
|
|
30
|
+
endif()
|
|
31
|
+
|
|
32
|
+
find_library(
|
|
33
|
+
MLX_EXTERNAL_LIBRARY
|
|
34
|
+
NAMES mlx
|
|
35
|
+
PATHS ${MLX_ONNX_EXTERNAL_MLX_LIB_DIR}
|
|
36
|
+
NO_DEFAULT_PATH)
|
|
37
|
+
|
|
38
|
+
if(NOT MLX_EXTERNAL_LIBRARY)
|
|
39
|
+
message(FATAL_ERROR "Could not find libmlx in ${MLX_ONNX_EXTERNAL_MLX_LIB_DIR}")
|
|
40
|
+
endif()
|
|
41
|
+
|
|
42
|
+
add_library(mlx SHARED IMPORTED GLOBAL)
|
|
43
|
+
set_target_properties(
|
|
44
|
+
mlx
|
|
45
|
+
PROPERTIES
|
|
46
|
+
IMPORTED_LOCATION ${MLX_EXTERNAL_LIBRARY}
|
|
47
|
+
INTERFACE_INCLUDE_DIRECTORIES ${MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR})
|
|
48
|
+
else()
|
|
49
|
+
set(MLX_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
|
50
|
+
set(MLX_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
|
51
|
+
set(MLX_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
|
|
52
|
+
if(MLX_ONNX_BUILD_PYTHON_BINDINGS)
|
|
53
|
+
set(MLX_BUILD_PYTHON_BINDINGS ON CACHE BOOL "" FORCE)
|
|
54
|
+
else()
|
|
55
|
+
set(MLX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
|
|
56
|
+
endif()
|
|
57
|
+
set(MLX_BUILD_PYTHON_STUBS OFF CACHE BOOL "" FORCE)
|
|
58
|
+
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
|
59
|
+
set(MLX_BUILD_SAFETENSORS OFF CACHE BOOL "" FORCE)
|
|
60
|
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mlx)
|
|
61
|
+
endif()
|
|
62
|
+
|
|
63
|
+
if(NOT TARGET nlohmann_json::nlohmann_json)
|
|
64
|
+
FetchContent_Declare(
|
|
65
|
+
nlohmann_json
|
|
66
|
+
GIT_REPOSITORY https://github.com/nlohmann/json.git
|
|
67
|
+
GIT_TAG v3.11.3
|
|
68
|
+
EXCLUDE_FROM_ALL)
|
|
69
|
+
FetchContent_MakeAvailable(nlohmann_json)
|
|
70
|
+
endif()
|
|
71
|
+
|
|
72
|
+
add_library(
|
|
73
|
+
mlx_onnx
|
|
74
|
+
src/export.cpp
|
|
75
|
+
src/api.cpp
|
|
76
|
+
src/compat.cpp
|
|
77
|
+
src/io.cpp
|
|
78
|
+
src/lowering.cpp
|
|
79
|
+
src/mappings.cpp
|
|
80
|
+
src/onnx.cpp
|
|
81
|
+
src/shared.cpp)
|
|
82
|
+
|
|
83
|
+
set_target_properties(mlx_onnx PROPERTIES OUTPUT_NAME mlx_onnx)
|
|
84
|
+
|
|
85
|
+
target_include_directories(
|
|
86
|
+
mlx_onnx
|
|
87
|
+
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
|
88
|
+
$<INSTALL_INTERFACE:include>
|
|
89
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
|
|
90
|
+
|
|
91
|
+
if(MLX_ONNX_USE_EXTERNAL_MLX)
|
|
92
|
+
target_include_directories(mlx_onnx PRIVATE ${MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR})
|
|
93
|
+
endif()
|
|
94
|
+
|
|
95
|
+
target_link_libraries(mlx_onnx PUBLIC mlx nlohmann_json::nlohmann_json)
|
|
96
|
+
|
|
97
|
+
if(MLX_ONNX_BUILD_PYTHON_BINDINGS)
|
|
98
|
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
|
99
|
+
set(MLX_ONNX_PY_INIT_FILE ${CMAKE_CURRENT_SOURCE_DIR}/python/mlx_onnx/__init__.py)
|
|
100
|
+
if(NOT EXISTS ${MLX_ONNX_PY_INIT_FILE})
|
|
101
|
+
set(MLX_ONNX_PY_INIT_FILE ${CMAKE_CURRENT_BINARY_DIR}/mlx_onnx___init__.py)
|
|
102
|
+
file(WRITE ${MLX_ONNX_PY_INIT_FILE} "from ._core import * # noqa: F401,F403\n")
|
|
103
|
+
endif()
|
|
104
|
+
if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mlx/CMakeLists.txt)
|
|
105
|
+
message(FATAL_ERROR "Bundled mlx sources are missing at ${CMAKE_CURRENT_SOURCE_DIR}/mlx")
|
|
106
|
+
endif()
|
|
107
|
+
if(NOT TARGET core)
|
|
108
|
+
message(FATAL_ERROR "Bundled mlx Python extension target `core` was not built")
|
|
109
|
+
endif()
|
|
110
|
+
install(TARGETS core LIBRARY DESTINATION mlx COMPONENT python)
|
|
111
|
+
if(APPLE AND MLX_BUILD_METAL)
|
|
112
|
+
# MLX looks for mlx.metallib next to the extension module using MLX runtime.
|
|
113
|
+
install(
|
|
114
|
+
FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib
|
|
115
|
+
DESTINATION mlx
|
|
116
|
+
COMPONENT python)
|
|
117
|
+
# mlx_onnx._core also links MLX and resolves the same metallib at runtime.
|
|
118
|
+
install(
|
|
119
|
+
FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib
|
|
120
|
+
DESTINATION mlx_onnx
|
|
121
|
+
COMPONENT python)
|
|
122
|
+
endif()
|
|
123
|
+
install(
|
|
124
|
+
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlx/python/mlx/
|
|
125
|
+
DESTINATION mlx
|
|
126
|
+
COMPONENT python
|
|
127
|
+
PATTERN "__pycache__" EXCLUDE)
|
|
128
|
+
install(
|
|
129
|
+
FILES ${MLX_ONNX_PY_INIT_FILE}
|
|
130
|
+
DESTINATION mlx_onnx
|
|
131
|
+
RENAME __init__.py
|
|
132
|
+
COMPONENT python)
|
|
133
|
+
set(MLX_ONNX_VENDOR_MLX_ROOT mlx_onnx/_vendor/mlx)
|
|
134
|
+
install(
|
|
135
|
+
FILES ${CMAKE_CURRENT_SOURCE_DIR}/mlx/CMakeLists.txt
|
|
136
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/mlx.pc.in
|
|
137
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/LICENSE
|
|
138
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/ACKNOWLEDGMENTS.md
|
|
139
|
+
DESTINATION ${MLX_ONNX_VENDOR_MLX_ROOT}
|
|
140
|
+
COMPONENT python)
|
|
141
|
+
install(
|
|
142
|
+
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlx/cmake
|
|
143
|
+
${CMAKE_CURRENT_SOURCE_DIR}/mlx/mlx
|
|
144
|
+
DESTINATION ${MLX_ONNX_VENDOR_MLX_ROOT}
|
|
145
|
+
COMPONENT python)
|
|
146
|
+
endif()
|
|
147
|
+
|
|
148
|
+
if(MLX_ONNX_INSTALL_CPP_ARTIFACTS)
|
|
149
|
+
include(GNUInstallDirs)
|
|
150
|
+
install(
|
|
151
|
+
TARGETS mlx_onnx
|
|
152
|
+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
153
|
+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
154
|
+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
|
155
|
+
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
|
156
|
+
COMPONENT cpp)
|
|
157
|
+
|
|
158
|
+
install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} COMPONENT cpp)
|
|
159
|
+
endif()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 MLX Contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|