mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
.. _tensor_parallelism:
|
|
2
|
+
|
|
3
|
+
Tensor Parallelism
|
|
4
|
+
==================
|
|
5
|
+
|
|
6
|
+
In this example, we will explore how tensor parallelism (TP) works in MLX. We
|
|
7
|
+
will start with an overview of the distributed layers in ``mlx.nn`` and then
|
|
8
|
+
show how to do tensor parallelism Llama-style transformer models.
|
|
9
|
+
|
|
10
|
+
Sharded Layers
|
|
11
|
+
--------------
|
|
12
|
+
|
|
13
|
+
:class:`AllToShardedLinear <mlx.nn.AllToShardedLinear>`
|
|
14
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
15
|
+
|
|
16
|
+
This layer replicates a common input and shards the weight matrix along the
|
|
17
|
+
output dimension across all devices in the :class:`mlx.core.distributed.Group`.
|
|
18
|
+
The layer produces a sharded output.
|
|
19
|
+
|
|
20
|
+
For example, consider an :class:`mlx.nn.AllToShardedLinear` layer with
|
|
21
|
+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
|
|
22
|
+
and a device group with 2 devices. The layer shards the weight matrix along the
|
|
23
|
+
output dimension across the two devices, where each device receives the full
|
|
24
|
+
input and computes a partial output.
|
|
25
|
+
|
|
26
|
+
.. raw:: html
|
|
27
|
+
|
|
28
|
+
<div>
|
|
29
|
+
<img src="../_static/tp_inference/all-to-sharded-linear.png" alt="column-wise tensor parallelism" style="width: 100%">
|
|
30
|
+
</div>
|
|
31
|
+
|
|
32
|
+
This layer does not automatically gather all outputs from each device. This is
|
|
33
|
+
an intended and :ref:`useful design choice <useful_design_choices>`.
|
|
34
|
+
|
|
35
|
+
:class:`QuantizedAllToShardedLinear <mlx.nn.QuantizedAllToShardedLinear>` is
|
|
36
|
+
the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. Similar to
|
|
37
|
+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
|
|
38
|
+
included in any gradient computation.
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
:class:`ShardedToAllLinear <mlx.nn.ShardedToAllLinear>`
|
|
42
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
43
|
+
|
|
44
|
+
This layer expects inputs that are sharded along the feature dimension and
|
|
45
|
+
shards the weight matrix along the input dimension across all devices in the
|
|
46
|
+
:class:`mlx.core.distributed.Group`. The layer automatically aggregates the
|
|
47
|
+
results using :class:`mlx.core.distributed.all_sum`, so all devices in the
|
|
48
|
+
group will have the same result.
|
|
49
|
+
|
|
50
|
+
For example, consider an :class:`mlx.nn.ShardedToAllLinear` layer with
|
|
51
|
+
``input_dims=2`` and ``output_dims=2``, a batched input of shape ``(4, 2)``,
|
|
52
|
+
and a device group with 2 devices. The layer shards the weight matrix along the
|
|
53
|
+
input dimension across the two devices. Each device computes a ``(4,2)``
|
|
54
|
+
output, which is then aggregated with all other device outputs to get layer
|
|
55
|
+
output.
|
|
56
|
+
|
|
57
|
+
.. raw:: html
|
|
58
|
+
|
|
59
|
+
<div>
|
|
60
|
+
<img src="../_static/tp_inference/sharded-to-all-linear.png" alt="row-wise tensor parallelism" style="width: 100%">
|
|
61
|
+
</div>
|
|
62
|
+
|
|
63
|
+
This layer does not automatically shard the inputs along the feature dimension
|
|
64
|
+
for you. It is necessary to create a "partial" input structure to feed into the
|
|
65
|
+
layer. This is an intended and :ref:`useful design choice
|
|
66
|
+
<useful_design_choices>`.
|
|
67
|
+
|
|
68
|
+
:class:`QuantizedShardedToAllLinear <mlx.nn.QuantizedShardedToAllLinear>` is
|
|
69
|
+
the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. Similar to
|
|
70
|
+
:class:`mlx.nn.QuantizedLinear`, its parameters are frozen and will not be
|
|
71
|
+
included in any gradient computation.
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
Shard Utility Functions
|
|
75
|
+
-----------------------
|
|
76
|
+
|
|
77
|
+
:func:`shard_linear <mlx.nn.layers.distributed.shard_linear>`
|
|
78
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
79
|
+
|
|
80
|
+
Converts a regular linear layer into a tensor parallel layer that distributes
|
|
81
|
+
computation across multiple devices. Takes an existing :class:`mlx.nn.Linear`
|
|
82
|
+
or :class:`mlx.nn.QuantizedLinear` layer and returns a new distributed layer
|
|
83
|
+
(either :class:`mlx.nn.AllToShardedLinear` or
|
|
84
|
+
:class:`mlx.nn.ShardedToAllLinear`, depending on the sharding type). The
|
|
85
|
+
original layer is not modified.
|
|
86
|
+
|
|
87
|
+
:func:`shard_inplace <mlx.nn.layers.distributed.shard_inplace>`
|
|
88
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
89
|
+
|
|
90
|
+
Splits the parameters of an existing layer across multiple devices by modifying
|
|
91
|
+
the layer in-place. Unlike :func:`shard_linear
|
|
92
|
+
<mlx.nn.layers.distributed.shard_linear>`, this function does not create a new
|
|
93
|
+
layer or add distributed communication. The layer itself must handle
|
|
94
|
+
distributed communication if needed.
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
.. _useful_design_choices:
|
|
98
|
+
|
|
99
|
+
Useful Design Choices
|
|
100
|
+
---------------------
|
|
101
|
+
|
|
102
|
+
The design choices above regarding when operations are done automatically are intentional and make model training and inference easier.
|
|
103
|
+
|
|
104
|
+
All-to-sharded and sharded-to-all layers naturally go together because the
|
|
105
|
+
output of the former layer is exactly the input needed needed for the latter.
|
|
106
|
+
This removes the need for an intermediate gather step between the layers,
|
|
107
|
+
reducing communication overhead.
|
|
108
|
+
|
|
109
|
+
This is why :class:`mlx.nn.AllToShardedLinear` does not aggregate results
|
|
110
|
+
automatically and why :class:`mlx.nn.ShardedToAllLinear` does not shard inputs
|
|
111
|
+
automatically. It is so that they can be placed in successive order and work
|
|
112
|
+
together easily.
|
|
113
|
+
|
|
114
|
+
We can demonstrate this through a simple model using our two types of
|
|
115
|
+
distributed layers.
|
|
116
|
+
|
|
117
|
+
.. code-block:: python
|
|
118
|
+
|
|
119
|
+
x = ... # some (4, 2) model input: batch size 4, feature size 2
|
|
120
|
+
|
|
121
|
+
l1 = nn.AllToShardedLinear(2, 2, bias=False) # initialize the layer
|
|
122
|
+
l1_out = l1(x) # (4, 1) output
|
|
123
|
+
|
|
124
|
+
l2 = nn.ShardedToAllLinear(2, 2, bias=False)
|
|
125
|
+
l2_out = l2(l1_out) # (4, 2) output
|
|
126
|
+
|
|
127
|
+
.. raw:: html
|
|
128
|
+
|
|
129
|
+
<div>
|
|
130
|
+
<img src="../_static/tp_inference/column-row-tp.png" alt="two layer tensor parallelism" style="width: 100%">
|
|
131
|
+
<p style="font-size: 0.85em; margin-top: 0.5em;"><small>A visualization of the simple MLX model using all-to-sharded then sharded-to-all tensor parallelism across 2 devices.</small></p>
|
|
132
|
+
</div>
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
LLM Inference with Tensor Parallelism
|
|
136
|
+
-------------------------------------
|
|
137
|
+
|
|
138
|
+
We can apply these TP techniques to LLMs in order to enable inference for much
|
|
139
|
+
larger models by sharding parameters from huge layers across multiple devices.
|
|
140
|
+
|
|
141
|
+
To demonstrate this, let's apply TP to the Transformer block of our :doc:`Llama
|
|
142
|
+
Inference <llama-inference>` example. In this example, we will use the same
|
|
143
|
+
inference script as the Llama Inference example, which can be found in
|
|
144
|
+
`mlx-examples`_.
|
|
145
|
+
|
|
146
|
+
Our first edit is to initialize the distributed communication group and get the
|
|
147
|
+
current process rank:
|
|
148
|
+
|
|
149
|
+
.. code-block:: python
|
|
150
|
+
|
|
151
|
+
world = mx.distributed.init()
|
|
152
|
+
rank = world.rank()
|
|
153
|
+
|
|
154
|
+
Next, let's look at the current architecture of the transformer block and see how we can apply tensor parallelism:
|
|
155
|
+
|
|
156
|
+
.. raw:: html
|
|
157
|
+
|
|
158
|
+
<div>
|
|
159
|
+
<img src="../_static/tp_inference/llama-transformer.png" alt="llama transformer example" style="width: 100%">
|
|
160
|
+
</div>
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
This architecture has two natural places where
|
|
164
|
+
tensor parallelism can be applied: the attention block and the FFN
|
|
165
|
+
block. Both follow the same pattern: multiple parallel linear layers operating
|
|
166
|
+
on the same input, followed by a single output linear layer. In the attention
|
|
167
|
+
block, the Q, K, and V projections are sharded along the output dimension (all-to-sharded), and the output
|
|
168
|
+
projection is sharded along the input dimension (sharded-to-all). Similarly in the FFN block, the gate and up projections
|
|
169
|
+
become all-to-sharded layers, and the down projection becomes an sharded-to-all layer.
|
|
170
|
+
|
|
171
|
+
The intermediate operations between the linear layers (RoPE, softmax, scaled
|
|
172
|
+
dot-product attention in the attention block, and element-wise multiplication
|
|
173
|
+
in the FFN block) do not impede the use of our TP paradigm. These operations
|
|
174
|
+
are either:
|
|
175
|
+
|
|
176
|
+
- **Element-wise operations** (RoPE, element-wise multiplication): These
|
|
177
|
+
operate independently on each element or position, preserving the sharding
|
|
178
|
+
pattern without requiring cross-device communication.
|
|
179
|
+
|
|
180
|
+
- **Operations on non-sharded dimensions** (softmax, scaled dot-product
|
|
181
|
+
attention): These operate along dimensions that are not sharded (such as the
|
|
182
|
+
sequence length or head dimensions), so they can be computed independently on
|
|
183
|
+
each device. The attention computation ``Q @ K^T`` and ``scores @ V`` work
|
|
184
|
+
correctly with sharded Q, K, V tensors because the matrix multiplications are
|
|
185
|
+
performed along the sharded feature dimension, and the results remain
|
|
186
|
+
properly sharded for the subsequent sharded-to-all layer.
|
|
187
|
+
|
|
188
|
+
To implement sharding in our Llama inference, we use :func:`shard_linear
|
|
189
|
+
<mlx.nn.layers.distributed.shard_linear>` to get sharded linear layers with
|
|
190
|
+
distributed communication. This is easier than using :func:`shard_inplace
|
|
191
|
+
<mlx.nn.layers.distributed.shard_inplace>` and implementing the steps manually
|
|
192
|
+
in the :code:`__call__` function.
|
|
193
|
+
|
|
194
|
+
The following code shows how to shard the Attention block. The Q, K, and V
|
|
195
|
+
projection layers are converted to all-to-sharded layers, while the output
|
|
196
|
+
projection is converted to a sharded-to-all layer. The number of heads are also
|
|
197
|
+
adjusted to account for the sharding:
|
|
198
|
+
|
|
199
|
+
.. code-block:: python
|
|
200
|
+
|
|
201
|
+
# ... in Attention class
|
|
202
|
+
def shard(self, group: mx.distributed.Group):
|
|
203
|
+
self.n_heads = self.n_heads // group.size()
|
|
204
|
+
self.n_kv_heads = self.n_kv_heads // group.size()
|
|
205
|
+
|
|
206
|
+
self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group)
|
|
207
|
+
self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group)
|
|
208
|
+
self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group)
|
|
209
|
+
self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group)
|
|
210
|
+
|
|
211
|
+
Similarly, the FeedForward block is sharded by converting the gate (w1) and up
|
|
212
|
+
(w3) projections to all-to-sharded layers, and the down projection (w2) to
|
|
213
|
+
a sharded-to-all layer:
|
|
214
|
+
|
|
215
|
+
.. code-block:: python
|
|
216
|
+
|
|
217
|
+
# ... in FeedForward class
|
|
218
|
+
def shard(self, group: mx.distributed.Group):
|
|
219
|
+
self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group)
|
|
220
|
+
self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group)
|
|
221
|
+
self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group)
|
|
222
|
+
|
|
223
|
+
Finally, in our :code:`load_model` function, we need to apply our sharding
|
|
224
|
+
functions to all transformer layers when using multiple devices:
|
|
225
|
+
|
|
226
|
+
.. code-block:: python
|
|
227
|
+
|
|
228
|
+
# ... in load_model function
|
|
229
|
+
if world.size() > 1:
|
|
230
|
+
# convert Linear layers in Transformer/FFN to appropriate Sharded Layers
|
|
231
|
+
for layer in model.layers:
|
|
232
|
+
layer.attention.shard(group=world)
|
|
233
|
+
layer.feed_forward.shard(group=world)
|
|
234
|
+
|
|
235
|
+
This allows us to use the llama inference file as normal when running
|
|
236
|
+
:code:`python llama.py`, but now we can also run it across two (or more)
|
|
237
|
+
devices via :code:`mlx.launch -n 2 llama.py`.
|
|
238
|
+
|
|
239
|
+
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
MLX
|
|
2
|
+
===
|
|
3
|
+
|
|
4
|
+
MLX is a NumPy-like array framework designed for efficient and flexible machine
|
|
5
|
+
learning on Apple silicon, brought to you by Apple machine learning research.
|
|
6
|
+
|
|
7
|
+
The Python API closely follows NumPy with a few exceptions. MLX also has a
|
|
8
|
+
fully featured C++ API which closely follows the Python API.
|
|
9
|
+
|
|
10
|
+
The main differences between MLX and NumPy are:
|
|
11
|
+
|
|
12
|
+
- **Composable function transformations**: MLX has composable function
|
|
13
|
+
transformations for automatic differentiation, automatic vectorization,
|
|
14
|
+
and computation graph optimization.
|
|
15
|
+
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
|
16
|
+
materialized when needed.
|
|
17
|
+
- **Multi-device**: Operations can run on any of the supported devices (CPU,
|
|
18
|
+
GPU, ...)
|
|
19
|
+
|
|
20
|
+
The design of MLX is inspired by frameworks like `PyTorch
|
|
21
|
+
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
|
22
|
+
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
|
23
|
+
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
|
24
|
+
memory. Operations on MLX arrays can be performed on any of the supported
|
|
25
|
+
device types without performing data copies. Currently supported device types
|
|
26
|
+
are the CPU and GPU.
|
|
27
|
+
|
|
28
|
+
.. toctree::
|
|
29
|
+
:caption: Install
|
|
30
|
+
:maxdepth: 1
|
|
31
|
+
|
|
32
|
+
install
|
|
33
|
+
|
|
34
|
+
.. toctree::
|
|
35
|
+
:caption: Usage
|
|
36
|
+
:maxdepth: 1
|
|
37
|
+
|
|
38
|
+
usage/quick_start
|
|
39
|
+
usage/lazy_evaluation
|
|
40
|
+
usage/unified_memory
|
|
41
|
+
usage/indexing
|
|
42
|
+
usage/saving_and_loading
|
|
43
|
+
usage/function_transforms
|
|
44
|
+
usage/compile
|
|
45
|
+
usage/numpy
|
|
46
|
+
usage/distributed
|
|
47
|
+
usage/using_streams
|
|
48
|
+
usage/export
|
|
49
|
+
|
|
50
|
+
.. toctree::
|
|
51
|
+
:caption: Examples
|
|
52
|
+
:maxdepth: 1
|
|
53
|
+
|
|
54
|
+
examples/linear_regression
|
|
55
|
+
examples/mlp
|
|
56
|
+
examples/llama-inference
|
|
57
|
+
examples/data_parallelism
|
|
58
|
+
examples/tensor_parallelism
|
|
59
|
+
|
|
60
|
+
.. toctree::
|
|
61
|
+
:caption: Python API Reference
|
|
62
|
+
:maxdepth: 1
|
|
63
|
+
|
|
64
|
+
python/array
|
|
65
|
+
python/data_types
|
|
66
|
+
python/devices_and_streams
|
|
67
|
+
python/export
|
|
68
|
+
python/ops
|
|
69
|
+
python/random
|
|
70
|
+
python/transforms
|
|
71
|
+
python/fast
|
|
72
|
+
python/fft
|
|
73
|
+
python/linalg
|
|
74
|
+
python/metal
|
|
75
|
+
python/cuda
|
|
76
|
+
python/memory_management
|
|
77
|
+
python/nn
|
|
78
|
+
python/optimizers
|
|
79
|
+
python/distributed
|
|
80
|
+
python/tree_utils
|
|
81
|
+
|
|
82
|
+
.. toctree::
|
|
83
|
+
:caption: C++ API Reference
|
|
84
|
+
:maxdepth: 1
|
|
85
|
+
|
|
86
|
+
cpp/ops
|
|
87
|
+
|
|
88
|
+
.. toctree::
|
|
89
|
+
:caption: Further Reading
|
|
90
|
+
:maxdepth: 1
|
|
91
|
+
|
|
92
|
+
dev/extensions
|
|
93
|
+
dev/metal_debugger
|
|
94
|
+
dev/metal_logging
|
|
95
|
+
dev/custom_metal_kernels
|
|
96
|
+
dev/mlx_in_cpp
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
.. _build_and_install:
|
|
2
|
+
|
|
3
|
+
Build and Install
|
|
4
|
+
=================
|
|
5
|
+
|
|
6
|
+
Python Installation
|
|
7
|
+
-------------------
|
|
8
|
+
|
|
9
|
+
MLX is available on PyPI. All you have to do to use MLX with your own Apple
|
|
10
|
+
silicon computer is
|
|
11
|
+
|
|
12
|
+
.. code-block:: shell
|
|
13
|
+
|
|
14
|
+
pip install mlx
|
|
15
|
+
|
|
16
|
+
To install from PyPI your system must meet the following requirements:
|
|
17
|
+
|
|
18
|
+
- Using an M series chip (Apple silicon)
|
|
19
|
+
- Using a native Python >= 3.10
|
|
20
|
+
- macOS >= 14.0
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
MLX is only available on devices running macOS >= 14.0 and higher.
|
|
24
|
+
|
|
25
|
+
CUDA
|
|
26
|
+
^^^^
|
|
27
|
+
|
|
28
|
+
MLX has a CUDA backend which you can install with:
|
|
29
|
+
|
|
30
|
+
.. code-block:: shell
|
|
31
|
+
|
|
32
|
+
pip install mlx[cuda12]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
To install the CUDA package from PyPi your system must meet the following
|
|
36
|
+
requirements:
|
|
37
|
+
|
|
38
|
+
- Nvidia architecture >= SM 7.5
|
|
39
|
+
- Nvidia driver >= 550.54.14
|
|
40
|
+
- CUDA toolkit >= 12.0
|
|
41
|
+
- Linux distribution with glibc >= 2.35
|
|
42
|
+
- Python >= 3.10
|
|
43
|
+
|
|
44
|
+
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
|
45
|
+
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
|
46
|
+
|
|
47
|
+
CPU-only (Linux)
|
|
48
|
+
^^^^^^^^^^^^^^^^
|
|
49
|
+
|
|
50
|
+
For a CPU-only version of MLX that runs on Linux use:
|
|
51
|
+
|
|
52
|
+
.. code-block:: shell
|
|
53
|
+
|
|
54
|
+
pip install mlx[cpu]
|
|
55
|
+
|
|
56
|
+
To install the CPU-only package from PyPi your system must meet the following
|
|
57
|
+
requirements:
|
|
58
|
+
|
|
59
|
+
- Linux distribution with glibc >= 2.35
|
|
60
|
+
- Python >= 3.10
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
Troubleshooting
|
|
64
|
+
^^^^^^^^^^^^^^^
|
|
65
|
+
|
|
66
|
+
*My OS and Python versions are in the required range but pip still does not find
|
|
67
|
+
a matching distribution.*
|
|
68
|
+
|
|
69
|
+
Probably you are using a non-native Python. The output of
|
|
70
|
+
|
|
71
|
+
.. code-block:: shell
|
|
72
|
+
|
|
73
|
+
python -c "import platform; print(platform.processor())"
|
|
74
|
+
|
|
75
|
+
should be ``arm``. If it is ``i386`` (and you have M series machine) then you
|
|
76
|
+
are using a non-native Python. Switch your Python to a native Python. A good
|
|
77
|
+
way to do this is with `Conda <https://stackoverflow.com/q/65415996>`_.
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
Build from source
|
|
81
|
+
-----------------
|
|
82
|
+
|
|
83
|
+
Build Requirements
|
|
84
|
+
^^^^^^^^^^^^^^^^^^
|
|
85
|
+
|
|
86
|
+
- A C++ compiler with C++20 support (e.g. Clang >= 15.0)
|
|
87
|
+
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
|
88
|
+
- Xcode >= 15.0 and macOS SDK >= 14.0
|
|
89
|
+
|
|
90
|
+
.. note::
|
|
91
|
+
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
|
92
|
+
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
|
|
93
|
+
|
|
94
|
+
Python API
|
|
95
|
+
^^^^^^^^^^
|
|
96
|
+
|
|
97
|
+
.. _python install:
|
|
98
|
+
|
|
99
|
+
To build and install the MLX python library from source, first, clone MLX from
|
|
100
|
+
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
|
101
|
+
|
|
102
|
+
.. code-block:: shell
|
|
103
|
+
|
|
104
|
+
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
|
105
|
+
|
|
106
|
+
Then simply build and install MLX using pip:
|
|
107
|
+
|
|
108
|
+
.. code-block:: shell
|
|
109
|
+
|
|
110
|
+
pip install .
|
|
111
|
+
|
|
112
|
+
For developing, install the package with development dependencies, and use an
|
|
113
|
+
editable install:
|
|
114
|
+
|
|
115
|
+
.. code-block:: shell
|
|
116
|
+
|
|
117
|
+
pip install -e ".[dev]"
|
|
118
|
+
|
|
119
|
+
Once the development dependencies are installed, you can build faster with:
|
|
120
|
+
|
|
121
|
+
.. code-block:: shell
|
|
122
|
+
|
|
123
|
+
python setup.py build_ext --inplace
|
|
124
|
+
|
|
125
|
+
Run the tests with:
|
|
126
|
+
|
|
127
|
+
.. code-block:: shell
|
|
128
|
+
|
|
129
|
+
python -m unittest discover python/tests
|
|
130
|
+
|
|
131
|
+
C++ API
|
|
132
|
+
^^^^^^^
|
|
133
|
+
|
|
134
|
+
.. _cpp install:
|
|
135
|
+
|
|
136
|
+
Currently, MLX must be built and installed from source.
|
|
137
|
+
|
|
138
|
+
Similarly to the python library, to build and install the MLX C++ library start
|
|
139
|
+
by cloning MLX from `its GitHub repo
|
|
140
|
+
<https://github.com/ml-explore/mlx>`_:
|
|
141
|
+
|
|
142
|
+
.. code-block:: shell
|
|
143
|
+
|
|
144
|
+
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
|
145
|
+
|
|
146
|
+
Create a build directory and run CMake and make:
|
|
147
|
+
|
|
148
|
+
.. code-block:: shell
|
|
149
|
+
|
|
150
|
+
mkdir -p build && cd build
|
|
151
|
+
cmake .. && make -j
|
|
152
|
+
|
|
153
|
+
Run tests with:
|
|
154
|
+
|
|
155
|
+
.. code-block:: shell
|
|
156
|
+
|
|
157
|
+
make test
|
|
158
|
+
|
|
159
|
+
Install with:
|
|
160
|
+
|
|
161
|
+
.. code-block:: shell
|
|
162
|
+
|
|
163
|
+
make install
|
|
164
|
+
|
|
165
|
+
Note that the built ``mlx.metallib`` file should be either at the same
|
|
166
|
+
directory as the executable statically linked to ``libmlx.a`` or the
|
|
167
|
+
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
|
168
|
+
should point to the path to the built metal library.
|
|
169
|
+
|
|
170
|
+
.. list-table:: Build Options
|
|
171
|
+
:widths: 25 8
|
|
172
|
+
:header-rows: 1
|
|
173
|
+
|
|
174
|
+
* - Option
|
|
175
|
+
- Default
|
|
176
|
+
* - MLX_BUILD_TESTS
|
|
177
|
+
- ON
|
|
178
|
+
* - MLX_BUILD_EXAMPLES
|
|
179
|
+
- OFF
|
|
180
|
+
* - MLX_BUILD_BENCHMARKS
|
|
181
|
+
- OFF
|
|
182
|
+
* - MLX_BUILD_METAL
|
|
183
|
+
- ON
|
|
184
|
+
* - MLX_BUILD_CPU
|
|
185
|
+
- ON
|
|
186
|
+
* - MLX_BUILD_PYTHON_BINDINGS
|
|
187
|
+
- OFF
|
|
188
|
+
* - MLX_METAL_DEBUG
|
|
189
|
+
- OFF
|
|
190
|
+
* - MLX_BUILD_SAFETENSORS
|
|
191
|
+
- ON
|
|
192
|
+
* - MLX_BUILD_GGUF
|
|
193
|
+
- ON
|
|
194
|
+
* - MLX_METAL_JIT
|
|
195
|
+
- OFF
|
|
196
|
+
|
|
197
|
+
.. note::
|
|
198
|
+
|
|
199
|
+
If you have multiple Xcode installations and wish to use
|
|
200
|
+
a specific one while building, you can do so by adding the
|
|
201
|
+
following environment variable before building
|
|
202
|
+
|
|
203
|
+
.. code-block:: shell
|
|
204
|
+
|
|
205
|
+
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
|
206
|
+
|
|
207
|
+
Further, you can use the following command to find out which
|
|
208
|
+
macOS SDK will be used
|
|
209
|
+
|
|
210
|
+
.. code-block:: shell
|
|
211
|
+
|
|
212
|
+
xcrun -sdk macosx --show-sdk-version
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
Binary Size Minimization
|
|
216
|
+
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
217
|
+
|
|
218
|
+
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
|
219
|
+
and ``BUILD_SHARED_LIBS=ON``.
|
|
220
|
+
|
|
221
|
+
The MLX CMake build has several additional options to make smaller binaries.
|
|
222
|
+
For example, if you don't need the CPU backend or support for safetensors and
|
|
223
|
+
GGUF, you can do:
|
|
224
|
+
|
|
225
|
+
.. code-block:: shell
|
|
226
|
+
|
|
227
|
+
cmake .. \
|
|
228
|
+
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
|
229
|
+
-DBUILD_SHARED_LIBS=ON \
|
|
230
|
+
-DMLX_BUILD_CPU=OFF \
|
|
231
|
+
-DMLX_BUILD_SAFETENSORS=OFF \
|
|
232
|
+
-DMLX_BUILD_GGUF=OFF \
|
|
233
|
+
-DMLX_METAL_JIT=ON
|
|
234
|
+
|
|
235
|
+
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
|
236
|
+
contains pre-built GPU kernels. This substantially reduces the size of the
|
|
237
|
+
Metal library by run-time compiling kernels the first time they are used in MLX
|
|
238
|
+
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
|
239
|
+
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|
240
|
+
application. Once a kernel is compiled, it will be cached by the system. The
|
|
241
|
+
Metal kernel cache persists across reboots.
|
|
242
|
+
|
|
243
|
+
Linux
|
|
244
|
+
^^^^^
|
|
245
|
+
|
|
246
|
+
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
|
247
|
+
For example on Ubuntu, run the following:
|
|
248
|
+
|
|
249
|
+
.. code-block:: shell
|
|
250
|
+
|
|
251
|
+
apt-get update -y
|
|
252
|
+
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
|
253
|
+
|
|
254
|
+
From here follow the instructions to install either the :ref:`Python <python
|
|
255
|
+
install>` or :ref:`C++ <cpp install>` APIs.
|
|
256
|
+
|
|
257
|
+
CUDA
|
|
258
|
+
^^^^
|
|
259
|
+
|
|
260
|
+
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
|
261
|
+
and the CUDA toolkit. For example on Ubuntu, run the following:
|
|
262
|
+
|
|
263
|
+
.. code-block:: shell
|
|
264
|
+
|
|
265
|
+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
|
266
|
+
dpkg -i cuda-keyring_1.1-1_all.deb
|
|
267
|
+
apt-get update -y
|
|
268
|
+
apt-get -y install cuda-toolkit-12-9
|
|
269
|
+
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
When building either the Python or C++ APIs make sure to pass the cmake flag
|
|
273
|
+
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
|
274
|
+
|
|
275
|
+
.. code-block:: shell
|
|
276
|
+
|
|
277
|
+
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
|
278
|
+
|
|
279
|
+
To build the C++ package run:
|
|
280
|
+
|
|
281
|
+
.. code-block:: shell
|
|
282
|
+
|
|
283
|
+
mkdir -p build && cd build
|
|
284
|
+
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
Troubleshooting
|
|
288
|
+
^^^^^^^^^^^^^^^
|
|
289
|
+
|
|
290
|
+
Metal not found
|
|
291
|
+
~~~~~~~~~~~~~~~
|
|
292
|
+
|
|
293
|
+
You see the following error when you try to build:
|
|
294
|
+
|
|
295
|
+
.. code-block:: shell
|
|
296
|
+
|
|
297
|
+
error: unable to find utility "metal", not a developer tool or in PATH
|
|
298
|
+
|
|
299
|
+
To fix this, first make sure you have Xcode installed:
|
|
300
|
+
|
|
301
|
+
.. code-block:: shell
|
|
302
|
+
|
|
303
|
+
xcode-select --install
|
|
304
|
+
|
|
305
|
+
Then set the active developer directory:
|
|
306
|
+
|
|
307
|
+
.. code-block:: shell
|
|
308
|
+
|
|
309
|
+
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
|
310
|
+
|
|
311
|
+
x86 Shell
|
|
312
|
+
~~~~~~~~~
|
|
313
|
+
|
|
314
|
+
.. _build shell:
|
|
315
|
+
|
|
316
|
+
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
|
317
|
+
Rosetta instead of natively.
|
|
318
|
+
|
|
319
|
+
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
|
320
|
+
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
|
|
321
|
+
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
|
|
322
|
+
terminal.
|
|
323
|
+
|
|
324
|
+
Verify the terminal is now running natively the following command:
|
|
325
|
+
|
|
326
|
+
.. code-block:: shell
|
|
327
|
+
|
|
328
|
+
$ uname -p
|
|
329
|
+
arm
|
|
330
|
+
|
|
331
|
+
Also check that cmake is using the correct architecture:
|
|
332
|
+
|
|
333
|
+
.. code-block:: shell
|
|
334
|
+
|
|
335
|
+
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
|
|
336
|
+
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
|
|
337
|
+
|
|
338
|
+
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
|
339
|
+
but the build errors out with "Building for x86_64 on macOS is not supported."
|
|
340
|
+
wipe your build cache with ``rm -rf build/`` and try again.
|