mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,1065 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
#include <algorithm>
|
|
3
|
+
#include <deque>
|
|
4
|
+
#include <future>
|
|
5
|
+
#include <numeric>
|
|
6
|
+
#include <set>
|
|
7
|
+
#include <sstream>
|
|
8
|
+
#include <stack>
|
|
9
|
+
#include <unordered_map>
|
|
10
|
+
#include <unordered_set>
|
|
11
|
+
|
|
12
|
+
#include "mlx/backend/cpu/eval.h"
|
|
13
|
+
#include "mlx/backend/gpu/eval.h"
|
|
14
|
+
#include "mlx/fence.h"
|
|
15
|
+
#include "mlx/memory.h"
|
|
16
|
+
#include "mlx/ops.h"
|
|
17
|
+
#include "mlx/primitives.h"
|
|
18
|
+
#include "mlx/scheduler.h"
|
|
19
|
+
#include "mlx/transforms.h"
|
|
20
|
+
#include "mlx/transforms_impl.h"
|
|
21
|
+
#include "mlx/utils.h"
|
|
22
|
+
|
|
23
|
+
namespace mlx::core {
|
|
24
|
+
|
|
25
|
+
static constexpr int MAX_ACTIVE_TASKS = 10;
|
|
26
|
+
|
|
27
|
+
/* This class is only meant to be used in eval
|
|
28
|
+
* for synchronizing with the main thread. */
|
|
29
|
+
class Synchronizer : public Primitive {
|
|
30
|
+
public:
|
|
31
|
+
explicit Synchronizer(Stream stream) : Primitive(stream) {}
|
|
32
|
+
|
|
33
|
+
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
|
|
34
|
+
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
|
|
35
|
+
|
|
36
|
+
DEFINE_NAME(Synchronize);
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
// Initialize the static tracing members from transforms_impl.h
|
|
40
|
+
//
|
|
41
|
+
// These are used to implement the in_tracing() function the returns true if we
|
|
42
|
+
// are currently under a function transformation and the retain_graph()
|
|
43
|
+
// function which returns true if we are forced to retain the graph during
|
|
44
|
+
// evaluation.
|
|
45
|
+
std::vector<std::pair<char, char>>& detail::InTracing::trace_stack() {
|
|
46
|
+
static std::vector<std::pair<char, char>> trace_stack_;
|
|
47
|
+
return trace_stack_;
|
|
48
|
+
}
|
|
49
|
+
int detail::InTracing::grad_counter{0};
|
|
50
|
+
int detail::RetainGraph::tracing_counter{0};
|
|
51
|
+
|
|
52
|
+
array eval_impl(std::vector<array> outputs, bool async) {
|
|
53
|
+
std::deque<array> tape;
|
|
54
|
+
|
|
55
|
+
// Make an effort to choose a good output stream
|
|
56
|
+
Stream stream = default_stream(default_device());
|
|
57
|
+
for (auto& o : outputs) {
|
|
58
|
+
if (o.status() == array::Status::unscheduled && o.has_primitive()) {
|
|
59
|
+
stream = o.primitive().stream();
|
|
60
|
+
break;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
// Map of array id that needs fence and stream it's computed on
|
|
65
|
+
std::unordered_map<uintptr_t, std::pair<uint32_t, bool>> needs_fence;
|
|
66
|
+
|
|
67
|
+
auto synchronizer = array(
|
|
68
|
+
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
|
69
|
+
|
|
70
|
+
// Stream fences for inter-stream synchronization
|
|
71
|
+
std::unordered_map<uint32_t, Fence> fences;
|
|
72
|
+
|
|
73
|
+
// Stream events for synchronization after eval
|
|
74
|
+
std::unordered_map<uint32_t, Event> events;
|
|
75
|
+
{
|
|
76
|
+
auto e = Event{stream};
|
|
77
|
+
e.set_value(1);
|
|
78
|
+
synchronizer.attach_event(e);
|
|
79
|
+
events.emplace(stream.index, std::move(e));
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
{
|
|
83
|
+
// Record the degree of each input
|
|
84
|
+
std::unordered_map<std::uintptr_t, int> cache;
|
|
85
|
+
|
|
86
|
+
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
|
|
87
|
+
dfs.emplace(synchronizer, 0);
|
|
88
|
+
while (!dfs.empty()) {
|
|
89
|
+
auto& [a_ref, idx] = dfs.top();
|
|
90
|
+
auto& a = a_ref.get();
|
|
91
|
+
|
|
92
|
+
if (idx < a.inputs().size()) {
|
|
93
|
+
// Add an input, and continue
|
|
94
|
+
auto& in = a.inputs()[idx++];
|
|
95
|
+
|
|
96
|
+
if (in.status() == array::Status::unscheduled) {
|
|
97
|
+
if (async && in.is_tracer()) {
|
|
98
|
+
throw std::invalid_argument(
|
|
99
|
+
"[async_eval] Not allowed inside a graph transformation.");
|
|
100
|
+
}
|
|
101
|
+
if (!in.has_primitive()) {
|
|
102
|
+
if (in.is_tracer()) {
|
|
103
|
+
throw std::invalid_argument(
|
|
104
|
+
"[eval] Attempting to eval an array during function"
|
|
105
|
+
" transformations like compile or vmap is not allowed.");
|
|
106
|
+
}
|
|
107
|
+
throw std::runtime_error(
|
|
108
|
+
"[eval] Attempting to eval an array without a primitive.\n"
|
|
109
|
+
"If you are compiling a function, make sure all the inputs "
|
|
110
|
+
"and outputs are captured:\n"
|
|
111
|
+
"https://ml-explore.github.io/mlx/build/html/usage/compile.html#pure-functions.\n"
|
|
112
|
+
"If you are not using compile, this may be a bug. "
|
|
113
|
+
"Please file an issue here:\n"
|
|
114
|
+
"https://github.com/ml-explore/mlx/issues.");
|
|
115
|
+
}
|
|
116
|
+
if (a.primitive().stream() != in.primitive().stream()) {
|
|
117
|
+
bool device_switch =
|
|
118
|
+
a.primitive().stream().device != in.primitive().stream().device;
|
|
119
|
+
auto [it, inserted] = needs_fence.emplace(
|
|
120
|
+
in.id(),
|
|
121
|
+
std::make_pair(in.primitive().stream().index, device_switch));
|
|
122
|
+
if (!inserted) {
|
|
123
|
+
it->second.second |= device_switch;
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// All siblings have the same degree
|
|
129
|
+
auto cache_it = cache.find(in.id());
|
|
130
|
+
if (cache_it == cache.end()) {
|
|
131
|
+
dfs.emplace(in, 0);
|
|
132
|
+
cache.insert({in.id(), 1});
|
|
133
|
+
for (auto& s : in.siblings()) {
|
|
134
|
+
cache.insert({s.id(), 1});
|
|
135
|
+
}
|
|
136
|
+
} else {
|
|
137
|
+
cache_it->second++;
|
|
138
|
+
for (auto& s : in.siblings()) {
|
|
139
|
+
cache[s.id()]++;
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
continue;
|
|
143
|
+
}
|
|
144
|
+
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
|
|
145
|
+
a.has_primitive()) {
|
|
146
|
+
// If the array is evaluated and is no longer a tracer, detach it
|
|
147
|
+
a.detach();
|
|
148
|
+
}
|
|
149
|
+
dfs.pop();
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// Build the tape in BFS order with a width limit
|
|
153
|
+
int max_width = env::bfs_max_width();
|
|
154
|
+
dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
|
|
155
|
+
tape.push_back(synchronizer);
|
|
156
|
+
for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
|
|
157
|
+
auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
|
|
158
|
+
int j = 0;
|
|
159
|
+
if (i >= tape.size()) {
|
|
160
|
+
j = dfs.top().second;
|
|
161
|
+
dfs.pop();
|
|
162
|
+
} else {
|
|
163
|
+
i++;
|
|
164
|
+
}
|
|
165
|
+
for (; j < a.inputs().size(); ++j) {
|
|
166
|
+
auto& in = a.inputs()[j];
|
|
167
|
+
if (in.status() != array::Status::unscheduled) {
|
|
168
|
+
continue;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// If the width limit is exceeded, push the array on the stack
|
|
172
|
+
// and go down a level
|
|
173
|
+
if ((tape.size() - i) >= max_width) {
|
|
174
|
+
dfs.emplace(a, j);
|
|
175
|
+
break;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
auto it = cache.find(in.id());
|
|
179
|
+
it->second -= 1;
|
|
180
|
+
|
|
181
|
+
if (it->second != 0) {
|
|
182
|
+
for (auto& s : in.siblings()) {
|
|
183
|
+
cache[s.id()] -= 1;
|
|
184
|
+
}
|
|
185
|
+
continue;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
// Remove input and siblings from cache
|
|
189
|
+
cache.erase(it);
|
|
190
|
+
for (auto& s : in.siblings()) {
|
|
191
|
+
cache.erase(s.id());
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
tape.push_back(in);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
std::unordered_set<int> open_streams;
|
|
200
|
+
while (!tape.empty()) {
|
|
201
|
+
auto arr = std::move(tape.back());
|
|
202
|
+
tape.pop_back();
|
|
203
|
+
|
|
204
|
+
auto stream = arr.primitive().stream();
|
|
205
|
+
open_streams.insert(stream.index);
|
|
206
|
+
|
|
207
|
+
if (async) {
|
|
208
|
+
// Lookup corresponding event
|
|
209
|
+
auto e = events.find(stream.index);
|
|
210
|
+
if (e == events.end()) {
|
|
211
|
+
e = events.emplace(stream.index, Event{stream}).first;
|
|
212
|
+
}
|
|
213
|
+
e->second.set_value(1);
|
|
214
|
+
arr.attach_event(e->second);
|
|
215
|
+
for (auto& s : arr.siblings()) {
|
|
216
|
+
s.attach_event(e->second);
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
for (auto& in : arr.inputs()) {
|
|
221
|
+
if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {
|
|
222
|
+
// Use fence to wait within a single eval
|
|
223
|
+
// Get the input array's stream fence and wait on the
|
|
224
|
+
// output arrays stream
|
|
225
|
+
fences[it->second.first].wait(stream, in);
|
|
226
|
+
} else if (in.event().valid()) {
|
|
227
|
+
if (in.event().is_signaled()) {
|
|
228
|
+
in.detach_event();
|
|
229
|
+
} else if (in.event().stream() != stream) {
|
|
230
|
+
// Use event to wait across async eval
|
|
231
|
+
in.event().wait(stream);
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
if (arr.primitive().device() == Device::gpu) {
|
|
237
|
+
gpu::eval(arr);
|
|
238
|
+
} else {
|
|
239
|
+
cpu::eval(arr);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
|
|
243
|
+
(get_active_memory() > get_memory_limit() &&
|
|
244
|
+
scheduler::n_active_tasks() > 0)) {
|
|
245
|
+
// Commit any open streams
|
|
246
|
+
for (auto i : open_streams) {
|
|
247
|
+
auto s = get_stream(i);
|
|
248
|
+
if (s.device == Device::gpu) {
|
|
249
|
+
gpu::finalize(s);
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
scheduler::wait_for_one();
|
|
253
|
+
while (get_active_memory() > get_memory_limit() &&
|
|
254
|
+
scheduler::n_active_tasks() > 0) {
|
|
255
|
+
scheduler::wait_for_one();
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
|
260
|
+
if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
|
|
261
|
+
auto it = fences.find(stream.index);
|
|
262
|
+
if (it == fences.end()) {
|
|
263
|
+
it = fences.emplace(stream.index, Fence{stream}).first;
|
|
264
|
+
}
|
|
265
|
+
it->second.update(stream, a, nf->second.second);
|
|
266
|
+
}
|
|
267
|
+
};
|
|
268
|
+
|
|
269
|
+
arr.set_status(array::Status::evaluated);
|
|
270
|
+
// TODO Maybe always want the fence coherent kernel in the same cbuf
|
|
271
|
+
// as the other kernels?
|
|
272
|
+
maybe_update_fence(arr);
|
|
273
|
+
for (auto& sib : arr.siblings()) {
|
|
274
|
+
sib.set_status(array::Status::evaluated);
|
|
275
|
+
maybe_update_fence(sib);
|
|
276
|
+
}
|
|
277
|
+
if (!arr.is_tracer()) {
|
|
278
|
+
arr.detach();
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// Signal the event in its stream
|
|
283
|
+
for (auto i : open_streams) {
|
|
284
|
+
auto s = get_stream(i);
|
|
285
|
+
if (auto e = events.find(i); e != events.end()) {
|
|
286
|
+
e->second.signal(s);
|
|
287
|
+
}
|
|
288
|
+
if (s.device == Device::gpu) {
|
|
289
|
+
gpu::finalize(s);
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
return synchronizer;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
void async_eval(std::vector<array> outputs) {
|
|
297
|
+
if (outputs.empty()) {
|
|
298
|
+
return;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
|
|
302
|
+
return x.status() == array::Status::unscheduled;
|
|
303
|
+
})) {
|
|
304
|
+
return;
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
eval_impl(std::move(outputs), true);
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
void eval(std::vector<array> outputs) {
|
|
311
|
+
if (outputs.empty()) {
|
|
312
|
+
return;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
|
|
316
|
+
return x.status() == array::Status::unscheduled;
|
|
317
|
+
})) {
|
|
318
|
+
for (auto& x : outputs) {
|
|
319
|
+
x.wait();
|
|
320
|
+
}
|
|
321
|
+
return;
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
eval_impl(std::move(outputs), false).wait();
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
std::pair<std::vector<array>, std::vector<array>> vjp(
|
|
328
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
329
|
+
const std::vector<array>& primals,
|
|
330
|
+
const std::vector<array>& cotans,
|
|
331
|
+
const std::vector<int>& argnums) {
|
|
332
|
+
// Set the global tracing flag.
|
|
333
|
+
detail::InTracing in_tracing{false, true};
|
|
334
|
+
|
|
335
|
+
// Make tracers from given primals
|
|
336
|
+
std::vector<array> primals_;
|
|
337
|
+
for (auto& p : primals) {
|
|
338
|
+
auto s = p.has_primitive() ? p.primitive().stream()
|
|
339
|
+
: default_stream(default_device());
|
|
340
|
+
primals_.push_back(copy(p, s)); // Does not do a deep copy
|
|
341
|
+
primals_.back().set_tracer(true);
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
// Pass tracer primals through the function
|
|
345
|
+
// Any variables that depend on the primals are marked as tracers
|
|
346
|
+
auto outputs = fun(primals_);
|
|
347
|
+
|
|
348
|
+
// Map outputs to passed cotans while ignoring the outputs
|
|
349
|
+
// that have stop_gradient called on them
|
|
350
|
+
int cotan_index = 0;
|
|
351
|
+
std::vector<std::pair<int, int>> output_cotan_pairs;
|
|
352
|
+
for (int i = 0; i < outputs.size(); ++i) {
|
|
353
|
+
auto& out = outputs[i];
|
|
354
|
+
if (out.has_primitive()) {
|
|
355
|
+
if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
|
|
356
|
+
continue;
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
if (cotan_index >= cotans.size()) {
|
|
360
|
+
std::ostringstream msg;
|
|
361
|
+
msg << "[vjp] Number of outputs to compute gradients for ("
|
|
362
|
+
<< outputs.size() << ") does not match number of cotangents ("
|
|
363
|
+
<< cotans.size() << ").";
|
|
364
|
+
throw std::invalid_argument(msg.str());
|
|
365
|
+
}
|
|
366
|
+
if (out.shape() != cotans[cotan_index].shape()) {
|
|
367
|
+
std::ostringstream msg;
|
|
368
|
+
msg << "[vjp] Output shape " << out.shape()
|
|
369
|
+
<< " does not match cotangent shape " << cotans[cotan_index].shape()
|
|
370
|
+
<< ".";
|
|
371
|
+
if (outputs.size() == 1 && out.size() == 1) {
|
|
372
|
+
msg << " If you are using grad your function must return a scalar.";
|
|
373
|
+
}
|
|
374
|
+
throw std::invalid_argument(msg.str());
|
|
375
|
+
}
|
|
376
|
+
output_cotan_pairs.emplace_back(i, cotan_index++);
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
// Topologically sort the compute graph, add graph nodes
|
|
380
|
+
// to the tape which need a gradient.
|
|
381
|
+
std::unordered_set<std::uintptr_t> cache;
|
|
382
|
+
std::unordered_set<std::uintptr_t> calc_grad;
|
|
383
|
+
for (int i = 0, j = 0; i < primals_.size(); ++i) {
|
|
384
|
+
auto& primal = primals_[i];
|
|
385
|
+
primal.set_tracer(false);
|
|
386
|
+
cache.insert(primal.id());
|
|
387
|
+
if (j < argnums.size() && argnums[j] == i) {
|
|
388
|
+
j++;
|
|
389
|
+
calc_grad.insert(primal.id());
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
std::vector<array> tape;
|
|
394
|
+
|
|
395
|
+
std::function<void(array&)> recurse;
|
|
396
|
+
recurse = [&](auto& a) {
|
|
397
|
+
// Check if visited and add to cache if not
|
|
398
|
+
if (auto inserted = cache.insert(a.id()); !inserted.second) {
|
|
399
|
+
return;
|
|
400
|
+
}
|
|
401
|
+
a.set_tracer(false);
|
|
402
|
+
for (auto& s : a.siblings()) {
|
|
403
|
+
s.set_tracer(false);
|
|
404
|
+
cache.insert(s.id());
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
for (auto& input : a.inputs()) {
|
|
408
|
+
recurse(input);
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
// Stop grad
|
|
412
|
+
if (a.has_primitive()) {
|
|
413
|
+
if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
|
|
414
|
+
return;
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
// Calculate gradient if any inputs require gradient
|
|
419
|
+
for (auto& input : a.inputs()) {
|
|
420
|
+
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
|
421
|
+
tape.push_back(a);
|
|
422
|
+
calc_grad.insert(a.id());
|
|
423
|
+
for (auto& s : a.siblings()) {
|
|
424
|
+
calc_grad.insert(s.id());
|
|
425
|
+
}
|
|
426
|
+
break;
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
};
|
|
430
|
+
|
|
431
|
+
for (auto out : outputs) {
|
|
432
|
+
recurse(out);
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
// Run the tape backwards, computing vector-jacobian
|
|
436
|
+
// products for each primitive
|
|
437
|
+
std::unordered_map<std::uintptr_t, array> cotan_map;
|
|
438
|
+
for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
|
|
439
|
+
auto& o = outputs[out_idx];
|
|
440
|
+
auto s = o.has_primitive() ? o.primitive().stream()
|
|
441
|
+
: default_stream(default_device());
|
|
442
|
+
cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});
|
|
443
|
+
}
|
|
444
|
+
for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
|
|
445
|
+
auto& a = *it;
|
|
446
|
+
|
|
447
|
+
// Get the arguments whose gradients are needed
|
|
448
|
+
std::vector<int> argnums;
|
|
449
|
+
for (int i = 0; i < a.inputs().size(); ++i) {
|
|
450
|
+
if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
|
|
451
|
+
argnums.push_back(i);
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
// Check if any of the array or its siblings have cotangents,
|
|
456
|
+
// if not, we can skip this primitive
|
|
457
|
+
auto outputs = a.outputs();
|
|
458
|
+
bool has_cotans =
|
|
459
|
+
std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) {
|
|
460
|
+
return cotan_map.find(s.id()) != cotan_map.end();
|
|
461
|
+
});
|
|
462
|
+
if (!has_cotans) {
|
|
463
|
+
continue;
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
auto s = a.primitive().stream();
|
|
467
|
+
std::vector<array> cotangents{};
|
|
468
|
+
for (auto& o : outputs) {
|
|
469
|
+
if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) {
|
|
470
|
+
cotangents.push_back(cotan_map.extract(cotan_it).mapped());
|
|
471
|
+
} else {
|
|
472
|
+
cotangents.push_back(zeros_like(o, s));
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
std::vector<array> vjps;
|
|
477
|
+
{
|
|
478
|
+
detail::RetainGraph retain;
|
|
479
|
+
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
|
480
|
+
}
|
|
481
|
+
// Accumulate the vector-jacobian products for each input
|
|
482
|
+
for (int i = 0; i < argnums.size(); ++i) {
|
|
483
|
+
auto in_id = a.inputs()[argnums[i]].id();
|
|
484
|
+
if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
|
|
485
|
+
cotan_it->second = add(cotan_it->second, vjps[i], s);
|
|
486
|
+
} else {
|
|
487
|
+
cotan_map.insert({in_id, vjps[i]});
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
}
|
|
491
|
+
std::vector<array> vjps;
|
|
492
|
+
for (auto arg : argnums) {
|
|
493
|
+
auto& primal = primals_[arg];
|
|
494
|
+
if (auto cotan_it = cotan_map.find(primal.id());
|
|
495
|
+
cotan_it != cotan_map.end()) {
|
|
496
|
+
vjps.push_back(cotan_it->second);
|
|
497
|
+
} else {
|
|
498
|
+
auto s = primal.has_primitive() ? primal.primitive().stream()
|
|
499
|
+
: default_stream(default_device());
|
|
500
|
+
vjps.push_back(zeros_like(primal, s));
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
return {outputs, vjps};
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
std::pair<std::vector<array>, std::vector<array>> vjp(
|
|
507
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
508
|
+
const std::vector<array>& primals,
|
|
509
|
+
const std::vector<array>& cotans) {
|
|
510
|
+
std::vector<int> argnums(primals.size());
|
|
511
|
+
std::iota(argnums.begin(), argnums.end(), 0);
|
|
512
|
+
return vjp(fun, primals, cotans, argnums);
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
std::pair<array, array> vjp(
|
|
516
|
+
const std::function<array(const array&)>& fun,
|
|
517
|
+
const array& primal,
|
|
518
|
+
const array& cotan) {
|
|
519
|
+
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
|
520
|
+
return std::vector<array>{fun(inputs[0])};
|
|
521
|
+
};
|
|
522
|
+
auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});
|
|
523
|
+
return {outputs[0], vjps[0]};
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
std::pair<std::vector<array>, std::vector<array>> jvp(
|
|
527
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
528
|
+
const std::vector<array>& primals,
|
|
529
|
+
const std::vector<array>& tangents) {
|
|
530
|
+
// Set the global tracing flag.
|
|
531
|
+
detail::InTracing in_tracing{false, true};
|
|
532
|
+
|
|
533
|
+
if (primals.size() != tangents.size()) {
|
|
534
|
+
throw std::invalid_argument(
|
|
535
|
+
"[jvp] Number of inputs does not match number of tangents.");
|
|
536
|
+
}
|
|
537
|
+
for (int i = 0; i < primals.size(); ++i) {
|
|
538
|
+
if (primals[i].shape() != tangents[i].shape()) {
|
|
539
|
+
throw std::invalid_argument(
|
|
540
|
+
"[jvp] Input shape does not match shape of tangent.");
|
|
541
|
+
}
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
std::vector<array> primals_;
|
|
545
|
+
for (auto& p : primals) {
|
|
546
|
+
auto s = p.has_primitive() ? p.primitive().stream()
|
|
547
|
+
: default_stream(default_device());
|
|
548
|
+
primals_.push_back(copy(p, s)); // Does not do a deep copy
|
|
549
|
+
primals_.back().set_tracer(true);
|
|
550
|
+
}
|
|
551
|
+
auto outputs = fun(primals_);
|
|
552
|
+
|
|
553
|
+
// Topologically sort the compute graph, record outputs
|
|
554
|
+
// in the tape if a gradient is needed.
|
|
555
|
+
std::unordered_set<std::uintptr_t> cache;
|
|
556
|
+
std::unordered_set<std::uintptr_t> calc_grad;
|
|
557
|
+
for (auto& primal : primals_) {
|
|
558
|
+
primal.set_tracer(false);
|
|
559
|
+
calc_grad.insert(primal.id());
|
|
560
|
+
cache.insert(primal.id());
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
std::vector<array> tape;
|
|
564
|
+
|
|
565
|
+
std::function<void(array&)> recurse;
|
|
566
|
+
recurse = [&](auto& a) {
|
|
567
|
+
// Check if visited and add to cache if not
|
|
568
|
+
if (auto inserted = cache.insert(a.id()); !inserted.second) {
|
|
569
|
+
return;
|
|
570
|
+
}
|
|
571
|
+
a.set_tracer(false);
|
|
572
|
+
for (auto& s : a.siblings()) {
|
|
573
|
+
s.set_tracer(false);
|
|
574
|
+
cache.insert(s.id());
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
for (auto input : a.inputs()) {
|
|
578
|
+
recurse(input);
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
// Stop grad
|
|
582
|
+
if (a.has_primitive()) {
|
|
583
|
+
if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
|
|
584
|
+
return;
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
// Calculate gradient if any inputs require gradient
|
|
589
|
+
for (auto& input : a.inputs()) {
|
|
590
|
+
if (calc_grad.find(input.id()) != calc_grad.end()) {
|
|
591
|
+
tape.push_back(a);
|
|
592
|
+
calc_grad.insert(a.id());
|
|
593
|
+
for (auto& s : a.siblings()) {
|
|
594
|
+
calc_grad.insert(s.id());
|
|
595
|
+
}
|
|
596
|
+
break;
|
|
597
|
+
}
|
|
598
|
+
}
|
|
599
|
+
};
|
|
600
|
+
|
|
601
|
+
for (auto out : outputs) {
|
|
602
|
+
recurse(out);
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
std::unordered_map<std::uintptr_t, array> tan_map;
|
|
606
|
+
for (int i = 0; i < primals_.size(); ++i) {
|
|
607
|
+
tan_map.insert({primals_[i].id(), tangents[i]});
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
for (auto& a : tape) {
|
|
611
|
+
// Get the arguments used in the jvp
|
|
612
|
+
std::vector<int> argnums;
|
|
613
|
+
std::vector<array> tangents;
|
|
614
|
+
for (int i = 0; i < a.inputs().size(); ++i) {
|
|
615
|
+
if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
|
|
616
|
+
argnums.push_back(i);
|
|
617
|
+
tangents.push_back(it->second);
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);
|
|
622
|
+
auto outputs = a.outputs();
|
|
623
|
+
for (int i = 0; i < jvps.size(); ++i) {
|
|
624
|
+
tan_map.insert({outputs[i].id(), jvps[i]});
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
std::vector<array> jvps;
|
|
629
|
+
for (auto& out : outputs) {
|
|
630
|
+
if (auto it = tan_map.find(out.id()); it != tan_map.end()) {
|
|
631
|
+
jvps.push_back(it->second);
|
|
632
|
+
} else {
|
|
633
|
+
auto s = out.has_primitive() ? out.primitive().stream()
|
|
634
|
+
: default_stream(default_device());
|
|
635
|
+
jvps.push_back(zeros_like(out, s));
|
|
636
|
+
}
|
|
637
|
+
}
|
|
638
|
+
return {outputs, jvps};
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
std::pair<array, array> jvp(
|
|
642
|
+
const std::function<array(const array&)>& fun,
|
|
643
|
+
const array& primal,
|
|
644
|
+
const array& tangent) {
|
|
645
|
+
auto vec_fun = [fun](const std::vector<array>& inputs) {
|
|
646
|
+
return std::vector<array>{fun(inputs[0])};
|
|
647
|
+
};
|
|
648
|
+
auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});
|
|
649
|
+
return {outputs[0], jvps[0]};
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
ValueAndGradFn value_and_grad(
|
|
653
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
654
|
+
const std::vector<int>& argnums) {
|
|
655
|
+
if (argnums.empty()) {
|
|
656
|
+
throw std::invalid_argument("[grad] Must specify at least one argument.");
|
|
657
|
+
}
|
|
658
|
+
return [fun, argnums](const std::vector<array>& inputs) {
|
|
659
|
+
std::set<int> args;
|
|
660
|
+
for (auto& arg : argnums) {
|
|
661
|
+
args.insert(arg < 0 ? arg + inputs.size() : arg);
|
|
662
|
+
}
|
|
663
|
+
if (args.size() != argnums.size()) {
|
|
664
|
+
throw std::invalid_argument(
|
|
665
|
+
"[grad] Repeat argument number not allowed in grad.");
|
|
666
|
+
}
|
|
667
|
+
if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
|
|
668
|
+
std::ostringstream msg;
|
|
669
|
+
msg << "[grad] Invalid argument number for function with "
|
|
670
|
+
<< inputs.size() << " inputs.";
|
|
671
|
+
throw std::invalid_argument(msg.str());
|
|
672
|
+
}
|
|
673
|
+
std::vector<int> sorted_argnums(args.begin(), args.end());
|
|
674
|
+
|
|
675
|
+
auto gfun = [&fun](const std::vector<array>& inputs) {
|
|
676
|
+
auto outputs = fun(inputs);
|
|
677
|
+
for (int i = 1; i < outputs.size(); i++) {
|
|
678
|
+
auto& out = outputs[i];
|
|
679
|
+
auto s = out.has_primitive() ? out.primitive().stream()
|
|
680
|
+
: default_stream(default_device());
|
|
681
|
+
outputs[i] = stop_gradient(out, s);
|
|
682
|
+
}
|
|
683
|
+
return outputs;
|
|
684
|
+
};
|
|
685
|
+
|
|
686
|
+
// Set the incoming gradient to float32, vjp will cast it to the output type
|
|
687
|
+
auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
|
|
688
|
+
return std::make_pair(outputs, grads);
|
|
689
|
+
};
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
namespace detail {
|
|
693
|
+
|
|
694
|
+
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
|
695
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
696
|
+
const std::vector<array>& inputs,
|
|
697
|
+
const std::vector<int>& in_axes) {
|
|
698
|
+
// Set the global tracing flag.
|
|
699
|
+
detail::InTracing in_tracing;
|
|
700
|
+
|
|
701
|
+
if (in_axes.size() != inputs.size()) {
|
|
702
|
+
std::stringstream ss;
|
|
703
|
+
ss << "[vmap] The number of in axes (" << in_axes.size()
|
|
704
|
+
<< ") must match the number of inputs (" << inputs.size() << ").";
|
|
705
|
+
throw std::invalid_argument(ss.str());
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
// Some error checking and get the vmap axis size
|
|
709
|
+
size_t vmap_ax_size;
|
|
710
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
711
|
+
if (in_axes[i] != -1) {
|
|
712
|
+
if (inputs[i].ndim() == 0) {
|
|
713
|
+
throw std::invalid_argument(
|
|
714
|
+
"[vmap] Cannot vmap an input with zero dimensions.");
|
|
715
|
+
}
|
|
716
|
+
if (in_axes[i] > inputs[i].ndim()) {
|
|
717
|
+
std::ostringstream msg;
|
|
718
|
+
msg << "[vmap] Axis " << in_axes[i] << " invalid for input with "
|
|
719
|
+
<< inputs[i].ndim() << " dimensions.";
|
|
720
|
+
throw std::invalid_argument(msg.str());
|
|
721
|
+
}
|
|
722
|
+
vmap_ax_size = inputs[i].shape(in_axes[i]);
|
|
723
|
+
}
|
|
724
|
+
}
|
|
725
|
+
// Check that all vmapped axes have the same size
|
|
726
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
727
|
+
if (in_axes[i] != -1) {
|
|
728
|
+
if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
|
|
729
|
+
std::ostringstream msg;
|
|
730
|
+
msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and "
|
|
731
|
+
<< vmap_ax_size << ".";
|
|
732
|
+
throw std::invalid_argument(msg.str());
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
// Run the function on placeholder inputs
|
|
738
|
+
// to get the original graph
|
|
739
|
+
std::vector<array> s_inputs;
|
|
740
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
741
|
+
if (in_axes[i] != -1) {
|
|
742
|
+
auto shape = inputs[i].shape();
|
|
743
|
+
shape.erase(shape.begin() + in_axes[i]);
|
|
744
|
+
array in(shape, inputs[i].dtype(), nullptr, {});
|
|
745
|
+
s_inputs.push_back(in);
|
|
746
|
+
s_inputs.back().set_tracer(true);
|
|
747
|
+
} else {
|
|
748
|
+
s_inputs.push_back(inputs[i]);
|
|
749
|
+
}
|
|
750
|
+
}
|
|
751
|
+
return {s_inputs, fun(s_inputs)};
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
std::vector<array> vmap_replace(
|
|
755
|
+
const std::vector<array>& inputs,
|
|
756
|
+
const std::vector<array>& s_inputs,
|
|
757
|
+
const std::vector<array>& s_outputs,
|
|
758
|
+
const std::vector<int>& in_axes,
|
|
759
|
+
const std::vector<int>& out_axes) {
|
|
760
|
+
if (out_axes.size() != s_outputs.size()) {
|
|
761
|
+
std::stringstream msg;
|
|
762
|
+
msg << "[vmap] The number of out axes (" << out_axes.size()
|
|
763
|
+
<< ") must match the number of outputs (" << s_outputs.size() << ").";
|
|
764
|
+
throw std::invalid_argument(msg.str());
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
int vmap_size = -1;
|
|
768
|
+
for (int i = 0; i < inputs.size(); ++i) {
|
|
769
|
+
if (in_axes[i] >= 0) {
|
|
770
|
+
vmap_size = inputs[i].shape(in_axes[i]);
|
|
771
|
+
break;
|
|
772
|
+
}
|
|
773
|
+
}
|
|
774
|
+
if (vmap_size == -1) {
|
|
775
|
+
throw std::invalid_argument("At least one of in_axes must be non-None.");
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
|
|
779
|
+
std::unordered_set<std::uintptr_t> needs_vmap;
|
|
780
|
+
std::unordered_set<std::uintptr_t> cache;
|
|
781
|
+
for (int i = 0; i < s_inputs.size(); ++i) {
|
|
782
|
+
auto in = s_inputs[i];
|
|
783
|
+
if (in_axes[i] != -1) {
|
|
784
|
+
tmap.insert({in.id(), {inputs[i], in_axes[i]}});
|
|
785
|
+
needs_vmap.insert(in.id());
|
|
786
|
+
in.set_tracer(false);
|
|
787
|
+
}
|
|
788
|
+
cache.insert(in.id());
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
// Topologically sort the graph
|
|
792
|
+
std::vector<array> tape;
|
|
793
|
+
|
|
794
|
+
std::function<void(const array&)> recurse;
|
|
795
|
+
|
|
796
|
+
recurse = [&](const array& a) {
|
|
797
|
+
auto id = a.id();
|
|
798
|
+
if (cache.find(id) != cache.end()) {
|
|
799
|
+
return;
|
|
800
|
+
}
|
|
801
|
+
cache.insert(id);
|
|
802
|
+
for (auto& s : a.siblings()) {
|
|
803
|
+
cache.insert(s.id());
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
// Recurse on inputs
|
|
807
|
+
for (auto& input : a.inputs()) {
|
|
808
|
+
recurse(input);
|
|
809
|
+
}
|
|
810
|
+
// If any input needs a vmap, then the outputs also need
|
|
811
|
+
// a vmap
|
|
812
|
+
for (auto& input : a.inputs()) {
|
|
813
|
+
if (needs_vmap.find(input.id()) != needs_vmap.end()) {
|
|
814
|
+
tape.push_back(a);
|
|
815
|
+
tape.back().set_tracer(false);
|
|
816
|
+
needs_vmap.insert(a.id());
|
|
817
|
+
for (auto s : a.siblings()) {
|
|
818
|
+
needs_vmap.insert(s.id());
|
|
819
|
+
s.set_tracer(false);
|
|
820
|
+
}
|
|
821
|
+
break;
|
|
822
|
+
}
|
|
823
|
+
}
|
|
824
|
+
};
|
|
825
|
+
|
|
826
|
+
for (auto& out : s_outputs) {
|
|
827
|
+
if (out.has_primitive()) {
|
|
828
|
+
recurse(out);
|
|
829
|
+
}
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
// Transform each primitive in the graph with
|
|
833
|
+
// its vmap implementation
|
|
834
|
+
for (auto& a : tape) {
|
|
835
|
+
std::vector<array> v_inputs;
|
|
836
|
+
std::vector<int> v_axes;
|
|
837
|
+
for (auto& in : a.inputs()) {
|
|
838
|
+
auto map_it = tmap.find(in.id());
|
|
839
|
+
if (map_it != tmap.end()) {
|
|
840
|
+
v_inputs.push_back(map_it->second.first);
|
|
841
|
+
v_axes.push_back(map_it->second.second);
|
|
842
|
+
} else {
|
|
843
|
+
v_inputs.push_back(in);
|
|
844
|
+
v_axes.push_back(-1);
|
|
845
|
+
}
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
|
|
849
|
+
|
|
850
|
+
// For each primitive's outputs add its id, the vout id and the vax
|
|
851
|
+
auto outputs = a.outputs();
|
|
852
|
+
for (int i = 0; i < v_outputs.size(); ++i) {
|
|
853
|
+
tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
// Populate the outputs and make sure all the output axes are
|
|
858
|
+
// in the right place
|
|
859
|
+
std::vector<array> outputs;
|
|
860
|
+
for (int i = 0; i < s_outputs.size(); ++i) {
|
|
861
|
+
if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {
|
|
862
|
+
auto& [out, vdim] = map_it->second;
|
|
863
|
+
if (vdim != out_axes[i]) {
|
|
864
|
+
if (out_axes[i] >= out.ndim()) {
|
|
865
|
+
std::ostringstream msg;
|
|
866
|
+
msg << "[vmap] Axis " << out_axes[i] << " invalid for output with "
|
|
867
|
+
<< out.ndim() << " dimensions.";
|
|
868
|
+
throw std::invalid_argument(msg.str());
|
|
869
|
+
}
|
|
870
|
+
out = moveaxis(out, vdim, out_axes[i]);
|
|
871
|
+
}
|
|
872
|
+
outputs.push_back(out);
|
|
873
|
+
} else {
|
|
874
|
+
// When the output has no input dependencies
|
|
875
|
+
// use the size of the vmapped axis in the inputs to expand the output
|
|
876
|
+
array output = expand_dims(s_outputs[i], out_axes[i]);
|
|
877
|
+
output = repeat(output, vmap_size, out_axes[i]);
|
|
878
|
+
outputs.push_back(output);
|
|
879
|
+
}
|
|
880
|
+
}
|
|
881
|
+
return outputs;
|
|
882
|
+
}
|
|
883
|
+
|
|
884
|
+
} // namespace detail
|
|
885
|
+
|
|
886
|
+
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
|
887
|
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
|
888
|
+
const std::vector<int>& in_axes /* = {} */,
|
|
889
|
+
const std::vector<int>& out_axes /* = {} */) {
|
|
890
|
+
auto infer_axes = [](auto axes) {
|
|
891
|
+
return !axes.empty() &&
|
|
892
|
+
std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });
|
|
893
|
+
};
|
|
894
|
+
if (infer_axes(in_axes) != infer_axes(out_axes)) {
|
|
895
|
+
throw std::invalid_argument(
|
|
896
|
+
"[vmap] Input (or output) axes must be "
|
|
897
|
+
"specified if output (or input) axes are.");
|
|
898
|
+
}
|
|
899
|
+
auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](
|
|
900
|
+
const std::vector<array>& inputs) mutable {
|
|
901
|
+
if (in_axes.size() == 0) {
|
|
902
|
+
in_axes.resize(inputs.size(), 0);
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
auto [trace_inputs, trace_outputs] =
|
|
906
|
+
detail::vmap_trace(fun, inputs, in_axes);
|
|
907
|
+
|
|
908
|
+
if (out_axes.size() == 0) {
|
|
909
|
+
out_axes.resize(trace_outputs.size(), 0);
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
return detail::vmap_replace(
|
|
913
|
+
inputs, trace_inputs, trace_outputs, in_axes, out_axes);
|
|
914
|
+
};
|
|
915
|
+
|
|
916
|
+
return vfun;
|
|
917
|
+
}
|
|
918
|
+
|
|
919
|
+
std::function<array(const array&, const array&)> vmap(
|
|
920
|
+
const std::function<array(const array&, const array&)>& fun,
|
|
921
|
+
int in_axis_a /* = 0 */,
|
|
922
|
+
int in_axis_b /* = 0 */,
|
|
923
|
+
int out_axis /* = 0 */) {
|
|
924
|
+
auto vfun = vmap(
|
|
925
|
+
[fun](const std::vector<array>& inputs) {
|
|
926
|
+
return std::vector<array>{fun(inputs[0], inputs[1])};
|
|
927
|
+
},
|
|
928
|
+
{in_axis_a, in_axis_b},
|
|
929
|
+
{out_axis});
|
|
930
|
+
return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
std::function<array(const array&)> vmap(
|
|
934
|
+
const std::function<array(const array&)>& fun,
|
|
935
|
+
int in_axis /* = 0 */,
|
|
936
|
+
int out_axis /* = 0 */) {
|
|
937
|
+
auto vfun = vmap(
|
|
938
|
+
[fun](const std::vector<array>& inputs) {
|
|
939
|
+
return std::vector<array>{fun(inputs[0])};
|
|
940
|
+
},
|
|
941
|
+
{in_axis},
|
|
942
|
+
{out_axis});
|
|
943
|
+
return [vfun](const array& a) { return vfun({a})[0]; };
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
|
947
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
|
948
|
+
std::optional<std::function<std::vector<array>(
|
|
949
|
+
const std::vector<array>&,
|
|
950
|
+
const std::vector<array>&,
|
|
951
|
+
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
|
|
952
|
+
std::optional<std::function<std::vector<array>(
|
|
953
|
+
const std::vector<array>&,
|
|
954
|
+
const std::vector<array>&,
|
|
955
|
+
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
|
|
956
|
+
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
|
957
|
+
const std::vector<array>&,
|
|
958
|
+
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
|
|
959
|
+
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
|
|
960
|
+
return fun;
|
|
961
|
+
}
|
|
962
|
+
|
|
963
|
+
return [fun = std::move(fun),
|
|
964
|
+
fun_vjp = std::move(fun_vjp),
|
|
965
|
+
fun_jvp = std::move(fun_jvp),
|
|
966
|
+
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
|
|
967
|
+
// Compute the outputs
|
|
968
|
+
auto outputs = fun(args);
|
|
969
|
+
for (auto& out : outputs) {
|
|
970
|
+
out = stop_gradient(out);
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
// Prepare the inputs to the primitive
|
|
974
|
+
// We also add the outputs to the primitive so that it can "run" the forward
|
|
975
|
+
// pass.
|
|
976
|
+
std::vector<array> inputs = args;
|
|
977
|
+
inputs.insert(inputs.end(), outputs.begin(), outputs.end());
|
|
978
|
+
|
|
979
|
+
// Compute the stream. Maybe do it in a smarter way at some point in the
|
|
980
|
+
// future.
|
|
981
|
+
Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
|
|
982
|
+
: default_stream(default_device());
|
|
983
|
+
|
|
984
|
+
// Make the output info
|
|
985
|
+
std::vector<Shape> shapes;
|
|
986
|
+
std::vector<Dtype> dtypes;
|
|
987
|
+
for (const auto& out : outputs) {
|
|
988
|
+
shapes.emplace_back(out.shape());
|
|
989
|
+
dtypes.emplace_back(out.dtype());
|
|
990
|
+
}
|
|
991
|
+
|
|
992
|
+
return array::make_arrays(
|
|
993
|
+
std::move(shapes),
|
|
994
|
+
dtypes,
|
|
995
|
+
std::make_shared<CustomTransforms>(
|
|
996
|
+
to_stream(s),
|
|
997
|
+
outputs.size(),
|
|
998
|
+
|
|
999
|
+
// We use the passed vjp function or compute it from the inputs and
|
|
1000
|
+
// passed cotangents. Note that this may be less efficient than
|
|
1001
|
+
// using `fun` directly because we may not be able to fully reuse
|
|
1002
|
+
// the outputs of the forward pass.
|
|
1003
|
+
fun_vjp.value_or(
|
|
1004
|
+
[fun](auto primals, auto cotangents, auto outputs) {
|
|
1005
|
+
auto [__, vjps] = vjp(fun, primals, cotangents);
|
|
1006
|
+
return vjps;
|
|
1007
|
+
}),
|
|
1008
|
+
|
|
1009
|
+
// We use the passed jvp function or compute it from the primals
|
|
1010
|
+
// and tangents. Similarly we can't take full advantage of the
|
|
1011
|
+
// argnums so it is best to use `fun` directly if we don't need a
|
|
1012
|
+
// custom transform.
|
|
1013
|
+
//
|
|
1014
|
+
// TODO: Use stop_gradient to make full use of argnums and not
|
|
1015
|
+
// waste computation.
|
|
1016
|
+
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
|
|
1017
|
+
std::vector<array> all_tangents;
|
|
1018
|
+
for (int i = 0, j = 0; i < primals.size(); i++) {
|
|
1019
|
+
if (j < argnums.size() && i == argnums[j]) {
|
|
1020
|
+
all_tangents.emplace_back(tangents[j++]);
|
|
1021
|
+
} else {
|
|
1022
|
+
all_tangents.emplace_back(zeros_like(primals[i]));
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
auto [__, jvps] = jvp(fun, primals, all_tangents);
|
|
1026
|
+
return jvps;
|
|
1027
|
+
}),
|
|
1028
|
+
|
|
1029
|
+
// Same as above, we use the passed vmap function or we compute it
|
|
1030
|
+
// from `fun`. The output axes is selected to be all 0s which again
|
|
1031
|
+
// may be suboptimal but the only thing we can do without any
|
|
1032
|
+
// information for `fun`.
|
|
1033
|
+
fun_vmap.value_or(
|
|
1034
|
+
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
|
|
1035
|
+
-> std::pair<std::vector<array>, std::vector<int>> {
|
|
1036
|
+
std::vector<int> out_axes(out_size, 0);
|
|
1037
|
+
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
|
|
1038
|
+
})),
|
|
1039
|
+
inputs);
|
|
1040
|
+
};
|
|
1041
|
+
}
|
|
1042
|
+
|
|
1043
|
+
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
|
1044
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
|
1045
|
+
std::function<std::vector<array>(
|
|
1046
|
+
const std::vector<array>&,
|
|
1047
|
+
const std::vector<array>&,
|
|
1048
|
+
const std::vector<array>&)> fun_vjp) {
|
|
1049
|
+
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
|
|
1050
|
+
}
|
|
1051
|
+
|
|
1052
|
+
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
|
1053
|
+
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
|
1054
|
+
auto vjp_fun = [fun](
|
|
1055
|
+
const std::vector<array>& primals,
|
|
1056
|
+
const std::vector<array>& cotangents,
|
|
1057
|
+
const std::vector<array>& outputs) -> std::vector<array> {
|
|
1058
|
+
auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
|
|
1059
|
+
return vjps;
|
|
1060
|
+
};
|
|
1061
|
+
|
|
1062
|
+
return custom_vjp(fun, vjp_fun);
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
} // namespace mlx::core
|