mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlx might be problematic. Click here for more details.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <nanobind/nanobind.h>
|
|
4
|
+
#include <nanobind/stl/optional.h>
|
|
5
|
+
#include <nanobind/stl/variant.h>
|
|
6
|
+
#include <nanobind/stl/vector.h>
|
|
7
|
+
#include <numeric>
|
|
8
|
+
|
|
9
|
+
#include "mlx/fft.h"
|
|
10
|
+
#include "mlx/ops.h"
|
|
11
|
+
#include "python/src/small_vector.h"
|
|
12
|
+
|
|
13
|
+
namespace mx = mlx::core;
|
|
14
|
+
namespace nb = nanobind;
|
|
15
|
+
using namespace nb::literals;
|
|
16
|
+
|
|
17
|
+
void init_fft(nb::module_& parent_module) {
|
|
18
|
+
auto m = parent_module.def_submodule(
|
|
19
|
+
"fft", "mlx.core.fft: Fast Fourier Transforms.");
|
|
20
|
+
m.def(
|
|
21
|
+
"fft",
|
|
22
|
+
[](const mx::array& a,
|
|
23
|
+
const std::optional<int>& n,
|
|
24
|
+
int axis,
|
|
25
|
+
mx::StreamOrDevice s) {
|
|
26
|
+
if (n.has_value()) {
|
|
27
|
+
return mx::fft::fft(a, n.value(), axis, s);
|
|
28
|
+
} else {
|
|
29
|
+
return mx::fft::fft(a, axis, s);
|
|
30
|
+
}
|
|
31
|
+
},
|
|
32
|
+
"a"_a,
|
|
33
|
+
"n"_a = nb::none(),
|
|
34
|
+
"axis"_a = -1,
|
|
35
|
+
"stream"_a = nb::none(),
|
|
36
|
+
R"pbdoc(
|
|
37
|
+
One dimensional discrete Fourier Transform.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
a (array): The input array.
|
|
41
|
+
n (int, optional): Size of the transformed axis. The
|
|
42
|
+
corresponding axis in the input is truncated or padded with
|
|
43
|
+
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
|
44
|
+
axis (int, optional): Axis along which to perform the FFT. The
|
|
45
|
+
default is ``-1``.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
array: The DFT of the input along the given axis.
|
|
49
|
+
)pbdoc");
|
|
50
|
+
m.def(
|
|
51
|
+
"ifft",
|
|
52
|
+
[](const mx::array& a,
|
|
53
|
+
const std::optional<int>& n,
|
|
54
|
+
int axis,
|
|
55
|
+
mx::StreamOrDevice s) {
|
|
56
|
+
if (n.has_value()) {
|
|
57
|
+
return mx::fft::ifft(a, n.value(), axis, s);
|
|
58
|
+
} else {
|
|
59
|
+
return mx::fft::ifft(a, axis, s);
|
|
60
|
+
}
|
|
61
|
+
},
|
|
62
|
+
"a"_a,
|
|
63
|
+
"n"_a = nb::none(),
|
|
64
|
+
"axis"_a = -1,
|
|
65
|
+
"stream"_a = nb::none(),
|
|
66
|
+
R"pbdoc(
|
|
67
|
+
One dimensional inverse discrete Fourier Transform.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
a (array): The input array.
|
|
71
|
+
n (int, optional): Size of the transformed axis. The
|
|
72
|
+
corresponding axis in the input is truncated or padded with
|
|
73
|
+
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
|
74
|
+
axis (int, optional): Axis along which to perform the FFT. The
|
|
75
|
+
default is ``-1``.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
array: The inverse DFT of the input along the given axis.
|
|
79
|
+
)pbdoc");
|
|
80
|
+
m.def(
|
|
81
|
+
"fft2",
|
|
82
|
+
[](const mx::array& a,
|
|
83
|
+
const std::optional<mx::Shape>& n,
|
|
84
|
+
const std::optional<std::vector<int>>& axes,
|
|
85
|
+
mx::StreamOrDevice s) {
|
|
86
|
+
if (axes.has_value() && n.has_value()) {
|
|
87
|
+
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
|
88
|
+
} else if (axes.has_value()) {
|
|
89
|
+
return mx::fft::fftn(a, axes.value(), s);
|
|
90
|
+
} else if (n.has_value()) {
|
|
91
|
+
throw std::invalid_argument(
|
|
92
|
+
"[fft2] `axes` should not be `None` if `s` is not `None`.");
|
|
93
|
+
} else {
|
|
94
|
+
return mx::fft::fftn(a, s);
|
|
95
|
+
}
|
|
96
|
+
},
|
|
97
|
+
"a"_a,
|
|
98
|
+
"s"_a = nb::none(),
|
|
99
|
+
"axes"_a.none() = std::vector<int>{-2, -1},
|
|
100
|
+
"stream"_a = nb::none(),
|
|
101
|
+
R"pbdoc(
|
|
102
|
+
Two dimensional discrete Fourier Transform.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
a (array): The input array.
|
|
106
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
107
|
+
corresponding axes in the input are truncated or padded with
|
|
108
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
109
|
+
sizes of ``a`` along ``axes``.
|
|
110
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
111
|
+
The default is ``[-2, -1]``.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
array: The DFT of the input along the given axes.
|
|
115
|
+
)pbdoc");
|
|
116
|
+
m.def(
|
|
117
|
+
"ifft2",
|
|
118
|
+
[](const mx::array& a,
|
|
119
|
+
const std::optional<mx::Shape>& n,
|
|
120
|
+
const std::optional<std::vector<int>>& axes,
|
|
121
|
+
mx::StreamOrDevice s) {
|
|
122
|
+
if (axes.has_value() && n.has_value()) {
|
|
123
|
+
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
|
124
|
+
} else if (axes.has_value()) {
|
|
125
|
+
return mx::fft::ifftn(a, axes.value(), s);
|
|
126
|
+
} else if (n.has_value()) {
|
|
127
|
+
throw std::invalid_argument(
|
|
128
|
+
"[ifft2] `axes` should not be `None` if `s` is not `None`.");
|
|
129
|
+
} else {
|
|
130
|
+
return mx::fft::ifftn(a, s);
|
|
131
|
+
}
|
|
132
|
+
},
|
|
133
|
+
"a"_a,
|
|
134
|
+
"s"_a = nb::none(),
|
|
135
|
+
"axes"_a.none() = std::vector<int>{-2, -1},
|
|
136
|
+
"stream"_a = nb::none(),
|
|
137
|
+
R"pbdoc(
|
|
138
|
+
Two dimensional inverse discrete Fourier Transform.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
a (array): The input array.
|
|
142
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
143
|
+
corresponding axes in the input are truncated or padded with
|
|
144
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
145
|
+
sizes of ``a`` along ``axes``.
|
|
146
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
147
|
+
The default is ``[-2, -1]``.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
array: The inverse DFT of the input along the given axes.
|
|
151
|
+
)pbdoc");
|
|
152
|
+
m.def(
|
|
153
|
+
"fftn",
|
|
154
|
+
[](const mx::array& a,
|
|
155
|
+
const std::optional<mx::Shape>& n,
|
|
156
|
+
const std::optional<std::vector<int>>& axes,
|
|
157
|
+
mx::StreamOrDevice s) {
|
|
158
|
+
if (axes.has_value() && n.has_value()) {
|
|
159
|
+
return mx::fft::fftn(a, n.value(), axes.value(), s);
|
|
160
|
+
} else if (axes.has_value()) {
|
|
161
|
+
return mx::fft::fftn(a, axes.value(), s);
|
|
162
|
+
} else if (n.has_value()) {
|
|
163
|
+
throw std::invalid_argument(
|
|
164
|
+
"[fftn] `axes` should not be `None` if `s` is not `None`.");
|
|
165
|
+
} else {
|
|
166
|
+
return mx::fft::fftn(a, s);
|
|
167
|
+
}
|
|
168
|
+
},
|
|
169
|
+
"a"_a,
|
|
170
|
+
"s"_a = nb::none(),
|
|
171
|
+
"axes"_a = nb::none(),
|
|
172
|
+
"stream"_a = nb::none(),
|
|
173
|
+
R"pbdoc(
|
|
174
|
+
n-dimensional discrete Fourier Transform.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
a (array): The input array.
|
|
178
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
179
|
+
corresponding axes in the input are truncated or padded with
|
|
180
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
181
|
+
sizes of ``a`` along ``axes``.
|
|
182
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
183
|
+
The default is ``None`` in which case the FFT is over the last
|
|
184
|
+
``len(s)`` axes are or all axes if ``s`` is also ``None``.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
array: The DFT of the input along the given axes.
|
|
188
|
+
)pbdoc");
|
|
189
|
+
m.def(
|
|
190
|
+
"ifftn",
|
|
191
|
+
[](const mx::array& a,
|
|
192
|
+
const std::optional<mx::Shape>& n,
|
|
193
|
+
const std::optional<std::vector<int>>& axes,
|
|
194
|
+
mx::StreamOrDevice s) {
|
|
195
|
+
if (axes.has_value() && n.has_value()) {
|
|
196
|
+
return mx::fft::ifftn(a, n.value(), axes.value(), s);
|
|
197
|
+
} else if (axes.has_value()) {
|
|
198
|
+
return mx::fft::ifftn(a, axes.value(), s);
|
|
199
|
+
} else if (n.has_value()) {
|
|
200
|
+
throw std::invalid_argument(
|
|
201
|
+
"[ifftn] `axes` should not be `None` if `s` is not `None`.");
|
|
202
|
+
} else {
|
|
203
|
+
return mx::fft::ifftn(a, s);
|
|
204
|
+
}
|
|
205
|
+
},
|
|
206
|
+
"a"_a,
|
|
207
|
+
"s"_a = nb::none(),
|
|
208
|
+
"axes"_a = nb::none(),
|
|
209
|
+
"stream"_a = nb::none(),
|
|
210
|
+
R"pbdoc(
|
|
211
|
+
n-dimensional inverse discrete Fourier Transform.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
a (array): The input array.
|
|
215
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
216
|
+
corresponding axes in the input are truncated or padded with
|
|
217
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
218
|
+
sizes of ``a`` along ``axes``.
|
|
219
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
220
|
+
The default is ``None`` in which case the FFT is over the last
|
|
221
|
+
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
array: The inverse DFT of the input along the given axes.
|
|
225
|
+
)pbdoc");
|
|
226
|
+
m.def(
|
|
227
|
+
"rfft",
|
|
228
|
+
[](const mx::array& a,
|
|
229
|
+
const std::optional<int>& n,
|
|
230
|
+
int axis,
|
|
231
|
+
mx::StreamOrDevice s) {
|
|
232
|
+
if (n.has_value()) {
|
|
233
|
+
return mx::fft::rfft(a, n.value(), axis, s);
|
|
234
|
+
} else {
|
|
235
|
+
return mx::fft::rfft(a, axis, s);
|
|
236
|
+
}
|
|
237
|
+
},
|
|
238
|
+
"a"_a,
|
|
239
|
+
"n"_a = nb::none(),
|
|
240
|
+
"axis"_a = -1,
|
|
241
|
+
"stream"_a = nb::none(),
|
|
242
|
+
R"pbdoc(
|
|
243
|
+
One dimensional discrete Fourier Transform on a real input.
|
|
244
|
+
|
|
245
|
+
The output has the same shape as the input except along ``axis`` in
|
|
246
|
+
which case it has size ``n // 2 + 1``.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
a (array): The input array. If the array is complex it will be silently
|
|
250
|
+
cast to a real type.
|
|
251
|
+
n (int, optional): Size of the transformed axis. The
|
|
252
|
+
corresponding axis in the input is truncated or padded with
|
|
253
|
+
zeros to match ``n``. The default value is ``a.shape[axis]``.
|
|
254
|
+
axis (int, optional): Axis along which to perform the FFT. The
|
|
255
|
+
default is ``-1``.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
array: The DFT of the input along the given axis. The output
|
|
259
|
+
data type will be complex.
|
|
260
|
+
)pbdoc");
|
|
261
|
+
m.def(
|
|
262
|
+
"irfft",
|
|
263
|
+
[](const mx::array& a,
|
|
264
|
+
const std::optional<int>& n,
|
|
265
|
+
int axis,
|
|
266
|
+
mx::StreamOrDevice s) {
|
|
267
|
+
if (n.has_value()) {
|
|
268
|
+
return mx::fft::irfft(a, n.value(), axis, s);
|
|
269
|
+
} else {
|
|
270
|
+
return mx::fft::irfft(a, axis, s);
|
|
271
|
+
}
|
|
272
|
+
},
|
|
273
|
+
"a"_a,
|
|
274
|
+
"n"_a = nb::none(),
|
|
275
|
+
"axis"_a = -1,
|
|
276
|
+
"stream"_a = nb::none(),
|
|
277
|
+
R"pbdoc(
|
|
278
|
+
The inverse of :func:`rfft`.
|
|
279
|
+
|
|
280
|
+
The output has the same shape as the input except along ``axis`` in
|
|
281
|
+
which case it has size ``n``.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
a (array): The input array.
|
|
285
|
+
n (int, optional): Size of the transformed axis. The
|
|
286
|
+
corresponding axis in the input is truncated or padded with
|
|
287
|
+
zeros to match ``n // 2 + 1``. The default value is
|
|
288
|
+
``a.shape[axis] // 2 + 1``.
|
|
289
|
+
axis (int, optional): Axis along which to perform the FFT. The
|
|
290
|
+
default is ``-1``.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
array: The real array containing the inverse of :func:`rfft`.
|
|
294
|
+
)pbdoc");
|
|
295
|
+
m.def(
|
|
296
|
+
"rfft2",
|
|
297
|
+
[](const mx::array& a,
|
|
298
|
+
const std::optional<mx::Shape>& n,
|
|
299
|
+
const std::optional<std::vector<int>>& axes,
|
|
300
|
+
mx::StreamOrDevice s) {
|
|
301
|
+
if (axes.has_value() && n.has_value()) {
|
|
302
|
+
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
|
303
|
+
} else if (axes.has_value()) {
|
|
304
|
+
return mx::fft::rfftn(a, axes.value(), s);
|
|
305
|
+
} else if (n.has_value()) {
|
|
306
|
+
throw std::invalid_argument(
|
|
307
|
+
"[rfft2] `axes` should not be `None` if `s` is not `None`.");
|
|
308
|
+
} else {
|
|
309
|
+
return mx::fft::rfftn(a, s);
|
|
310
|
+
}
|
|
311
|
+
},
|
|
312
|
+
"a"_a,
|
|
313
|
+
"s"_a = nb::none(),
|
|
314
|
+
"axes"_a.none() = std::vector<int>{-2, -1},
|
|
315
|
+
"stream"_a = nb::none(),
|
|
316
|
+
R"pbdoc(
|
|
317
|
+
Two dimensional real discrete Fourier Transform.
|
|
318
|
+
|
|
319
|
+
The output has the same shape as the input except along the dimensions in
|
|
320
|
+
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
|
|
321
|
+
treated as the real axis and will have size ``s[-1] // 2 + 1``.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
a (array): The input array. If the array is complex it will be silently
|
|
325
|
+
cast to a real type.
|
|
326
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
327
|
+
corresponding axes in the input are truncated or padded with
|
|
328
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
329
|
+
sizes of ``a`` along ``axes``.
|
|
330
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
331
|
+
The default is ``[-2, -1]``.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
array: The real DFT of the input along the given axes. The output
|
|
335
|
+
data type will be complex.
|
|
336
|
+
)pbdoc");
|
|
337
|
+
m.def(
|
|
338
|
+
"irfft2",
|
|
339
|
+
[](const mx::array& a,
|
|
340
|
+
const std::optional<mx::Shape>& n,
|
|
341
|
+
const std::optional<std::vector<int>>& axes,
|
|
342
|
+
mx::StreamOrDevice s) {
|
|
343
|
+
if (axes.has_value() && n.has_value()) {
|
|
344
|
+
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
|
345
|
+
} else if (axes.has_value()) {
|
|
346
|
+
return mx::fft::irfftn(a, axes.value(), s);
|
|
347
|
+
} else if (n.has_value()) {
|
|
348
|
+
throw std::invalid_argument(
|
|
349
|
+
"[irfft2] `axes` should not be `None` if `s` is not `None`.");
|
|
350
|
+
} else {
|
|
351
|
+
return mx::fft::irfftn(a, s);
|
|
352
|
+
}
|
|
353
|
+
},
|
|
354
|
+
"a"_a,
|
|
355
|
+
"s"_a = nb::none(),
|
|
356
|
+
"axes"_a.none() = std::vector<int>{-2, -1},
|
|
357
|
+
"stream"_a = nb::none(),
|
|
358
|
+
R"pbdoc(
|
|
359
|
+
The inverse of :func:`rfft2`.
|
|
360
|
+
|
|
361
|
+
Note the input is generally complex. The dimensions of the input
|
|
362
|
+
specified in ``axes`` are padded or truncated to match the sizes
|
|
363
|
+
from ``s``. The last axis in ``axes`` is treated as the real axis
|
|
364
|
+
and will have size ``s[-1] // 2 + 1``.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
a (array): The input array.
|
|
368
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
369
|
+
corresponding axes in the input are truncated or padded with
|
|
370
|
+
zeros to match the sizes in ``s`` except for the last axis
|
|
371
|
+
which has size ``s[-1] // 2 + 1``. The default value is the
|
|
372
|
+
sizes of ``a`` along ``axes``.
|
|
373
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
374
|
+
The default is ``[-2, -1]``.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
array: The real array containing the inverse of :func:`rfft2`.
|
|
378
|
+
)pbdoc");
|
|
379
|
+
m.def(
|
|
380
|
+
"rfftn",
|
|
381
|
+
[](const mx::array& a,
|
|
382
|
+
const std::optional<mx::Shape>& n,
|
|
383
|
+
const std::optional<std::vector<int>>& axes,
|
|
384
|
+
mx::StreamOrDevice s) {
|
|
385
|
+
if (axes.has_value() && n.has_value()) {
|
|
386
|
+
return mx::fft::rfftn(a, n.value(), axes.value(), s);
|
|
387
|
+
} else if (axes.has_value()) {
|
|
388
|
+
return mx::fft::rfftn(a, axes.value(), s);
|
|
389
|
+
} else if (n.has_value()) {
|
|
390
|
+
throw std::invalid_argument(
|
|
391
|
+
"[rfftn] `axes` should not be `None` if `s` is not `None`.");
|
|
392
|
+
} else {
|
|
393
|
+
return mx::fft::rfftn(a, s);
|
|
394
|
+
}
|
|
395
|
+
},
|
|
396
|
+
"a"_a,
|
|
397
|
+
"s"_a = nb::none(),
|
|
398
|
+
"axes"_a = nb::none(),
|
|
399
|
+
"stream"_a = nb::none(),
|
|
400
|
+
R"pbdoc(
|
|
401
|
+
n-dimensional real discrete Fourier Transform.
|
|
402
|
+
|
|
403
|
+
The output has the same shape as the input except along the dimensions in
|
|
404
|
+
``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is
|
|
405
|
+
treated as the real axis and will have size ``s[-1] // 2 + 1``.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
a (array): The input array. If the array is complex it will be silently
|
|
409
|
+
cast to a real type.
|
|
410
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
411
|
+
corresponding axes in the input are truncated or padded with
|
|
412
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
413
|
+
sizes of ``a`` along ``axes``.
|
|
414
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
415
|
+
The default is ``None`` in which case the FFT is over the last
|
|
416
|
+
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
array: The real DFT of the input along the given axes. The output
|
|
420
|
+
)pbdoc");
|
|
421
|
+
m.def(
|
|
422
|
+
"irfftn",
|
|
423
|
+
[](const mx::array& a,
|
|
424
|
+
const std::optional<mx::Shape>& n,
|
|
425
|
+
const std::optional<std::vector<int>>& axes,
|
|
426
|
+
mx::StreamOrDevice s) {
|
|
427
|
+
if (axes.has_value() && n.has_value()) {
|
|
428
|
+
return mx::fft::irfftn(a, n.value(), axes.value(), s);
|
|
429
|
+
} else if (axes.has_value()) {
|
|
430
|
+
return mx::fft::irfftn(a, axes.value(), s);
|
|
431
|
+
} else if (n.has_value()) {
|
|
432
|
+
throw std::invalid_argument(
|
|
433
|
+
"[irfftn] `axes` should not be `None` if `s` is not `None`.");
|
|
434
|
+
} else {
|
|
435
|
+
return mx::fft::irfftn(a, s);
|
|
436
|
+
}
|
|
437
|
+
},
|
|
438
|
+
"a"_a,
|
|
439
|
+
"s"_a = nb::none(),
|
|
440
|
+
"axes"_a = nb::none(),
|
|
441
|
+
"stream"_a = nb::none(),
|
|
442
|
+
R"pbdoc(
|
|
443
|
+
The inverse of :func:`rfftn`.
|
|
444
|
+
|
|
445
|
+
Note the input is generally complex. The dimensions of the input
|
|
446
|
+
specified in ``axes`` are padded or truncated to match the sizes
|
|
447
|
+
from ``s``. The last axis in ``axes`` is treated as the real axis
|
|
448
|
+
and will have size ``s[-1] // 2 + 1``.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
a (array): The input array.
|
|
452
|
+
s (list(int), optional): Sizes of the transformed axes. The
|
|
453
|
+
corresponding axes in the input are truncated or padded with
|
|
454
|
+
zeros to match the sizes in ``s``. The default value is the
|
|
455
|
+
sizes of ``a`` along ``axes``.
|
|
456
|
+
axes (list(int), optional): Axes along which to perform the FFT.
|
|
457
|
+
The default is ``None`` in which case the FFT is over the last
|
|
458
|
+
``len(s)`` axes or all axes if ``s`` is also ``None``.
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
array: The real array containing the inverse of :func:`rfftn`.
|
|
462
|
+
)pbdoc");
|
|
463
|
+
m.def(
|
|
464
|
+
"fftshift",
|
|
465
|
+
[](const mx::array& a,
|
|
466
|
+
const std::optional<std::vector<int>>& axes,
|
|
467
|
+
mx::StreamOrDevice s) {
|
|
468
|
+
if (axes.has_value()) {
|
|
469
|
+
return mx::fft::fftshift(a, axes.value(), s);
|
|
470
|
+
} else {
|
|
471
|
+
return mx::fft::fftshift(a, s);
|
|
472
|
+
}
|
|
473
|
+
},
|
|
474
|
+
"a"_a,
|
|
475
|
+
"axes"_a = nb::none(),
|
|
476
|
+
"stream"_a = nb::none(),
|
|
477
|
+
R"pbdoc(
|
|
478
|
+
Shift the zero-frequency component to the center of the spectrum.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
a (array): The input array.
|
|
482
|
+
axes (list(int), optional): Axes over which to perform the shift.
|
|
483
|
+
If ``None``, shift all axes.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
array: The shifted array with the same shape as the input.
|
|
487
|
+
)pbdoc");
|
|
488
|
+
m.def(
|
|
489
|
+
"ifftshift",
|
|
490
|
+
[](const mx::array& a,
|
|
491
|
+
const std::optional<std::vector<int>>& axes,
|
|
492
|
+
mx::StreamOrDevice s) {
|
|
493
|
+
if (axes.has_value()) {
|
|
494
|
+
return mx::fft::ifftshift(a, axes.value(), s);
|
|
495
|
+
} else {
|
|
496
|
+
return mx::fft::ifftshift(a, s);
|
|
497
|
+
}
|
|
498
|
+
},
|
|
499
|
+
"a"_a,
|
|
500
|
+
"axes"_a = nb::none(),
|
|
501
|
+
"stream"_a = nb::none(),
|
|
502
|
+
R"pbdoc(
|
|
503
|
+
The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes,
|
|
504
|
+
the behavior differs for odd-length axes.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
a (array): The input array.
|
|
508
|
+
axes (list(int), optional): Axes over which to perform the inverse shift.
|
|
509
|
+
If ``None``, shift all axes.
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
array: The inverse-shifted array with the same shape as the input.
|
|
513
|
+
)pbdoc");
|
|
514
|
+
}
|