mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,572 @@
|
|
|
1
|
+
.. _usage_distributed:
|
|
2
|
+
|
|
3
|
+
Distributed Communication
|
|
4
|
+
=========================
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core.distributed
|
|
7
|
+
|
|
8
|
+
MLX supports distributed communication operations that allow the computational cost
|
|
9
|
+
of training or inference to be shared across many physical machines. At the
|
|
10
|
+
moment we support several different communication backends introduced below.
|
|
11
|
+
|
|
12
|
+
.. list-table::
|
|
13
|
+
:widths: 20 80
|
|
14
|
+
:header-rows: 1
|
|
15
|
+
|
|
16
|
+
* - Backend
|
|
17
|
+
- Description
|
|
18
|
+
* - :ref:`MPI <mpi_section>`
|
|
19
|
+
- A full featured and mature distributed communications library.
|
|
20
|
+
* - :ref:`RING <ring_section>`
|
|
21
|
+
- Ring all reduce and all gather over TCP sockets. Always available and
|
|
22
|
+
usually faster than MPI.
|
|
23
|
+
* - :ref:`JACCL <jaccl_section>`
|
|
24
|
+
- Low latency communication with RDMA over thunderbolt. Necessary for
|
|
25
|
+
things like tensor parallelism.
|
|
26
|
+
* - :ref:`NCCL <nccl_section>`
|
|
27
|
+
- The backend of choice for CUDA environments.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
The list of all currently supported operations and their documentation can be
|
|
31
|
+
seen in the :ref:`API docs<distributed>`.
|
|
32
|
+
|
|
33
|
+
Getting Started
|
|
34
|
+
---------------
|
|
35
|
+
|
|
36
|
+
A distributed program in MLX is as simple as:
|
|
37
|
+
|
|
38
|
+
.. code:: python
|
|
39
|
+
|
|
40
|
+
import mlx.core as mx
|
|
41
|
+
|
|
42
|
+
world = mx.distributed.init()
|
|
43
|
+
x = mx.distributed.all_sum(mx.ones(10))
|
|
44
|
+
print(world.rank(), x)
|
|
45
|
+
|
|
46
|
+
The program above sums the array ``mx.ones(10)`` across all
|
|
47
|
+
distributed processes. However, when this script is run with ``python`` only
|
|
48
|
+
one process is launched and no distributed communication takes place. Namely,
|
|
49
|
+
all operations in ``mx.distributed`` are noops when the distributed group has a
|
|
50
|
+
size of one. This property allows us to avoid code that checks if we are in a
|
|
51
|
+
distributed setting similar to the one below:
|
|
52
|
+
|
|
53
|
+
.. code:: python
|
|
54
|
+
|
|
55
|
+
import mlx.core as mx
|
|
56
|
+
|
|
57
|
+
x = ...
|
|
58
|
+
world = mx.distributed.init()
|
|
59
|
+
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
|
60
|
+
if world.size() > 1:
|
|
61
|
+
x = mx.distributed.all_sum(x)
|
|
62
|
+
|
|
63
|
+
Running Distributed Programs
|
|
64
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
65
|
+
|
|
66
|
+
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
|
67
|
+
Continuing with our initial example we can run it on localhost with 4 processes using
|
|
68
|
+
|
|
69
|
+
.. code:: shell
|
|
70
|
+
|
|
71
|
+
$ mlx.launch -n 4 my_script.py
|
|
72
|
+
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
73
|
+
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
74
|
+
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
75
|
+
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
76
|
+
|
|
77
|
+
We can also run it on some remote hosts by providing their IPs (provided that
|
|
78
|
+
the script exists on all hosts and they are reachable by ssh)
|
|
79
|
+
|
|
80
|
+
.. code:: shell
|
|
81
|
+
|
|
82
|
+
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
|
83
|
+
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
84
|
+
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
85
|
+
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
86
|
+
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
|
87
|
+
|
|
88
|
+
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
|
89
|
+
information on using ``mlx.launch``.
|
|
90
|
+
|
|
91
|
+
Selecting Backend
|
|
92
|
+
^^^^^^^^^^^^^^^^^
|
|
93
|
+
|
|
94
|
+
You can select the backend you want to use when calling :func:`init` by passing
|
|
95
|
+
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
|
96
|
+
available backends. If they all fail then a singleton group is created.
|
|
97
|
+
|
|
98
|
+
.. note::
|
|
99
|
+
After a distributed backend is successfully initialized :func:`init` will
|
|
100
|
+
return **the same backend** if called without arguments or with backend set to
|
|
101
|
+
``any``.
|
|
102
|
+
|
|
103
|
+
The following examples aim to clarify the backend initialization logic in MLX:
|
|
104
|
+
|
|
105
|
+
.. code:: python
|
|
106
|
+
|
|
107
|
+
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
|
108
|
+
world = mx.distributed.init(backend="mpi")
|
|
109
|
+
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
|
110
|
+
|
|
111
|
+
# Case 2: Initialize any backend
|
|
112
|
+
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
|
113
|
+
world2 = mx.distributed.init() # same as above
|
|
114
|
+
|
|
115
|
+
# Case 3: Initialize both backends at the same time
|
|
116
|
+
world_mpi = mx.distributed.init(backend="mpi")
|
|
117
|
+
world_ring = mx.distributed.init(backend="ring")
|
|
118
|
+
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
|
119
|
+
|
|
120
|
+
Distributed Program Examples
|
|
121
|
+
----------------------------
|
|
122
|
+
|
|
123
|
+
- :ref:`Data Parallelism <data_parallelism>`
|
|
124
|
+
- :ref:`Tensor Parallelism <tensor_parallelism>`
|
|
125
|
+
|
|
126
|
+
.. _ring_section:
|
|
127
|
+
|
|
128
|
+
Getting Started with Ring
|
|
129
|
+
-------------------------
|
|
130
|
+
|
|
131
|
+
The ring backend does not depend on any third party library so it is always
|
|
132
|
+
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
|
133
|
+
As the name suggests the nodes are connected in a ring which means that rank 1
|
|
134
|
+
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
|
135
|
+
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
|
136
|
+
arbitrary sender and receiver are not supported in the ring backend.
|
|
137
|
+
|
|
138
|
+
Defining a Ring
|
|
139
|
+
^^^^^^^^^^^^^^^
|
|
140
|
+
|
|
141
|
+
The easiest way to define and use a ring is via a JSON hostfile and the
|
|
142
|
+
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
|
143
|
+
defines a hostname to ssh into to run commands on this node and one or more IPs
|
|
144
|
+
that this node will listen to for connections.
|
|
145
|
+
|
|
146
|
+
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
|
147
|
+
rank 0, ``hostname2`` rank 1 etc.
|
|
148
|
+
|
|
149
|
+
.. code:: json
|
|
150
|
+
|
|
151
|
+
[
|
|
152
|
+
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
|
153
|
+
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
|
154
|
+
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
|
155
|
+
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
|
159
|
+
node, run the script which will listen for connections in each of the provided
|
|
160
|
+
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
|
161
|
+
connection from ``123.123.123.4`` and so on and so forth.
|
|
162
|
+
|
|
163
|
+
Thunderbolt Ring
|
|
164
|
+
^^^^^^^^^^^^^^^^
|
|
165
|
+
|
|
166
|
+
Although the ring backend can have benefits over MPI even for Ethernet, its
|
|
167
|
+
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
|
168
|
+
Setting up such thunderbolt rings can be done manually, but is a relatively
|
|
169
|
+
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
|
170
|
+
|
|
171
|
+
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
|
172
|
+
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
|
173
|
+
utility as follows:
|
|
174
|
+
|
|
175
|
+
.. code:: shell
|
|
176
|
+
|
|
177
|
+
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
|
|
178
|
+
|
|
179
|
+
By default the script will attempt to discover the thunderbolt ring and provide
|
|
180
|
+
you with the commands to configure each node as well as the ``hostfile.json``
|
|
181
|
+
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
|
182
|
+
then ``--auto-setup`` can be used to configure them automatically.
|
|
183
|
+
|
|
184
|
+
If you want to go through the process manually, the steps are as follows:
|
|
185
|
+
|
|
186
|
+
* Disable the thunderbolt bridge interface
|
|
187
|
+
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
|
188
|
+
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
|
189
|
+
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
|
190
|
+
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
|
191
|
+
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
|
192
|
+
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
|
193
|
+
the commands prepared by the utility script.
|
|
194
|
+
|
|
195
|
+
.. _jaccl_section:
|
|
196
|
+
|
|
197
|
+
Getting Started with JACCL
|
|
198
|
+
--------------------------
|
|
199
|
+
|
|
200
|
+
Starting from macOS 26.2, RDMA over thunderbolt is available and
|
|
201
|
+
enables low-latency communication between Macs with thunderbolt 5. MLX provides
|
|
202
|
+
the JACCL backend that uses this functionality to achieve communication latency
|
|
203
|
+
an order of magnitude lower than the ring backend.
|
|
204
|
+
|
|
205
|
+
.. note::
|
|
206
|
+
|
|
207
|
+
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
|
|
208
|
+
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
|
|
209
|
+
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
|
|
210
|
+
at Apple.
|
|
211
|
+
|
|
212
|
+
Enabling RDMA
|
|
213
|
+
^^^^^^^^^^^^^
|
|
214
|
+
|
|
215
|
+
Until the feature matures, enabling RDMA over thunderbolt is slightly more
|
|
216
|
+
involved and **cannot** be done remotely even with sudo. In fact, it has to be
|
|
217
|
+
done in macOS recovery:
|
|
218
|
+
|
|
219
|
+
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
|
|
220
|
+
2. Open the Terminal by going to Utilities -> Terminal.
|
|
221
|
+
3. Run ``rdma_ctl enable``.
|
|
222
|
+
4. Reboot.
|
|
223
|
+
|
|
224
|
+
To verify that you have successfully enabled Thunderbolt RDMA you can run
|
|
225
|
+
``ibv_devices`` which should produce something like the following for an M3 Ultra.
|
|
226
|
+
|
|
227
|
+
.. code-block:: bash
|
|
228
|
+
|
|
229
|
+
~ % ibv_devices
|
|
230
|
+
device node GUID
|
|
231
|
+
------ ----------------
|
|
232
|
+
rdma_en2 8096a9d9edbaac05
|
|
233
|
+
rdma_en3 8196a9d9edbaac05
|
|
234
|
+
rdma_en5 8396a9d9edbaac05
|
|
235
|
+
rdma_en4 8296a9d9edbaac05
|
|
236
|
+
rdma_en6 8496a9d9edbaac05
|
|
237
|
+
rdma_en7 8596a9d9edbaac05
|
|
238
|
+
|
|
239
|
+
Defining a Mesh
|
|
240
|
+
^^^^^^^^^^^^^^^
|
|
241
|
+
|
|
242
|
+
The JACCL backend supports only fully connected topologies. Namely, there needs
|
|
243
|
+
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
|
|
244
|
+
the following topology visualizations, the left one is valid because there is a
|
|
245
|
+
connection from any node to any other node, while for the one on the right M3
|
|
246
|
+
Ultra 1 is not connected to M3 Ultra 2.
|
|
247
|
+
|
|
248
|
+
.. raw:: html
|
|
249
|
+
|
|
250
|
+
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
|
|
251
|
+
<div>
|
|
252
|
+
<img src="../_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
|
|
253
|
+
<p>Fully connected mesh of four M3 Ultra.</p>
|
|
254
|
+
</div>
|
|
255
|
+
<div>
|
|
256
|
+
<img src="../_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
|
|
257
|
+
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
|
|
258
|
+
</div>
|
|
259
|
+
</div>
|
|
260
|
+
|
|
261
|
+
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
|
|
262
|
+
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
|
|
263
|
+
|
|
264
|
+
- Hostnames to use for launching scripts via ssh
|
|
265
|
+
- An IP for rank 0 that is reachable by all nodes
|
|
266
|
+
- A list of rdma devices that connect each node to each other node
|
|
267
|
+
|
|
268
|
+
The following JSON defines the valid 4-node mesh from the image above.
|
|
269
|
+
|
|
270
|
+
.. code-block:: json
|
|
271
|
+
|
|
272
|
+
[
|
|
273
|
+
{
|
|
274
|
+
"ssh": "m3-ultra-1",
|
|
275
|
+
"ips": ["123.123.123.1"],
|
|
276
|
+
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
|
|
277
|
+
},
|
|
278
|
+
{
|
|
279
|
+
"ssh": "m3-ultra-2",
|
|
280
|
+
"ips": [],
|
|
281
|
+
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
|
|
282
|
+
},
|
|
283
|
+
{
|
|
284
|
+
"ssh": "m3-ultra-3",
|
|
285
|
+
"ips": [],
|
|
286
|
+
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
|
|
287
|
+
},
|
|
288
|
+
{
|
|
289
|
+
"ssh": "m3-ultra-4",
|
|
290
|
+
"ips": [],
|
|
291
|
+
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
|
|
292
|
+
}
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
|
|
296
|
+
disabling the thunderbolt bridge is still required as well as setting up
|
|
297
|
+
isolated local networks for each thunderbolt connection.
|
|
298
|
+
|
|
299
|
+
All of the above can be done instead via ``mlx.distributed_config``. This helper
|
|
300
|
+
script will
|
|
301
|
+
|
|
302
|
+
- ssh into each node
|
|
303
|
+
- extract the thunderbolt connectivity
|
|
304
|
+
- check for a valid mesh
|
|
305
|
+
- provide the commands to configure each node (or run them if sudo is available)
|
|
306
|
+
- generate the hostfile to be used with ``mlx.launch``
|
|
307
|
+
|
|
308
|
+
Putting It All Together
|
|
309
|
+
^^^^^^^^^^^^^^^^^^^^^^^^
|
|
310
|
+
|
|
311
|
+
For example launching a distributed MLX script that uses JACCL is fairly simple
|
|
312
|
+
if the nodes are reachable via ssh and have password-less sudo.
|
|
313
|
+
|
|
314
|
+
First, connect all the thunderbolt cables. Then we can verify the connections
|
|
315
|
+
by using the ``mlx.distributed_config`` script to visualize them.
|
|
316
|
+
|
|
317
|
+
.. code-block::
|
|
318
|
+
|
|
319
|
+
mlx.distributed_config --verbose \
|
|
320
|
+
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
|
|
321
|
+
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
|
|
322
|
+
|
|
323
|
+
After making sure that everything looks right we can auto-configure the nodes
|
|
324
|
+
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
|
|
325
|
+
|
|
326
|
+
.. code-block::
|
|
327
|
+
|
|
328
|
+
mlx.distributed_config --verbose \
|
|
329
|
+
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
|
|
330
|
+
--over thunderbolt --backend jaccl \
|
|
331
|
+
--auto-setup --output m3-ultra-jaccl.json
|
|
332
|
+
|
|
333
|
+
And now we are ready to run a distributed MLX script such as distributed inference
|
|
334
|
+
of a gigantic model using MLX LM.
|
|
335
|
+
|
|
336
|
+
.. code-block::
|
|
337
|
+
|
|
338
|
+
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
|
|
339
|
+
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
|
|
340
|
+
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-R1-0528-4bit
|
|
341
|
+
|
|
342
|
+
.. note::
|
|
343
|
+
|
|
344
|
+
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
|
|
345
|
+
different, faster way of synchronizing between the GPU and the CPU. It is
|
|
346
|
+
not specific to the JACCL backend and can be used in all cases where the CPU
|
|
347
|
+
and GPU need to collaborate for some computation and is pretty critical for
|
|
348
|
+
low-latency communication since the communication is done by the CPU.
|
|
349
|
+
|
|
350
|
+
.. _nccl_section:
|
|
351
|
+
|
|
352
|
+
Getting Started with NCCL
|
|
353
|
+
-------------------------
|
|
354
|
+
|
|
355
|
+
MLX on CUDA environments ships with the ability to talk to `NCCL
|
|
356
|
+
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
|
|
357
|
+
communication library that supports both multi-gpu and multi-node setups.
|
|
358
|
+
|
|
359
|
+
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
|
|
360
|
+
it takes to run a distributed job is
|
|
361
|
+
|
|
362
|
+
.. code-block::
|
|
363
|
+
|
|
364
|
+
mlx.launch -n 8 test.py
|
|
365
|
+
|
|
366
|
+
# perfect for interactive scripts
|
|
367
|
+
mlx.launch -n 8 python -m mlx_lm chat --model my-model
|
|
368
|
+
|
|
369
|
+
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
|
|
370
|
+
with the same ease
|
|
371
|
+
|
|
372
|
+
.. code-block::
|
|
373
|
+
|
|
374
|
+
mlx.launch --hosts my-cuda-node -n 8 test.py
|
|
375
|
+
|
|
376
|
+
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
|
|
377
|
+
because the cluster scheduler will be the one launching the processes. You can
|
|
378
|
+
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
|
|
379
|
+
order for the MLX NCCL backend to be initialized correctly.
|
|
380
|
+
|
|
381
|
+
.. _mpi_section:
|
|
382
|
+
|
|
383
|
+
Getting Started with MPI
|
|
384
|
+
------------------------
|
|
385
|
+
|
|
386
|
+
MLX already comes with the ability to "talk" to `MPI
|
|
387
|
+
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
|
|
388
|
+
on the machine. Launching distributed MLX programs that use MPI can be done
|
|
389
|
+
with ``mpirun`` as expected. However, in the following examples we will be
|
|
390
|
+
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
|
|
391
|
+
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
|
|
392
|
+
shared library.
|
|
393
|
+
|
|
394
|
+
The simplest possible usage is the following which, assuming the minimal
|
|
395
|
+
example in the beginning of this page, should result in:
|
|
396
|
+
|
|
397
|
+
.. code:: shell
|
|
398
|
+
|
|
399
|
+
$ mlx.launch --backend mpi -n 2 test.py
|
|
400
|
+
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
|
401
|
+
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
|
402
|
+
|
|
403
|
+
The above launches two processes on the same (local) machine and we can see
|
|
404
|
+
both standard output streams. The processes send the array of 1s to each other
|
|
405
|
+
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
|
406
|
+
print 4 etc.
|
|
407
|
+
|
|
408
|
+
Installing MPI
|
|
409
|
+
^^^^^^^^^^^^^^
|
|
410
|
+
|
|
411
|
+
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
|
|
412
|
+
compiled from source. Most of our testing is done using ``openmpi`` installed
|
|
413
|
+
with the Anaconda package manager as follows:
|
|
414
|
+
|
|
415
|
+
.. code:: shell
|
|
416
|
+
|
|
417
|
+
$ conda install conda-forge::openmpi
|
|
418
|
+
|
|
419
|
+
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
|
|
420
|
+
so that MLX can find it and load it at runtime. This can simply be achieved by
|
|
421
|
+
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
|
422
|
+
done automatically by ``mlx.launch``. Some environments use a non-standard
|
|
423
|
+
library filename that can be specified using the ``MPI_LIBNAME`` environment
|
|
424
|
+
variable. This is automatically taken care of by ``mlx.launch`` as well.
|
|
425
|
+
|
|
426
|
+
.. code:: shell
|
|
427
|
+
|
|
428
|
+
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
|
|
429
|
+
$ # or simply
|
|
430
|
+
$ mlx.launch -n 2 test.py
|
|
431
|
+
|
|
432
|
+
Setting up Remote Hosts
|
|
433
|
+
^^^^^^^^^^^^^^^^^^^^^^^
|
|
434
|
+
|
|
435
|
+
MPI can automatically connect to remote hosts and set up the communication over
|
|
436
|
+
the network if the remote hosts can be accessed via ssh. A good checklist to
|
|
437
|
+
debug connectivity issues is the following:
|
|
438
|
+
|
|
439
|
+
* ``ssh hostname`` works from all machines to all machines without asking for
|
|
440
|
+
password or host confirmation
|
|
441
|
+
* ``mpirun`` is accessible on all machines.
|
|
442
|
+
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
|
443
|
+
in the ``.ssh/config`` files on all machines.
|
|
444
|
+
|
|
445
|
+
Tuning MPI All Reduce
|
|
446
|
+
^^^^^^^^^^^^^^^^^^^^^
|
|
447
|
+
|
|
448
|
+
.. note::
|
|
449
|
+
|
|
450
|
+
For faster all reduce consider using the ring backend either with Thunderbolt
|
|
451
|
+
connections or over Ethernet.
|
|
452
|
+
|
|
453
|
+
Configure MPI to use N tcp connections between each host to improve bandwidth
|
|
454
|
+
by passing ``--mca btl_tcp_links N``.
|
|
455
|
+
|
|
456
|
+
Force MPI to use the most performant network interface by setting ``--mca
|
|
457
|
+
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
|
458
|
+
to use.
|
|
459
|
+
|
|
460
|
+
.. _no_mlx_launch:
|
|
461
|
+
|
|
462
|
+
Distributed Without ``mlx.launch``
|
|
463
|
+
----------------------------------
|
|
464
|
+
|
|
465
|
+
None of the implementations of the distributed backends require launching with
|
|
466
|
+
``mlx.launch``. The script simply connects to each host. Starts a process per
|
|
467
|
+
rank and sets up the necessary environment variables before delegating to your
|
|
468
|
+
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
|
|
469
|
+
for more details.
|
|
470
|
+
|
|
471
|
+
For many use-cases this will be the easiest way to perform distributed
|
|
472
|
+
computations in MLX. However, there may be reasons that you cannot or should
|
|
473
|
+
not use ``mlx.launch``. A common such case is the use of a scheduler that
|
|
474
|
+
starts all the processes for you on machines undetermined at the time of
|
|
475
|
+
scheduling the job.
|
|
476
|
+
|
|
477
|
+
Below we list the environment variables required to use each backend.
|
|
478
|
+
|
|
479
|
+
Ring
|
|
480
|
+
^^^^^^
|
|
481
|
+
|
|
482
|
+
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
|
483
|
+
the process.
|
|
484
|
+
|
|
485
|
+
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
|
|
486
|
+
ports for each rank to listen to, something like the following:
|
|
487
|
+
|
|
488
|
+
.. code-block:: json
|
|
489
|
+
|
|
490
|
+
[
|
|
491
|
+
["123.123.1.1:5000", "123.123.1.2:5000"],
|
|
492
|
+
["123.123.2.1:5000", "123.123.2.2:5000"],
|
|
493
|
+
["123.123.3.1:5000", "123.123.3.2:5000"],
|
|
494
|
+
["123.123.4.1:5000", "123.123.4.2:5000"]
|
|
495
|
+
]
|
|
496
|
+
|
|
497
|
+
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
|
|
498
|
+
from the distributed backend.
|
|
499
|
+
|
|
500
|
+
JACCL
|
|
501
|
+
^^^^^
|
|
502
|
+
|
|
503
|
+
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
|
504
|
+
the process.
|
|
505
|
+
|
|
506
|
+
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
|
|
507
|
+
to all the other ranks connect to in order to establish the RDMA connections.
|
|
508
|
+
|
|
509
|
+
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
|
|
510
|
+
ibverbs device names that connect each node to each other node, something like
|
|
511
|
+
the following:
|
|
512
|
+
|
|
513
|
+
.. code-block:: json
|
|
514
|
+
|
|
515
|
+
[
|
|
516
|
+
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
|
|
517
|
+
["rdma_en5", null, "rdma_en3", "rdma_en4"],
|
|
518
|
+
["rdma_en4", "rdma_en3", null, "rdma_en5"],
|
|
519
|
+
["rdma_en3", "rdma_en4", "rdma_en5", null]
|
|
520
|
+
]
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
NCCL
|
|
524
|
+
^^^^^
|
|
525
|
+
|
|
526
|
+
**MLX_RANK** should contain a single 0-based integer that defines the rank of
|
|
527
|
+
the process.
|
|
528
|
+
|
|
529
|
+
**MLX_WORLD_SIZE** should contain the total number of processes that will be
|
|
530
|
+
launched.
|
|
531
|
+
|
|
532
|
+
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
|
|
533
|
+
hosts can connect to to establish the NCCL communication.
|
|
534
|
+
|
|
535
|
+
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
|
|
536
|
+
corresponds to this process.
|
|
537
|
+
|
|
538
|
+
Of course any `other environment variable
|
|
539
|
+
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
|
|
540
|
+
used by NCCL can be set.
|
|
541
|
+
|
|
542
|
+
.. _tips_and_tricks:
|
|
543
|
+
|
|
544
|
+
Tips and Tricks
|
|
545
|
+
----------------
|
|
546
|
+
|
|
547
|
+
This is a small collection of tips to help you utilize better the distributed
|
|
548
|
+
communication capabilities of MLX.
|
|
549
|
+
|
|
550
|
+
- *Test locally first.*
|
|
551
|
+
|
|
552
|
+
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
|
|
553
|
+
scale test on a single node first.
|
|
554
|
+
|
|
555
|
+
- *Batch your communication.*
|
|
556
|
+
|
|
557
|
+
As described in the :ref:`training example <training_example>`, performing a
|
|
558
|
+
lot of small communications can hurt performance. Copy the approach of
|
|
559
|
+
:func:`mlx.nn.average_gradients` to gather many small communications in a
|
|
560
|
+
single large one.
|
|
561
|
+
|
|
562
|
+
- *Visualize the connectivity.*
|
|
563
|
+
|
|
564
|
+
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
|
|
565
|
+
visualize the connnections and make sure that the cables are connected
|
|
566
|
+
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
|
|
567
|
+
|
|
568
|
+
- *Use the debugger.*
|
|
569
|
+
|
|
570
|
+
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
|
|
571
|
+
processes and gathers stdout from all processes. This makes using ``pdb`` a
|
|
572
|
+
breeze.
|