mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/allocator.h"
|
|
6
|
+
#include "mlx/backend/common/buffer_cache.h"
|
|
7
|
+
#include "mlx/backend/cuda/cuda_utils.h"
|
|
8
|
+
|
|
9
|
+
#include <cuda_runtime.h>
|
|
10
|
+
#include <mutex>
|
|
11
|
+
#include <set>
|
|
12
|
+
#include <utility>
|
|
13
|
+
|
|
14
|
+
namespace mlx::core::cu {
|
|
15
|
+
|
|
16
|
+
class CommandEncoder;
|
|
17
|
+
|
|
18
|
+
using allocator::Buffer;
|
|
19
|
+
|
|
20
|
+
// Stores cuda-managed unified memory.
|
|
21
|
+
struct CudaBuffer {
|
|
22
|
+
void* data;
|
|
23
|
+
size_t size;
|
|
24
|
+
int device; // -1 for managed
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
class SmallSizePool {
|
|
28
|
+
private:
|
|
29
|
+
union Block {
|
|
30
|
+
Block* next;
|
|
31
|
+
CudaBuffer buf;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
Block* buffer_{nullptr};
|
|
35
|
+
void* data_{nullptr};
|
|
36
|
+
Block* next_free_{nullptr};
|
|
37
|
+
|
|
38
|
+
public:
|
|
39
|
+
SmallSizePool();
|
|
40
|
+
~SmallSizePool();
|
|
41
|
+
|
|
42
|
+
SmallSizePool(const SmallSizePool&) = delete;
|
|
43
|
+
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
|
44
|
+
|
|
45
|
+
CudaBuffer* malloc();
|
|
46
|
+
void free(CudaBuffer* buf);
|
|
47
|
+
bool in_pool(CudaBuffer* buf);
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
class CudaAllocator : public allocator::Allocator {
|
|
51
|
+
public:
|
|
52
|
+
Buffer malloc(size_t size) override;
|
|
53
|
+
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
|
|
54
|
+
void free(Buffer buffer) override;
|
|
55
|
+
size_t size(Buffer buffer) const override;
|
|
56
|
+
|
|
57
|
+
// Replace the memory of |buf| with unified memory (managed memory or pinned
|
|
58
|
+
// host memory), and copy the data over. Pass |stream| to copy asynchronously.
|
|
59
|
+
void move_to_unified_memory(CudaBuffer& buf, cudaStream_t stream = nullptr);
|
|
60
|
+
|
|
61
|
+
size_t get_active_memory() const;
|
|
62
|
+
size_t get_peak_memory() const;
|
|
63
|
+
void reset_peak_memory();
|
|
64
|
+
size_t get_memory_limit();
|
|
65
|
+
size_t set_memory_limit(size_t limit);
|
|
66
|
+
size_t get_cache_memory() const;
|
|
67
|
+
size_t set_cache_limit(size_t limit);
|
|
68
|
+
void clear_cache();
|
|
69
|
+
|
|
70
|
+
private:
|
|
71
|
+
void free_cuda_buffer(CudaBuffer* buf);
|
|
72
|
+
void free_async(CudaBuffer& buf, cudaStream_t stream = nullptr);
|
|
73
|
+
|
|
74
|
+
CudaAllocator();
|
|
75
|
+
friend CudaAllocator& allocator();
|
|
76
|
+
|
|
77
|
+
std::mutex mutex_;
|
|
78
|
+
size_t memory_limit_;
|
|
79
|
+
size_t free_limit_;
|
|
80
|
+
size_t total_memory_;
|
|
81
|
+
size_t max_pool_size_;
|
|
82
|
+
BufferCache<CudaBuffer> buffer_cache_;
|
|
83
|
+
size_t active_memory_{0};
|
|
84
|
+
size_t peak_memory_{0};
|
|
85
|
+
std::vector<CudaStream> free_streams_;
|
|
86
|
+
std::vector<cudaMemPool_t> mem_pools_;
|
|
87
|
+
SmallSizePool scalar_pool_;
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
CudaAllocator& allocator();
|
|
91
|
+
|
|
92
|
+
Buffer malloc_async(size_t size, CommandEncoder& encoder);
|
|
93
|
+
|
|
94
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/cuda/device.h"
|
|
4
|
+
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
5
|
+
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
6
|
+
#include "mlx/dtype_utils.h"
|
|
7
|
+
#include "mlx/primitives.h"
|
|
8
|
+
|
|
9
|
+
#include <cooperative_groups.h>
|
|
10
|
+
#include <nvtx3/nvtx3.hpp>
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
namespace cu {
|
|
15
|
+
|
|
16
|
+
namespace cg = cooperative_groups;
|
|
17
|
+
|
|
18
|
+
template <typename T, typename IdxT, int N_WRITES>
|
|
19
|
+
__global__ void arange(T* out, IdxT size, T start, T step) {
|
|
20
|
+
IdxT index = cg::this_grid().thread_rank();
|
|
21
|
+
|
|
22
|
+
if ((index + 1) * N_WRITES > size) {
|
|
23
|
+
for (IdxT i = index * N_WRITES; i < size; ++i) {
|
|
24
|
+
out[i] = start + i * step;
|
|
25
|
+
}
|
|
26
|
+
} else {
|
|
27
|
+
AlignedVector<T, N_WRITES> out_vec;
|
|
28
|
+
#pragma unroll
|
|
29
|
+
for (int i = 0; i < N_WRITES; ++i) {
|
|
30
|
+
out_vec[i] = start + (index * N_WRITES + i) * step;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
store_vector<N_WRITES>(out, index, out_vec);
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
} // namespace cu
|
|
38
|
+
|
|
39
|
+
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
40
|
+
nvtx3::scoped_range r("Arange::eval_gpu");
|
|
41
|
+
if (out.size() == 0) {
|
|
42
|
+
return;
|
|
43
|
+
}
|
|
44
|
+
auto& encoder = cu::get_command_encoder(stream());
|
|
45
|
+
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
|
46
|
+
encoder.set_output_array(out);
|
|
47
|
+
|
|
48
|
+
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
|
49
|
+
using CTYPE = MLX_GET_TYPE(type_tag);
|
|
50
|
+
using OutType = cuda_type_t<CTYPE>;
|
|
51
|
+
constexpr int N_WRITES = 16 / sizeof(OutType);
|
|
52
|
+
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
|
53
|
+
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
|
54
|
+
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
|
|
55
|
+
encoder.add_kernel_node(
|
|
56
|
+
cu::arange<OutType, IdxT, N_WRITES>,
|
|
57
|
+
num_blocks,
|
|
58
|
+
block_dims,
|
|
59
|
+
0,
|
|
60
|
+
gpu_ptr<OutType>(out),
|
|
61
|
+
out.data_size(),
|
|
62
|
+
static_cast<CTYPE>(start_),
|
|
63
|
+
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
|
64
|
+
});
|
|
65
|
+
});
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/common/utils.h"
|
|
4
|
+
#include "mlx/backend/cuda/device.h"
|
|
5
|
+
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
6
|
+
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
7
|
+
#include "mlx/dtype_utils.h"
|
|
8
|
+
#include "mlx/primitives.h"
|
|
9
|
+
|
|
10
|
+
#include <cooperative_groups.h>
|
|
11
|
+
#include <nvtx3/nvtx3.hpp>
|
|
12
|
+
#include <cub/block/block_load.cuh>
|
|
13
|
+
#include <cub/block/block_reduce.cuh>
|
|
14
|
+
|
|
15
|
+
#include <cassert>
|
|
16
|
+
|
|
17
|
+
namespace mlx::core {
|
|
18
|
+
|
|
19
|
+
namespace cu {
|
|
20
|
+
|
|
21
|
+
namespace cg = cooperative_groups;
|
|
22
|
+
|
|
23
|
+
template <typename T>
|
|
24
|
+
struct IndexValPair {
|
|
25
|
+
uint32_t index;
|
|
26
|
+
T val;
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
template <typename T>
|
|
30
|
+
struct ArgMin {
|
|
31
|
+
constexpr __device__ T init() {
|
|
32
|
+
return Limits<T>::max();
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
__device__ IndexValPair<T> operator()(
|
|
36
|
+
const IndexValPair<T>& best,
|
|
37
|
+
const IndexValPair<T>& current) {
|
|
38
|
+
if (best.val > current.val ||
|
|
39
|
+
(best.val == current.val && best.index > current.index)) {
|
|
40
|
+
return current;
|
|
41
|
+
} else {
|
|
42
|
+
return best;
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
template <int N>
|
|
47
|
+
__device__ IndexValPair<T> reduce_many(
|
|
48
|
+
IndexValPair<T> best,
|
|
49
|
+
const AlignedVector<T, N>& vals,
|
|
50
|
+
uint32_t offset) {
|
|
51
|
+
#pragma unroll
|
|
52
|
+
for (int i = 0; i < N; i++) {
|
|
53
|
+
if (vals[i] < best.val) {
|
|
54
|
+
best.val = vals[i];
|
|
55
|
+
best.index = offset + i;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
return best;
|
|
59
|
+
}
|
|
60
|
+
};
|
|
61
|
+
|
|
62
|
+
template <typename T>
|
|
63
|
+
struct ArgMax {
|
|
64
|
+
constexpr __device__ T init() {
|
|
65
|
+
return Limits<T>::min();
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
__device__ IndexValPair<T> operator()(
|
|
69
|
+
const IndexValPair<T>& best,
|
|
70
|
+
const IndexValPair<T>& current) {
|
|
71
|
+
if (best.val < current.val ||
|
|
72
|
+
(best.val == current.val && best.index > current.index)) {
|
|
73
|
+
return current;
|
|
74
|
+
} else {
|
|
75
|
+
return best;
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
template <int N>
|
|
80
|
+
__device__ IndexValPair<T> reduce_many(
|
|
81
|
+
IndexValPair<T> best,
|
|
82
|
+
const AlignedVector<T, N>& vals,
|
|
83
|
+
uint32_t offset) {
|
|
84
|
+
#pragma unroll
|
|
85
|
+
for (int i = 0; i < N; i++) {
|
|
86
|
+
if (vals[i] > best.val) {
|
|
87
|
+
best.val = vals[i];
|
|
88
|
+
best.index = offset + i;
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
return best;
|
|
92
|
+
}
|
|
93
|
+
};
|
|
94
|
+
|
|
95
|
+
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
|
96
|
+
__global__ void arg_reduce_general(
|
|
97
|
+
const T* in,
|
|
98
|
+
uint32_t* out,
|
|
99
|
+
size_t size,
|
|
100
|
+
const __grid_constant__ Shape shape,
|
|
101
|
+
const __grid_constant__ Strides in_strides,
|
|
102
|
+
const __grid_constant__ Strides out_strides,
|
|
103
|
+
int32_t ndim,
|
|
104
|
+
int64_t axis_stride,
|
|
105
|
+
int32_t axis_size) {
|
|
106
|
+
auto block = cg::this_thread_block();
|
|
107
|
+
|
|
108
|
+
int64_t index = cg::this_grid().block_rank();
|
|
109
|
+
if (index >= size) {
|
|
110
|
+
return;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
|
114
|
+
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
|
115
|
+
in += in_idx;
|
|
116
|
+
|
|
117
|
+
Op op;
|
|
118
|
+
T init = op.init();
|
|
119
|
+
IndexValPair<T> best{0, init};
|
|
120
|
+
|
|
121
|
+
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
|
122
|
+
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
|
123
|
+
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
|
|
124
|
+
best = op.reduce_many(best, vals, tid * N_READS);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
|
128
|
+
__shared__ typename BlockReduceT::TempStorage temp;
|
|
129
|
+
|
|
130
|
+
best = BlockReduceT(temp).Reduce(best, op);
|
|
131
|
+
|
|
132
|
+
if (block.thread_rank() == 0) {
|
|
133
|
+
out[out_idx] = best.index;
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
} // namespace cu
|
|
138
|
+
|
|
139
|
+
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
140
|
+
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
|
141
|
+
assert(inputs.size() == 1);
|
|
142
|
+
auto& in = inputs[0];
|
|
143
|
+
|
|
144
|
+
auto& s = stream();
|
|
145
|
+
auto& encoder = cu::get_command_encoder(s);
|
|
146
|
+
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
|
147
|
+
|
|
148
|
+
// Prepare the shapes, strides and axis arguments.
|
|
149
|
+
Shape shape = remove_index(in.shape(), axis_);
|
|
150
|
+
Strides in_strides = remove_index(in.strides(), axis_);
|
|
151
|
+
Strides out_strides = out.ndim() == in.ndim()
|
|
152
|
+
? remove_index(out.strides(), axis_)
|
|
153
|
+
: out.strides();
|
|
154
|
+
int64_t axis_stride = in.strides()[axis_];
|
|
155
|
+
int32_t axis_size = in.shape()[axis_];
|
|
156
|
+
int32_t ndim = shape.size();
|
|
157
|
+
|
|
158
|
+
// ArgReduce.
|
|
159
|
+
encoder.set_input_array(in);
|
|
160
|
+
encoder.set_output_array(out);
|
|
161
|
+
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
|
162
|
+
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
163
|
+
constexpr uint32_t N_READS = 4;
|
|
164
|
+
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
|
165
|
+
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
|
166
|
+
auto kernel =
|
|
167
|
+
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
|
168
|
+
if (reduce_type_ == ArgReduce::ArgMin) {
|
|
169
|
+
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
|
170
|
+
}
|
|
171
|
+
encoder.add_kernel_node(
|
|
172
|
+
kernel,
|
|
173
|
+
num_blocks,
|
|
174
|
+
block_dim(),
|
|
175
|
+
0,
|
|
176
|
+
gpu_ptr<T>(in),
|
|
177
|
+
gpu_ptr<uint32_t>(out),
|
|
178
|
+
out.size(),
|
|
179
|
+
const_param(shape),
|
|
180
|
+
const_param(in_strides),
|
|
181
|
+
const_param(out_strides),
|
|
182
|
+
ndim,
|
|
183
|
+
axis_stride,
|
|
184
|
+
axis_size);
|
|
185
|
+
});
|
|
186
|
+
});
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# Based on: https://github.com/sivachandran/cmake-bin2h
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2020 Sivachandran Paramasivam
|
|
4
|
+
#
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
#
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
#
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
include(CMakeParseArguments)
|
|
24
|
+
|
|
25
|
+
# Function to wrap a given string into multiple lines at the given column
|
|
26
|
+
# position.
|
|
27
|
+
#
|
|
28
|
+
# Parameters:
|
|
29
|
+
#
|
|
30
|
+
# * VARIABLE - The name of the CMake variable holding the string.
|
|
31
|
+
# * AT_COLUMN - The column position at which string will be wrapped.
|
|
32
|
+
function(WRAP_STRING)
|
|
33
|
+
set(oneValueArgs VARIABLE AT_COLUMN)
|
|
34
|
+
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
|
35
|
+
|
|
36
|
+
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
|
37
|
+
math(EXPR offset "0")
|
|
38
|
+
|
|
39
|
+
while(stringLength GREATER 0)
|
|
40
|
+
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
|
41
|
+
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
|
42
|
+
else()
|
|
43
|
+
math(EXPR length "${stringLength}")
|
|
44
|
+
endif()
|
|
45
|
+
|
|
46
|
+
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
|
47
|
+
set(lines "${lines}\n ${line}")
|
|
48
|
+
|
|
49
|
+
math(EXPR stringLength "${stringLength} - ${length}")
|
|
50
|
+
math(EXPR offset "${offset} + ${length}")
|
|
51
|
+
endwhile()
|
|
52
|
+
|
|
53
|
+
set(${WRAP_STRING_VARIABLE}
|
|
54
|
+
"${lines}"
|
|
55
|
+
PARENT_SCOPE)
|
|
56
|
+
endfunction()
|
|
57
|
+
|
|
58
|
+
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
|
59
|
+
# The header file will contain a byte array and integer variable holding the
|
|
60
|
+
# size of the array.
|
|
61
|
+
#
|
|
62
|
+
# Parameters:
|
|
63
|
+
#
|
|
64
|
+
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
|
65
|
+
# the header file.
|
|
66
|
+
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
|
67
|
+
# "_SIZE" will be append to this name and will be used a variable name for
|
|
68
|
+
# size variable.
|
|
69
|
+
# * HEADER_FILE - The path of header file.
|
|
70
|
+
# * APPEND - If specified appends to the header file instead of overwriting it
|
|
71
|
+
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
|
72
|
+
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
|
73
|
+
# array.
|
|
74
|
+
#
|
|
75
|
+
# Usage:
|
|
76
|
+
#
|
|
77
|
+
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
|
78
|
+
function(BIN2H)
|
|
79
|
+
set(options APPEND NULL_TERMINATE)
|
|
80
|
+
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
|
81
|
+
set(multiValueArgs SOURCE_FILES)
|
|
82
|
+
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
|
83
|
+
"${multiValueArgs}" ${ARGN})
|
|
84
|
+
|
|
85
|
+
set(arrayDefinition "")
|
|
86
|
+
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
|
87
|
+
# get filename without extension
|
|
88
|
+
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
|
89
|
+
# convert the filename to a valid C identifier
|
|
90
|
+
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
|
91
|
+
|
|
92
|
+
# reads source file contents as hex string
|
|
93
|
+
file(READ ${SOURCE_FILE} hexString HEX)
|
|
94
|
+
|
|
95
|
+
# append null
|
|
96
|
+
if(BIN2H_NULL_TERMINATE)
|
|
97
|
+
string(APPEND hexString "00")
|
|
98
|
+
endif()
|
|
99
|
+
|
|
100
|
+
# wraps the hex string into multiple lines
|
|
101
|
+
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
|
102
|
+
|
|
103
|
+
# strip the © in source code
|
|
104
|
+
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
|
105
|
+
|
|
106
|
+
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
|
107
|
+
${arrayValues})
|
|
108
|
+
|
|
109
|
+
# make a full variable name for the array
|
|
110
|
+
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
|
111
|
+
|
|
112
|
+
# declares byte array and the length variables
|
|
113
|
+
string(APPEND arrayDefinition
|
|
114
|
+
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
|
115
|
+
endforeach()
|
|
116
|
+
|
|
117
|
+
# add namespace wrapper if defined
|
|
118
|
+
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
|
119
|
+
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
|
120
|
+
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
|
121
|
+
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
|
122
|
+
endif()
|
|
123
|
+
|
|
124
|
+
set(arrayIncludes "#pragma once")
|
|
125
|
+
string(PREPEND declarations "${arrayIncludes}\n\n")
|
|
126
|
+
|
|
127
|
+
if(BIN2H_APPEND)
|
|
128
|
+
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
|
129
|
+
else()
|
|
130
|
+
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
|
131
|
+
endif()
|
|
132
|
+
endfunction()
|
|
133
|
+
|
|
134
|
+
# ----------------------------- CLI args -----------------------------
|
|
135
|
+
|
|
136
|
+
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
|
137
|
+
foreach(source ${MLX_JIT_SOURCES_LIST})
|
|
138
|
+
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
|
139
|
+
endforeach()
|
|
140
|
+
|
|
141
|
+
bin2h(
|
|
142
|
+
SOURCE_FILES
|
|
143
|
+
${MLX_JIT_SOURCES_ABS}
|
|
144
|
+
NULL_TERMINATE
|
|
145
|
+
VARIABLE_NAME
|
|
146
|
+
"jit_source"
|
|
147
|
+
HEADER_NAMESPACE
|
|
148
|
+
"mlx::core"
|
|
149
|
+
HEADER_FILE
|
|
150
|
+
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
target_sources(
|
|
2
|
+
mlx
|
|
3
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
|
|
4
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
|
|
5
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
|
|
6
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
|
|
7
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
|
|
8
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
|
|
9
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
|
|
10
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
|
|
11
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
|
|
12
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
|
|
13
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
|
|
14
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
|
|
15
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
|
|
16
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
|
|
17
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
|
|
18
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
|
|
19
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
|
|
20
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
|
|
21
|
+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)
|