mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
data/mlx/mlx/compile.cpp
ADDED
|
@@ -0,0 +1,1243 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
#include <cstdlib>
|
|
3
|
+
#include <map>
|
|
4
|
+
#include <sstream>
|
|
5
|
+
#include <unordered_map>
|
|
6
|
+
#include <unordered_set>
|
|
7
|
+
|
|
8
|
+
#include "mlx/allocator.h"
|
|
9
|
+
#include "mlx/backend/common/compiled.h"
|
|
10
|
+
#include "mlx/compile.h"
|
|
11
|
+
#include "mlx/compile_impl.h"
|
|
12
|
+
#include "mlx/fast_primitives.h"
|
|
13
|
+
#include "mlx/graph_utils.h"
|
|
14
|
+
#include "mlx/primitives.h"
|
|
15
|
+
#include "mlx/transforms.h"
|
|
16
|
+
#include "mlx/transforms_impl.h"
|
|
17
|
+
#include "mlx/utils.h"
|
|
18
|
+
|
|
19
|
+
namespace mlx::core {
|
|
20
|
+
|
|
21
|
+
constexpr int max_compile_depth = 11;
|
|
22
|
+
constexpr int max_compile_arrays = 24;
|
|
23
|
+
|
|
24
|
+
bool is_unary(const Primitive& p) {
|
|
25
|
+
return (
|
|
26
|
+
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
|
|
27
|
+
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
|
|
28
|
+
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
|
29
|
+
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
|
30
|
+
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
|
31
|
+
typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) ||
|
|
32
|
+
typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) ||
|
|
33
|
+
typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) ||
|
|
34
|
+
typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) ||
|
|
35
|
+
typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) ||
|
|
36
|
+
typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) ||
|
|
37
|
+
typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) ||
|
|
38
|
+
typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
|
|
39
|
+
typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
|
|
40
|
+
typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
|
|
41
|
+
typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
|
|
42
|
+
typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert));
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
bool is_binary(const Primitive& p) {
|
|
46
|
+
return (
|
|
47
|
+
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
|
|
48
|
+
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
|
|
49
|
+
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
|
|
50
|
+
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
|
|
51
|
+
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
|
|
52
|
+
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
|
53
|
+
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
|
54
|
+
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
|
55
|
+
typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
|
|
56
|
+
typeid(p) == typeid(ArcTan2));
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
bool is_ternary(const Primitive& p) {
|
|
60
|
+
return typeid(p) == typeid(Select);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
bool is_broadcast(const Primitive& p) {
|
|
64
|
+
return typeid(p) == typeid(Broadcast);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
bool is_noop(const Primitive& p) {
|
|
68
|
+
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
bool is_reduction(const Primitive& p) {
|
|
72
|
+
return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
bool is_fusable(const Primitive& p) {
|
|
76
|
+
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
Compiled::Compiled(
|
|
80
|
+
Stream stream,
|
|
81
|
+
std::vector<array> inputs,
|
|
82
|
+
std::vector<array> outputs,
|
|
83
|
+
std::vector<array> tape,
|
|
84
|
+
std::unordered_set<uintptr_t> constant_ids)
|
|
85
|
+
: Primitive(stream),
|
|
86
|
+
inputs_(std::move(inputs)),
|
|
87
|
+
outputs_(std::move(outputs)),
|
|
88
|
+
tape_(std::move(tape)),
|
|
89
|
+
constant_ids_(std::move(constant_ids)),
|
|
90
|
+
is_constant_([this](size_t i) {
|
|
91
|
+
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
|
|
92
|
+
}) {
|
|
93
|
+
// Build the kernel name.
|
|
94
|
+
NodeNamer namer;
|
|
95
|
+
std::ostringstream os;
|
|
96
|
+
std::ostringstream constant_hasher;
|
|
97
|
+
|
|
98
|
+
std::unordered_set<uintptr_t> output_ids;
|
|
99
|
+
for (auto& o : outputs_) {
|
|
100
|
+
output_ids.insert(o.id());
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Fill the input names. This is not really necessary, I just like having A,
|
|
104
|
+
// B, C, ... as the inputs.
|
|
105
|
+
for (const auto& x : inputs_) {
|
|
106
|
+
namer.get_name(x);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// The primitives describing the tape. For unary and binary primitives this
|
|
110
|
+
// must be enough to describe the full computation.
|
|
111
|
+
for (const auto& a : tape_) {
|
|
112
|
+
// name and type of output
|
|
113
|
+
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
|
114
|
+
// whether or not it's an output
|
|
115
|
+
if (output_ids.find(a.id()) != output_ids.end()) {
|
|
116
|
+
os << "O";
|
|
117
|
+
} else {
|
|
118
|
+
os << "I";
|
|
119
|
+
}
|
|
120
|
+
// computation performed
|
|
121
|
+
os << a.primitive().name();
|
|
122
|
+
// name of inputs to the function
|
|
123
|
+
for (auto& inp : a.inputs()) {
|
|
124
|
+
os << namer.get_name(inp);
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
os << "_";
|
|
128
|
+
|
|
129
|
+
for (const auto& x : inputs_) {
|
|
130
|
+
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
|
|
131
|
+
os << "C";
|
|
132
|
+
print_constant(constant_hasher, x);
|
|
133
|
+
} else {
|
|
134
|
+
os << (is_scalar(x) ? "S" : "V");
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
os << "_";
|
|
138
|
+
for (const auto& x : inputs) {
|
|
139
|
+
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
140
|
+
continue;
|
|
141
|
+
}
|
|
142
|
+
os << kindof(x.dtype()) << x.itemsize();
|
|
143
|
+
}
|
|
144
|
+
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
|
145
|
+
|
|
146
|
+
kernel_lib_ = os.str();
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
std::vector<array> Compiled::vjp(
|
|
150
|
+
const std::vector<array>&,
|
|
151
|
+
const std::vector<array>&,
|
|
152
|
+
const std::vector<int>&,
|
|
153
|
+
const std::vector<array>&) {
|
|
154
|
+
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
std::vector<array> Compiled::jvp(
|
|
158
|
+
const std::vector<array>&,
|
|
159
|
+
const std::vector<array>&,
|
|
160
|
+
const std::vector<int>&) {
|
|
161
|
+
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
|
165
|
+
const std::vector<array>&,
|
|
166
|
+
const std::vector<int>&) {
|
|
167
|
+
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
bool Compiled::is_equivalent(const Primitive& other) const {
|
|
171
|
+
const Compiled& a_other = static_cast<const Compiled&>(other);
|
|
172
|
+
return std::equal(
|
|
173
|
+
tape_.begin(),
|
|
174
|
+
tape_.end(),
|
|
175
|
+
a_other.tape_.begin(),
|
|
176
|
+
a_other.tape_.end(),
|
|
177
|
+
[](const array& a1, const array& a2) {
|
|
178
|
+
auto& p1 = a1.primitive();
|
|
179
|
+
auto& p2 = a2.primitive();
|
|
180
|
+
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
|
|
181
|
+
});
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
const char* Compiled::name() const {
|
|
185
|
+
if (name_.empty()) {
|
|
186
|
+
std::ostringstream os;
|
|
187
|
+
os << "Compiled";
|
|
188
|
+
for (auto& a : tape_) {
|
|
189
|
+
os << a.primitive().name();
|
|
190
|
+
}
|
|
191
|
+
name_ = os.str();
|
|
192
|
+
}
|
|
193
|
+
return name_.c_str();
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
|
197
|
+
size_t nd = 0;
|
|
198
|
+
for (auto& in : inputs) {
|
|
199
|
+
nd = std::max(nd, in.ndim());
|
|
200
|
+
}
|
|
201
|
+
Shape out_shape(nd, 0);
|
|
202
|
+
for (auto& in : inputs) {
|
|
203
|
+
auto dd = nd - in.ndim();
|
|
204
|
+
for (auto i = dd; i < nd; ++i) {
|
|
205
|
+
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
// All outputs have the same shape
|
|
209
|
+
return std::vector<Shape>(outputs_.size(), out_shape);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
namespace detail {
|
|
213
|
+
|
|
214
|
+
CompileMode& compile_mode() {
|
|
215
|
+
auto get_val = []() {
|
|
216
|
+
if (std::getenv("MLX_DISABLE_COMPILE")) {
|
|
217
|
+
return CompileMode::disabled;
|
|
218
|
+
} else {
|
|
219
|
+
return CompileMode::enabled;
|
|
220
|
+
}
|
|
221
|
+
};
|
|
222
|
+
static CompileMode compile_mode_ = get_val();
|
|
223
|
+
return compile_mode_;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
// Helper like below but only merges the two provided arrays. If the src has
|
|
227
|
+
// siblings then these won't be merged to the dst.
|
|
228
|
+
void merge_one(array& dst, array& src, ParentsMap& parents_map) {
|
|
229
|
+
auto src_parents = parents_map.find(src.id());
|
|
230
|
+
if (src_parents == parents_map.end()) {
|
|
231
|
+
return;
|
|
232
|
+
}
|
|
233
|
+
auto& pairs = parents_map[dst.id()];
|
|
234
|
+
for (auto& parent : src_parents->second) {
|
|
235
|
+
parent.first.inputs()[parent.second] = dst;
|
|
236
|
+
pairs.push_back(parent);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
// If src is a parent of dst, remove it from dst's parents
|
|
240
|
+
for (auto it = pairs.begin(); it != pairs.end();) {
|
|
241
|
+
if (it->first.id() == src.id()) {
|
|
242
|
+
it = pairs.erase(it);
|
|
243
|
+
} else {
|
|
244
|
+
it++;
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
// Remove the source from the map to avoid fusing with it again
|
|
248
|
+
parents_map.erase(src_parents);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// Helper that merges two arrays in the graph by setting the parents of the
|
|
252
|
+
// source to point to the destination. The arrays are assumed to be coming from
|
|
253
|
+
// equivalent primitives so their siblings are merged as well.
|
|
254
|
+
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
|
255
|
+
// Canonicalize the order of the primitives outputs
|
|
256
|
+
auto sources = src.outputs();
|
|
257
|
+
auto dests = dst.outputs();
|
|
258
|
+
// For each src parent, point it to the corresponding dst
|
|
259
|
+
for (int i = 0; i < sources.size(); ++i) {
|
|
260
|
+
merge_one(dests[i], sources[i], parents_map);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Any parent in the divider will continue to refer to `x` but any parent not
|
|
265
|
+
// in the divider will refer to a copy of the operation.
|
|
266
|
+
array split_one(
|
|
267
|
+
const array& x,
|
|
268
|
+
ParentsMap& parents_map,
|
|
269
|
+
const std::unordered_set<uintptr_t>& divider) {
|
|
270
|
+
array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
|
|
271
|
+
|
|
272
|
+
auto& x_parents = parents_map[x.id()];
|
|
273
|
+
auto& y_parents = parents_map[y.id()];
|
|
274
|
+
|
|
275
|
+
for (auto it = x_parents.begin(); it != x_parents.end();) {
|
|
276
|
+
if (divider.find(it->first.id()) != divider.end()) {
|
|
277
|
+
it->first.inputs()[it->second] = y;
|
|
278
|
+
y_parents.emplace_back(std::move(*it));
|
|
279
|
+
it = x_parents.erase(it);
|
|
280
|
+
} else {
|
|
281
|
+
it++;
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
return y;
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
template <typename T, typename... U>
|
|
289
|
+
std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
|
290
|
+
using FunType = T (*)(U...);
|
|
291
|
+
const FunType* fun_ptr = fun.template target<FunType>();
|
|
292
|
+
if (fun_ptr == nullptr) {
|
|
293
|
+
return 0;
|
|
294
|
+
}
|
|
295
|
+
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
class CompilerCache {
|
|
299
|
+
public:
|
|
300
|
+
struct CacheEntry {
|
|
301
|
+
CacheEntry(Stream stream, bool shapeless)
|
|
302
|
+
: stream(stream), shapeless(shapeless) {};
|
|
303
|
+
Stream stream;
|
|
304
|
+
bool shapeless;
|
|
305
|
+
std::vector<array> inputs;
|
|
306
|
+
std::vector<array> outputs;
|
|
307
|
+
std::vector<array> tape;
|
|
308
|
+
bool empty{true};
|
|
309
|
+
std::vector<uint64_t> constants;
|
|
310
|
+
std::shared_ptr<void> extra;
|
|
311
|
+
};
|
|
312
|
+
|
|
313
|
+
// Returns a reference to a CacheEntry which can be updated
|
|
314
|
+
// by the caller to avoid copying large tapes / inputs / outputs
|
|
315
|
+
CacheEntry& find(
|
|
316
|
+
std::uintptr_t fun_id,
|
|
317
|
+
const std::vector<array>& inputs,
|
|
318
|
+
bool shapeless,
|
|
319
|
+
const std::vector<uint64_t>& constants) {
|
|
320
|
+
// Find the cache entries for |fun_id|.
|
|
321
|
+
std::vector<CacheEntry>& entries = cache_[fun_id];
|
|
322
|
+
|
|
323
|
+
// Compare if 2 arrays have same shape and dtype.
|
|
324
|
+
auto has_same_shape_and_dtype = [shapeless](
|
|
325
|
+
const std::vector<array>& in1,
|
|
326
|
+
const std::vector<array>& in2) {
|
|
327
|
+
if (in1.size() != in2.size()) {
|
|
328
|
+
return false;
|
|
329
|
+
}
|
|
330
|
+
for (size_t i = 0; i < in1.size(); ++i) {
|
|
331
|
+
if (in1[i].ndim() != in2[i].ndim()) {
|
|
332
|
+
return false;
|
|
333
|
+
}
|
|
334
|
+
if (!shapeless && in1[i].shape() != in2[i].shape()) {
|
|
335
|
+
return false;
|
|
336
|
+
}
|
|
337
|
+
if (in1[i].dtype() != in2[i].dtype()) {
|
|
338
|
+
return false;
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
return true;
|
|
342
|
+
};
|
|
343
|
+
// Loop over entries and check:
|
|
344
|
+
// - Default stream and device match the entry's default stream
|
|
345
|
+
// - Inputs match i.e. shapes and types must be equal.
|
|
346
|
+
auto stream = default_stream(default_device());
|
|
347
|
+
for (CacheEntry& entry : entries) {
|
|
348
|
+
// Check that the default stream and device match
|
|
349
|
+
if (entry.stream != stream) {
|
|
350
|
+
continue;
|
|
351
|
+
}
|
|
352
|
+
if (entry.shapeless != shapeless) {
|
|
353
|
+
continue;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
// Check the inputs match and return if so
|
|
357
|
+
if (has_same_shape_and_dtype(inputs, entry.inputs) &&
|
|
358
|
+
constants == entry.constants) {
|
|
359
|
+
return entry;
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
// Otherwise append a new cache entry
|
|
363
|
+
entries.push_back(CacheEntry{stream, shapeless});
|
|
364
|
+
return entries.back();
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
void erase(std::uintptr_t fun_id) {
|
|
368
|
+
cache_.erase(fun_id);
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
void clear() {
|
|
372
|
+
cache_.clear();
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
private:
|
|
376
|
+
CompilerCache() {
|
|
377
|
+
// Make sure the allocator is fully
|
|
378
|
+
// initialized before the compiler cache
|
|
379
|
+
allocator::allocator();
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
friend CompilerCache& compiler_cache();
|
|
383
|
+
std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
|
|
384
|
+
};
|
|
385
|
+
|
|
386
|
+
CompilerCache& compiler_cache() {
|
|
387
|
+
static CompilerCache compiler_cache_;
|
|
388
|
+
return compiler_cache_;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
|
|
392
|
+
compile_trace(
|
|
393
|
+
const ArrayFnWithExtra& fun,
|
|
394
|
+
const std::vector<array>& inputs,
|
|
395
|
+
bool shapeless) {
|
|
396
|
+
// Set the global tracing flag.
|
|
397
|
+
detail::InTracing in_tracing{shapeless};
|
|
398
|
+
|
|
399
|
+
// Run the function on placeholder inputs
|
|
400
|
+
// to get compute graph
|
|
401
|
+
std::vector<array> tracer_inputs;
|
|
402
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
403
|
+
array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});
|
|
404
|
+
in.set_tracer(true);
|
|
405
|
+
tracer_inputs.push_back(std::move(in));
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
auto output = fun(tracer_inputs);
|
|
409
|
+
return {tracer_inputs, output.first, output.second};
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
// Traverses the graph to build a tape and a map of array ids to their parents
|
|
413
|
+
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|
414
|
+
const std::vector<array>& inputs,
|
|
415
|
+
std::vector<array>& outputs,
|
|
416
|
+
const std::vector<array>& original_inputs) {
|
|
417
|
+
std::vector<array> tape;
|
|
418
|
+
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
|
419
|
+
parents_map;
|
|
420
|
+
{
|
|
421
|
+
std::function<void(const array&)> recurse;
|
|
422
|
+
std::unordered_set<std::uintptr_t> input_set;
|
|
423
|
+
std::unordered_set<std::uintptr_t> original_input_set;
|
|
424
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
425
|
+
input_set.insert(inputs[i].id());
|
|
426
|
+
original_input_set.insert(original_inputs[i].id());
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
// DFS the graph to build the tape, and log parents and scalars
|
|
430
|
+
std::unordered_set<std::uintptr_t> cache;
|
|
431
|
+
recurse = [&](const array& a) {
|
|
432
|
+
auto id = a.id();
|
|
433
|
+
if (original_input_set.find(id) != original_input_set.end()) {
|
|
434
|
+
throw std::invalid_argument(
|
|
435
|
+
"[compile] Attempting to compile a function with uncaptured inputs is not allowed.");
|
|
436
|
+
}
|
|
437
|
+
if (cache.find(id) != cache.end()) {
|
|
438
|
+
return;
|
|
439
|
+
}
|
|
440
|
+
for (int i = 0; i < a.inputs().size(); i++) {
|
|
441
|
+
auto& in = a.inputs()[i];
|
|
442
|
+
parents_map[in.id()].push_back({a, i});
|
|
443
|
+
for (auto& s : a.siblings()) {
|
|
444
|
+
parents_map[in.id()].push_back({s, i});
|
|
445
|
+
}
|
|
446
|
+
// Don't recurse on inputs (but add them to the tape for the purpose
|
|
447
|
+
// of future optimizations)
|
|
448
|
+
if (input_set.find(a.id()) == input_set.end()) {
|
|
449
|
+
recurse(in);
|
|
450
|
+
}
|
|
451
|
+
}
|
|
452
|
+
cache.insert(id);
|
|
453
|
+
for (auto& s : a.siblings()) {
|
|
454
|
+
cache.insert(s.id());
|
|
455
|
+
}
|
|
456
|
+
tape.push_back(a);
|
|
457
|
+
};
|
|
458
|
+
for (auto& a : outputs) {
|
|
459
|
+
recurse(a);
|
|
460
|
+
}
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
// Deep copy the tape and parents map while preserving inputs and outputs
|
|
464
|
+
std::vector<array> new_tape;
|
|
465
|
+
std::unordered_set<uintptr_t> io_set;
|
|
466
|
+
std::unordered_map<uintptr_t, array> old_to_new;
|
|
467
|
+
for (auto& o : outputs) {
|
|
468
|
+
old_to_new.insert({o.id(), o});
|
|
469
|
+
io_set.insert(o.id());
|
|
470
|
+
for (auto& s : o.siblings()) {
|
|
471
|
+
old_to_new.insert({s.id(), s});
|
|
472
|
+
io_set.insert(s.id());
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
for (auto& i : inputs) {
|
|
476
|
+
io_set.insert(i.id());
|
|
477
|
+
old_to_new.insert({i.id(), i});
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
new_tape.reserve(tape.size());
|
|
481
|
+
for (auto& arr : tape) {
|
|
482
|
+
if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) {
|
|
483
|
+
old_to_new.insert({arr.id(), arr});
|
|
484
|
+
new_tape.push_back(arr);
|
|
485
|
+
continue;
|
|
486
|
+
}
|
|
487
|
+
std::vector<array> inputs;
|
|
488
|
+
inputs.reserve(arr.inputs().size());
|
|
489
|
+
for (auto& i : arr.inputs()) {
|
|
490
|
+
inputs.push_back(old_to_new.find(i.id())->second);
|
|
491
|
+
}
|
|
492
|
+
if (arr.siblings().size() > 0) {
|
|
493
|
+
std::vector<Dtype> types;
|
|
494
|
+
std::vector<Shape> shapes;
|
|
495
|
+
auto out = arr.outputs();
|
|
496
|
+
for (auto& o : out) {
|
|
497
|
+
types.push_back(o.dtype());
|
|
498
|
+
shapes.push_back(o.shape());
|
|
499
|
+
}
|
|
500
|
+
auto as = array::make_arrays(
|
|
501
|
+
std::move(shapes), types, arr.primitive_ptr(), std::move(inputs));
|
|
502
|
+
for (int i = 0; i < out.size(); ++i) {
|
|
503
|
+
old_to_new.insert({out[i].id(), as[i]});
|
|
504
|
+
}
|
|
505
|
+
new_tape.push_back(as[arr.sibling_position()]);
|
|
506
|
+
} else {
|
|
507
|
+
auto a = array(
|
|
508
|
+
arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));
|
|
509
|
+
old_to_new.insert({arr.id(), a});
|
|
510
|
+
new_tape.push_back(a);
|
|
511
|
+
}
|
|
512
|
+
}
|
|
513
|
+
io_set.clear();
|
|
514
|
+
for (auto& o : outputs) {
|
|
515
|
+
if (!(io_set.insert(o.id()).second)) {
|
|
516
|
+
continue;
|
|
517
|
+
}
|
|
518
|
+
for (auto& i : o.inputs()) {
|
|
519
|
+
i = old_to_new.find(i.id())->second;
|
|
520
|
+
}
|
|
521
|
+
for (auto& s : o.siblings()) {
|
|
522
|
+
io_set.insert(s.id());
|
|
523
|
+
for (auto& i : s.inputs()) {
|
|
524
|
+
i = old_to_new.find(i.id())->second;
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
tape = std::move(new_tape);
|
|
529
|
+
|
|
530
|
+
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
|
531
|
+
new_parents_map;
|
|
532
|
+
for (auto& [id, vec] : parents_map) {
|
|
533
|
+
for (auto& [a, _] : vec) {
|
|
534
|
+
a = old_to_new.find(a.id())->second;
|
|
535
|
+
}
|
|
536
|
+
new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
|
|
537
|
+
}
|
|
538
|
+
parents_map = std::move(new_parents_map);
|
|
539
|
+
return {tape, parents_map};
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
static inline uint64_t splitmix64(uint64_t x) noexcept {
|
|
543
|
+
x += 0x9e3779b97f4a7c15ull;
|
|
544
|
+
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull;
|
|
545
|
+
x = (x ^ (x >> 27)) * 0x94d049bb133111ebull;
|
|
546
|
+
return x ^ (x >> 31);
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
struct VecU64Hash {
|
|
550
|
+
size_t operator()(const std::vector<uint64_t>& s) const noexcept {
|
|
551
|
+
uint64_t h =
|
|
552
|
+
0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull;
|
|
553
|
+
for (uint64_t x : s) {
|
|
554
|
+
h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull));
|
|
555
|
+
}
|
|
556
|
+
return (size_t)h;
|
|
557
|
+
}
|
|
558
|
+
};
|
|
559
|
+
|
|
560
|
+
// Simplify the tape. Note, this function modifies in-place both the tape,
|
|
561
|
+
// the parents map to remove orphaned arrays, and potentially the outputs
|
|
562
|
+
void compile_simplify(
|
|
563
|
+
std::vector<array>& tape,
|
|
564
|
+
ParentsMap& parents_map,
|
|
565
|
+
std::vector<array>& outputs,
|
|
566
|
+
int passes) {
|
|
567
|
+
// Helpers to identify identical scalars
|
|
568
|
+
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
|
569
|
+
auto is_scalar = [](const array& a) {
|
|
570
|
+
// Condition for when it's safe to read an array
|
|
571
|
+
return a.is_available() && a.ndim() == 0;
|
|
572
|
+
};
|
|
573
|
+
auto get_scalar_rep = [](const array& a) {
|
|
574
|
+
uint64_t v = 0;
|
|
575
|
+
switch (a.dtype().size()) {
|
|
576
|
+
case 1:
|
|
577
|
+
v = *a.data<uint8_t>();
|
|
578
|
+
break;
|
|
579
|
+
case 2:
|
|
580
|
+
v = *a.data<uint16_t>();
|
|
581
|
+
break;
|
|
582
|
+
case 4:
|
|
583
|
+
v = *a.data<uint32_t>();
|
|
584
|
+
break;
|
|
585
|
+
case 8:
|
|
586
|
+
v = *a.data<uint64_t>();
|
|
587
|
+
break;
|
|
588
|
+
}
|
|
589
|
+
return std::make_pair(v, a.dtype().val());
|
|
590
|
+
};
|
|
591
|
+
|
|
592
|
+
for (auto& a : tape) {
|
|
593
|
+
if (is_scalar(a)) {
|
|
594
|
+
scalars.insert({get_scalar_rep(a), a});
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
// Depth-1 array equivalence check.
|
|
599
|
+
auto array_equivalent = [](const array& a, const array& b) {
|
|
600
|
+
if (!a.has_primitive() || !b.has_primitive()) {
|
|
601
|
+
return false;
|
|
602
|
+
}
|
|
603
|
+
if (a.primitive_id() == b.primitive_id()) {
|
|
604
|
+
return false;
|
|
605
|
+
}
|
|
606
|
+
const auto& pa = a.primitive();
|
|
607
|
+
const auto& pb = b.primitive();
|
|
608
|
+
if (typeid(pa) != typeid(pb)) {
|
|
609
|
+
return false;
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
if (a.inputs().size() != b.inputs().size()) {
|
|
613
|
+
return false;
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
for (int i = 0; i < a.inputs().size(); i++) {
|
|
617
|
+
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
|
618
|
+
return false;
|
|
619
|
+
}
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
return pa.is_equivalent(pb);
|
|
623
|
+
};
|
|
624
|
+
|
|
625
|
+
// Merge scalars
|
|
626
|
+
std::vector<array> new_tape;
|
|
627
|
+
for (auto& arr : tape) {
|
|
628
|
+
// Check if we can merge scalars
|
|
629
|
+
if (is_scalar(arr)) {
|
|
630
|
+
auto scalar = scalars.find(get_scalar_rep(arr));
|
|
631
|
+
if (scalar->second.id() != arr.id()) {
|
|
632
|
+
merge(scalar->second, arr, parents_map);
|
|
633
|
+
// Don't keep orphaned scalars in the tape
|
|
634
|
+
continue;
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
new_tape.push_back(std::move(arr));
|
|
638
|
+
}
|
|
639
|
+
tape = std::move(new_tape);
|
|
640
|
+
|
|
641
|
+
// Remove no-ops
|
|
642
|
+
{
|
|
643
|
+
std::unordered_map<uintptr_t, array> output_map;
|
|
644
|
+
for (auto& o : outputs) {
|
|
645
|
+
output_map.insert({o.id(), o});
|
|
646
|
+
}
|
|
647
|
+
for (auto& arr : tape) {
|
|
648
|
+
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
|
|
649
|
+
new_tape.push_back(std::move(arr));
|
|
650
|
+
continue;
|
|
651
|
+
}
|
|
652
|
+
merge_one(arr.inputs()[0], arr, parents_map);
|
|
653
|
+
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
|
|
654
|
+
it->second = arr.inputs()[0];
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
tape = std::move(new_tape);
|
|
658
|
+
for (auto& o : outputs) {
|
|
659
|
+
o = output_map.at(o.id());
|
|
660
|
+
}
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
|
|
664
|
+
for (uint32_t i = 0; i < tape.size(); ++i) {
|
|
665
|
+
tape_order.insert({tape[i].id(), i});
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
std::unordered_set<uintptr_t> output_set;
|
|
669
|
+
for (auto& o : outputs) {
|
|
670
|
+
output_set.insert(o.id());
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
|
674
|
+
for (int pass = 0; pass < passes; ++pass) {
|
|
675
|
+
for (auto& arr : tape) {
|
|
676
|
+
// Helper to check if we can merge the parents of the
|
|
677
|
+
// given array
|
|
678
|
+
auto maybe_merge_parents = [&](auto& a) {
|
|
679
|
+
auto parents = parents_map.find(a.id());
|
|
680
|
+
if (parents != parents_map.end()) {
|
|
681
|
+
auto N = parents->second.size();
|
|
682
|
+
std::vector<bool> mask(N, false);
|
|
683
|
+
|
|
684
|
+
auto try_merge = [&](int dst_idx, int src_idx) {
|
|
685
|
+
if (tape_order[parents->second[src_idx].first.id()] <
|
|
686
|
+
tape_order[parents->second[dst_idx].first.id()]) {
|
|
687
|
+
std::swap(src_idx, dst_idx);
|
|
688
|
+
}
|
|
689
|
+
auto& src = parents->second[src_idx].first;
|
|
690
|
+
auto& dst = parents->second[dst_idx].first;
|
|
691
|
+
if (src.id() != dst.id() && array_equivalent(src, dst) &&
|
|
692
|
+
output_set.find(src.id()) == output_set.end()) {
|
|
693
|
+
merge(dst, src, parents_map);
|
|
694
|
+
mask[src_idx] = true;
|
|
695
|
+
}
|
|
696
|
+
};
|
|
697
|
+
|
|
698
|
+
if (N > 100) {
|
|
699
|
+
std::unordered_map<
|
|
700
|
+
std::vector<uint64_t>,
|
|
701
|
+
std::vector<int>,
|
|
702
|
+
VecU64Hash>
|
|
703
|
+
dst_map;
|
|
704
|
+
// Find possibly mergeable groups
|
|
705
|
+
for (int i = 0; i < N; i++) {
|
|
706
|
+
// Make the hash key
|
|
707
|
+
std::vector<uint64_t> key;
|
|
708
|
+
auto& curr = parents->second[i].first;
|
|
709
|
+
key.reserve(curr.inputs().size() + 2);
|
|
710
|
+
for (auto& in : curr.inputs()) {
|
|
711
|
+
key.push_back(in.id());
|
|
712
|
+
}
|
|
713
|
+
auto& p = curr.primitive();
|
|
714
|
+
key.push_back(curr.inputs().size());
|
|
715
|
+
key.push_back(typeid(p).hash_code());
|
|
716
|
+
auto it = dst_map.find(key);
|
|
717
|
+
if (it == dst_map.end()) {
|
|
718
|
+
bool _;
|
|
719
|
+
std::tie(it, _) = dst_map.insert({key, std::vector<int>{}});
|
|
720
|
+
}
|
|
721
|
+
it->second.push_back(i);
|
|
722
|
+
}
|
|
723
|
+
for (auto& [_, group] : dst_map) {
|
|
724
|
+
for (int i = 0; i < group.size(); ++i) {
|
|
725
|
+
if (mask[group[i]]) {
|
|
726
|
+
continue;
|
|
727
|
+
}
|
|
728
|
+
for (int j = i + 1; j < group.size(); ++j) {
|
|
729
|
+
if (mask[group[j]]) {
|
|
730
|
+
continue;
|
|
731
|
+
}
|
|
732
|
+
try_merge(group[i], group[j]);
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
} else {
|
|
737
|
+
for (int i = 0; i < N; ++i) {
|
|
738
|
+
if (mask[i]) {
|
|
739
|
+
continue;
|
|
740
|
+
}
|
|
741
|
+
for (int j = i + 1; j < N; ++j) {
|
|
742
|
+
if (mask[j]) {
|
|
743
|
+
continue;
|
|
744
|
+
}
|
|
745
|
+
try_merge(i, j);
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
// Erase orphaned parents so we don't keep fusing with them
|
|
751
|
+
for (int i = N - 1; i >= 0; --i) {
|
|
752
|
+
if (mask[i]) {
|
|
753
|
+
parents->second.erase(parents->second.begin() + i);
|
|
754
|
+
}
|
|
755
|
+
}
|
|
756
|
+
return false;
|
|
757
|
+
} else {
|
|
758
|
+
return output_set.find(a.id()) == output_set.end();
|
|
759
|
+
}
|
|
760
|
+
};
|
|
761
|
+
bool discard = maybe_merge_parents(arr);
|
|
762
|
+
for (auto& s : arr.siblings()) {
|
|
763
|
+
discard &= maybe_merge_parents(s);
|
|
764
|
+
}
|
|
765
|
+
// If an array and its siblings have no parents, and none of them are
|
|
766
|
+
// outputs, it is safe to remove it from the tape
|
|
767
|
+
if (!discard) {
|
|
768
|
+
new_tape.push_back(std::move(arr));
|
|
769
|
+
}
|
|
770
|
+
}
|
|
771
|
+
tape = std::move(new_tape);
|
|
772
|
+
}
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
// Extract sub-graphs of the graph that can be compiled
|
|
776
|
+
// and replace them with a Compiled Primitive.
|
|
777
|
+
void compile_fuse(
|
|
778
|
+
std::vector<array>& tape,
|
|
779
|
+
ParentsMap& parents_map,
|
|
780
|
+
const std::vector<array>& inputs,
|
|
781
|
+
std::vector<array>& outputs) {
|
|
782
|
+
// Track outputs to replace with new compiled outputs
|
|
783
|
+
std::unordered_map<uintptr_t, array> output_map;
|
|
784
|
+
for (auto& o : outputs) {
|
|
785
|
+
output_map.insert({o.id(), o});
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
// Set of inputs to distinguish constants
|
|
789
|
+
std::unordered_set<uintptr_t> input_ids;
|
|
790
|
+
for (auto& in : inputs) {
|
|
791
|
+
input_ids.insert(in.id());
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
// Go through the tape in reverse order and check for fusable sub-graphs
|
|
795
|
+
std::vector<array> new_tape;
|
|
796
|
+
std::unordered_set<uintptr_t> global_cache;
|
|
797
|
+
for (int i = tape.size() - 1; i >= 0; --i) {
|
|
798
|
+
auto& arr = tape[i];
|
|
799
|
+
|
|
800
|
+
// Already compiled
|
|
801
|
+
if (global_cache.find(arr.id()) != global_cache.end()) {
|
|
802
|
+
continue;
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
// Two pass recursion:
|
|
806
|
+
// First pass:
|
|
807
|
+
// - Collect all the primitives which we can fuse with
|
|
808
|
+
// - Keeps a cache of fusable primitives which may be added out of
|
|
809
|
+
// DAG order. We have to determine if all of a fused primitive's
|
|
810
|
+
// outputs are also in the fused section, and this may not be the
|
|
811
|
+
// case the first time we visit it.
|
|
812
|
+
// Second pass:
|
|
813
|
+
// - Collect inputs to the new compiled primitive
|
|
814
|
+
// - Add fusable primitives to a tape in the correct order
|
|
815
|
+
|
|
816
|
+
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
|
|
817
|
+
std::unordered_set<uintptr_t> cache;
|
|
818
|
+
std::unordered_set<uintptr_t> input_set;
|
|
819
|
+
recurse = [&](const array& a,
|
|
820
|
+
int depth,
|
|
821
|
+
const Stream& s,
|
|
822
|
+
const Shape& shape) {
|
|
823
|
+
if (cache.find(a.id()) != cache.end()) {
|
|
824
|
+
return;
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
// Stop fusing if:
|
|
828
|
+
// - Depth limit exceeded
|
|
829
|
+
// - Constant input
|
|
830
|
+
// - Stream mismatch
|
|
831
|
+
// - Non fusable primitive
|
|
832
|
+
// - Is global output but has a different shape
|
|
833
|
+
if (depth >= max_compile_depth || !a.has_primitive() ||
|
|
834
|
+
a.primitive().stream() != s || !is_fusable(a.primitive()) ||
|
|
835
|
+
(output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
|
|
836
|
+
// Possible input
|
|
837
|
+
input_set.insert(a.id());
|
|
838
|
+
return;
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
bool all_parents_in = true;
|
|
842
|
+
if (depth > 0) {
|
|
843
|
+
// Guaranteed to have a parent since nested in the
|
|
844
|
+
// recursion.
|
|
845
|
+
auto& parents = parents_map.at(a.id());
|
|
846
|
+
for (auto& [p, idx] : parents) {
|
|
847
|
+
auto in_cache = cache.find(p.id()) != cache.end();
|
|
848
|
+
if (!in_cache) {
|
|
849
|
+
all_parents_in = false;
|
|
850
|
+
break;
|
|
851
|
+
}
|
|
852
|
+
}
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// Arrays with a mix of parents outside the compilable section
|
|
856
|
+
// are not fusable except for broadcast which we can split to avoid
|
|
857
|
+
// stopping fusion
|
|
858
|
+
if (!all_parents_in) {
|
|
859
|
+
if (a.has_primitive() && is_broadcast(a.primitive())) {
|
|
860
|
+
array b = split_one(a, parents_map, cache);
|
|
861
|
+
recurse(b, depth, s, shape);
|
|
862
|
+
} else {
|
|
863
|
+
// Possible input
|
|
864
|
+
input_set.insert(a.id());
|
|
865
|
+
}
|
|
866
|
+
return;
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
if (output_map.find(a.id()) != output_map.end()) {
|
|
870
|
+
input_set.insert(a.id());
|
|
871
|
+
} else {
|
|
872
|
+
// Not an input anymore since fusing it
|
|
873
|
+
input_set.erase(a.id());
|
|
874
|
+
}
|
|
875
|
+
if (input_set.size() >= max_compile_arrays) {
|
|
876
|
+
return;
|
|
877
|
+
}
|
|
878
|
+
cache.insert({a.id()});
|
|
879
|
+
|
|
880
|
+
for (auto& in : a.inputs()) {
|
|
881
|
+
recurse(in, depth + 1, s, shape);
|
|
882
|
+
}
|
|
883
|
+
};
|
|
884
|
+
|
|
885
|
+
// This will be the result of the fused operation so it needs
|
|
886
|
+
// a) to not be already computed ie have a primitive
|
|
887
|
+
// b) that primitive to not be a broadcast since it will unnecessarily
|
|
888
|
+
// cast to a contiguous array potentially blowing up memory
|
|
889
|
+
if (arr.has_primitive() && !is_broadcast(arr.primitive())) {
|
|
890
|
+
Stream s = arr.primitive().stream();
|
|
891
|
+
recurse(arr, 0, s, arr.shape());
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
// Not worth fusing a single primitive
|
|
895
|
+
if (cache.size() <= 1) {
|
|
896
|
+
new_tape.push_back(arr);
|
|
897
|
+
continue;
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
// Recurse a second time to build the tape in the right
|
|
901
|
+
// order and collect the inputs
|
|
902
|
+
input_set.clear();
|
|
903
|
+
std::vector<array> inputs;
|
|
904
|
+
std::vector<array> fused_tape;
|
|
905
|
+
std::unordered_set<uintptr_t> tape_set;
|
|
906
|
+
std::function<void(const array&)> recurse_tape;
|
|
907
|
+
recurse_tape = [&](const array& a) {
|
|
908
|
+
if (cache.find(a.id()) == cache.end()) {
|
|
909
|
+
if (input_set.find(a.id()) == input_set.end()) {
|
|
910
|
+
input_set.insert(a.id());
|
|
911
|
+
inputs.push_back(a);
|
|
912
|
+
}
|
|
913
|
+
return;
|
|
914
|
+
}
|
|
915
|
+
if (tape_set.find(a.id()) != tape_set.end()) {
|
|
916
|
+
return;
|
|
917
|
+
}
|
|
918
|
+
tape_set.insert(a.id());
|
|
919
|
+
for (auto& in : a.inputs()) {
|
|
920
|
+
recurse_tape(in);
|
|
921
|
+
}
|
|
922
|
+
fused_tape.push_back(a);
|
|
923
|
+
};
|
|
924
|
+
recurse_tape(arr);
|
|
925
|
+
|
|
926
|
+
std::vector<array> old_outputs;
|
|
927
|
+
// Add to global cache and add any global outputs to outputs
|
|
928
|
+
// of new primitive
|
|
929
|
+
for (int j = 0; j < fused_tape.size() - 1; ++j) {
|
|
930
|
+
auto& f = fused_tape[j];
|
|
931
|
+
if (output_map.find(f.id()) != output_map.end()) {
|
|
932
|
+
old_outputs.push_back(f);
|
|
933
|
+
// Parents are now siblings, update the parent map
|
|
934
|
+
auto& pairs = parents_map[f.id()];
|
|
935
|
+
pairs.erase(
|
|
936
|
+
std::remove_if(
|
|
937
|
+
pairs.begin(),
|
|
938
|
+
pairs.end(),
|
|
939
|
+
[&](auto& p) {
|
|
940
|
+
return cache.find(p.first.id()) != cache.end();
|
|
941
|
+
}),
|
|
942
|
+
pairs.end());
|
|
943
|
+
} else {
|
|
944
|
+
// Remove inner fused arrays parents from the parents map
|
|
945
|
+
// to keep the parents map in a valid state
|
|
946
|
+
parents_map.erase(f.id());
|
|
947
|
+
}
|
|
948
|
+
global_cache.insert({f.id()});
|
|
949
|
+
}
|
|
950
|
+
old_outputs.push_back(arr);
|
|
951
|
+
|
|
952
|
+
std::vector<Shape> shapes;
|
|
953
|
+
std::vector<Dtype> types;
|
|
954
|
+
for (auto& o : old_outputs) {
|
|
955
|
+
if (o.shape() != old_outputs.back().shape()) {
|
|
956
|
+
throw std::runtime_error(
|
|
957
|
+
"[compile] Compilation failed. Tried to fuse operations with different output shapes");
|
|
958
|
+
}
|
|
959
|
+
shapes.push_back(o.shape());
|
|
960
|
+
types.push_back(o.dtype());
|
|
961
|
+
}
|
|
962
|
+
std::unordered_set<uintptr_t> constant_ids;
|
|
963
|
+
for (auto& in : inputs) {
|
|
964
|
+
// Scalar constant
|
|
965
|
+
if (in.size() == 1 && !in.has_primitive() &&
|
|
966
|
+
input_ids.find(in.id()) == input_ids.end()) {
|
|
967
|
+
constant_ids.insert(in.id());
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
auto compiled_outputs = array::make_arrays(
|
|
971
|
+
std::move(shapes),
|
|
972
|
+
types,
|
|
973
|
+
std::make_shared<Compiled>(
|
|
974
|
+
old_outputs.back().primitive().stream(),
|
|
975
|
+
inputs,
|
|
976
|
+
old_outputs,
|
|
977
|
+
std::move(fused_tape),
|
|
978
|
+
std::move(constant_ids)),
|
|
979
|
+
inputs);
|
|
980
|
+
|
|
981
|
+
// One output per primitive
|
|
982
|
+
new_tape.push_back(compiled_outputs.back());
|
|
983
|
+
|
|
984
|
+
// Replace inputs old parents with compiled_outputs
|
|
985
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
986
|
+
auto& pairs = parents_map[inputs[i].id()];
|
|
987
|
+
pairs.erase(
|
|
988
|
+
std::remove_if(
|
|
989
|
+
pairs.begin(),
|
|
990
|
+
pairs.end(),
|
|
991
|
+
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
|
|
992
|
+
pairs.end());
|
|
993
|
+
for (auto& o : compiled_outputs) {
|
|
994
|
+
pairs.push_back({o, i});
|
|
995
|
+
}
|
|
996
|
+
}
|
|
997
|
+
|
|
998
|
+
// - Update outputs parents to point to compiled outputs
|
|
999
|
+
// - Update any overall graph outputs to be compiled outputs
|
|
1000
|
+
for (int o = 0; o < old_outputs.size(); ++o) {
|
|
1001
|
+
merge_one(compiled_outputs[o], old_outputs[o], parents_map);
|
|
1002
|
+
if (auto it = output_map.find(old_outputs[o].id());
|
|
1003
|
+
it != output_map.end()) {
|
|
1004
|
+
it->second = compiled_outputs[o];
|
|
1005
|
+
}
|
|
1006
|
+
}
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
std::reverse(new_tape.begin(), new_tape.end());
|
|
1010
|
+
tape = std::move(new_tape);
|
|
1011
|
+
|
|
1012
|
+
// Replace output with potentially compiled output
|
|
1013
|
+
for (auto& o : outputs) {
|
|
1014
|
+
o = output_map.at(o.id());
|
|
1015
|
+
}
|
|
1016
|
+
}
|
|
1017
|
+
|
|
1018
|
+
std::vector<array> compile_replace(
|
|
1019
|
+
const std::vector<array>& tape,
|
|
1020
|
+
const std::vector<array>& trace_inputs,
|
|
1021
|
+
const std::vector<array>& trace_outputs,
|
|
1022
|
+
const std::vector<array>& inputs,
|
|
1023
|
+
bool shapeless) {
|
|
1024
|
+
std::unordered_map<uintptr_t, array> trace_to_real;
|
|
1025
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
1026
|
+
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
|
|
1030
|
+
|
|
1031
|
+
for (auto& a : tape) {
|
|
1032
|
+
// Arrays in the tape without primitives are either:
|
|
1033
|
+
// - inputs, which are already in the map
|
|
1034
|
+
// - constants, which can be used directly
|
|
1035
|
+
// - a load primitive which has no inputs and will become a constant
|
|
1036
|
+
// after the first eval
|
|
1037
|
+
if (!a.has_primitive() || is_load(a.primitive())) {
|
|
1038
|
+
trace_to_real.insert({a.id(), a});
|
|
1039
|
+
} else {
|
|
1040
|
+
// Find real inputs
|
|
1041
|
+
std::vector<array> real_inputs;
|
|
1042
|
+
for (auto& in : a.inputs()) {
|
|
1043
|
+
real_inputs.push_back(trace_to_real.at(in.id()));
|
|
1044
|
+
}
|
|
1045
|
+
if (a.siblings().empty()) {
|
|
1046
|
+
auto shape =
|
|
1047
|
+
shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
|
|
1048
|
+
auto real_a = array(
|
|
1049
|
+
std::move(shape),
|
|
1050
|
+
a.dtype(),
|
|
1051
|
+
a.primitive_ptr(),
|
|
1052
|
+
std::move(real_inputs));
|
|
1053
|
+
trace_to_real.insert({a.id(), std::move(real_a)});
|
|
1054
|
+
} else {
|
|
1055
|
+
// Ensure the order is correct for multi-output primitives
|
|
1056
|
+
std::vector<Dtype> types;
|
|
1057
|
+
auto trace_out = a.outputs();
|
|
1058
|
+
for (auto& o : trace_out) {
|
|
1059
|
+
types.push_back(o.dtype());
|
|
1060
|
+
}
|
|
1061
|
+
std::vector<Shape> shapes;
|
|
1062
|
+
if (shapeless) {
|
|
1063
|
+
shapes = a.primitive().output_shapes(real_inputs);
|
|
1064
|
+
} else {
|
|
1065
|
+
for (auto& o : trace_out) {
|
|
1066
|
+
shapes.push_back(o.shape());
|
|
1067
|
+
}
|
|
1068
|
+
}
|
|
1069
|
+
auto real_out = array::make_arrays(
|
|
1070
|
+
std::move(shapes), types, a.primitive_ptr(), real_inputs);
|
|
1071
|
+
for (int i = 0; i < trace_out.size(); ++i) {
|
|
1072
|
+
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
|
1073
|
+
}
|
|
1074
|
+
}
|
|
1075
|
+
}
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
std::vector<array> outputs;
|
|
1079
|
+
for (auto& o : trace_outputs) {
|
|
1080
|
+
outputs.push_back(trace_to_real.at(o.id()));
|
|
1081
|
+
}
|
|
1082
|
+
return outputs;
|
|
1083
|
+
}
|
|
1084
|
+
|
|
1085
|
+
bool skip_compile() {
|
|
1086
|
+
return compile_mode() == CompileMode::disabled ||
|
|
1087
|
+
!(compile_available_for_device(default_device()));
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
ArrayFnWithExtra compile(
|
|
1091
|
+
ArrayFnWithExtra fun,
|
|
1092
|
+
std::uintptr_t fun_id,
|
|
1093
|
+
bool shapeless /* = false */,
|
|
1094
|
+
std::vector<uint64_t> constants /* = {} */) {
|
|
1095
|
+
if (skip_compile()) {
|
|
1096
|
+
return fun;
|
|
1097
|
+
}
|
|
1098
|
+
if (!fun) {
|
|
1099
|
+
throw std::invalid_argument(
|
|
1100
|
+
"[compile] Cannot compile a function without a target.");
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
return [fun = std::move(fun),
|
|
1104
|
+
fun_id,
|
|
1105
|
+
shapeless,
|
|
1106
|
+
constants = std::move(constants)](const std::vector<array>& inputs) {
|
|
1107
|
+
// If the inputs are tracers, trace the original graph
|
|
1108
|
+
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
|
1109
|
+
return in.is_tracer();
|
|
1110
|
+
})) {
|
|
1111
|
+
return fun(inputs);
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
// Find a cache entry with the correct inputs
|
|
1115
|
+
auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
|
|
1116
|
+
|
|
1117
|
+
// No matching cache entry existed, so compile
|
|
1118
|
+
if (entry.empty) {
|
|
1119
|
+
// Mark the entry as not empty since we are about to fill it
|
|
1120
|
+
entry.empty = false;
|
|
1121
|
+
// Set the constants
|
|
1122
|
+
entry.constants = std::move(constants);
|
|
1123
|
+
// Trace to build the graph
|
|
1124
|
+
std::tie(entry.inputs, entry.outputs, entry.extra) =
|
|
1125
|
+
compile_trace(fun, inputs, shapeless);
|
|
1126
|
+
|
|
1127
|
+
// DFS the graph and get a tape, and a map of array id to (parent,
|
|
1128
|
+
// position in parent inputs)
|
|
1129
|
+
std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
|
|
1130
|
+
parents_map;
|
|
1131
|
+
std::tie(entry.tape, parents_map) =
|
|
1132
|
+
compile_dfs(entry.inputs, entry.outputs, inputs);
|
|
1133
|
+
|
|
1134
|
+
// Simplify the tape
|
|
1135
|
+
if (compile_mode() != CompileMode::no_simplify) {
|
|
1136
|
+
compile_simplify(
|
|
1137
|
+
entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
// Kernel fusion to generate Compiled primitives. The tape and
|
|
1141
|
+
// new outputs must be updated accordingly
|
|
1142
|
+
if (compile_mode() != CompileMode::no_fuse) {
|
|
1143
|
+
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
|
1144
|
+
}
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
// At this point we must have a tape, now replace the placeholders
|
|
1148
|
+
// with real arrays that can be evaluated
|
|
1149
|
+
return ArraysAndExtra{
|
|
1150
|
+
compile_replace(
|
|
1151
|
+
entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
|
|
1152
|
+
entry.extra};
|
|
1153
|
+
};
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|
1157
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
|
1158
|
+
std::uintptr_t fun_id,
|
|
1159
|
+
bool shapeless /* = false */,
|
|
1160
|
+
std::vector<uint64_t> constants /* = {} */) {
|
|
1161
|
+
if (skip_compile()) {
|
|
1162
|
+
return fun;
|
|
1163
|
+
}
|
|
1164
|
+
if (!fun) {
|
|
1165
|
+
throw std::invalid_argument(
|
|
1166
|
+
"[compile] Cannot compile a function without a target.");
|
|
1167
|
+
}
|
|
1168
|
+
|
|
1169
|
+
ArrayFnWithExtra fun_with_extra =
|
|
1170
|
+
[fun = std::move(fun)](const std::vector<array>& inputs) {
|
|
1171
|
+
return ArraysAndExtra{fun(inputs), nullptr};
|
|
1172
|
+
};
|
|
1173
|
+
|
|
1174
|
+
auto compiled_fun = compile(
|
|
1175
|
+
std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
|
|
1176
|
+
|
|
1177
|
+
return [compiled_fun =
|
|
1178
|
+
std::move(compiled_fun)](const std::vector<array>& inputs) {
|
|
1179
|
+
return compiled_fun(inputs).first;
|
|
1180
|
+
};
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
void compile_erase(std::uintptr_t fun_id) {
|
|
1184
|
+
detail::compiler_cache().erase(fun_id);
|
|
1185
|
+
}
|
|
1186
|
+
|
|
1187
|
+
void compile_clear_cache() {
|
|
1188
|
+
detail::compiler_cache().clear();
|
|
1189
|
+
}
|
|
1190
|
+
|
|
1191
|
+
} // namespace detail
|
|
1192
|
+
|
|
1193
|
+
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|
1194
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
|
1195
|
+
bool shapeless /* false */) {
|
|
1196
|
+
if (detail::skip_compile()) {
|
|
1197
|
+
return fun;
|
|
1198
|
+
}
|
|
1199
|
+
auto fun_id = detail::get_function_address(fun);
|
|
1200
|
+
if (fun_id) {
|
|
1201
|
+
// If the function has an addressable target then no need to manage it's
|
|
1202
|
+
// lifetime
|
|
1203
|
+
return detail::compile(std::move(fun), fun_id, shapeless);
|
|
1204
|
+
} else {
|
|
1205
|
+
auto pfun = std::shared_ptr<
|
|
1206
|
+
std::function<std::vector<array>(const std::vector<array>&)>>(
|
|
1207
|
+
new std::function<std::vector<array>(const std::vector<array>&)>{fun},
|
|
1208
|
+
[](auto* p) {
|
|
1209
|
+
detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));
|
|
1210
|
+
delete p;
|
|
1211
|
+
});
|
|
1212
|
+
fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());
|
|
1213
|
+
return detail::compile(
|
|
1214
|
+
[pfun = std::move(pfun)](const auto& inputs) {
|
|
1215
|
+
return (*pfun)(inputs);
|
|
1216
|
+
},
|
|
1217
|
+
fun_id,
|
|
1218
|
+
shapeless);
|
|
1219
|
+
}
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|
1223
|
+
std::vector<array> (*fun)(const std::vector<array>&),
|
|
1224
|
+
bool shapeless /* = false */) {
|
|
1225
|
+
if (detail::skip_compile()) {
|
|
1226
|
+
return fun;
|
|
1227
|
+
}
|
|
1228
|
+
return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);
|
|
1229
|
+
}
|
|
1230
|
+
|
|
1231
|
+
void disable_compile() {
|
|
1232
|
+
detail::compile_mode() = CompileMode::disabled;
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
void enable_compile() {
|
|
1236
|
+
detail::compile_mode() = CompileMode::enabled;
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
void set_compile_mode(CompileMode mode) {
|
|
1240
|
+
detail::compile_mode() = mode;
|
|
1241
|
+
}
|
|
1242
|
+
|
|
1243
|
+
} // namespace mlx::core
|