mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
:orphan:
|
|
2
|
+
|
|
3
|
+
.. _usage_launch_distributed:
|
|
4
|
+
|
|
5
|
+
Launching Distributed Programs
|
|
6
|
+
==============================
|
|
7
|
+
|
|
8
|
+
.. currentmodule:: mlx.core.distributed
|
|
9
|
+
|
|
10
|
+
The MLX python package provides two utilities to help you configure
|
|
11
|
+
your Macs for distributed computation and also launch distributed programs on
|
|
12
|
+
multiple nodes or with many processes in a single node. These utilities are aptly named
|
|
13
|
+
|
|
14
|
+
- ``mlx.launch``
|
|
15
|
+
- ``mlx.distributed_config``
|
|
16
|
+
|
|
17
|
+
See the :doc:`distributed docs <distributed>` for an introduction and
|
|
18
|
+
getting-started guides to the various backends.
|
|
19
|
+
|
|
20
|
+
``mlx.distributed_config``
|
|
21
|
+
---------------------------
|
|
22
|
+
|
|
23
|
+
Unless you are launching distributed jobs locally for development or multi-gpu
|
|
24
|
+
CUDA environments, then you have several Macs that you need to configure for
|
|
25
|
+
distributed communication with MLX.
|
|
26
|
+
|
|
27
|
+
``mlx.distributed_config`` aims to automate the process of configuring the
|
|
28
|
+
network interfaces (especially for communication over thunderbolt) and also
|
|
29
|
+
creating the hostfile to be used with ``mlx.launch``.
|
|
30
|
+
|
|
31
|
+
We will analyse 3 cases of using ``mlx.distributed_config``
|
|
32
|
+
|
|
33
|
+
1. RDMA over thunderbolt using JACCL
|
|
34
|
+
2. TCP/IP over thunderbolt using the ring backend
|
|
35
|
+
3. TCP/IP over ethernet using the ring backend
|
|
36
|
+
|
|
37
|
+
JACCL
|
|
38
|
+
^^^^^^^
|
|
39
|
+
|
|
40
|
+
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
|
|
41
|
+
following command to configure the nodes and create the hostfile.
|
|
42
|
+
|
|
43
|
+
.. code-block::
|
|
44
|
+
|
|
45
|
+
mlx.distributed_config --verbose --backend jaccl \
|
|
46
|
+
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
|
|
47
|
+
--auto-setup --output m3-ultra-jaccl.json
|
|
48
|
+
|
|
49
|
+
Let's walk through the steps that the script takes to configure the nodes.
|
|
50
|
+
|
|
51
|
+
1. ssh to all nodes to verify that they are reachable
|
|
52
|
+
2. Extract the thunderbolt connectivity. Namely run commands on each node to
|
|
53
|
+
calculate which node is connected to which other node.
|
|
54
|
+
3. Verify that we have a valid fully connected mesh
|
|
55
|
+
4. Check that RDMA is enabled
|
|
56
|
+
5. Extract the ethernet IP from interface en0
|
|
57
|
+
6. Disable the thunderbolt bridge and set up peer to peer networks for each
|
|
58
|
+
thunderbolt cable
|
|
59
|
+
7. Write the hostfile
|
|
60
|
+
|
|
61
|
+
Knowing the above steps allows you to manually configure the nodes but also
|
|
62
|
+
debug any configuration issue. For instance changing the Ethernet IP to a
|
|
63
|
+
different interface directly in the config is possible (as long as it is
|
|
64
|
+
reachable from all nodes).
|
|
65
|
+
|
|
66
|
+
The ``--auto-setup`` argument requires password-less sudo on each node. If it
|
|
67
|
+
isn't available then the configuration script will print commands to be run on
|
|
68
|
+
each node.
|
|
69
|
+
|
|
70
|
+
Ring over thunderbolt
|
|
71
|
+
^^^^^^^^^^^^^^^^^^^^^
|
|
72
|
+
|
|
73
|
+
Setting up a ring backend over thunderbolt only requires changing the
|
|
74
|
+
``--backend`` from ``jaccl`` to ``ring``.
|
|
75
|
+
|
|
76
|
+
The steps are very similar with the main difference being that instead of
|
|
77
|
+
verifying that the nodes are fully connected, the script attempts to identify a
|
|
78
|
+
ring topology (or multiple rings).
|
|
79
|
+
|
|
80
|
+
Ring over Ethernet
|
|
81
|
+
^^^^^^^^^^^^^^^^^^
|
|
82
|
+
|
|
83
|
+
Configuring the ring backend over ethernet doesn't require setting up network
|
|
84
|
+
interface and as such it simply extracts the ``en0`` IP from each node and
|
|
85
|
+
writes the hostfile.
|
|
86
|
+
|
|
87
|
+
Debugging cable connections
|
|
88
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
89
|
+
|
|
90
|
+
``mlx.distributed_config`` can help you debug the connectivity of your nodes
|
|
91
|
+
over thunderbolt by exporting a graph of the connections.
|
|
92
|
+
|
|
93
|
+
Running
|
|
94
|
+
|
|
95
|
+
.. code-block::
|
|
96
|
+
|
|
97
|
+
mlx.distributed_config --verbose \
|
|
98
|
+
--hosts host1,host2,host3,host4 \
|
|
99
|
+
--over thunderbolt --dot
|
|
100
|
+
|
|
101
|
+
will export a `GraphViz <https://graphviz.org>`_ representation of the
|
|
102
|
+
connections between the nodes which makes it very easy to figure out which
|
|
103
|
+
cable is not connected correctly.
|
|
104
|
+
|
|
105
|
+
See :ref:`the JACCL section <jaccl_section>` for an example.
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
``mlx.launch``
|
|
109
|
+
--------------
|
|
110
|
+
|
|
111
|
+
The minimal usage example of ``mlx.launch`` is simply
|
|
112
|
+
|
|
113
|
+
.. code:: shell
|
|
114
|
+
|
|
115
|
+
mlx.launch --hosts ip1,ip2 my_script.py
|
|
116
|
+
|
|
117
|
+
or for testing on localhost
|
|
118
|
+
|
|
119
|
+
.. code:: shell
|
|
120
|
+
|
|
121
|
+
mlx.launch -n 2 my_script.py
|
|
122
|
+
|
|
123
|
+
The ``mlx.launch`` command connects to the provided host and launches the input
|
|
124
|
+
script on each host. It monitors each of the launched processes and terminates
|
|
125
|
+
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
|
126
|
+
It also takes care of forwarding the output of each remote process to stdout
|
|
127
|
+
and stderr respectively.
|
|
128
|
+
|
|
129
|
+
Importantly, it also broadcasts stdin to each process which enables interactive
|
|
130
|
+
programs to work in distributed mode as well as debugging using the interactive
|
|
131
|
+
debugger.
|
|
132
|
+
|
|
133
|
+
Providing Hosts
|
|
134
|
+
^^^^^^^^^^^^^^^^
|
|
135
|
+
|
|
136
|
+
Hosts can be provided as command line arguments, like above, but the way that
|
|
137
|
+
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
|
|
138
|
+
a very simple schema. It is simply a list of objects that define each host via
|
|
139
|
+
a hostname to ssh to and a list of IPs to utilize for the communication.
|
|
140
|
+
|
|
141
|
+
.. code:: json
|
|
142
|
+
|
|
143
|
+
[
|
|
144
|
+
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
|
|
145
|
+
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
|
|
149
|
+
with IPs corresponding to the ``en0`` interface.
|
|
150
|
+
|
|
151
|
+
Setting up Remote Hosts
|
|
152
|
+
^^^^^^^^^^^^^^^^^^^^^^^^
|
|
153
|
+
|
|
154
|
+
In order to be able to launch the script on each host we need to be able to
|
|
155
|
+
connect via ssh. Moreover the input script and python binary need to be on each
|
|
156
|
+
host and on the same path. A good checklist to debug errors is the following:
|
|
157
|
+
|
|
158
|
+
* ``ssh hostname`` works without asking for password or host confirmation
|
|
159
|
+
* the python binary is available on all hosts at the same path. You can use
|
|
160
|
+
``mlx.launch --print-python`` to see what that path is.
|
|
161
|
+
* the script you want to run is available on all hosts at the same path
|
|
162
|
+
|
|
163
|
+
If you are launching from a node with a completely different setup than the
|
|
164
|
+
nodes that the program will run on, you can specify ``--no-verify-script`` so
|
|
165
|
+
that ``mlx.launch`` does not attempt to verify that the executable and script
|
|
166
|
+
exist locally before launching the distributed job.
|
|
167
|
+
|
|
168
|
+
.. _ring_specifics:
|
|
169
|
+
|
|
170
|
+
Ring Specifics
|
|
171
|
+
^^^^^^^^^^^^^^
|
|
172
|
+
|
|
173
|
+
The :ref:`ring <ring_section>` backend, which is also the default
|
|
174
|
+
backend, can be explicitly selected with the argument ``--backend ring``. The
|
|
175
|
+
ring backend has some specific requirements and arguments that are different to
|
|
176
|
+
other backends:
|
|
177
|
+
|
|
178
|
+
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
|
179
|
+
ssh to a hostname that does not correspond to the IP we want to bind to we
|
|
180
|
+
have to provide a hostfile.
|
|
181
|
+
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
|
182
|
+
Specifically rank 0 for the first IP will use this port and each subsequent
|
|
183
|
+
IP or rank will add 1 to this port.
|
|
184
|
+
* ``--connections-per-ip`` allows us to increase the number of connections
|
|
185
|
+
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
|
186
|
+
``mpirun``.
|
|
187
|
+
|
|
188
|
+
.. _jaccl_specifics:
|
|
189
|
+
|
|
190
|
+
JACCL Specifics
|
|
191
|
+
^^^^^^^^^^^^^^^^
|
|
192
|
+
|
|
193
|
+
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
|
|
194
|
+
``--backend jaccl``. A hostfile is necessary to launch with this backend
|
|
195
|
+
because it needs to contain the RDMA devices connecting each node to each other
|
|
196
|
+
node.
|
|
197
|
+
|
|
198
|
+
NCCL Specifics
|
|
199
|
+
^^^^^^^^^^^^^^
|
|
200
|
+
|
|
201
|
+
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
|
|
202
|
+
environments. When launching from a Mac to a Linux machine with CUDA then the
|
|
203
|
+
backend should be selected using ``--backend nccl``.
|
|
204
|
+
|
|
205
|
+
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
|
|
206
|
+
multi-gpu jobs. For instance
|
|
207
|
+
|
|
208
|
+
.. code-block::
|
|
209
|
+
|
|
210
|
+
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
|
|
211
|
+
|
|
212
|
+
will attempt to launch 16 processes, 8 on each node that will all run
|
|
213
|
+
``my-job.sh``.
|
|
214
|
+
|
|
215
|
+
.. _mpi_specifics:
|
|
216
|
+
|
|
217
|
+
MPI Specifics
|
|
218
|
+
^^^^^^^^^^^^^
|
|
219
|
+
|
|
220
|
+
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
|
221
|
+
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
|
222
|
+
|
|
223
|
+
* The IPs in the hostfile are ignored
|
|
224
|
+
* The ssh connectivity requirement is stronger as every node needs to be able
|
|
225
|
+
to connect to every other node
|
|
226
|
+
* ``mpirun`` needs to be available on every node at the same path
|
|
227
|
+
|
|
228
|
+
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
|
|
229
|
+
to choose a specific interface for the byte-transfer-layer of MPI we can call
|
|
230
|
+
``mlx.launch`` as follows:
|
|
231
|
+
|
|
232
|
+
.. code:: shell
|
|
233
|
+
|
|
234
|
+
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
.. _lazy eval:
|
|
2
|
+
|
|
3
|
+
Lazy Evaluation
|
|
4
|
+
===============
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
Why Lazy Evaluation
|
|
9
|
+
-------------------
|
|
10
|
+
|
|
11
|
+
When you perform operations in MLX, no computation actually happens. Instead a
|
|
12
|
+
compute graph is recorded. The actual computation only happens if an
|
|
13
|
+
:func:`eval` is performed.
|
|
14
|
+
|
|
15
|
+
MLX uses lazy evaluation because it has some nice features, some of which we
|
|
16
|
+
describe below.
|
|
17
|
+
|
|
18
|
+
Transforming Compute Graphs
|
|
19
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
20
|
+
|
|
21
|
+
Lazy evaluation lets us record a compute graph without actually doing any
|
|
22
|
+
computations. This is useful for function transformations like :func:`grad` and
|
|
23
|
+
:func:`vmap` and graph optimizations.
|
|
24
|
+
|
|
25
|
+
Currently, MLX does not compile and rerun compute graphs. They are all
|
|
26
|
+
generated dynamically. However, lazy evaluation makes it much easier to
|
|
27
|
+
integrate compilation for future performance enhancements.
|
|
28
|
+
|
|
29
|
+
Only Compute What You Use
|
|
30
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
31
|
+
|
|
32
|
+
In MLX you do not need to worry as much about computing outputs that are never
|
|
33
|
+
used. For example:
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
def fun(x):
|
|
38
|
+
a = fun1(x)
|
|
39
|
+
b = expensive_fun(a)
|
|
40
|
+
return a, b
|
|
41
|
+
|
|
42
|
+
y, _ = fun(x)
|
|
43
|
+
|
|
44
|
+
Here, we never actually compute the output of ``expensive_fun``. Use this
|
|
45
|
+
pattern with care though, as the graph of ``expensive_fun`` is still built, and
|
|
46
|
+
that has some cost associated to it.
|
|
47
|
+
|
|
48
|
+
Similarly, lazy evaluation can be beneficial for saving memory while keeping
|
|
49
|
+
code simple. Say you have a very large model ``Model`` derived from
|
|
50
|
+
:obj:`mlx.nn.Module`. You can instantiate this model with ``model = Model()``.
|
|
51
|
+
Typically, this will initialize all of the weights as ``float32``, but the
|
|
52
|
+
initialization does not actually compute anything until you perform an
|
|
53
|
+
:func:`eval`. If you update the model with ``float16`` weights, your maximum
|
|
54
|
+
consumed memory will be half that required if eager computation was used
|
|
55
|
+
instead.
|
|
56
|
+
|
|
57
|
+
This pattern is simple to do in MLX thanks to lazy computation:
|
|
58
|
+
|
|
59
|
+
.. code-block:: python
|
|
60
|
+
|
|
61
|
+
model = Model() # no memory used yet
|
|
62
|
+
model.load_weights("weights_fp16.safetensors")
|
|
63
|
+
|
|
64
|
+
When to Evaluate
|
|
65
|
+
----------------
|
|
66
|
+
|
|
67
|
+
A common question is when to use :func:`eval`. The trade-off is between
|
|
68
|
+
letting graphs get too large and not batching enough useful work.
|
|
69
|
+
|
|
70
|
+
For example:
|
|
71
|
+
|
|
72
|
+
.. code-block:: python
|
|
73
|
+
|
|
74
|
+
for _ in range(100):
|
|
75
|
+
a = a + b
|
|
76
|
+
mx.eval(a)
|
|
77
|
+
b = b * 2
|
|
78
|
+
mx.eval(b)
|
|
79
|
+
|
|
80
|
+
This is a bad idea because there is some fixed overhead with each graph
|
|
81
|
+
evaluation. On the other hand, there is some slight overhead which grows with
|
|
82
|
+
the compute graph size, so extremely large graphs (while computationally
|
|
83
|
+
correct) can be costly.
|
|
84
|
+
|
|
85
|
+
Luckily, a wide range of compute graph sizes work pretty well with MLX:
|
|
86
|
+
anything from a few tens of operations to many thousands of operations per
|
|
87
|
+
evaluation should be okay.
|
|
88
|
+
|
|
89
|
+
Most numerical computations have an iterative outer loop (e.g. the iteration in
|
|
90
|
+
stochastic gradient descent). A natural and usually efficient place to use
|
|
91
|
+
:func:`eval` is at each iteration of this outer loop.
|
|
92
|
+
|
|
93
|
+
Here is a concrete example:
|
|
94
|
+
|
|
95
|
+
.. code-block:: python
|
|
96
|
+
|
|
97
|
+
for batch in dataset:
|
|
98
|
+
|
|
99
|
+
# Nothing has been evaluated yet
|
|
100
|
+
loss, grad = value_and_grad_fn(model, batch)
|
|
101
|
+
|
|
102
|
+
# Still nothing has been evaluated
|
|
103
|
+
optimizer.update(model, grad)
|
|
104
|
+
|
|
105
|
+
# Evaluate the loss and the new parameters which will
|
|
106
|
+
# run the full gradient computation and optimizer update
|
|
107
|
+
mx.eval(loss, model.parameters())
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
An important behavior to be aware of is when the graph will be implicitly
|
|
111
|
+
evaluated. Anytime you ``print`` an array, convert it to an
|
|
112
|
+
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
|
113
|
+
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
|
114
|
+
saving functions) will also evaluate the array.
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
|
118
|
+
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
|
119
|
+
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
|
120
|
+
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
|
121
|
+
will be a partial evaluation, computing only the forward pass.
|
|
122
|
+
|
|
123
|
+
Also, calling :func:`eval` on an array or set of arrays multiple times is
|
|
124
|
+
perfectly fine. This is effectively a no-op.
|
|
125
|
+
|
|
126
|
+
.. warning::
|
|
127
|
+
|
|
128
|
+
Using scalar arrays for control-flow will cause an evaluation.
|
|
129
|
+
|
|
130
|
+
Here is an example:
|
|
131
|
+
|
|
132
|
+
.. code-block:: python
|
|
133
|
+
|
|
134
|
+
def fun(x):
|
|
135
|
+
h, y = first_layer(x)
|
|
136
|
+
if y > 0: # An evaluation is done here!
|
|
137
|
+
z = second_layer_a(h)
|
|
138
|
+
else:
|
|
139
|
+
z = second_layer_b(h)
|
|
140
|
+
return z
|
|
141
|
+
|
|
142
|
+
Using arrays for control flow should be done with care. The above example works
|
|
143
|
+
and can even be used with gradient transformations. However, this can be very
|
|
144
|
+
inefficient if evaluations are done too frequently.
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
.. _numpy:
|
|
2
|
+
|
|
3
|
+
Conversion to NumPy and Other Frameworks
|
|
4
|
+
========================================
|
|
5
|
+
|
|
6
|
+
MLX array supports conversion between other frameworks with either:
|
|
7
|
+
|
|
8
|
+
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
|
9
|
+
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
|
10
|
+
|
|
11
|
+
Let's convert an array to NumPy and back.
|
|
12
|
+
|
|
13
|
+
.. code-block:: python
|
|
14
|
+
|
|
15
|
+
import mlx.core as mx
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
a = mx.arange(3)
|
|
19
|
+
b = np.array(a) # copy of a
|
|
20
|
+
c = mx.array(b) # copy of b
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
|
|
24
|
+
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
|
25
|
+
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
|
26
|
+
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
|
27
|
+
buffer format string does not match the dtype V item size 0.``
|
|
28
|
+
|
|
29
|
+
By default, NumPy copies data to a new array. This can be prevented by creating
|
|
30
|
+
an array view:
|
|
31
|
+
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
|
|
34
|
+
a = mx.arange(3)
|
|
35
|
+
a_view = np.array(a, copy=False)
|
|
36
|
+
print(a_view.flags.owndata) # False
|
|
37
|
+
a_view[0] = 1
|
|
38
|
+
print(a[0].item()) # 1
|
|
39
|
+
|
|
40
|
+
.. note::
|
|
41
|
+
|
|
42
|
+
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
|
43
|
+
with type ``float32``.
|
|
44
|
+
|
|
45
|
+
A NumPy array view is a normal NumPy array, except that it does not own its
|
|
46
|
+
memory. This means writing to the view is reflected in the original array.
|
|
47
|
+
|
|
48
|
+
While this is quite powerful to prevent copying arrays, it should be noted that
|
|
49
|
+
external changes to the memory of arrays cannot be reflected in gradients.
|
|
50
|
+
|
|
51
|
+
Let's demonstrate this in an example:
|
|
52
|
+
|
|
53
|
+
.. code-block:: python
|
|
54
|
+
|
|
55
|
+
def f(x):
|
|
56
|
+
x_view = np.array(x, copy=False)
|
|
57
|
+
x_view[:] *= x_view # modify memory without telling mx
|
|
58
|
+
return x.sum()
|
|
59
|
+
|
|
60
|
+
x = mx.array([3.0])
|
|
61
|
+
y, df = mx.value_and_grad(f)(x)
|
|
62
|
+
print("f(x) = x² =", y.item()) # 9.0
|
|
63
|
+
print("f'(x) = 2x !=", df.item()) # 1.0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
|
67
|
+
However, this modification is not reflected in the gradient, as seen in the
|
|
68
|
+
last line outputting ``1.0``, representing the gradient of the sum operation
|
|
69
|
+
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
|
70
|
+
gradient is incorporated. It's important to note that a similar issue arises
|
|
71
|
+
during array conversion and copying. For instance, a function defined as
|
|
72
|
+
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
|
73
|
+
even though no in-place operations on MLX memory are executed.
|
|
74
|
+
|
|
75
|
+
PyTorch
|
|
76
|
+
-------
|
|
77
|
+
|
|
78
|
+
.. warning::
|
|
79
|
+
|
|
80
|
+
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
|
81
|
+
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
|
82
|
+
|
|
83
|
+
PyTorch supports the buffer protocol, but it requires an explicit
|
|
84
|
+
:obj:`memoryview`.
|
|
85
|
+
|
|
86
|
+
.. code-block:: python
|
|
87
|
+
|
|
88
|
+
import mlx.core as mx
|
|
89
|
+
import torch
|
|
90
|
+
|
|
91
|
+
a = mx.arange(3)
|
|
92
|
+
b = torch.tensor(memoryview(a))
|
|
93
|
+
c = mx.array(b.numpy())
|
|
94
|
+
|
|
95
|
+
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
|
96
|
+
NumPy arrays with ``numpy()``.
|
|
97
|
+
|
|
98
|
+
JAX
|
|
99
|
+
---
|
|
100
|
+
JAX fully supports the buffer protocol.
|
|
101
|
+
|
|
102
|
+
.. code-block:: python
|
|
103
|
+
|
|
104
|
+
import mlx.core as mx
|
|
105
|
+
import jax.numpy as jnp
|
|
106
|
+
|
|
107
|
+
a = mx.arange(3)
|
|
108
|
+
b = jnp.array(a)
|
|
109
|
+
c = mx.array(b)
|
|
110
|
+
|
|
111
|
+
TensorFlow
|
|
112
|
+
----------
|
|
113
|
+
|
|
114
|
+
TensorFlow supports the buffer protocol, but it requires an explicit
|
|
115
|
+
:obj:`memoryview`.
|
|
116
|
+
|
|
117
|
+
.. code-block:: python
|
|
118
|
+
|
|
119
|
+
import mlx.core as mx
|
|
120
|
+
import tensorflow as tf
|
|
121
|
+
|
|
122
|
+
a = mx.arange(3)
|
|
123
|
+
b = tf.constant(memoryview(a))
|
|
124
|
+
c = mx.array(b)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
Quick Start Guide
|
|
2
|
+
=================
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
Basics
|
|
6
|
+
------
|
|
7
|
+
|
|
8
|
+
.. currentmodule:: mlx.core
|
|
9
|
+
|
|
10
|
+
Import ``mlx.core`` and make an :class:`array`:
|
|
11
|
+
|
|
12
|
+
.. code-block:: python
|
|
13
|
+
|
|
14
|
+
>> import mlx.core as mx
|
|
15
|
+
>> a = mx.array([1, 2, 3, 4])
|
|
16
|
+
>> a.shape
|
|
17
|
+
[4]
|
|
18
|
+
>> a.dtype
|
|
19
|
+
int32
|
|
20
|
+
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
|
|
21
|
+
>> b.dtype
|
|
22
|
+
float32
|
|
23
|
+
|
|
24
|
+
Operations in MLX are lazy. The outputs of MLX operations are not computed
|
|
25
|
+
until they are needed. To force an array to be evaluated use
|
|
26
|
+
:func:`eval`. Arrays will automatically be evaluated in a few cases. For
|
|
27
|
+
example, inspecting a scalar with :meth:`array.item`, printing an array,
|
|
28
|
+
or converting an array from :class:`array` to :class:`numpy.ndarray` all
|
|
29
|
+
automatically evaluate the array.
|
|
30
|
+
|
|
31
|
+
.. code-block:: python
|
|
32
|
+
|
|
33
|
+
>> c = a + b # c not yet evaluated
|
|
34
|
+
>> mx.eval(c) # evaluates c
|
|
35
|
+
>> c = a + b
|
|
36
|
+
>> print(c) # Also evaluates c
|
|
37
|
+
array([2, 4, 6, 8], dtype=float32)
|
|
38
|
+
>> c = a + b
|
|
39
|
+
>> import numpy as np
|
|
40
|
+
>> np.array(c) # Also evaluates c
|
|
41
|
+
array([2., 4., 6., 8.], dtype=float32)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
See the page on :ref:`Lazy Evaluation <lazy eval>` for more details.
|
|
45
|
+
|
|
46
|
+
Function and Graph Transformations
|
|
47
|
+
----------------------------------
|
|
48
|
+
|
|
49
|
+
MLX has standard function transformations like :func:`grad` and :func:`vmap`.
|
|
50
|
+
Transformations can be composed arbitrarily. For example
|
|
51
|
+
``grad(vmap(grad(fn)))`` (or any other composition) is allowed.
|
|
52
|
+
|
|
53
|
+
.. code-block:: python
|
|
54
|
+
|
|
55
|
+
>> x = mx.array(0.0)
|
|
56
|
+
>> mx.sin(x)
|
|
57
|
+
array(0, dtype=float32)
|
|
58
|
+
>> mx.grad(mx.sin)(x)
|
|
59
|
+
array(1, dtype=float32)
|
|
60
|
+
>> mx.grad(mx.grad(mx.sin))(x)
|
|
61
|
+
array(-0, dtype=float32)
|
|
62
|
+
|
|
63
|
+
Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
|
64
|
+
and :func:`jvp` for Jacobian-vector products.
|
|
65
|
+
|
|
66
|
+
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
|
67
|
+
gradient with respect to the function's input.
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
.. _saving_and_loading:
|
|
2
|
+
|
|
3
|
+
Saving and Loading Arrays
|
|
4
|
+
=========================
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
MLX supports multiple array serialization formats.
|
|
9
|
+
|
|
10
|
+
.. list-table:: Serialization Formats
|
|
11
|
+
:widths: 20 8 25 25
|
|
12
|
+
:header-rows: 1
|
|
13
|
+
|
|
14
|
+
* - Format
|
|
15
|
+
- Extension
|
|
16
|
+
- Function
|
|
17
|
+
- Notes
|
|
18
|
+
* - NumPy
|
|
19
|
+
- ``.npy``
|
|
20
|
+
- :func:`save`
|
|
21
|
+
- Single arrays only
|
|
22
|
+
* - NumPy archive
|
|
23
|
+
- ``.npz``
|
|
24
|
+
- :func:`savez` and :func:`savez_compressed`
|
|
25
|
+
- Multiple arrays
|
|
26
|
+
* - Safetensors
|
|
27
|
+
- ``.safetensors``
|
|
28
|
+
- :func:`save_safetensors`
|
|
29
|
+
- Multiple arrays
|
|
30
|
+
* - GGUF
|
|
31
|
+
- ``.gguf``
|
|
32
|
+
- :func:`save_gguf`
|
|
33
|
+
- Multiple arrays
|
|
34
|
+
|
|
35
|
+
The :func:`load` function will load any of the supported serialization
|
|
36
|
+
formats. It determines the format from the extensions. The output of
|
|
37
|
+
:func:`load` depends on the format.
|
|
38
|
+
|
|
39
|
+
Here's an example of saving a single array to a file:
|
|
40
|
+
|
|
41
|
+
.. code-block:: shell
|
|
42
|
+
|
|
43
|
+
>>> a = mx.array([1.0])
|
|
44
|
+
>>> mx.save("array", a)
|
|
45
|
+
|
|
46
|
+
The array ``a`` will be saved in the file ``array.npy`` (notice the extension
|
|
47
|
+
is automatically added). Including the extension is optional; if it is missing
|
|
48
|
+
it will be added. You can load the array with:
|
|
49
|
+
|
|
50
|
+
.. code-block:: shell
|
|
51
|
+
|
|
52
|
+
>>> mx.load("array.npy")
|
|
53
|
+
array([1], dtype=float32)
|
|
54
|
+
|
|
55
|
+
Here's an example of saving several arrays to a single file:
|
|
56
|
+
|
|
57
|
+
.. code-block:: shell
|
|
58
|
+
|
|
59
|
+
>>> a = mx.array([1.0])
|
|
60
|
+
>>> b = mx.array([2.0])
|
|
61
|
+
>>> mx.savez("arrays", a, b=b)
|
|
62
|
+
|
|
63
|
+
For compatibility with :func:`numpy.savez` the MLX :func:`savez` takes arrays
|
|
64
|
+
as arguments. If the keywords are missing, then default names will be
|
|
65
|
+
provided. This can be loaded with:
|
|
66
|
+
|
|
67
|
+
.. code-block:: shell
|
|
68
|
+
|
|
69
|
+
>>> mx.load("arrays.npz")
|
|
70
|
+
{'b': array([2], dtype=float32), 'arr_0': array([1], dtype=float32)}
|
|
71
|
+
|
|
72
|
+
In this case :func:`load` returns a dictionary of names to arrays.
|
|
73
|
+
|
|
74
|
+
The functions :func:`save_safetensors` and :func:`save_gguf` are similar to
|
|
75
|
+
:func:`savez`, but they take as input a :obj:`dict` of string names to arrays:
|
|
76
|
+
|
|
77
|
+
.. code-block:: shell
|
|
78
|
+
|
|
79
|
+
>>> a = mx.array([1.0])
|
|
80
|
+
>>> b = mx.array([2.0])
|
|
81
|
+
>>> mx.save_safetensors("arrays", {"a": a, "b": b})
|