mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
.. _export_usage:
|
|
2
|
+
|
|
3
|
+
Exporting Functions
|
|
4
|
+
===================
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
MLX has an API to export and import functions to and from a file. This lets you
|
|
9
|
+
run computations written in one MLX front-end (e.g. Python) in another MLX
|
|
10
|
+
front-end (e.g. C++).
|
|
11
|
+
|
|
12
|
+
This guide walks through the basics of the MLX export API with some examples.
|
|
13
|
+
To see the full list of functions check-out the :ref:`API documentation
|
|
14
|
+
<export>`.
|
|
15
|
+
|
|
16
|
+
Basics of Exporting
|
|
17
|
+
-------------------
|
|
18
|
+
|
|
19
|
+
Let's start with a simple example:
|
|
20
|
+
|
|
21
|
+
.. code-block:: python
|
|
22
|
+
|
|
23
|
+
def fun(x, y):
|
|
24
|
+
return x + y
|
|
25
|
+
|
|
26
|
+
x = mx.array(1.0)
|
|
27
|
+
y = mx.array(1.0)
|
|
28
|
+
mx.export_function("add.mlxfn", fun, x, y)
|
|
29
|
+
|
|
30
|
+
To export a function, provide sample input arrays that the function
|
|
31
|
+
can be called with. The data doesn't matter, but the shapes and types of the
|
|
32
|
+
arrays do. In the above example we exported ``fun`` with two ``float32``
|
|
33
|
+
scalar arrays. We can then import the function and run it:
|
|
34
|
+
|
|
35
|
+
.. code-block:: python
|
|
36
|
+
|
|
37
|
+
add_fun = mx.import_function("add.mlxfn")
|
|
38
|
+
|
|
39
|
+
out, = add_fun(mx.array(1.0), mx.array(2.0))
|
|
40
|
+
# Prints: array(3, dtype=float32)
|
|
41
|
+
print(out)
|
|
42
|
+
|
|
43
|
+
out, = add_fun(mx.array(1.0), mx.array(3.0))
|
|
44
|
+
# Prints: array(4, dtype=float32)
|
|
45
|
+
print(out)
|
|
46
|
+
|
|
47
|
+
# Raises an exception
|
|
48
|
+
add_fun(mx.array(1), mx.array(3.0))
|
|
49
|
+
|
|
50
|
+
# Raises an exception
|
|
51
|
+
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
|
|
52
|
+
|
|
53
|
+
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
|
|
54
|
+
shapes and types of the inputs are different than the shapes and types of the
|
|
55
|
+
example inputs we exported the function with.
|
|
56
|
+
|
|
57
|
+
Also notice that even though the original ``fun`` returns a single output
|
|
58
|
+
array, the imported function always returns a tuple of one or more arrays.
|
|
59
|
+
|
|
60
|
+
The inputs to :func:`export_function` and to an imported function can be
|
|
61
|
+
specified as variable positional arguments or as a tuple of arrays:
|
|
62
|
+
|
|
63
|
+
.. code-block:: python
|
|
64
|
+
|
|
65
|
+
def fun(x, y):
|
|
66
|
+
return x + y
|
|
67
|
+
|
|
68
|
+
x = mx.array(1.0)
|
|
69
|
+
y = mx.array(1.0)
|
|
70
|
+
|
|
71
|
+
# Both arguments to fun are positional
|
|
72
|
+
mx.export_function("add.mlxfn", fun, x, y)
|
|
73
|
+
|
|
74
|
+
# Same as above
|
|
75
|
+
mx.export_function("add.mlxfn", fun, (x, y))
|
|
76
|
+
|
|
77
|
+
imported_fun = mx.import_function("add.mlxfn")
|
|
78
|
+
|
|
79
|
+
# Ok
|
|
80
|
+
out, = imported_fun(x, y)
|
|
81
|
+
|
|
82
|
+
# Also ok
|
|
83
|
+
out, = imported_fun((x, y))
|
|
84
|
+
|
|
85
|
+
You can pass example inputs to functions as positional or keyword arguments. If
|
|
86
|
+
you use keyword arguments to export the function, then you have to use the same
|
|
87
|
+
keyword arguments when calling the imported function.
|
|
88
|
+
|
|
89
|
+
.. code-block:: python
|
|
90
|
+
|
|
91
|
+
def fun(x, y):
|
|
92
|
+
return x + y
|
|
93
|
+
|
|
94
|
+
# One argument to fun is positional, the other is a kwarg
|
|
95
|
+
mx.export_function("add.mlxfn", fun, x, y=y)
|
|
96
|
+
|
|
97
|
+
imported_fun = mx.import_function("add.mlxfn")
|
|
98
|
+
|
|
99
|
+
# Ok
|
|
100
|
+
out, = imported_fun(x, y=y)
|
|
101
|
+
|
|
102
|
+
# Also ok
|
|
103
|
+
out, = imported_fun((x,), {"y": y})
|
|
104
|
+
|
|
105
|
+
# Raises since the keyword argument is missing
|
|
106
|
+
out, = imported_fun(x, y)
|
|
107
|
+
|
|
108
|
+
# Raises since the keyword argument has the wrong key
|
|
109
|
+
out, = imported_fun(x, z=y)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
Exporting Modules
|
|
113
|
+
-----------------
|
|
114
|
+
|
|
115
|
+
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
|
|
116
|
+
in the exported function. Here's an example:
|
|
117
|
+
|
|
118
|
+
.. code-block:: python
|
|
119
|
+
|
|
120
|
+
model = nn.Linear(4, 4)
|
|
121
|
+
mx.eval(model.parameters())
|
|
122
|
+
|
|
123
|
+
def call(x):
|
|
124
|
+
return model(x)
|
|
125
|
+
|
|
126
|
+
mx.export_function("model.mlxfn", call, mx.zeros(4))
|
|
127
|
+
|
|
128
|
+
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
|
|
129
|
+
parameters are also saved to the ``model.mlxfn`` file.
|
|
130
|
+
|
|
131
|
+
.. note::
|
|
132
|
+
|
|
133
|
+
For enclosed arrays inside an exported function, be extra careful to ensure
|
|
134
|
+
they are evaluated. The computation graph that gets exported will include
|
|
135
|
+
the computation that produces enclosed inputs.
|
|
136
|
+
|
|
137
|
+
If the above example was missing ``mx.eval(model.parameters()``, the
|
|
138
|
+
exported function would include the random initialization of the
|
|
139
|
+
:obj:`mlx.nn.Module` parameters.
|
|
140
|
+
|
|
141
|
+
If you only want to export the ``Module.__call__`` function without the
|
|
142
|
+
parameters, pass them as inputs to the ``call`` wrapper:
|
|
143
|
+
|
|
144
|
+
.. code-block:: python
|
|
145
|
+
|
|
146
|
+
model = nn.Linear(4, 4)
|
|
147
|
+
mx.eval(model.parameters())
|
|
148
|
+
|
|
149
|
+
def call(x, **params):
|
|
150
|
+
# Set the model's parameters to the input parameters
|
|
151
|
+
model.update(tree_unflatten(list(params.items())))
|
|
152
|
+
return model(x)
|
|
153
|
+
|
|
154
|
+
params = tree_flatten(model.parameters(), destination={})
|
|
155
|
+
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
Shapeless Exports
|
|
159
|
+
-----------------
|
|
160
|
+
|
|
161
|
+
Just like :func:`compile`, functions can also be exported for dynamically shaped
|
|
162
|
+
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
|
|
163
|
+
to export a function which can be used for inputs with variable shapes:
|
|
164
|
+
|
|
165
|
+
.. code-block:: python
|
|
166
|
+
|
|
167
|
+
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
|
168
|
+
imported_abs = mx.import_function("fun.mlxfn")
|
|
169
|
+
|
|
170
|
+
# Ok
|
|
171
|
+
out, = imported_abs(mx.array([-1.0]))
|
|
172
|
+
|
|
173
|
+
# Also ok
|
|
174
|
+
out, = imported_abs(mx.array([-1.0, -2.0]))
|
|
175
|
+
|
|
176
|
+
With ``shapeless=False`` (which is the default), the second call to
|
|
177
|
+
``imported_abs`` would raise an exception with a shape mismatch.
|
|
178
|
+
|
|
179
|
+
Shapeless exporting works the same as shapeless compilation and should be
|
|
180
|
+
used carefully. See the :ref:`documentation on shapeless compilation
|
|
181
|
+
<shapeless_compile>` for more information.
|
|
182
|
+
|
|
183
|
+
Exporting Multiple Traces
|
|
184
|
+
-------------------------
|
|
185
|
+
|
|
186
|
+
In some cases, functions build different computation graphs for different
|
|
187
|
+
input arguments. A simple way to manage this is to export to a new file with
|
|
188
|
+
each set of inputs. This is a fine option in many cases. But it can be
|
|
189
|
+
suboptimal if the exported functions have a large amount of duplicate constant
|
|
190
|
+
data (for example the parameters of a :obj:`mlx.nn.Module`).
|
|
191
|
+
|
|
192
|
+
The export API in MLX lets you export multiple traces of the same function to
|
|
193
|
+
a single file by creating an exporting context manager with :func:`exporter`:
|
|
194
|
+
|
|
195
|
+
.. code-block:: python
|
|
196
|
+
|
|
197
|
+
def fun(x, y=None):
|
|
198
|
+
constant = mx.array(3.0)
|
|
199
|
+
if y is not None:
|
|
200
|
+
x += y
|
|
201
|
+
return x + constant
|
|
202
|
+
|
|
203
|
+
with mx.exporter("fun.mlxfn", fun) as exporter:
|
|
204
|
+
exporter(mx.array(1.0))
|
|
205
|
+
exporter(mx.array(1.0), y=mx.array(0.0))
|
|
206
|
+
|
|
207
|
+
imported_function = mx.import_function("fun.mlxfn")
|
|
208
|
+
|
|
209
|
+
# Call the function with y=None
|
|
210
|
+
out, = imported_function(mx.array(1.0))
|
|
211
|
+
print(out)
|
|
212
|
+
|
|
213
|
+
# Call the function with y specified
|
|
214
|
+
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
|
|
215
|
+
print(out)
|
|
216
|
+
|
|
217
|
+
In the above example the function constant data, (i.e. ``constant``), is only
|
|
218
|
+
saved once.
|
|
219
|
+
|
|
220
|
+
Transformations with Imported Functions
|
|
221
|
+
---------------------------------------
|
|
222
|
+
|
|
223
|
+
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
|
|
224
|
+
on imported functions just like regular Python functions:
|
|
225
|
+
|
|
226
|
+
.. code-block:: python
|
|
227
|
+
|
|
228
|
+
def fun(x):
|
|
229
|
+
return mx.sin(x)
|
|
230
|
+
|
|
231
|
+
x = mx.array(0.0)
|
|
232
|
+
mx.export_function("sine.mlxfn", fun, x)
|
|
233
|
+
|
|
234
|
+
imported_fun = mx.import_function("sine.mlxfn")
|
|
235
|
+
|
|
236
|
+
# Take the derivative of the imported function
|
|
237
|
+
dfdx = mx.grad(lambda x: imported_fun(x)[0])
|
|
238
|
+
# Prints: array(1, dtype=float32)
|
|
239
|
+
print(dfdx(x))
|
|
240
|
+
|
|
241
|
+
# Compile the imported function
|
|
242
|
+
mx.compile(imported_fun)
|
|
243
|
+
# Prints: array(0, dtype=float32)
|
|
244
|
+
print(compiled_fun(x)[0])
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
Importing Functions in C++
|
|
248
|
+
--------------------------
|
|
249
|
+
|
|
250
|
+
Importing and running functions in C++ is basically the same as importing and
|
|
251
|
+
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
|
|
252
|
+
setup a simple C++ project that uses MLX as a library.
|
|
253
|
+
|
|
254
|
+
Next, export a simple function from Python:
|
|
255
|
+
|
|
256
|
+
.. code-block:: python
|
|
257
|
+
|
|
258
|
+
def fun(x, y):
|
|
259
|
+
return mx.exp(x + y)
|
|
260
|
+
|
|
261
|
+
x = mx.array(1.0)
|
|
262
|
+
y = mx.array(1.0)
|
|
263
|
+
mx.export_function("fun.mlxfn", fun, x, y)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
Import and run the function in C++ with only a few lines of code:
|
|
267
|
+
|
|
268
|
+
.. code-block:: c++
|
|
269
|
+
|
|
270
|
+
auto fun = mx::import_function("fun.mlxfn");
|
|
271
|
+
|
|
272
|
+
auto inputs = {mx::array(1.0), mx::array(1.0)};
|
|
273
|
+
auto outputs = fun(inputs);
|
|
274
|
+
|
|
275
|
+
// Prints: array(2, dtype=float32)
|
|
276
|
+
std::cout << outputs[0] << std::endl;
|
|
277
|
+
|
|
278
|
+
Imported functions can be transformed in C++ just like in Python. Use
|
|
279
|
+
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
|
280
|
+
mx::array>`` for keyword arguments when calling imported functions in C++.
|
|
281
|
+
|
|
282
|
+
More Examples
|
|
283
|
+
-------------
|
|
284
|
+
|
|
285
|
+
Here are a few more complete examples exporting more complex functions from
|
|
286
|
+
Python and importing and running them in C++:
|
|
287
|
+
|
|
288
|
+
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
.. _function_transforms:
|
|
2
|
+
|
|
3
|
+
Function Transforms
|
|
4
|
+
===================
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
MLX uses composable function transformations for automatic differentiation,
|
|
9
|
+
vectorization, and compute graph optimizations. To see the complete list of
|
|
10
|
+
function transformations check-out the :ref:`API documentation <transforms>`.
|
|
11
|
+
|
|
12
|
+
The key idea behind composable function transformations is that every
|
|
13
|
+
transformation returns a function which can be further transformed.
|
|
14
|
+
|
|
15
|
+
Here is a simple example:
|
|
16
|
+
|
|
17
|
+
.. code-block:: shell
|
|
18
|
+
|
|
19
|
+
>>> dfdx = mx.grad(mx.sin)
|
|
20
|
+
>>> dfdx(mx.array(mx.pi))
|
|
21
|
+
array(-1, dtype=float32)
|
|
22
|
+
>>> mx.cos(mx.array(mx.pi))
|
|
23
|
+
array(-1, dtype=float32)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
|
27
|
+
case it is the gradient of the sine function which is exactly the cosine
|
|
28
|
+
function. To get the second derivative you can do:
|
|
29
|
+
|
|
30
|
+
.. code-block:: shell
|
|
31
|
+
|
|
32
|
+
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
|
|
33
|
+
>>> d2fdx2(mx.array(mx.pi / 2))
|
|
34
|
+
array(-1, dtype=float32)
|
|
35
|
+
>>> mx.sin(mx.array(mx.pi / 2))
|
|
36
|
+
array(1, dtype=float32)
|
|
37
|
+
|
|
38
|
+
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
|
39
|
+
getting higher order derivatives.
|
|
40
|
+
|
|
41
|
+
Any of the MLX function transformations can be composed in any order to any
|
|
42
|
+
depth. See the following sections for more information on :ref:`automatic
|
|
43
|
+
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
|
44
|
+
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
Automatic Differentiation
|
|
48
|
+
-------------------------
|
|
49
|
+
|
|
50
|
+
.. _auto diff:
|
|
51
|
+
|
|
52
|
+
Automatic differentiation in MLX works on functions rather than on implicit
|
|
53
|
+
graphs.
|
|
54
|
+
|
|
55
|
+
.. note::
|
|
56
|
+
|
|
57
|
+
If you are coming to MLX from PyTorch, you no longer need functions like
|
|
58
|
+
``backward``, ``zero_grad``, and ``detach``, or properties like
|
|
59
|
+
``requires_grad``.
|
|
60
|
+
|
|
61
|
+
The most basic example is taking the gradient of a scalar-valued function as we
|
|
62
|
+
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
|
|
63
|
+
compute gradients of more complex functions. By default these functions compute
|
|
64
|
+
the gradient with respect to the first argument:
|
|
65
|
+
|
|
66
|
+
.. code-block:: python
|
|
67
|
+
|
|
68
|
+
def loss_fn(w, x, y):
|
|
69
|
+
return mx.mean(mx.square(w * x - y))
|
|
70
|
+
|
|
71
|
+
w = mx.array(1.0)
|
|
72
|
+
x = mx.array([0.5, -0.5])
|
|
73
|
+
y = mx.array([1.5, -1.5])
|
|
74
|
+
|
|
75
|
+
# Computes the gradient of loss_fn with respect to w:
|
|
76
|
+
grad_fn = mx.grad(loss_fn)
|
|
77
|
+
dloss_dw = grad_fn(w, x, y)
|
|
78
|
+
# Prints array(-1, dtype=float32)
|
|
79
|
+
print(dloss_dw)
|
|
80
|
+
|
|
81
|
+
# To get the gradient with respect to x we can do:
|
|
82
|
+
grad_fn = mx.grad(loss_fn, argnums=1)
|
|
83
|
+
dloss_dx = grad_fn(w, x, y)
|
|
84
|
+
# Prints array([-1, 1], dtype=float32)
|
|
85
|
+
print(dloss_dx)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
One way to get the loss and gradient is to call ``loss_fn`` followed by
|
|
89
|
+
``grad_fn``, but this can result in a lot of redundant work. Instead, you
|
|
90
|
+
should use :func:`value_and_grad`. Continuing the above example:
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
.. code-block:: python
|
|
94
|
+
|
|
95
|
+
# Computes the gradient of loss_fn with respect to w:
|
|
96
|
+
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
|
97
|
+
loss, dloss_dw = loss_and_grad_fn(w, x, y)
|
|
98
|
+
|
|
99
|
+
# Prints array(1, dtype=float32)
|
|
100
|
+
print(loss)
|
|
101
|
+
|
|
102
|
+
# Prints array(-1, dtype=float32)
|
|
103
|
+
print(dloss_dw)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
You can also take the gradient with respect to arbitrarily nested Python
|
|
107
|
+
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
|
|
108
|
+
:obj:`dict`).
|
|
109
|
+
|
|
110
|
+
Suppose we wanted a weight and a bias parameter in the above example. A nice
|
|
111
|
+
way to do that is the following:
|
|
112
|
+
|
|
113
|
+
.. code-block:: python
|
|
114
|
+
|
|
115
|
+
def loss_fn(params, x, y):
|
|
116
|
+
w, b = params["weight"], params["bias"]
|
|
117
|
+
h = w * x + b
|
|
118
|
+
return mx.mean(mx.square(h - y))
|
|
119
|
+
|
|
120
|
+
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
|
121
|
+
x = mx.array([0.5, -0.5])
|
|
122
|
+
y = mx.array([1.5, -1.5])
|
|
123
|
+
|
|
124
|
+
# Computes the gradient of loss_fn with respect to both the
|
|
125
|
+
# weight and bias:
|
|
126
|
+
grad_fn = mx.grad(loss_fn)
|
|
127
|
+
grads = grad_fn(params, x, y)
|
|
128
|
+
|
|
129
|
+
# Prints
|
|
130
|
+
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
|
|
131
|
+
print(grads)
|
|
132
|
+
|
|
133
|
+
Notice the tree structure of the parameters is preserved in the gradients.
|
|
134
|
+
|
|
135
|
+
In some cases you may want to stop gradients from propagating through a
|
|
136
|
+
part of the function. You can use the :func:`stop_gradient` for that.
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
Automatic Vectorization
|
|
140
|
+
-----------------------
|
|
141
|
+
|
|
142
|
+
.. _vmap:
|
|
143
|
+
|
|
144
|
+
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
|
|
145
|
+
through a basic and contrived example for the sake of clarity, but :func:`vmap`
|
|
146
|
+
can be quite powerful for more complex functions which are difficult to optimize
|
|
147
|
+
by hand.
|
|
148
|
+
|
|
149
|
+
.. warning::
|
|
150
|
+
|
|
151
|
+
Some operations are not yet supported with :func:`vmap`. If you encounter an error
|
|
152
|
+
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
|
|
153
|
+
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
|
|
154
|
+
We will prioritize including it.
|
|
155
|
+
|
|
156
|
+
A naive way to add the elements from two sets of vectors is with a loop:
|
|
157
|
+
|
|
158
|
+
.. code-block:: python
|
|
159
|
+
|
|
160
|
+
xs = mx.random.uniform(shape=(4096, 100))
|
|
161
|
+
ys = mx.random.uniform(shape=(100, 4096))
|
|
162
|
+
|
|
163
|
+
def naive_add(xs, ys):
|
|
164
|
+
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
|
165
|
+
|
|
166
|
+
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
|
167
|
+
|
|
168
|
+
.. code-block:: python
|
|
169
|
+
|
|
170
|
+
# Vectorize over the second dimension of x and the
|
|
171
|
+
# first dimension of y
|
|
172
|
+
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
|
173
|
+
|
|
174
|
+
The ``in_axes`` parameter can be used to specify which dimensions of the
|
|
175
|
+
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
|
176
|
+
where the vectorized axes should be in the outputs.
|
|
177
|
+
|
|
178
|
+
Let's time these two different versions:
|
|
179
|
+
|
|
180
|
+
.. code-block:: python
|
|
181
|
+
|
|
182
|
+
import timeit
|
|
183
|
+
|
|
184
|
+
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
|
185
|
+
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
|
186
|
+
|
|
187
|
+
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
|
188
|
+
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
|
189
|
+
|
|
190
|
+
Of course, this operation is quite contrived. A better approach is to simply do
|
|
191
|
+
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
.. _indexing:
|
|
2
|
+
|
|
3
|
+
Indexing Arrays
|
|
4
|
+
===============
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
|
9
|
+
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
|
10
|
+
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
|
11
|
+
how that works.
|
|
12
|
+
|
|
13
|
+
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
|
14
|
+
|
|
15
|
+
.. code-block:: shell
|
|
16
|
+
|
|
17
|
+
>>> arr = mx.arange(10)
|
|
18
|
+
>>> arr[3]
|
|
19
|
+
array(3, dtype=int32)
|
|
20
|
+
>>> arr[-2] # negative indexing works
|
|
21
|
+
array(8, dtype=int32)
|
|
22
|
+
>>> arr[2:8:2] # start, stop, stride
|
|
23
|
+
array([2, 4, 6], dtype=int32)
|
|
24
|
+
|
|
25
|
+
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
|
26
|
+
|
|
27
|
+
.. code-block:: shell
|
|
28
|
+
|
|
29
|
+
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
|
30
|
+
>>> arr[:, :, 0]
|
|
31
|
+
array(3, dtype=int32)
|
|
32
|
+
array([[0, 2],
|
|
33
|
+
[4, 6]], dtype=int32
|
|
34
|
+
>>> arr[..., 0]
|
|
35
|
+
array([[0, 2],
|
|
36
|
+
[4, 6]], dtype=int32
|
|
37
|
+
|
|
38
|
+
You can index with ``None`` to create a new axis:
|
|
39
|
+
|
|
40
|
+
.. code-block:: shell
|
|
41
|
+
|
|
42
|
+
>>> arr = mx.arange(8)
|
|
43
|
+
>>> arr.shape
|
|
44
|
+
[8]
|
|
45
|
+
>>> arr[None].shape
|
|
46
|
+
[1, 8]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
You can also use an :obj:`array` to index another :obj:`array`:
|
|
50
|
+
|
|
51
|
+
.. code-block:: shell
|
|
52
|
+
|
|
53
|
+
>>> arr = mx.arange(10)
|
|
54
|
+
>>> idx = mx.array([5, 7])
|
|
55
|
+
>>> arr[idx]
|
|
56
|
+
array([5, 7], dtype=int32)
|
|
57
|
+
|
|
58
|
+
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
|
59
|
+
works just as in NumPy.
|
|
60
|
+
|
|
61
|
+
Other functions which may be useful for indexing arrays are :func:`take` and
|
|
62
|
+
:func:`take_along_axis`.
|
|
63
|
+
|
|
64
|
+
Differences from NumPy
|
|
65
|
+
----------------------
|
|
66
|
+
|
|
67
|
+
.. Note::
|
|
68
|
+
|
|
69
|
+
MLX indexing is different from NumPy indexing in two important ways:
|
|
70
|
+
|
|
71
|
+
* Indexing does not perform bounds checking. Indexing out of bounds is
|
|
72
|
+
undefined behavior.
|
|
73
|
+
* Boolean mask based indexing is supported for assignment only (see
|
|
74
|
+
:ref:`boolean-mask-assignment`).
|
|
75
|
+
|
|
76
|
+
The reason for the lack of bounds checking is that exceptions cannot propagate
|
|
77
|
+
from the GPU. Performing bounds checking for array indices before launching the
|
|
78
|
+
kernel would be extremely inefficient.
|
|
79
|
+
|
|
80
|
+
Indexing with boolean masks is something that MLX may support in the future. In
|
|
81
|
+
general, MLX has limited support for operations for which output
|
|
82
|
+
*shapes* are dependent on input *data*. Other examples of these types of
|
|
83
|
+
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
|
84
|
+
single input version of :func:`numpy.where`.
|
|
85
|
+
|
|
86
|
+
In Place Updates
|
|
87
|
+
----------------
|
|
88
|
+
|
|
89
|
+
In place updates to indexed arrays are possible in MLX. For example:
|
|
90
|
+
|
|
91
|
+
.. code-block:: shell
|
|
92
|
+
|
|
93
|
+
>>> a = mx.array([1, 2, 3])
|
|
94
|
+
>>> a[2] = 0
|
|
95
|
+
>>> a
|
|
96
|
+
array([1, 2, 0], dtype=int32)
|
|
97
|
+
|
|
98
|
+
Just as in NumPy, in place updates will be reflected in all references to the
|
|
99
|
+
same array:
|
|
100
|
+
|
|
101
|
+
.. code-block:: shell
|
|
102
|
+
|
|
103
|
+
>>> a = mx.array([1, 2, 3])
|
|
104
|
+
>>> b = a
|
|
105
|
+
>>> b[2] = 0
|
|
106
|
+
>>> b
|
|
107
|
+
array([1, 2, 0], dtype=int32)
|
|
108
|
+
>>> a
|
|
109
|
+
array([1, 2, 0], dtype=int32)
|
|
110
|
+
|
|
111
|
+
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
|
112
|
+
mutating it does not mutate the original array:
|
|
113
|
+
|
|
114
|
+
.. code-block:: shell
|
|
115
|
+
|
|
116
|
+
>>> a = mx.array([1, 2, 3])
|
|
117
|
+
>>> b = a[:]
|
|
118
|
+
>>> b[2] = 0
|
|
119
|
+
>>> b
|
|
120
|
+
array([1, 2, 0], dtype=int32)
|
|
121
|
+
>>> a
|
|
122
|
+
array([1, 2, 3], dtype=int32)
|
|
123
|
+
|
|
124
|
+
Also unlike NumPy, updates to the same location are nondeterministic:
|
|
125
|
+
|
|
126
|
+
.. code-block:: shell
|
|
127
|
+
|
|
128
|
+
>>> a = mx.array([1, 2, 3])
|
|
129
|
+
>>> a[[0, 0]] = mx.array([4, 5])
|
|
130
|
+
|
|
131
|
+
The first element of ``a`` could be ``4`` or ``5``.
|
|
132
|
+
|
|
133
|
+
Transformations of functions which use in-place updates are allowed and work as
|
|
134
|
+
expected. For example:
|
|
135
|
+
|
|
136
|
+
.. code-block:: python
|
|
137
|
+
|
|
138
|
+
def fun(x, idx):
|
|
139
|
+
x[idx] = 2.0
|
|
140
|
+
return x.sum()
|
|
141
|
+
|
|
142
|
+
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
|
143
|
+
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
|
144
|
+
|
|
145
|
+
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
|
146
|
+
and ones elsewhere.
|
|
147
|
+
|
|
148
|
+
.. _boolean-mask-assignment:
|
|
149
|
+
|
|
150
|
+
Boolean Mask Assignment
|
|
151
|
+
-----------------------
|
|
152
|
+
|
|
153
|
+
MLX supports boolean indices using NumPy syntax. A mask must already be
|
|
154
|
+
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
|
155
|
+
Other index types are routed through the standard scatter code.
|
|
156
|
+
|
|
157
|
+
.. code-block:: shell
|
|
158
|
+
|
|
159
|
+
>>> a = mx.array([1.0, 2.0, 3.0])
|
|
160
|
+
>>> mask = mx.array([True, False, True])
|
|
161
|
+
>>> updates = mx.array([5.0, 6.0])
|
|
162
|
+
>>> a[mask] = updates
|
|
163
|
+
>>> a
|
|
164
|
+
array([5.0, 2.0, 6.0], dtype=float32)
|
|
165
|
+
|
|
166
|
+
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
|
167
|
+
assignments, ``updates`` must provide at least as many elements as there are
|
|
168
|
+
``True`` entries in ``mask``.
|
|
169
|
+
|
|
170
|
+
.. code-block:: shell
|
|
171
|
+
|
|
172
|
+
>>> a = mx.zeros((2, 3))
|
|
173
|
+
>>> mask = mx.array([[True, False, True],
|
|
174
|
+
[False, False, True]])
|
|
175
|
+
>>> a[mask] = 1.0
|
|
176
|
+
>>> a
|
|
177
|
+
array([[1.0, 0.0, 1.0],
|
|
178
|
+
[0.0, 0.0, 1.0]], dtype=float32)
|
|
179
|
+
|
|
180
|
+
Boolean masks follow NumPy semantics:
|
|
181
|
+
|
|
182
|
+
- The mask shape must match the shape of the axes it indexes exactly. The only
|
|
183
|
+
exception is a scalar boolean mask, which broadcasts to the full array.
|
|
184
|
+
- Any axes not covered by the mask are taken in full.
|
|
185
|
+
|
|
186
|
+
.. code-block:: shell
|
|
187
|
+
|
|
188
|
+
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
|
189
|
+
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
|
190
|
+
|
|
191
|
+
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
|
192
|
+
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
|
193
|
+
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
|
194
|
+
axes and therefore raise errors.
|