mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# Filename rules in cuda backend:
|
|
2
|
+
#
|
|
3
|
+
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
|
4
|
+
# * Device-only code should be put in device/ subdir.
|
|
5
|
+
# * Files in device/ subdir should not include files outside.
|
|
6
|
+
target_sources(
|
|
7
|
+
mlx
|
|
8
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
|
9
|
+
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
|
10
|
+
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
|
11
|
+
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
|
12
|
+
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
|
13
|
+
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
|
14
|
+
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
|
15
|
+
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
|
16
|
+
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
|
17
|
+
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
|
18
|
+
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
|
19
|
+
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
|
20
|
+
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
|
21
|
+
${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp
|
|
22
|
+
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
|
23
|
+
${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp
|
|
24
|
+
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
|
25
|
+
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
|
26
|
+
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
|
27
|
+
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
|
28
|
+
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
|
29
|
+
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
|
30
|
+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
|
31
|
+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
|
32
|
+
${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu
|
|
33
|
+
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
|
34
|
+
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
|
35
|
+
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
|
36
|
+
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
|
37
|
+
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
|
38
|
+
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
|
39
|
+
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
|
40
|
+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
41
|
+
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
|
42
|
+
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
|
43
|
+
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
|
44
|
+
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
|
45
|
+
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
|
46
|
+
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
|
47
|
+
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
|
48
|
+
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
|
49
|
+
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
|
50
|
+
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
|
51
|
+
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
|
52
|
+
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
|
53
|
+
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
|
54
|
+
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
|
55
|
+
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
|
56
|
+
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
|
57
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
|
58
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
|
|
59
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv.cu
|
|
60
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
|
61
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
|
|
62
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu
|
|
63
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
|
|
64
|
+
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
|
65
|
+
|
|
66
|
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
|
|
67
|
+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
|
|
68
|
+
|
|
69
|
+
# fp4 is not available on < 12.8
|
|
70
|
+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
|
|
71
|
+
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
|
|
72
|
+
target_sources(mlx
|
|
73
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/no_qqmm_impl.cpp)
|
|
74
|
+
else()
|
|
75
|
+
target_sources(
|
|
76
|
+
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_impl.cpp
|
|
77
|
+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp)
|
|
78
|
+
endif()
|
|
79
|
+
|
|
80
|
+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
|
81
|
+
target_sources(
|
|
82
|
+
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
|
|
83
|
+
else()
|
|
84
|
+
target_sources(
|
|
85
|
+
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
|
|
86
|
+
endif()
|
|
87
|
+
|
|
88
|
+
# Embed kernel sources in binary for JIT compilation.
|
|
89
|
+
file(
|
|
90
|
+
GLOB MLX_JIT_SOURCES
|
|
91
|
+
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|
92
|
+
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
|
93
|
+
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
|
94
|
+
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
|
95
|
+
add_custom_command(
|
|
96
|
+
OUTPUT gen/cuda_jit_sources.h
|
|
97
|
+
COMMAND
|
|
98
|
+
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
|
99
|
+
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
|
100
|
+
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
|
101
|
+
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
|
102
|
+
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
|
103
|
+
add_dependencies(mlx cuda_jit_sources)
|
|
104
|
+
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
|
105
|
+
|
|
106
|
+
# ------------------------ Compilation configs ------------------------
|
|
107
|
+
|
|
108
|
+
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
|
109
|
+
|
|
110
|
+
# Enable defining device lambda functions.
|
|
111
|
+
target_compile_options(mlx
|
|
112
|
+
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
|
113
|
+
|
|
114
|
+
# Enable calling host constexpr functions from device. This is needed because
|
|
115
|
+
# the constexpr version of isnan is host only.
|
|
116
|
+
target_compile_options(
|
|
117
|
+
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
|
|
118
|
+
|
|
119
|
+
# Suppress nvcc warnings on C++ headers.
|
|
120
|
+
target_compile_options(
|
|
121
|
+
mlx
|
|
122
|
+
PRIVATE
|
|
123
|
+
$<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=27,997,1394,20011,20208">
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Ignore some valid nvcc warnings, we might want to fix them in future.
|
|
127
|
+
target_compile_options(
|
|
128
|
+
mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=177,550">)
|
|
129
|
+
|
|
130
|
+
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
|
|
131
|
+
# and requires drivers released after CUDA 12.4.
|
|
132
|
+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
|
133
|
+
target_compile_options(
|
|
134
|
+
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
|
135
|
+
endif()
|
|
136
|
+
|
|
137
|
+
# Use native CUDA arch by default.
|
|
138
|
+
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
|
139
|
+
execute_process(
|
|
140
|
+
COMMAND __nvcc_device_query
|
|
141
|
+
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
|
142
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
|
143
|
+
set(UPGRADABLE_ARCHITECTURES "90;100;121")
|
|
144
|
+
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
|
|
145
|
+
message(
|
|
146
|
+
FATAL_ERROR
|
|
147
|
+
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
|
|
148
|
+
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
|
|
149
|
+
# Use arch-specific compute capability whenever possible.
|
|
150
|
+
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
|
|
151
|
+
endif()
|
|
152
|
+
endif()
|
|
153
|
+
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
|
154
|
+
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
|
155
|
+
"${MLX_CUDA_ARCHITECTURES}")
|
|
156
|
+
|
|
157
|
+
# Search CUDA libs from installed python packages.
|
|
158
|
+
if(WIN32)
|
|
159
|
+
# Resolve paths of unfound DLL at runtime.
|
|
160
|
+
if(BUILD_SHARED_LIBS)
|
|
161
|
+
target_link_libraries(mlx PRIVATE "delayimp.lib")
|
|
162
|
+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp)
|
|
163
|
+
else()
|
|
164
|
+
# For static library the delayload must be compiled into final executables.
|
|
165
|
+
target_link_libraries(mlx PUBLIC "delayimp.lib")
|
|
166
|
+
target_sources(
|
|
167
|
+
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp>)
|
|
168
|
+
endif()
|
|
169
|
+
# Get all the CUDA DLLs we could link with.
|
|
170
|
+
file(
|
|
171
|
+
GLOB CUDA_DLL_NAMES
|
|
172
|
+
RELATIVE "${CUDAToolkit_BIN_DIR}/x64"
|
|
173
|
+
"${CUDAToolkit_BIN_DIR}/x64/*.dll")
|
|
174
|
+
# Delay load CUDA and cuDNN libs.
|
|
175
|
+
foreach(CUDA_DLL ${CUDA_DLL_NAMES} ${CUDNN_DLL_NAMES})
|
|
176
|
+
target_link_options(mlx PUBLIC "/DELAYLOAD:${CUDA_DLL}")
|
|
177
|
+
endforeach()
|
|
178
|
+
# Pass the locations where CUDA DLLs are placed.
|
|
179
|
+
if(NOT MLX_LOAD_CUDA_LIBS_FROM_PYTHON)
|
|
180
|
+
target_compile_definitions(
|
|
181
|
+
mlx PUBLIC MLX_CUDA_BIN_DIR="${CUDAToolkit_BIN_DIR}/x64"
|
|
182
|
+
MLX_CUDNN_BIN_DIR="${CUDNN_BIN_DIR}")
|
|
183
|
+
endif()
|
|
184
|
+
else()
|
|
185
|
+
# For POSIX we rely on RPATH to search for CUDA libs.
|
|
186
|
+
if(MLX_LOAD_CUDA_LIBS_FROM_PYTHON)
|
|
187
|
+
set_property(
|
|
188
|
+
TARGET mlx
|
|
189
|
+
APPEND
|
|
190
|
+
PROPERTY INSTALL_RPATH
|
|
191
|
+
# The paths here should match the install_requires in setup.py.
|
|
192
|
+
"$ORIGIN/../../nvidia/cublas/lib"
|
|
193
|
+
"$ORIGIN/../../nvidia/cuda_nvrtc/lib"
|
|
194
|
+
"$ORIGIN/../../nvidia/cudnn/lib"
|
|
195
|
+
"$ORIGIN/../../nvidia/nccl/lib")
|
|
196
|
+
endif()
|
|
197
|
+
endif()
|
|
198
|
+
|
|
199
|
+
# ------------------------ Dependencies ------------------------
|
|
200
|
+
|
|
201
|
+
# Use fixed version of CCCL.
|
|
202
|
+
FetchContent_Declare(
|
|
203
|
+
cccl
|
|
204
|
+
URL "https://github.com/NVIDIA/cccl/releases/download/v3.1.3/cccl-v3.1.3.zip")
|
|
205
|
+
FetchContent_MakeAvailable(cccl)
|
|
206
|
+
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
|
207
|
+
|
|
208
|
+
# Install CCCL headers for JIT.
|
|
209
|
+
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
|
210
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
|
211
|
+
install(DIRECTORY ${cccl_SOURCE_DIR}/include/nv
|
|
212
|
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
|
213
|
+
|
|
214
|
+
# The binary of C++ tests will not be installed so it can not find the CCCL
|
|
215
|
+
# headers, and we have to hard-code the path.
|
|
216
|
+
if(MLX_BUILD_TESTS)
|
|
217
|
+
target_compile_definitions(mlx
|
|
218
|
+
PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include")
|
|
219
|
+
endif()
|
|
220
|
+
|
|
221
|
+
# Use fixed version of NVTX.
|
|
222
|
+
FetchContent_Declare(
|
|
223
|
+
nvtx3
|
|
224
|
+
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
|
225
|
+
GIT_TAG v3.1.1
|
|
226
|
+
GIT_SHALLOW TRUE
|
|
227
|
+
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
|
228
|
+
FetchContent_MakeAvailable(nvtx3)
|
|
229
|
+
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
|
230
|
+
|
|
231
|
+
# Make cuda runtime APIs available in non-cuda files.
|
|
232
|
+
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
|
233
|
+
|
|
234
|
+
# Use cublasLt.
|
|
235
|
+
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
|
236
|
+
|
|
237
|
+
# Use NVRTC and driver APIs.
|
|
238
|
+
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
|
239
|
+
|
|
240
|
+
# Use the frontend APIs of cuDNN.
|
|
241
|
+
FetchContent_Declare(
|
|
242
|
+
cudnn
|
|
243
|
+
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
|
244
|
+
GIT_TAG v1.16.0
|
|
245
|
+
GIT_SHALLOW TRUE
|
|
246
|
+
EXCLUDE_FROM_ALL)
|
|
247
|
+
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
|
248
|
+
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
|
249
|
+
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
|
250
|
+
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
|
251
|
+
FetchContent_MakeAvailable(cudnn)
|
|
252
|
+
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
|
253
|
+
# Link with the actual cuDNN libraries.
|
|
254
|
+
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
|
255
|
+
|
|
256
|
+
# Use header-only CUTLASS.
|
|
257
|
+
FetchContent_Declare(
|
|
258
|
+
cutlass
|
|
259
|
+
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
|
|
260
|
+
GIT_TAG v4.3.5
|
|
261
|
+
GIT_SHALLOW TRUE
|
|
262
|
+
SOURCE_SUBDIR include EXCLUDE_FROM_ALL)
|
|
263
|
+
FetchContent_MakeAvailable(cutlass)
|
|
264
|
+
target_include_directories(
|
|
265
|
+
mlx SYSTEM PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/cuda/allocator.h"
|
|
4
|
+
#include "mlx/backend/cuda/device.h"
|
|
5
|
+
#include "mlx/backend/cuda/utils.h"
|
|
6
|
+
#include "mlx/backend/gpu/device_info.h"
|
|
7
|
+
#include "mlx/memory.h"
|
|
8
|
+
#include "mlx/scheduler.h"
|
|
9
|
+
#include "mlx/utils.h"
|
|
10
|
+
|
|
11
|
+
#include <cuda_runtime.h>
|
|
12
|
+
#include <fmt/format.h>
|
|
13
|
+
|
|
14
|
+
#include <cassert>
|
|
15
|
+
#include <fstream>
|
|
16
|
+
#include <string>
|
|
17
|
+
|
|
18
|
+
namespace mlx::core {
|
|
19
|
+
|
|
20
|
+
namespace cu {
|
|
21
|
+
|
|
22
|
+
constexpr int page_size = 16384;
|
|
23
|
+
|
|
24
|
+
// Any allocations smaller than this will try to use the small pool
|
|
25
|
+
constexpr int small_block_size = 8;
|
|
26
|
+
|
|
27
|
+
// The small pool size in bytes. This should be a multiple of the host page
|
|
28
|
+
// size and small_block_size.
|
|
29
|
+
constexpr int small_pool_size = 4 * page_size;
|
|
30
|
+
|
|
31
|
+
// Check if running on Windows or Windows Subsystem for Linux
|
|
32
|
+
bool is_windows() {
|
|
33
|
+
#if defined(_WIN32)
|
|
34
|
+
return true;
|
|
35
|
+
#elif defined(__linux__)
|
|
36
|
+
// WSL kernels contain "microsoft" or "WSL" in /proc/version
|
|
37
|
+
static bool is_wsl = []() {
|
|
38
|
+
std::ifstream version("/proc/version");
|
|
39
|
+
if (version.is_open()) {
|
|
40
|
+
std::string line;
|
|
41
|
+
std::getline(version, line);
|
|
42
|
+
return line.find("microsoft") != std::string::npos ||
|
|
43
|
+
line.find("Microsoft") != std::string::npos ||
|
|
44
|
+
line.find("WSL") != std::string::npos;
|
|
45
|
+
}
|
|
46
|
+
return false;
|
|
47
|
+
}();
|
|
48
|
+
return is_wsl;
|
|
49
|
+
#else
|
|
50
|
+
return false;
|
|
51
|
+
#endif
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
bool supports_managed_memory() {
|
|
55
|
+
static bool managed_memory = []() {
|
|
56
|
+
int device_count = gpu::device_count();
|
|
57
|
+
for (int i = 0; i < device_count; ++i) {
|
|
58
|
+
auto& d = cu::device(i);
|
|
59
|
+
if (!d.managed_memory()) {
|
|
60
|
+
return false;
|
|
61
|
+
}
|
|
62
|
+
// Empirically on Windows (and WSL) if there is no concurrentManagedAccess
|
|
63
|
+
// the managed memory also does not work.
|
|
64
|
+
if (is_windows() && !d.concurrent_managed_access()) {
|
|
65
|
+
return false;
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
return true;
|
|
69
|
+
}();
|
|
70
|
+
return managed_memory;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
inline void* unified_malloc(size_t size) {
|
|
74
|
+
void* data = nullptr;
|
|
75
|
+
if (supports_managed_memory()) {
|
|
76
|
+
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
|
77
|
+
} else {
|
|
78
|
+
CHECK_CUDA_ERROR(cudaMallocHost(&data, size));
|
|
79
|
+
}
|
|
80
|
+
return data;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
inline void unified_free(void* data) {
|
|
84
|
+
if (supports_managed_memory()) {
|
|
85
|
+
CHECK_CUDA_ERROR(cudaFree(data));
|
|
86
|
+
} else {
|
|
87
|
+
CHECK_CUDA_ERROR(cudaFreeHost(data));
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
#if CUDART_VERSION >= 13000
|
|
92
|
+
inline cudaMemLocation cuda_mem_loc(int i) {
|
|
93
|
+
cudaMemLocation loc;
|
|
94
|
+
loc.type = cudaMemLocationTypeDevice;
|
|
95
|
+
loc.id = i;
|
|
96
|
+
return loc;
|
|
97
|
+
}
|
|
98
|
+
#else
|
|
99
|
+
inline int cuda_mem_loc(int i) {
|
|
100
|
+
return i;
|
|
101
|
+
}
|
|
102
|
+
#endif // CUDART_VERSION >= 13000
|
|
103
|
+
|
|
104
|
+
SmallSizePool::SmallSizePool() {
|
|
105
|
+
auto num_blocks = small_pool_size / small_block_size;
|
|
106
|
+
buffer_ = new Block[num_blocks];
|
|
107
|
+
next_free_ = buffer_;
|
|
108
|
+
|
|
109
|
+
data_ = unified_malloc(small_pool_size);
|
|
110
|
+
if (supports_managed_memory()) {
|
|
111
|
+
int device_count = gpu::device_count();
|
|
112
|
+
for (int i = 0; i < device_count; ++i) {
|
|
113
|
+
if (device(i).concurrent_managed_access()) {
|
|
114
|
+
auto loc = cuda_mem_loc(i);
|
|
115
|
+
CHECK_CUDA_ERROR(cudaMemAdvise(
|
|
116
|
+
data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
auto curr = next_free_;
|
|
122
|
+
for (size_t i = 1; i < num_blocks; ++i) {
|
|
123
|
+
curr->next = buffer_ + i;
|
|
124
|
+
curr = curr->next;
|
|
125
|
+
}
|
|
126
|
+
curr->next = nullptr;
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
SmallSizePool::~SmallSizePool() {
|
|
130
|
+
unified_free(data_);
|
|
131
|
+
delete[] buffer_;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
CudaBuffer* SmallSizePool::malloc() {
|
|
135
|
+
if (next_free_ == nullptr) {
|
|
136
|
+
return nullptr;
|
|
137
|
+
}
|
|
138
|
+
Block* b = next_free_;
|
|
139
|
+
uint64_t i = next_free_ - buffer_;
|
|
140
|
+
next_free_ = next_free_->next;
|
|
141
|
+
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
|
142
|
+
b->buf.size = small_block_size;
|
|
143
|
+
b->buf.device = -1;
|
|
144
|
+
return &b->buf;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
void SmallSizePool::free(CudaBuffer* buf) {
|
|
148
|
+
auto b = reinterpret_cast<Block*>(buf);
|
|
149
|
+
b->next = next_free_;
|
|
150
|
+
next_free_ = b;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
|
154
|
+
constexpr int num_blocks = (small_pool_size / small_block_size);
|
|
155
|
+
auto b = reinterpret_cast<Block*>(buf);
|
|
156
|
+
int64_t block_num = b - buffer_;
|
|
157
|
+
return block_num >= 0 && block_num < num_blocks;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
CudaAllocator::CudaAllocator()
|
|
161
|
+
: buffer_cache_(
|
|
162
|
+
page_size,
|
|
163
|
+
[](CudaBuffer* buf) { return buf->size; },
|
|
164
|
+
[this](CudaBuffer* buf) { free_cuda_buffer(buf); }) {
|
|
165
|
+
size_t free;
|
|
166
|
+
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
|
|
167
|
+
memory_limit_ = total_memory_ * 0.95;
|
|
168
|
+
free_limit_ = total_memory_ - memory_limit_;
|
|
169
|
+
max_pool_size_ = memory_limit_;
|
|
170
|
+
|
|
171
|
+
int device_count = gpu::device_count();
|
|
172
|
+
free_streams_.resize(device_count);
|
|
173
|
+
mem_pools_.resize(device_count);
|
|
174
|
+
for (int i = 0; i < device_count; ++i) {
|
|
175
|
+
auto& d = device(i);
|
|
176
|
+
if (d.memory_pools()) {
|
|
177
|
+
free_streams_[i] = CudaStream(d);
|
|
178
|
+
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pools_[i], i));
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
Buffer
|
|
184
|
+
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|
185
|
+
if (size == 0) {
|
|
186
|
+
return Buffer{new CudaBuffer{nullptr, 0, -1}};
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
if (size <= small_block_size) {
|
|
190
|
+
size = 8;
|
|
191
|
+
} else if (size < page_size) {
|
|
192
|
+
size = next_power_of_2(size);
|
|
193
|
+
} else {
|
|
194
|
+
size = page_size * ((size + page_size - 1) / page_size);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if (size <= small_block_size || stream == nullptr) {
|
|
198
|
+
device = -1;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
// Find available buffer from cache.
|
|
202
|
+
std::unique_lock lock(mutex_);
|
|
203
|
+
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
|
204
|
+
if (!buf) {
|
|
205
|
+
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
|
206
|
+
int64_t mem_to_free =
|
|
207
|
+
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
|
208
|
+
if (mem_to_free > 0) {
|
|
209
|
+
buffer_cache_.release_cached_buffers(mem_to_free);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
// Try the scalar pool first
|
|
213
|
+
if (size <= small_block_size) {
|
|
214
|
+
buf = scalar_pool_.malloc();
|
|
215
|
+
}
|
|
216
|
+
lock.unlock();
|
|
217
|
+
if (!buf) {
|
|
218
|
+
void* data = nullptr;
|
|
219
|
+
if (device == -1) {
|
|
220
|
+
data = unified_malloc(size);
|
|
221
|
+
} else {
|
|
222
|
+
cu::device(device).make_current();
|
|
223
|
+
if (mem_pools_[device]) { // supports memory pools
|
|
224
|
+
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
|
225
|
+
} else {
|
|
226
|
+
CHECK_CUDA_ERROR(cudaMalloc(&data, size));
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
if (!data) {
|
|
230
|
+
std::ostringstream msg;
|
|
231
|
+
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
|
232
|
+
throw std::runtime_error(msg.str());
|
|
233
|
+
}
|
|
234
|
+
buf = new CudaBuffer{data, size, device};
|
|
235
|
+
}
|
|
236
|
+
lock.lock();
|
|
237
|
+
|
|
238
|
+
// If any cuda memory pool has too much reserved memory, clear some
|
|
239
|
+
// memory from the cache. This prevents graph / kernel execution failing
|
|
240
|
+
// from OOM
|
|
241
|
+
if (get_cache_memory() > 0) {
|
|
242
|
+
for (auto p : mem_pools_) {
|
|
243
|
+
if (p) {
|
|
244
|
+
size_t used = 0;
|
|
245
|
+
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
|
|
246
|
+
p, cudaMemPoolAttrReservedMemCurrent, &used));
|
|
247
|
+
if (used > (total_memory_ - free_limit_)) {
|
|
248
|
+
buffer_cache_.release_cached_buffers(free_limit_);
|
|
249
|
+
break;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
active_memory_ += buf->size;
|
|
256
|
+
peak_memory_ = std::max(active_memory_, peak_memory_);
|
|
257
|
+
|
|
258
|
+
// Maintain the cache below the requested limit.
|
|
259
|
+
if (get_cache_memory() > max_pool_size_) {
|
|
260
|
+
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
|
261
|
+
}
|
|
262
|
+
lock.unlock();
|
|
263
|
+
// Copy to unified memory here if the buffer is not on the right device.
|
|
264
|
+
if (buf->device >= 0 && buf->device != device) {
|
|
265
|
+
move_to_unified_memory(*buf, stream);
|
|
266
|
+
}
|
|
267
|
+
return Buffer{buf};
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
Buffer CudaAllocator::malloc(size_t size) {
|
|
271
|
+
return malloc_async(size, -1, nullptr);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
void CudaAllocator::free(Buffer buffer) {
|
|
275
|
+
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
|
276
|
+
if (!buf) {
|
|
277
|
+
return;
|
|
278
|
+
}
|
|
279
|
+
if (buf->size == 0) {
|
|
280
|
+
delete buf;
|
|
281
|
+
return;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
std::unique_lock lock(mutex_);
|
|
285
|
+
active_memory_ -= buf->size;
|
|
286
|
+
if (get_cache_memory() < max_pool_size_) {
|
|
287
|
+
buffer_cache_.recycle_to_cache(buf);
|
|
288
|
+
} else {
|
|
289
|
+
free_cuda_buffer(buf);
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
size_t CudaAllocator::size(Buffer buffer) const {
|
|
294
|
+
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
|
295
|
+
if (!buf) {
|
|
296
|
+
return 0;
|
|
297
|
+
}
|
|
298
|
+
return buf->size;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
void CudaAllocator::move_to_unified_memory(
|
|
302
|
+
CudaBuffer& buf,
|
|
303
|
+
cudaStream_t stream) {
|
|
304
|
+
if (buf.device == -1) {
|
|
305
|
+
return;
|
|
306
|
+
}
|
|
307
|
+
void* data = unified_malloc(buf.size);
|
|
308
|
+
cudaMemcpyKind kind =
|
|
309
|
+
supports_managed_memory() ? cudaMemcpyDefault : cudaMemcpyDeviceToHost;
|
|
310
|
+
if (stream && mem_pools_[buf.device]) {
|
|
311
|
+
CHECK_CUDA_ERROR(cudaMemcpyAsync(data, buf.data, buf.size, kind, stream));
|
|
312
|
+
free_async(buf, stream);
|
|
313
|
+
} else {
|
|
314
|
+
CHECK_CUDA_ERROR(cudaMemcpy(data, buf.data, buf.size, kind));
|
|
315
|
+
free_async(buf);
|
|
316
|
+
}
|
|
317
|
+
buf.data = data;
|
|
318
|
+
buf.device = -1;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
// This must be called with mutex_ aquired
|
|
322
|
+
void CudaAllocator::free_cuda_buffer(CudaBuffer* buf) {
|
|
323
|
+
if (scalar_pool_.in_pool(buf)) {
|
|
324
|
+
scalar_pool_.free(buf);
|
|
325
|
+
} else {
|
|
326
|
+
free_async(*buf);
|
|
327
|
+
delete buf;
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
void CudaAllocator::free_async(CudaBuffer& buf, cudaStream_t stream) {
|
|
332
|
+
if (buf.device == -1) {
|
|
333
|
+
unified_free(buf.data);
|
|
334
|
+
} else {
|
|
335
|
+
// Free asynchronously when memory pools is supported.
|
|
336
|
+
if (mem_pools_[buf.device]) {
|
|
337
|
+
if (!stream) {
|
|
338
|
+
stream = free_streams_[buf.device];
|
|
339
|
+
}
|
|
340
|
+
CHECK_CUDA_ERROR(cudaFreeAsync(buf.data, stream));
|
|
341
|
+
} else {
|
|
342
|
+
CHECK_CUDA_ERROR(cudaFree(buf.data));
|
|
343
|
+
}
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
size_t CudaAllocator::get_active_memory() const {
|
|
348
|
+
return active_memory_;
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
size_t CudaAllocator::get_peak_memory() const {
|
|
352
|
+
return peak_memory_;
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
void CudaAllocator::reset_peak_memory() {
|
|
356
|
+
std::lock_guard lock(mutex_);
|
|
357
|
+
peak_memory_ = 0;
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
size_t CudaAllocator::get_memory_limit() {
|
|
361
|
+
return memory_limit_;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
|
365
|
+
std::lock_guard lock(mutex_);
|
|
366
|
+
std::swap(limit, memory_limit_);
|
|
367
|
+
return limit;
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
size_t CudaAllocator::get_cache_memory() const {
|
|
371
|
+
return buffer_cache_.cache_size();
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
|
375
|
+
std::lock_guard lk(mutex_);
|
|
376
|
+
std::swap(limit, max_pool_size_);
|
|
377
|
+
return limit;
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
void CudaAllocator::clear_cache() {
|
|
381
|
+
std::lock_guard lk(mutex_);
|
|
382
|
+
buffer_cache_.clear();
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
CudaAllocator& allocator() {
|
|
386
|
+
static auto* allocator_ = []() {
|
|
387
|
+
// Ensure scheduler is created before allocator.
|
|
388
|
+
scheduler::scheduler();
|
|
389
|
+
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
|
390
|
+
// will not be called on exit and buffers in the cache will be leaked. This
|
|
391
|
+
// can save some time at program exit.
|
|
392
|
+
return new CudaAllocator();
|
|
393
|
+
}();
|
|
394
|
+
return *allocator_;
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
|
|
398
|
+
return allocator().malloc_async(
|
|
399
|
+
size, encoder.device().cuda_device(), encoder.stream());
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
} // namespace cu
|
|
403
|
+
|
|
404
|
+
namespace allocator {
|
|
405
|
+
|
|
406
|
+
Allocator& allocator() {
|
|
407
|
+
return cu::allocator();
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
void* Buffer::raw_ptr() {
|
|
411
|
+
if (!ptr_) {
|
|
412
|
+
return nullptr;
|
|
413
|
+
}
|
|
414
|
+
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
|
|
415
|
+
cu::allocator().move_to_unified_memory(cbuf);
|
|
416
|
+
return cbuf.data;
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
} // namespace allocator
|
|
420
|
+
|
|
421
|
+
size_t get_active_memory() {
|
|
422
|
+
return cu::allocator().get_active_memory();
|
|
423
|
+
}
|
|
424
|
+
size_t get_peak_memory() {
|
|
425
|
+
return cu::allocator().get_peak_memory();
|
|
426
|
+
}
|
|
427
|
+
void reset_peak_memory() {
|
|
428
|
+
return cu::allocator().reset_peak_memory();
|
|
429
|
+
}
|
|
430
|
+
size_t set_memory_limit(size_t limit) {
|
|
431
|
+
return cu::allocator().set_memory_limit(limit);
|
|
432
|
+
}
|
|
433
|
+
size_t get_memory_limit() {
|
|
434
|
+
return cu::allocator().get_memory_limit();
|
|
435
|
+
}
|
|
436
|
+
size_t get_cache_memory() {
|
|
437
|
+
return cu::allocator().get_cache_memory();
|
|
438
|
+
}
|
|
439
|
+
size_t set_cache_limit(size_t limit) {
|
|
440
|
+
return cu::allocator().set_cache_limit(limit);
|
|
441
|
+
}
|
|
442
|
+
void clear_cache() {
|
|
443
|
+
cu::allocator().clear_cache();
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// Not supported in CUDA.
|
|
447
|
+
size_t set_wired_limit(size_t) {
|
|
448
|
+
return 0;
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
} // namespace mlx::core
|