mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,976 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
from mlx.nn import Module
|
|
7
|
+
from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Optimizer:
|
|
11
|
+
"""The base class for all optimizers. It allows us to implement an
|
|
12
|
+
optimizer on a per-parameter basis and apply it to a parameter tree.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, schedulers=None):
|
|
16
|
+
self._initialized = False
|
|
17
|
+
self._state = {"step": mx.array(0, mx.uint64)}
|
|
18
|
+
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
|
|
19
|
+
|
|
20
|
+
def update(self, model: Module, gradients: dict):
|
|
21
|
+
"""Apply the gradients to the parameters of the model and update the
|
|
22
|
+
model with the new parameters.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
model (mlx.nn.Module): An mlx module to be updated.
|
|
26
|
+
gradients (dict): A Python tree of gradients, most likely computed
|
|
27
|
+
via :func:`mlx.nn.value_and_grad`.
|
|
28
|
+
"""
|
|
29
|
+
model.update(self.apply_gradients(gradients, model))
|
|
30
|
+
|
|
31
|
+
def init(self, parameters: dict):
|
|
32
|
+
"""Initialize the optimizer's state
|
|
33
|
+
|
|
34
|
+
This function can be used to initialize optimizers which have state
|
|
35
|
+
(like momentum in :class:`SGD`). Using this method is optional as the
|
|
36
|
+
optimizer will initialize itself if the state is not yet set. However,
|
|
37
|
+
there are some cases where explicit initialization is useful in order
|
|
38
|
+
to have access to the :attr:`Optimizer.state` before the first call to
|
|
39
|
+
:meth:`Optimizer.update`.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model (dict): A Python tree of parameters.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
|
|
46
|
+
>>> model = nn.Linear(2, 2)
|
|
47
|
+
>>> optimizer.init(model.trainable_parameters())
|
|
48
|
+
>>> optimizer.state.keys()
|
|
49
|
+
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# Initialize the optimizer state to match the parameter state
|
|
53
|
+
def update_state(params, state):
|
|
54
|
+
if isinstance(params, (list, tuple)):
|
|
55
|
+
state = list(state)
|
|
56
|
+
for i in range(len(state)):
|
|
57
|
+
state[i] = update_state(params[i], state[i])
|
|
58
|
+
if len(state) != len(params):
|
|
59
|
+
state.extend(tree_map(lambda _: {}, params[len(state) :]))
|
|
60
|
+
return type(params)(state)
|
|
61
|
+
elif isinstance(params, dict):
|
|
62
|
+
for k, v in params.items():
|
|
63
|
+
if k not in state:
|
|
64
|
+
state[k] = tree_map(lambda _: {}, v)
|
|
65
|
+
else:
|
|
66
|
+
state[k] = update_state(v, state[k])
|
|
67
|
+
return state
|
|
68
|
+
else:
|
|
69
|
+
return state
|
|
70
|
+
|
|
71
|
+
update_state(parameters, self._state)
|
|
72
|
+
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
|
|
73
|
+
self._initialized = True
|
|
74
|
+
|
|
75
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
76
|
+
"""To be extended by the children classes to implement each optimizer's
|
|
77
|
+
state initialization.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
parameter (mx.array): A single parameter that will be optimized.
|
|
81
|
+
state (dict): The optimizer's state.
|
|
82
|
+
"""
|
|
83
|
+
raise NotImplementedError()
|
|
84
|
+
|
|
85
|
+
def apply_gradients(self, gradients: dict, parameters: dict):
|
|
86
|
+
"""Apply the gradients to the parameters and return the updated parameters.
|
|
87
|
+
|
|
88
|
+
Can be used to update a model via
|
|
89
|
+
``model.update(opt.apply_gradients(grads, model))`` which is precisely
|
|
90
|
+
how :meth:`Optimizer.update` is implemented.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
gradients (dict): A Python tree of gradients.
|
|
94
|
+
parameters (dict): A Python tree of parameters. It can be a
|
|
95
|
+
superset of the gradients. In that case the returned python
|
|
96
|
+
tree will be of the same structure as the gradients.
|
|
97
|
+
"""
|
|
98
|
+
if not self._initialized:
|
|
99
|
+
self.init(gradients)
|
|
100
|
+
|
|
101
|
+
# Update any scheduled variables
|
|
102
|
+
for param, scheduler in self._schedulers.items():
|
|
103
|
+
self.state[param] = scheduler(self.step)
|
|
104
|
+
|
|
105
|
+
# Increment the step
|
|
106
|
+
self.state["step"] = self.step + 1
|
|
107
|
+
|
|
108
|
+
# Apply the update
|
|
109
|
+
return tree_map(self.apply_single, gradients, parameters, self.state)
|
|
110
|
+
|
|
111
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
112
|
+
"""To be extended by derived classes to implement the optimizer's update.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
gradient (mx.array): The ``parameter`` gradient.
|
|
116
|
+
parameter (mx.array): The ``parameter`` to update.
|
|
117
|
+
state (dict): The optimizer's state.
|
|
118
|
+
"""
|
|
119
|
+
raise NotImplementedError()
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def state(self):
|
|
123
|
+
"""The optimizer's state dictionary."""
|
|
124
|
+
return self._state
|
|
125
|
+
|
|
126
|
+
@state.setter
|
|
127
|
+
def state(self, state: dict):
|
|
128
|
+
self._initialized = False
|
|
129
|
+
self._state = state
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def step(self):
|
|
133
|
+
return self.state["step"]
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def learning_rate(self):
|
|
137
|
+
return self.state["learning_rate"]
|
|
138
|
+
|
|
139
|
+
@learning_rate.setter
|
|
140
|
+
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
|
141
|
+
self.state["learning_rate"] = mx.array(learning_rate)
|
|
142
|
+
|
|
143
|
+
def _maybe_schedule(
|
|
144
|
+
self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
|
|
145
|
+
):
|
|
146
|
+
"""
|
|
147
|
+
To be used by derived classes to optionally put a parameter on a schedule.
|
|
148
|
+
"""
|
|
149
|
+
if isinstance(param, Callable):
|
|
150
|
+
self._schedulers[name] = param
|
|
151
|
+
parameter = param(self.step)
|
|
152
|
+
else:
|
|
153
|
+
parameter = mx.array(param)
|
|
154
|
+
self.state[name] = parameter
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class MultiOptimizer(Optimizer):
|
|
158
|
+
"""Wraps a list of optimizers with corresponding weight predicates/filters
|
|
159
|
+
to make it easy to use different optimizers for different weights.
|
|
160
|
+
|
|
161
|
+
The predicates take the full "path" of the weight and the weight itself and
|
|
162
|
+
return True if it should be considered for this optimizer. The last
|
|
163
|
+
optimizer in the list is a fallback optimizer and no predicate should be
|
|
164
|
+
given for it.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
optimizers (list[Optimizer]): A list of optimizers to delegate to
|
|
168
|
+
filters (list[Callable[[str, array], bool]): A list of predicates that
|
|
169
|
+
should be one less than the provided optimizers.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, optimizers, filters: list = []):
|
|
173
|
+
super().__init__()
|
|
174
|
+
self._state = {}
|
|
175
|
+
|
|
176
|
+
if len(filters) != len(optimizers) - 1:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Given {len(filters)} filters but {len(optimizers)-1} needed."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self.optimizers = optimizers
|
|
182
|
+
self.filters = filters + [lambda *args, **kwargs: True]
|
|
183
|
+
|
|
184
|
+
def _split_dictionary(self, gradients: dict):
|
|
185
|
+
if len(self.optimizers) == 1:
|
|
186
|
+
return [gradients]
|
|
187
|
+
|
|
188
|
+
parts = [[] for _ in range(len(self.optimizers))]
|
|
189
|
+
flat_gradients = tree_flatten(gradients)
|
|
190
|
+
for k, g in flat_gradients:
|
|
191
|
+
for i, fn in enumerate(self.filters):
|
|
192
|
+
if fn(k, g):
|
|
193
|
+
parts[i].append((k, g))
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
return [tree_unflatten(p) for p in parts]
|
|
197
|
+
|
|
198
|
+
def init(self, parameters: dict):
|
|
199
|
+
for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
|
|
200
|
+
o.init(p)
|
|
201
|
+
|
|
202
|
+
def apply_gradients(self, gradients: dict, parameters: dict):
|
|
203
|
+
tree = {}
|
|
204
|
+
for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
|
|
205
|
+
tree = tree_merge(tree, o.apply_gradients(g, parameters))
|
|
206
|
+
return tree
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def state(self):
|
|
210
|
+
return {"states": [o.state for o in self.optimizers]}
|
|
211
|
+
|
|
212
|
+
@state.setter
|
|
213
|
+
def state(self, state: dict):
|
|
214
|
+
if "states" not in state or len(state["states"]) != len(self.optimizers):
|
|
215
|
+
raise ValueError("Invalid state provided")
|
|
216
|
+
|
|
217
|
+
for o, s in zip(self.optimizers, state["states"]):
|
|
218
|
+
o.state = s
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def learning_rate(self):
|
|
222
|
+
return self.optimizers[0].learning_rate
|
|
223
|
+
|
|
224
|
+
@learning_rate.setter
|
|
225
|
+
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
|
226
|
+
for o in self.optimizers:
|
|
227
|
+
o.learning_rate = learning_rate
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class SGD(Optimizer):
|
|
231
|
+
r"""The stochastic gradient descent optimizer.
|
|
232
|
+
|
|
233
|
+
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
|
234
|
+
|
|
235
|
+
.. math::
|
|
236
|
+
|
|
237
|
+
v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
|
|
238
|
+
w_{t+1} &= w_t - \lambda v_{t+1}
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
|
242
|
+
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
|
243
|
+
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
|
244
|
+
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
|
245
|
+
nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
251
|
+
momentum: float = 0.0,
|
|
252
|
+
weight_decay: float = 0.0,
|
|
253
|
+
dampening: float = 0.0,
|
|
254
|
+
nesterov: bool = False,
|
|
255
|
+
):
|
|
256
|
+
if nesterov and (momentum <= 0 or dampening != 0):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"Nesterov momentum requires a momentum and zero dampening."
|
|
259
|
+
)
|
|
260
|
+
super().__init__()
|
|
261
|
+
|
|
262
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
263
|
+
self.momentum = momentum
|
|
264
|
+
self.weight_decay = weight_decay
|
|
265
|
+
self.dampening = dampening
|
|
266
|
+
self.nesterov = nesterov
|
|
267
|
+
|
|
268
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
269
|
+
"""Initialize optimizer state"""
|
|
270
|
+
state["v"] = mx.zeros_like(parameter)
|
|
271
|
+
|
|
272
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
273
|
+
"""Performs the SGD parameter update and stores :math:`v` in the
|
|
274
|
+
optimizer state."""
|
|
275
|
+
|
|
276
|
+
if self.weight_decay != 0:
|
|
277
|
+
gradient += self.weight_decay * parameter
|
|
278
|
+
|
|
279
|
+
if self.momentum <= 0:
|
|
280
|
+
return parameter - self.learning_rate.astype(gradient.dtype) * gradient
|
|
281
|
+
|
|
282
|
+
v = self.momentum * state.get("v")
|
|
283
|
+
if self.dampening > 0:
|
|
284
|
+
v += (1 - self.dampening) * gradient
|
|
285
|
+
else:
|
|
286
|
+
v += gradient
|
|
287
|
+
|
|
288
|
+
if self.nesterov:
|
|
289
|
+
update = gradient + self.momentum * v
|
|
290
|
+
else:
|
|
291
|
+
update = v
|
|
292
|
+
|
|
293
|
+
state["v"] = v
|
|
294
|
+
return parameter - self.learning_rate.astype(gradient.dtype) * update
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class RMSprop(Optimizer):
|
|
298
|
+
r"""The RMSprop optimizer [1].
|
|
299
|
+
|
|
300
|
+
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
|
301
|
+
|
|
302
|
+
.. math::
|
|
303
|
+
|
|
304
|
+
v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
|
|
305
|
+
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
|
309
|
+
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
|
310
|
+
Default: ``0.99``
|
|
311
|
+
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
|
312
|
+
to improve numerical stability. Default: ``1e-8``
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
318
|
+
alpha: float = 0.99,
|
|
319
|
+
eps: float = 1e-8,
|
|
320
|
+
):
|
|
321
|
+
super().__init__()
|
|
322
|
+
|
|
323
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
324
|
+
self.alpha = alpha
|
|
325
|
+
self.eps = eps
|
|
326
|
+
|
|
327
|
+
if self.alpha < 0.0:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
|
|
330
|
+
)
|
|
331
|
+
if self.eps < 0.0:
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
337
|
+
"""Initialize optimizer state"""
|
|
338
|
+
state["v"] = mx.zeros_like(parameter)
|
|
339
|
+
|
|
340
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
341
|
+
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
|
|
342
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
343
|
+
alpha = self.alpha
|
|
344
|
+
eps = self.eps
|
|
345
|
+
|
|
346
|
+
v = state["v"]
|
|
347
|
+
v = alpha * v + (1 - alpha) * mx.square(gradient)
|
|
348
|
+
state["v"] = v
|
|
349
|
+
|
|
350
|
+
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class Adagrad(Optimizer):
|
|
354
|
+
r"""The Adagrad optimizer [1].
|
|
355
|
+
|
|
356
|
+
Our Adagrad implementation follows the original paper. In detail,
|
|
357
|
+
|
|
358
|
+
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
|
|
359
|
+
for online learning and stochastic optimization. JMLR 2011.
|
|
360
|
+
|
|
361
|
+
.. math::
|
|
362
|
+
|
|
363
|
+
v_{t+1} &= v_t + g_t^2 \\
|
|
364
|
+
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
|
368
|
+
eps (float, optional): The term :math:`\epsilon` added to the
|
|
369
|
+
denominator to improve numerical stability. Default: ``1e-8``
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
def __init__(
|
|
373
|
+
self,
|
|
374
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
375
|
+
eps: float = 1e-8,
|
|
376
|
+
):
|
|
377
|
+
super().__init__()
|
|
378
|
+
|
|
379
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
380
|
+
self.eps = eps
|
|
381
|
+
|
|
382
|
+
if self.eps < 0.0:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
388
|
+
"""Initialize optimizer state"""
|
|
389
|
+
state["v"] = mx.zeros_like(parameter)
|
|
390
|
+
|
|
391
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
392
|
+
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
|
393
|
+
optimizer state."""
|
|
394
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
395
|
+
eps = self.eps
|
|
396
|
+
|
|
397
|
+
v = state["v"] + mx.square(gradient)
|
|
398
|
+
state["v"] = v
|
|
399
|
+
|
|
400
|
+
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class AdaDelta(Optimizer):
|
|
404
|
+
r"""The AdaDelta optimizer with a learning rate [1].
|
|
405
|
+
|
|
406
|
+
Our AdaDelta implementation follows the original paper. In detail,
|
|
407
|
+
|
|
408
|
+
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
|
|
409
|
+
|
|
410
|
+
.. math::
|
|
411
|
+
|
|
412
|
+
v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
|
|
413
|
+
\Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
|
|
414
|
+
u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
|
|
415
|
+
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
|
419
|
+
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
|
420
|
+
running average of squared gradients. Default: ``0.9``
|
|
421
|
+
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
|
422
|
+
numerical stability. Default: `1e-8`
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
428
|
+
rho: float = 0.9,
|
|
429
|
+
eps: float = 1e-6,
|
|
430
|
+
):
|
|
431
|
+
super().__init__()
|
|
432
|
+
|
|
433
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
434
|
+
self.rho = rho
|
|
435
|
+
self.eps = eps
|
|
436
|
+
if self.rho < 0.0:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
f"AdaDelta rho should be >=0, {self.rho} was provided instead"
|
|
439
|
+
)
|
|
440
|
+
if self.eps < 0.0:
|
|
441
|
+
raise ValueError(
|
|
442
|
+
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
446
|
+
"""Initialize optimizer state"""
|
|
447
|
+
state["v"] = mx.zeros_like(parameter)
|
|
448
|
+
state["u"] = mx.zeros_like(parameter)
|
|
449
|
+
|
|
450
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
451
|
+
"""Performs the AdaDelta parameter update and stores :math:`v` and
|
|
452
|
+
:math:`u` in the optimizer state."""
|
|
453
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
454
|
+
rho = self.rho
|
|
455
|
+
eps = self.eps
|
|
456
|
+
|
|
457
|
+
v = state["v"]
|
|
458
|
+
u = state["u"]
|
|
459
|
+
|
|
460
|
+
v = rho * v + (1 - rho) * mx.square(gradient)
|
|
461
|
+
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
|
462
|
+
u = rho * u + (1 - rho) * mx.square(d)
|
|
463
|
+
|
|
464
|
+
state["v"] = v
|
|
465
|
+
state["u"] = u
|
|
466
|
+
|
|
467
|
+
return parameter - lr * d
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class Adam(Optimizer):
|
|
471
|
+
r"""The Adam optimizer [1]. In detail,
|
|
472
|
+
|
|
473
|
+
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
|
474
|
+
optimization. ICLR 2015.
|
|
475
|
+
|
|
476
|
+
.. math::
|
|
477
|
+
|
|
478
|
+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
479
|
+
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
|
480
|
+
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
|
484
|
+
betas (Tuple[float, float], optional): The coefficients
|
|
485
|
+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
|
486
|
+
gradient and its square. Default: ``(0.9, 0.999)``
|
|
487
|
+
eps (float, optional): The term :math:`\epsilon` added to the
|
|
488
|
+
denominator to improve numerical stability. Default: ``1e-8``
|
|
489
|
+
bias_correction (bool, optional): If set to ``True``, bias correction
|
|
490
|
+
is applied. Default: ``False``
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
def __init__(
|
|
494
|
+
self,
|
|
495
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
496
|
+
betas: List[float] = [0.9, 0.999],
|
|
497
|
+
eps: float = 1e-8,
|
|
498
|
+
bias_correction: bool = False,
|
|
499
|
+
):
|
|
500
|
+
super().__init__()
|
|
501
|
+
|
|
502
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
503
|
+
self.betas = betas
|
|
504
|
+
self.eps = eps
|
|
505
|
+
self.bias_correction = bias_correction
|
|
506
|
+
|
|
507
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
508
|
+
"""Initialize optimizer state"""
|
|
509
|
+
state["m"] = mx.zeros_like(parameter)
|
|
510
|
+
state["v"] = mx.zeros_like(parameter)
|
|
511
|
+
|
|
512
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
513
|
+
"""Performs the Adam parameter update and stores :math:`v` and
|
|
514
|
+
:math:`m` in the optimizer state."""
|
|
515
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
516
|
+
b1, b2 = self.betas
|
|
517
|
+
eps = self.eps
|
|
518
|
+
bias_correction = self.bias_correction
|
|
519
|
+
step = self.step
|
|
520
|
+
|
|
521
|
+
m = state["m"]
|
|
522
|
+
v = state["v"]
|
|
523
|
+
m = b1 * m + (1 - b1) * gradient
|
|
524
|
+
v = b2 * v + (1 - b2) * mx.square(gradient)
|
|
525
|
+
state["m"] = m
|
|
526
|
+
state["v"] = v
|
|
527
|
+
|
|
528
|
+
if bias_correction:
|
|
529
|
+
c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
|
|
530
|
+
c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
|
|
531
|
+
numerator = c1 * m
|
|
532
|
+
denominator = mx.sqrt(v) * c2 + eps
|
|
533
|
+
return parameter - numerator / denominator
|
|
534
|
+
else:
|
|
535
|
+
return parameter - lr * m / (mx.sqrt(v) + eps)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class AdamW(Adam):
|
|
539
|
+
r"""The AdamW optimizer [1]. We update the weights with a weight_decay
|
|
540
|
+
(:math:`\lambda`) value:
|
|
541
|
+
|
|
542
|
+
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
|
543
|
+
regularization. ICLR 2019.
|
|
544
|
+
|
|
545
|
+
.. math::
|
|
546
|
+
|
|
547
|
+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
548
|
+
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
|
549
|
+
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
learning_rate (float or callable): The learning rate :math:`\alpha`.
|
|
553
|
+
betas (Tuple[float, float], optional): The coefficients
|
|
554
|
+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
|
555
|
+
gradient and its square. Default: ``(0.9, 0.999)``
|
|
556
|
+
eps (float, optional): The term :math:`\epsilon` added to the
|
|
557
|
+
denominator to improve numerical stability. Default: ``1e-8``
|
|
558
|
+
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
|
559
|
+
Default: ``0.01``.
|
|
560
|
+
bias_correction (bool, optional): If set to ``True``, bias correction
|
|
561
|
+
is applied. Default: ``False``
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
def __init__(
|
|
565
|
+
self,
|
|
566
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
567
|
+
betas: List[float] = [0.9, 0.999],
|
|
568
|
+
eps: float = 1e-8,
|
|
569
|
+
weight_decay: float = 0.01,
|
|
570
|
+
bias_correction: bool = False,
|
|
571
|
+
):
|
|
572
|
+
super().__init__(
|
|
573
|
+
learning_rate=learning_rate,
|
|
574
|
+
betas=betas,
|
|
575
|
+
eps=eps,
|
|
576
|
+
bias_correction=bias_correction,
|
|
577
|
+
)
|
|
578
|
+
self.weight_decay = weight_decay
|
|
579
|
+
|
|
580
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
581
|
+
"""Performs the AdamW parameter update by modifying the parameters
|
|
582
|
+
passed into Adam.
|
|
583
|
+
"""
|
|
584
|
+
|
|
585
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
586
|
+
return super().apply_single(
|
|
587
|
+
gradient, parameter * (1 - lr * self.weight_decay), state
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
class Adamax(Adam):
|
|
592
|
+
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
|
|
593
|
+
|
|
594
|
+
Our Adam implementation follows the original paper and omits the bias
|
|
595
|
+
correction in the first and second moment estimates. In detail,
|
|
596
|
+
|
|
597
|
+
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
|
598
|
+
optimization. ICLR 2015.
|
|
599
|
+
|
|
600
|
+
.. math::
|
|
601
|
+
|
|
602
|
+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
603
|
+
v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
|
|
604
|
+
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
|
608
|
+
betas (Tuple[float, float], optional): The coefficients
|
|
609
|
+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
|
610
|
+
gradient and its square. Default: ``(0.9, 0.999)``
|
|
611
|
+
eps (float, optional): The term :math:`\epsilon` added to the
|
|
612
|
+
denominator to improve numerical stability. Default: ``1e-8``
|
|
613
|
+
"""
|
|
614
|
+
|
|
615
|
+
def __init__(
|
|
616
|
+
self,
|
|
617
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
618
|
+
betas: List[float] = [0.9, 0.999],
|
|
619
|
+
eps: float = 1e-8,
|
|
620
|
+
):
|
|
621
|
+
super().__init__(learning_rate, betas, eps)
|
|
622
|
+
if not 0.0 <= eps:
|
|
623
|
+
raise ValueError(
|
|
624
|
+
f"Epsilon value should be >=0, {self.eps} was provided instead"
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
628
|
+
"""Initialize optimizer state"""
|
|
629
|
+
state["m"] = mx.zeros_like(parameter)
|
|
630
|
+
state["v"] = mx.zeros_like(parameter)
|
|
631
|
+
|
|
632
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
633
|
+
"""Performs the Adamax parameter update and stores :math:`v` and
|
|
634
|
+
:math:`m` in the optimizer state."""
|
|
635
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
636
|
+
b1, b2 = self.betas
|
|
637
|
+
eps = self.eps
|
|
638
|
+
|
|
639
|
+
m = state["m"]
|
|
640
|
+
v = state["v"]
|
|
641
|
+
|
|
642
|
+
m = b1 * m + (1 - b1) * gradient
|
|
643
|
+
v = mx.maximum(b2 * v, mx.abs(gradient))
|
|
644
|
+
state["m"] = m
|
|
645
|
+
state["v"] = v
|
|
646
|
+
|
|
647
|
+
return parameter - lr * m / (v + eps)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
class Lion(Optimizer):
|
|
651
|
+
r"""The Lion optimizer [1].
|
|
652
|
+
|
|
653
|
+
Since updates are computed through the sign operation, they tend to
|
|
654
|
+
have larger norm than for other optimizers such as SGD and Adam.
|
|
655
|
+
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
|
656
|
+
weight decay 3-10x larger than AdamW to maintain the strength
|
|
657
|
+
(lr * wd). Our Lion implementation follows the original paper. In
|
|
658
|
+
detail,
|
|
659
|
+
|
|
660
|
+
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
|
|
661
|
+
preprint arXiv:2302.06675.
|
|
662
|
+
|
|
663
|
+
.. math::
|
|
664
|
+
|
|
665
|
+
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
|
666
|
+
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t \\
|
|
667
|
+
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
learning_rate (float or callable): The learning rate :math:`\eta`.
|
|
671
|
+
betas (Tuple[float, float], optional): The coefficients
|
|
672
|
+
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
|
673
|
+
momentum and update direction. Default: ``(0.9, 0.99)``
|
|
674
|
+
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
def __init__(
|
|
678
|
+
self,
|
|
679
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
680
|
+
betas: List[float] = [0.9, 0.99],
|
|
681
|
+
weight_decay: float = 0.0,
|
|
682
|
+
):
|
|
683
|
+
super().__init__()
|
|
684
|
+
|
|
685
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
686
|
+
self.betas = betas
|
|
687
|
+
self.weight_decay = weight_decay
|
|
688
|
+
|
|
689
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
690
|
+
"""Initialize optimizer state"""
|
|
691
|
+
state["m"] = mx.zeros_like(parameter)
|
|
692
|
+
|
|
693
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
694
|
+
"""Performs the Lion parameter update and stores :math:`m`
|
|
695
|
+
in the optimizer state."""
|
|
696
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
697
|
+
b1, b2 = self.betas
|
|
698
|
+
weight_decay = self.weight_decay
|
|
699
|
+
|
|
700
|
+
m = state["m"]
|
|
701
|
+
c = b1 * m + (1 - b1) * gradient
|
|
702
|
+
state["m"] = b2 * m + (1 - b2) * gradient
|
|
703
|
+
if weight_decay > 0:
|
|
704
|
+
parameter = (1 - lr * weight_decay) * parameter
|
|
705
|
+
return parameter - lr * mx.sign(c)
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
class Adafactor(Optimizer):
|
|
709
|
+
r"""The Adafactor optimizer.
|
|
710
|
+
|
|
711
|
+
Our Adafactor implementation follows the original paper: `Adafactor:
|
|
712
|
+
Adaptive Learning Rates with Sublinear Memory Cost
|
|
713
|
+
<https://arxiv.org/abs/1804.04235>`_
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
learning_rate (float or callable, optional): The learning rate.
|
|
717
|
+
Default: ``None``.
|
|
718
|
+
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
|
|
719
|
+
added to the square of the gradients to improve numerical
|
|
720
|
+
stability and the second term :math:`\epsilon_2` is used for
|
|
721
|
+
parameter scaling if ``parameter_scale`` is set to ``True``.
|
|
722
|
+
Default: ``(1e-30, 1e-3)``.
|
|
723
|
+
clip_threshold (float, optional): Clips the unscaled update at
|
|
724
|
+
``clip_threshold``. Default: ``1.0``.
|
|
725
|
+
decay_rate (float, optional): Coefficient for the running average
|
|
726
|
+
of the squared gradient. Default: ``-0.8``.
|
|
727
|
+
beta_1 (float, optional): If set to a value bigger than zero
|
|
728
|
+
then first moment will be used. Default: ``None``.
|
|
729
|
+
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
|
730
|
+
Default: ``0.0``.
|
|
731
|
+
scale_parameter (bool, optional): If set to ``True`` the learning rate
|
|
732
|
+
will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
|
|
733
|
+
Default: ``True``.
|
|
734
|
+
relative_step (bool, optional): If set to ``True`` the ``learning_rate``
|
|
735
|
+
will be ignored and relative step size will be computed.
|
|
736
|
+
Default: ``True``.
|
|
737
|
+
warmup_init (bool, optional): If set to ``True`` then the relative
|
|
738
|
+
step size will be calculated by the current step. Default:
|
|
739
|
+
``False``.
|
|
740
|
+
"""
|
|
741
|
+
|
|
742
|
+
def __init__(
|
|
743
|
+
self,
|
|
744
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
|
|
745
|
+
eps: Tuple[float, float] = (1e-30, 1e-3),
|
|
746
|
+
clip_threshold: float = 1.0,
|
|
747
|
+
decay_rate: float = -0.8,
|
|
748
|
+
beta_1: Optional[float] = None,
|
|
749
|
+
weight_decay: float = 0.0,
|
|
750
|
+
scale_parameter: bool = True,
|
|
751
|
+
relative_step: bool = True,
|
|
752
|
+
warmup_init: bool = False,
|
|
753
|
+
):
|
|
754
|
+
super().__init__()
|
|
755
|
+
if learning_rate is not None:
|
|
756
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
757
|
+
self.eps = eps
|
|
758
|
+
self.clip_threshold = clip_threshold
|
|
759
|
+
self.decay_rate = decay_rate
|
|
760
|
+
self.beta_1 = beta_1
|
|
761
|
+
self.weight_decay = weight_decay
|
|
762
|
+
self.scale_parameter = scale_parameter
|
|
763
|
+
self.relative_step = relative_step
|
|
764
|
+
self.warmup_init = warmup_init
|
|
765
|
+
|
|
766
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
767
|
+
"""Initialize optimizer state"""
|
|
768
|
+
if parameter.ndim >= 2:
|
|
769
|
+
shape = parameter.shape
|
|
770
|
+
dtype = parameter.dtype
|
|
771
|
+
state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
|
|
772
|
+
state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
|
|
773
|
+
else:
|
|
774
|
+
state["exp_avg_sq"] = mx.zeros_like(parameter)
|
|
775
|
+
|
|
776
|
+
if self.beta_1 is not None:
|
|
777
|
+
state["exp_avg"] = mx.zeros_like(parameter)
|
|
778
|
+
|
|
779
|
+
def _compute_rms(self, inputs):
|
|
780
|
+
return mx.sqrt(mx.mean(mx.square(inputs)))
|
|
781
|
+
|
|
782
|
+
def _compute_learning_rate(self, step, parameter_rms):
|
|
783
|
+
if self.relative_step:
|
|
784
|
+
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
|
785
|
+
relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
|
|
786
|
+
else:
|
|
787
|
+
relative_step_size = self.learning_rate
|
|
788
|
+
|
|
789
|
+
relative_step_size = relative_step_size.astype(parameter_rms.dtype)
|
|
790
|
+
parameter_scale = 1.0
|
|
791
|
+
if self.scale_parameter:
|
|
792
|
+
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
|
|
793
|
+
return parameter_scale * relative_step_size
|
|
794
|
+
|
|
795
|
+
def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
|
|
796
|
+
r_factor = mx.rsqrt(
|
|
797
|
+
exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
|
|
798
|
+
)
|
|
799
|
+
c_factor = mx.rsqrt(exp_avg_sq_col)
|
|
800
|
+
return mx.matmul(
|
|
801
|
+
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
805
|
+
"""Performs the Adafactor parameter and state update."""
|
|
806
|
+
factored = gradient.ndim >= 2
|
|
807
|
+
|
|
808
|
+
step = self.step
|
|
809
|
+
use_first_moment = self.beta_1 is not None
|
|
810
|
+
|
|
811
|
+
parameter_rms = self._compute_rms(parameter)
|
|
812
|
+
learning_rate = self._compute_learning_rate(step, parameter_rms)
|
|
813
|
+
beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
|
|
814
|
+
update = mx.square(gradient) + self.eps[0]
|
|
815
|
+
|
|
816
|
+
if factored:
|
|
817
|
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
|
818
|
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
|
819
|
+
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
|
|
820
|
+
(1 - beta_2) * mx.mean(update, axis=-1)
|
|
821
|
+
)
|
|
822
|
+
exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
|
|
823
|
+
(1 - beta_2) * mx.mean(update, axis=-2)
|
|
824
|
+
)
|
|
825
|
+
state["exp_avg_sq_row"] = exp_avg_sq_row
|
|
826
|
+
state["exp_avg_sq_col"] = exp_avg_sq_col
|
|
827
|
+
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
|
|
828
|
+
update = update * gradient
|
|
829
|
+
else:
|
|
830
|
+
exp_avg_sq = state["exp_avg_sq"]
|
|
831
|
+
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
|
|
832
|
+
state["exp_avg_sq"] = exp_avg_sq
|
|
833
|
+
update = mx.rsqrt(exp_avg_sq) * gradient
|
|
834
|
+
|
|
835
|
+
update = update / mx.maximum(
|
|
836
|
+
1.0, self._compute_rms(update) / self.clip_threshold
|
|
837
|
+
)
|
|
838
|
+
update = learning_rate * update
|
|
839
|
+
|
|
840
|
+
if use_first_moment:
|
|
841
|
+
exp_avg = state["exp_avg"]
|
|
842
|
+
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
|
|
843
|
+
state["exp_avg"] = exp_avg
|
|
844
|
+
update = exp_avg
|
|
845
|
+
|
|
846
|
+
if self.weight_decay != 0:
|
|
847
|
+
parameter += parameter * (-self.weight_decay * learning_rate)
|
|
848
|
+
return parameter - update
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
class Muon(Optimizer):
|
|
852
|
+
r"""The Muon optimizer.
|
|
853
|
+
|
|
854
|
+
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
|
855
|
+
original implementation: `Muon: An optimizer for hidden layers in neural
|
|
856
|
+
networks <https://kellerjordan.github.io/posts/muon/>`_
|
|
857
|
+
|
|
858
|
+
Note:
|
|
859
|
+
- Muon may be sub-optimal for the embedding layer, the final fully
|
|
860
|
+
connected layer, or any 0D/1D parameters. Those should be optimized
|
|
861
|
+
by a different method (e.g., :class:`AdamW`).
|
|
862
|
+
- For 4D convolutional filters, it works by flattening their last
|
|
863
|
+
dimensions.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
learning_rate (float or callable): The learning rate.
|
|
867
|
+
momentum (float, optional): The momentum strength. Default: ``0.95``
|
|
868
|
+
weight_decay (float, optional): The weight decay (L2 penalty).
|
|
869
|
+
Default: ``0.01``
|
|
870
|
+
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
|
871
|
+
better performance. Default: ``True``
|
|
872
|
+
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
|
873
|
+
orthogonalization. Default: ``5``
|
|
874
|
+
"""
|
|
875
|
+
|
|
876
|
+
def __init__(
|
|
877
|
+
self,
|
|
878
|
+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
|
879
|
+
momentum: float = 0.95,
|
|
880
|
+
weight_decay: float = 0.01,
|
|
881
|
+
nesterov: bool = True,
|
|
882
|
+
ns_steps: int = 5,
|
|
883
|
+
):
|
|
884
|
+
super().__init__()
|
|
885
|
+
|
|
886
|
+
self._maybe_schedule("learning_rate", learning_rate)
|
|
887
|
+
self.momentum = momentum
|
|
888
|
+
self.weight_decay = weight_decay
|
|
889
|
+
self.nesterov = nesterov
|
|
890
|
+
self.ns_steps = ns_steps
|
|
891
|
+
|
|
892
|
+
def init_single(self, parameter: mx.array, state: dict):
|
|
893
|
+
"""Initialize optimizer state"""
|
|
894
|
+
state["v"] = mx.zeros_like(parameter)
|
|
895
|
+
|
|
896
|
+
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
|
897
|
+
assert (
|
|
898
|
+
X.ndim == 2
|
|
899
|
+
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
|
900
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
901
|
+
transpose_needed = X.shape[-2] > X.shape[-1]
|
|
902
|
+
|
|
903
|
+
if transpose_needed:
|
|
904
|
+
X = X.T
|
|
905
|
+
|
|
906
|
+
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
|
907
|
+
|
|
908
|
+
for _ in range(steps):
|
|
909
|
+
A = X @ X.T
|
|
910
|
+
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
|
911
|
+
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
|
912
|
+
|
|
913
|
+
if transpose_needed:
|
|
914
|
+
X = X.T
|
|
915
|
+
return X
|
|
916
|
+
|
|
917
|
+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
|
918
|
+
"""Performs the Muon parameter update"""
|
|
919
|
+
|
|
920
|
+
if self.weight_decay != 0:
|
|
921
|
+
gradient = gradient + self.weight_decay * parameter
|
|
922
|
+
|
|
923
|
+
v = self.momentum * state["v"]
|
|
924
|
+
v = v + (1 - self.momentum) * gradient
|
|
925
|
+
state["v"] = v
|
|
926
|
+
|
|
927
|
+
if self.nesterov:
|
|
928
|
+
update = gradient * (1 - self.momentum) + v * self.momentum
|
|
929
|
+
else:
|
|
930
|
+
update = v
|
|
931
|
+
|
|
932
|
+
lr = self.learning_rate.astype(gradient.dtype)
|
|
933
|
+
|
|
934
|
+
if update.ndim >= 2:
|
|
935
|
+
original_shape = update.shape
|
|
936
|
+
reshape_needed = update.ndim > 2
|
|
937
|
+
|
|
938
|
+
if reshape_needed:
|
|
939
|
+
update = mx.reshape(update, (update.shape[0], -1))
|
|
940
|
+
|
|
941
|
+
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
|
942
|
+
|
|
943
|
+
if reshape_needed:
|
|
944
|
+
update = mx.reshape(update, original_shape)
|
|
945
|
+
|
|
946
|
+
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
|
947
|
+
|
|
948
|
+
return parameter - lr * update
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
def clip_grad_norm(grads, max_norm):
|
|
952
|
+
"""Clips the global norm of the gradients.
|
|
953
|
+
|
|
954
|
+
This function ensures that the global norm of the gradients does not exceed
|
|
955
|
+
``max_norm``. It scales down the gradients proportionally if their norm is
|
|
956
|
+
greater than ``max_norm``.
|
|
957
|
+
|
|
958
|
+
Example:
|
|
959
|
+
>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
|
|
960
|
+
>>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
|
|
961
|
+
>>> print(clipped_grads)
|
|
962
|
+
{"w1": mx.array([...]), "w2": mx.array([...])}
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
grads (dict): A dictionary containing the gradient arrays.
|
|
966
|
+
max_norm (float): The maximum allowed global norm of the gradients.
|
|
967
|
+
|
|
968
|
+
Returns:
|
|
969
|
+
(dict, float): The possibly rescaled gradients and the original
|
|
970
|
+
gradient norm.
|
|
971
|
+
"""
|
|
972
|
+
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
|
|
973
|
+
total_norm = mx.sqrt(norm_squared)
|
|
974
|
+
normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
|
|
975
|
+
clipped_grads = tree_map(lambda g: g * normalizer, grads)
|
|
976
|
+
return clipped_grads, total_norm
|