mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <algorithm>
|
|
6
|
+
#include <cmath>
|
|
7
|
+
#include <cstdint>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
#define __MLX_HALF_NAN__ 0x7D00
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
namespace {
|
|
15
|
+
union float_bits_fp16 {
|
|
16
|
+
float f;
|
|
17
|
+
uint32_t u;
|
|
18
|
+
};
|
|
19
|
+
} // namespace
|
|
20
|
+
|
|
21
|
+
struct _MLX_Float16 {
|
|
22
|
+
uint16_t bits_;
|
|
23
|
+
|
|
24
|
+
// Default constructor
|
|
25
|
+
_MLX_Float16() = default;
|
|
26
|
+
|
|
27
|
+
// Default copy constructor
|
|
28
|
+
_MLX_Float16(_MLX_Float16 const&) = default;
|
|
29
|
+
|
|
30
|
+
// Appease std::vector<bool> for being special
|
|
31
|
+
_MLX_Float16& operator=(std::vector<bool>::reference x) {
|
|
32
|
+
bits_ = x;
|
|
33
|
+
return *this;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
_MLX_Float16& operator=(const float& x) {
|
|
37
|
+
return (*this = _MLX_Float16(x));
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// From float32
|
|
41
|
+
_MLX_Float16(const float& x) : bits_(0) {
|
|
42
|
+
// Conversion following
|
|
43
|
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
|
44
|
+
|
|
45
|
+
// Union
|
|
46
|
+
float_bits_fp16 in;
|
|
47
|
+
|
|
48
|
+
// Take fp32 bits
|
|
49
|
+
in.f = x;
|
|
50
|
+
|
|
51
|
+
// Find and take sign bit
|
|
52
|
+
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
|
|
53
|
+
uint16_t x_sign_16 = (x_sign_32 >> 16);
|
|
54
|
+
|
|
55
|
+
if (std::isnan(x)) {
|
|
56
|
+
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
|
|
57
|
+
} else {
|
|
58
|
+
// Union
|
|
59
|
+
float_bits_fp16 inf_scale, zero_scale, magic_bits;
|
|
60
|
+
|
|
61
|
+
// Find exponent bits and take the max supported by half
|
|
62
|
+
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
|
|
63
|
+
uint32_t max_expo_32 = uint32_t(0x38800000);
|
|
64
|
+
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
|
|
65
|
+
x_expo_32 += uint32_t(15) << 23;
|
|
66
|
+
|
|
67
|
+
// Handle scaling to inf as needed
|
|
68
|
+
inf_scale.u = uint32_t(0x77800000);
|
|
69
|
+
zero_scale.u = uint32_t(0x08800000);
|
|
70
|
+
|
|
71
|
+
// Combine with magic and let addition do rounding
|
|
72
|
+
magic_bits.u = x_expo_32;
|
|
73
|
+
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
|
74
|
+
|
|
75
|
+
// Take the lower 5 bits of the exponent
|
|
76
|
+
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
|
|
77
|
+
|
|
78
|
+
// Collect the lower 12 bits which have the mantissa
|
|
79
|
+
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
|
|
80
|
+
|
|
81
|
+
// Combine sign, exp and mantissa
|
|
82
|
+
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// To float32
|
|
87
|
+
operator float() const {
|
|
88
|
+
// Conversion following
|
|
89
|
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
|
90
|
+
|
|
91
|
+
// Union
|
|
92
|
+
float_bits_fp16 out;
|
|
93
|
+
|
|
94
|
+
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
|
|
95
|
+
uint32_t base = (bits_ << 16);
|
|
96
|
+
uint32_t two_base = base + base;
|
|
97
|
+
|
|
98
|
+
uint32_t denorm_max = 1u << 27;
|
|
99
|
+
if (two_base < denorm_max) {
|
|
100
|
+
out.u = uint32_t(126) << 23; // magic mask
|
|
101
|
+
out.u |= (two_base >> 17); // Bits from fp16
|
|
102
|
+
out.f -= 0.5f; // magic bias
|
|
103
|
+
} else {
|
|
104
|
+
out.u = uint32_t(0xE0) << 23; // exponent offset
|
|
105
|
+
out.u += (two_base >> 4); // Bits from fp16
|
|
106
|
+
float out_unscaled = out.f; // Store value
|
|
107
|
+
out.u = uint32_t(0x7800000); // exponent scale
|
|
108
|
+
out.f *= out_unscaled;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// Add sign
|
|
112
|
+
out.u |= x_sign_32;
|
|
113
|
+
|
|
114
|
+
return out.f;
|
|
115
|
+
}
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
119
|
+
inline otype __operator__(atype lhs, btype rhs) { \
|
|
120
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
124
|
+
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
|
|
125
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
126
|
+
} \
|
|
127
|
+
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
|
|
128
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// Operators
|
|
132
|
+
#define half_binop(__op__, __operator__) \
|
|
133
|
+
half_binop_base( \
|
|
134
|
+
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
|
|
135
|
+
half_binop_helper(__op__, __operator__, float, float, float); \
|
|
136
|
+
half_binop_helper(__op__, __operator__, double, double, double); \
|
|
137
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
|
|
138
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
|
|
139
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
|
|
140
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
|
|
141
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
|
|
142
|
+
|
|
143
|
+
half_binop(+, operator+);
|
|
144
|
+
half_binop(-, operator-);
|
|
145
|
+
half_binop(*, operator*);
|
|
146
|
+
half_binop(/, operator/);
|
|
147
|
+
|
|
148
|
+
#undef half_binop
|
|
149
|
+
|
|
150
|
+
// Comparison ops
|
|
151
|
+
#define half_compop(__op__, __operator__) \
|
|
152
|
+
half_binop_base( \
|
|
153
|
+
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
|
|
154
|
+
half_binop_helper(__op__, __operator__, bool, float, float); \
|
|
155
|
+
half_binop_helper(__op__, __operator__, bool, double, double); \
|
|
156
|
+
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
157
|
+
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
158
|
+
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
159
|
+
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
160
|
+
|
|
161
|
+
half_compop(>, operator>);
|
|
162
|
+
half_compop(<, operator<);
|
|
163
|
+
half_compop(>=, operator>=);
|
|
164
|
+
half_compop(<=, operator<=);
|
|
165
|
+
half_compop(==, operator==);
|
|
166
|
+
half_compop(!=, operator!=);
|
|
167
|
+
|
|
168
|
+
#undef half_compop
|
|
169
|
+
|
|
170
|
+
// Negative
|
|
171
|
+
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
|
|
172
|
+
return -static_cast<float>(lhs);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// Inplace ops
|
|
176
|
+
#define half_inplace_op(__op__, __operator__) \
|
|
177
|
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
|
|
178
|
+
lhs = lhs __op__ rhs; \
|
|
179
|
+
return lhs; \
|
|
180
|
+
} \
|
|
181
|
+
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
|
|
182
|
+
lhs = lhs __op__ rhs; \
|
|
183
|
+
return lhs; \
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
half_inplace_op(+, operator+=);
|
|
187
|
+
half_inplace_op(-, operator-=);
|
|
188
|
+
half_inplace_op(*, operator*=);
|
|
189
|
+
half_inplace_op(/, operator/=);
|
|
190
|
+
|
|
191
|
+
#undef half_inplace_op
|
|
192
|
+
|
|
193
|
+
// Bitwise ops
|
|
194
|
+
|
|
195
|
+
#define half_bitop(__op__, __operator__) \
|
|
196
|
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
|
|
197
|
+
_MLX_Float16 out; \
|
|
198
|
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
|
199
|
+
return out; \
|
|
200
|
+
} \
|
|
201
|
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
|
|
202
|
+
_MLX_Float16 out; \
|
|
203
|
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
|
204
|
+
return out; \
|
|
205
|
+
} \
|
|
206
|
+
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
|
|
207
|
+
_MLX_Float16 out; \
|
|
208
|
+
out.bits_ = lhs __op__ rhs.bits_; \
|
|
209
|
+
return out; \
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
half_bitop(|, operator|);
|
|
213
|
+
half_bitop(&, operator&);
|
|
214
|
+
half_bitop(^, operator^);
|
|
215
|
+
|
|
216
|
+
#undef half_bitop
|
|
217
|
+
|
|
218
|
+
#define half_inplace_bitop(__op__, __operator__) \
|
|
219
|
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
|
|
220
|
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
|
221
|
+
return lhs; \
|
|
222
|
+
} \
|
|
223
|
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
|
|
224
|
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
|
225
|
+
return lhs; \
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
half_inplace_bitop(|, operator|=);
|
|
229
|
+
half_inplace_bitop(&, operator&=);
|
|
230
|
+
half_inplace_bitop(^, operator^=);
|
|
231
|
+
|
|
232
|
+
#undef half_inplace_bitop
|
|
233
|
+
|
|
234
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
|
6
|
+
|
|
7
|
+
#include <arm_fp16.h>
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
using ::float16_t;
|
|
10
|
+
} // namespace mlx::core
|
|
11
|
+
|
|
12
|
+
#else
|
|
13
|
+
|
|
14
|
+
#define ADD_HALF_BINOPS
|
|
15
|
+
#include "mlx/types/fp16.h"
|
|
16
|
+
namespace mlx::core {
|
|
17
|
+
typedef struct _MLX_Float16 float16_t;
|
|
18
|
+
} // namespace mlx::core
|
|
19
|
+
|
|
20
|
+
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
|
21
|
+
|
|
22
|
+
#ifdef __ARM_FEATURE_BF16
|
|
23
|
+
|
|
24
|
+
#include <arm_bf16.h>
|
|
25
|
+
namespace mlx::core {
|
|
26
|
+
using ::bfloat16_t;
|
|
27
|
+
} // namespace mlx::core
|
|
28
|
+
|
|
29
|
+
#else
|
|
30
|
+
|
|
31
|
+
#define ADD_HALF_BINOPS
|
|
32
|
+
#include "mlx/types/bf16.h"
|
|
33
|
+
namespace mlx::core {
|
|
34
|
+
typedef struct _MLX_BFloat16 bfloat16_t;
|
|
35
|
+
} // namespace mlx::core
|
|
36
|
+
|
|
37
|
+
#endif // __ARM_FEATURE_BF16
|
|
38
|
+
|
|
39
|
+
#ifdef ADD_HALF_BINOPS
|
|
40
|
+
namespace mlx::core {
|
|
41
|
+
|
|
42
|
+
// clang-format off
|
|
43
|
+
#define fp16_bf16_binop_helper(__op__, __operator__) \
|
|
44
|
+
inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
|
|
45
|
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
46
|
+
} \
|
|
47
|
+
inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
|
|
48
|
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
fp16_bf16_binop_helper(+, operator+)
|
|
52
|
+
fp16_bf16_binop_helper(-, operator-)
|
|
53
|
+
fp16_bf16_binop_helper(*, operator*)
|
|
54
|
+
fp16_bf16_binop_helper(/, operator/)
|
|
55
|
+
// clang-format on
|
|
56
|
+
|
|
57
|
+
} // namespace mlx::core
|
|
58
|
+
#endif
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
#pragma once
|
|
3
|
+
|
|
4
|
+
#include <limits>
|
|
5
|
+
#include "mlx/types/half_types.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
template <typename T>
|
|
10
|
+
struct numeric_limits;
|
|
11
|
+
|
|
12
|
+
template <>
|
|
13
|
+
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
|
14
|
+
|
|
15
|
+
template <>
|
|
16
|
+
struct numeric_limits<double> : public std::numeric_limits<double> {};
|
|
17
|
+
|
|
18
|
+
template <>
|
|
19
|
+
struct numeric_limits<float16_t> {
|
|
20
|
+
private:
|
|
21
|
+
union half_or_bits {
|
|
22
|
+
uint16_t bits;
|
|
23
|
+
float16_t value;
|
|
24
|
+
};
|
|
25
|
+
constexpr static float16_t bits_to_half(uint16_t v) {
|
|
26
|
+
return half_or_bits{v}.value;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
public:
|
|
30
|
+
constexpr static float16_t lowest() {
|
|
31
|
+
return bits_to_half(0xFBFF);
|
|
32
|
+
}
|
|
33
|
+
static constexpr float16_t max() {
|
|
34
|
+
return bits_to_half(0x7BFF);
|
|
35
|
+
}
|
|
36
|
+
static constexpr float16_t epsilon() {
|
|
37
|
+
return bits_to_half(0x1400);
|
|
38
|
+
}
|
|
39
|
+
static constexpr float16_t infinity() {
|
|
40
|
+
return bits_to_half(0x7C00);
|
|
41
|
+
}
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
template <>
|
|
45
|
+
struct numeric_limits<bfloat16_t> {
|
|
46
|
+
private:
|
|
47
|
+
union bfloat_or_bits {
|
|
48
|
+
uint16_t bits;
|
|
49
|
+
bfloat16_t value;
|
|
50
|
+
};
|
|
51
|
+
constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {
|
|
52
|
+
return bfloat_or_bits{v}.value;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
public:
|
|
56
|
+
constexpr static bfloat16_t lowest() {
|
|
57
|
+
return bits_to_bfloat(0xFF7F);
|
|
58
|
+
}
|
|
59
|
+
static constexpr bfloat16_t max() {
|
|
60
|
+
return bits_to_bfloat(0x7F7F);
|
|
61
|
+
}
|
|
62
|
+
static constexpr bfloat16_t epsilon() {
|
|
63
|
+
return bits_to_bfloat(0x3C00);
|
|
64
|
+
}
|
|
65
|
+
static constexpr bfloat16_t infinity() {
|
|
66
|
+
return bits_to_bfloat(0x7F80);
|
|
67
|
+
}
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
} // namespace mlx::core
|
data/mlx/mlx/utils.cpp
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <cstdlib>
|
|
4
|
+
#include <iostream>
|
|
5
|
+
#include <sstream>
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
#include "mlx/dtype_utils.h"
|
|
9
|
+
#include "mlx/types/limits.h"
|
|
10
|
+
#include "mlx/utils.h"
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
Stream to_stream(StreamOrDevice s) {
|
|
15
|
+
if (std::holds_alternative<std::monostate>(s)) {
|
|
16
|
+
return default_stream(default_device());
|
|
17
|
+
} else if (std::holds_alternative<Device>(s)) {
|
|
18
|
+
return default_stream(std::get<Device>(s));
|
|
19
|
+
} else {
|
|
20
|
+
return std::get<Stream>(s);
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
Stream to_stream(StreamOrDevice s, Device default_) {
|
|
25
|
+
if (std::holds_alternative<std::monostate>(s)) {
|
|
26
|
+
return default_stream(default_);
|
|
27
|
+
} else if (std::holds_alternative<Device>(s)) {
|
|
28
|
+
return default_stream(std::get<Device>(s));
|
|
29
|
+
} else {
|
|
30
|
+
return std::get<Stream>(s);
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
void PrintFormatter::print(std::ostream& os, bool val) {
|
|
35
|
+
if (capitalize_bool) {
|
|
36
|
+
os << (val ? "True" : "False");
|
|
37
|
+
} else {
|
|
38
|
+
os << val;
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
inline void PrintFormatter::print(std::ostream& os, int16_t val) {
|
|
42
|
+
os << val;
|
|
43
|
+
}
|
|
44
|
+
inline void PrintFormatter::print(std::ostream& os, uint16_t val) {
|
|
45
|
+
os << val;
|
|
46
|
+
}
|
|
47
|
+
inline void PrintFormatter::print(std::ostream& os, int32_t val) {
|
|
48
|
+
os << val;
|
|
49
|
+
}
|
|
50
|
+
inline void PrintFormatter::print(std::ostream& os, uint32_t val) {
|
|
51
|
+
os << val;
|
|
52
|
+
}
|
|
53
|
+
inline void PrintFormatter::print(std::ostream& os, int64_t val) {
|
|
54
|
+
os << val;
|
|
55
|
+
}
|
|
56
|
+
inline void PrintFormatter::print(std::ostream& os, uint64_t val) {
|
|
57
|
+
os << val;
|
|
58
|
+
}
|
|
59
|
+
inline void PrintFormatter::print(std::ostream& os, float16_t val) {
|
|
60
|
+
os << val;
|
|
61
|
+
}
|
|
62
|
+
inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
|
|
63
|
+
os << val;
|
|
64
|
+
}
|
|
65
|
+
inline void PrintFormatter::print(std::ostream& os, float val) {
|
|
66
|
+
os << val;
|
|
67
|
+
}
|
|
68
|
+
inline void PrintFormatter::print(std::ostream& os, double val) {
|
|
69
|
+
os << val;
|
|
70
|
+
}
|
|
71
|
+
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
|
72
|
+
os << val.real();
|
|
73
|
+
if (val.imag() >= 0 || std::isnan(val.imag())) {
|
|
74
|
+
os << "+" << val.imag() << "j";
|
|
75
|
+
} else {
|
|
76
|
+
os << "-" << -val.imag() << "j";
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
PrintFormatter& get_global_formatter() {
|
|
81
|
+
static PrintFormatter formatter;
|
|
82
|
+
return formatter;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
void abort_with_exception(const std::exception& error) {
|
|
86
|
+
std::ostringstream msg;
|
|
87
|
+
msg << "Terminating due to uncaught exception: " << error.what();
|
|
88
|
+
std::cerr << msg.str() << std::endl;
|
|
89
|
+
std::abort();
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
Dtype result_type(const std::vector<array>& arrays) {
|
|
93
|
+
Dtype t = bool_;
|
|
94
|
+
for (auto& arr : arrays) {
|
|
95
|
+
t = promote_types(t, arr.dtype());
|
|
96
|
+
}
|
|
97
|
+
return t;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
|
101
|
+
// Use the same broadcasting rules as numpy
|
|
102
|
+
// https://numpy.org/doc/1.20/user/theory.broadcasting.html
|
|
103
|
+
// "The size of the trailing axes for both arrays in an operation must
|
|
104
|
+
// either be the same size or one of them must be one."
|
|
105
|
+
int ndim1 = s1.size();
|
|
106
|
+
int ndim2 = s2.size();
|
|
107
|
+
int ndim = std::max(ndim1, ndim2);
|
|
108
|
+
int diff = std::abs(ndim1 - ndim2);
|
|
109
|
+
const auto& big = ndim1 > ndim2 ? s1 : s2;
|
|
110
|
+
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
|
111
|
+
Shape out_shape(ndim);
|
|
112
|
+
for (int i = ndim - 1; i >= diff; --i) {
|
|
113
|
+
auto a = big[i];
|
|
114
|
+
auto b = small[i - diff];
|
|
115
|
+
if (b == a) {
|
|
116
|
+
out_shape[i] = a;
|
|
117
|
+
} else if (a == 1 || b == 1) {
|
|
118
|
+
// 0 if a or b is 0 otherwise max(a, b)
|
|
119
|
+
out_shape[i] = a * b;
|
|
120
|
+
} else {
|
|
121
|
+
std::ostringstream msg;
|
|
122
|
+
msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
|
|
123
|
+
<< " cannot be broadcast.";
|
|
124
|
+
throw std::invalid_argument(msg.str());
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
for (int i = diff - 1; i >= 0; --i) {
|
|
128
|
+
out_shape[i] = big[i];
|
|
129
|
+
}
|
|
130
|
+
return out_shape;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
int normalize_axis_index(
|
|
134
|
+
int axis,
|
|
135
|
+
int ndim,
|
|
136
|
+
const std::string& msg_prefix /* = "" */) {
|
|
137
|
+
if (axis < -ndim || axis >= ndim) {
|
|
138
|
+
std::ostringstream msg;
|
|
139
|
+
msg << msg_prefix << "Axis " << axis << " is out of bounds for array with "
|
|
140
|
+
<< ndim << " dimensions.";
|
|
141
|
+
throw std::invalid_argument(msg.str());
|
|
142
|
+
}
|
|
143
|
+
return axis < 0 ? axis + ndim : axis;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
|
147
|
+
os << "Device(";
|
|
148
|
+
switch (d.type) {
|
|
149
|
+
case Device::cpu:
|
|
150
|
+
os << "cpu";
|
|
151
|
+
break;
|
|
152
|
+
case Device::gpu:
|
|
153
|
+
os << "gpu";
|
|
154
|
+
break;
|
|
155
|
+
}
|
|
156
|
+
os << ", " << d.index << ")";
|
|
157
|
+
return os;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
std::ostream& operator<<(std::ostream& os, const Stream& s) {
|
|
161
|
+
os << "Stream(";
|
|
162
|
+
os << s.device;
|
|
163
|
+
os << ", " << s.index << ")";
|
|
164
|
+
return os;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
std::ostream& operator<<(std::ostream& os, int8_t x) {
|
|
168
|
+
os << static_cast<int>(x);
|
|
169
|
+
return os;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
|
173
|
+
os << static_cast<unsigned int>(x);
|
|
174
|
+
return os;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
namespace {
|
|
178
|
+
|
|
179
|
+
template <typename T>
|
|
180
|
+
void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
|
|
181
|
+
int num_print = 3;
|
|
182
|
+
int n = a.shape(dim);
|
|
183
|
+
size_t s = a.strides()[dim];
|
|
184
|
+
bool is_last = dim == a.ndim() - 1;
|
|
185
|
+
auto prefix = is_last ? "" : std::string(7 + dim, ' ');
|
|
186
|
+
auto postfix = is_last ? ", " : ",\n";
|
|
187
|
+
os << "[";
|
|
188
|
+
for (int i = 0; i < n; ++i) {
|
|
189
|
+
os << (i == 0 ? "" : prefix);
|
|
190
|
+
if (i == num_print && n > 2 * num_print) {
|
|
191
|
+
os << "...";
|
|
192
|
+
i = n - num_print - 1;
|
|
193
|
+
index += s * (n - 2 * num_print - 1);
|
|
194
|
+
} else if (is_last) {
|
|
195
|
+
get_global_formatter().print(os, a.data<T>()[index]);
|
|
196
|
+
} else {
|
|
197
|
+
print_subarray<T>(os, a, index, dim + 1);
|
|
198
|
+
}
|
|
199
|
+
os << (i == n - 1 ? "" : postfix);
|
|
200
|
+
index += s;
|
|
201
|
+
}
|
|
202
|
+
os << "]";
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
template <typename T>
|
|
206
|
+
void print_array(std::ostream& os, const array& a) {
|
|
207
|
+
os << std::boolalpha;
|
|
208
|
+
os << "array(";
|
|
209
|
+
if (a.ndim() == 0) {
|
|
210
|
+
auto data = a.data<T>();
|
|
211
|
+
get_global_formatter().print(os, data[0]);
|
|
212
|
+
} else {
|
|
213
|
+
print_subarray<T>(os, a, 0, 0);
|
|
214
|
+
}
|
|
215
|
+
os << ", dtype=" << a.dtype() << ")";
|
|
216
|
+
os << std::noboolalpha;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
} // namespace
|
|
220
|
+
|
|
221
|
+
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
|
222
|
+
return os << dtype_to_string(dtype);
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
|
226
|
+
switch (k) {
|
|
227
|
+
case Dtype::Kind::b:
|
|
228
|
+
return os << "b";
|
|
229
|
+
case Dtype::Kind::i:
|
|
230
|
+
return os << "i";
|
|
231
|
+
case Dtype::Kind::u:
|
|
232
|
+
return os << "u";
|
|
233
|
+
case Dtype::Kind::f:
|
|
234
|
+
return os << "f";
|
|
235
|
+
case Dtype::Kind::c:
|
|
236
|
+
return os << "c";
|
|
237
|
+
case Dtype::Kind::V:
|
|
238
|
+
return os << "V";
|
|
239
|
+
}
|
|
240
|
+
return os;
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
std::ostream& operator<<(std::ostream& os, array a) {
|
|
244
|
+
a.eval();
|
|
245
|
+
dispatch_all_types(a.dtype(), [&](auto type_tag) {
|
|
246
|
+
print_array<MLX_GET_TYPE(type_tag)>(os, a);
|
|
247
|
+
});
|
|
248
|
+
return os;
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
namespace env {
|
|
252
|
+
|
|
253
|
+
int get_var(const char* name, int default_value) {
|
|
254
|
+
if (const char* buff_str = std::getenv(name)) {
|
|
255
|
+
return atoi(buff_str);
|
|
256
|
+
} else {
|
|
257
|
+
return default_value;
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
} // namespace env
|
|
262
|
+
|
|
263
|
+
template <typename T>
|
|
264
|
+
void set_finfo_limits(double& min, double& max, double& eps) {
|
|
265
|
+
min = numeric_limits<T>::lowest();
|
|
266
|
+
max = numeric_limits<T>::max();
|
|
267
|
+
eps = numeric_limits<T>::epsilon();
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
|
271
|
+
if (!issubdtype(dtype, inexact)) {
|
|
272
|
+
std::ostringstream msg;
|
|
273
|
+
msg << "[finfo] dtype " << dtype << " is not inexact.";
|
|
274
|
+
throw std::invalid_argument(msg.str());
|
|
275
|
+
}
|
|
276
|
+
if (dtype == float32) {
|
|
277
|
+
set_finfo_limits<float>(min, max, eps);
|
|
278
|
+
} else if (dtype == float16) {
|
|
279
|
+
set_finfo_limits<float16_t>(min, max, eps);
|
|
280
|
+
} else if (dtype == bfloat16) {
|
|
281
|
+
set_finfo_limits<bfloat16_t>(min, max, eps);
|
|
282
|
+
} else if (dtype == float64) {
|
|
283
|
+
set_finfo_limits<double>(min, max, eps);
|
|
284
|
+
} else if (dtype == complex64) {
|
|
285
|
+
this->dtype = float32;
|
|
286
|
+
set_finfo_limits<float>(min, max, eps);
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
template <typename T>
|
|
291
|
+
void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
|
292
|
+
min = std::numeric_limits<T>::min();
|
|
293
|
+
max = std::numeric_limits<T>::max();
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
|
297
|
+
dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) {
|
|
298
|
+
set_iinfo_limits<MLX_GET_TYPE(type_tag)>(min, max);
|
|
299
|
+
});
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
} // namespace mlx::core
|