mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
data/mlx/CMakeLists.txt
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.25)
|
|
2
|
+
|
|
3
|
+
if(NOT MLX_VERSION)
|
|
4
|
+
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
|
5
|
+
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
|
6
|
+
set(_major ${CMAKE_MATCH_1})
|
|
7
|
+
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
|
8
|
+
set(_minor ${CMAKE_MATCH_1})
|
|
9
|
+
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
|
10
|
+
set(_patch ${CMAKE_MATCH_1})
|
|
11
|
+
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
|
12
|
+
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
|
13
|
+
else()
|
|
14
|
+
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
|
15
|
+
${MLX_VERSION})
|
|
16
|
+
endif()
|
|
17
|
+
|
|
18
|
+
project(
|
|
19
|
+
mlx
|
|
20
|
+
LANGUAGES C CXX
|
|
21
|
+
VERSION ${MLX_PROJECT_VERSION})
|
|
22
|
+
|
|
23
|
+
# ----------------------------- Setup -----------------------------
|
|
24
|
+
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
|
25
|
+
set(CMAKE_CXX_STANDARD 20)
|
|
26
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
27
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
28
|
+
set(CMAKE_INSTALL_MESSAGE NEVER)
|
|
29
|
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
30
|
+
|
|
31
|
+
# ----------------------------- Configuration -----------------------------
|
|
32
|
+
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
|
33
|
+
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
|
34
|
+
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|
35
|
+
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
|
36
|
+
option(MLX_BUILD_METAL "Build metal backend" ON)
|
|
37
|
+
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
|
38
|
+
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
|
39
|
+
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
|
40
|
+
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
|
41
|
+
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
|
42
|
+
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
|
43
|
+
option(MLX_BUILD_PYTHON_STUBS "Build stub files for python bindings" ON)
|
|
44
|
+
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
|
45
|
+
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
|
46
|
+
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
|
47
|
+
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
|
48
|
+
option(USE_ASAN "Enable AddressSanitizer (ASan)" OFF)
|
|
49
|
+
option(USE_UBSAN "Enable UndefinedBehaviorSanitizer (UBSan)" OFF)
|
|
50
|
+
option(USE_TSAN "Enable ThreadSanitizer (TSan)" OFF)
|
|
51
|
+
|
|
52
|
+
# --------------------- Processor tests -------------------------
|
|
53
|
+
message(
|
|
54
|
+
STATUS
|
|
55
|
+
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|
59
|
+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
|
60
|
+
if(NOT MLX_ENABLE_X64_MAC)
|
|
61
|
+
message(
|
|
62
|
+
FATAL_ERROR
|
|
63
|
+
"Building for x86_64 on macOS is not supported."
|
|
64
|
+
" If you are on an Apple silicon system, check the build"
|
|
65
|
+
" documentation for possible fixes: "
|
|
66
|
+
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
|
67
|
+
)
|
|
68
|
+
else()
|
|
69
|
+
set(MLX_BUILD_METAL OFF)
|
|
70
|
+
message(WARNING "Building for x86_64 arch is not officially supported.")
|
|
71
|
+
endif()
|
|
72
|
+
endif()
|
|
73
|
+
else()
|
|
74
|
+
set(MLX_BUILD_METAL OFF)
|
|
75
|
+
endif()
|
|
76
|
+
|
|
77
|
+
if(MLX_USE_CCACHE)
|
|
78
|
+
find_program(CCACHE_PROGRAM ccache)
|
|
79
|
+
if(CCACHE_PROGRAM)
|
|
80
|
+
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
|
81
|
+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
82
|
+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
83
|
+
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
84
|
+
endif()
|
|
85
|
+
endif()
|
|
86
|
+
|
|
87
|
+
if(USE_ASAN AND USE_TSAN)
|
|
88
|
+
message(
|
|
89
|
+
FATAL_ERROR
|
|
90
|
+
"AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time."
|
|
91
|
+
)
|
|
92
|
+
endif()
|
|
93
|
+
|
|
94
|
+
set(SANITIZER_COMPILE_FLAGS "")
|
|
95
|
+
set(SANITIZER_LINK_FLAGS "")
|
|
96
|
+
|
|
97
|
+
if(USE_ASAN)
|
|
98
|
+
if(WIN32 AND MSVC)
|
|
99
|
+
list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address)
|
|
100
|
+
list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address)
|
|
101
|
+
else()
|
|
102
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address)
|
|
103
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address)
|
|
104
|
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
105
|
+
list(APPEND SANITIZER_LINK_FLAGS -lpthread)
|
|
106
|
+
endif()
|
|
107
|
+
endif()
|
|
108
|
+
endif()
|
|
109
|
+
|
|
110
|
+
if(USE_UBSAN)
|
|
111
|
+
if(WIN32 AND MSVC)
|
|
112
|
+
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|
113
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
|
|
114
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
|
|
115
|
+
else()
|
|
116
|
+
message(
|
|
117
|
+
WARNING
|
|
118
|
+
"UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC."
|
|
119
|
+
)
|
|
120
|
+
endif()
|
|
121
|
+
else()
|
|
122
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
|
|
123
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
|
|
124
|
+
endif()
|
|
125
|
+
endif()
|
|
126
|
+
|
|
127
|
+
if(USE_TSAN)
|
|
128
|
+
if(WIN32 AND MSVC)
|
|
129
|
+
message(
|
|
130
|
+
FATAL_ERROR
|
|
131
|
+
"ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC."
|
|
132
|
+
)
|
|
133
|
+
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
|
134
|
+
message(FATAL_ERROR "ThreadSanitizer (TSan) is not supported on macOS.")
|
|
135
|
+
else()
|
|
136
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread)
|
|
137
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread)
|
|
138
|
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
139
|
+
list(APPEND SANITIZER_LINK_FLAGS -lpthread)
|
|
140
|
+
endif()
|
|
141
|
+
endif()
|
|
142
|
+
endif()
|
|
143
|
+
|
|
144
|
+
# ----------------------------- Lib -----------------------------
|
|
145
|
+
|
|
146
|
+
include(FetchContent)
|
|
147
|
+
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
|
|
148
|
+
cmake_policy(SET CMP0135 NEW)
|
|
149
|
+
|
|
150
|
+
add_library(mlx)
|
|
151
|
+
|
|
152
|
+
target_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS})
|
|
153
|
+
target_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS})
|
|
154
|
+
|
|
155
|
+
if(MLX_BUILD_CUDA)
|
|
156
|
+
enable_language(CUDA)
|
|
157
|
+
find_package(CUDAToolkit REQUIRED)
|
|
158
|
+
find_package(CUDNN REQUIRED)
|
|
159
|
+
endif()
|
|
160
|
+
|
|
161
|
+
if(MLX_BUILD_METAL)
|
|
162
|
+
find_library(METAL_LIB Metal)
|
|
163
|
+
find_library(FOUNDATION_LIB Foundation)
|
|
164
|
+
find_library(QUARTZ_LIB QuartzCore)
|
|
165
|
+
if(METAL_LIB)
|
|
166
|
+
message(STATUS "Metal found ${METAL_LIB}")
|
|
167
|
+
else()
|
|
168
|
+
message(
|
|
169
|
+
FATAL_ERROR
|
|
170
|
+
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
|
171
|
+
endif()
|
|
172
|
+
|
|
173
|
+
if(MLX_METAL_DEBUG)
|
|
174
|
+
add_compile_definitions(MLX_METAL_DEBUG)
|
|
175
|
+
endif()
|
|
176
|
+
|
|
177
|
+
# Throw an error if xcrun not found
|
|
178
|
+
execute_process(
|
|
179
|
+
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
|
180
|
+
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
|
181
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
|
182
|
+
|
|
183
|
+
if(${MACOS_SDK_VERSION} LESS 14.0)
|
|
184
|
+
message(
|
|
185
|
+
FATAL_ERROR
|
|
186
|
+
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
|
187
|
+
endif()
|
|
188
|
+
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
|
189
|
+
|
|
190
|
+
set(METAL_CPP_URL
|
|
191
|
+
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
|
192
|
+
|
|
193
|
+
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
|
194
|
+
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
|
195
|
+
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
|
196
|
+
endif()
|
|
197
|
+
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
|
198
|
+
endif()
|
|
199
|
+
execute_process(
|
|
200
|
+
COMMAND
|
|
201
|
+
zsh "-c"
|
|
202
|
+
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
|
203
|
+
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
|
204
|
+
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
|
205
|
+
FetchContent_MakeAvailable(metal_cpp)
|
|
206
|
+
target_include_directories(
|
|
207
|
+
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
|
208
|
+
$<INSTALL_INTERFACE:include/metal_cpp>)
|
|
209
|
+
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
|
210
|
+
endif()
|
|
211
|
+
|
|
212
|
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
213
|
+
# With newer clang/gcc versions following libs are implicitly linked, but when
|
|
214
|
+
# building on old distributions they need to be explicitly listed.
|
|
215
|
+
target_link_libraries(mlx PRIVATE dl pthread)
|
|
216
|
+
endif()
|
|
217
|
+
|
|
218
|
+
if(WIN32)
|
|
219
|
+
if(MSVC)
|
|
220
|
+
# GGUF does not build with MSVC.
|
|
221
|
+
set(MLX_BUILD_GGUF OFF)
|
|
222
|
+
endif()
|
|
223
|
+
# Generate DLL and EXE in the same dir, otherwise EXE will not be able to run.
|
|
224
|
+
# This is only done when MLX is built as the top project.
|
|
225
|
+
if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
|
|
226
|
+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
|
227
|
+
endif()
|
|
228
|
+
# Windows implementation of dlfcn.h APIs.
|
|
229
|
+
FetchContent_Declare(
|
|
230
|
+
dlfcn-win32
|
|
231
|
+
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
|
232
|
+
GIT_TAG v1.4.2
|
|
233
|
+
EXCLUDE_FROM_ALL)
|
|
234
|
+
block()
|
|
235
|
+
set(BUILD_SHARED_LIBS OFF)
|
|
236
|
+
FetchContent_MakeAvailable(dlfcn-win32)
|
|
237
|
+
endblock()
|
|
238
|
+
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
|
239
|
+
target_link_libraries(mlx PRIVATE dl)
|
|
240
|
+
endif()
|
|
241
|
+
|
|
242
|
+
if(MLX_BUILD_CPU)
|
|
243
|
+
find_library(ACCELERATE_LIBRARY Accelerate)
|
|
244
|
+
if(ACCELERATE_LIBRARY)
|
|
245
|
+
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
|
246
|
+
set(MLX_BUILD_ACCELERATE ON)
|
|
247
|
+
else()
|
|
248
|
+
message(STATUS "Accelerate not found, using default backend.")
|
|
249
|
+
set(MLX_BUILD_ACCELERATE OFF)
|
|
250
|
+
endif()
|
|
251
|
+
|
|
252
|
+
if(MLX_BUILD_ACCELERATE)
|
|
253
|
+
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
|
254
|
+
add_compile_definitions(MLX_USE_ACCELERATE)
|
|
255
|
+
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
|
256
|
+
elseif(WIN32)
|
|
257
|
+
# Download and link prebuilt binaries of OpenBLAS. Note that we can only
|
|
258
|
+
# link with the dynamic library, the prebuilt binaries were built with MinGW
|
|
259
|
+
# so static-linking would require linking with MinGW's runtime.
|
|
260
|
+
FetchContent_Declare(
|
|
261
|
+
openblas
|
|
262
|
+
URL "https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip"
|
|
263
|
+
)
|
|
264
|
+
FetchContent_MakeAvailable(openblas)
|
|
265
|
+
target_link_libraries(mlx
|
|
266
|
+
PRIVATE "${openblas_SOURCE_DIR}/lib/libopenblas.lib")
|
|
267
|
+
target_include_directories(mlx PRIVATE "${openblas_SOURCE_DIR}/include")
|
|
268
|
+
# Make sure the DLL file is placed in the same dir with executables.
|
|
269
|
+
set(OPENBLAS_DLL_FILE "${openblas_SOURCE_DIR}/bin/libopenblas.dll")
|
|
270
|
+
add_custom_command(
|
|
271
|
+
TARGET mlx
|
|
272
|
+
POST_BUILD
|
|
273
|
+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE}
|
|
274
|
+
${CMAKE_BINARY_DIR})
|
|
275
|
+
else()
|
|
276
|
+
if(${CMAKE_HOST_APPLE})
|
|
277
|
+
# The blas shipped in macOS SDK is not supported, search homebrew for
|
|
278
|
+
# openblas instead.
|
|
279
|
+
set(BLA_VENDOR OpenBLAS)
|
|
280
|
+
set(LAPACK_ROOT
|
|
281
|
+
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
|
282
|
+
endif()
|
|
283
|
+
# Search and link with lapack.
|
|
284
|
+
find_package(LAPACK REQUIRED)
|
|
285
|
+
if(NOT LAPACK_FOUND)
|
|
286
|
+
message(FATAL_ERROR "Must have LAPACK installed")
|
|
287
|
+
endif()
|
|
288
|
+
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
|
289
|
+
/usr/local/opt/openblas/include)
|
|
290
|
+
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
|
291
|
+
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
|
292
|
+
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
|
293
|
+
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
|
294
|
+
# List blas after lapack otherwise we may accidentally incldue an old
|
|
295
|
+
# version of lapack.h from the include dirs of blas.
|
|
296
|
+
find_package(BLAS REQUIRED)
|
|
297
|
+
if(NOT BLAS_FOUND)
|
|
298
|
+
message(FATAL_ERROR "Must have BLAS installed")
|
|
299
|
+
endif()
|
|
300
|
+
# TODO find a cleaner way to do this
|
|
301
|
+
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
|
302
|
+
$ENV{BLAS_HOME}/include)
|
|
303
|
+
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
|
304
|
+
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
|
305
|
+
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
|
306
|
+
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
|
307
|
+
endif()
|
|
308
|
+
else()
|
|
309
|
+
set(MLX_BUILD_ACCELERATE OFF)
|
|
310
|
+
endif()
|
|
311
|
+
|
|
312
|
+
message(STATUS "Downloading json")
|
|
313
|
+
FetchContent_Declare(
|
|
314
|
+
json
|
|
315
|
+
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
|
316
|
+
FetchContent_MakeAvailable(json)
|
|
317
|
+
target_include_directories(
|
|
318
|
+
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
|
319
|
+
|
|
320
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
|
321
|
+
|
|
322
|
+
target_include_directories(
|
|
323
|
+
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
|
324
|
+
$<INSTALL_INTERFACE:include>)
|
|
325
|
+
|
|
326
|
+
if(USE_SYSTEM_FMT)
|
|
327
|
+
find_package(fmt REQUIRED)
|
|
328
|
+
else()
|
|
329
|
+
FetchContent_Declare(
|
|
330
|
+
fmt
|
|
331
|
+
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
|
332
|
+
GIT_TAG 12.1.0
|
|
333
|
+
EXCLUDE_FROM_ALL)
|
|
334
|
+
FetchContent_MakeAvailable(fmt)
|
|
335
|
+
endif()
|
|
336
|
+
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
|
337
|
+
|
|
338
|
+
if(MLX_BUILD_PYTHON_BINDINGS)
|
|
339
|
+
message(STATUS "Building Python bindings.")
|
|
340
|
+
find_package(
|
|
341
|
+
Python 3.10
|
|
342
|
+
COMPONENTS Interpreter Development.Module
|
|
343
|
+
REQUIRED)
|
|
344
|
+
FetchContent_Declare(
|
|
345
|
+
nanobind
|
|
346
|
+
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
|
|
347
|
+
GIT_TAG v2.10.2
|
|
348
|
+
GIT_SHALLOW TRUE
|
|
349
|
+
EXCLUDE_FROM_ALL)
|
|
350
|
+
FetchContent_MakeAvailable(nanobind)
|
|
351
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
|
352
|
+
endif()
|
|
353
|
+
|
|
354
|
+
if(MLX_BUILD_TESTS)
|
|
355
|
+
include(CTest)
|
|
356
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
|
357
|
+
endif()
|
|
358
|
+
|
|
359
|
+
if(MLX_BUILD_EXAMPLES)
|
|
360
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
|
361
|
+
endif()
|
|
362
|
+
|
|
363
|
+
if(MLX_BUILD_BENCHMARKS)
|
|
364
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
|
365
|
+
endif()
|
|
366
|
+
|
|
367
|
+
# ----------------------------- Installation -----------------------------
|
|
368
|
+
include(GNUInstallDirs)
|
|
369
|
+
|
|
370
|
+
if(WIN32)
|
|
371
|
+
# Install DLLs to the same dir with extension file (core.pyd) on Windows.
|
|
372
|
+
set(CMAKE_INSTALL_BINDIR ".")
|
|
373
|
+
if(MLX_BUILD_CPU)
|
|
374
|
+
# Install OpenBLAS.
|
|
375
|
+
install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN)
|
|
376
|
+
endif()
|
|
377
|
+
endif()
|
|
378
|
+
|
|
379
|
+
# Install library
|
|
380
|
+
install(
|
|
381
|
+
TARGETS mlx
|
|
382
|
+
EXPORT MLXTargets
|
|
383
|
+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
384
|
+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
385
|
+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
|
386
|
+
INCLUDES
|
|
387
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
|
388
|
+
|
|
389
|
+
# Install headers
|
|
390
|
+
install(
|
|
391
|
+
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
|
392
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
|
393
|
+
COMPONENT headers
|
|
394
|
+
FILES_MATCHING
|
|
395
|
+
PATTERN "*.h"
|
|
396
|
+
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
|
397
|
+
|
|
398
|
+
# Install metal dependencies
|
|
399
|
+
if(MLX_BUILD_METAL)
|
|
400
|
+
|
|
401
|
+
# Install metal cpp
|
|
402
|
+
install(
|
|
403
|
+
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
|
404
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
|
405
|
+
COMPONENT metal_cpp_source)
|
|
406
|
+
|
|
407
|
+
endif()
|
|
408
|
+
|
|
409
|
+
# Install cmake config
|
|
410
|
+
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
|
|
411
|
+
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
|
|
412
|
+
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
|
413
|
+
|
|
414
|
+
install(
|
|
415
|
+
EXPORT MLXTargets
|
|
416
|
+
FILE MLXTargets.cmake
|
|
417
|
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
|
418
|
+
|
|
419
|
+
include(CMakePackageConfigHelpers)
|
|
420
|
+
|
|
421
|
+
write_basic_package_version_file(
|
|
422
|
+
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
|
423
|
+
COMPATIBILITY SameMajorVersion
|
|
424
|
+
VERSION ${MLX_VERSION})
|
|
425
|
+
|
|
426
|
+
configure_package_config_file(
|
|
427
|
+
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
|
428
|
+
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
429
|
+
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
|
430
|
+
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
|
431
|
+
MLX_CMAKE_INSTALL_MODULE_DIR)
|
|
432
|
+
|
|
433
|
+
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
|
434
|
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
|
435
|
+
|
|
436
|
+
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
|
437
|
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Contributor Covenant Code of Conduct
|
|
2
|
+
|
|
3
|
+
## Our Pledge
|
|
4
|
+
|
|
5
|
+
We as members, contributors, and leaders pledge to make participation in our
|
|
6
|
+
community a harassment-free experience for everyone, regardless of age, body
|
|
7
|
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
|
8
|
+
identity and expression, level of experience, education, socio-economic status,
|
|
9
|
+
nationality, personal appearance, race, caste, color, religion, or sexual
|
|
10
|
+
identity and orientation.
|
|
11
|
+
|
|
12
|
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
|
13
|
+
diverse, inclusive, and healthy community.
|
|
14
|
+
|
|
15
|
+
## Our Standards
|
|
16
|
+
|
|
17
|
+
Examples of behavior that contributes to a positive environment for our
|
|
18
|
+
community include:
|
|
19
|
+
|
|
20
|
+
* Demonstrating empathy and kindness toward other people
|
|
21
|
+
* Being respectful of differing opinions, viewpoints, and experiences
|
|
22
|
+
* Giving and gracefully accepting constructive feedback
|
|
23
|
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
|
24
|
+
and learning from the experience
|
|
25
|
+
* Focusing on what is best not just for us as individuals, but for the overall
|
|
26
|
+
community
|
|
27
|
+
|
|
28
|
+
Examples of unacceptable behavior include:
|
|
29
|
+
|
|
30
|
+
* The use of sexualized language or imagery, and sexual attention or advances of
|
|
31
|
+
any kind
|
|
32
|
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
|
33
|
+
* Public or private harassment
|
|
34
|
+
* Publishing others' private information, such as a physical or email address,
|
|
35
|
+
without their explicit permission
|
|
36
|
+
* Other conduct which could reasonably be considered inappropriate in a
|
|
37
|
+
professional setting
|
|
38
|
+
|
|
39
|
+
## Enforcement Responsibilities
|
|
40
|
+
|
|
41
|
+
Community leaders are responsible for clarifying and enforcing our standards of
|
|
42
|
+
acceptable behavior and will take appropriate and fair corrective action in
|
|
43
|
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
|
44
|
+
or harmful.
|
|
45
|
+
|
|
46
|
+
Community leaders have the right and responsibility to remove, edit, or reject
|
|
47
|
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
|
48
|
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
|
49
|
+
decisions when appropriate.
|
|
50
|
+
|
|
51
|
+
## Scope
|
|
52
|
+
|
|
53
|
+
This Code of Conduct applies within all community spaces, and also applies when
|
|
54
|
+
an individual is officially representing the community in public spaces.
|
|
55
|
+
Examples of representing our community include using an official e-mail address,
|
|
56
|
+
posting via an official social media account, or acting as an appointed
|
|
57
|
+
representative at an online or offline event.
|
|
58
|
+
|
|
59
|
+
## Enforcement
|
|
60
|
+
|
|
61
|
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
|
62
|
+
reported to the community leaders responsible for enforcement at
|
|
63
|
+
[opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
|
|
64
|
+
All complaints will be reviewed and investigated promptly and fairly.
|
|
65
|
+
|
|
66
|
+
All community leaders are obligated to respect the privacy and security of the
|
|
67
|
+
reporter of any incident.
|
|
68
|
+
|
|
69
|
+
## Enforcement Guidelines
|
|
70
|
+
|
|
71
|
+
Community leaders will follow these Community Impact Guidelines in determining
|
|
72
|
+
the consequences for any action they deem in violation of this Code of Conduct:
|
|
73
|
+
|
|
74
|
+
### 1. Correction
|
|
75
|
+
|
|
76
|
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
|
77
|
+
unprofessional or unwelcome in the community.
|
|
78
|
+
|
|
79
|
+
**Consequence**: A private, written warning from community leaders, providing
|
|
80
|
+
clarity around the nature of the violation and an explanation of why the
|
|
81
|
+
behavior was inappropriate. A public apology may be requested.
|
|
82
|
+
|
|
83
|
+
### 2. Warning
|
|
84
|
+
|
|
85
|
+
**Community Impact**: A violation through a single incident or series of
|
|
86
|
+
actions.
|
|
87
|
+
|
|
88
|
+
**Consequence**: A warning with consequences for continued behavior. No
|
|
89
|
+
interaction with the people involved, including unsolicited interaction with
|
|
90
|
+
those enforcing the Code of Conduct, for a specified period of time. This
|
|
91
|
+
includes avoiding interactions in community spaces as well as external channels
|
|
92
|
+
like social media. Violating these terms may lead to a temporary or permanent
|
|
93
|
+
ban.
|
|
94
|
+
|
|
95
|
+
### 3. Temporary Ban
|
|
96
|
+
|
|
97
|
+
**Community Impact**: A serious violation of community standards, including
|
|
98
|
+
sustained inappropriate behavior.
|
|
99
|
+
|
|
100
|
+
**Consequence**: A temporary ban from any sort of interaction or public
|
|
101
|
+
communication with the community for a specified period of time. No public or
|
|
102
|
+
private interaction with the people involved, including unsolicited interaction
|
|
103
|
+
with those enforcing the Code of Conduct, is allowed during this period.
|
|
104
|
+
Violating these terms may lead to a permanent ban.
|
|
105
|
+
|
|
106
|
+
### 4. Permanent Ban
|
|
107
|
+
|
|
108
|
+
**Community Impact**: Demonstrating a pattern of violation of community
|
|
109
|
+
standards, including sustained inappropriate behavior, harassment of an
|
|
110
|
+
individual, or aggression toward or disparagement of classes of individuals.
|
|
111
|
+
|
|
112
|
+
**Consequence**: A permanent ban from any sort of public interaction within the
|
|
113
|
+
community.
|
|
114
|
+
|
|
115
|
+
## Attribution
|
|
116
|
+
|
|
117
|
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
|
118
|
+
version 2.1, available at
|
|
119
|
+
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
|
120
|
+
|
|
121
|
+
Community Impact Guidelines were inspired by
|
|
122
|
+
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
|
123
|
+
|
|
124
|
+
For answers to common questions about this code of conduct, see the FAQ at
|
|
125
|
+
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
|
126
|
+
[https://www.contributor-covenant.org/translations][translations].
|
|
127
|
+
|
|
128
|
+
[homepage]: https://www.contributor-covenant.org
|
|
129
|
+
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
|
130
|
+
[Mozilla CoC]: https://github.com/mozilla/diversity
|
|
131
|
+
[FAQ]: https://www.contributor-covenant.org/faq
|
|
132
|
+
[translations]: https://www.contributor-covenant.org/translations
|
data/mlx/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Contributing to MLX
|
|
2
|
+
|
|
3
|
+
We want to make contributing to this project as easy and transparent as
|
|
4
|
+
possible.
|
|
5
|
+
|
|
6
|
+
## Pull Requests
|
|
7
|
+
|
|
8
|
+
1. Fork and submit pull requests to the repo.
|
|
9
|
+
2. If you've added code that should be tested, add tests.
|
|
10
|
+
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
|
11
|
+
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
|
12
|
+
4. If you've changed APIs, update the documentation.
|
|
13
|
+
5. Every PR should have passing tests and at least one review.
|
|
14
|
+
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
|
15
|
+
This should install hooks for running `black` and `clang-format` to ensure
|
|
16
|
+
consistent style for C++ and python code.
|
|
17
|
+
|
|
18
|
+
You can also run the formatters manually as follows:
|
|
19
|
+
|
|
20
|
+
```shell
|
|
21
|
+
clang-format -i file.cpp
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
```shell
|
|
25
|
+
black file.py
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
or run `pre-commit run --all-files` to check all files in the repo.
|
|
29
|
+
|
|
30
|
+
## Issues
|
|
31
|
+
|
|
32
|
+
We use GitHub issues to track public bugs. Please ensure your description is
|
|
33
|
+
clear and has sufficient instructions to be able to reproduce the issue.
|
|
34
|
+
|
|
35
|
+
## License
|
|
36
|
+
|
|
37
|
+
By contributing to MLX, you agree that your contributions will be licensed
|
|
38
|
+
under the LICENSE file in the root directory of this source tree.
|
data/mlx/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright © 2023 Apple Inc.
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|