mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
data/mlx/mlx/fft.cpp
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
#include <numeric>
|
|
3
|
+
#include <set>
|
|
4
|
+
|
|
5
|
+
#include "mlx/fft.h"
|
|
6
|
+
#include "mlx/ops.h"
|
|
7
|
+
#include "mlx/primitives.h"
|
|
8
|
+
#include "mlx/utils.h"
|
|
9
|
+
|
|
10
|
+
namespace mlx::core::fft {
|
|
11
|
+
|
|
12
|
+
array fft_impl(
|
|
13
|
+
const array& a,
|
|
14
|
+
Shape n,
|
|
15
|
+
const std::vector<int>& axes,
|
|
16
|
+
bool real,
|
|
17
|
+
bool inverse,
|
|
18
|
+
StreamOrDevice s) {
|
|
19
|
+
if (a.ndim() < 1) {
|
|
20
|
+
throw std::invalid_argument(
|
|
21
|
+
"[fftn] Requires array with at least one dimension.");
|
|
22
|
+
}
|
|
23
|
+
if (n.size() != axes.size()) {
|
|
24
|
+
throw std::invalid_argument("[fftn] Shape and axes have different sizes.");
|
|
25
|
+
}
|
|
26
|
+
if (axes.empty()) {
|
|
27
|
+
return a;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
std::vector<size_t> valid_axes;
|
|
31
|
+
for (int ax : axes) {
|
|
32
|
+
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
|
|
33
|
+
}
|
|
34
|
+
std::set<int> unique_axes(valid_axes.begin(), valid_axes.end());
|
|
35
|
+
if (unique_axes.size() != axes.size()) {
|
|
36
|
+
std::ostringstream msg;
|
|
37
|
+
msg << "[fftn] Duplicated axis received " << axes;
|
|
38
|
+
throw std::invalid_argument(msg.str());
|
|
39
|
+
}
|
|
40
|
+
if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) {
|
|
41
|
+
std::ostringstream msg;
|
|
42
|
+
msg << "[fftn] Invalid axis received for array with " << a.ndim()
|
|
43
|
+
<< " dimensions.";
|
|
44
|
+
throw std::invalid_argument(msg.str());
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// In the following shape manipulations there are three cases to consider:
|
|
48
|
+
// 1. In a complex to complex transform (fftn / ifftn) the output
|
|
49
|
+
// and input shapes are the same.
|
|
50
|
+
// 2. In a real to complex transform (rfftn) n specifies the input dims
|
|
51
|
+
// and the output dims are n[i] / 2 + 1
|
|
52
|
+
// 3 In a complex to real transform (irfftn) n specifies the output dims
|
|
53
|
+
// and the input dims are n[i] / 2 + 1
|
|
54
|
+
|
|
55
|
+
if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) {
|
|
56
|
+
std::ostringstream msg;
|
|
57
|
+
msg << "[fftn] Invalid FFT output size requested " << n;
|
|
58
|
+
throw std::invalid_argument(msg.str());
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
auto in_shape = a.shape();
|
|
62
|
+
for (int i = 0; i < valid_axes.size(); ++i) {
|
|
63
|
+
in_shape[valid_axes[i]] = n[i];
|
|
64
|
+
}
|
|
65
|
+
if (real && inverse) {
|
|
66
|
+
in_shape[valid_axes.back()] = n.back() / 2 + 1;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
bool any_greater = false;
|
|
70
|
+
bool any_less = false;
|
|
71
|
+
for (int i = 0; i < in_shape.size(); ++i) {
|
|
72
|
+
any_greater |= in_shape[i] > a.shape()[i];
|
|
73
|
+
any_less |= in_shape[i] < a.shape()[i];
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
auto in = a;
|
|
77
|
+
if (any_less) {
|
|
78
|
+
in = slice(in, Shape(in.ndim(), 0), in_shape, s);
|
|
79
|
+
}
|
|
80
|
+
if (any_greater) {
|
|
81
|
+
// Pad with zeros
|
|
82
|
+
auto tmp = zeros(in_shape, a.dtype(), s);
|
|
83
|
+
in = slice_update(tmp, in, Shape(in.ndim(), 0), in.shape());
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
auto out_shape = in_shape;
|
|
87
|
+
if (real) {
|
|
88
|
+
auto ax = valid_axes.back();
|
|
89
|
+
out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
auto in_type = real && !inverse ? float32 : complex64;
|
|
93
|
+
auto out_type = real && inverse ? float32 : complex64;
|
|
94
|
+
return array(
|
|
95
|
+
out_shape,
|
|
96
|
+
out_type,
|
|
97
|
+
std::make_shared<FFT>(to_stream(s), valid_axes, inverse, real),
|
|
98
|
+
{astype(in, in_type, s)});
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
array fft_impl(
|
|
102
|
+
const array& a,
|
|
103
|
+
const std::vector<int>& axes,
|
|
104
|
+
bool real,
|
|
105
|
+
bool inverse,
|
|
106
|
+
StreamOrDevice s) {
|
|
107
|
+
Shape n;
|
|
108
|
+
for (auto ax : axes) {
|
|
109
|
+
n.push_back(a.shape(ax));
|
|
110
|
+
}
|
|
111
|
+
if (real && inverse && a.ndim() > 0) {
|
|
112
|
+
n.back() = (n.back() - 1) * 2;
|
|
113
|
+
}
|
|
114
|
+
return fft_impl(a, n, axes, real, inverse, s);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {
|
|
118
|
+
std::vector<int> axes(a.ndim());
|
|
119
|
+
std::iota(axes.begin(), axes.end(), 0);
|
|
120
|
+
return fft_impl(a, axes, real, inverse, s);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
array fftn(
|
|
124
|
+
const array& a,
|
|
125
|
+
const Shape& n,
|
|
126
|
+
const std::vector<int>& axes,
|
|
127
|
+
StreamOrDevice s /* = {} */) {
|
|
128
|
+
return fft_impl(a, n, axes, false, false, s);
|
|
129
|
+
}
|
|
130
|
+
array fftn(
|
|
131
|
+
const array& a,
|
|
132
|
+
const std::vector<int>& axes,
|
|
133
|
+
StreamOrDevice s /* = {} */) {
|
|
134
|
+
return fft_impl(a, axes, false, false, s);
|
|
135
|
+
}
|
|
136
|
+
array fftn(const array& a, StreamOrDevice s /* = {} */) {
|
|
137
|
+
return fft_impl(a, false, false, s);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
array ifftn(
|
|
141
|
+
const array& a,
|
|
142
|
+
const Shape& n,
|
|
143
|
+
const std::vector<int>& axes,
|
|
144
|
+
StreamOrDevice s /* = {} */) {
|
|
145
|
+
return fft_impl(a, n, axes, false, true, s);
|
|
146
|
+
}
|
|
147
|
+
array ifftn(
|
|
148
|
+
const array& a,
|
|
149
|
+
const std::vector<int>& axes,
|
|
150
|
+
StreamOrDevice s /* = {} */) {
|
|
151
|
+
return fft_impl(a, axes, false, true, s);
|
|
152
|
+
}
|
|
153
|
+
array ifftn(const array& a, StreamOrDevice s /* = {} */) {
|
|
154
|
+
return fft_impl(a, false, true, s);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
array rfftn(
|
|
158
|
+
const array& a,
|
|
159
|
+
const Shape& n,
|
|
160
|
+
const std::vector<int>& axes,
|
|
161
|
+
StreamOrDevice s /* = {} */) {
|
|
162
|
+
return fft_impl(a, n, axes, true, false, s);
|
|
163
|
+
}
|
|
164
|
+
array rfftn(
|
|
165
|
+
const array& a,
|
|
166
|
+
const std::vector<int>& axes,
|
|
167
|
+
StreamOrDevice s /* = {} */) {
|
|
168
|
+
return fft_impl(a, axes, true, false, s);
|
|
169
|
+
}
|
|
170
|
+
array rfftn(const array& a, StreamOrDevice s /* = {} */) {
|
|
171
|
+
return fft_impl(a, true, false, s);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
array irfftn(
|
|
175
|
+
const array& a,
|
|
176
|
+
const Shape& n,
|
|
177
|
+
const std::vector<int>& axes,
|
|
178
|
+
StreamOrDevice s /* = {} */) {
|
|
179
|
+
return fft_impl(a, n, axes, true, true, s);
|
|
180
|
+
}
|
|
181
|
+
array irfftn(
|
|
182
|
+
const array& a,
|
|
183
|
+
const std::vector<int>& axes,
|
|
184
|
+
StreamOrDevice s /* = {} */) {
|
|
185
|
+
return fft_impl(a, axes, true, true, s);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
|
189
|
+
return fft_impl(a, true, true, s);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
array fftshift(
|
|
193
|
+
const array& a,
|
|
194
|
+
const std::vector<int>& axes,
|
|
195
|
+
StreamOrDevice s /* = {} */) {
|
|
196
|
+
if (axes.empty()) {
|
|
197
|
+
return a;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
Shape shifts;
|
|
201
|
+
for (int ax : axes) {
|
|
202
|
+
// Convert negative axes to positive
|
|
203
|
+
int axis = ax < 0 ? ax + a.ndim() : ax;
|
|
204
|
+
if (axis < 0 || axis >= a.ndim()) {
|
|
205
|
+
std::ostringstream msg;
|
|
206
|
+
msg << "[fftshift] Invalid axis " << ax << " for array with " << a.ndim()
|
|
207
|
+
<< " dimensions.";
|
|
208
|
+
throw std::invalid_argument(msg.str());
|
|
209
|
+
}
|
|
210
|
+
// Match NumPy's implementation
|
|
211
|
+
shifts.push_back(a.shape(axis) / 2);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
return roll(a, shifts, axes, s);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
array ifftshift(
|
|
218
|
+
const array& a,
|
|
219
|
+
const std::vector<int>& axes,
|
|
220
|
+
StreamOrDevice s /* = {} */) {
|
|
221
|
+
if (axes.empty()) {
|
|
222
|
+
return a;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
Shape shifts;
|
|
226
|
+
for (int ax : axes) {
|
|
227
|
+
// Convert negative axes to positive
|
|
228
|
+
int axis = ax < 0 ? ax + a.ndim() : ax;
|
|
229
|
+
if (axis < 0 || axis >= a.ndim()) {
|
|
230
|
+
std::ostringstream msg;
|
|
231
|
+
msg << "[ifftshift] Invalid axis " << ax << " for array with " << a.ndim()
|
|
232
|
+
<< " dimensions.";
|
|
233
|
+
throw std::invalid_argument(msg.str());
|
|
234
|
+
}
|
|
235
|
+
// Match NumPy's implementation
|
|
236
|
+
int size = a.shape(axis);
|
|
237
|
+
shifts.push_back(-(size / 2));
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
return roll(a, shifts, axes, s);
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// Default versions that operate on all axes
|
|
244
|
+
array fftshift(const array& a, StreamOrDevice s /* = {} */) {
|
|
245
|
+
if (a.ndim() < 1) {
|
|
246
|
+
return a;
|
|
247
|
+
}
|
|
248
|
+
std::vector<int> axes(a.ndim());
|
|
249
|
+
std::iota(axes.begin(), axes.end(), 0);
|
|
250
|
+
return fftshift(a, axes, s);
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
array ifftshift(const array& a, StreamOrDevice s /* = {} */) {
|
|
254
|
+
if (a.ndim() < 1) {
|
|
255
|
+
return a;
|
|
256
|
+
}
|
|
257
|
+
std::vector<int> axes(a.ndim());
|
|
258
|
+
std::iota(axes.begin(), axes.end(), 0);
|
|
259
|
+
return ifftshift(a, axes, s);
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
} // namespace mlx::core::fft
|
data/mlx/mlx/fft.h
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <variant>
|
|
6
|
+
|
|
7
|
+
#include "array.h"
|
|
8
|
+
#include "device.h"
|
|
9
|
+
#include "mlx/api.h"
|
|
10
|
+
#include "utils.h"
|
|
11
|
+
|
|
12
|
+
namespace mlx::core::fft {
|
|
13
|
+
|
|
14
|
+
/** Compute the n-dimensional Fourier Transform. */
|
|
15
|
+
MLX_API array fftn(
|
|
16
|
+
const array& a,
|
|
17
|
+
const Shape& n,
|
|
18
|
+
const std::vector<int>& axes,
|
|
19
|
+
StreamOrDevice s = {});
|
|
20
|
+
MLX_API array
|
|
21
|
+
fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
22
|
+
MLX_API array fftn(const array& a, StreamOrDevice s = {});
|
|
23
|
+
|
|
24
|
+
/** Compute the n-dimensional inverse Fourier Transform. */
|
|
25
|
+
MLX_API array ifftn(
|
|
26
|
+
const array& a,
|
|
27
|
+
const Shape& n,
|
|
28
|
+
const std::vector<int>& axes,
|
|
29
|
+
StreamOrDevice s = {});
|
|
30
|
+
MLX_API array
|
|
31
|
+
ifftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
32
|
+
MLX_API array ifftn(const array& a, StreamOrDevice s = {});
|
|
33
|
+
|
|
34
|
+
/** Compute the one-dimensional Fourier Transform. */
|
|
35
|
+
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
36
|
+
return fftn(a, {n}, {axis}, s);
|
|
37
|
+
}
|
|
38
|
+
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
39
|
+
return fftn(a, {axis}, s);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
/** Compute the one-dimensional inverse Fourier Transform. */
|
|
43
|
+
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
44
|
+
return ifftn(a, {n}, {axis}, s);
|
|
45
|
+
}
|
|
46
|
+
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
47
|
+
return ifftn(a, {axis}, s);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/** Compute the two-dimensional Fourier Transform. */
|
|
51
|
+
inline array fft2(
|
|
52
|
+
const array& a,
|
|
53
|
+
const Shape& n,
|
|
54
|
+
const std::vector<int>& axes,
|
|
55
|
+
StreamOrDevice s = {}) {
|
|
56
|
+
return fftn(a, n, axes, s);
|
|
57
|
+
}
|
|
58
|
+
inline array fft2(
|
|
59
|
+
const array& a,
|
|
60
|
+
const std::vector<int>& axes = {-2, -1},
|
|
61
|
+
StreamOrDevice s = {}) {
|
|
62
|
+
return fftn(a, axes, s);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/** Compute the two-dimensional inverse Fourier Transform. */
|
|
66
|
+
inline array ifft2(
|
|
67
|
+
const array& a,
|
|
68
|
+
const Shape& n,
|
|
69
|
+
const std::vector<int>& axes,
|
|
70
|
+
StreamOrDevice s = {}) {
|
|
71
|
+
return ifftn(a, n, axes, s);
|
|
72
|
+
}
|
|
73
|
+
inline array ifft2(
|
|
74
|
+
const array& a,
|
|
75
|
+
const std::vector<int>& axes = {-2, -1},
|
|
76
|
+
StreamOrDevice s = {}) {
|
|
77
|
+
return ifftn(a, axes, s);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
/** Compute the n-dimensional Fourier Transform on a real input. */
|
|
81
|
+
MLX_API array rfftn(
|
|
82
|
+
const array& a,
|
|
83
|
+
const Shape& n,
|
|
84
|
+
const std::vector<int>& axes,
|
|
85
|
+
StreamOrDevice s = {});
|
|
86
|
+
MLX_API array
|
|
87
|
+
rfftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
88
|
+
MLX_API array rfftn(const array& a, StreamOrDevice s = {});
|
|
89
|
+
|
|
90
|
+
/** Compute the n-dimensional inverse of `rfftn`. */
|
|
91
|
+
MLX_API array irfftn(
|
|
92
|
+
const array& a,
|
|
93
|
+
const Shape& n,
|
|
94
|
+
const std::vector<int>& axes,
|
|
95
|
+
StreamOrDevice s = {});
|
|
96
|
+
MLX_API array
|
|
97
|
+
irfftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
98
|
+
MLX_API array irfftn(const array& a, StreamOrDevice s = {});
|
|
99
|
+
|
|
100
|
+
/** Compute the one-dimensional Fourier Transform on a real input. */
|
|
101
|
+
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
102
|
+
return rfftn(a, {n}, {axis}, s);
|
|
103
|
+
}
|
|
104
|
+
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
105
|
+
return rfftn(a, {axis}, s);
|
|
106
|
+
}
|
|
107
|
+
/** Compute the one-dimensional inverse of `rfft`. */
|
|
108
|
+
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
|
109
|
+
return irfftn(a, {n}, {axis}, s);
|
|
110
|
+
}
|
|
111
|
+
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
|
112
|
+
return irfftn(a, {axis}, s);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
/** Compute the two-dimensional Fourier Transform on a real input. */
|
|
116
|
+
inline array rfft2(
|
|
117
|
+
const array& a,
|
|
118
|
+
const Shape& n,
|
|
119
|
+
const std::vector<int>& axes,
|
|
120
|
+
StreamOrDevice s = {}) {
|
|
121
|
+
return rfftn(a, n, axes, s);
|
|
122
|
+
}
|
|
123
|
+
inline array rfft2(
|
|
124
|
+
const array& a,
|
|
125
|
+
const std::vector<int>& axes = {-2, -1},
|
|
126
|
+
StreamOrDevice s = {}) {
|
|
127
|
+
return rfftn(a, axes, s);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
/** Compute the two-dimensional inverse of `rfft2`. */
|
|
131
|
+
inline array irfft2(
|
|
132
|
+
const array& a,
|
|
133
|
+
const Shape& n,
|
|
134
|
+
const std::vector<int>& axes,
|
|
135
|
+
StreamOrDevice s = {}) {
|
|
136
|
+
return irfftn(a, n, axes, s);
|
|
137
|
+
}
|
|
138
|
+
inline array irfft2(
|
|
139
|
+
const array& a,
|
|
140
|
+
const std::vector<int>& axes = {-2, -1},
|
|
141
|
+
StreamOrDevice s = {}) {
|
|
142
|
+
return irfftn(a, axes, s);
|
|
143
|
+
}
|
|
144
|
+
/** Shift the zero-frequency component to the center of the spectrum. */
|
|
145
|
+
MLX_API array fftshift(const array& a, StreamOrDevice s = {});
|
|
146
|
+
|
|
147
|
+
/** Shift the zero-frequency component to the center of the spectrum along
|
|
148
|
+
* specified axes. */
|
|
149
|
+
MLX_API array
|
|
150
|
+
fftshift(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
151
|
+
|
|
152
|
+
/** The inverse of fftshift. */
|
|
153
|
+
MLX_API array ifftshift(const array& a, StreamOrDevice s = {});
|
|
154
|
+
|
|
155
|
+
/** The inverse of fftshift along specified axes. */
|
|
156
|
+
MLX_API array
|
|
157
|
+
ifftshift(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
|
158
|
+
|
|
159
|
+
} // namespace mlx::core::fft
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <functional>
|
|
4
|
+
#include <optional>
|
|
5
|
+
#include <sstream>
|
|
6
|
+
#include <unordered_map>
|
|
7
|
+
#include <unordered_set>
|
|
8
|
+
|
|
9
|
+
#include "mlx/graph_utils.h"
|
|
10
|
+
#include "mlx/primitives.h"
|
|
11
|
+
#include "mlx/utils.h"
|
|
12
|
+
|
|
13
|
+
namespace mlx::core {
|
|
14
|
+
|
|
15
|
+
const std::string& NodeNamer::get_name(const array& x) {
|
|
16
|
+
auto it = names.find(x.id());
|
|
17
|
+
if (it == names.end()) {
|
|
18
|
+
// Get the next name in the sequence
|
|
19
|
+
// [A, B, ..., Z, AA, AB, ...]
|
|
20
|
+
std::vector<char> letters;
|
|
21
|
+
auto var_num = names.size() + 1;
|
|
22
|
+
while (var_num > 0) {
|
|
23
|
+
letters.push_back('A' + (var_num - 1) % 26);
|
|
24
|
+
var_num = (var_num - 1) / 26;
|
|
25
|
+
}
|
|
26
|
+
names.emplace(x.id(), std::string(letters.rbegin(), letters.rend()));
|
|
27
|
+
|
|
28
|
+
return get_name(x);
|
|
29
|
+
}
|
|
30
|
+
return it->second;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
void NodeNamer::set_name(const array& x, std::string n) {
|
|
34
|
+
names[x.id()] = std::move(n);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
void depth_first_traversal(
|
|
38
|
+
std::function<void(array)> callback,
|
|
39
|
+
const std::vector<array>& outputs) {
|
|
40
|
+
std::function<void(const array&)> recurse;
|
|
41
|
+
std::unordered_set<std::uintptr_t> cache;
|
|
42
|
+
recurse = [&](const array& x) {
|
|
43
|
+
auto id = x.id();
|
|
44
|
+
if (cache.find(id) != cache.end()) {
|
|
45
|
+
return;
|
|
46
|
+
}
|
|
47
|
+
cache.insert(id);
|
|
48
|
+
for (auto& s : x.siblings()) {
|
|
49
|
+
cache.insert(s.id());
|
|
50
|
+
}
|
|
51
|
+
for (auto& in : x.inputs()) {
|
|
52
|
+
recurse(in);
|
|
53
|
+
}
|
|
54
|
+
callback(x);
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
for (auto& o : outputs) {
|
|
58
|
+
recurse(o);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
void print_graph(
|
|
63
|
+
std::ostream& os,
|
|
64
|
+
NodeNamer namer,
|
|
65
|
+
const std::vector<array>& outputs) {
|
|
66
|
+
std::vector<array> tape;
|
|
67
|
+
std::vector<array> inputs;
|
|
68
|
+
|
|
69
|
+
depth_first_traversal(
|
|
70
|
+
[&](const array& x) {
|
|
71
|
+
if (x.has_primitive()) {
|
|
72
|
+
tape.push_back(x);
|
|
73
|
+
} else {
|
|
74
|
+
inputs.push_back(x);
|
|
75
|
+
}
|
|
76
|
+
},
|
|
77
|
+
outputs);
|
|
78
|
+
|
|
79
|
+
auto print_arrs = [&namer, &os](std::vector<array> arrs) {
|
|
80
|
+
for (auto& arr : arrs) {
|
|
81
|
+
os << namer.get_name(arr);
|
|
82
|
+
os << " [" << arr.shape() << ", " << arr.dtype() << "]";
|
|
83
|
+
if (&arr != &arrs.back()) {
|
|
84
|
+
os << ", ";
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
};
|
|
88
|
+
|
|
89
|
+
os << "Inputs: ";
|
|
90
|
+
print_arrs(inputs);
|
|
91
|
+
os << "\nOutputs: ";
|
|
92
|
+
print_arrs(outputs);
|
|
93
|
+
os << "\n";
|
|
94
|
+
|
|
95
|
+
for (auto& arr : tape) {
|
|
96
|
+
os << arr.primitive().name();
|
|
97
|
+
os << " ";
|
|
98
|
+
print_arrs(arr.inputs());
|
|
99
|
+
os << " -> ";
|
|
100
|
+
print_arrs(arr.outputs());
|
|
101
|
+
os << "\n";
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
void export_to_dot(
|
|
106
|
+
std::ostream& os,
|
|
107
|
+
NodeNamer namer,
|
|
108
|
+
const std::vector<array>& nodes) {
|
|
109
|
+
// Perform one DFS to mark arrays as intermediate if they are used as inputs
|
|
110
|
+
// to other arrays.
|
|
111
|
+
std::unordered_set<std::uintptr_t> intermediate_set;
|
|
112
|
+
depth_first_traversal(
|
|
113
|
+
[&](const array& x) {
|
|
114
|
+
// No primitive so it is an input
|
|
115
|
+
if (!x.has_primitive()) {
|
|
116
|
+
return;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
for (auto& a : x.inputs()) {
|
|
120
|
+
intermediate_set.insert(a.id());
|
|
121
|
+
}
|
|
122
|
+
},
|
|
123
|
+
nodes);
|
|
124
|
+
|
|
125
|
+
// Now we got everything we need to make the graph. Arrays can be one of 3
|
|
126
|
+
// things:
|
|
127
|
+
// 1. Inputs, when they have no primitive ie are evaluated
|
|
128
|
+
// 2. Intermediates, when they are the intermediate set
|
|
129
|
+
// 3. Outputs, if they are not inputs and not intermediates
|
|
130
|
+
|
|
131
|
+
os << "digraph {" << std::endl;
|
|
132
|
+
|
|
133
|
+
depth_first_traversal(
|
|
134
|
+
[&](const array& x) {
|
|
135
|
+
if (!x.has_primitive()) {
|
|
136
|
+
os << "{ rank=source; \"" << namer.get_name(x) << "\"; }"
|
|
137
|
+
<< std::endl;
|
|
138
|
+
return;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// Node for primitive
|
|
142
|
+
if (x.has_primitive()) {
|
|
143
|
+
os << "{ ";
|
|
144
|
+
os << x.primitive_id();
|
|
145
|
+
os << " [label =\"";
|
|
146
|
+
os << x.primitive().name();
|
|
147
|
+
os << "\", shape=rectangle]";
|
|
148
|
+
os << "; }" << std::endl;
|
|
149
|
+
// Arrows to primitive's inputs
|
|
150
|
+
for (auto& a : x.inputs()) {
|
|
151
|
+
os << '"' << namer.get_name(a) << "\" -> " << x.primitive_id()
|
|
152
|
+
<< std::endl;
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// Point outputs to their primitive
|
|
157
|
+
for (auto& a : x.outputs()) {
|
|
158
|
+
os << "{ ";
|
|
159
|
+
if (intermediate_set.find(a.id()) == intermediate_set.end()) {
|
|
160
|
+
os << "rank=sink; ";
|
|
161
|
+
}
|
|
162
|
+
os << '"' << namer.get_name(a);
|
|
163
|
+
os << "\"; }" << std::endl;
|
|
164
|
+
if (x.has_primitive()) {
|
|
165
|
+
os << x.primitive_id() << " -> \"" << namer.get_name(a) << '"'
|
|
166
|
+
<< std::endl;
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
},
|
|
170
|
+
nodes);
|
|
171
|
+
|
|
172
|
+
os << "}";
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <unordered_map>
|
|
6
|
+
|
|
7
|
+
#include "mlx/api.h"
|
|
8
|
+
#include "mlx/array.h"
|
|
9
|
+
|
|
10
|
+
namespace mlx::core {
|
|
11
|
+
|
|
12
|
+
struct MLX_API NodeNamer {
|
|
13
|
+
std::unordered_map<std::uintptr_t, std::string> names;
|
|
14
|
+
|
|
15
|
+
const std::string& get_name(const array& x);
|
|
16
|
+
void set_name(const array& x, std::string n);
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
MLX_API void print_graph(
|
|
20
|
+
std::ostream& os,
|
|
21
|
+
NodeNamer namer,
|
|
22
|
+
const std::vector<array>& outputs);
|
|
23
|
+
|
|
24
|
+
inline void print_graph(std::ostream& os, const std::vector<array>& outputs) {
|
|
25
|
+
print_graph(os, NodeNamer{}, outputs);
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
29
|
+
inline void print_graph(std::ostream& os, Arrays&&... outputs) {
|
|
30
|
+
print_graph(
|
|
31
|
+
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
35
|
+
inline void
|
|
36
|
+
print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
|
|
37
|
+
print_graph(
|
|
38
|
+
os,
|
|
39
|
+
std::move(namer),
|
|
40
|
+
std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
MLX_API void export_to_dot(
|
|
44
|
+
std::ostream& os,
|
|
45
|
+
NodeNamer namer,
|
|
46
|
+
const std::vector<array>& outputs);
|
|
47
|
+
|
|
48
|
+
inline void export_to_dot(std::ostream& os, const std::vector<array>& outputs) {
|
|
49
|
+
export_to_dot(os, NodeNamer{}, outputs);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
53
|
+
inline void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
|
54
|
+
export_to_dot(
|
|
55
|
+
os, NodeNamer{}, std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
59
|
+
inline void
|
|
60
|
+
export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) {
|
|
61
|
+
export_to_dot(
|
|
62
|
+
os,
|
|
63
|
+
std::move(namer),
|
|
64
|
+
std::vector<array>{std::forward<Arrays>(outputs)...});
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp)
|
|
2
|
+
|
|
3
|
+
if(MLX_BUILD_SAFETENSORS)
|
|
4
|
+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/safetensors.cpp)
|
|
5
|
+
else()
|
|
6
|
+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_safetensors.cpp)
|
|
7
|
+
endif()
|
|
8
|
+
|
|
9
|
+
if(MLX_BUILD_GGUF)
|
|
10
|
+
message(STATUS "Downloading gguflib")
|
|
11
|
+
FetchContent_Declare(
|
|
12
|
+
gguflib
|
|
13
|
+
GIT_REPOSITORY https://github.com/antirez/gguf-tools/
|
|
14
|
+
GIT_TAG 8fa6eb65236618e28fd7710a0fba565f7faa1848)
|
|
15
|
+
FetchContent_MakeAvailable(gguflib)
|
|
16
|
+
target_include_directories(mlx
|
|
17
|
+
PRIVATE $<BUILD_INTERFACE:${gguflib_SOURCE_DIR}>)
|
|
18
|
+
add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c
|
|
19
|
+
${gguflib_SOURCE_DIR}/gguflib.c)
|
|
20
|
+
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:gguflib>)
|
|
21
|
+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
|
|
22
|
+
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp)
|
|
23
|
+
else()
|
|
24
|
+
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_gguf.cpp)
|
|
25
|
+
endif()
|