mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
data/mlx/CMakeLists.txt
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.25)
|
|
2
|
+
|
|
3
|
+
if(NOT MLX_VERSION)
|
|
4
|
+
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
|
5
|
+
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
|
6
|
+
set(_major ${CMAKE_MATCH_1})
|
|
7
|
+
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
|
8
|
+
set(_minor ${CMAKE_MATCH_1})
|
|
9
|
+
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
|
10
|
+
set(_patch ${CMAKE_MATCH_1})
|
|
11
|
+
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
|
12
|
+
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
|
13
|
+
else()
|
|
14
|
+
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
|
15
|
+
${MLX_VERSION})
|
|
16
|
+
endif()
|
|
17
|
+
|
|
18
|
+
project(
|
|
19
|
+
mlx
|
|
20
|
+
LANGUAGES C CXX
|
|
21
|
+
VERSION ${MLX_PROJECT_VERSION})
|
|
22
|
+
|
|
23
|
+
# ----------------------------- Setup -----------------------------
|
|
24
|
+
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
|
25
|
+
set(CMAKE_CXX_STANDARD 20)
|
|
26
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
27
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
28
|
+
set(CMAKE_INSTALL_MESSAGE NEVER)
|
|
29
|
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
30
|
+
|
|
31
|
+
# ----------------------------- Configuration -----------------------------
|
|
32
|
+
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
|
33
|
+
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
|
34
|
+
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|
35
|
+
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
|
36
|
+
option(MLX_BUILD_METAL "Build metal backend" ON)
|
|
37
|
+
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
|
38
|
+
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
|
39
|
+
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
|
40
|
+
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
|
41
|
+
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
|
42
|
+
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
|
43
|
+
option(MLX_BUILD_PYTHON_STUBS "Build stub files for python bindings" ON)
|
|
44
|
+
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
|
45
|
+
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
|
46
|
+
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
|
47
|
+
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
|
48
|
+
option(USE_ASAN "Enable AddressSanitizer (ASan)" OFF)
|
|
49
|
+
option(USE_UBSAN "Enable UndefinedBehaviorSanitizer (UBSan)" OFF)
|
|
50
|
+
option(USE_TSAN "Enable ThreadSanitizer (TSan)" OFF)
|
|
51
|
+
|
|
52
|
+
# --------------------- Processor tests -------------------------
|
|
53
|
+
message(
|
|
54
|
+
STATUS
|
|
55
|
+
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|
59
|
+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
|
60
|
+
if(NOT MLX_ENABLE_X64_MAC)
|
|
61
|
+
message(
|
|
62
|
+
FATAL_ERROR
|
|
63
|
+
"Building for x86_64 on macOS is not supported."
|
|
64
|
+
" If you are on an Apple silicon system, check the build"
|
|
65
|
+
" documentation for possible fixes: "
|
|
66
|
+
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
|
67
|
+
)
|
|
68
|
+
else()
|
|
69
|
+
set(MLX_BUILD_METAL OFF)
|
|
70
|
+
message(WARNING "Building for x86_64 arch is not officially supported.")
|
|
71
|
+
endif()
|
|
72
|
+
endif()
|
|
73
|
+
else()
|
|
74
|
+
set(MLX_BUILD_METAL OFF)
|
|
75
|
+
endif()
|
|
76
|
+
|
|
77
|
+
if(MLX_USE_CCACHE)
|
|
78
|
+
find_program(CCACHE_PROGRAM ccache)
|
|
79
|
+
if(CCACHE_PROGRAM)
|
|
80
|
+
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
|
81
|
+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
82
|
+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
83
|
+
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
|
84
|
+
endif()
|
|
85
|
+
endif()
|
|
86
|
+
|
|
87
|
+
if(USE_ASAN AND USE_TSAN)
|
|
88
|
+
message(
|
|
89
|
+
FATAL_ERROR
|
|
90
|
+
"AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time."
|
|
91
|
+
)
|
|
92
|
+
endif()
|
|
93
|
+
|
|
94
|
+
set(SANITIZER_COMPILE_FLAGS "")
|
|
95
|
+
set(SANITIZER_LINK_FLAGS "")
|
|
96
|
+
|
|
97
|
+
if(USE_ASAN)
|
|
98
|
+
if(WIN32 AND MSVC)
|
|
99
|
+
list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address)
|
|
100
|
+
list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address)
|
|
101
|
+
else()
|
|
102
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address)
|
|
103
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address)
|
|
104
|
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
105
|
+
list(APPEND SANITIZER_LINK_FLAGS -lpthread)
|
|
106
|
+
endif()
|
|
107
|
+
endif()
|
|
108
|
+
endif()
|
|
109
|
+
|
|
110
|
+
if(USE_UBSAN)
|
|
111
|
+
if(WIN32 AND MSVC)
|
|
112
|
+
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|
113
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
|
|
114
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
|
|
115
|
+
else()
|
|
116
|
+
message(
|
|
117
|
+
WARNING
|
|
118
|
+
"UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC."
|
|
119
|
+
)
|
|
120
|
+
endif()
|
|
121
|
+
else()
|
|
122
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
|
|
123
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
|
|
124
|
+
endif()
|
|
125
|
+
endif()
|
|
126
|
+
|
|
127
|
+
if(USE_TSAN)
|
|
128
|
+
if(WIN32 AND MSVC)
|
|
129
|
+
message(
|
|
130
|
+
FATAL_ERROR
|
|
131
|
+
"ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC."
|
|
132
|
+
)
|
|
133
|
+
elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
|
134
|
+
message(FATAL_ERROR "ThreadSanitizer (TSan) is not supported on macOS.")
|
|
135
|
+
else()
|
|
136
|
+
list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread)
|
|
137
|
+
list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread)
|
|
138
|
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
139
|
+
list(APPEND SANITIZER_LINK_FLAGS -lpthread)
|
|
140
|
+
endif()
|
|
141
|
+
endif()
|
|
142
|
+
endif()
|
|
143
|
+
|
|
144
|
+
# ----------------------------- Lib -----------------------------
|
|
145
|
+
|
|
146
|
+
include(FetchContent)
|
|
147
|
+
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
|
|
148
|
+
cmake_policy(SET CMP0135 NEW)
|
|
149
|
+
|
|
150
|
+
add_library(mlx)
|
|
151
|
+
|
|
152
|
+
target_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS})
|
|
153
|
+
target_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS})
|
|
154
|
+
|
|
155
|
+
if(MLX_BUILD_CUDA)
|
|
156
|
+
enable_language(CUDA)
|
|
157
|
+
find_package(CUDAToolkit REQUIRED)
|
|
158
|
+
find_package(CUDNN REQUIRED)
|
|
159
|
+
endif()
|
|
160
|
+
|
|
161
|
+
if(MLX_BUILD_METAL)
|
|
162
|
+
find_library(METAL_LIB Metal)
|
|
163
|
+
find_library(FOUNDATION_LIB Foundation)
|
|
164
|
+
find_library(QUARTZ_LIB QuartzCore)
|
|
165
|
+
if(METAL_LIB)
|
|
166
|
+
message(STATUS "Metal found ${METAL_LIB}")
|
|
167
|
+
else()
|
|
168
|
+
message(
|
|
169
|
+
FATAL_ERROR
|
|
170
|
+
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
|
171
|
+
endif()
|
|
172
|
+
|
|
173
|
+
if(MLX_METAL_DEBUG)
|
|
174
|
+
add_compile_definitions(MLX_METAL_DEBUG)
|
|
175
|
+
endif()
|
|
176
|
+
|
|
177
|
+
# Throw an error if xcrun not found
|
|
178
|
+
execute_process(
|
|
179
|
+
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
|
180
|
+
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
|
181
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
|
182
|
+
|
|
183
|
+
if(${MACOS_SDK_VERSION} LESS 14.0)
|
|
184
|
+
message(
|
|
185
|
+
FATAL_ERROR
|
|
186
|
+
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
|
187
|
+
endif()
|
|
188
|
+
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
|
189
|
+
|
|
190
|
+
set(METAL_CPP_URL
|
|
191
|
+
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
|
192
|
+
|
|
193
|
+
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
|
194
|
+
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
|
195
|
+
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
|
196
|
+
endif()
|
|
197
|
+
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
|
198
|
+
endif()
|
|
199
|
+
execute_process(
|
|
200
|
+
COMMAND
|
|
201
|
+
zsh "-c"
|
|
202
|
+
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
|
203
|
+
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
|
204
|
+
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
|
205
|
+
FetchContent_MakeAvailable(metal_cpp)
|
|
206
|
+
target_include_directories(
|
|
207
|
+
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
|
208
|
+
$<INSTALL_INTERFACE:include/metal_cpp>)
|
|
209
|
+
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
|
210
|
+
endif()
|
|
211
|
+
|
|
212
|
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
213
|
+
# With newer clang/gcc versions following libs are implicitly linked, but when
|
|
214
|
+
# building on old distributions they need to be explicitly listed.
|
|
215
|
+
target_link_libraries(mlx PRIVATE dl pthread)
|
|
216
|
+
endif()
|
|
217
|
+
|
|
218
|
+
if(WIN32)
|
|
219
|
+
if(MSVC)
|
|
220
|
+
# GGUF does not build with MSVC.
|
|
221
|
+
set(MLX_BUILD_GGUF OFF)
|
|
222
|
+
endif()
|
|
223
|
+
# Generate DLL and EXE in the same dir, otherwise EXE will not be able to run.
|
|
224
|
+
# This is only done when MLX is built as the top project.
|
|
225
|
+
if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
|
|
226
|
+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
|
227
|
+
endif()
|
|
228
|
+
# Windows implementation of dlfcn.h APIs.
|
|
229
|
+
FetchContent_Declare(
|
|
230
|
+
dlfcn-win32
|
|
231
|
+
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
|
232
|
+
GIT_TAG v1.4.2
|
|
233
|
+
EXCLUDE_FROM_ALL)
|
|
234
|
+
block()
|
|
235
|
+
set(BUILD_SHARED_LIBS OFF)
|
|
236
|
+
FetchContent_MakeAvailable(dlfcn-win32)
|
|
237
|
+
endblock()
|
|
238
|
+
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
|
239
|
+
target_link_libraries(mlx PRIVATE dl)
|
|
240
|
+
endif()
|
|
241
|
+
|
|
242
|
+
if(MLX_BUILD_CPU)
|
|
243
|
+
find_library(ACCELERATE_LIBRARY Accelerate)
|
|
244
|
+
if(ACCELERATE_LIBRARY)
|
|
245
|
+
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
|
246
|
+
set(MLX_BUILD_ACCELERATE ON)
|
|
247
|
+
else()
|
|
248
|
+
message(STATUS "Accelerate not found, using default backend.")
|
|
249
|
+
set(MLX_BUILD_ACCELERATE OFF)
|
|
250
|
+
endif()
|
|
251
|
+
|
|
252
|
+
if(MLX_BUILD_ACCELERATE)
|
|
253
|
+
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
|
254
|
+
add_compile_definitions(MLX_USE_ACCELERATE)
|
|
255
|
+
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
|
256
|
+
elseif(WIN32)
|
|
257
|
+
# Download and link prebuilt binaries of OpenBLAS. Note that we can only
|
|
258
|
+
# link with the dynamic library, the prebuilt binaries were built with MinGW
|
|
259
|
+
# so static-linking would require linking with MinGW's runtime.
|
|
260
|
+
FetchContent_Declare(
|
|
261
|
+
openblas
|
|
262
|
+
URL "https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip"
|
|
263
|
+
)
|
|
264
|
+
FetchContent_MakeAvailable(openblas)
|
|
265
|
+
target_link_libraries(mlx
|
|
266
|
+
PRIVATE "${openblas_SOURCE_DIR}/lib/libopenblas.lib")
|
|
267
|
+
target_include_directories(mlx PRIVATE "${openblas_SOURCE_DIR}/include")
|
|
268
|
+
# Make sure the DLL file is placed in the same dir with executables.
|
|
269
|
+
set(OPENBLAS_DLL_FILE "${openblas_SOURCE_DIR}/bin/libopenblas.dll")
|
|
270
|
+
add_custom_command(
|
|
271
|
+
TARGET mlx
|
|
272
|
+
POST_BUILD
|
|
273
|
+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE}
|
|
274
|
+
${CMAKE_BINARY_DIR})
|
|
275
|
+
else()
|
|
276
|
+
if(${CMAKE_HOST_APPLE})
|
|
277
|
+
# The blas shipped in macOS SDK is not supported, search homebrew for
|
|
278
|
+
# openblas instead.
|
|
279
|
+
set(BLA_VENDOR OpenBLAS)
|
|
280
|
+
set(LAPACK_ROOT
|
|
281
|
+
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
|
282
|
+
endif()
|
|
283
|
+
# Search and link with lapack.
|
|
284
|
+
find_package(LAPACK REQUIRED)
|
|
285
|
+
if(NOT LAPACK_FOUND)
|
|
286
|
+
message(FATAL_ERROR "Must have LAPACK installed")
|
|
287
|
+
endif()
|
|
288
|
+
find_path(
|
|
289
|
+
LAPACK_INCLUDE_DIRS lapacke.h
|
|
290
|
+
/usr/include
|
|
291
|
+
/usr/include/lapacke
|
|
292
|
+
/usr/include/x86_64-linux-gnu
|
|
293
|
+
/usr/include/x86_64-linux-gnu/lapacke
|
|
294
|
+
/usr/local/include
|
|
295
|
+
/usr/local/include/lapacke
|
|
296
|
+
/usr/local/opt/openblas/include)
|
|
297
|
+
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
|
298
|
+
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
|
299
|
+
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
|
300
|
+
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
|
301
|
+
# List blas after lapack otherwise we may accidentally incldue an old
|
|
302
|
+
# version of lapack.h from the include dirs of blas.
|
|
303
|
+
find_package(BLAS REQUIRED)
|
|
304
|
+
if(NOT BLAS_FOUND)
|
|
305
|
+
message(FATAL_ERROR "Must have BLAS installed")
|
|
306
|
+
endif()
|
|
307
|
+
# TODO find a cleaner way to do this
|
|
308
|
+
find_path(
|
|
309
|
+
BLAS_INCLUDE_DIRS cblas.h
|
|
310
|
+
/usr/include
|
|
311
|
+
/usr/include/x86_64-linux-gnu
|
|
312
|
+
/usr/local/include
|
|
313
|
+
/usr/local/include/openblas
|
|
314
|
+
$ENV{BLAS_HOME}/include)
|
|
315
|
+
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
|
316
|
+
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
|
317
|
+
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
|
318
|
+
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
|
319
|
+
endif()
|
|
320
|
+
else()
|
|
321
|
+
set(MLX_BUILD_ACCELERATE OFF)
|
|
322
|
+
endif()
|
|
323
|
+
|
|
324
|
+
message(STATUS "Downloading json")
|
|
325
|
+
FetchContent_Declare(
|
|
326
|
+
json
|
|
327
|
+
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
|
328
|
+
FetchContent_MakeAvailable(json)
|
|
329
|
+
target_include_directories(
|
|
330
|
+
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
|
331
|
+
|
|
332
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
|
333
|
+
|
|
334
|
+
target_include_directories(
|
|
335
|
+
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
|
336
|
+
$<INSTALL_INTERFACE:include>)
|
|
337
|
+
|
|
338
|
+
if(USE_SYSTEM_FMT)
|
|
339
|
+
find_package(fmt REQUIRED)
|
|
340
|
+
else()
|
|
341
|
+
FetchContent_Declare(
|
|
342
|
+
fmt
|
|
343
|
+
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
|
344
|
+
GIT_TAG 12.1.0
|
|
345
|
+
EXCLUDE_FROM_ALL)
|
|
346
|
+
FetchContent_MakeAvailable(fmt)
|
|
347
|
+
endif()
|
|
348
|
+
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
|
349
|
+
|
|
350
|
+
if(MLX_BUILD_PYTHON_BINDINGS)
|
|
351
|
+
message(STATUS "Building Python bindings.")
|
|
352
|
+
find_package(
|
|
353
|
+
Python 3.10
|
|
354
|
+
COMPONENTS Interpreter Development.Module
|
|
355
|
+
REQUIRED)
|
|
356
|
+
FetchContent_Declare(
|
|
357
|
+
nanobind
|
|
358
|
+
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
|
|
359
|
+
GIT_TAG v2.10.2
|
|
360
|
+
GIT_SHALLOW TRUE
|
|
361
|
+
EXCLUDE_FROM_ALL)
|
|
362
|
+
FetchContent_MakeAvailable(nanobind)
|
|
363
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
|
364
|
+
endif()
|
|
365
|
+
|
|
366
|
+
if(MLX_BUILD_TESTS)
|
|
367
|
+
include(CTest)
|
|
368
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
|
369
|
+
endif()
|
|
370
|
+
|
|
371
|
+
if(MLX_BUILD_EXAMPLES)
|
|
372
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
|
373
|
+
endif()
|
|
374
|
+
|
|
375
|
+
if(MLX_BUILD_BENCHMARKS)
|
|
376
|
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
|
377
|
+
endif()
|
|
378
|
+
|
|
379
|
+
# ----------------------------- Installation -----------------------------
|
|
380
|
+
include(GNUInstallDirs)
|
|
381
|
+
|
|
382
|
+
if(WIN32)
|
|
383
|
+
# Install DLLs to the same dir with extension file (core.pyd) on Windows.
|
|
384
|
+
set(CMAKE_INSTALL_BINDIR ".")
|
|
385
|
+
if(MLX_BUILD_CPU)
|
|
386
|
+
# Install OpenBLAS.
|
|
387
|
+
install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN)
|
|
388
|
+
endif()
|
|
389
|
+
endif()
|
|
390
|
+
|
|
391
|
+
# Install library
|
|
392
|
+
install(
|
|
393
|
+
TARGETS mlx
|
|
394
|
+
EXPORT MLXTargets
|
|
395
|
+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
396
|
+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
|
397
|
+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
|
398
|
+
INCLUDES
|
|
399
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
|
400
|
+
|
|
401
|
+
# Install headers
|
|
402
|
+
install(
|
|
403
|
+
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
|
404
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
|
405
|
+
COMPONENT headers
|
|
406
|
+
FILES_MATCHING
|
|
407
|
+
PATTERN "*.h"
|
|
408
|
+
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
|
409
|
+
|
|
410
|
+
# Install metal dependencies
|
|
411
|
+
if(MLX_BUILD_METAL)
|
|
412
|
+
|
|
413
|
+
# Install metal cpp
|
|
414
|
+
install(
|
|
415
|
+
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
|
416
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
|
417
|
+
COMPONENT metal_cpp_source)
|
|
418
|
+
|
|
419
|
+
endif()
|
|
420
|
+
|
|
421
|
+
# Install cmake config
|
|
422
|
+
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
|
|
423
|
+
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
|
|
424
|
+
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
|
425
|
+
|
|
426
|
+
install(
|
|
427
|
+
EXPORT MLXTargets
|
|
428
|
+
FILE MLXTargets.cmake
|
|
429
|
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
|
430
|
+
|
|
431
|
+
include(CMakePackageConfigHelpers)
|
|
432
|
+
|
|
433
|
+
write_basic_package_version_file(
|
|
434
|
+
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
|
435
|
+
COMPATIBILITY SameMajorVersion
|
|
436
|
+
VERSION ${MLX_VERSION})
|
|
437
|
+
|
|
438
|
+
configure_package_config_file(
|
|
439
|
+
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
|
440
|
+
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
441
|
+
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
|
442
|
+
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
|
443
|
+
MLX_CMAKE_INSTALL_MODULE_DIR)
|
|
444
|
+
|
|
445
|
+
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
|
446
|
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
|
447
|
+
|
|
448
|
+
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
|
449
|
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
4
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
5
|
+
# in the Software without restriction, including without limitation the rights
|
|
6
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
7
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
8
|
+
# furnished to do so, subject to the following conditions:
|
|
9
|
+
#
|
|
10
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
# copies or substantial portions of the Software.
|
|
12
|
+
#
|
|
13
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
16
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
17
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
18
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
19
|
+
# SOFTWARE.
|
|
20
|
+
|
|
21
|
+
# Modified from
|
|
22
|
+
# https://github.com/NVIDIA/cudnn-frontend/blob/main/cmake/cuDNN.cmake
|
|
23
|
+
|
|
24
|
+
# Return the last file matching the pattern.
|
|
25
|
+
function(find_file_glob VAR PATTERN)
|
|
26
|
+
file(GLOB _RESULT "${PATTERN}")
|
|
27
|
+
if(_RESULT)
|
|
28
|
+
list(LENGTH ${_RESULT} _RESULT_LENGTH)
|
|
29
|
+
if(_RESULT_LENGTH GREATER 0)
|
|
30
|
+
list(GET ${_RESULT} -1 _RESULT)
|
|
31
|
+
endif()
|
|
32
|
+
set(${VAR}
|
|
33
|
+
"${_RESULT}"
|
|
34
|
+
PARENT_SCOPE)
|
|
35
|
+
endif()
|
|
36
|
+
endfunction()
|
|
37
|
+
|
|
38
|
+
# Find the dir including the "cudnn.h" file.
|
|
39
|
+
find_path(
|
|
40
|
+
CUDNN_INCLUDE_DIR cudnn.h
|
|
41
|
+
HINTS ${CUDNN_INCLUDE_PATH} ${CUDAToolkit_INCLUDE_DIRS}
|
|
42
|
+
PATH_SUFFIXES include OPTIONAL)
|
|
43
|
+
|
|
44
|
+
# Glob searching "cudnn.h" for Windows.
|
|
45
|
+
if(WIN32 AND NOT CUDNN_INCLUDE_DIR)
|
|
46
|
+
find_file_glob(
|
|
47
|
+
CUDNN_H_PATH
|
|
48
|
+
"C:/Program Files/NVIDIA/CUDNN/*/include/${CUDAToolkit_VERSION_MAJOR}.*/cudnn.h"
|
|
49
|
+
)
|
|
50
|
+
if(CUDNN_H_PATH)
|
|
51
|
+
get_filename_component(CUDNN_INCLUDE_DIR "${CUDNN_H_PATH}" DIRECTORY)
|
|
52
|
+
endif()
|
|
53
|
+
endif()
|
|
54
|
+
|
|
55
|
+
if(NOT CUDNN_INCLUDE_DIR)
|
|
56
|
+
message(
|
|
57
|
+
FATAL_ERROR
|
|
58
|
+
"Unable to find cudnn.h, please make sure cuDNN is installed and pass CUDNN_INCLUDE_PATH to cmake."
|
|
59
|
+
)
|
|
60
|
+
endif()
|
|
61
|
+
|
|
62
|
+
# Get cudnn version.
|
|
63
|
+
file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
|
|
64
|
+
string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef
|
|
65
|
+
"${cudnn_version_header}")
|
|
66
|
+
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
|
|
67
|
+
|
|
68
|
+
# Function for searching library files.
|
|
69
|
+
function(find_cudnn_library NAME)
|
|
70
|
+
if(NOT "${ARGV1}" STREQUAL "OPTIONAL")
|
|
71
|
+
set(_CUDNN_REQUIRED TRUE)
|
|
72
|
+
else()
|
|
73
|
+
set(_CUDNN_REQUIRED FALSE)
|
|
74
|
+
endif()
|
|
75
|
+
|
|
76
|
+
find_library(
|
|
77
|
+
${NAME}_LIBRARY
|
|
78
|
+
NAMES ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" NAMES_PER_DIR
|
|
79
|
+
HINTS ${CUDNN_LIBRARY_PATH} ${CUDAToolkit_LIBRARY_DIR}
|
|
80
|
+
PATH_SUFFIXES lib64 lib/x64 lib OPTIONAL)
|
|
81
|
+
|
|
82
|
+
if(WIN32 AND NOT ${NAME}_LIBRARY)
|
|
83
|
+
find_file_glob(
|
|
84
|
+
${NAME}_LIBRARY
|
|
85
|
+
"C:/Program Files/NVIDIA/CUDNN/*/lib/${CUDAToolkit_VERSION_MAJOR}.*/x64/${NAME}.lib"
|
|
86
|
+
)
|
|
87
|
+
endif()
|
|
88
|
+
|
|
89
|
+
if(NOT ${NAME}_LIBRARY AND ${_CUDNN_REQUIRED})
|
|
90
|
+
message(
|
|
91
|
+
FATAL_ERROR
|
|
92
|
+
"Unable to find ${NAME}, please make sure cuDNN is installed and pass CUDNN_LIBRARY_PATH to cmake."
|
|
93
|
+
)
|
|
94
|
+
endif()
|
|
95
|
+
|
|
96
|
+
if(${NAME}_LIBRARY)
|
|
97
|
+
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
|
|
98
|
+
set_target_properties(
|
|
99
|
+
CUDNN::${NAME}
|
|
100
|
+
PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
|
|
101
|
+
IMPORTED_LOCATION ${${NAME}_LIBRARY})
|
|
102
|
+
set(${NAME}_LIBRARY
|
|
103
|
+
"${${NAME}_LIBRARY}"
|
|
104
|
+
PARENT_SCOPE)
|
|
105
|
+
else()
|
|
106
|
+
message(STATUS "${NAME} not found.")
|
|
107
|
+
endif()
|
|
108
|
+
endfunction()
|
|
109
|
+
|
|
110
|
+
# Search for the main cudnn library.
|
|
111
|
+
find_cudnn_library(cudnn)
|
|
112
|
+
|
|
113
|
+
include(FindPackageHandleStandardArgs)
|
|
114
|
+
find_package_handle_standard_args(CUDNN REQUIRED_VARS CUDNN_INCLUDE_DIR
|
|
115
|
+
cudnn_LIBRARY)
|
|
116
|
+
|
|
117
|
+
if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)
|
|
118
|
+
set(CUDNN_FOUND
|
|
119
|
+
ON
|
|
120
|
+
CACHE INTERNAL "cuDNN Library Found")
|
|
121
|
+
else()
|
|
122
|
+
set(CUDNN_FOUND
|
|
123
|
+
OFF
|
|
124
|
+
CACHE INTERNAL "cuDNN Library Not Found")
|
|
125
|
+
endif()
|
|
126
|
+
|
|
127
|
+
# Find out all the DLL files for Windows.
|
|
128
|
+
if(WIN32 AND cudnn_LIBRARY)
|
|
129
|
+
get_filename_component(CUDNN_BIN_DIR "${cudnn_LIBRARY}" DIRECTORY)
|
|
130
|
+
string(REPLACE "/lib/" "/bin/" CUDNN_BIN_DIR "${CUDNN_BIN_DIR}")
|
|
131
|
+
file(
|
|
132
|
+
GLOB CUDNN_DLL_NAMES
|
|
133
|
+
RELATIVE "${CUDNN_BIN_DIR}"
|
|
134
|
+
"${CUDNN_BIN_DIR}/*.dll")
|
|
135
|
+
endif()
|
|
136
|
+
|
|
137
|
+
# Create an interface library that users can link with.
|
|
138
|
+
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)
|
|
139
|
+
target_link_libraries(CUDNN::cudnn_all INTERFACE CUDNN::cudnn)
|
|
140
|
+
target_include_directories(
|
|
141
|
+
CUDNN::cudnn_all INTERFACE $<INSTALL_INTERFACE:include>
|
|
142
|
+
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>)
|
|
143
|
+
|
|
144
|
+
# Add other components of cudnn.
|
|
145
|
+
if(CUDNN_MAJOR_VERSION EQUAL 8)
|
|
146
|
+
find_cudnn_library(cudnn_adv_infer)
|
|
147
|
+
find_cudnn_library(cudnn_adv_train)
|
|
148
|
+
find_cudnn_library(cudnn_cnn_infer)
|
|
149
|
+
find_cudnn_library(cudnn_cnn_train)
|
|
150
|
+
find_cudnn_library(cudnn_ops_infer)
|
|
151
|
+
find_cudnn_library(cudnn_ops_train)
|
|
152
|
+
|
|
153
|
+
target_link_libraries(
|
|
154
|
+
CUDNN::cudnn_all
|
|
155
|
+
INTERFACE CUDNN::cudnn_adv_train CUDNN::cudnn_ops_train
|
|
156
|
+
CUDNN::cudnn_cnn_train CUDNN::cudnn_adv_infer
|
|
157
|
+
CUDNN::cudnn_cnn_infer CUDNN::cudnn_ops_infer)
|
|
158
|
+
|
|
159
|
+
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
|
|
160
|
+
find_cudnn_library(cudnn_graph)
|
|
161
|
+
find_cudnn_library(cudnn_engines_runtime_compiled)
|
|
162
|
+
find_cudnn_library(cudnn_ops OPTIONAL)
|
|
163
|
+
find_cudnn_library(cudnn_cnn OPTIONAL)
|
|
164
|
+
find_cudnn_library(cudnn_adv OPTIONAL)
|
|
165
|
+
find_cudnn_library(cudnn_engines_precompiled OPTIONAL)
|
|
166
|
+
find_cudnn_library(cudnn_heuristic OPTIONAL)
|
|
167
|
+
|
|
168
|
+
target_link_libraries(
|
|
169
|
+
CUDNN::cudnn_all
|
|
170
|
+
INTERFACE CUDNN::cudnn_graph
|
|
171
|
+
CUDNN::cudnn_engines_runtime_compiled
|
|
172
|
+
CUDNN::cudnn_ops
|
|
173
|
+
CUDNN::cudnn_cnn
|
|
174
|
+
CUDNN::cudnn_adv
|
|
175
|
+
CUDNN::cudnn_engines_precompiled
|
|
176
|
+
CUDNN::cudnn_heuristic)
|
|
177
|
+
endif()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
|
2
|
+
# directories.
|
|
3
|
+
|
|
4
|
+
set(NCCL_ROOT_DIR
|
|
5
|
+
$ENV{NCCL_ROOT_DIR}
|
|
6
|
+
CACHE PATH "Folder contains NVIDIA NCCL")
|
|
7
|
+
|
|
8
|
+
find_path(
|
|
9
|
+
NCCL_INCLUDE_DIRS
|
|
10
|
+
NAMES nccl.h
|
|
11
|
+
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
|
12
|
+
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
|
13
|
+
|
|
14
|
+
if($ENV{USE_STATIC_NCCL})
|
|
15
|
+
message(
|
|
16
|
+
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
|
17
|
+
set(NCCL_LIBNAME "libnccl_static.a")
|
|
18
|
+
else()
|
|
19
|
+
set(NCCL_LIBNAME "nccl")
|
|
20
|
+
endif()
|
|
21
|
+
|
|
22
|
+
find_library(
|
|
23
|
+
NCCL_LIBRARIES
|
|
24
|
+
NAMES ${NCCL_LIBNAME}
|
|
25
|
+
HINTS ${NCCL_LIB_DIR}
|
|
26
|
+
${NCCL_ROOT_DIR}
|
|
27
|
+
${NCCL_ROOT_DIR}/lib
|
|
28
|
+
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
|
29
|
+
${NCCL_ROOT_DIR}/lib64
|
|
30
|
+
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
|
31
|
+
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
|
32
|
+
|
|
33
|
+
include(FindPackageHandleStandardArgs)
|
|
34
|
+
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
|
35
|
+
NCCL_LIBRARIES)
|
|
36
|
+
|
|
37
|
+
if(NCCL_FOUND)
|
|
38
|
+
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
|
39
|
+
message(
|
|
40
|
+
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
|
41
|
+
file(
|
|
42
|
+
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
|
43
|
+
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
|
44
|
+
LIMIT_COUNT 1)
|
|
45
|
+
if(NCCL_MAJOR_VERSION_DEFINED)
|
|
46
|
+
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
|
47
|
+
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
|
48
|
+
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
|
49
|
+
endif()
|
|
50
|
+
message(
|
|
51
|
+
STATUS
|
|
52
|
+
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
|
53
|
+
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
|
54
|
+
endif()
|