mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
data/mlx/mlx/random.h
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <chrono>
|
|
6
|
+
#include <optional>
|
|
7
|
+
|
|
8
|
+
#include "mlx/api.h"
|
|
9
|
+
#include "mlx/array.h"
|
|
10
|
+
#include "mlx/stream.h"
|
|
11
|
+
#include "mlx/utils.h"
|
|
12
|
+
|
|
13
|
+
namespace mlx::core::random {
|
|
14
|
+
|
|
15
|
+
class MLX_API KeySequence {
|
|
16
|
+
public:
|
|
17
|
+
explicit KeySequence(uint64_t seed);
|
|
18
|
+
|
|
19
|
+
void seed(uint64_t seed);
|
|
20
|
+
array next();
|
|
21
|
+
|
|
22
|
+
// static default
|
|
23
|
+
static KeySequence& default_() {
|
|
24
|
+
static KeySequence ks(get_current_time_seed());
|
|
25
|
+
return ks;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
private:
|
|
29
|
+
array key_;
|
|
30
|
+
static uint64_t get_current_time_seed() {
|
|
31
|
+
auto now = std::chrono::system_clock::now();
|
|
32
|
+
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
33
|
+
now.time_since_epoch())
|
|
34
|
+
.count();
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
/** Get a PRNG key from a seed. */
|
|
39
|
+
MLX_API array key(uint64_t seed);
|
|
40
|
+
|
|
41
|
+
/** Seed the default PRNG key. */
|
|
42
|
+
MLX_API void seed(uint64_t seed);
|
|
43
|
+
|
|
44
|
+
/** Generate an array with type uint32 filled with random bits. */
|
|
45
|
+
MLX_API array bits(
|
|
46
|
+
const Shape& shape,
|
|
47
|
+
int width,
|
|
48
|
+
const std::optional<array>& key = std::nullopt,
|
|
49
|
+
StreamOrDevice s = {});
|
|
50
|
+
inline array bits(
|
|
51
|
+
const Shape& shape,
|
|
52
|
+
const std::optional<array>& key = std::nullopt,
|
|
53
|
+
StreamOrDevice s = {}) {
|
|
54
|
+
return bits(shape, 4, key, s);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
/** Split the rng key into a pair of keys. */
|
|
58
|
+
MLX_API std::pair<array, array> split(const array& key, StreamOrDevice s = {});
|
|
59
|
+
|
|
60
|
+
/** Split the rng key into `num` keys. */
|
|
61
|
+
MLX_API array split(const array& key, int num, StreamOrDevice s = {});
|
|
62
|
+
|
|
63
|
+
/** Generate uniform random numbers between low and high. */
|
|
64
|
+
MLX_API array uniform(
|
|
65
|
+
const array& low,
|
|
66
|
+
const array& high,
|
|
67
|
+
const Shape& shape,
|
|
68
|
+
Dtype dtype = float32,
|
|
69
|
+
const std::optional<array>& key = std::nullopt,
|
|
70
|
+
StreamOrDevice s = {});
|
|
71
|
+
|
|
72
|
+
template <typename T, typename U>
|
|
73
|
+
array uniform(
|
|
74
|
+
T low,
|
|
75
|
+
U high,
|
|
76
|
+
const Shape& shape,
|
|
77
|
+
Dtype dtype = float32,
|
|
78
|
+
const std::optional<array>& key = std::nullopt,
|
|
79
|
+
StreamOrDevice s = {}) {
|
|
80
|
+
return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/** Generate uniform random numbers between 0 and 1. */
|
|
84
|
+
MLX_API array uniform(
|
|
85
|
+
const Shape& shape,
|
|
86
|
+
Dtype dtype,
|
|
87
|
+
const std::optional<array>& key = std::nullopt,
|
|
88
|
+
StreamOrDevice s = {});
|
|
89
|
+
inline array uniform(
|
|
90
|
+
const Shape& shape,
|
|
91
|
+
const std::optional<array>& key = std::nullopt,
|
|
92
|
+
StreamOrDevice s = {}) {
|
|
93
|
+
return uniform(shape, float32, key, s);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/** Generate samples from the standard normal distribution. */
|
|
97
|
+
MLX_API array normal(
|
|
98
|
+
const Shape& shape,
|
|
99
|
+
Dtype dtype,
|
|
100
|
+
const std::optional<array>& loc,
|
|
101
|
+
const std::optional<array>& scale,
|
|
102
|
+
const std::optional<array>& key,
|
|
103
|
+
StreamOrDevice s = {});
|
|
104
|
+
inline array normal(
|
|
105
|
+
const Shape& shape,
|
|
106
|
+
Dtype dtype,
|
|
107
|
+
const float loc,
|
|
108
|
+
const float scale,
|
|
109
|
+
const std::optional<array>& key = std::nullopt,
|
|
110
|
+
StreamOrDevice s = {}) {
|
|
111
|
+
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
|
|
112
|
+
auto scale_ =
|
|
113
|
+
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
|
|
114
|
+
return normal(shape, dtype, loc_, scale_, key, s);
|
|
115
|
+
}
|
|
116
|
+
inline array normal(
|
|
117
|
+
const Shape& shape,
|
|
118
|
+
const float loc,
|
|
119
|
+
const float scale,
|
|
120
|
+
const std::optional<array>& key = std::nullopt,
|
|
121
|
+
StreamOrDevice s = {}) {
|
|
122
|
+
return normal(shape, float32, loc, scale, key, s);
|
|
123
|
+
}
|
|
124
|
+
inline array normal(
|
|
125
|
+
const Shape& shape,
|
|
126
|
+
const Dtype dtype,
|
|
127
|
+
const std::optional<array>& key = std::nullopt,
|
|
128
|
+
StreamOrDevice s = {}) {
|
|
129
|
+
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
|
|
130
|
+
}
|
|
131
|
+
inline array normal(
|
|
132
|
+
const Shape& shape,
|
|
133
|
+
const std::optional<array>& key = std::nullopt,
|
|
134
|
+
StreamOrDevice s = {}) {
|
|
135
|
+
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
/** Generate samples from a multivariate normal distribution. **/
|
|
139
|
+
MLX_API array multivariate_normal(
|
|
140
|
+
const array& mean,
|
|
141
|
+
const array& cov,
|
|
142
|
+
const Shape& shape,
|
|
143
|
+
Dtype dtype,
|
|
144
|
+
const std::optional<array>& key = std::nullopt,
|
|
145
|
+
StreamOrDevice s = {});
|
|
146
|
+
|
|
147
|
+
/** Generate integer samples uniformly at random */
|
|
148
|
+
MLX_API array randint(
|
|
149
|
+
const array& low,
|
|
150
|
+
const array& high,
|
|
151
|
+
const Shape& shape,
|
|
152
|
+
Dtype dtype = int32,
|
|
153
|
+
const std::optional<array>& key = std::nullopt,
|
|
154
|
+
StreamOrDevice s = {});
|
|
155
|
+
|
|
156
|
+
template <typename T, typename U>
|
|
157
|
+
array randint(
|
|
158
|
+
T low,
|
|
159
|
+
U high,
|
|
160
|
+
const Shape& shape,
|
|
161
|
+
Dtype dtype = int32,
|
|
162
|
+
const std::optional<array>& key = std::nullopt,
|
|
163
|
+
StreamOrDevice s = {}) {
|
|
164
|
+
return randint(array(low), array(high), shape, dtype, key, to_stream(s));
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/** Generate binary variables with probability to be true equal to p */
|
|
168
|
+
MLX_API array bernoulli(
|
|
169
|
+
const array& p,
|
|
170
|
+
const Shape& shape,
|
|
171
|
+
const std::optional<array>& key = std::nullopt,
|
|
172
|
+
StreamOrDevice s = {});
|
|
173
|
+
MLX_API array bernoulli(
|
|
174
|
+
const array& p,
|
|
175
|
+
const std::optional<array>& key = std::nullopt,
|
|
176
|
+
StreamOrDevice s = {});
|
|
177
|
+
|
|
178
|
+
template <typename T>
|
|
179
|
+
array bernoulli(
|
|
180
|
+
T p,
|
|
181
|
+
const std::optional<array>& key = std::nullopt,
|
|
182
|
+
StreamOrDevice s = {}) {
|
|
183
|
+
return bernoulli(array(p), key, s);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
template <typename T>
|
|
187
|
+
array bernoulli(
|
|
188
|
+
T p,
|
|
189
|
+
const Shape& shape,
|
|
190
|
+
const std::optional<array>& key = std::nullopt,
|
|
191
|
+
StreamOrDevice s = {}) {
|
|
192
|
+
return bernoulli(array(p), shape, key, s);
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
MLX_API array bernoulli(
|
|
196
|
+
const std::optional<array>& key = std::nullopt,
|
|
197
|
+
StreamOrDevice s = {});
|
|
198
|
+
|
|
199
|
+
MLX_API array truncated_normal(
|
|
200
|
+
const array& lower,
|
|
201
|
+
const array& upper,
|
|
202
|
+
const Shape& shape,
|
|
203
|
+
Dtype dtype = float32,
|
|
204
|
+
const std::optional<array>& key = std::nullopt,
|
|
205
|
+
StreamOrDevice s = {});
|
|
206
|
+
|
|
207
|
+
MLX_API array truncated_normal(
|
|
208
|
+
const array& lower,
|
|
209
|
+
const array& upper,
|
|
210
|
+
Dtype dtype = float32,
|
|
211
|
+
const std::optional<array>& key = std::nullopt,
|
|
212
|
+
StreamOrDevice s = {});
|
|
213
|
+
|
|
214
|
+
MLX_API array gumbel(
|
|
215
|
+
const Shape& shape,
|
|
216
|
+
Dtype dtype = float32,
|
|
217
|
+
const std::optional<array>& key = std::nullopt,
|
|
218
|
+
StreamOrDevice s = {});
|
|
219
|
+
|
|
220
|
+
MLX_API array categorical(
|
|
221
|
+
const array& logits,
|
|
222
|
+
int axis,
|
|
223
|
+
const Shape& shape,
|
|
224
|
+
const std::optional<array>& key = std::nullopt,
|
|
225
|
+
StreamOrDevice s = {});
|
|
226
|
+
|
|
227
|
+
MLX_API array categorical(
|
|
228
|
+
const array& logits_,
|
|
229
|
+
int axis,
|
|
230
|
+
int num_samples,
|
|
231
|
+
const std::optional<array>& key = std::nullopt,
|
|
232
|
+
StreamOrDevice s = {});
|
|
233
|
+
|
|
234
|
+
MLX_API array categorical(
|
|
235
|
+
const array& logits,
|
|
236
|
+
int axis = -1,
|
|
237
|
+
const std::optional<array>& key = std::nullopt,
|
|
238
|
+
StreamOrDevice s = {});
|
|
239
|
+
|
|
240
|
+
/** Generate samples from the laplace distribution. */
|
|
241
|
+
MLX_API array laplace(
|
|
242
|
+
const Shape& shape,
|
|
243
|
+
Dtype dtype,
|
|
244
|
+
const float loc,
|
|
245
|
+
const float scale,
|
|
246
|
+
const std::optional<array>& key = std::nullopt,
|
|
247
|
+
StreamOrDevice s = {});
|
|
248
|
+
inline array laplace(
|
|
249
|
+
const Shape& shape,
|
|
250
|
+
const float loc,
|
|
251
|
+
const float scale,
|
|
252
|
+
const std::optional<array>& key = std::nullopt,
|
|
253
|
+
StreamOrDevice s = {}) {
|
|
254
|
+
return laplace(shape, float32, loc, scale, key, s);
|
|
255
|
+
}
|
|
256
|
+
inline array laplace(
|
|
257
|
+
const Shape& shape,
|
|
258
|
+
const Dtype dtype,
|
|
259
|
+
const std::optional<array>& key = std::nullopt,
|
|
260
|
+
StreamOrDevice s = {}) {
|
|
261
|
+
return laplace(shape, dtype, 0.0, 1.0, key, s);
|
|
262
|
+
}
|
|
263
|
+
inline array laplace(
|
|
264
|
+
const Shape& shape,
|
|
265
|
+
const std::optional<array>& key = std::nullopt,
|
|
266
|
+
StreamOrDevice s = {}) {
|
|
267
|
+
return laplace(shape, float32, 0.0, 1.0, key, s);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
/* Randomly permute the elements of x along the given axis. */
|
|
271
|
+
MLX_API array permutation(
|
|
272
|
+
const array& x,
|
|
273
|
+
int axis = 0,
|
|
274
|
+
const std::optional<array>& key = std::nullopt,
|
|
275
|
+
StreamOrDevice s = {});
|
|
276
|
+
|
|
277
|
+
/* A random permutation of `arange(x)` */
|
|
278
|
+
MLX_API array permutation(
|
|
279
|
+
int x,
|
|
280
|
+
const std::optional<array>& key = std::nullopt,
|
|
281
|
+
StreamOrDevice s = {});
|
|
282
|
+
|
|
283
|
+
} // namespace mlx::core::random
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/scheduler.h"
|
|
4
|
+
#include "mlx/backend/gpu/device_info.h"
|
|
5
|
+
#include "mlx/backend/gpu/eval.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
Stream default_stream(Device d) {
|
|
10
|
+
if (!gpu::is_available() && d == Device::gpu) {
|
|
11
|
+
throw std::invalid_argument(
|
|
12
|
+
"[default_stream] Cannot get gpu stream without gpu backend.");
|
|
13
|
+
}
|
|
14
|
+
return scheduler::scheduler().get_default_stream(d);
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
void set_default_stream(Stream s) {
|
|
18
|
+
if (!gpu::is_available() && s.device == Device::gpu) {
|
|
19
|
+
throw std::invalid_argument(
|
|
20
|
+
"[set_default_stream] Cannot set gpu stream without gpu backend.");
|
|
21
|
+
}
|
|
22
|
+
return scheduler::scheduler().set_default_stream(s);
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
Stream get_stream(int index) {
|
|
26
|
+
return scheduler::scheduler().get_stream(index);
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
Stream new_stream(Device d) {
|
|
30
|
+
if (!gpu::is_available() && d == Device::gpu) {
|
|
31
|
+
throw std::invalid_argument(
|
|
32
|
+
"[new_stream] Cannot make gpu stream without gpu backend.");
|
|
33
|
+
}
|
|
34
|
+
return scheduler::scheduler().new_stream(d);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
Stream new_stream() {
|
|
38
|
+
return scheduler::scheduler().new_stream(default_device());
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
void synchronize(Stream s) {
|
|
42
|
+
if (s.device == mlx::core::Device::cpu) {
|
|
43
|
+
auto p = std::make_shared<std::promise<void>>();
|
|
44
|
+
std::future<void> f = p->get_future();
|
|
45
|
+
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
|
46
|
+
f.wait();
|
|
47
|
+
} else {
|
|
48
|
+
gpu::synchronize(s);
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
void synchronize() {
|
|
53
|
+
synchronize(default_stream(default_device()));
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
namespace scheduler {
|
|
57
|
+
|
|
58
|
+
/** A singleton scheduler to manage devices, streams, and task execution. */
|
|
59
|
+
Scheduler& scheduler() {
|
|
60
|
+
// Leak the scheduler on Windows to avoid joining threads on exit, can be
|
|
61
|
+
// removed after Visual Studio fixes bug:
|
|
62
|
+
// https://developercommunity.visualstudio.com/t/1654756
|
|
63
|
+
#ifdef _WIN32
|
|
64
|
+
static Scheduler* scheduler = new Scheduler;
|
|
65
|
+
return *scheduler;
|
|
66
|
+
#else
|
|
67
|
+
static Scheduler scheduler;
|
|
68
|
+
return scheduler;
|
|
69
|
+
#endif
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
} // namespace scheduler
|
|
73
|
+
} // namespace mlx::core
|
data/mlx/mlx/scheduler.h
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <atomic>
|
|
6
|
+
#include <future>
|
|
7
|
+
#include <queue>
|
|
8
|
+
#include <thread>
|
|
9
|
+
#include <unordered_map>
|
|
10
|
+
|
|
11
|
+
#include "mlx/api.h"
|
|
12
|
+
#include "mlx/backend/gpu/eval.h"
|
|
13
|
+
#include "mlx/device.h"
|
|
14
|
+
#include "mlx/stream.h"
|
|
15
|
+
|
|
16
|
+
namespace mlx::core::scheduler {
|
|
17
|
+
|
|
18
|
+
struct StreamThread {
|
|
19
|
+
std::mutex mtx;
|
|
20
|
+
std::queue<std::function<void()>> q;
|
|
21
|
+
std::condition_variable cond;
|
|
22
|
+
bool stop;
|
|
23
|
+
std::thread thread;
|
|
24
|
+
|
|
25
|
+
StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {}
|
|
26
|
+
|
|
27
|
+
~StreamThread() {
|
|
28
|
+
{
|
|
29
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
30
|
+
stop = true;
|
|
31
|
+
}
|
|
32
|
+
cond.notify_one();
|
|
33
|
+
thread.join();
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
void thread_fn() {
|
|
37
|
+
while (true) {
|
|
38
|
+
std::function<void()> task;
|
|
39
|
+
{
|
|
40
|
+
std::unique_lock<std::mutex> lk(mtx);
|
|
41
|
+
cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
|
|
42
|
+
if (q.empty() && stop) {
|
|
43
|
+
return;
|
|
44
|
+
}
|
|
45
|
+
task = std::move(q.front());
|
|
46
|
+
q.pop();
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
task();
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
template <typename F>
|
|
54
|
+
void enqueue(F&& f) {
|
|
55
|
+
{
|
|
56
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
57
|
+
if (stop) {
|
|
58
|
+
throw std::runtime_error(
|
|
59
|
+
"Cannot enqueue work after stream is stopped.");
|
|
60
|
+
}
|
|
61
|
+
q.emplace(std::forward<F>(f));
|
|
62
|
+
}
|
|
63
|
+
cond.notify_one();
|
|
64
|
+
}
|
|
65
|
+
};
|
|
66
|
+
|
|
67
|
+
class Scheduler {
|
|
68
|
+
public:
|
|
69
|
+
Scheduler() : n_active_tasks_(0) {
|
|
70
|
+
if (is_available(Device::gpu)) {
|
|
71
|
+
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
|
72
|
+
}
|
|
73
|
+
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// Not copyable or moveable
|
|
77
|
+
Scheduler(const Scheduler&) = delete;
|
|
78
|
+
Scheduler(Scheduler&&) = delete;
|
|
79
|
+
Scheduler& operator=(const Scheduler&) = delete;
|
|
80
|
+
Scheduler& operator=(Scheduler&&) = delete;
|
|
81
|
+
|
|
82
|
+
Stream new_stream(const Device& d) {
|
|
83
|
+
streams_.emplace_back(streams_.size(), d);
|
|
84
|
+
if (d == Device::gpu) {
|
|
85
|
+
threads_.push_back(nullptr);
|
|
86
|
+
gpu::new_stream(streams_.back());
|
|
87
|
+
} else {
|
|
88
|
+
threads_.push_back(new StreamThread{});
|
|
89
|
+
}
|
|
90
|
+
return streams_.back();
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
template <typename F>
|
|
94
|
+
void enqueue(const Stream& stream, F&& f);
|
|
95
|
+
|
|
96
|
+
Stream get_default_stream(const Device& d) const {
|
|
97
|
+
return default_streams_.at(d.type);
|
|
98
|
+
}
|
|
99
|
+
Stream get_stream(int index) const {
|
|
100
|
+
return streams_.at(index);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
void set_default_stream(const Stream& s) {
|
|
104
|
+
default_streams_.at(s.device.type) = s;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
void notify_new_task(const Stream& stream) {
|
|
108
|
+
{
|
|
109
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
110
|
+
n_active_tasks_++;
|
|
111
|
+
}
|
|
112
|
+
completion_cv.notify_all();
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
void notify_task_completion(const Stream& stream) {
|
|
116
|
+
{
|
|
117
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
118
|
+
n_active_tasks_--;
|
|
119
|
+
}
|
|
120
|
+
completion_cv.notify_all();
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
int n_active_tasks() const {
|
|
124
|
+
return n_active_tasks_;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
void wait_for_one() {
|
|
128
|
+
std::unique_lock<std::mutex> lk(mtx);
|
|
129
|
+
int n_tasks_old = n_active_tasks();
|
|
130
|
+
if (n_tasks_old > 1) {
|
|
131
|
+
completion_cv.wait(lk, [this, n_tasks_old] {
|
|
132
|
+
return this->n_active_tasks() < n_tasks_old;
|
|
133
|
+
});
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
~Scheduler() {
|
|
138
|
+
for (auto s : streams_) {
|
|
139
|
+
try {
|
|
140
|
+
synchronize(s);
|
|
141
|
+
} catch (const std::runtime_error&) {
|
|
142
|
+
// ignore errors if synch fails
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
for (auto t : threads_) {
|
|
146
|
+
if (t != nullptr) {
|
|
147
|
+
delete t;
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
private:
|
|
153
|
+
int n_active_tasks_;
|
|
154
|
+
std::vector<StreamThread*> threads_;
|
|
155
|
+
std::vector<Stream> streams_;
|
|
156
|
+
std::unordered_map<Device::DeviceType, Stream> default_streams_;
|
|
157
|
+
std::condition_variable completion_cv;
|
|
158
|
+
std::mutex mtx;
|
|
159
|
+
};
|
|
160
|
+
|
|
161
|
+
template <typename F>
|
|
162
|
+
void Scheduler::enqueue(const Stream& stream, F&& f) {
|
|
163
|
+
threads_[stream.index]->enqueue(std::forward<F>(f));
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
MLX_API Scheduler& scheduler();
|
|
167
|
+
|
|
168
|
+
template <typename F>
|
|
169
|
+
void enqueue(const Stream& stream, F&& f) {
|
|
170
|
+
scheduler().enqueue(stream, std::forward<F>(f));
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
inline int n_active_tasks() {
|
|
174
|
+
return scheduler().n_active_tasks();
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
inline void notify_new_task(const Stream& stream) {
|
|
178
|
+
scheduler().notify_new_task(stream);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
inline void notify_task_completion(const Stream& stream) {
|
|
182
|
+
scheduler().notify_task_completion(stream);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
inline void wait_for_one() {
|
|
186
|
+
scheduler().wait_for_one();
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
} // namespace mlx::core::scheduler
|