mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
data/lib/mlx/core.rb
ADDED
|
@@ -0,0 +1,1678 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "open3"
|
|
4
|
+
require "tmpdir"
|
|
5
|
+
|
|
6
|
+
module MLX
|
|
7
|
+
module Core
|
|
8
|
+
class NativeUnavailableError < StandardError; end
|
|
9
|
+
|
|
10
|
+
module DeviceType
|
|
11
|
+
module_function
|
|
12
|
+
|
|
13
|
+
def cpu
|
|
14
|
+
:cpu
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def gpu
|
|
18
|
+
:gpu
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
class Finfo
|
|
23
|
+
FLOAT_INFO = {
|
|
24
|
+
"float16" => { min: -65_504.0, max: 65_504.0, eps: 9.765625e-4 },
|
|
25
|
+
"bfloat16" => { min: -3.389531389e38, max: 3.389531389e38, eps: 7.8125e-3 },
|
|
26
|
+
"float32" => { min: -3.4028235e38, max: 3.4028235e38, eps: 1.1920929e-7 },
|
|
27
|
+
"float64" => { min: -Float::MAX, max: Float::MAX, eps: Float::EPSILON },
|
|
28
|
+
"complex64" => { min: -3.4028235e38, max: 3.4028235e38, eps: 1.1920929e-7 }
|
|
29
|
+
}.freeze
|
|
30
|
+
|
|
31
|
+
attr_reader :dtype, :min, :max, :eps
|
|
32
|
+
|
|
33
|
+
def initialize(dtype)
|
|
34
|
+
@dtype = dtype
|
|
35
|
+
info = FLOAT_INFO[dtype_name(dtype)]
|
|
36
|
+
raise ArgumentError, "unsupported dtype for finfo: #{dtype_name(dtype)}" if info.nil?
|
|
37
|
+
|
|
38
|
+
@min = info[:min]
|
|
39
|
+
@max = info[:max]
|
|
40
|
+
@eps = info[:eps]
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
private
|
|
44
|
+
|
|
45
|
+
def dtype_name(dtype)
|
|
46
|
+
if dtype.respond_to?(:name)
|
|
47
|
+
dtype.name.to_s
|
|
48
|
+
else
|
|
49
|
+
dtype.to_s
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
class Iinfo
|
|
55
|
+
INT_INFO = {
|
|
56
|
+
"bool_" => { min: 0, max: 1 },
|
|
57
|
+
"uint8" => { min: 0, max: 255 },
|
|
58
|
+
"uint16" => { min: 0, max: 65_535 },
|
|
59
|
+
"uint32" => { min: 0, max: 4_294_967_295 },
|
|
60
|
+
"uint64" => { min: 0, max: 18_446_744_073_709_551_615 },
|
|
61
|
+
"int8" => { min: -128, max: 127 },
|
|
62
|
+
"int16" => { min: -32_768, max: 32_767 },
|
|
63
|
+
"int32" => { min: -2_147_483_648, max: 2_147_483_647 },
|
|
64
|
+
"int64" => { min: -9_223_372_036_854_775_808, max: 9_223_372_036_854_775_807 }
|
|
65
|
+
}.freeze
|
|
66
|
+
|
|
67
|
+
attr_reader :dtype, :min, :max
|
|
68
|
+
|
|
69
|
+
def initialize(dtype)
|
|
70
|
+
@dtype = dtype
|
|
71
|
+
info = INT_INFO[dtype_name(dtype)]
|
|
72
|
+
raise ArgumentError, "unsupported dtype for iinfo: #{dtype_name(dtype)}" if info.nil?
|
|
73
|
+
|
|
74
|
+
@min = info[:min]
|
|
75
|
+
@max = info[:max]
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
private
|
|
79
|
+
|
|
80
|
+
def dtype_name(dtype)
|
|
81
|
+
if dtype.respond_to?(:name)
|
|
82
|
+
dtype.name.to_s
|
|
83
|
+
else
|
|
84
|
+
dtype.to_s
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
class ArrayLike
|
|
90
|
+
attr_reader :object
|
|
91
|
+
|
|
92
|
+
def initialize(object)
|
|
93
|
+
unless object.respond_to?(:__mlx__array__)
|
|
94
|
+
raise TypeError, "ArrayLike requires an object that responds to __mlx__array__"
|
|
95
|
+
end
|
|
96
|
+
@object = object
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
def to_a
|
|
100
|
+
out = @object.__mlx__array__
|
|
101
|
+
raise TypeError, "__mlx__array__ must return MLX::Core::Array" unless out.is_a?(MLX::Core::Array)
|
|
102
|
+
|
|
103
|
+
out
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
class ArrayIterator
|
|
108
|
+
def initialize(array)
|
|
109
|
+
@array = array
|
|
110
|
+
@index = 0
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def __iter__
|
|
114
|
+
self
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
def __next__
|
|
118
|
+
raise StopIteration if @index >= @array.__len__
|
|
119
|
+
|
|
120
|
+
out = @array.__getitem__(@index)
|
|
121
|
+
@index += 1
|
|
122
|
+
out
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
alias next __next__
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
class ArrayAt
|
|
129
|
+
def initialize(array)
|
|
130
|
+
@array = array
|
|
131
|
+
@indices = nil
|
|
132
|
+
end
|
|
133
|
+
|
|
134
|
+
def [](indices)
|
|
135
|
+
@indices = indices
|
|
136
|
+
self
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def add(value)
|
|
140
|
+
apply(value) { |lhs, rhs| MLX::Core.add(lhs, rhs) }
|
|
141
|
+
end
|
|
142
|
+
|
|
143
|
+
def subtract(value)
|
|
144
|
+
apply(value) { |lhs, rhs| MLX::Core.subtract(lhs, rhs) }
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
def multiply(value)
|
|
148
|
+
apply(value) { |lhs, rhs| MLX::Core.multiply(lhs, rhs) }
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
def divide(value)
|
|
152
|
+
apply(value) { |lhs, rhs| MLX::Core.divide(lhs, rhs) }
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def maximum(value)
|
|
156
|
+
apply(value) { |lhs, rhs| MLX::Core.maximum(lhs, rhs) }
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
def minimum(value)
|
|
160
|
+
apply(value) { |lhs, rhs| MLX::Core.minimum(lhs, rhs) }
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
private
|
|
164
|
+
|
|
165
|
+
def apply(value)
|
|
166
|
+
raise ArgumentError, "must provide indices to array.at first" if @indices.nil?
|
|
167
|
+
|
|
168
|
+
current = @array.__getitem__(@indices)
|
|
169
|
+
rhs = value.is_a?(MLX::Core::Array) ? value : MLX::Core.array(value, current.dtype)
|
|
170
|
+
updated = yield(current, rhs)
|
|
171
|
+
@array.__setitem__(@indices, updated)
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
class DLPackCapsule
|
|
176
|
+
attr_reader :array, :dtype, :shape, :device, :stream
|
|
177
|
+
|
|
178
|
+
def initialize(array, device:, stream: nil)
|
|
179
|
+
unless array.is_a?(MLX::Core::Array)
|
|
180
|
+
raise TypeError, "DLPackCapsule requires an MLX::Core::Array"
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
@array = array
|
|
184
|
+
@dtype = array.dtype
|
|
185
|
+
@shape = array.shape.dup.freeze
|
|
186
|
+
@device = device.dup.freeze
|
|
187
|
+
@stream = stream
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
def to_h
|
|
191
|
+
{
|
|
192
|
+
"dtype" => (dtype.respond_to?(:name) ? dtype.name.to_s : dtype.to_s),
|
|
193
|
+
"shape" => shape,
|
|
194
|
+
"device" => device,
|
|
195
|
+
"stream" => stream
|
|
196
|
+
}
|
|
197
|
+
end
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
class CustomFunction
|
|
201
|
+
def initialize(fun)
|
|
202
|
+
raise TypeError, "expected callable object" unless fun.respond_to?(:call)
|
|
203
|
+
|
|
204
|
+
@fun = fun
|
|
205
|
+
@vjp = nil
|
|
206
|
+
@jvp = nil
|
|
207
|
+
@vmap = nil
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
def call(*args, **kwargs, &block)
|
|
211
|
+
@fun.call(*args, **kwargs, &block)
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
def vjp(fun = nil, &block)
|
|
215
|
+
@vjp = fun || block
|
|
216
|
+
raise ArgumentError, "expected callable object" unless @vjp.respond_to?(:call)
|
|
217
|
+
|
|
218
|
+
@vjp
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def jvp(fun = nil, &block)
|
|
222
|
+
@jvp = fun || block
|
|
223
|
+
raise ArgumentError, "expected callable object" unless @jvp.respond_to?(:call)
|
|
224
|
+
|
|
225
|
+
@jvp
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def vmap(fun = nil, &block)
|
|
229
|
+
@vmap = fun || block
|
|
230
|
+
raise ArgumentError, "expected callable object" unless @vmap.respond_to?(:call)
|
|
231
|
+
|
|
232
|
+
@vmap
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
def custom_vjp?
|
|
236
|
+
!@vjp.nil?
|
|
237
|
+
end
|
|
238
|
+
|
|
239
|
+
def custom_jvp?
|
|
240
|
+
!@jvp.nil?
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
def custom_vmap?
|
|
244
|
+
!@vmap.nil?
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
def call_custom_vjp(primals, cotangents, outputs)
|
|
248
|
+
raise ArgumentError, "custom vjp is not defined" unless custom_vjp?
|
|
249
|
+
|
|
250
|
+
@vjp.call(primals, cotangents, outputs)
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
def call_custom_jvp(primals, tangents)
|
|
254
|
+
raise ArgumentError, "custom jvp is not defined" unless custom_jvp?
|
|
255
|
+
|
|
256
|
+
@jvp.call(primals, tangents)
|
|
257
|
+
end
|
|
258
|
+
|
|
259
|
+
def call_custom_vmap(inputs, axes)
|
|
260
|
+
raise ArgumentError, "custom vmap is not defined" unless custom_vmap?
|
|
261
|
+
|
|
262
|
+
@vmap.call(inputs, axes)
|
|
263
|
+
end
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
class StreamContext
|
|
267
|
+
def initialize(target)
|
|
268
|
+
@target = target
|
|
269
|
+
@previous_device = nil
|
|
270
|
+
@previous_stream = nil
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
def enter
|
|
274
|
+
@previous_device = MLX::Core.default_device
|
|
275
|
+
@previous_stream = MLX::Core.default_stream(@previous_device)
|
|
276
|
+
MLX::Core.native_stream(@target)
|
|
277
|
+
self
|
|
278
|
+
end
|
|
279
|
+
|
|
280
|
+
def exit(*)
|
|
281
|
+
return self if @previous_device.nil?
|
|
282
|
+
|
|
283
|
+
MLX::Core.set_default_device(@previous_device)
|
|
284
|
+
MLX::Core.set_default_stream(@previous_stream)
|
|
285
|
+
@previous_device = nil
|
|
286
|
+
@previous_stream = nil
|
|
287
|
+
self
|
|
288
|
+
end
|
|
289
|
+
end
|
|
290
|
+
|
|
291
|
+
PY_EXTRACT_NPZ = <<~PY.freeze
|
|
292
|
+
import os, sys, zipfile
|
|
293
|
+
src = sys.argv[1]
|
|
294
|
+
out_dir = sys.argv[2]
|
|
295
|
+
with zipfile.ZipFile(src, "r") as zf:
|
|
296
|
+
zf.extractall(out_dir)
|
|
297
|
+
PY
|
|
298
|
+
|
|
299
|
+
PY_BUILD_NPZ = <<~PY.freeze
|
|
300
|
+
import os, sys, zipfile
|
|
301
|
+
out_path = sys.argv[1]
|
|
302
|
+
in_dir = sys.argv[2]
|
|
303
|
+
compressed = sys.argv[3] == "1"
|
|
304
|
+
mode = zipfile.ZIP_DEFLATED if compressed else zipfile.ZIP_STORED
|
|
305
|
+
with zipfile.ZipFile(out_path, "w", compression=mode, allowZip64=True) as zf:
|
|
306
|
+
for name in sorted(os.listdir(in_dir)):
|
|
307
|
+
zf.write(os.path.join(in_dir, name), arcname=name)
|
|
308
|
+
PY
|
|
309
|
+
|
|
310
|
+
module_function
|
|
311
|
+
|
|
312
|
+
def ensure_native!
|
|
313
|
+
return if MLX.native_available?
|
|
314
|
+
|
|
315
|
+
raise NativeUnavailableError,
|
|
316
|
+
"MLX native extension is unavailable. Build ext/mlx first."
|
|
317
|
+
end
|
|
318
|
+
|
|
319
|
+
def available?
|
|
320
|
+
MLX.native_available?
|
|
321
|
+
end
|
|
322
|
+
|
|
323
|
+
class << self
|
|
324
|
+
alias_method :native_load, :load if method_defined?(:load)
|
|
325
|
+
alias_method :native_grad, :grad if method_defined?(:grad) && !method_defined?(:native_grad)
|
|
326
|
+
alias_method :native_value_and_grad,
|
|
327
|
+
:value_and_grad if method_defined?(:value_and_grad) && !method_defined?(:native_value_and_grad)
|
|
328
|
+
alias_method :native_compile, :compile if method_defined?(:compile) && !method_defined?(:native_compile)
|
|
329
|
+
alias_method :native_checkpoint,
|
|
330
|
+
:checkpoint if method_defined?(:checkpoint) && !method_defined?(:native_checkpoint)
|
|
331
|
+
alias_method :native_stream, :stream if method_defined?(:stream) && !method_defined?(:native_stream)
|
|
332
|
+
alias_method :native_jvp, :jvp if method_defined?(:jvp) && !method_defined?(:native_jvp)
|
|
333
|
+
alias_method :native_vjp, :vjp if method_defined?(:vjp) && !method_defined?(:native_vjp)
|
|
334
|
+
alias_method :native_vmap, :vmap if method_defined?(:vmap) && !method_defined?(:native_vmap)
|
|
335
|
+
alias_method :native_export_to_dot,
|
|
336
|
+
:export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
|
|
337
|
+
|
|
338
|
+
ARRAY_LEAF = :__mlx_array_leaf__
|
|
339
|
+
|
|
340
|
+
def load(file, format = nil, return_metadata = false)
|
|
341
|
+
ensure_native!
|
|
342
|
+
format_name = (format || infer_format(file)).to_s
|
|
343
|
+
if format_name == "npz"
|
|
344
|
+
raise ArgumentError, "metadata not supported for format npz" if return_metadata
|
|
345
|
+
|
|
346
|
+
return load_npz(file)
|
|
347
|
+
end
|
|
348
|
+
|
|
349
|
+
native_load(file, format, return_metadata)
|
|
350
|
+
end
|
|
351
|
+
|
|
352
|
+
def savez(file, *args, **kwargs)
|
|
353
|
+
ensure_native!
|
|
354
|
+
save_npz(file, args, kwargs, false)
|
|
355
|
+
end
|
|
356
|
+
|
|
357
|
+
def savez_compressed(file, *args, **kwargs)
|
|
358
|
+
ensure_native!
|
|
359
|
+
save_npz(file, args, kwargs, true)
|
|
360
|
+
end
|
|
361
|
+
|
|
362
|
+
def export_to_dot(target, *outputs)
|
|
363
|
+
ensure_native!
|
|
364
|
+
raise ArgumentError, "export_to_dot expects at least one output" if outputs.empty?
|
|
365
|
+
|
|
366
|
+
if target.respond_to?(:write)
|
|
367
|
+
Dir.mktmpdir do |dir|
|
|
368
|
+
path = File.join(dir, "graph.dot")
|
|
369
|
+
native_export_to_dot(path, *outputs)
|
|
370
|
+
content = File.binread(path)
|
|
371
|
+
target.write(content)
|
|
372
|
+
target.rewind if target.respond_to?(:rewind)
|
|
373
|
+
content
|
|
374
|
+
end
|
|
375
|
+
else
|
|
376
|
+
native_export_to_dot(target, *outputs)
|
|
377
|
+
end
|
|
378
|
+
end
|
|
379
|
+
|
|
380
|
+
def full_like(array, fill_value, dtype = nil)
|
|
381
|
+
ensure_native!
|
|
382
|
+
raise TypeError, "full_like expects an MLX::Core::Array" unless array.is_a?(MLX::Core::Array)
|
|
383
|
+
|
|
384
|
+
target_dtype = dtype || array.dtype
|
|
385
|
+
full(array.shape, fill_value, target_dtype)
|
|
386
|
+
end
|
|
387
|
+
|
|
388
|
+
def grad(fun, argnums = nil, argnames = nil)
|
|
389
|
+
ensure_native!
|
|
390
|
+
if fun.is_a?(CustomFunction) && fun.custom_vjp?
|
|
391
|
+
return build_custom_vjp_grad_function(fun)
|
|
392
|
+
end
|
|
393
|
+
|
|
394
|
+
argnums_v, argnames_v = normalize_diff_targets(argnums, argnames)
|
|
395
|
+
build_grad_like_function(fun, argnums_v, argnames_v, false)
|
|
396
|
+
end
|
|
397
|
+
|
|
398
|
+
def value_and_grad(fun, argnums = nil, argnames = nil)
|
|
399
|
+
ensure_native!
|
|
400
|
+
if fun.is_a?(CustomFunction) && fun.custom_vjp?
|
|
401
|
+
return build_custom_vjp_value_and_grad_function(fun)
|
|
402
|
+
end
|
|
403
|
+
|
|
404
|
+
argnums_v, argnames_v = normalize_diff_targets(argnums, argnames)
|
|
405
|
+
build_grad_like_function(fun, argnums_v, argnames_v, true)
|
|
406
|
+
end
|
|
407
|
+
|
|
408
|
+
def compile(fun, inputs = nil, outputs = nil, shapeless = false)
|
|
409
|
+
ensure_native!
|
|
410
|
+
cache = {}
|
|
411
|
+
|
|
412
|
+
lambda do |*args, **kwargs|
|
|
413
|
+
flat_inputs = []
|
|
414
|
+
input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false)
|
|
415
|
+
key = structure_cache_key(input_spec)
|
|
416
|
+
|
|
417
|
+
entry = cache[key]
|
|
418
|
+
unless entry
|
|
419
|
+
output_spec = nil
|
|
420
|
+
lifted = lambda do |*flat_vars|
|
|
421
|
+
rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
|
|
422
|
+
unless cursor == flat_vars.length
|
|
423
|
+
raise RuntimeError, "internal input reconstruction mismatch"
|
|
424
|
+
end
|
|
425
|
+
|
|
426
|
+
call_args = rebuilt[0]
|
|
427
|
+
call_kwargs = rebuilt[1]
|
|
428
|
+
raw_output = fun.call(*call_args, **call_kwargs)
|
|
429
|
+
|
|
430
|
+
flat_output = []
|
|
431
|
+
output_spec = flatten_tree_spec(raw_output, flat_output, false)
|
|
432
|
+
flat_output
|
|
433
|
+
end
|
|
434
|
+
|
|
435
|
+
compiled = native_compile(lifted, inputs, outputs, shapeless)
|
|
436
|
+
entry = { fn: compiled, output_spec: -> { output_spec } }
|
|
437
|
+
cache[key] = entry
|
|
438
|
+
end
|
|
439
|
+
|
|
440
|
+
flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "compiled output")
|
|
441
|
+
spec = entry[:output_spec].call
|
|
442
|
+
raise RuntimeError, "missing output structure from compiled function" if spec.nil?
|
|
443
|
+
|
|
444
|
+
rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
|
|
445
|
+
unless cursor == flat_output.length
|
|
446
|
+
raise RuntimeError, "internal output reconstruction mismatch"
|
|
447
|
+
end
|
|
448
|
+
rebuilt
|
|
449
|
+
end
|
|
450
|
+
end
|
|
451
|
+
|
|
452
|
+
def checkpoint(fun)
|
|
453
|
+
ensure_native!
|
|
454
|
+
cache = {}
|
|
455
|
+
|
|
456
|
+
lambda do |*args, **kwargs|
|
|
457
|
+
flat_inputs = []
|
|
458
|
+
input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false)
|
|
459
|
+
key = structure_cache_key(input_spec)
|
|
460
|
+
|
|
461
|
+
entry = cache[key]
|
|
462
|
+
unless entry
|
|
463
|
+
output_spec = nil
|
|
464
|
+
lifted = lambda do |*flat_vars|
|
|
465
|
+
rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
|
|
466
|
+
unless cursor == flat_vars.length
|
|
467
|
+
raise RuntimeError, "internal input reconstruction mismatch"
|
|
468
|
+
end
|
|
469
|
+
|
|
470
|
+
call_args = rebuilt[0]
|
|
471
|
+
call_kwargs = rebuilt[1]
|
|
472
|
+
raw_output = fun.call(*call_args, **call_kwargs)
|
|
473
|
+
|
|
474
|
+
flat_output = []
|
|
475
|
+
output_spec = flatten_tree_spec(raw_output, flat_output, false)
|
|
476
|
+
flat_output
|
|
477
|
+
end
|
|
478
|
+
|
|
479
|
+
checkpointed = native_checkpoint(lifted)
|
|
480
|
+
entry = { fn: checkpointed, output_spec: -> { output_spec } }
|
|
481
|
+
cache[key] = entry
|
|
482
|
+
end
|
|
483
|
+
|
|
484
|
+
flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "checkpoint output")
|
|
485
|
+
spec = entry[:output_spec].call
|
|
486
|
+
raise RuntimeError, "missing output structure from checkpoint function" if spec.nil?
|
|
487
|
+
|
|
488
|
+
rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
|
|
489
|
+
unless cursor == flat_output.length
|
|
490
|
+
raise RuntimeError, "internal output reconstruction mismatch"
|
|
491
|
+
end
|
|
492
|
+
rebuilt
|
|
493
|
+
end
|
|
494
|
+
end
|
|
495
|
+
|
|
496
|
+
def stream(stream_or_device, &block)
|
|
497
|
+
ensure_native!
|
|
498
|
+
if block_given?
|
|
499
|
+
native_stream(stream_or_device, &block)
|
|
500
|
+
else
|
|
501
|
+
StreamContext.new(stream_or_device)
|
|
502
|
+
end
|
|
503
|
+
end
|
|
504
|
+
|
|
505
|
+
def jvp(fun, primals, tangents)
|
|
506
|
+
ensure_native!
|
|
507
|
+
if fun.is_a?(CustomFunction) && fun.custom_jvp?
|
|
508
|
+
return custom_jvp(fun, primals, tangents)
|
|
509
|
+
end
|
|
510
|
+
native_jvp(fun, primals, tangents)
|
|
511
|
+
end
|
|
512
|
+
|
|
513
|
+
def vjp(fun, primals, cotangents)
|
|
514
|
+
ensure_native!
|
|
515
|
+
if fun.is_a?(CustomFunction) && fun.custom_vjp?
|
|
516
|
+
return custom_vjp(fun, primals, cotangents)
|
|
517
|
+
end
|
|
518
|
+
native_vjp(fun, primals, cotangents)
|
|
519
|
+
end
|
|
520
|
+
|
|
521
|
+
def vmap(fun, in_axes = nil, out_axes = nil)
|
|
522
|
+
ensure_native!
|
|
523
|
+
if fun.is_a?(CustomFunction) && fun.custom_vmap?
|
|
524
|
+
return custom_vmap_callable(fun, in_axes, out_axes)
|
|
525
|
+
end
|
|
526
|
+
native_vmap(fun, in_axes, out_axes)
|
|
527
|
+
end
|
|
528
|
+
|
|
529
|
+
def custom_function(fun = nil, &block)
|
|
530
|
+
callable = fun || block
|
|
531
|
+
raise ArgumentError, "custom_function requires a callable" if callable.nil?
|
|
532
|
+
|
|
533
|
+
CustomFunction.new(callable)
|
|
534
|
+
end
|
|
535
|
+
|
|
536
|
+
def finfo(dtype)
|
|
537
|
+
Finfo.new(dtype)
|
|
538
|
+
end
|
|
539
|
+
|
|
540
|
+
def iinfo(dtype)
|
|
541
|
+
Iinfo.new(dtype)
|
|
542
|
+
end
|
|
543
|
+
|
|
544
|
+
def from_dlpack(dlpack_value)
|
|
545
|
+
case dlpack_value
|
|
546
|
+
when MLX::Core::DLPackCapsule
|
|
547
|
+
dlpack_value.array
|
|
548
|
+
when MLX::Core::Array
|
|
549
|
+
dlpack_value
|
|
550
|
+
else
|
|
551
|
+
raise TypeError, "from_dlpack expects MLX::Core::DLPackCapsule or MLX::Core::Array"
|
|
552
|
+
end
|
|
553
|
+
end
|
|
554
|
+
|
|
555
|
+
private
|
|
556
|
+
|
|
557
|
+
def infer_format(file)
|
|
558
|
+
path = file_path(file)
|
|
559
|
+
ext = File.extname(path).delete_prefix(".")
|
|
560
|
+
raise ArgumentError, "could not infer load format from file extension" if ext.empty?
|
|
561
|
+
|
|
562
|
+
ext
|
|
563
|
+
end
|
|
564
|
+
|
|
565
|
+
def file_path(file)
|
|
566
|
+
if file.respond_to?(:to_path)
|
|
567
|
+
file.to_path.to_s
|
|
568
|
+
else
|
|
569
|
+
file.to_s
|
|
570
|
+
end
|
|
571
|
+
end
|
|
572
|
+
|
|
573
|
+
def python_bin
|
|
574
|
+
ENV.fetch("PYTHON", "python3")
|
|
575
|
+
end
|
|
576
|
+
|
|
577
|
+
def run_python!(*argv)
|
|
578
|
+
stdout, stderr, status = Open3.capture3(*argv)
|
|
579
|
+
return if status.success?
|
|
580
|
+
|
|
581
|
+
raise RuntimeError, <<~MSG
|
|
582
|
+
python command failed: #{argv.join(" ")}
|
|
583
|
+
stdout:
|
|
584
|
+
#{stdout}
|
|
585
|
+
stderr:
|
|
586
|
+
#{stderr}
|
|
587
|
+
MSG
|
|
588
|
+
end
|
|
589
|
+
|
|
590
|
+
def load_npz(file)
|
|
591
|
+
path = file_path(file)
|
|
592
|
+
Dir.mktmpdir("mlx-ruby-npz-load") do |dir|
|
|
593
|
+
run_python!(python_bin, "-c", PY_EXTRACT_NPZ, path, dir)
|
|
594
|
+
out = {}
|
|
595
|
+
Dir.glob(File.join(dir, "**", "*.npy")).sort.each do |npy_path|
|
|
596
|
+
rel = npy_path.delete_prefix(dir + File::SEPARATOR)
|
|
597
|
+
key = rel.end_with?(".npy") ? rel[0...-4] : rel
|
|
598
|
+
out[key] = native_load(npy_path, "npy", false)
|
|
599
|
+
end
|
|
600
|
+
out
|
|
601
|
+
end
|
|
602
|
+
end
|
|
603
|
+
|
|
604
|
+
def save_npz(file, args, kwargs, compressed)
|
|
605
|
+
path = file_path(file)
|
|
606
|
+
path = "#{path}.npz" unless path.end_with?(".npz")
|
|
607
|
+
|
|
608
|
+
arrays = kwargs.transform_keys(&:to_s)
|
|
609
|
+
args.each_with_index do |value, i|
|
|
610
|
+
key = "arr_#{i}"
|
|
611
|
+
if arrays.key?(key)
|
|
612
|
+
raise ArgumentError, "Cannot use un-named variables and keyword #{key}"
|
|
613
|
+
end
|
|
614
|
+
arrays[key] = value
|
|
615
|
+
end
|
|
616
|
+
|
|
617
|
+
Dir.mktmpdir("mlx-ruby-npz-save") do |dir|
|
|
618
|
+
arrays.each do |name, value|
|
|
619
|
+
array_value = value.is_a?(MLX::Core::Array) ? value : MLX::Core.array(value)
|
|
620
|
+
save(File.join(dir, "#{name}.npy"), array_value)
|
|
621
|
+
end
|
|
622
|
+
run_python!(python_bin, "-c", PY_BUILD_NPZ, path, dir, compressed ? "1" : "0")
|
|
623
|
+
end
|
|
624
|
+
|
|
625
|
+
nil
|
|
626
|
+
end
|
|
627
|
+
|
|
628
|
+
def normalize_diff_targets(argnums, argnames)
|
|
629
|
+
argnames_v = normalize_argnames(argnames)
|
|
630
|
+
argnums_v = normalize_argnums(argnums, argnames_v)
|
|
631
|
+
if argnums_v.empty? && argnames_v.empty?
|
|
632
|
+
raise ArgumentError, "Gradient wrt no argument requested"
|
|
633
|
+
end
|
|
634
|
+
[argnums_v, argnames_v]
|
|
635
|
+
end
|
|
636
|
+
|
|
637
|
+
def normalize_argnums(argnums, argnames)
|
|
638
|
+
if argnums.nil?
|
|
639
|
+
return argnames.empty? ? [0] : []
|
|
640
|
+
end
|
|
641
|
+
values = if argnums.is_a?(::Integer)
|
|
642
|
+
[argnums]
|
|
643
|
+
elsif argnums.is_a?(::Array)
|
|
644
|
+
argnums
|
|
645
|
+
else
|
|
646
|
+
raise TypeError, "argnums must be an Integer, an Array of Integer, or nil"
|
|
647
|
+
end
|
|
648
|
+
out = values.map do |value|
|
|
649
|
+
raise TypeError, "argnums entries must be Integer" unless value.is_a?(::Integer)
|
|
650
|
+
raise ArgumentError, "argnums cannot contain negative values" if value.negative?
|
|
651
|
+
value
|
|
652
|
+
end
|
|
653
|
+
raise ArgumentError, "duplicate argnums are not allowed" if out.uniq.length != out.length
|
|
654
|
+
|
|
655
|
+
out
|
|
656
|
+
end
|
|
657
|
+
|
|
658
|
+
def normalize_argnames(argnames)
|
|
659
|
+
return [] if argnames.nil?
|
|
660
|
+
values = if argnames.is_a?(::String) || argnames.is_a?(::Symbol)
|
|
661
|
+
[argnames]
|
|
662
|
+
elsif argnames.is_a?(::Array)
|
|
663
|
+
argnames
|
|
664
|
+
else
|
|
665
|
+
raise TypeError, "argnames must be a String, Symbol, Array, or nil"
|
|
666
|
+
end
|
|
667
|
+
out = values.map(&:to_s)
|
|
668
|
+
raise ArgumentError, "duplicate argnames are not allowed" if out.uniq.length != out.length
|
|
669
|
+
|
|
670
|
+
out
|
|
671
|
+
end
|
|
672
|
+
|
|
673
|
+
def build_grad_like_function(fun, argnums, argnames, with_value)
|
|
674
|
+
lambda do |*args, **kwargs|
|
|
675
|
+
selections, flat_inputs = build_target_selections(args, kwargs, argnums, argnames)
|
|
676
|
+
native_argnums = (0...flat_inputs.length).to_a
|
|
677
|
+
captured_value = nil
|
|
678
|
+
lifted = lambda do |*flat_vars|
|
|
679
|
+
call_args, call_kwargs = apply_flat_vars_to_targets(args, kwargs, selections, flat_vars)
|
|
680
|
+
raw_value = fun.call(*call_args, **call_kwargs)
|
|
681
|
+
captured_value = raw_value
|
|
682
|
+
extract_loss(raw_value)
|
|
683
|
+
end
|
|
684
|
+
|
|
685
|
+
if with_value
|
|
686
|
+
native_fn = native_value_and_grad(lifted, native_argnums)
|
|
687
|
+
_loss, raw_grads = native_fn.call(*flat_inputs)
|
|
688
|
+
value = captured_value.nil? ? fun.call(*args, **kwargs) : captured_value
|
|
689
|
+
[value, rebuild_grad_result(raw_grads, selections, argnames)]
|
|
690
|
+
else
|
|
691
|
+
native_fn = native_grad(lifted, native_argnums)
|
|
692
|
+
raw_grads = native_fn.call(*flat_inputs)
|
|
693
|
+
rebuild_grad_result(raw_grads, selections, argnames)
|
|
694
|
+
end
|
|
695
|
+
end
|
|
696
|
+
end
|
|
697
|
+
|
|
698
|
+
def build_custom_vjp_grad_function(fun)
|
|
699
|
+
lambda do |*args, **kwargs|
|
|
700
|
+
unless kwargs.empty?
|
|
701
|
+
raise ArgumentError, "custom-function grad currently supports positional arguments only"
|
|
702
|
+
end
|
|
703
|
+
outputs = normalize_array_output(fun.call(*args), "custom_function output")
|
|
704
|
+
cotangents = outputs.map { |out| MLX::Core.ones_like(out) }
|
|
705
|
+
output_arg = outputs.length == 1 ? outputs[0] : outputs
|
|
706
|
+
grads = normalize_array_output(
|
|
707
|
+
fun.call_custom_vjp(args, cotangents, output_arg),
|
|
708
|
+
"custom_function vjp output"
|
|
709
|
+
)
|
|
710
|
+
grads.length == 1 ? grads[0] : grads
|
|
711
|
+
end
|
|
712
|
+
end
|
|
713
|
+
|
|
714
|
+
def build_custom_vjp_value_and_grad_function(fun)
|
|
715
|
+
grad_fn = build_custom_vjp_grad_function(fun)
|
|
716
|
+
lambda do |*args, **kwargs|
|
|
717
|
+
value = fun.call(*args, **kwargs)
|
|
718
|
+
[value, grad_fn.call(*args, **kwargs)]
|
|
719
|
+
end
|
|
720
|
+
end
|
|
721
|
+
|
|
722
|
+
def custom_jvp(fun, primals, tangents)
|
|
723
|
+
primals_list = normalize_array_output(primals, "primals")
|
|
724
|
+
tangents_list = normalize_array_output(tangents, "tangents")
|
|
725
|
+
outputs = normalize_array_output(fun.call(*primals_list), "custom_function output")
|
|
726
|
+
jvps = normalize_array_output(
|
|
727
|
+
fun.call_custom_jvp(primals_list, tangents_list),
|
|
728
|
+
"custom_function jvp output"
|
|
729
|
+
)
|
|
730
|
+
[outputs, jvps]
|
|
731
|
+
end
|
|
732
|
+
|
|
733
|
+
def custom_vjp(fun, primals, cotangents)
|
|
734
|
+
primals_list = normalize_array_output(primals, "primals")
|
|
735
|
+
cotangents_list = normalize_array_output(cotangents, "cotangents")
|
|
736
|
+
outputs = normalize_array_output(fun.call(*primals_list), "custom_function output")
|
|
737
|
+
output_arg = outputs.length == 1 ? outputs[0] : outputs
|
|
738
|
+
vjps = normalize_array_output(
|
|
739
|
+
fun.call_custom_vjp(primals_list, cotangents_list, output_arg),
|
|
740
|
+
"custom_function vjp output"
|
|
741
|
+
)
|
|
742
|
+
[outputs, vjps]
|
|
743
|
+
end
|
|
744
|
+
|
|
745
|
+
def custom_vmap_callable(fun, in_axes, _out_axes)
|
|
746
|
+
lambda do |*args|
|
|
747
|
+
input_axes = if in_axes.nil?
|
|
748
|
+
::Array.new(args.length, 0)
|
|
749
|
+
elsif in_axes.is_a?(::Integer)
|
|
750
|
+
::Array.new(args.length, in_axes)
|
|
751
|
+
elsif in_axes.is_a?(::Array)
|
|
752
|
+
in_axes
|
|
753
|
+
else
|
|
754
|
+
raise TypeError, "in_axes must be Integer, Array, or nil"
|
|
755
|
+
end
|
|
756
|
+
out = fun.call_custom_vmap(args, input_axes)
|
|
757
|
+
if out.is_a?(::Array) && out.length == 2
|
|
758
|
+
out[0]
|
|
759
|
+
else
|
|
760
|
+
out
|
|
761
|
+
end
|
|
762
|
+
end
|
|
763
|
+
end
|
|
764
|
+
|
|
765
|
+
def extract_loss(output)
|
|
766
|
+
return output if output.is_a?(MLX::Core::Array)
|
|
767
|
+
|
|
768
|
+
if output.is_a?(::Array) && !output.empty? && output[0].is_a?(MLX::Core::Array)
|
|
769
|
+
return output[0]
|
|
770
|
+
end
|
|
771
|
+
|
|
772
|
+
raise ArgumentError,
|
|
773
|
+
"function must return an MLX::Core::Array or an Array whose first element is an MLX::Core::Array"
|
|
774
|
+
end
|
|
775
|
+
|
|
776
|
+
def build_target_selections(args, kwargs, argnums, argnames)
|
|
777
|
+
positional = []
|
|
778
|
+
keyword = []
|
|
779
|
+
flat_inputs = []
|
|
780
|
+
|
|
781
|
+
argnums.each do |index|
|
|
782
|
+
if index >= args.length
|
|
783
|
+
raise ArgumentError,
|
|
784
|
+
"Can't compute gradient for positional argument #{index} when #{args.length} positional arguments were provided"
|
|
785
|
+
end
|
|
786
|
+
spec = flatten_tree_spec(args[index], flat_inputs, true)
|
|
787
|
+
positional << { index: index, spec: spec }
|
|
788
|
+
end
|
|
789
|
+
|
|
790
|
+
argnames.each do |name|
|
|
791
|
+
key = kwarg_key_for_name(kwargs, name)
|
|
792
|
+
unless key
|
|
793
|
+
raise ArgumentError,
|
|
794
|
+
"Can't compute gradient for keyword argument '#{name}' because it was not provided"
|
|
795
|
+
end
|
|
796
|
+
spec = flatten_tree_spec(kwargs[key], flat_inputs, true)
|
|
797
|
+
keyword << { key: key, name: name, spec: spec }
|
|
798
|
+
end
|
|
799
|
+
|
|
800
|
+
[{ positional: positional, keyword: keyword }, flat_inputs]
|
|
801
|
+
end
|
|
802
|
+
|
|
803
|
+
def flatten_tree_spec(value, arrays, strict_arrays)
|
|
804
|
+
if value.is_a?(MLX::Core::Array)
|
|
805
|
+
arrays << value
|
|
806
|
+
return ARRAY_LEAF
|
|
807
|
+
end
|
|
808
|
+
if value.is_a?(::Array)
|
|
809
|
+
return [:array, value.map { |item| flatten_tree_spec(item, arrays, strict_arrays) }]
|
|
810
|
+
end
|
|
811
|
+
if value.is_a?(::Hash)
|
|
812
|
+
return [:hash, value.map { |k, v| [k, flatten_tree_spec(v, arrays, strict_arrays)] }]
|
|
813
|
+
end
|
|
814
|
+
if strict_arrays
|
|
815
|
+
raise TypeError, "[tree_flatten] The argument should contain only arrays"
|
|
816
|
+
end
|
|
817
|
+
if value.nil? || value.is_a?(::Numeric) || value.is_a?(::String) ||
|
|
818
|
+
value.is_a?(::Symbol) || value == true || value == false
|
|
819
|
+
return [:const, value]
|
|
820
|
+
end
|
|
821
|
+
raise TypeError,
|
|
822
|
+
"[compile] Function arguments and outputs must be trees of arrays or constants (Numeric, String, Symbol, true/false, nil)"
|
|
823
|
+
end
|
|
824
|
+
|
|
825
|
+
def structure_cache_key(spec)
|
|
826
|
+
return "A" if spec == ARRAY_LEAF
|
|
827
|
+
|
|
828
|
+
tag, payload = spec
|
|
829
|
+
case tag
|
|
830
|
+
when :array
|
|
831
|
+
"L[#{payload.map { |entry| structure_cache_key(entry) }.join(",")}]"
|
|
832
|
+
when :hash
|
|
833
|
+
pairs = payload.map do |key, child|
|
|
834
|
+
"#{key.inspect}:#{structure_cache_key(child)}"
|
|
835
|
+
end
|
|
836
|
+
"H{#{pairs.join(",")}}"
|
|
837
|
+
when :const
|
|
838
|
+
"C(#{payload.class}:#{payload.inspect})"
|
|
839
|
+
else
|
|
840
|
+
raise ArgumentError, "invalid tree specification"
|
|
841
|
+
end
|
|
842
|
+
end
|
|
843
|
+
|
|
844
|
+
def inflate_tree_from_arrays(spec, arrays, cursor)
|
|
845
|
+
return [arrays.fetch(cursor), cursor + 1] if spec == ARRAY_LEAF
|
|
846
|
+
|
|
847
|
+
tag, payload = spec
|
|
848
|
+
case tag
|
|
849
|
+
when :array
|
|
850
|
+
out = []
|
|
851
|
+
payload.each do |child_spec|
|
|
852
|
+
item, cursor = inflate_tree_from_arrays(child_spec, arrays, cursor)
|
|
853
|
+
out << item
|
|
854
|
+
end
|
|
855
|
+
[out, cursor]
|
|
856
|
+
when :hash
|
|
857
|
+
out = {}
|
|
858
|
+
payload.each do |key, child_spec|
|
|
859
|
+
item, cursor = inflate_tree_from_arrays(child_spec, arrays, cursor)
|
|
860
|
+
out[key] = item
|
|
861
|
+
end
|
|
862
|
+
[out, cursor]
|
|
863
|
+
when :const
|
|
864
|
+
[payload, cursor]
|
|
865
|
+
else
|
|
866
|
+
raise ArgumentError, "invalid tree specification"
|
|
867
|
+
end
|
|
868
|
+
end
|
|
869
|
+
|
|
870
|
+
def normalize_raw_grads(raw)
|
|
871
|
+
normalize_array_sequence(raw, "gradient")
|
|
872
|
+
end
|
|
873
|
+
|
|
874
|
+
def normalize_array_sequence(raw, context)
|
|
875
|
+
return [raw] if raw.is_a?(MLX::Core::Array)
|
|
876
|
+
|
|
877
|
+
if raw.is_a?(::Array) && raw.all? { |item| item.is_a?(MLX::Core::Array) }
|
|
878
|
+
return raw
|
|
879
|
+
end
|
|
880
|
+
raise TypeError, "unexpected #{context} return type"
|
|
881
|
+
end
|
|
882
|
+
|
|
883
|
+
def normalize_array_output(raw, context)
|
|
884
|
+
if raw.is_a?(MLX::Core::Array)
|
|
885
|
+
[raw]
|
|
886
|
+
elsif raw.is_a?(::Array) && raw.all? { |item| item.is_a?(MLX::Core::Array) }
|
|
887
|
+
raw
|
|
888
|
+
else
|
|
889
|
+
raise TypeError, "unexpected #{context} type"
|
|
890
|
+
end
|
|
891
|
+
end
|
|
892
|
+
|
|
893
|
+
def rebuild_grad_result(raw_grads, selections, argnames)
|
|
894
|
+
grad_arrays = normalize_raw_grads(raw_grads)
|
|
895
|
+
cursor = 0
|
|
896
|
+
|
|
897
|
+
positional_grads = selections[:positional].map do |entry|
|
|
898
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
|
|
899
|
+
value
|
|
900
|
+
end
|
|
901
|
+
keyword_grads = {}
|
|
902
|
+
selections[:keyword].each do |entry|
|
|
903
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
|
|
904
|
+
keyword_grads[entry[:name]] = value
|
|
905
|
+
end
|
|
906
|
+
unless cursor == grad_arrays.length
|
|
907
|
+
raise RuntimeError, "internal gradient reconstruction mismatch"
|
|
908
|
+
end
|
|
909
|
+
|
|
910
|
+
if argnames.empty?
|
|
911
|
+
return positional_grads[0] if positional_grads.length == 1
|
|
912
|
+
return positional_grads
|
|
913
|
+
end
|
|
914
|
+
|
|
915
|
+
positional_out = if positional_grads.empty?
|
|
916
|
+
nil
|
|
917
|
+
elsif positional_grads.length == 1
|
|
918
|
+
positional_grads[0]
|
|
919
|
+
else
|
|
920
|
+
positional_grads
|
|
921
|
+
end
|
|
922
|
+
[positional_out, keyword_grads]
|
|
923
|
+
end
|
|
924
|
+
|
|
925
|
+
def apply_flat_vars_to_targets(args, kwargs, selections, flat_vars)
|
|
926
|
+
rebuilt_args = args.dup
|
|
927
|
+
rebuilt_kwargs = kwargs.dup
|
|
928
|
+
cursor = 0
|
|
929
|
+
|
|
930
|
+
selections[:positional].each do |entry|
|
|
931
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
|
|
932
|
+
rebuilt_args[entry[:index]] = value
|
|
933
|
+
end
|
|
934
|
+
|
|
935
|
+
selections[:keyword].each do |entry|
|
|
936
|
+
value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
|
|
937
|
+
rebuilt_kwargs[entry[:key]] = value
|
|
938
|
+
end
|
|
939
|
+
|
|
940
|
+
unless cursor == flat_vars.length
|
|
941
|
+
raise RuntimeError, "internal target reconstruction mismatch"
|
|
942
|
+
end
|
|
943
|
+
[rebuilt_args, rebuilt_kwargs]
|
|
944
|
+
end
|
|
945
|
+
|
|
946
|
+
def kwarg_key_for_name(kwargs, name)
|
|
947
|
+
symbol = name.to_sym
|
|
948
|
+
return symbol if kwargs.key?(symbol)
|
|
949
|
+
return name if kwargs.key?(name)
|
|
950
|
+
|
|
951
|
+
nil
|
|
952
|
+
end
|
|
953
|
+
end
|
|
954
|
+
|
|
955
|
+
class Device
|
|
956
|
+
alias_method :native_equal, :== if method_defined?(:==) && !method_defined?(:native_equal)
|
|
957
|
+
|
|
958
|
+
def ==(other)
|
|
959
|
+
if other.is_a?(::Symbol) || other.is_a?(::String)
|
|
960
|
+
type == other.to_sym
|
|
961
|
+
else
|
|
962
|
+
native_equal(other)
|
|
963
|
+
end
|
|
964
|
+
end
|
|
965
|
+
|
|
966
|
+
alias eql? ==
|
|
967
|
+
end
|
|
968
|
+
|
|
969
|
+
class Array
|
|
970
|
+
EPSILON_BY_DTYPE = {
|
|
971
|
+
"float16" => 9.765625e-4,
|
|
972
|
+
"bfloat16" => 7.8125e-3,
|
|
973
|
+
"float32" => 1.1920929e-7,
|
|
974
|
+
"float64" => Float::EPSILON,
|
|
975
|
+
"complex64" => 1.1920929e-7
|
|
976
|
+
}.freeze
|
|
977
|
+
|
|
978
|
+
def T
|
|
979
|
+
transpose
|
|
980
|
+
end
|
|
981
|
+
|
|
982
|
+
def at
|
|
983
|
+
ArrayAt.new(self)
|
|
984
|
+
end
|
|
985
|
+
|
|
986
|
+
def real
|
|
987
|
+
MLX::Core.real(self)
|
|
988
|
+
end
|
|
989
|
+
|
|
990
|
+
def imag
|
|
991
|
+
MLX::Core.imag(self)
|
|
992
|
+
end
|
|
993
|
+
|
|
994
|
+
def itemsize
|
|
995
|
+
dtype.size
|
|
996
|
+
end
|
|
997
|
+
|
|
998
|
+
def nbytes
|
|
999
|
+
size * itemsize
|
|
1000
|
+
end
|
|
1001
|
+
|
|
1002
|
+
def add(other)
|
|
1003
|
+
MLX::Core.add(self, other)
|
|
1004
|
+
end
|
|
1005
|
+
|
|
1006
|
+
def subtract(other)
|
|
1007
|
+
MLX::Core.subtract(self, other)
|
|
1008
|
+
end
|
|
1009
|
+
|
|
1010
|
+
def multiply(other)
|
|
1011
|
+
MLX::Core.multiply(self, other)
|
|
1012
|
+
end
|
|
1013
|
+
|
|
1014
|
+
def divide(other)
|
|
1015
|
+
MLX::Core.divide(self, other)
|
|
1016
|
+
end
|
|
1017
|
+
|
|
1018
|
+
def exp
|
|
1019
|
+
MLX::Core.exp(self)
|
|
1020
|
+
end
|
|
1021
|
+
|
|
1022
|
+
def sin
|
|
1023
|
+
MLX::Core.sin(self)
|
|
1024
|
+
end
|
|
1025
|
+
|
|
1026
|
+
def cos
|
|
1027
|
+
MLX::Core.cos(self)
|
|
1028
|
+
end
|
|
1029
|
+
|
|
1030
|
+
def mean(axis = nil)
|
|
1031
|
+
MLX::Core.mean(self, axis)
|
|
1032
|
+
end
|
|
1033
|
+
|
|
1034
|
+
def sum(axis = nil)
|
|
1035
|
+
MLX::Core.sum(self, axis)
|
|
1036
|
+
end
|
|
1037
|
+
|
|
1038
|
+
def var(axis = nil, keepdims = nil, ddof = nil)
|
|
1039
|
+
MLX::Core.var(self, axis, keepdims, ddof)
|
|
1040
|
+
end
|
|
1041
|
+
|
|
1042
|
+
def std(axis = nil, keepdims = nil, ddof = nil)
|
|
1043
|
+
MLX::Core.std(self, axis, keepdims, ddof)
|
|
1044
|
+
end
|
|
1045
|
+
|
|
1046
|
+
def max(axis = nil, keepdims = nil)
|
|
1047
|
+
MLX::Core.max(self, axis, keepdims)
|
|
1048
|
+
end
|
|
1049
|
+
|
|
1050
|
+
def min(axis = nil, keepdims = nil)
|
|
1051
|
+
MLX::Core.min(self, axis, keepdims)
|
|
1052
|
+
end
|
|
1053
|
+
|
|
1054
|
+
def reshape(*shape)
|
|
1055
|
+
target = shape.length == 1 ? shape[0] : shape
|
|
1056
|
+
MLX::Core.reshape(self, target)
|
|
1057
|
+
end
|
|
1058
|
+
|
|
1059
|
+
def transpose(axes = nil)
|
|
1060
|
+
MLX::Core.transpose(self, axes)
|
|
1061
|
+
end
|
|
1062
|
+
|
|
1063
|
+
def squeeze(axis = nil)
|
|
1064
|
+
MLX::Core.squeeze(self, axis)
|
|
1065
|
+
end
|
|
1066
|
+
|
|
1067
|
+
def square
|
|
1068
|
+
MLX::Core.square(self)
|
|
1069
|
+
end
|
|
1070
|
+
|
|
1071
|
+
def sqrt
|
|
1072
|
+
MLX::Core.sqrt(self)
|
|
1073
|
+
end
|
|
1074
|
+
|
|
1075
|
+
def rsqrt
|
|
1076
|
+
MLX::Core.rsqrt(self)
|
|
1077
|
+
end
|
|
1078
|
+
|
|
1079
|
+
def reciprocal
|
|
1080
|
+
MLX::Core.reciprocal(self)
|
|
1081
|
+
end
|
|
1082
|
+
|
|
1083
|
+
def abs
|
|
1084
|
+
MLX::Core.abs(self)
|
|
1085
|
+
end
|
|
1086
|
+
|
|
1087
|
+
def all(axis = nil, keepdims = nil)
|
|
1088
|
+
MLX::Core.all(self, axis, keepdims)
|
|
1089
|
+
end
|
|
1090
|
+
|
|
1091
|
+
def any(axis = nil, keepdims = nil)
|
|
1092
|
+
MLX::Core.any(self, axis, keepdims)
|
|
1093
|
+
end
|
|
1094
|
+
|
|
1095
|
+
def argmax(axis = nil, keepdims = nil)
|
|
1096
|
+
MLX::Core.argmax(self, axis, keepdims)
|
|
1097
|
+
end
|
|
1098
|
+
|
|
1099
|
+
def argmin(axis = nil, keepdims = nil)
|
|
1100
|
+
MLX::Core.argmin(self, axis, keepdims)
|
|
1101
|
+
end
|
|
1102
|
+
|
|
1103
|
+
def astype(dtype, stream = nil)
|
|
1104
|
+
if stream.nil?
|
|
1105
|
+
MLX::Core.astype(self, dtype)
|
|
1106
|
+
else
|
|
1107
|
+
MLX::Core.astype(self, dtype, stream)
|
|
1108
|
+
end
|
|
1109
|
+
end
|
|
1110
|
+
|
|
1111
|
+
def conj
|
|
1112
|
+
MLX::Core.conj(self)
|
|
1113
|
+
end
|
|
1114
|
+
|
|
1115
|
+
def cummax(*args)
|
|
1116
|
+
MLX::Core.cummax(self, *args)
|
|
1117
|
+
end
|
|
1118
|
+
|
|
1119
|
+
def cummin(*args)
|
|
1120
|
+
MLX::Core.cummin(self, *args)
|
|
1121
|
+
end
|
|
1122
|
+
|
|
1123
|
+
def cumprod(*args)
|
|
1124
|
+
MLX::Core.cumprod(self, *args)
|
|
1125
|
+
end
|
|
1126
|
+
|
|
1127
|
+
def cumsum(*args)
|
|
1128
|
+
MLX::Core.cumsum(self, *args)
|
|
1129
|
+
end
|
|
1130
|
+
|
|
1131
|
+
def diag(*args)
|
|
1132
|
+
MLX::Core.diag(self, *args)
|
|
1133
|
+
end
|
|
1134
|
+
|
|
1135
|
+
def diagonal(*args)
|
|
1136
|
+
MLX::Core.diagonal(self, *args)
|
|
1137
|
+
end
|
|
1138
|
+
|
|
1139
|
+
def flatten(start_axis = 0, end_axis = -1)
|
|
1140
|
+
MLX::Core.flatten(self, start_axis, end_axis)
|
|
1141
|
+
end
|
|
1142
|
+
|
|
1143
|
+
def log
|
|
1144
|
+
MLX::Core.log(self)
|
|
1145
|
+
end
|
|
1146
|
+
|
|
1147
|
+
def log10
|
|
1148
|
+
MLX::Core.log10(self)
|
|
1149
|
+
end
|
|
1150
|
+
|
|
1151
|
+
def log1p
|
|
1152
|
+
MLX::Core.log1p(self)
|
|
1153
|
+
end
|
|
1154
|
+
|
|
1155
|
+
def log2
|
|
1156
|
+
MLX::Core.log2(self)
|
|
1157
|
+
end
|
|
1158
|
+
|
|
1159
|
+
def logcumsumexp(*args)
|
|
1160
|
+
MLX::Core.logcumsumexp(self, *args)
|
|
1161
|
+
end
|
|
1162
|
+
|
|
1163
|
+
def logsumexp(*args)
|
|
1164
|
+
MLX::Core.logsumexp(self, *args)
|
|
1165
|
+
end
|
|
1166
|
+
|
|
1167
|
+
def maximum(other)
|
|
1168
|
+
MLX::Core.maximum(self, other)
|
|
1169
|
+
end
|
|
1170
|
+
|
|
1171
|
+
def minimum(other)
|
|
1172
|
+
MLX::Core.minimum(self, other)
|
|
1173
|
+
end
|
|
1174
|
+
|
|
1175
|
+
def moveaxis(source, destination)
|
|
1176
|
+
MLX::Core.moveaxis(self, source, destination)
|
|
1177
|
+
end
|
|
1178
|
+
|
|
1179
|
+
def prod(axis = nil, keepdims = nil)
|
|
1180
|
+
MLX::Core.prod(self, axis, keepdims)
|
|
1181
|
+
end
|
|
1182
|
+
|
|
1183
|
+
def round(decimals = 0)
|
|
1184
|
+
MLX::Core.round(self, decimals)
|
|
1185
|
+
end
|
|
1186
|
+
|
|
1187
|
+
def split(indices_or_sections, axis = 0)
|
|
1188
|
+
MLX::Core.split(self, indices_or_sections, axis)
|
|
1189
|
+
end
|
|
1190
|
+
|
|
1191
|
+
def swapaxes(axis1, axis2)
|
|
1192
|
+
MLX::Core.swapaxes(self, axis1, axis2)
|
|
1193
|
+
end
|
|
1194
|
+
|
|
1195
|
+
def view(dtype)
|
|
1196
|
+
MLX::Core.view(self, dtype)
|
|
1197
|
+
end
|
|
1198
|
+
|
|
1199
|
+
def eps
|
|
1200
|
+
dtype_name = if dtype.respond_to?(:name)
|
|
1201
|
+
dtype.name.to_s
|
|
1202
|
+
else
|
|
1203
|
+
dtype.to_s
|
|
1204
|
+
end
|
|
1205
|
+
EPSILON_BY_DTYPE.fetch(dtype_name, Float::EPSILON)
|
|
1206
|
+
end
|
|
1207
|
+
|
|
1208
|
+
def tolist
|
|
1209
|
+
to_a
|
|
1210
|
+
end
|
|
1211
|
+
|
|
1212
|
+
def __add__(other)
|
|
1213
|
+
add(other)
|
|
1214
|
+
end
|
|
1215
|
+
|
|
1216
|
+
def __sub__(other)
|
|
1217
|
+
subtract(other)
|
|
1218
|
+
end
|
|
1219
|
+
|
|
1220
|
+
def __mul__(other)
|
|
1221
|
+
multiply(other)
|
|
1222
|
+
end
|
|
1223
|
+
|
|
1224
|
+
def __truediv__(other)
|
|
1225
|
+
divide(other)
|
|
1226
|
+
end
|
|
1227
|
+
|
|
1228
|
+
def __div__(other)
|
|
1229
|
+
__truediv__(other)
|
|
1230
|
+
end
|
|
1231
|
+
|
|
1232
|
+
def __matmul__(other)
|
|
1233
|
+
MLX::Core.matmul(self, other)
|
|
1234
|
+
end
|
|
1235
|
+
|
|
1236
|
+
def __imatmul__(other)
|
|
1237
|
+
__matmul__(other)
|
|
1238
|
+
end
|
|
1239
|
+
|
|
1240
|
+
def __len__
|
|
1241
|
+
shape.first || 0
|
|
1242
|
+
end
|
|
1243
|
+
|
|
1244
|
+
def __iter__
|
|
1245
|
+
ArrayIterator.new(self)
|
|
1246
|
+
end
|
|
1247
|
+
|
|
1248
|
+
def __next__
|
|
1249
|
+
@__mlx_array_iterator ||= __iter__
|
|
1250
|
+
@__mlx_array_iterator.__next__
|
|
1251
|
+
end
|
|
1252
|
+
|
|
1253
|
+
def __init__(*_)
|
|
1254
|
+
self
|
|
1255
|
+
end
|
|
1256
|
+
|
|
1257
|
+
def __repr__
|
|
1258
|
+
inspect
|
|
1259
|
+
end
|
|
1260
|
+
|
|
1261
|
+
def __bool__
|
|
1262
|
+
raise ArgumentError, "The truth value of an array with more than one element is ambiguous" if size != 1
|
|
1263
|
+
|
|
1264
|
+
!!item
|
|
1265
|
+
end
|
|
1266
|
+
|
|
1267
|
+
def __int__
|
|
1268
|
+
raise ArgumentError, "only size-1 arrays can be converted to Integer" if size != 1
|
|
1269
|
+
|
|
1270
|
+
Integer(item)
|
|
1271
|
+
end
|
|
1272
|
+
|
|
1273
|
+
def __float__
|
|
1274
|
+
raise ArgumentError, "only size-1 arrays can be converted to Float" if size != 1
|
|
1275
|
+
|
|
1276
|
+
Float(item)
|
|
1277
|
+
end
|
|
1278
|
+
|
|
1279
|
+
def __hash__
|
|
1280
|
+
object_id.hash
|
|
1281
|
+
end
|
|
1282
|
+
|
|
1283
|
+
def __array_namespace__
|
|
1284
|
+
MLX::Core
|
|
1285
|
+
end
|
|
1286
|
+
|
|
1287
|
+
def __eq__(other)
|
|
1288
|
+
MLX::Core.equal(self, other)
|
|
1289
|
+
end
|
|
1290
|
+
|
|
1291
|
+
def __ne__(other)
|
|
1292
|
+
MLX::Core.not_equal(self, other)
|
|
1293
|
+
end
|
|
1294
|
+
|
|
1295
|
+
def __abs__
|
|
1296
|
+
MLX::Core.abs(self)
|
|
1297
|
+
end
|
|
1298
|
+
|
|
1299
|
+
def __neg__
|
|
1300
|
+
MLX::Core.negative(self)
|
|
1301
|
+
end
|
|
1302
|
+
|
|
1303
|
+
def __pow__(other)
|
|
1304
|
+
MLX::Core.power(self, other)
|
|
1305
|
+
end
|
|
1306
|
+
|
|
1307
|
+
def __rpow__(other)
|
|
1308
|
+
MLX::Core.power(other, self)
|
|
1309
|
+
end
|
|
1310
|
+
|
|
1311
|
+
def __floordiv__(other)
|
|
1312
|
+
MLX::Core.floor_divide(self, other)
|
|
1313
|
+
end
|
|
1314
|
+
|
|
1315
|
+
def __mod__(other)
|
|
1316
|
+
MLX::Core.remainder(self, other)
|
|
1317
|
+
end
|
|
1318
|
+
|
|
1319
|
+
def __rmod__(other)
|
|
1320
|
+
MLX::Core.remainder(other, self)
|
|
1321
|
+
end
|
|
1322
|
+
|
|
1323
|
+
def __radd__(other)
|
|
1324
|
+
MLX::Core.add(other, self)
|
|
1325
|
+
end
|
|
1326
|
+
|
|
1327
|
+
def __rsub__(other)
|
|
1328
|
+
MLX::Core.subtract(other, self)
|
|
1329
|
+
end
|
|
1330
|
+
|
|
1331
|
+
def __rmul__(other)
|
|
1332
|
+
MLX::Core.multiply(other, self)
|
|
1333
|
+
end
|
|
1334
|
+
|
|
1335
|
+
def __rtruediv__(other)
|
|
1336
|
+
MLX::Core.divide(other, self)
|
|
1337
|
+
end
|
|
1338
|
+
|
|
1339
|
+
def __rdiv__(other)
|
|
1340
|
+
__rtruediv__(other)
|
|
1341
|
+
end
|
|
1342
|
+
|
|
1343
|
+
def __and__(other)
|
|
1344
|
+
MLX::Core.bitwise_and(self, other)
|
|
1345
|
+
end
|
|
1346
|
+
|
|
1347
|
+
def __or__(other)
|
|
1348
|
+
MLX::Core.bitwise_or(self, other)
|
|
1349
|
+
end
|
|
1350
|
+
|
|
1351
|
+
def __xor__(other)
|
|
1352
|
+
MLX::Core.bitwise_xor(self, other)
|
|
1353
|
+
end
|
|
1354
|
+
|
|
1355
|
+
def __invert__
|
|
1356
|
+
MLX::Core.bitwise_invert(self)
|
|
1357
|
+
end
|
|
1358
|
+
|
|
1359
|
+
def __lshift__(other)
|
|
1360
|
+
MLX::Core.left_shift(self, other)
|
|
1361
|
+
end
|
|
1362
|
+
|
|
1363
|
+
def __rshift__(other)
|
|
1364
|
+
MLX::Core.right_shift(self, other)
|
|
1365
|
+
end
|
|
1366
|
+
|
|
1367
|
+
def __lt__(other)
|
|
1368
|
+
MLX::Core.less(self, other)
|
|
1369
|
+
end
|
|
1370
|
+
|
|
1371
|
+
def __le__(other)
|
|
1372
|
+
MLX::Core.less_equal(self, other)
|
|
1373
|
+
end
|
|
1374
|
+
|
|
1375
|
+
def __gt__(other)
|
|
1376
|
+
MLX::Core.greater(self, other)
|
|
1377
|
+
end
|
|
1378
|
+
|
|
1379
|
+
def __ge__(other)
|
|
1380
|
+
MLX::Core.greater_equal(self, other)
|
|
1381
|
+
end
|
|
1382
|
+
|
|
1383
|
+
def __iadd__(other)
|
|
1384
|
+
__add__(other)
|
|
1385
|
+
end
|
|
1386
|
+
|
|
1387
|
+
def __isub__(other)
|
|
1388
|
+
__sub__(other)
|
|
1389
|
+
end
|
|
1390
|
+
|
|
1391
|
+
def __imul__(other)
|
|
1392
|
+
__mul__(other)
|
|
1393
|
+
end
|
|
1394
|
+
|
|
1395
|
+
def __itruediv__(other)
|
|
1396
|
+
__truediv__(other)
|
|
1397
|
+
end
|
|
1398
|
+
|
|
1399
|
+
def __ifloordiv__(other)
|
|
1400
|
+
__floordiv__(other)
|
|
1401
|
+
end
|
|
1402
|
+
|
|
1403
|
+
def __imod__(other)
|
|
1404
|
+
__mod__(other)
|
|
1405
|
+
end
|
|
1406
|
+
|
|
1407
|
+
def __ipow__(other)
|
|
1408
|
+
__pow__(other)
|
|
1409
|
+
end
|
|
1410
|
+
|
|
1411
|
+
def __iand__(other)
|
|
1412
|
+
__and__(other)
|
|
1413
|
+
end
|
|
1414
|
+
|
|
1415
|
+
def __ior__(other)
|
|
1416
|
+
__or__(other)
|
|
1417
|
+
end
|
|
1418
|
+
|
|
1419
|
+
def __ixor__(other)
|
|
1420
|
+
__xor__(other)
|
|
1421
|
+
end
|
|
1422
|
+
|
|
1423
|
+
def __ilshift__(other)
|
|
1424
|
+
__lshift__(other)
|
|
1425
|
+
end
|
|
1426
|
+
|
|
1427
|
+
def __irshift__(other)
|
|
1428
|
+
__rshift__(other)
|
|
1429
|
+
end
|
|
1430
|
+
|
|
1431
|
+
def __rfloordiv__(other)
|
|
1432
|
+
MLX::Core.floor_divide(other, self)
|
|
1433
|
+
end
|
|
1434
|
+
|
|
1435
|
+
def __getitem__(index)
|
|
1436
|
+
self[index]
|
|
1437
|
+
end
|
|
1438
|
+
|
|
1439
|
+
def __setitem__(index, value)
|
|
1440
|
+
fast_path = __setitem_1d_device_fast_path(index, value)
|
|
1441
|
+
return fast_path unless fast_path.nil?
|
|
1442
|
+
|
|
1443
|
+
copy = __ruby_deep_copy(to_a)
|
|
1444
|
+
replacement = value.is_a?(MLX::Core::Array) ? value.to_a : value
|
|
1445
|
+
__apply_setitem!(copy, index, replacement)
|
|
1446
|
+
MLX::Core.array(copy, dtype)
|
|
1447
|
+
end
|
|
1448
|
+
|
|
1449
|
+
def __copy__
|
|
1450
|
+
MLX::Core.array(to_a, dtype)
|
|
1451
|
+
end
|
|
1452
|
+
|
|
1453
|
+
def __deepcopy__(_memo = nil)
|
|
1454
|
+
__copy__
|
|
1455
|
+
end
|
|
1456
|
+
|
|
1457
|
+
def __getstate__
|
|
1458
|
+
dtype_name = if dtype.respond_to?(:name)
|
|
1459
|
+
dtype.name.to_s
|
|
1460
|
+
else
|
|
1461
|
+
dtype.to_s
|
|
1462
|
+
end
|
|
1463
|
+
{
|
|
1464
|
+
"values" => to_a,
|
|
1465
|
+
"dtype" => dtype_name
|
|
1466
|
+
}
|
|
1467
|
+
end
|
|
1468
|
+
|
|
1469
|
+
def __setstate__(state)
|
|
1470
|
+
values = state["values"] || state[:values]
|
|
1471
|
+
dtype_name = state["dtype"] || state[:dtype]
|
|
1472
|
+
if !dtype_name.nil? && MLX::Core.respond_to?(dtype_name.to_sym)
|
|
1473
|
+
MLX::Core.array(values, MLX::Core.public_send(dtype_name.to_sym))
|
|
1474
|
+
else
|
|
1475
|
+
MLX::Core.array(values)
|
|
1476
|
+
end
|
|
1477
|
+
end
|
|
1478
|
+
|
|
1479
|
+
def __format__(format_spec = "")
|
|
1480
|
+
if size == 1 && !format_spec.to_s.empty?
|
|
1481
|
+
kernel = Kernel.format(format_spec, item)
|
|
1482
|
+
return kernel
|
|
1483
|
+
end
|
|
1484
|
+
to_a.to_s
|
|
1485
|
+
end
|
|
1486
|
+
|
|
1487
|
+
def __dlpack__(stream = nil)
|
|
1488
|
+
unless stream.nil? || stream.is_a?(::Integer)
|
|
1489
|
+
raise ArgumentError, "__dlpack__ stream must be nil or Integer"
|
|
1490
|
+
end
|
|
1491
|
+
|
|
1492
|
+
MLX::Core::DLPackCapsule.new(self, device: __dlpack_device, stream: stream)
|
|
1493
|
+
end
|
|
1494
|
+
|
|
1495
|
+
def __dlpack_device
|
|
1496
|
+
device = MLX::Core.default_device
|
|
1497
|
+
type_id = case device.type
|
|
1498
|
+
when :cpu
|
|
1499
|
+
1
|
|
1500
|
+
when :gpu
|
|
1501
|
+
MLX::Core.metal_is_available ? 8 : 13
|
|
1502
|
+
else
|
|
1503
|
+
device.type
|
|
1504
|
+
end
|
|
1505
|
+
[type_id, device.index]
|
|
1506
|
+
end
|
|
1507
|
+
|
|
1508
|
+
alias __dlpack_device__ __dlpack_device
|
|
1509
|
+
|
|
1510
|
+
private
|
|
1511
|
+
|
|
1512
|
+
def __setitem_1d_device_fast_path(index, replacement)
|
|
1513
|
+
return nil unless ndim == 1
|
|
1514
|
+
|
|
1515
|
+
if index.is_a?(::Integer)
|
|
1516
|
+
normalized = __normalize_1d_index(index)
|
|
1517
|
+
index_array = MLX::Core.array([normalized], MLX::Core.int32)
|
|
1518
|
+
values_array = __coerce_setitem_values_1d(replacement, 1)
|
|
1519
|
+
return MLX::Core.put_along_axis(self, index_array, values_array, 0)
|
|
1520
|
+
end
|
|
1521
|
+
|
|
1522
|
+
if index.is_a?(MLX::Core::Array) && index.ndim == 1
|
|
1523
|
+
case __dtype_name(index.dtype)
|
|
1524
|
+
when "bool_"
|
|
1525
|
+
return nil unless index.shape[0] == shape[0]
|
|
1526
|
+
|
|
1527
|
+
replacement_array = __coerce_setitem_mask_values_1d(replacement, shape[0])
|
|
1528
|
+
return MLX::Core.where(index, replacement_array, self)
|
|
1529
|
+
when "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"
|
|
1530
|
+
index_array = index.astype(MLX::Core.int32)
|
|
1531
|
+
values_array = __coerce_setitem_values_1d(replacement, index_array.size)
|
|
1532
|
+
return MLX::Core.put_along_axis(self, index_array, values_array, 0)
|
|
1533
|
+
end
|
|
1534
|
+
end
|
|
1535
|
+
|
|
1536
|
+
return nil unless index.is_a?(::Array) && index.all? { |entry| entry.is_a?(::Integer) }
|
|
1537
|
+
|
|
1538
|
+
normalized = index.map { |entry| __normalize_1d_index(entry) }
|
|
1539
|
+
index_array = MLX::Core.array(normalized, MLX::Core.int32)
|
|
1540
|
+
values_array = __coerce_setitem_values_1d(replacement, normalized.length)
|
|
1541
|
+
MLX::Core.put_along_axis(self, index_array, values_array, 0)
|
|
1542
|
+
rescue StandardError
|
|
1543
|
+
nil
|
|
1544
|
+
end
|
|
1545
|
+
|
|
1546
|
+
def __normalize_1d_index(index)
|
|
1547
|
+
size_1d = shape[0]
|
|
1548
|
+
normalized = index
|
|
1549
|
+
normalized += size_1d if normalized.negative?
|
|
1550
|
+
if normalized.negative? || normalized >= size_1d
|
|
1551
|
+
raise IndexError, "index out of range"
|
|
1552
|
+
end
|
|
1553
|
+
|
|
1554
|
+
normalized
|
|
1555
|
+
end
|
|
1556
|
+
|
|
1557
|
+
def __coerce_setitem_values_1d(values, count)
|
|
1558
|
+
case values
|
|
1559
|
+
when MLX::Core::Array
|
|
1560
|
+
return MLX::Core.full([count], values, dtype) if values.size == 1 && count > 1
|
|
1561
|
+
raise ArgumentError, "__setitem__ replacement values must match index list length" if values.size != count
|
|
1562
|
+
|
|
1563
|
+
MLX::Core.reshape(values.astype(dtype), [count])
|
|
1564
|
+
when ::Array
|
|
1565
|
+
value_array = MLX::Core.array(values, dtype)
|
|
1566
|
+
return MLX::Core.full([count], value_array, dtype) if value_array.size == 1 && count > 1
|
|
1567
|
+
raise ArgumentError, "__setitem__ replacement values must match index list length" if value_array.size != count
|
|
1568
|
+
|
|
1569
|
+
MLX::Core.reshape(value_array, [count])
|
|
1570
|
+
else
|
|
1571
|
+
MLX::Core.full([count], values, dtype)
|
|
1572
|
+
end
|
|
1573
|
+
end
|
|
1574
|
+
|
|
1575
|
+
def __coerce_setitem_mask_values_1d(values, count)
|
|
1576
|
+
case values
|
|
1577
|
+
when MLX::Core::Array
|
|
1578
|
+
return MLX::Core.full([count], values, dtype) if values.size == 1
|
|
1579
|
+
raise ArgumentError, "__setitem__ replacement values must match mask length" if values.size != count
|
|
1580
|
+
|
|
1581
|
+
MLX::Core.reshape(values.astype(dtype), [count])
|
|
1582
|
+
when ::Array
|
|
1583
|
+
value_array = MLX::Core.array(values, dtype)
|
|
1584
|
+
return MLX::Core.full([count], value_array, dtype) if value_array.size == 1
|
|
1585
|
+
raise ArgumentError, "__setitem__ replacement values must match mask length" if value_array.size != count
|
|
1586
|
+
|
|
1587
|
+
MLX::Core.reshape(value_array, [count])
|
|
1588
|
+
else
|
|
1589
|
+
MLX::Core.full([count], values, dtype)
|
|
1590
|
+
end
|
|
1591
|
+
end
|
|
1592
|
+
|
|
1593
|
+
def __dtype_name(dtype_obj)
|
|
1594
|
+
if dtype_obj.respond_to?(:name)
|
|
1595
|
+
dtype_obj.name.to_s
|
|
1596
|
+
else
|
|
1597
|
+
dtype_obj.to_s
|
|
1598
|
+
end
|
|
1599
|
+
end
|
|
1600
|
+
|
|
1601
|
+
def __apply_setitem!(data, index, replacement)
|
|
1602
|
+
if index.is_a?(::Integer)
|
|
1603
|
+
data[index] = replacement
|
|
1604
|
+
return
|
|
1605
|
+
end
|
|
1606
|
+
|
|
1607
|
+
normalized = if index.is_a?(MLX::Core::Array)
|
|
1608
|
+
index.to_a
|
|
1609
|
+
elsif index.is_a?(::Array)
|
|
1610
|
+
index
|
|
1611
|
+
else
|
|
1612
|
+
raise ArgumentError, "__setitem__ supports Integer, Integer list, or boolean mask indices"
|
|
1613
|
+
end
|
|
1614
|
+
|
|
1615
|
+
unless data.is_a?(::Array)
|
|
1616
|
+
raise ArgumentError, "__setitem__ list/mask indices require array values"
|
|
1617
|
+
end
|
|
1618
|
+
|
|
1619
|
+
if normalized.all? { |v| v == true || v == false }
|
|
1620
|
+
__apply_boolean_mask_setitem!(data, normalized, replacement)
|
|
1621
|
+
return
|
|
1622
|
+
end
|
|
1623
|
+
|
|
1624
|
+
unless normalized.all? { |v| v.is_a?(::Integer) }
|
|
1625
|
+
raise ArgumentError, "__setitem__ list indices must be all Integers or all booleans"
|
|
1626
|
+
end
|
|
1627
|
+
|
|
1628
|
+
__apply_integer_list_setitem!(data, normalized, replacement)
|
|
1629
|
+
end
|
|
1630
|
+
|
|
1631
|
+
def __apply_boolean_mask_setitem!(data, mask, replacement)
|
|
1632
|
+
if mask.length != data.length
|
|
1633
|
+
raise ArgumentError, "__setitem__ boolean mask must match array length"
|
|
1634
|
+
end
|
|
1635
|
+
|
|
1636
|
+
replacement_values = replacement.is_a?(::Array) ? replacement.flatten : nil
|
|
1637
|
+
replacement_index = 0
|
|
1638
|
+
|
|
1639
|
+
mask.each_with_index do |flag, i|
|
|
1640
|
+
next unless flag
|
|
1641
|
+
|
|
1642
|
+
if replacement_values
|
|
1643
|
+
if replacement_index >= replacement_values.length
|
|
1644
|
+
raise ArgumentError, "__setitem__ replacement values shorter than mask true count"
|
|
1645
|
+
end
|
|
1646
|
+
data[i] = replacement_values[replacement_index]
|
|
1647
|
+
replacement_index += 1
|
|
1648
|
+
else
|
|
1649
|
+
data[i] = replacement
|
|
1650
|
+
end
|
|
1651
|
+
end
|
|
1652
|
+
end
|
|
1653
|
+
|
|
1654
|
+
def __apply_integer_list_setitem!(data, indices, replacement)
|
|
1655
|
+
if replacement.is_a?(::Array)
|
|
1656
|
+
values = replacement.flatten
|
|
1657
|
+
if values.length == 1
|
|
1658
|
+
indices.each { |i| data[i] = values[0] }
|
|
1659
|
+
return
|
|
1660
|
+
end
|
|
1661
|
+
if values.length != indices.length
|
|
1662
|
+
raise ArgumentError, "__setitem__ replacement values must match index list length"
|
|
1663
|
+
end
|
|
1664
|
+
|
|
1665
|
+
indices.each_with_index { |i, offset| data[i] = values[offset] }
|
|
1666
|
+
else
|
|
1667
|
+
indices.each { |i| data[i] = replacement }
|
|
1668
|
+
end
|
|
1669
|
+
end
|
|
1670
|
+
|
|
1671
|
+
def __ruby_deep_copy(value)
|
|
1672
|
+
return value.map { |item| __ruby_deep_copy(item) } if value.is_a?(::Array)
|
|
1673
|
+
|
|
1674
|
+
value
|
|
1675
|
+
end
|
|
1676
|
+
end
|
|
1677
|
+
end
|
|
1678
|
+
end
|