mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
.. _unified_memory:
|
|
2
|
+
|
|
3
|
+
Unified Memory
|
|
4
|
+
==============
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
Apple silicon has a unified memory architecture. The CPU and GPU have direct
|
|
9
|
+
access to the same memory pool. MLX is designed to take advantage of that.
|
|
10
|
+
|
|
11
|
+
Concretely, when you make an array in MLX you don't have to specify its location:
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
.. code-block:: python
|
|
15
|
+
|
|
16
|
+
a = mx.random.normal((100,))
|
|
17
|
+
b = mx.random.normal((100,))
|
|
18
|
+
|
|
19
|
+
Both ``a`` and ``b`` live in unified memory.
|
|
20
|
+
|
|
21
|
+
In MLX, rather than moving arrays to devices, you specify the device when you
|
|
22
|
+
run the operation. Any device can perform any operation on ``a`` and ``b``
|
|
23
|
+
without needing to move them from one memory location to another. For example:
|
|
24
|
+
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
mx.add(a, b, stream=mx.cpu)
|
|
28
|
+
mx.add(a, b, stream=mx.gpu)
|
|
29
|
+
|
|
30
|
+
In the above, both the CPU and the GPU will perform the same add
|
|
31
|
+
operation. The operations can (and likely will) be run in parallel since
|
|
32
|
+
there are no dependencies between them. See :ref:`using_streams` for more
|
|
33
|
+
information the semantics of streams in MLX.
|
|
34
|
+
|
|
35
|
+
In the above ``add`` example, there are no dependencies between operations, so
|
|
36
|
+
there is no possibility for race conditions. If there are dependencies, the
|
|
37
|
+
MLX scheduler will automatically manage them. For example:
|
|
38
|
+
|
|
39
|
+
.. code-block:: python
|
|
40
|
+
|
|
41
|
+
c = mx.add(a, b, stream=mx.cpu)
|
|
42
|
+
d = mx.add(a, c, stream=mx.gpu)
|
|
43
|
+
|
|
44
|
+
In the above case, the second ``add`` runs on the GPU but it depends on the
|
|
45
|
+
output of the first ``add`` which is running on the CPU. MLX will
|
|
46
|
+
automatically insert a dependency between the two streams so that the second
|
|
47
|
+
``add`` only starts executing after the first is complete and ``c`` is
|
|
48
|
+
available.
|
|
49
|
+
|
|
50
|
+
A Simple Example
|
|
51
|
+
~~~~~~~~~~~~~~~~
|
|
52
|
+
|
|
53
|
+
Here is a more interesting (albeit slightly contrived example) of how unified
|
|
54
|
+
memory can be helpful. Suppose we have the following computation:
|
|
55
|
+
|
|
56
|
+
.. code-block:: python
|
|
57
|
+
|
|
58
|
+
def fun(a, b, d1, d2):
|
|
59
|
+
x = mx.matmul(a, b, stream=d1)
|
|
60
|
+
for _ in range(500):
|
|
61
|
+
b = mx.exp(b, stream=d2)
|
|
62
|
+
return x, b
|
|
63
|
+
|
|
64
|
+
which we want to run with the following arguments:
|
|
65
|
+
|
|
66
|
+
.. code-block:: python
|
|
67
|
+
|
|
68
|
+
a = mx.random.uniform(shape=(4096, 512))
|
|
69
|
+
b = mx.random.uniform(shape=(512, 4))
|
|
70
|
+
|
|
71
|
+
The first ``matmul`` operation is a good fit for the GPU since it's more
|
|
72
|
+
compute dense. The second sequence of operations are a better fit for the CPU,
|
|
73
|
+
since they are very small and would probably be overhead bound on the GPU.
|
|
74
|
+
|
|
75
|
+
If we time the computation fully on the GPU, we get 2.8 milliseconds. But if we
|
|
76
|
+
run the computation with ``d1=mx.gpu`` and ``d2=mx.cpu``, then the time is only
|
|
77
|
+
about 1.4 milliseconds, about twice as fast. These times were measured on an M1
|
|
78
|
+
Max.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
.. _using_streams:
|
|
2
|
+
|
|
3
|
+
Using Streams
|
|
4
|
+
=============
|
|
5
|
+
|
|
6
|
+
.. currentmodule:: mlx.core
|
|
7
|
+
|
|
8
|
+
Specifying the :obj:`Stream`
|
|
9
|
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
10
|
+
|
|
11
|
+
All operations (including random number generation) take an optional
|
|
12
|
+
keyword argument ``stream``. The ``stream`` kwarg specifies which
|
|
13
|
+
:obj:`Stream` the operation should run on. If the stream is unspecified then
|
|
14
|
+
the operation is run on the default stream of the default device:
|
|
15
|
+
``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also
|
|
16
|
+
be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is
|
|
17
|
+
run on the default stream of the provided device
|
|
18
|
+
``mx.default_stream(my_device)``.
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.27)
|
|
2
|
+
|
|
3
|
+
project(example LANGUAGES CXX)
|
|
4
|
+
|
|
5
|
+
set(CMAKE_CXX_STANDARD 17)
|
|
6
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
7
|
+
|
|
8
|
+
# Comment the following two commands only the MLX C++ library is installed and
|
|
9
|
+
# set(MLX_ROOT "/path/to/mlx") directly if needed.
|
|
10
|
+
find_package(
|
|
11
|
+
Python 3.9
|
|
12
|
+
COMPONENTS Interpreter Development.Module
|
|
13
|
+
REQUIRED)
|
|
14
|
+
execute_process(
|
|
15
|
+
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
|
16
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
17
|
+
OUTPUT_VARIABLE MLX_ROOT)
|
|
18
|
+
|
|
19
|
+
find_package(MLX CONFIG REQUIRED)
|
|
20
|
+
|
|
21
|
+
add_executable(example example.cpp)
|
|
22
|
+
target_link_libraries(example PRIVATE mlx)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
## Build and Run
|
|
2
|
+
|
|
3
|
+
Install MLX with Python:
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
pip install mlx>=0.22
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
Build the C++ example:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
|
13
|
+
cmake --build build
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Run the C++ example:
|
|
17
|
+
|
|
18
|
+
```
|
|
19
|
+
./build/example
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
which should output:
|
|
23
|
+
|
|
24
|
+
```
|
|
25
|
+
array([2, 4, 6], dtype=int32)
|
|
26
|
+
```
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
function(build_example SRCFILE)
|
|
2
|
+
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
|
3
|
+
set(target "${src_name}")
|
|
4
|
+
add_executable(${target} ${SRCFILE})
|
|
5
|
+
target_link_libraries(${target} PRIVATE mlx)
|
|
6
|
+
endfunction(build_example)
|
|
7
|
+
|
|
8
|
+
build_example(tutorial.cpp)
|
|
9
|
+
build_example(linear_regression.cpp)
|
|
10
|
+
build_example(logistic_regression.cpp)
|
|
11
|
+
build_example(metal_capture.cpp)
|
|
12
|
+
build_example(distributed.cpp)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <iostream>
|
|
4
|
+
|
|
5
|
+
#include "mlx/mlx.h"
|
|
6
|
+
|
|
7
|
+
namespace mx = mlx::core;
|
|
8
|
+
|
|
9
|
+
int main() {
|
|
10
|
+
if (!mx::distributed::is_available()) {
|
|
11
|
+
std::cout << "No communication backend found" << std::endl;
|
|
12
|
+
return 1;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
auto global_group = mx::distributed::init();
|
|
16
|
+
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
|
17
|
+
|
|
18
|
+
mx::array x = mx::ones({10});
|
|
19
|
+
mx::array out = mx::distributed::all_sum(x, global_group);
|
|
20
|
+
|
|
21
|
+
std::cout << out << std::endl;
|
|
22
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <chrono>
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <iostream>
|
|
6
|
+
|
|
7
|
+
#include "mlx/mlx.h"
|
|
8
|
+
#include "timer.h"
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* An example of linear regression with MLX.
|
|
12
|
+
*/
|
|
13
|
+
namespace mx = mlx::core;
|
|
14
|
+
|
|
15
|
+
int main() {
|
|
16
|
+
int num_features = 100;
|
|
17
|
+
int num_examples = 1'000;
|
|
18
|
+
int num_iters = 10'000;
|
|
19
|
+
float learning_rate = 0.01;
|
|
20
|
+
|
|
21
|
+
// True parameters
|
|
22
|
+
auto w_star = mx::random::normal({num_features});
|
|
23
|
+
|
|
24
|
+
// The input examples (design matrix)
|
|
25
|
+
auto X = mx::random::normal({num_examples, num_features});
|
|
26
|
+
|
|
27
|
+
// Noisy labels
|
|
28
|
+
auto eps = 1e-2 * mx::random::normal({num_examples});
|
|
29
|
+
auto y = mx::matmul(X, w_star) + eps;
|
|
30
|
+
|
|
31
|
+
// Initialize random parameters
|
|
32
|
+
mx::array w = 1e-2 * mx::random::normal({num_features});
|
|
33
|
+
|
|
34
|
+
auto loss_fn = [&](mx::array w) {
|
|
35
|
+
auto yhat = mx::matmul(X, w);
|
|
36
|
+
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
auto grad_fn = mx::grad(loss_fn);
|
|
40
|
+
|
|
41
|
+
auto tic = timer::time();
|
|
42
|
+
for (int it = 0; it < num_iters; ++it) {
|
|
43
|
+
auto grads = grad_fn(w);
|
|
44
|
+
w = w - learning_rate * grads;
|
|
45
|
+
mx::eval(w);
|
|
46
|
+
}
|
|
47
|
+
auto toc = timer::time();
|
|
48
|
+
|
|
49
|
+
auto loss = loss_fn(w);
|
|
50
|
+
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
|
51
|
+
auto throughput = num_iters / timer::seconds(toc - tic);
|
|
52
|
+
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
|
53
|
+
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
|
54
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <chrono>
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <iostream>
|
|
6
|
+
|
|
7
|
+
#include "mlx/mlx.h"
|
|
8
|
+
#include "timer.h"
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* An example of logistic regression with MLX.
|
|
12
|
+
*/
|
|
13
|
+
namespace mx = mlx::core;
|
|
14
|
+
|
|
15
|
+
int main() {
|
|
16
|
+
int num_features = 100;
|
|
17
|
+
int num_examples = 1'000;
|
|
18
|
+
int num_iters = 10'000;
|
|
19
|
+
float learning_rate = 0.1;
|
|
20
|
+
|
|
21
|
+
// True parameters
|
|
22
|
+
auto w_star = mx::random::normal({num_features});
|
|
23
|
+
|
|
24
|
+
// The input examples
|
|
25
|
+
auto X = mx::random::normal({num_examples, num_features});
|
|
26
|
+
|
|
27
|
+
// Labels
|
|
28
|
+
auto y = mx::matmul(X, w_star) > 0;
|
|
29
|
+
|
|
30
|
+
// Initialize random parameters
|
|
31
|
+
mx::array w = 1e-2 * mx::random::normal({num_features});
|
|
32
|
+
|
|
33
|
+
auto loss_fn = [&](mx::array w) {
|
|
34
|
+
auto logits = mx::matmul(X, w);
|
|
35
|
+
auto scale = (1.0f / num_examples);
|
|
36
|
+
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
auto grad_fn = mx::grad(loss_fn);
|
|
40
|
+
|
|
41
|
+
auto tic = timer::time();
|
|
42
|
+
for (int it = 0; it < num_iters; ++it) {
|
|
43
|
+
auto grads = grad_fn(w);
|
|
44
|
+
w = w - learning_rate * grads;
|
|
45
|
+
mx::eval(w);
|
|
46
|
+
}
|
|
47
|
+
auto toc = timer::time();
|
|
48
|
+
|
|
49
|
+
auto loss = loss_fn(w);
|
|
50
|
+
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
|
51
|
+
auto throughput = num_iters / timer::seconds(toc - tic);
|
|
52
|
+
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
|
53
|
+
<< throughput << " (it/s)." << std::endl;
|
|
54
|
+
}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <cassert>
|
|
4
|
+
#include <iostream>
|
|
5
|
+
|
|
6
|
+
#include "mlx/mlx.h"
|
|
7
|
+
|
|
8
|
+
namespace mx = mlx::core;
|
|
9
|
+
|
|
10
|
+
int main() {
|
|
11
|
+
// To use Metal debugging and profiling:
|
|
12
|
+
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
|
13
|
+
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
|
14
|
+
mx::metal::start_capture("mlx_trace.gputrace");
|
|
15
|
+
|
|
16
|
+
// Start at index two because the default GPU and CPU streams have indices
|
|
17
|
+
// zero and one, respectively. This naming matches the label assigned to each
|
|
18
|
+
// stream's command queue.
|
|
19
|
+
auto s2 = new_stream(mx::Device::gpu);
|
|
20
|
+
auto s3 = new_stream(mx::Device::gpu);
|
|
21
|
+
|
|
22
|
+
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
|
23
|
+
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
|
24
|
+
auto x = mx::add(a, a, s2);
|
|
25
|
+
auto y = mx::add(b, b, s3);
|
|
26
|
+
|
|
27
|
+
// The multiply will happen on the default stream.
|
|
28
|
+
std::cout << mx::multiply(x, y) << std::endl;
|
|
29
|
+
|
|
30
|
+
mx::metal::stop_capture();
|
|
31
|
+
}
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <chrono>
|
|
6
|
+
|
|
7
|
+
namespace timer {
|
|
8
|
+
|
|
9
|
+
using namespace std::chrono;
|
|
10
|
+
|
|
11
|
+
template <typename R, typename P>
|
|
12
|
+
inline double seconds(duration<R, P> x) {
|
|
13
|
+
return duration_cast<nanoseconds>(x).count() / 1e9;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
inline auto time() {
|
|
17
|
+
return high_resolution_clock::now();
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
} // namespace timer
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <cassert>
|
|
4
|
+
#include <iostream>
|
|
5
|
+
|
|
6
|
+
#include "mlx/mlx.h"
|
|
7
|
+
|
|
8
|
+
namespace mx = mlx::core;
|
|
9
|
+
|
|
10
|
+
void array_basics() {
|
|
11
|
+
// Make a scalar array:
|
|
12
|
+
mx::array x(1.0);
|
|
13
|
+
|
|
14
|
+
// Get the value out of it:
|
|
15
|
+
auto s = x.item<float>();
|
|
16
|
+
assert(s == 1.0);
|
|
17
|
+
|
|
18
|
+
// Scalars have a size of 1:
|
|
19
|
+
size_t size = x.size();
|
|
20
|
+
assert(size == 1);
|
|
21
|
+
|
|
22
|
+
// Scalars have 0 dimensions:
|
|
23
|
+
int ndim = x.ndim();
|
|
24
|
+
assert(ndim == 0);
|
|
25
|
+
|
|
26
|
+
// The shape should be an empty vector:
|
|
27
|
+
auto shape = x.shape();
|
|
28
|
+
assert(shape.empty());
|
|
29
|
+
|
|
30
|
+
// The datatype should be float32:
|
|
31
|
+
auto dtype = x.dtype();
|
|
32
|
+
assert(dtype == mx::float32);
|
|
33
|
+
|
|
34
|
+
// Specify the dtype when constructing the array:
|
|
35
|
+
x = mx::array(1, mx::int32);
|
|
36
|
+
assert(x.dtype() == mx::int32);
|
|
37
|
+
x.item<int>(); // OK
|
|
38
|
+
// x.item<float>(); // Undefined!
|
|
39
|
+
|
|
40
|
+
// Make a multidimensional array:
|
|
41
|
+
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
|
42
|
+
// mlx is row-major by default so the first row of this array
|
|
43
|
+
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
|
44
|
+
|
|
45
|
+
// Make an array of shape {2, 2} filled with ones:
|
|
46
|
+
auto y = mx::ones({2, 2});
|
|
47
|
+
|
|
48
|
+
// Pointwise add x and y:
|
|
49
|
+
auto z = mx::add(x, y);
|
|
50
|
+
|
|
51
|
+
// Same thing:
|
|
52
|
+
z = x + y;
|
|
53
|
+
|
|
54
|
+
// mlx is lazy by default. At this point `z` only
|
|
55
|
+
// has a shape and a type but no actual data:
|
|
56
|
+
assert(z.dtype() == mx::float32);
|
|
57
|
+
assert(z.shape(0) == 2);
|
|
58
|
+
assert(z.shape(1) == 2);
|
|
59
|
+
|
|
60
|
+
// To actually run the computation you must evaluate `z`.
|
|
61
|
+
// Under the hood, mlx records operations in a graph.
|
|
62
|
+
// The variable `z` is a node in the graph which points to its operation
|
|
63
|
+
// and inputs. When `eval` is called on an array (or arrays), the array and
|
|
64
|
+
// all of its dependencies are recursively evaluated to produce the result.
|
|
65
|
+
// Once an array is evaluated, it has data and is detached from its inputs.
|
|
66
|
+
mx::eval(z);
|
|
67
|
+
|
|
68
|
+
// Of course the array can still be an input to other operations. You can
|
|
69
|
+
// even call eval on the array again, this will just be a no-op:
|
|
70
|
+
mx::eval(z); // no-op
|
|
71
|
+
|
|
72
|
+
// Some functions or methods on arrays implicitly evaluate them. For example
|
|
73
|
+
// accessing a value in an array or printing the array implicitly evaluate it:
|
|
74
|
+
z = mx::ones({1});
|
|
75
|
+
z.item<float>(); // implicit evaluation
|
|
76
|
+
|
|
77
|
+
z = mx::ones({2, 2});
|
|
78
|
+
std::cout << z << std::endl; // implicit evaluation
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
void automatic_differentiation() {
|
|
82
|
+
auto fn = [](mx::array x) { return mx::square(x); };
|
|
83
|
+
|
|
84
|
+
// Computing the derivative function of a function
|
|
85
|
+
auto grad_fn = mx::grad(fn);
|
|
86
|
+
// Call grad_fn on the input to get the derivative
|
|
87
|
+
auto x = mx::array(1.5);
|
|
88
|
+
auto dfdx = grad_fn(x);
|
|
89
|
+
// dfdx is 2 * x
|
|
90
|
+
|
|
91
|
+
// Get the second derivative by composing grad with grad
|
|
92
|
+
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
|
93
|
+
// d2fdx2 is 2
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
int main() {
|
|
97
|
+
array_basics();
|
|
98
|
+
automatic_differentiation();
|
|
99
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.27)
|
|
2
|
+
|
|
3
|
+
project(import_mlx LANGUAGES CXX)
|
|
4
|
+
|
|
5
|
+
set(CMAKE_CXX_STANDARD 17)
|
|
6
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
7
|
+
|
|
8
|
+
find_package(
|
|
9
|
+
Python 3.9
|
|
10
|
+
COMPONENTS Interpreter Development.Module
|
|
11
|
+
REQUIRED)
|
|
12
|
+
execute_process(
|
|
13
|
+
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
|
14
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
15
|
+
OUTPUT_VARIABLE MLX_ROOT)
|
|
16
|
+
find_package(MLX CONFIG REQUIRED)
|
|
17
|
+
|
|
18
|
+
add_executable(eval_mlp eval_mlp.cpp)
|
|
19
|
+
target_link_libraries(eval_mlp PRIVATE mlx)
|
|
20
|
+
|
|
21
|
+
add_executable(train_mlp train_mlp.cpp)
|
|
22
|
+
target_link_libraries(train_mlp PRIVATE mlx)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
## Setup
|
|
2
|
+
|
|
3
|
+
Install MLX:
|
|
4
|
+
|
|
5
|
+
```bash
|
|
6
|
+
pip install mlx>=0.22
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
Build the C++ examples:
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
|
13
|
+
cmake --build build
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Run
|
|
17
|
+
|
|
18
|
+
### Eval MLP
|
|
19
|
+
|
|
20
|
+
Run the Python script to export the eval function:
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
python eval_mlp.py
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
Then run the C++ program to import and run the function:
|
|
27
|
+
|
|
28
|
+
```
|
|
29
|
+
./build/eval_mlp
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
The Python and C++ programs should output the same result.
|
|
33
|
+
|
|
34
|
+
### Train MLP
|
|
35
|
+
|
|
36
|
+
Run the Python script to export the model initialization and training
|
|
37
|
+
functions:
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
python train_mlp.py
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Then run the C++ program to import and run the functions:
|
|
44
|
+
|
|
45
|
+
```
|
|
46
|
+
./build/train_mlp
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
The Python and C++ programs should output the same results.
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <mlx/mlx.h>
|
|
4
|
+
#include <iostream>
|
|
5
|
+
|
|
6
|
+
namespace mx = mlx::core;
|
|
7
|
+
|
|
8
|
+
int main() {
|
|
9
|
+
int batch_size = 8;
|
|
10
|
+
int input_dim = 32;
|
|
11
|
+
|
|
12
|
+
// Make the input
|
|
13
|
+
mx::random::seed(42);
|
|
14
|
+
auto example_x = mx::random::uniform({batch_size, input_dim});
|
|
15
|
+
|
|
16
|
+
// Import the function
|
|
17
|
+
auto forward = mx::import_function("eval_mlp.mlxfn");
|
|
18
|
+
|
|
19
|
+
// Call the imported function
|
|
20
|
+
auto out = forward({example_x})[0];
|
|
21
|
+
|
|
22
|
+
std::cout << out << std::endl;
|
|
23
|
+
|
|
24
|
+
return 0;
|
|
25
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import mlx.utils
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MLP(nn.Module):
|
|
9
|
+
"""A simple MLP."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
|
16
|
+
self.layers = [
|
|
17
|
+
nn.Linear(idim, odim)
|
|
18
|
+
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
def __call__(self, x):
|
|
22
|
+
for l in self.layers[:-1]:
|
|
23
|
+
x = nn.relu(l(x))
|
|
24
|
+
return self.layers[-1](x)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if __name__ == "__main__":
|
|
28
|
+
|
|
29
|
+
batch_size = 8
|
|
30
|
+
input_dim = 32
|
|
31
|
+
output_dim = 10
|
|
32
|
+
|
|
33
|
+
# Load the model
|
|
34
|
+
mx.random.seed(0) # Seed for params
|
|
35
|
+
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
|
36
|
+
mx.eval(model)
|
|
37
|
+
|
|
38
|
+
# Note, the model parameters are saved in the export function
|
|
39
|
+
def forward(x):
|
|
40
|
+
return model(x)
|
|
41
|
+
|
|
42
|
+
mx.random.seed(42) # Seed for input
|
|
43
|
+
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
|
44
|
+
|
|
45
|
+
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
|
46
|
+
|
|
47
|
+
# Import in Python
|
|
48
|
+
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
|
49
|
+
expected = forward(example_x)
|
|
50
|
+
(out,) = imported_forward(example_x)
|
|
51
|
+
assert mx.allclose(expected, out)
|
|
52
|
+
print(out)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <mlx/mlx.h>
|
|
4
|
+
#include <iostream>
|
|
5
|
+
|
|
6
|
+
namespace mx = mlx::core;
|
|
7
|
+
|
|
8
|
+
int main() {
|
|
9
|
+
int batch_size = 8;
|
|
10
|
+
int input_dim = 32;
|
|
11
|
+
int output_dim = 10;
|
|
12
|
+
|
|
13
|
+
auto state = mx::import_function("init_mlp.mlxfn")({});
|
|
14
|
+
|
|
15
|
+
// Make the input
|
|
16
|
+
mx::random::seed(42);
|
|
17
|
+
auto example_X = mx::random::normal({batch_size, input_dim});
|
|
18
|
+
auto example_y = mx::random::randint(0, output_dim, {batch_size});
|
|
19
|
+
|
|
20
|
+
// Import the function
|
|
21
|
+
auto step = mx::import_function("train_mlp.mlxfn");
|
|
22
|
+
|
|
23
|
+
// Call the imported function
|
|
24
|
+
for (int it = 0; it < 100; ++it) {
|
|
25
|
+
state.insert(state.end(), {example_X, example_y});
|
|
26
|
+
state = step(state);
|
|
27
|
+
eval(state);
|
|
28
|
+
auto loss = state.back();
|
|
29
|
+
state.pop_back();
|
|
30
|
+
if (it % 10 == 0) {
|
|
31
|
+
std::cout << "Loss " << loss.item<float>() << std::endl;
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
return 0;
|
|
35
|
+
}
|