mlx 0.30.7.2 → 0.30.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/extconf.rb +267 -8
- data/ext/mlx/native.cpp +112 -58
- data/ext/mlx-onnx/native.cpp +1402 -0
- data/ext/mlx-onnx/native.hpp +19 -0
- data/lib/mlx/core.rb +342 -117
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +2 -0
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +8 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +10 -0
- data/lib/mlx/nn/base.rb +4 -0
- data/lib/mlx/nn/layers/linear.rb +2 -3
- data/lib/mlx/onnx.rb +250 -0
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx-onnx/webgpu_harness.rb +289 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
- data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
- data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
- data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
- data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
- data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
- data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
- data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
- data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
- data/submodules/mlx-onnx/CMakeLists.txt +159 -0
- data/submodules/mlx-onnx/LICENSE +21 -0
- data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
- data/submodules/mlx-onnx/src/api.cpp +81 -0
- data/submodules/mlx-onnx/src/compat.cpp +111 -0
- data/submodules/mlx-onnx/src/detail.hpp +69 -0
- data/submodules/mlx-onnx/src/export.cpp +653 -0
- data/submodules/mlx-onnx/src/io.cpp +61 -0
- data/submodules/mlx-onnx/src/json.hpp +25 -0
- data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
- data/submodules/mlx-onnx/src/mappings.cpp +201 -0
- data/submodules/mlx-onnx/src/mappings.hpp +16 -0
- data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
- data/submodules/mlx-onnx/src/shared.cpp +206 -0
- metadata +665 -567
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
- /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
- /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <functional>
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <unordered_map>
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
#include <ruby.h>
|
|
9
|
+
|
|
10
|
+
#include "mlx/array.h"
|
|
11
|
+
#include "mlx/export.h"
|
|
12
|
+
|
|
13
|
+
mlx::core::array onnx_array_from_ruby(VALUE value);
|
|
14
|
+
std::vector<mlx::core::array> onnx_array_vector_from_ruby(VALUE value);
|
|
15
|
+
std::unordered_map<std::string, mlx::core::array> onnx_array_map_from_ruby_hash(VALUE value);
|
|
16
|
+
std::function<std::vector<mlx::core::array>(const mlx::core::Args&, const mlx::core::Kwargs&)>
|
|
17
|
+
onnx_args_kwargs_function_from_callable(VALUE callable);
|
|
18
|
+
|
|
19
|
+
extern "C" void init_onnx_native_bindings(VALUE mMLX);
|
data/lib/mlx/core.rb
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
require "open3"
|
|
4
4
|
require "tmpdir"
|
|
5
|
+
require_relative "onnx"
|
|
5
6
|
|
|
6
7
|
module MLX
|
|
7
8
|
module Core
|
|
@@ -334,6 +335,8 @@ module MLX
|
|
|
334
335
|
alias_method :native_vmap, :vmap if method_defined?(:vmap) && !method_defined?(:native_vmap)
|
|
335
336
|
alias_method :native_export_to_dot,
|
|
336
337
|
:export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
|
|
338
|
+
alias_method :native_array, :array if method_defined?(:array) && !method_defined?(:native_array)
|
|
339
|
+
alias_method :native_mean, :mean if method_defined?(:mean) && !method_defined?(:native_mean)
|
|
337
340
|
|
|
338
341
|
%i[savez savez_compressed].each do |method_name|
|
|
339
342
|
if method_defined?(method_name) && instance_method(method_name).owner == self
|
|
@@ -343,6 +346,24 @@ module MLX
|
|
|
343
346
|
|
|
344
347
|
ARRAY_LEAF = :__mlx_array_leaf__
|
|
345
348
|
|
|
349
|
+
def array(value, positional_dtype = nil, dtype: nil)
|
|
350
|
+
ensure_native!
|
|
351
|
+
target_dtype = resolve_array_dtype(positional_dtype, dtype)
|
|
352
|
+
native_array(value, target_dtype)
|
|
353
|
+
end
|
|
354
|
+
|
|
355
|
+
def mean(array, axis = nil, positional_keepdims = nil, keepdims: nil)
|
|
356
|
+
ensure_native!
|
|
357
|
+
keepdims_v = resolve_keepdims_argument(positional_keepdims, keepdims)
|
|
358
|
+
reduced = reduce_mean(array, axis)
|
|
359
|
+
return reduced unless keepdims_v
|
|
360
|
+
|
|
361
|
+
normalize_reduction_axes(array, axis).each do |axis_index|
|
|
362
|
+
reduced = expand_dims(reduced, axis_index)
|
|
363
|
+
end
|
|
364
|
+
reduced
|
|
365
|
+
end
|
|
366
|
+
|
|
346
367
|
def load(file, format = nil, return_metadata = false)
|
|
347
368
|
ensure_native!
|
|
348
369
|
format_name = (format || infer_format(file)).to_s
|
|
@@ -560,6 +581,97 @@ module MLX
|
|
|
560
581
|
|
|
561
582
|
private
|
|
562
583
|
|
|
584
|
+
def resolve_array_dtype(positional_dtype, keyword_dtype)
|
|
585
|
+
normalized_positional = normalize_dtype_alias(positional_dtype)
|
|
586
|
+
normalized_keyword = normalize_dtype_alias(keyword_dtype)
|
|
587
|
+
return normalized_keyword if normalized_positional.nil?
|
|
588
|
+
return normalized_positional if normalized_keyword.nil?
|
|
589
|
+
|
|
590
|
+
if dtype_name_for_compare(normalized_positional) != dtype_name_for_compare(normalized_keyword)
|
|
591
|
+
raise ArgumentError,
|
|
592
|
+
"array received conflicting dtype arguments (positional=#{positional_dtype.inspect}, keyword=#{keyword_dtype.inspect})"
|
|
593
|
+
end
|
|
594
|
+
|
|
595
|
+
normalized_positional
|
|
596
|
+
end
|
|
597
|
+
|
|
598
|
+
def normalize_dtype_alias(dtype)
|
|
599
|
+
return nil if dtype.nil?
|
|
600
|
+
return dtype if dtype.respond_to?(:name)
|
|
601
|
+
return dtype unless dtype.is_a?(::Symbol) || dtype.is_a?(::String)
|
|
602
|
+
|
|
603
|
+
case dtype.to_s.strip.downcase
|
|
604
|
+
when "bool", "bool_"
|
|
605
|
+
:bool_
|
|
606
|
+
when "f16", "fp16", "float16"
|
|
607
|
+
:float16
|
|
608
|
+
when "bf16", "bfloat16"
|
|
609
|
+
:bfloat16
|
|
610
|
+
when "f32", "fp32", "float32"
|
|
611
|
+
:float32
|
|
612
|
+
when "f64", "fp64", "float64"
|
|
613
|
+
:float64
|
|
614
|
+
when "c64", "complex64"
|
|
615
|
+
:complex64
|
|
616
|
+
else
|
|
617
|
+
dtype
|
|
618
|
+
end
|
|
619
|
+
end
|
|
620
|
+
|
|
621
|
+
def dtype_name_for_compare(dtype)
|
|
622
|
+
return nil if dtype.nil?
|
|
623
|
+
dtype = normalize_dtype_alias(dtype)
|
|
624
|
+
|
|
625
|
+
if dtype.respond_to?(:name)
|
|
626
|
+
dtype.name.to_s
|
|
627
|
+
else
|
|
628
|
+
dtype.to_s
|
|
629
|
+
end
|
|
630
|
+
end
|
|
631
|
+
|
|
632
|
+
def resolve_keepdims_argument(positional_keepdims, keyword_keepdims)
|
|
633
|
+
if !positional_keepdims.nil? && !keyword_keepdims.nil? && !!positional_keepdims != !!keyword_keepdims
|
|
634
|
+
raise ArgumentError,
|
|
635
|
+
"mean received conflicting keepdims arguments (positional=#{positional_keepdims.inspect}, keyword=#{keyword_keepdims.inspect})"
|
|
636
|
+
end
|
|
637
|
+
return !!keyword_keepdims unless keyword_keepdims.nil?
|
|
638
|
+
return !!positional_keepdims unless positional_keepdims.nil?
|
|
639
|
+
|
|
640
|
+
false
|
|
641
|
+
end
|
|
642
|
+
|
|
643
|
+
def reduce_mean(array, axis)
|
|
644
|
+
if axis.is_a?(::Array)
|
|
645
|
+
normalize_reduction_axes(array, axis).reverse_each.reduce(array) do |acc, axis_index|
|
|
646
|
+
native_mean(acc, axis_index)
|
|
647
|
+
end
|
|
648
|
+
else
|
|
649
|
+
native_mean(array, axis)
|
|
650
|
+
end
|
|
651
|
+
end
|
|
652
|
+
|
|
653
|
+
def normalize_reduction_axes(array, axis)
|
|
654
|
+
ndim = array.ndim
|
|
655
|
+
return (0...ndim).to_a if axis.nil?
|
|
656
|
+
|
|
657
|
+
raw_axes = axis.is_a?(::Array) ? axis : [axis]
|
|
658
|
+
axes = raw_axes.map { |entry| normalize_axis_index(entry, ndim) }.sort
|
|
659
|
+
raise ArgumentError, "axis contains duplicate values: #{raw_axes.inspect}" if axes.uniq.length != axes.length
|
|
660
|
+
|
|
661
|
+
axes
|
|
662
|
+
end
|
|
663
|
+
|
|
664
|
+
def normalize_axis_index(axis, ndim)
|
|
665
|
+
raise TypeError, "axis entries must be Integer" unless axis.is_a?(::Integer)
|
|
666
|
+
|
|
667
|
+
out = axis
|
|
668
|
+
out += ndim if out.negative?
|
|
669
|
+
if out.negative? || out >= ndim
|
|
670
|
+
raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{ndim}"
|
|
671
|
+
end
|
|
672
|
+
out
|
|
673
|
+
end
|
|
674
|
+
|
|
563
675
|
def infer_format(file)
|
|
564
676
|
path = file_path(file)
|
|
565
677
|
ext = File.extname(path).delete_prefix(".")
|
|
@@ -601,7 +713,11 @@ module MLX
|
|
|
601
713
|
Dir.glob(File.join(dir, "**", "*.npy")).sort.each do |npy_path|
|
|
602
714
|
rel = npy_path.delete_prefix(dir + File::SEPARATOR)
|
|
603
715
|
key = rel.end_with?(".npy") ? rel[0...-4] : rel
|
|
604
|
-
|
|
716
|
+
# Force a materialized copy to avoid keeping many file-backed mmap handles open.
|
|
717
|
+
value = native_load(npy_path, "npy", false)
|
|
718
|
+
value = add(value, 0)
|
|
719
|
+
eval(value)
|
|
720
|
+
out[key] = value
|
|
605
721
|
end
|
|
606
722
|
out
|
|
607
723
|
end
|
|
@@ -677,27 +793,95 @@ module MLX
|
|
|
677
793
|
end
|
|
678
794
|
|
|
679
795
|
def build_grad_like_function(fun, argnums, argnames, with_value)
|
|
796
|
+
cache = {}
|
|
797
|
+
|
|
680
798
|
lambda do |*args, **kwargs|
|
|
681
799
|
selections, flat_inputs = build_target_selections(args, kwargs, argnums, argnames)
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
800
|
+
cache_key = grad_selection_cache_key(selections)
|
|
801
|
+
entry = cache[cache_key]
|
|
802
|
+
unless entry
|
|
803
|
+
call_state = { mutex: Mutex.new, stacks: {} }
|
|
804
|
+
lifted = lambda do |*flat_vars|
|
|
805
|
+
state = grad_call_state_current(call_state)
|
|
806
|
+
if state.nil?
|
|
807
|
+
raise RuntimeError, "gradient transform invoked without call state"
|
|
808
|
+
end
|
|
809
|
+
|
|
810
|
+
call_args, call_kwargs = apply_flat_vars_to_targets(
|
|
811
|
+
state[:args],
|
|
812
|
+
state[:kwargs],
|
|
813
|
+
state[:selections],
|
|
814
|
+
flat_vars
|
|
815
|
+
)
|
|
816
|
+
raw_value = fun.call(*call_args, **call_kwargs)
|
|
817
|
+
state[:captured_value] = raw_value
|
|
818
|
+
extract_loss(raw_value)
|
|
819
|
+
end
|
|
820
|
+
|
|
821
|
+
native_argnums = (0...flat_inputs.length).to_a
|
|
822
|
+
native_fn = if with_value
|
|
823
|
+
native_value_and_grad(lifted, native_argnums)
|
|
824
|
+
else
|
|
825
|
+
native_grad(lifted, native_argnums)
|
|
826
|
+
end
|
|
827
|
+
|
|
828
|
+
entry = {
|
|
829
|
+
native_fn: native_fn,
|
|
830
|
+
call_state: call_state
|
|
831
|
+
}
|
|
832
|
+
cache[cache_key] = entry
|
|
689
833
|
end
|
|
690
834
|
|
|
835
|
+
state = {
|
|
836
|
+
args: args,
|
|
837
|
+
kwargs: kwargs,
|
|
838
|
+
selections: selections,
|
|
839
|
+
captured_value: nil
|
|
840
|
+
}
|
|
841
|
+
grad_call_state_push(entry[:call_state], state)
|
|
842
|
+
|
|
691
843
|
if with_value
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
value =
|
|
844
|
+
_loss, raw_grads = entry[:native_fn].call(*flat_inputs)
|
|
845
|
+
value = state[:captured_value]
|
|
846
|
+
value = fun.call(*args, **kwargs) if value.nil?
|
|
695
847
|
[value, rebuild_grad_result(raw_grads, selections, argnames)]
|
|
696
848
|
else
|
|
697
|
-
|
|
698
|
-
raw_grads = native_fn.call(*flat_inputs)
|
|
849
|
+
raw_grads = entry[:native_fn].call(*flat_inputs)
|
|
699
850
|
rebuild_grad_result(raw_grads, selections, argnames)
|
|
700
851
|
end
|
|
852
|
+
ensure
|
|
853
|
+
grad_call_state_pop(entry[:call_state]) unless entry.nil?
|
|
854
|
+
end
|
|
855
|
+
end
|
|
856
|
+
|
|
857
|
+
def grad_call_state_current(call_state)
|
|
858
|
+
thread = Thread.current
|
|
859
|
+
call_state[:mutex].synchronize do
|
|
860
|
+
stack = call_state[:stacks][thread]
|
|
861
|
+
stack&.last
|
|
862
|
+
end
|
|
863
|
+
end
|
|
864
|
+
|
|
865
|
+
def grad_call_state_push(call_state, state)
|
|
866
|
+
thread = Thread.current
|
|
867
|
+
call_state[:mutex].synchronize do
|
|
868
|
+
stack = call_state[:stacks][thread]
|
|
869
|
+
if stack.nil?
|
|
870
|
+
stack = []
|
|
871
|
+
call_state[:stacks][thread] = stack
|
|
872
|
+
end
|
|
873
|
+
stack << state
|
|
874
|
+
end
|
|
875
|
+
end
|
|
876
|
+
|
|
877
|
+
def grad_call_state_pop(call_state)
|
|
878
|
+
thread = Thread.current
|
|
879
|
+
call_state[:mutex].synchronize do
|
|
880
|
+
stack = call_state[:stacks][thread]
|
|
881
|
+
return if stack.nil?
|
|
882
|
+
|
|
883
|
+
stack.pop
|
|
884
|
+
call_state[:stacks].delete(thread) if stack.empty?
|
|
701
885
|
end
|
|
702
886
|
end
|
|
703
887
|
|
|
@@ -725,6 +909,119 @@ module MLX
|
|
|
725
909
|
end
|
|
726
910
|
end
|
|
727
911
|
|
|
912
|
+
def extract_loss(output)
|
|
913
|
+
return output if output.is_a?(MLX::Core::Array)
|
|
914
|
+
|
|
915
|
+
if output.is_a?(::Array) && !output.empty? && output[0].is_a?(MLX::Core::Array)
|
|
916
|
+
return output[0]
|
|
917
|
+
end
|
|
918
|
+
|
|
919
|
+
raise ArgumentError,
|
|
920
|
+
"function must return an MLX::Core::Array or an Array whose first element is an MLX::Core::Array"
|
|
921
|
+
end
|
|
922
|
+
|
|
923
|
+
def build_target_selections(args, kwargs, argnums, argnames)
|
|
924
|
+
positional = []
|
|
925
|
+
keyword = []
|
|
926
|
+
flat_inputs = []
|
|
927
|
+
|
|
928
|
+
argnums.each do |index|
|
|
929
|
+
if index >= args.length
|
|
930
|
+
raise ArgumentError,
|
|
931
|
+
"Can't compute gradient for positional argument #{index} when #{args.length} positional arguments were provided"
|
|
932
|
+
end
|
|
933
|
+
spec = flatten_tree_spec(args[index], flat_inputs, true)
|
|
934
|
+
positional << { index: index, spec: spec }
|
|
935
|
+
end
|
|
936
|
+
|
|
937
|
+
argnames.each do |name|
|
|
938
|
+
key = kwarg_key_for_name(kwargs, name)
|
|
939
|
+
unless key
|
|
940
|
+
raise ArgumentError,
|
|
941
|
+
"Can't compute gradient for keyword argument '#{name}' because it was not provided"
|
|
942
|
+
end
|
|
943
|
+
spec = flatten_tree_spec(kwargs[key], flat_inputs, true)
|
|
944
|
+
keyword << { key: key, name: name, spec: spec }
|
|
945
|
+
end
|
|
946
|
+
|
|
947
|
+
[{ positional: positional, keyword: keyword }, flat_inputs]
|
|
948
|
+
end
|
|
949
|
+
|
|
950
|
+
def grad_selection_cache_key(selections)
|
|
951
|
+
positional = selections[:positional].map do |entry|
|
|
952
|
+
"#{entry[:index]}:#{structure_cache_key(entry[:spec])}"
|
|
953
|
+
end
|
|
954
|
+
keyword = selections[:keyword].map do |entry|
|
|
955
|
+
"#{entry[:name]}:#{entry[:key]}:#{structure_cache_key(entry[:spec])}"
|
|
956
|
+
end
|
|
957
|
+
"P[#{positional.join(',')}]K[#{keyword.join(',')}]"
|
|
958
|
+
end
|
|
959
|
+
|
|
960
|
+
def normalize_raw_grads(raw)
|
|
961
|
+
normalize_array_sequence(raw, "gradient")
|
|
962
|
+
end
|
|
963
|
+
|
|
964
|
+
def rebuild_grad_result(raw_grads, selections, argnames)
|
|
965
|
+
grad_arrays = normalize_raw_grads(raw_grads)
|
|
966
|
+
cursor = 0
|
|
967
|
+
|
|
968
|
+
positional_grads = selections[:positional].map do |entry|
|
|
969
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
|
|
970
|
+
value
|
|
971
|
+
end
|
|
972
|
+
keyword_grads = {}
|
|
973
|
+
selections[:keyword].each do |entry|
|
|
974
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
|
|
975
|
+
keyword_grads[entry[:name]] = value
|
|
976
|
+
end
|
|
977
|
+
unless cursor == grad_arrays.length
|
|
978
|
+
raise RuntimeError, "internal gradient reconstruction mismatch"
|
|
979
|
+
end
|
|
980
|
+
|
|
981
|
+
if argnames.empty?
|
|
982
|
+
return positional_grads[0] if positional_grads.length == 1
|
|
983
|
+
return positional_grads
|
|
984
|
+
end
|
|
985
|
+
|
|
986
|
+
positional_out = if positional_grads.empty?
|
|
987
|
+
nil
|
|
988
|
+
elsif positional_grads.length == 1
|
|
989
|
+
positional_grads[0]
|
|
990
|
+
else
|
|
991
|
+
positional_grads
|
|
992
|
+
end
|
|
993
|
+
[positional_out, keyword_grads]
|
|
994
|
+
end
|
|
995
|
+
|
|
996
|
+
def apply_flat_vars_to_targets(args, kwargs, selections, flat_vars)
|
|
997
|
+
rebuilt_args = args.dup
|
|
998
|
+
rebuilt_kwargs = kwargs.dup
|
|
999
|
+
cursor = 0
|
|
1000
|
+
|
|
1001
|
+
selections[:positional].each do |entry|
|
|
1002
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
|
|
1003
|
+
rebuilt_args[entry[:index]] = value
|
|
1004
|
+
end
|
|
1005
|
+
|
|
1006
|
+
selections[:keyword].each do |entry|
|
|
1007
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
|
|
1008
|
+
rebuilt_kwargs[entry[:key]] = value
|
|
1009
|
+
end
|
|
1010
|
+
|
|
1011
|
+
unless cursor == flat_vars.length
|
|
1012
|
+
raise RuntimeError, "internal target reconstruction mismatch"
|
|
1013
|
+
end
|
|
1014
|
+
[rebuilt_args, rebuilt_kwargs]
|
|
1015
|
+
end
|
|
1016
|
+
|
|
1017
|
+
def kwarg_key_for_name(kwargs, name)
|
|
1018
|
+
symbol = name.to_sym
|
|
1019
|
+
return symbol if kwargs.key?(symbol)
|
|
1020
|
+
return name if kwargs.key?(name)
|
|
1021
|
+
|
|
1022
|
+
nil
|
|
1023
|
+
end
|
|
1024
|
+
|
|
728
1025
|
def custom_jvp(fun, primals, tangents)
|
|
729
1026
|
primals_list = normalize_array_output(primals, "primals")
|
|
730
1027
|
tangents_list = normalize_array_output(tangents, "tangents")
|
|
@@ -768,44 +1065,6 @@ module MLX
|
|
|
768
1065
|
end
|
|
769
1066
|
end
|
|
770
1067
|
|
|
771
|
-
def extract_loss(output)
|
|
772
|
-
return output if output.is_a?(MLX::Core::Array)
|
|
773
|
-
|
|
774
|
-
if output.is_a?(::Array) && !output.empty? && output[0].is_a?(MLX::Core::Array)
|
|
775
|
-
return output[0]
|
|
776
|
-
end
|
|
777
|
-
|
|
778
|
-
raise ArgumentError,
|
|
779
|
-
"function must return an MLX::Core::Array or an Array whose first element is an MLX::Core::Array"
|
|
780
|
-
end
|
|
781
|
-
|
|
782
|
-
def build_target_selections(args, kwargs, argnums, argnames)
|
|
783
|
-
positional = []
|
|
784
|
-
keyword = []
|
|
785
|
-
flat_inputs = []
|
|
786
|
-
|
|
787
|
-
argnums.each do |index|
|
|
788
|
-
if index >= args.length
|
|
789
|
-
raise ArgumentError,
|
|
790
|
-
"Can't compute gradient for positional argument #{index} when #{args.length} positional arguments were provided"
|
|
791
|
-
end
|
|
792
|
-
spec = flatten_tree_spec(args[index], flat_inputs, true)
|
|
793
|
-
positional << { index: index, spec: spec }
|
|
794
|
-
end
|
|
795
|
-
|
|
796
|
-
argnames.each do |name|
|
|
797
|
-
key = kwarg_key_for_name(kwargs, name)
|
|
798
|
-
unless key
|
|
799
|
-
raise ArgumentError,
|
|
800
|
-
"Can't compute gradient for keyword argument '#{name}' because it was not provided"
|
|
801
|
-
end
|
|
802
|
-
spec = flatten_tree_spec(kwargs[key], flat_inputs, true)
|
|
803
|
-
keyword << { key: key, name: name, spec: spec }
|
|
804
|
-
end
|
|
805
|
-
|
|
806
|
-
[{ positional: positional, keyword: keyword }, flat_inputs]
|
|
807
|
-
end
|
|
808
|
-
|
|
809
1068
|
def flatten_tree_spec(value, arrays, strict_arrays)
|
|
810
1069
|
if value.is_a?(MLX::Core::Array)
|
|
811
1070
|
arrays << value
|
|
@@ -873,10 +1132,6 @@ module MLX
|
|
|
873
1132
|
end
|
|
874
1133
|
end
|
|
875
1134
|
|
|
876
|
-
def normalize_raw_grads(raw)
|
|
877
|
-
normalize_array_sequence(raw, "gradient")
|
|
878
|
-
end
|
|
879
|
-
|
|
880
1135
|
def normalize_array_sequence(raw, context)
|
|
881
1136
|
return [raw] if raw.is_a?(MLX::Core::Array)
|
|
882
1137
|
|
|
@@ -896,66 +1151,6 @@ module MLX
|
|
|
896
1151
|
end
|
|
897
1152
|
end
|
|
898
1153
|
|
|
899
|
-
def rebuild_grad_result(raw_grads, selections, argnames)
|
|
900
|
-
grad_arrays = normalize_raw_grads(raw_grads)
|
|
901
|
-
cursor = 0
|
|
902
|
-
|
|
903
|
-
positional_grads = selections[:positional].map do |entry|
|
|
904
|
-
value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
|
|
905
|
-
value
|
|
906
|
-
end
|
|
907
|
-
keyword_grads = {}
|
|
908
|
-
selections[:keyword].each do |entry|
|
|
909
|
-
value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
|
|
910
|
-
keyword_grads[entry[:name]] = value
|
|
911
|
-
end
|
|
912
|
-
unless cursor == grad_arrays.length
|
|
913
|
-
raise RuntimeError, "internal gradient reconstruction mismatch"
|
|
914
|
-
end
|
|
915
|
-
|
|
916
|
-
if argnames.empty?
|
|
917
|
-
return positional_grads[0] if positional_grads.length == 1
|
|
918
|
-
return positional_grads
|
|
919
|
-
end
|
|
920
|
-
|
|
921
|
-
positional_out = if positional_grads.empty?
|
|
922
|
-
nil
|
|
923
|
-
elsif positional_grads.length == 1
|
|
924
|
-
positional_grads[0]
|
|
925
|
-
else
|
|
926
|
-
positional_grads
|
|
927
|
-
end
|
|
928
|
-
[positional_out, keyword_grads]
|
|
929
|
-
end
|
|
930
|
-
|
|
931
|
-
def apply_flat_vars_to_targets(args, kwargs, selections, flat_vars)
|
|
932
|
-
rebuilt_args = args.dup
|
|
933
|
-
rebuilt_kwargs = kwargs.dup
|
|
934
|
-
cursor = 0
|
|
935
|
-
|
|
936
|
-
selections[:positional].each do |entry|
|
|
937
|
-
value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
|
|
938
|
-
rebuilt_args[entry[:index]] = value
|
|
939
|
-
end
|
|
940
|
-
|
|
941
|
-
selections[:keyword].each do |entry|
|
|
942
|
-
value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
|
|
943
|
-
rebuilt_kwargs[entry[:key]] = value
|
|
944
|
-
end
|
|
945
|
-
|
|
946
|
-
unless cursor == flat_vars.length
|
|
947
|
-
raise RuntimeError, "internal target reconstruction mismatch"
|
|
948
|
-
end
|
|
949
|
-
[rebuilt_args, rebuilt_kwargs]
|
|
950
|
-
end
|
|
951
|
-
|
|
952
|
-
def kwarg_key_for_name(kwargs, name)
|
|
953
|
-
symbol = name.to_sym
|
|
954
|
-
return symbol if kwargs.key?(symbol)
|
|
955
|
-
return name if kwargs.key?(name)
|
|
956
|
-
|
|
957
|
-
nil
|
|
958
|
-
end
|
|
959
1154
|
end
|
|
960
1155
|
|
|
961
1156
|
class Device
|
|
@@ -1034,8 +1229,8 @@ module MLX
|
|
|
1034
1229
|
MLX::Core.cos(self)
|
|
1035
1230
|
end
|
|
1036
1231
|
|
|
1037
|
-
def mean(axis = nil)
|
|
1038
|
-
MLX::Core.mean(self, axis)
|
|
1232
|
+
def mean(axis = nil, keepdims_positional = nil, keepdims: nil)
|
|
1233
|
+
MLX::Core.mean(self, axis, keepdims_positional, keepdims: keepdims)
|
|
1039
1234
|
end
|
|
1040
1235
|
|
|
1041
1236
|
def sum(axis = nil)
|
|
@@ -1307,6 +1502,10 @@ module MLX
|
|
|
1307
1502
|
MLX::Core.negative(self)
|
|
1308
1503
|
end
|
|
1309
1504
|
|
|
1505
|
+
def -@
|
|
1506
|
+
__neg__
|
|
1507
|
+
end
|
|
1508
|
+
|
|
1310
1509
|
def __pow__(other)
|
|
1311
1510
|
MLX::Core.power(self, other)
|
|
1312
1511
|
end
|
|
@@ -1375,18 +1574,34 @@ module MLX
|
|
|
1375
1574
|
MLX::Core.less(self, other)
|
|
1376
1575
|
end
|
|
1377
1576
|
|
|
1577
|
+
def <(other)
|
|
1578
|
+
__lt__(other)
|
|
1579
|
+
end
|
|
1580
|
+
|
|
1378
1581
|
def __le__(other)
|
|
1379
1582
|
MLX::Core.less_equal(self, other)
|
|
1380
1583
|
end
|
|
1381
1584
|
|
|
1585
|
+
def <=(other)
|
|
1586
|
+
__le__(other)
|
|
1587
|
+
end
|
|
1588
|
+
|
|
1382
1589
|
def __gt__(other)
|
|
1383
1590
|
MLX::Core.greater(self, other)
|
|
1384
1591
|
end
|
|
1385
1592
|
|
|
1593
|
+
def >(other)
|
|
1594
|
+
__gt__(other)
|
|
1595
|
+
end
|
|
1596
|
+
|
|
1386
1597
|
def __ge__(other)
|
|
1387
1598
|
MLX::Core.greater_equal(self, other)
|
|
1388
1599
|
end
|
|
1389
1600
|
|
|
1601
|
+
def >=(other)
|
|
1602
|
+
__ge__(other)
|
|
1603
|
+
end
|
|
1604
|
+
|
|
1390
1605
|
def __iadd__(other)
|
|
1391
1606
|
__add__(other)
|
|
1392
1607
|
end
|
|
@@ -1439,6 +1654,16 @@ module MLX
|
|
|
1439
1654
|
MLX::Core.floor_divide(other, self)
|
|
1440
1655
|
end
|
|
1441
1656
|
|
|
1657
|
+
def coerce(other)
|
|
1658
|
+
if other.is_a?(MLX::Core::Array)
|
|
1659
|
+
[other, self]
|
|
1660
|
+
elsif other.is_a?(::Numeric)
|
|
1661
|
+
[MLX::Core.array(other, dtype), self]
|
|
1662
|
+
else
|
|
1663
|
+
raise TypeError, "#{other.class} can't be coerced into MLX::Core::Array"
|
|
1664
|
+
end
|
|
1665
|
+
end
|
|
1666
|
+
|
|
1442
1667
|
def __getitem__(index)
|
|
1443
1668
|
self[index]
|
|
1444
1669
|
end
|
|
@@ -8,13 +8,14 @@ require "shellwords"
|
|
|
8
8
|
|
|
9
9
|
module MLX
|
|
10
10
|
module DistributedUtils
|
|
11
|
-
SSHInfo =
|
|
11
|
+
SSHInfo = Data.define(:can_ssh, :has_sudo) do
|
|
12
12
|
def to_bool
|
|
13
13
|
can_ssh
|
|
14
14
|
end
|
|
15
15
|
end
|
|
16
|
-
ThunderboltPort =
|
|
17
|
-
ThunderboltHost =
|
|
16
|
+
ThunderboltPort = Data.define(:iface, :uuid, :connected_to)
|
|
17
|
+
ThunderboltHost = Data.define(:name, :ports)
|
|
18
|
+
CommandResult = Data.define(:stdout, :stderr, :status)
|
|
18
19
|
|
|
19
20
|
class IPConfigurator
|
|
20
21
|
attr_reader :ips, :hosts, :tb_hosts
|
|
@@ -509,6 +510,8 @@ module MLX
|
|
|
509
510
|
end
|
|
510
511
|
|
|
511
512
|
def config_main(argv = ARGV, runner: nil)
|
|
513
|
+
Process.warmup if Process.respond_to?(:warmup)
|
|
514
|
+
|
|
512
515
|
opts = {
|
|
513
516
|
verbose: false,
|
|
514
517
|
hosts: "127.0.0.1",
|
|
@@ -577,7 +580,7 @@ module MLX
|
|
|
577
580
|
return runner.call(cmd) unless runner.nil?
|
|
578
581
|
|
|
579
582
|
stdout, stderr, status = Open3.capture3(*cmd)
|
|
580
|
-
|
|
583
|
+
CommandResult.new(stdout: stdout, stderr: stderr, status: status)
|
|
581
584
|
end
|
|
582
585
|
|
|
583
586
|
def stdout_for(result)
|