mlx 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/CMakeLists.txt +7 -0
- data/ext/mlx/Makefile +273 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/mkmf.log +44 -0
- data/ext/mlx/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
- data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
- data/ext/mlx/native.cpp +8027 -0
- data/ext/mlx/native.o +0 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version +1 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/.clang-format +87 -0
- data/mlx/.git +1 -0
- data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
- data/mlx/.github/actions/build-docs/action.yml +38 -0
- data/mlx/.github/actions/build-linux/action.yml +38 -0
- data/mlx/.github/actions/build-linux-release/action.yml +42 -0
- data/mlx/.github/actions/build-macos/action.yml +80 -0
- data/mlx/.github/actions/build-macos-release/action.yml +36 -0
- data/mlx/.github/actions/build-windows/action.yml +26 -0
- data/mlx/.github/actions/setup-linux/action.yml +93 -0
- data/mlx/.github/actions/setup-macos/action.yml +24 -0
- data/mlx/.github/actions/setup-windows/action.yml +42 -0
- data/mlx/.github/actions/test-linux/action.yml +69 -0
- data/mlx/.github/actions/test-windows/action.yml +20 -0
- data/mlx/.github/dependabot.yml +6 -0
- data/mlx/.github/pull_request_template.md +12 -0
- data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
- data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
- data/mlx/.github/workflows/build_and_test.yml +152 -0
- data/mlx/.github/workflows/documentation.yml +28 -0
- data/mlx/.github/workflows/nightly.yml +104 -0
- data/mlx/.github/workflows/release.yml +256 -0
- data/mlx/.gitignore +81 -0
- data/mlx/.pre-commit-config.yaml +27 -0
- data/mlx/ACKNOWLEDGMENTS.md +268 -0
- data/mlx/CITATION.cff +24 -0
- data/mlx/CMakeLists.txt +437 -0
- data/mlx/CODE_OF_CONDUCT.md +132 -0
- data/mlx/CONTRIBUTING.md +38 -0
- data/mlx/LICENSE +21 -0
- data/mlx/MANIFEST.in +6 -0
- data/mlx/README.md +121 -0
- data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- data/mlx/benchmarks/cpp/autograd.cpp +39 -0
- data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- data/mlx/benchmarks/cpp/time_utils.h +39 -0
- data/mlx/benchmarks/numpy/single_ops.py +39 -0
- data/mlx/benchmarks/numpy/time_utils.py +20 -0
- data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
- data/mlx/benchmarks/python/comparative/README.md +15 -0
- data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- data/mlx/benchmarks/python/comparative/compare.py +284 -0
- data/mlx/benchmarks/python/compile_bench.py +107 -0
- data/mlx/benchmarks/python/conv1d_bench.py +123 -0
- data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- data/mlx/benchmarks/python/conv_bench.py +135 -0
- data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- data/mlx/benchmarks/python/distributed_bench.py +66 -0
- data/mlx/benchmarks/python/einsum_bench.py +84 -0
- data/mlx/benchmarks/python/fft_bench.py +118 -0
- data/mlx/benchmarks/python/gather_bench.py +52 -0
- data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- data/mlx/benchmarks/python/hadamard_bench.py +70 -0
- data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
- data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
- data/mlx/benchmarks/python/masked_scatter.py +212 -0
- data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
- data/mlx/benchmarks/python/rope_bench.py +35 -0
- data/mlx/benchmarks/python/scatter_bench.py +96 -0
- data/mlx/benchmarks/python/sdpa_bench.py +223 -0
- data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
- data/mlx/benchmarks/python/single_ops.py +132 -0
- data/mlx/benchmarks/python/synchronize_bench.py +55 -0
- data/mlx/benchmarks/python/time_utils.py +38 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/docs/.clang-format +2 -0
- data/mlx/docs/.gitignore +3 -0
- data/mlx/docs/.nojekyll +0 -0
- data/mlx/docs/Doxyfile +51 -0
- data/mlx/docs/Makefile +18 -0
- data/mlx/docs/README.md +54 -0
- data/mlx/docs/index.html +1 -0
- data/mlx/docs/requirements.txt +5 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
- data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
- data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
- data/mlx/docs/src/_static/mlx_logo.png +0 -0
- data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
- data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
- data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
- data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
- data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
- data/mlx/docs/src/_templates/module-base-class.rst +33 -0
- data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
- data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
- data/mlx/docs/src/conf.py +99 -0
- data/mlx/docs/src/cpp/ops.rst +7 -0
- data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
- data/mlx/docs/src/dev/extensions.rst +811 -0
- data/mlx/docs/src/dev/metal_debugger.rst +68 -0
- data/mlx/docs/src/dev/metal_logging.rst +40 -0
- data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
- data/mlx/docs/src/examples/data_parallelism.rst +91 -0
- data/mlx/docs/src/examples/linear_regression.rst +77 -0
- data/mlx/docs/src/examples/llama-inference.rst +382 -0
- data/mlx/docs/src/examples/mlp.rst +134 -0
- data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
- data/mlx/docs/src/index.rst +96 -0
- data/mlx/docs/src/install.rst +340 -0
- data/mlx/docs/src/python/array.rst +65 -0
- data/mlx/docs/src/python/cuda.rst +9 -0
- data/mlx/docs/src/python/data_types.rst +78 -0
- data/mlx/docs/src/python/devices_and_streams.rst +21 -0
- data/mlx/docs/src/python/distributed.rst +22 -0
- data/mlx/docs/src/python/export.rst +14 -0
- data/mlx/docs/src/python/fast.rst +16 -0
- data/mlx/docs/src/python/fft.rst +24 -0
- data/mlx/docs/src/python/linalg.rst +27 -0
- data/mlx/docs/src/python/memory_management.rst +16 -0
- data/mlx/docs/src/python/metal.rst +12 -0
- data/mlx/docs/src/python/nn/distributed.rst +30 -0
- data/mlx/docs/src/python/nn/functions.rst +40 -0
- data/mlx/docs/src/python/nn/init.rst +45 -0
- data/mlx/docs/src/python/nn/layers.rst +74 -0
- data/mlx/docs/src/python/nn/losses.rst +25 -0
- data/mlx/docs/src/python/nn/module.rst +38 -0
- data/mlx/docs/src/python/nn.rst +186 -0
- data/mlx/docs/src/python/ops.rst +184 -0
- data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
- data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
- data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
- data/mlx/docs/src/python/optimizers.rst +78 -0
- data/mlx/docs/src/python/random.rst +48 -0
- data/mlx/docs/src/python/transforms.rst +22 -0
- data/mlx/docs/src/python/tree_utils.rst +23 -0
- data/mlx/docs/src/usage/compile.rst +516 -0
- data/mlx/docs/src/usage/distributed.rst +572 -0
- data/mlx/docs/src/usage/export.rst +288 -0
- data/mlx/docs/src/usage/function_transforms.rst +191 -0
- data/mlx/docs/src/usage/indexing.rst +194 -0
- data/mlx/docs/src/usage/launching_distributed.rst +234 -0
- data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
- data/mlx/docs/src/usage/numpy.rst +124 -0
- data/mlx/docs/src/usage/quick_start.rst +67 -0
- data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
- data/mlx/docs/src/usage/unified_memory.rst +78 -0
- data/mlx/docs/src/usage/using_streams.rst +18 -0
- data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
- data/mlx/examples/cmake_project/README.md +26 -0
- data/mlx/examples/cmake_project/example.cpp +14 -0
- data/mlx/examples/cpp/CMakeLists.txt +12 -0
- data/mlx/examples/cpp/distributed.cpp +22 -0
- data/mlx/examples/cpp/linear_regression.cpp +54 -0
- data/mlx/examples/cpp/logistic_regression.cpp +54 -0
- data/mlx/examples/cpp/metal_capture.cpp +31 -0
- data/mlx/examples/cpp/timer.h +20 -0
- data/mlx/examples/cpp/tutorial.cpp +99 -0
- data/mlx/examples/export/CMakeLists.txt +22 -0
- data/mlx/examples/export/README.md +49 -0
- data/mlx/examples/export/eval_mlp.cpp +25 -0
- data/mlx/examples/export/eval_mlp.py +52 -0
- data/mlx/examples/export/train_mlp.cpp +35 -0
- data/mlx/examples/export/train_mlp.py +76 -0
- data/mlx/examples/extensions/CMakeLists.txt +78 -0
- data/mlx/examples/extensions/README.md +24 -0
- data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
- data/mlx/examples/extensions/axpby/axpby.h +90 -0
- data/mlx/examples/extensions/axpby/axpby.metal +47 -0
- data/mlx/examples/extensions/bindings.cpp +39 -0
- data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
- data/mlx/examples/extensions/pyproject.toml +8 -0
- data/mlx/examples/extensions/requirements.txt +4 -0
- data/mlx/examples/extensions/setup.py +18 -0
- data/mlx/examples/extensions/test.py +12 -0
- data/mlx/examples/python/linear_regression.py +46 -0
- data/mlx/examples/python/logistic_regression.py +49 -0
- data/mlx/examples/python/qqmm.py +117 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- data/mlx/pyproject.toml +7 -0
- data/mlx/python/mlx/__main__.py +27 -0
- data/mlx/python/mlx/_distributed_utils/common.py +135 -0
- data/mlx/python/mlx/_distributed_utils/config.py +631 -0
- data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
- data/mlx/python/mlx/_reprlib_fix.py +16 -0
- data/mlx/python/mlx/_stub_patterns.txt +36 -0
- data/mlx/python/mlx/extension.py +88 -0
- data/mlx/python/mlx/nn/__init__.py +5 -0
- data/mlx/python/mlx/nn/init.py +441 -0
- data/mlx/python/mlx/nn/layers/__init__.py +105 -0
- data/mlx/python/mlx/nn/layers/activations.py +661 -0
- data/mlx/python/mlx/nn/layers/base.py +675 -0
- data/mlx/python/mlx/nn/layers/containers.py +24 -0
- data/mlx/python/mlx/nn/layers/convolution.py +232 -0
- data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
- data/mlx/python/mlx/nn/layers/distributed.py +601 -0
- data/mlx/python/mlx/nn/layers/dropout.py +137 -0
- data/mlx/python/mlx/nn/layers/embedding.py +53 -0
- data/mlx/python/mlx/nn/layers/linear.py +180 -0
- data/mlx/python/mlx/nn/layers/normalization.py +363 -0
- data/mlx/python/mlx/nn/layers/pooling.py +398 -0
- data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
- data/mlx/python/mlx/nn/layers/quantized.py +426 -0
- data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
- data/mlx/python/mlx/nn/layers/transformer.py +354 -0
- data/mlx/python/mlx/nn/layers/upsample.py +277 -0
- data/mlx/python/mlx/nn/losses.py +610 -0
- data/mlx/python/mlx/nn/utils.py +165 -0
- data/mlx/python/mlx/optimizers/__init__.py +4 -0
- data/mlx/python/mlx/optimizers/optimizers.py +976 -0
- data/mlx/python/mlx/optimizers/schedulers.py +158 -0
- data/mlx/python/mlx/py.typed +1 -0
- data/mlx/python/mlx/utils.py +325 -0
- data/mlx/python/src/CMakeLists.txt +96 -0
- data/mlx/python/src/array.cpp +1525 -0
- data/mlx/python/src/buffer.h +124 -0
- data/mlx/python/src/constants.cpp +15 -0
- data/mlx/python/src/convert.cpp +504 -0
- data/mlx/python/src/convert.h +50 -0
- data/mlx/python/src/cuda.cpp +19 -0
- data/mlx/python/src/device.cpp +98 -0
- data/mlx/python/src/distributed.cpp +352 -0
- data/mlx/python/src/export.cpp +356 -0
- data/mlx/python/src/fast.cpp +627 -0
- data/mlx/python/src/fft.cpp +514 -0
- data/mlx/python/src/indexing.cpp +1016 -0
- data/mlx/python/src/indexing.h +41 -0
- data/mlx/python/src/linalg.cpp +663 -0
- data/mlx/python/src/load.cpp +531 -0
- data/mlx/python/src/load.h +51 -0
- data/mlx/python/src/memory.cpp +125 -0
- data/mlx/python/src/metal.cpp +98 -0
- data/mlx/python/src/mlx.cpp +51 -0
- data/mlx/python/src/mlx_func.cpp +116 -0
- data/mlx/python/src/mlx_func.h +31 -0
- data/mlx/python/src/ops.cpp +5545 -0
- data/mlx/python/src/random.cpp +516 -0
- data/mlx/python/src/small_vector.h +76 -0
- data/mlx/python/src/stream.cpp +147 -0
- data/mlx/python/src/transforms.cpp +1542 -0
- data/mlx/python/src/trees.cpp +311 -0
- data/mlx/python/src/trees.h +62 -0
- data/mlx/python/src/utils.cpp +98 -0
- data/mlx/python/src/utils.h +78 -0
- data/mlx/python/tests/__main__.py +5 -0
- data/mlx/python/tests/cuda_skip.py +62 -0
- data/mlx/python/tests/mlx_distributed_tests.py +314 -0
- data/mlx/python/tests/mlx_tests.py +116 -0
- data/mlx/python/tests/mpi_test_distributed.py +142 -0
- data/mlx/python/tests/nccl_test_distributed.py +52 -0
- data/mlx/python/tests/ring_test_distributed.py +131 -0
- data/mlx/python/tests/test_array.py +2139 -0
- data/mlx/python/tests/test_autograd.py +880 -0
- data/mlx/python/tests/test_bf16.py +196 -0
- data/mlx/python/tests/test_blas.py +1429 -0
- data/mlx/python/tests/test_compile.py +1277 -0
- data/mlx/python/tests/test_constants.py +41 -0
- data/mlx/python/tests/test_conv.py +1198 -0
- data/mlx/python/tests/test_conv_transpose.py +810 -0
- data/mlx/python/tests/test_device.py +150 -0
- data/mlx/python/tests/test_double.py +306 -0
- data/mlx/python/tests/test_einsum.py +363 -0
- data/mlx/python/tests/test_eval.py +200 -0
- data/mlx/python/tests/test_export_import.py +614 -0
- data/mlx/python/tests/test_fast.py +923 -0
- data/mlx/python/tests/test_fast_sdpa.py +647 -0
- data/mlx/python/tests/test_fft.py +323 -0
- data/mlx/python/tests/test_graph.py +37 -0
- data/mlx/python/tests/test_init.py +139 -0
- data/mlx/python/tests/test_linalg.py +621 -0
- data/mlx/python/tests/test_load.py +447 -0
- data/mlx/python/tests/test_losses.py +427 -0
- data/mlx/python/tests/test_memory.py +77 -0
- data/mlx/python/tests/test_nn.py +1986 -0
- data/mlx/python/tests/test_ops.py +3261 -0
- data/mlx/python/tests/test_optimizers.py +584 -0
- data/mlx/python/tests/test_quantized.py +1160 -0
- data/mlx/python/tests/test_random.py +392 -0
- data/mlx/python/tests/test_reduce.py +223 -0
- data/mlx/python/tests/test_tree.py +96 -0
- data/mlx/python/tests/test_upsample.py +100 -0
- data/mlx/python/tests/test_vmap.py +860 -0
- data/mlx/setup.py +315 -0
- data/mlx/tests/CMakeLists.txt +44 -0
- data/mlx/tests/allocator_tests.cpp +41 -0
- data/mlx/tests/arg_reduce_tests.cpp +204 -0
- data/mlx/tests/array_tests.cpp +663 -0
- data/mlx/tests/autograd_tests.cpp +1399 -0
- data/mlx/tests/blas_tests.cpp +110 -0
- data/mlx/tests/compile_tests.cpp +818 -0
- data/mlx/tests/creations_tests.cpp +239 -0
- data/mlx/tests/custom_vjp_tests.cpp +55 -0
- data/mlx/tests/device_tests.cpp +35 -0
- data/mlx/tests/einsum_tests.cpp +85 -0
- data/mlx/tests/eval_tests.cpp +93 -0
- data/mlx/tests/export_import_tests.cpp +164 -0
- data/mlx/tests/fft_tests.cpp +366 -0
- data/mlx/tests/gpu_tests.cpp +523 -0
- data/mlx/tests/linalg_tests.cpp +639 -0
- data/mlx/tests/load_tests.cpp +270 -0
- data/mlx/tests/ops_tests.cpp +4159 -0
- data/mlx/tests/random_tests.cpp +716 -0
- data/mlx/tests/scheduler_tests.cpp +121 -0
- data/mlx/tests/tests.cpp +26 -0
- data/mlx/tests/utils_tests.cpp +67 -0
- data/mlx/tests/vmap_tests.cpp +547 -0
- metadata +958 -0
|
@@ -0,0 +1,2572 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <algorithm>
|
|
4
|
+
#include <cassert>
|
|
5
|
+
#include <numeric>
|
|
6
|
+
#include <sstream>
|
|
7
|
+
|
|
8
|
+
#include "mlx/backend/common/broadcasting.h"
|
|
9
|
+
#include "mlx/backend/common/matmul.h"
|
|
10
|
+
#include "mlx/backend/gpu/copy.h"
|
|
11
|
+
#include "mlx/backend/metal/binary.h"
|
|
12
|
+
#include "mlx/backend/metal/device.h"
|
|
13
|
+
#include "mlx/backend/metal/kernels.h"
|
|
14
|
+
#include "mlx/backend/metal/kernels/defines.h"
|
|
15
|
+
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
|
16
|
+
#include "mlx/backend/metal/matmul.h"
|
|
17
|
+
#include "mlx/backend/metal/utils.h"
|
|
18
|
+
#include "mlx/primitives.h"
|
|
19
|
+
#include "mlx/utils.h"
|
|
20
|
+
|
|
21
|
+
namespace mlx::core {
|
|
22
|
+
|
|
23
|
+
namespace {
|
|
24
|
+
|
|
25
|
+
std::tuple<bool, int64_t, array> check_transpose(
|
|
26
|
+
std::vector<array>& copies,
|
|
27
|
+
const Stream& s,
|
|
28
|
+
const array& arr,
|
|
29
|
+
bool is_vector) {
|
|
30
|
+
auto stx = arr.strides()[arr.ndim() - 2];
|
|
31
|
+
auto sty = arr.strides()[arr.ndim() - 1];
|
|
32
|
+
if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
|
|
33
|
+
return std::make_tuple(false, stx, arr);
|
|
34
|
+
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
|
35
|
+
return std::make_tuple(true, sty, arr);
|
|
36
|
+
} else {
|
|
37
|
+
array arr_copy = contiguous_copy_gpu(arr, s);
|
|
38
|
+
copies.push_back(arr_copy);
|
|
39
|
+
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
|
40
|
+
}
|
|
41
|
+
};
|
|
42
|
+
|
|
43
|
+
inline array
|
|
44
|
+
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|
45
|
+
if (!x.flags().row_contiguous) {
|
|
46
|
+
array x_copy = contiguous_copy_gpu(x, s);
|
|
47
|
+
d.add_temporary(x_copy, s.index);
|
|
48
|
+
return x_copy;
|
|
49
|
+
} else {
|
|
50
|
+
return x;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
inline std::tuple<bool, int64_t, array>
|
|
55
|
+
ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|
56
|
+
if (x.flags().row_contiguous) {
|
|
57
|
+
return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
bool rc = true;
|
|
61
|
+
for (int i = 0; i < x.ndim() - 3; i++) {
|
|
62
|
+
rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
|
|
63
|
+
}
|
|
64
|
+
if (rc) {
|
|
65
|
+
auto stx = x.strides()[x.ndim() - 2];
|
|
66
|
+
auto sty = x.strides()[x.ndim() - 1];
|
|
67
|
+
auto K = x.shape(-2);
|
|
68
|
+
auto N = x.shape(-1);
|
|
69
|
+
if (sty == 1 && (N != 1 || stx == N)) {
|
|
70
|
+
return std::make_tuple(false, stx, x);
|
|
71
|
+
}
|
|
72
|
+
if (stx == 1 && (N != 1 || sty == K)) {
|
|
73
|
+
return std::make_tuple(true, sty, x);
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
array x_copy = contiguous_copy_gpu(x, s);
|
|
78
|
+
d.add_temporary(x_copy, s.index);
|
|
79
|
+
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
} // namespace
|
|
83
|
+
|
|
84
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
85
|
+
// Steel matmul fallback
|
|
86
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
87
|
+
|
|
88
|
+
#define GEMM_TPARAM_MACRO(devc) \
|
|
89
|
+
if (devc == 'g' || devc == 'p') { /* Small device */ \
|
|
90
|
+
if (out.dtype() == complex64) { \
|
|
91
|
+
bm = 64; \
|
|
92
|
+
bn = 32; \
|
|
93
|
+
bk = 8; \
|
|
94
|
+
wm = 4; \
|
|
95
|
+
wn = 1; \
|
|
96
|
+
} else if (!transpose_a && transpose_b) { /* nt */ \
|
|
97
|
+
bm = 64; \
|
|
98
|
+
bn = 32; \
|
|
99
|
+
bk = 32; \
|
|
100
|
+
wm = 2; \
|
|
101
|
+
wn = 2; \
|
|
102
|
+
} else if (out.dtype() != float32) { /* half and bfloat */ \
|
|
103
|
+
bm = 64; \
|
|
104
|
+
bn = 64; \
|
|
105
|
+
bk = 16; \
|
|
106
|
+
wm = 1; \
|
|
107
|
+
wn = 2; \
|
|
108
|
+
} \
|
|
109
|
+
} else if (devc == 'd') { /* Large device */ \
|
|
110
|
+
if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
|
|
111
|
+
if (out.dtype() != float32) { /* half and bfloat */ \
|
|
112
|
+
if (2 * std::max(M, N) > K) { /* Reasonable K */ \
|
|
113
|
+
bm = 64; \
|
|
114
|
+
bn = 64; \
|
|
115
|
+
bk = 16; \
|
|
116
|
+
wm = 1; \
|
|
117
|
+
wn = 2; \
|
|
118
|
+
} else if (!transpose_a && transpose_b) { /* nt with large k */ \
|
|
119
|
+
bm = 64; \
|
|
120
|
+
bn = 32; \
|
|
121
|
+
bk = 32; \
|
|
122
|
+
wm = 2; \
|
|
123
|
+
wn = 2; \
|
|
124
|
+
} else { /* nn with large K */ \
|
|
125
|
+
bm = 32; \
|
|
126
|
+
bn = 64; \
|
|
127
|
+
bk = 16; \
|
|
128
|
+
wm = 1; \
|
|
129
|
+
wn = 2; \
|
|
130
|
+
} \
|
|
131
|
+
} /* float takes default */ \
|
|
132
|
+
} else { /* smaller matmul */ \
|
|
133
|
+
if (out.dtype() != float32) { /* half and bfloat */ \
|
|
134
|
+
if (!transpose_a && transpose_b) { /* nt */ \
|
|
135
|
+
bm = 64; \
|
|
136
|
+
bn = 32; \
|
|
137
|
+
bk = 32; \
|
|
138
|
+
wm = 2; \
|
|
139
|
+
wn = 2; \
|
|
140
|
+
} else { /* nn */ \
|
|
141
|
+
bm = 64; \
|
|
142
|
+
bn = 64; \
|
|
143
|
+
bk = 16; \
|
|
144
|
+
wm = 1; \
|
|
145
|
+
wn = 2; \
|
|
146
|
+
} \
|
|
147
|
+
} else { /* floats */ \
|
|
148
|
+
if (!transpose_a && transpose_b) { /* nt */ \
|
|
149
|
+
bm = 32; \
|
|
150
|
+
bn = 64; \
|
|
151
|
+
bk = 16; \
|
|
152
|
+
wm = 1; \
|
|
153
|
+
wn = 2; \
|
|
154
|
+
} else { /* nn */ \
|
|
155
|
+
bm = 64; \
|
|
156
|
+
bn = 32; \
|
|
157
|
+
bk = 32; \
|
|
158
|
+
wm = 2; \
|
|
159
|
+
wn = 2; \
|
|
160
|
+
} \
|
|
161
|
+
} \
|
|
162
|
+
} \
|
|
163
|
+
} else { /* Medium device */ \
|
|
164
|
+
bm = 64; \
|
|
165
|
+
bn = 64; \
|
|
166
|
+
bk = 16; \
|
|
167
|
+
wm = 2; \
|
|
168
|
+
wn = 2; \
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
172
|
+
// Regular steel matmul dispatch
|
|
173
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
174
|
+
|
|
175
|
+
template <bool CHECK_AB>
|
|
176
|
+
void steel_matmul_regular_axpby_nax(
|
|
177
|
+
const Stream& s,
|
|
178
|
+
metal::Device& d,
|
|
179
|
+
const array& a,
|
|
180
|
+
const array& b,
|
|
181
|
+
const array& c,
|
|
182
|
+
array& out,
|
|
183
|
+
int M,
|
|
184
|
+
int N,
|
|
185
|
+
int K,
|
|
186
|
+
int batch_size_out,
|
|
187
|
+
int lda,
|
|
188
|
+
int ldb,
|
|
189
|
+
int ldd,
|
|
190
|
+
bool transpose_a,
|
|
191
|
+
bool transpose_b,
|
|
192
|
+
std::vector<array>& copies,
|
|
193
|
+
Shape batch_shape,
|
|
194
|
+
Strides batch_strides,
|
|
195
|
+
int64_t A_batch_stride,
|
|
196
|
+
int64_t B_batch_stride,
|
|
197
|
+
int64_t matrix_stride_out,
|
|
198
|
+
int64_t C_batch_stride /* = 0*/,
|
|
199
|
+
float alpha /* = 1.0f */,
|
|
200
|
+
float beta /* = 0.0f */) {
|
|
201
|
+
using namespace mlx::steel;
|
|
202
|
+
|
|
203
|
+
// Determine dispatch kernel
|
|
204
|
+
int bm = 128, bn = 128, bk = 512;
|
|
205
|
+
int wm = 4, wn = 4;
|
|
206
|
+
|
|
207
|
+
// Prepare kernel name
|
|
208
|
+
std::ostringstream kname;
|
|
209
|
+
|
|
210
|
+
// clang-format off
|
|
211
|
+
kname << "steel_gemm_fused_nax_"
|
|
212
|
+
<< (transpose_a ? 't' : 'n')
|
|
213
|
+
<< (transpose_b ? 't' : 'n')
|
|
214
|
+
<< "_" << type_to_name(a)
|
|
215
|
+
<< "_" << type_to_name(out)
|
|
216
|
+
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
217
|
+
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
|
218
|
+
|
|
219
|
+
std::string base_name = kname.str();
|
|
220
|
+
|
|
221
|
+
const bool has_batch = (batch_shape.size() > 1);
|
|
222
|
+
const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
|
|
223
|
+
const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
|
|
224
|
+
const bool align_M = (M % bm) == 0;
|
|
225
|
+
const bool align_N = (N % bn) == 0;
|
|
226
|
+
const bool align_K = (K % bk) == 0;
|
|
227
|
+
|
|
228
|
+
metal::MTLFCList func_consts = {
|
|
229
|
+
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
|
230
|
+
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
|
231
|
+
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
|
232
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
233
|
+
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
234
|
+
{&align_K, MTL::DataType::DataTypeBool, 202},
|
|
235
|
+
};
|
|
236
|
+
|
|
237
|
+
// clang-format off
|
|
238
|
+
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
|
239
|
+
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
|
240
|
+
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
|
241
|
+
<< "_align_M_" << (align_M ? 't' : 'n')
|
|
242
|
+
<< "_align_N_" << (align_N ? 't' : 'n')
|
|
243
|
+
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
|
244
|
+
|
|
245
|
+
std::string hash_name = kname.str();
|
|
246
|
+
|
|
247
|
+
// Encode and dispatch kernel
|
|
248
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
249
|
+
auto kernel = get_steel_gemm_fused_nax_kernel(
|
|
250
|
+
/* metal::Device& d = */ d,
|
|
251
|
+
/* const std::string& kernel_name = */ base_name,
|
|
252
|
+
/* const std::string& hash_name = */ hash_name,
|
|
253
|
+
/* const metal::MTLFCList& func_consts = */ func_consts,
|
|
254
|
+
/* const array& out = */ out,
|
|
255
|
+
/* bool transpose_a = */ transpose_a,
|
|
256
|
+
/* bool transpose_b = */ transpose_b,
|
|
257
|
+
/* int bm = */ bm,
|
|
258
|
+
/* int bn = */ bn,
|
|
259
|
+
/* int bk = */ bk,
|
|
260
|
+
/* int wm = */ wm,
|
|
261
|
+
/* int wn = */ wn);
|
|
262
|
+
|
|
263
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
264
|
+
|
|
265
|
+
// Use problem size to determine threadblock swizzle
|
|
266
|
+
int tn = (N + bn - 1) / bn;
|
|
267
|
+
int tm = (M + bm - 1) / bm;
|
|
268
|
+
|
|
269
|
+
// TODO: Explore device-based tuning for swizzle
|
|
270
|
+
int swizzle_log = tm <= 3 ? 0 : 1;
|
|
271
|
+
|
|
272
|
+
// Prepare steel matmul params
|
|
273
|
+
GEMMParams params{/* const int M = */ M,
|
|
274
|
+
/* const int N = */ N,
|
|
275
|
+
/* const int K = */ K,
|
|
276
|
+
/* const int lda = */ lda,
|
|
277
|
+
/* const int ldb = */ ldb,
|
|
278
|
+
/* const int ldd = */ ldd,
|
|
279
|
+
/* const int tiles_n = */ tn,
|
|
280
|
+
/* const int tiles_m = */ tm,
|
|
281
|
+
/* const int64_t batch_stride_a = */ A_batch_stride,
|
|
282
|
+
/* const int64_t batch_stride_b = */ B_batch_stride,
|
|
283
|
+
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
|
284
|
+
/* const int swizzle_log = */ swizzle_log,
|
|
285
|
+
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
286
|
+
/* const int batch_ndim = */ int(batch_shape.size())};
|
|
287
|
+
|
|
288
|
+
// Prepare launch grid params
|
|
289
|
+
int tile = 1 << swizzle_log;
|
|
290
|
+
tm = (tm + tile - 1) / tile;
|
|
291
|
+
tn = tn * tile;
|
|
292
|
+
|
|
293
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
294
|
+
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
|
295
|
+
|
|
296
|
+
// Launch kernel
|
|
297
|
+
compute_encoder.set_input_array(a, 0);
|
|
298
|
+
compute_encoder.set_input_array(b, 1);
|
|
299
|
+
compute_encoder.set_output_array(out, 3);
|
|
300
|
+
|
|
301
|
+
compute_encoder.set_bytes(params, 4);
|
|
302
|
+
|
|
303
|
+
if (has_batch) {
|
|
304
|
+
compute_encoder.set_vector_bytes(batch_shape, 6);
|
|
305
|
+
compute_encoder.set_vector_bytes(batch_strides, 7);
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
if (use_out_source) {
|
|
309
|
+
int ldc = c.strides()[c.ndim() - 2];
|
|
310
|
+
int fdc = c.strides()[c.ndim() - 1];
|
|
311
|
+
|
|
312
|
+
GEMMAddMMParams params{/* const int ldc = */ ldc,
|
|
313
|
+
/* const int fdc = */ fdc,
|
|
314
|
+
/* const int64_t batch_stride_c = */ C_batch_stride,
|
|
315
|
+
/* const float alpha = */ alpha,
|
|
316
|
+
/* const float beta = */ beta};
|
|
317
|
+
|
|
318
|
+
compute_encoder.set_input_array(c, 2);
|
|
319
|
+
compute_encoder.set_bytes(params, 5);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
323
|
+
|
|
324
|
+
// Record copies
|
|
325
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
template <bool CHECK_AB>
|
|
329
|
+
void steel_matmul_regular_axpby(
|
|
330
|
+
const Stream& s,
|
|
331
|
+
metal::Device& d,
|
|
332
|
+
const array& a,
|
|
333
|
+
const array& b,
|
|
334
|
+
const array& c,
|
|
335
|
+
array& out,
|
|
336
|
+
int M,
|
|
337
|
+
int N,
|
|
338
|
+
int K,
|
|
339
|
+
int batch_size_out,
|
|
340
|
+
int lda,
|
|
341
|
+
int ldb,
|
|
342
|
+
int ldd,
|
|
343
|
+
bool transpose_a,
|
|
344
|
+
bool transpose_b,
|
|
345
|
+
std::vector<array>& copies,
|
|
346
|
+
Shape batch_shape,
|
|
347
|
+
Strides batch_strides,
|
|
348
|
+
int64_t A_batch_stride,
|
|
349
|
+
int64_t B_batch_stride,
|
|
350
|
+
int64_t matrix_stride_out,
|
|
351
|
+
int64_t C_batch_stride /* = 0*/,
|
|
352
|
+
float alpha /* = 1.0f */,
|
|
353
|
+
float beta /* = 0.0f */) {
|
|
354
|
+
if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
|
|
355
|
+
(env::enable_tf32() || a.dtype() != float32)) {
|
|
356
|
+
return steel_matmul_regular_axpby_nax<CHECK_AB>(
|
|
357
|
+
/* const Stream& s = */ s,
|
|
358
|
+
/* metal::Device& d = */ d,
|
|
359
|
+
/* const array& a = */ a,
|
|
360
|
+
/* const array& b = */ b,
|
|
361
|
+
/* const array& c = */ c,
|
|
362
|
+
/* array& out = */ out,
|
|
363
|
+
/* int M = */ M,
|
|
364
|
+
/* int N = */ N,
|
|
365
|
+
/* int K = */ K,
|
|
366
|
+
/* int batch_size_out = */ batch_size_out,
|
|
367
|
+
/* int lda = */ lda,
|
|
368
|
+
/* int ldb = */ ldb,
|
|
369
|
+
/* int ldd = */ ldd,
|
|
370
|
+
/* bool transpose_a = */ transpose_a,
|
|
371
|
+
/* bool transpose_b = */ transpose_b,
|
|
372
|
+
/* std::vector<array>& copies = */ copies,
|
|
373
|
+
/* Shape batch_shape = */ batch_shape,
|
|
374
|
+
/* Strides batch_strides = */ batch_strides,
|
|
375
|
+
/* int64_t A_batch_stride = */ A_batch_stride,
|
|
376
|
+
/* int64_t B_batch_stride = */ B_batch_stride,
|
|
377
|
+
/* int64_t matrix_stride_out = */ matrix_stride_out,
|
|
378
|
+
/* int64_t C_batch_stride = */ C_batch_stride,
|
|
379
|
+
/* float alpha = */ alpha,
|
|
380
|
+
/* float beta = */ beta);
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
using namespace mlx::steel;
|
|
384
|
+
|
|
385
|
+
// Determine dispatch kernel
|
|
386
|
+
int bm = 64, bn = 64, bk = 16;
|
|
387
|
+
int wm = 2, wn = 2;
|
|
388
|
+
|
|
389
|
+
char devc = d.get_architecture().back();
|
|
390
|
+
GEMM_TPARAM_MACRO(devc)
|
|
391
|
+
|
|
392
|
+
// Prepare kernel name
|
|
393
|
+
std::ostringstream kname;
|
|
394
|
+
|
|
395
|
+
// clang-format off
|
|
396
|
+
kname << "steel_gemm_fused_"
|
|
397
|
+
<< (transpose_a ? 't' : 'n')
|
|
398
|
+
<< (transpose_b ? 't' : 'n')
|
|
399
|
+
<< "_" << type_to_name(a)
|
|
400
|
+
<< "_" << type_to_name(out)
|
|
401
|
+
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
402
|
+
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
|
403
|
+
|
|
404
|
+
std::string base_name = kname.str();
|
|
405
|
+
|
|
406
|
+
const bool has_batch = (batch_shape.size() > 1);
|
|
407
|
+
const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
|
|
408
|
+
const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
|
|
409
|
+
const bool align_M = (M % bm) == 0;
|
|
410
|
+
const bool align_N = (N % bn) == 0;
|
|
411
|
+
const bool align_K = (K % bk) == 0;
|
|
412
|
+
|
|
413
|
+
metal::MTLFCList func_consts = {
|
|
414
|
+
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
|
415
|
+
{&use_out_source, MTL::DataType::DataTypeBool, 100},
|
|
416
|
+
{&do_axpby, MTL::DataType::DataTypeBool, 110},
|
|
417
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
418
|
+
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
419
|
+
{&align_K, MTL::DataType::DataTypeBool, 202},
|
|
420
|
+
};
|
|
421
|
+
|
|
422
|
+
// clang-format off
|
|
423
|
+
kname << "_has_batch_" << (has_batch ? 't' : 'n')
|
|
424
|
+
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
|
|
425
|
+
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
|
|
426
|
+
<< "_align_M_" << (align_M ? 't' : 'n')
|
|
427
|
+
<< "_align_N_" << (align_N ? 't' : 'n')
|
|
428
|
+
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
|
429
|
+
|
|
430
|
+
std::string hash_name = kname.str();
|
|
431
|
+
|
|
432
|
+
// Encode and dispatch kernel
|
|
433
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
434
|
+
auto kernel = get_steel_gemm_fused_kernel(
|
|
435
|
+
/* metal::Device& d = */ d,
|
|
436
|
+
/* const std::string& kernel_name = */ base_name,
|
|
437
|
+
/* const std::string& hash_name = */ hash_name,
|
|
438
|
+
/* const metal::MTLFCList& func_consts = */ func_consts,
|
|
439
|
+
/* const array& out = */ out,
|
|
440
|
+
/* bool transpose_a = */ transpose_a,
|
|
441
|
+
/* bool transpose_b = */ transpose_b,
|
|
442
|
+
/* int bm = */ bm,
|
|
443
|
+
/* int bn = */ bn,
|
|
444
|
+
/* int bk = */ bk,
|
|
445
|
+
/* int wm = */ wm,
|
|
446
|
+
/* int wn = */ wn);
|
|
447
|
+
|
|
448
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
449
|
+
|
|
450
|
+
// Use problem size to determine threadblock swizzle
|
|
451
|
+
int tn = (N + bn - 1) / bn;
|
|
452
|
+
int tm = (M + bm - 1) / bm;
|
|
453
|
+
|
|
454
|
+
// TODO: Explore device-based tuning for swizzle
|
|
455
|
+
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
|
456
|
+
|
|
457
|
+
// Prepare steel matmul params
|
|
458
|
+
GEMMParams params{/* const int M = */ M,
|
|
459
|
+
/* const int N = */ N,
|
|
460
|
+
/* const int K = */ K,
|
|
461
|
+
/* const int lda = */ lda,
|
|
462
|
+
/* const int ldb = */ ldb,
|
|
463
|
+
/* const int ldd = */ ldd,
|
|
464
|
+
/* const int tiles_n = */ tn,
|
|
465
|
+
/* const int tiles_m = */ tm,
|
|
466
|
+
/* const int64_t batch_stride_a = */ A_batch_stride,
|
|
467
|
+
/* const int64_t batch_stride_b = */ B_batch_stride,
|
|
468
|
+
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
|
469
|
+
/* const int swizzle_log = */ swizzle_log,
|
|
470
|
+
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
471
|
+
/* const int batch_ndim = */ int(batch_shape.size())};
|
|
472
|
+
|
|
473
|
+
// Prepare launch grid params
|
|
474
|
+
int tile = 1 << swizzle_log;
|
|
475
|
+
tm = (tm + tile - 1) / tile;
|
|
476
|
+
tn = tn * tile;
|
|
477
|
+
|
|
478
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
479
|
+
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
|
480
|
+
|
|
481
|
+
// Launch kernel
|
|
482
|
+
compute_encoder.set_input_array(a, 0);
|
|
483
|
+
compute_encoder.set_input_array(b, 1);
|
|
484
|
+
compute_encoder.set_output_array(out, 3);
|
|
485
|
+
|
|
486
|
+
compute_encoder.set_bytes(params, 4);
|
|
487
|
+
|
|
488
|
+
if (has_batch) {
|
|
489
|
+
compute_encoder.set_vector_bytes(batch_shape, 6);
|
|
490
|
+
compute_encoder.set_vector_bytes(batch_strides, 7);
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
if (use_out_source) {
|
|
494
|
+
int ldc = c.strides()[c.ndim() - 2];
|
|
495
|
+
int fdc = c.strides()[c.ndim() - 1];
|
|
496
|
+
|
|
497
|
+
GEMMAddMMParams params{/* const int ldc = */ ldc,
|
|
498
|
+
/* const int fdc = */ fdc,
|
|
499
|
+
/* const int64_t batch_stride_c = */ C_batch_stride,
|
|
500
|
+
/* const float alpha = */ alpha,
|
|
501
|
+
/* const float beta = */ beta};
|
|
502
|
+
|
|
503
|
+
compute_encoder.set_input_array(c, 2);
|
|
504
|
+
compute_encoder.set_bytes(params, 5);
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
508
|
+
|
|
509
|
+
// Record copies
|
|
510
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
514
|
+
// Split k steel matmul
|
|
515
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
516
|
+
|
|
517
|
+
template <bool CHECK_AB = true>
|
|
518
|
+
void steel_gemm_splitk_axpby(
|
|
519
|
+
const Stream& s,
|
|
520
|
+
metal::Device& d,
|
|
521
|
+
const array& a,
|
|
522
|
+
const array& b,
|
|
523
|
+
const array& c,
|
|
524
|
+
array& out,
|
|
525
|
+
int M,
|
|
526
|
+
int N,
|
|
527
|
+
int K,
|
|
528
|
+
int batch_size_out,
|
|
529
|
+
int lda,
|
|
530
|
+
int ldb,
|
|
531
|
+
bool transpose_a,
|
|
532
|
+
bool transpose_b,
|
|
533
|
+
std::vector<array>& copies,
|
|
534
|
+
float alpha = 1.0f,
|
|
535
|
+
float beta = 0.0f) {
|
|
536
|
+
using namespace mlx::steel;
|
|
537
|
+
|
|
538
|
+
int _tm = (M + 32 - 1) / 32;
|
|
539
|
+
int _tn = (N + 32 - 1) / 32;
|
|
540
|
+
int _tk = K / 16;
|
|
541
|
+
|
|
542
|
+
int bm = M < 40 ? 16 : 32;
|
|
543
|
+
int bn = N < 40 ? 16 : 32;
|
|
544
|
+
int bk = 16;
|
|
545
|
+
int wm = 2, wn = 2;
|
|
546
|
+
|
|
547
|
+
// As _tk grows use more partitions, as _tm * _tn grow use fewer partitions
|
|
548
|
+
int split_k_partitions =
|
|
549
|
+
std::min(std::max(2, next_power_of_2(_tk / (_tm * _tn))), 32);
|
|
550
|
+
int split_k_partition_stride = M * N;
|
|
551
|
+
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
|
552
|
+
int split_k_partition_size = gemm_k_iterations * bk;
|
|
553
|
+
|
|
554
|
+
array C_split(
|
|
555
|
+
{split_k_partitions, M, N},
|
|
556
|
+
issubdtype(out.dtype(), complexfloating) ? complex64 : float32,
|
|
557
|
+
nullptr,
|
|
558
|
+
{});
|
|
559
|
+
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
|
560
|
+
copies.push_back(C_split);
|
|
561
|
+
|
|
562
|
+
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
|
563
|
+
bool k_aligned = K % bk == 0;
|
|
564
|
+
std::ostringstream kname;
|
|
565
|
+
|
|
566
|
+
// clang-format off
|
|
567
|
+
kname << "steel_gemm_splitk_"
|
|
568
|
+
<< (transpose_a ? 't' : 'n')
|
|
569
|
+
<< (transpose_b ? 't' : 'n')
|
|
570
|
+
<< "_" << type_to_name(a)
|
|
571
|
+
<< "_" << type_to_name(C_split)
|
|
572
|
+
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
573
|
+
<< "_wm" << wm << "_wn" << wn
|
|
574
|
+
<< "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
|
|
575
|
+
<< "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
|
|
576
|
+
|
|
577
|
+
// Encode and dispatch gemm kernel
|
|
578
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
579
|
+
auto kernel = get_steel_gemm_splitk_kernel(
|
|
580
|
+
/* metal::Device& d = */ d,
|
|
581
|
+
/* const std::string& kernel_name = */ kname.str(),
|
|
582
|
+
/* const array& in = */ a,
|
|
583
|
+
/* const array& out = */ C_split,
|
|
584
|
+
/* bool transpose_a = */ transpose_a,
|
|
585
|
+
/* bool transpose_b = */ transpose_b,
|
|
586
|
+
/* int bm = */ bm,
|
|
587
|
+
/* int bn = */ bn,
|
|
588
|
+
/* int bk = */ bk,
|
|
589
|
+
/* int wm = */ wm,
|
|
590
|
+
/* int wn = */ wn,
|
|
591
|
+
/* bool mn_aligned = */ mn_aligned,
|
|
592
|
+
/* bool k_aligned = */ k_aligned);
|
|
593
|
+
|
|
594
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
595
|
+
|
|
596
|
+
int tn = (N + bn - 1) / bn;
|
|
597
|
+
int tm = (M + bm - 1) / bm;
|
|
598
|
+
|
|
599
|
+
GEMMSpiltKParams params{
|
|
600
|
+
/* const int M = */ M,
|
|
601
|
+
/* const int N = */ N,
|
|
602
|
+
/* const int K = */ K,
|
|
603
|
+
/* const int lda = */ lda,
|
|
604
|
+
/* const int ldb = */ ldb,
|
|
605
|
+
/* const int ldc = */ N,
|
|
606
|
+
/* const int tiles_n = */ tn,
|
|
607
|
+
/* const int tiles_m = */ tm,
|
|
608
|
+
/* const int split_k_partitions = */ split_k_partitions,
|
|
609
|
+
/* const int split_k_partition_stride = */ split_k_partition_stride,
|
|
610
|
+
/* const int split_k_partition_size = */ split_k_partition_size,
|
|
611
|
+
/* const int swizzle_log = */ 0, // no swizzle
|
|
612
|
+
/* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
|
|
613
|
+
|
|
614
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
615
|
+
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
|
|
616
|
+
|
|
617
|
+
compute_encoder.set_input_array(a, 0);
|
|
618
|
+
compute_encoder.set_input_array(b, 1);
|
|
619
|
+
compute_encoder.set_output_array(C_split, 2);
|
|
620
|
+
|
|
621
|
+
compute_encoder.set_bytes(params, 3);
|
|
622
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
623
|
+
|
|
624
|
+
// Do accum kernel
|
|
625
|
+
{
|
|
626
|
+
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
|
627
|
+
|
|
628
|
+
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
|
629
|
+
type_to_name(C_split);
|
|
630
|
+
|
|
631
|
+
if (do_axpby) {
|
|
632
|
+
kernel_name = kernel_name + "_axbpy";
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
|
636
|
+
/* metal::Device& d = */ d,
|
|
637
|
+
/* const std::string& kernel_name = */ kernel_name,
|
|
638
|
+
/* const array& in = */ C_split,
|
|
639
|
+
/* const array& out = */ out,
|
|
640
|
+
/* bool axbpy = */ do_axpby);
|
|
641
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
642
|
+
|
|
643
|
+
// Set the arguments for the kernel
|
|
644
|
+
compute_encoder.set_input_array(C_split, 0);
|
|
645
|
+
compute_encoder.set_output_array(out, 1);
|
|
646
|
+
compute_encoder.set_bytes(split_k_partitions, 2);
|
|
647
|
+
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
|
648
|
+
compute_encoder.set_bytes(N, 4);
|
|
649
|
+
|
|
650
|
+
if (do_axpby) {
|
|
651
|
+
int ldc = c.strides()[c.ndim() - 2];
|
|
652
|
+
int fdc = c.strides()[c.ndim() - 1];
|
|
653
|
+
|
|
654
|
+
compute_encoder.set_input_array(c, 5);
|
|
655
|
+
compute_encoder.set_bytes(ldc, 6);
|
|
656
|
+
compute_encoder.set_bytes(fdc, 7);
|
|
657
|
+
compute_encoder.set_bytes(alpha, 8);
|
|
658
|
+
compute_encoder.set_bytes(beta, 9);
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
// Launch enough thread groups for each output
|
|
662
|
+
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
|
663
|
+
auto group_dims = get_block_dims(N, M, 1);
|
|
664
|
+
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
671
|
+
// NAX Split k steel matmul
|
|
672
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
673
|
+
|
|
674
|
+
template <bool CHECK_AB = true>
|
|
675
|
+
void steel_gemm_splitk_axpby_nax(
|
|
676
|
+
const Stream& s,
|
|
677
|
+
metal::Device& d,
|
|
678
|
+
const array& a,
|
|
679
|
+
const array& b,
|
|
680
|
+
const array& c,
|
|
681
|
+
array& out,
|
|
682
|
+
int M,
|
|
683
|
+
int N,
|
|
684
|
+
int K,
|
|
685
|
+
int batch_size_out,
|
|
686
|
+
int lda,
|
|
687
|
+
int ldb,
|
|
688
|
+
bool transpose_a,
|
|
689
|
+
bool transpose_b,
|
|
690
|
+
std::vector<array>& copies,
|
|
691
|
+
float alpha = 1.0f,
|
|
692
|
+
float beta = 0.0f) {
|
|
693
|
+
using namespace mlx::steel;
|
|
694
|
+
|
|
695
|
+
constexpr int bm = 128, bn = 128, bk = 512;
|
|
696
|
+
constexpr int wm = 4, wn = 4;
|
|
697
|
+
|
|
698
|
+
// Determine how many partitions to split K into
|
|
699
|
+
constexpr int split_k_partition_size = 3072;
|
|
700
|
+
int split_k_partitions =
|
|
701
|
+
(K + split_k_partition_size - 1) / split_k_partition_size;
|
|
702
|
+
|
|
703
|
+
const int bk_iters_per_partition = split_k_partition_size / bk;
|
|
704
|
+
const int split_k_partition_stride = M * N;
|
|
705
|
+
|
|
706
|
+
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
|
707
|
+
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
|
708
|
+
copies.push_back(C_split);
|
|
709
|
+
|
|
710
|
+
const bool align_M = (M % bm) == 0;
|
|
711
|
+
const bool align_N = (N % bn) == 0;
|
|
712
|
+
const bool align_K = (K % bk) == 0;
|
|
713
|
+
|
|
714
|
+
// Per-tile align_K is checked at runtime; only the last tile can be unaligned
|
|
715
|
+
metal::MTLFCList func_consts = {
|
|
716
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
717
|
+
{&align_N, MTL::DataType::DataTypeBool, 201}};
|
|
718
|
+
|
|
719
|
+
std::ostringstream kname;
|
|
720
|
+
|
|
721
|
+
// clang-format off
|
|
722
|
+
kname << "steel_gemm_splitk_nax_"
|
|
723
|
+
<< (transpose_a ? 't' : 'n')
|
|
724
|
+
<< (transpose_b ? 't' : 'n')
|
|
725
|
+
<< "_" << type_to_name(a)
|
|
726
|
+
<< "_" << type_to_name(C_split)
|
|
727
|
+
<< "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
728
|
+
<< "_wm" << wm << "_wn" << wn; // clang-format on
|
|
729
|
+
|
|
730
|
+
std::string base_name = kname.str();
|
|
731
|
+
|
|
732
|
+
// clang-format off
|
|
733
|
+
kname << "_align_M_" << (align_M ? 't' : 'n')
|
|
734
|
+
<< "_align_N_" << (align_N ? 't' : 'n')
|
|
735
|
+
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
|
|
736
|
+
|
|
737
|
+
std::string hash_name = kname.str();
|
|
738
|
+
|
|
739
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
740
|
+
auto kernel = get_steel_gemm_splitk_nax_kernel(
|
|
741
|
+
/* metal::Device& d = */ d,
|
|
742
|
+
/* const std::string& kernel_name = */ base_name,
|
|
743
|
+
/* const std::string& hash_name = */ hash_name,
|
|
744
|
+
/* const metal::MTLFCList& func_consts = */ func_consts,
|
|
745
|
+
/* const array& out = */ C_split,
|
|
746
|
+
/* bool transpose_a = */ transpose_a,
|
|
747
|
+
/* bool transpose_b = */ transpose_b,
|
|
748
|
+
/* int bm = */ bm,
|
|
749
|
+
/* int bn = */ bn,
|
|
750
|
+
/* int bk = */ bk,
|
|
751
|
+
/* int wm = */ wm,
|
|
752
|
+
/* int wn = */ wn);
|
|
753
|
+
|
|
754
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
755
|
+
|
|
756
|
+
int tn = (N + bn - 1) / bn;
|
|
757
|
+
int tm = (M + bm - 1) / bm;
|
|
758
|
+
|
|
759
|
+
int swizzle_log = tm <= 3 ? 0 : 1;
|
|
760
|
+
|
|
761
|
+
// Compute swizzled tile counts
|
|
762
|
+
int tile = 1 << swizzle_log;
|
|
763
|
+
int tm_swizzled = (tm + tile - 1) / tile;
|
|
764
|
+
int tn_swizzled = tn * tile;
|
|
765
|
+
|
|
766
|
+
GEMMSpiltKParams params{
|
|
767
|
+
/* const int M = */ M,
|
|
768
|
+
/* const int N = */ N,
|
|
769
|
+
/* const int K = */ K,
|
|
770
|
+
/* const int lda = */ lda,
|
|
771
|
+
/* const int ldb = */ ldb,
|
|
772
|
+
/* const int ldc = */ N,
|
|
773
|
+
/* const int tiles_n = */ tn,
|
|
774
|
+
/* const int tiles_m = */ tm,
|
|
775
|
+
/* const int split_k_partitions = */ split_k_partitions,
|
|
776
|
+
/* const int split_k_partition_stride = */ split_k_partition_stride,
|
|
777
|
+
/* const int split_k_partition_size = */ split_k_partition_size,
|
|
778
|
+
/* const int swizzle_log = */ swizzle_log,
|
|
779
|
+
/* const int gemm_k_iterations_aligned = */ bk_iters_per_partition};
|
|
780
|
+
|
|
781
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
782
|
+
// Use 1D grid with K-partition-major layout: [Partition0: M×N
|
|
783
|
+
// tiles][Partition1: M×N tiles]... Grid size is 1D to prevent driver/HW from
|
|
784
|
+
// using its own heuristic to exploit 2D locality by launching threadgroups in
|
|
785
|
+
// a non-linear order
|
|
786
|
+
MTL::Size grid_dims =
|
|
787
|
+
MTL::Size(tn_swizzled * tm_swizzled * split_k_partitions, 1, 1);
|
|
788
|
+
|
|
789
|
+
compute_encoder.set_input_array(a, 0);
|
|
790
|
+
compute_encoder.set_input_array(b, 1);
|
|
791
|
+
compute_encoder.set_output_array(C_split, 2);
|
|
792
|
+
|
|
793
|
+
compute_encoder.set_bytes(params, 3);
|
|
794
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
795
|
+
|
|
796
|
+
// Do accum kernel
|
|
797
|
+
{
|
|
798
|
+
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
|
799
|
+
|
|
800
|
+
auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
|
|
801
|
+
type_to_name(C_split);
|
|
802
|
+
|
|
803
|
+
if (do_axpby) {
|
|
804
|
+
kernel_name = kernel_name + "_axbpy";
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
auto kernel = get_steel_gemm_splitk_accum_kernel(
|
|
808
|
+
/* metal::Device& d = */ d,
|
|
809
|
+
/* const std::string& kernel_name = */ kernel_name,
|
|
810
|
+
/* const array& in = */ C_split,
|
|
811
|
+
/* const array& out = */ out,
|
|
812
|
+
/* bool axbpy = */ do_axpby);
|
|
813
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
814
|
+
|
|
815
|
+
// Set the arguments for the kernel
|
|
816
|
+
compute_encoder.set_input_array(C_split, 0);
|
|
817
|
+
compute_encoder.set_output_array(out, 1);
|
|
818
|
+
compute_encoder.set_bytes(split_k_partitions, 2);
|
|
819
|
+
compute_encoder.set_bytes(split_k_partition_stride, 3);
|
|
820
|
+
compute_encoder.set_bytes(N, 4);
|
|
821
|
+
|
|
822
|
+
if (do_axpby) {
|
|
823
|
+
int ldc = c.strides()[c.ndim() - 2];
|
|
824
|
+
int fdc = c.strides()[c.ndim() - 1];
|
|
825
|
+
|
|
826
|
+
compute_encoder.set_input_array(c, 5);
|
|
827
|
+
compute_encoder.set_bytes(ldc, 6);
|
|
828
|
+
compute_encoder.set_bytes(fdc, 7);
|
|
829
|
+
compute_encoder.set_bytes(alpha, 8);
|
|
830
|
+
compute_encoder.set_bytes(beta, 9);
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
// Launch enough thread groups for each output
|
|
834
|
+
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
|
835
|
+
auto group_dims = get_block_dims(N, M, 1);
|
|
836
|
+
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
843
|
+
// Split matmul routing
|
|
844
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
845
|
+
|
|
846
|
+
template <bool CHECK_AB>
|
|
847
|
+
void steel_matmul_axpby(
|
|
848
|
+
const Stream& s,
|
|
849
|
+
metal::Device& d,
|
|
850
|
+
const array& a,
|
|
851
|
+
const array& b,
|
|
852
|
+
const array& c,
|
|
853
|
+
array& out,
|
|
854
|
+
int M,
|
|
855
|
+
int N,
|
|
856
|
+
int K,
|
|
857
|
+
int batch_size_out,
|
|
858
|
+
int lda,
|
|
859
|
+
int ldb,
|
|
860
|
+
bool transpose_a,
|
|
861
|
+
bool transpose_b,
|
|
862
|
+
std::vector<array>& copies,
|
|
863
|
+
Shape batch_shape /* = {} */,
|
|
864
|
+
Strides A_batch_stride /* = {} */,
|
|
865
|
+
Strides B_batch_stride /* = {} */,
|
|
866
|
+
Strides C_batch_stride /* = {} */,
|
|
867
|
+
float alpha /* = 1.0f */,
|
|
868
|
+
float beta /* = 0.0f */) {
|
|
869
|
+
if (batch_shape.empty()) {
|
|
870
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
871
|
+
// Check and collapse batch dimensions
|
|
872
|
+
if constexpr (CHECK_AB) {
|
|
873
|
+
auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] =
|
|
874
|
+
collapse_batches(a, b, c);
|
|
875
|
+
|
|
876
|
+
batch_shape = batch_shape_;
|
|
877
|
+
A_batch_stride = A_bstride_;
|
|
878
|
+
B_batch_stride = B_bstride_;
|
|
879
|
+
C_batch_stride = C_bstride_;
|
|
880
|
+
// Collapse batches into M if needed
|
|
881
|
+
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
|
882
|
+
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
|
883
|
+
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
|
|
884
|
+
B_batch_stride.back() == 0) {
|
|
885
|
+
M *= batch_shape.back();
|
|
886
|
+
batch_size_out = 1;
|
|
887
|
+
|
|
888
|
+
A_batch_stride = {0};
|
|
889
|
+
B_batch_stride = {0};
|
|
890
|
+
C_batch_stride = {0};
|
|
891
|
+
batch_shape = {1};
|
|
892
|
+
}
|
|
893
|
+
} else {
|
|
894
|
+
auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
|
|
895
|
+
|
|
896
|
+
batch_shape = batch_shape_;
|
|
897
|
+
A_batch_stride = A_bstride_;
|
|
898
|
+
B_batch_stride = B_bstride_;
|
|
899
|
+
// Collapse batches into M if needed
|
|
900
|
+
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
|
901
|
+
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
|
902
|
+
B_batch_stride.back() == 0) {
|
|
903
|
+
M *= batch_shape.back();
|
|
904
|
+
batch_size_out = 1;
|
|
905
|
+
|
|
906
|
+
A_batch_stride = {0};
|
|
907
|
+
B_batch_stride = {0};
|
|
908
|
+
batch_shape = {1};
|
|
909
|
+
}
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
914
|
+
// Split K specialization
|
|
915
|
+
|
|
916
|
+
int _tm = (M + 16 - 1) / 16;
|
|
917
|
+
int _tn = (N + 16 - 1) / 16;
|
|
918
|
+
int _tk = K / 16;
|
|
919
|
+
|
|
920
|
+
// Case 1: Small M×N with large K, use SIMD split-K
|
|
921
|
+
char devc = d.get_architecture().back();
|
|
922
|
+
// Max and Ultra dispatch larger sizes to splitk
|
|
923
|
+
int min_tmn_threshold = (devc == 's' || devc == 'd') ? 2048 : 1024;
|
|
924
|
+
if (batch_size_out == 1 && (_tm * _tn) <= min_tmn_threshold && _tk >= 8 &&
|
|
925
|
+
K >= std::max(M, N)) {
|
|
926
|
+
return steel_gemm_splitk_axpby<CHECK_AB>(
|
|
927
|
+
/* const Stream& s = */ s,
|
|
928
|
+
/* metal::Device& d = */ d,
|
|
929
|
+
/* const array& a = */ a,
|
|
930
|
+
/* const array& b = */ b,
|
|
931
|
+
/* const array& c = */ c,
|
|
932
|
+
/* array& out = */ out,
|
|
933
|
+
/* int M = */ M,
|
|
934
|
+
/* int N = */ N,
|
|
935
|
+
/* int K = */ K,
|
|
936
|
+
/* int batch_size_out = */ batch_size_out,
|
|
937
|
+
/* int lda = */ lda,
|
|
938
|
+
/* int ldb = */ ldb,
|
|
939
|
+
/* bool transpose_a = */ transpose_a,
|
|
940
|
+
/* bool transpose_b = */ transpose_b,
|
|
941
|
+
/* std::vector<array>& copies = */ copies,
|
|
942
|
+
/* float alpha = */ alpha,
|
|
943
|
+
/* float beta = */ beta);
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
// Case 2: Large K with sufficient M, N, and NAX is available, use NAX split-K
|
|
947
|
+
// TODO: Add device-specific tuning for more NAX GPUs in the future
|
|
948
|
+
constexpr int min_mn_threshold = 2048 * 2048;
|
|
949
|
+
constexpr int min_k_threshold = 10240;
|
|
950
|
+
if (batch_size_out == 1 && metal::is_nax_available() &&
|
|
951
|
+
!issubdtype(a.dtype(), complexfloating) &&
|
|
952
|
+
(env::enable_tf32() || a.dtype() != float32) &&
|
|
953
|
+
int64_t(M) * N >= min_mn_threshold && K >= min_k_threshold &&
|
|
954
|
+
K >= (3 * std::max(M, N))) {
|
|
955
|
+
return steel_gemm_splitk_axpby_nax<CHECK_AB>(
|
|
956
|
+
/* const Stream& s = */ s,
|
|
957
|
+
/* metal::Device& d = */ d,
|
|
958
|
+
/* const array& a = */ a,
|
|
959
|
+
/* const array& b = */ b,
|
|
960
|
+
/* const array& c = */ c,
|
|
961
|
+
/* array& out = */ out,
|
|
962
|
+
/* int M = */ M,
|
|
963
|
+
/* int N = */ N,
|
|
964
|
+
/* int K = */ K,
|
|
965
|
+
/* int batch_size_out = */ batch_size_out,
|
|
966
|
+
/* int lda = */ lda,
|
|
967
|
+
/* int ldb = */ ldb,
|
|
968
|
+
/* bool transpose_a = */ transpose_a,
|
|
969
|
+
/* bool transpose_b = */ transpose_b,
|
|
970
|
+
/* std::vector<array>& copies = */ copies,
|
|
971
|
+
/* float alpha = */ alpha,
|
|
972
|
+
/* float beta = */ beta);
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
976
|
+
// Regular kernel dispatch
|
|
977
|
+
auto batch_strides = A_batch_stride;
|
|
978
|
+
batch_strides.insert(
|
|
979
|
+
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
|
|
980
|
+
if (CHECK_AB && !C_batch_stride.empty()) {
|
|
981
|
+
batch_strides.insert(
|
|
982
|
+
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back();
|
|
986
|
+
int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back();
|
|
987
|
+
int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back();
|
|
988
|
+
|
|
989
|
+
return steel_matmul_regular_axpby<CHECK_AB>(
|
|
990
|
+
/* const Stream& s = */ s,
|
|
991
|
+
/* metal::Device& d = */ d,
|
|
992
|
+
/* const array& a = */ a,
|
|
993
|
+
/* const array& b = */ b,
|
|
994
|
+
/* const array& c = */ c,
|
|
995
|
+
/* array& out = */ out,
|
|
996
|
+
/* int M = */ M,
|
|
997
|
+
/* int N = */ N,
|
|
998
|
+
/* int K = */ K,
|
|
999
|
+
/* int batch_size_out = */ batch_size_out,
|
|
1000
|
+
/* int lda = */ lda,
|
|
1001
|
+
/* int ldb = */ ldb,
|
|
1002
|
+
/* int ldd = */ N,
|
|
1003
|
+
/* bool transpose_a = */ transpose_a,
|
|
1004
|
+
/* bool transpose_b = */ transpose_b,
|
|
1005
|
+
/* std::vector<array>& copies = */ copies,
|
|
1006
|
+
/* Shape batch_shape = */ std::move(batch_shape),
|
|
1007
|
+
/* Strides batch_strides = */ std::move(batch_strides),
|
|
1008
|
+
/* int64_t A_batch_stride = */ A_batch_stride_,
|
|
1009
|
+
/* int64_t B_batch_stride = */ B_batch_stride_,
|
|
1010
|
+
/* int64_t matrix_stride_out = */ int64_t(M) * N,
|
|
1011
|
+
/* int64_t C_batch_stride = */ C_batch_stride_,
|
|
1012
|
+
/* float alpha = */ alpha,
|
|
1013
|
+
/* float beta = */ beta);
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1017
|
+
// GEMV dispatch
|
|
1018
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1019
|
+
|
|
1020
|
+
template <bool CHECK_AB = true>
|
|
1021
|
+
void gemv_axbpy(
|
|
1022
|
+
const Stream& s,
|
|
1023
|
+
metal::Device& d,
|
|
1024
|
+
const array& a,
|
|
1025
|
+
const array& b,
|
|
1026
|
+
const array& c,
|
|
1027
|
+
array& out,
|
|
1028
|
+
int M,
|
|
1029
|
+
int N,
|
|
1030
|
+
int K,
|
|
1031
|
+
int batch_size_out,
|
|
1032
|
+
int lda,
|
|
1033
|
+
int ldb,
|
|
1034
|
+
bool transpose_a,
|
|
1035
|
+
bool transpose_b,
|
|
1036
|
+
std::vector<array>& copies,
|
|
1037
|
+
Shape batch_shape = {},
|
|
1038
|
+
Strides A_batch_stride = {},
|
|
1039
|
+
Strides B_batch_stride = {},
|
|
1040
|
+
Strides C_batch_stride = {},
|
|
1041
|
+
float alpha = 1.0f,
|
|
1042
|
+
float beta = 0.0f) {
|
|
1043
|
+
// Collect problem info
|
|
1044
|
+
bool is_b_matrix = N != 1;
|
|
1045
|
+
|
|
1046
|
+
auto& mat = is_b_matrix ? b : a;
|
|
1047
|
+
auto& vec = is_b_matrix ? a : b;
|
|
1048
|
+
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
|
1049
|
+
int in_vector_len = K;
|
|
1050
|
+
int out_vector_len = is_b_matrix ? N : M;
|
|
1051
|
+
|
|
1052
|
+
int mat_ld = is_b_matrix ? ldb : lda;
|
|
1053
|
+
|
|
1054
|
+
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
|
1055
|
+
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
|
1056
|
+
|
|
1057
|
+
// Determine if inputs have simple batching / broadcasting
|
|
1058
|
+
bool contiguous_kernel = (batch_shape.size() == 1);
|
|
1059
|
+
|
|
1060
|
+
int batch_ndim = batch_shape.size();
|
|
1061
|
+
|
|
1062
|
+
// Determine dispatch kernel
|
|
1063
|
+
int tm = 4, tn = 4;
|
|
1064
|
+
int sm = 1, sn = 32;
|
|
1065
|
+
int bm = 1, bn = 1;
|
|
1066
|
+
int n_out_per_tgp;
|
|
1067
|
+
std::ostringstream kname;
|
|
1068
|
+
|
|
1069
|
+
if (transpose_mat) {
|
|
1070
|
+
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
|
1071
|
+
sm = 4;
|
|
1072
|
+
sn = 8;
|
|
1073
|
+
} else {
|
|
1074
|
+
sm = 8;
|
|
1075
|
+
sn = 4;
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
if (out_vector_len >= 2048) {
|
|
1079
|
+
bn = 16;
|
|
1080
|
+
} else if (out_vector_len >= 512) {
|
|
1081
|
+
bn = 4;
|
|
1082
|
+
} else {
|
|
1083
|
+
bn = 2;
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
// Specialized kernel for very small outputs
|
|
1087
|
+
tn = out_vector_len < tn ? 1 : tn;
|
|
1088
|
+
|
|
1089
|
+
n_out_per_tgp = bn * sn * tn;
|
|
1090
|
+
kname << "gemv_t_" << type_to_name(out);
|
|
1091
|
+
|
|
1092
|
+
} else {
|
|
1093
|
+
bm = out_vector_len >= 4096 ? 8 : 4;
|
|
1094
|
+
sn = 32;
|
|
1095
|
+
|
|
1096
|
+
if (K <= 64) {
|
|
1097
|
+
bm = 1;
|
|
1098
|
+
sm = 8;
|
|
1099
|
+
sn = 4;
|
|
1100
|
+
} else if (K >= 16 * out_vector_len) {
|
|
1101
|
+
bm = 1;
|
|
1102
|
+
bn = 8;
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
// Specialized kernel for very small outputs
|
|
1106
|
+
tm = out_vector_len < tm ? 1 : tm;
|
|
1107
|
+
|
|
1108
|
+
n_out_per_tgp = bm * sm * tm;
|
|
1109
|
+
kname << "gemv_" << type_to_name(out);
|
|
1110
|
+
}
|
|
1111
|
+
|
|
1112
|
+
const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
|
|
1113
|
+
|
|
1114
|
+
// clang-format off
|
|
1115
|
+
kname << "_bm" << bm << "_bn" << bn
|
|
1116
|
+
<< "_sm" << sm << "_sn" << sn
|
|
1117
|
+
<< "_tm" << tm << "_tn" << tn
|
|
1118
|
+
<< "_nc" << !contiguous_kernel
|
|
1119
|
+
<< "_axpby" << do_axpby; // clang-format on
|
|
1120
|
+
|
|
1121
|
+
// Encode and dispatch kernel
|
|
1122
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
1123
|
+
auto kernel = d.get_kernel(kname.str());
|
|
1124
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
1125
|
+
|
|
1126
|
+
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
|
1127
|
+
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
|
1128
|
+
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
|
1129
|
+
|
|
1130
|
+
compute_encoder.set_input_array(mat, 0);
|
|
1131
|
+
compute_encoder.set_input_array(vec, 1);
|
|
1132
|
+
compute_encoder.set_output_array(out, 3);
|
|
1133
|
+
|
|
1134
|
+
compute_encoder.set_bytes(in_vector_len, 4);
|
|
1135
|
+
compute_encoder.set_bytes(out_vector_len, 5);
|
|
1136
|
+
compute_encoder.set_bytes(mat_ld, 6);
|
|
1137
|
+
|
|
1138
|
+
compute_encoder.set_bytes(batch_ndim, 9);
|
|
1139
|
+
compute_encoder.set_vector_bytes(batch_shape, 10);
|
|
1140
|
+
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
|
1141
|
+
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
|
1142
|
+
|
|
1143
|
+
if (do_axpby) {
|
|
1144
|
+
compute_encoder.set_input_array(c, 2);
|
|
1145
|
+
|
|
1146
|
+
compute_encoder.set_bytes(alpha, 7);
|
|
1147
|
+
compute_encoder.set_bytes(beta, 8);
|
|
1148
|
+
|
|
1149
|
+
compute_encoder.set_vector_bytes(C_batch_stride, 13);
|
|
1150
|
+
|
|
1151
|
+
int bias_stride = c.strides()[c.ndim() - 1];
|
|
1152
|
+
compute_encoder.set_bytes(bias_stride, 14);
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
1156
|
+
|
|
1157
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
inline void gemv(
|
|
1161
|
+
const Stream& s,
|
|
1162
|
+
metal::Device& d,
|
|
1163
|
+
const array& a,
|
|
1164
|
+
const array& b,
|
|
1165
|
+
array& out,
|
|
1166
|
+
int M,
|
|
1167
|
+
int N,
|
|
1168
|
+
int K,
|
|
1169
|
+
int batch_size_out,
|
|
1170
|
+
int lda,
|
|
1171
|
+
int ldb,
|
|
1172
|
+
bool transpose_a,
|
|
1173
|
+
bool transpose_b,
|
|
1174
|
+
std::vector<array>& copies,
|
|
1175
|
+
Shape batch_shape = {},
|
|
1176
|
+
Strides A_batch_stride = {},
|
|
1177
|
+
Strides B_batch_stride = {}) {
|
|
1178
|
+
return gemv_axbpy<false>(
|
|
1179
|
+
/* const Stream& s = */ s,
|
|
1180
|
+
/* metal::Device& d = */ d,
|
|
1181
|
+
/* const array& a = */ a,
|
|
1182
|
+
/* const array& b = */ b,
|
|
1183
|
+
/* const array& c = */ b,
|
|
1184
|
+
/* array& out = */ out,
|
|
1185
|
+
/* int M = */ M,
|
|
1186
|
+
/* int N = */ N,
|
|
1187
|
+
/* int K = */ K,
|
|
1188
|
+
/* int batch_size_out = */ batch_size_out,
|
|
1189
|
+
/* int lda = */ lda,
|
|
1190
|
+
/* int ldb = */ ldb,
|
|
1191
|
+
/* bool transpose_a = */ transpose_a,
|
|
1192
|
+
/* bool transpose_b = */ transpose_b,
|
|
1193
|
+
/* std::vector<array>& copies = */ copies,
|
|
1194
|
+
/* Shape batch_shape = */ batch_shape,
|
|
1195
|
+
/* Strides A_batch_stride = */ A_batch_stride,
|
|
1196
|
+
/* Strides B_batch_stride = */ B_batch_stride);
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1200
|
+
// Matmul implementation
|
|
1201
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1202
|
+
|
|
1203
|
+
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
1204
|
+
assert(inputs.size() == 2);
|
|
1205
|
+
if (!issubdtype(out.dtype(), inexact)) {
|
|
1206
|
+
throw std::runtime_error("[matmul] dtype must be inexact.");
|
|
1207
|
+
}
|
|
1208
|
+
auto& s = stream();
|
|
1209
|
+
auto& d = metal::device(s.device);
|
|
1210
|
+
|
|
1211
|
+
auto& a_pre = inputs[0];
|
|
1212
|
+
auto& b_pre = inputs[1];
|
|
1213
|
+
// Return 0s if either input is empty
|
|
1214
|
+
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
|
1215
|
+
array zero = array(0, a_pre.dtype());
|
|
1216
|
+
fill_gpu(zero, out, s);
|
|
1217
|
+
d.add_temporary(std::move(zero), s.index);
|
|
1218
|
+
return;
|
|
1219
|
+
}
|
|
1220
|
+
|
|
1221
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
1222
|
+
|
|
1223
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1224
|
+
// Init checks and prep
|
|
1225
|
+
|
|
1226
|
+
int M = a_pre.shape(-2);
|
|
1227
|
+
int N = b_pre.shape(-1);
|
|
1228
|
+
int K = a_pre.shape(-1);
|
|
1229
|
+
|
|
1230
|
+
// Keep a vector with copies to be cleared in the completed buffer to release
|
|
1231
|
+
// the arrays
|
|
1232
|
+
std::vector<array> copies;
|
|
1233
|
+
auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
|
1234
|
+
auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
|
1235
|
+
|
|
1236
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1237
|
+
// Check and collapse batch dimensions
|
|
1238
|
+
|
|
1239
|
+
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
|
|
1240
|
+
|
|
1241
|
+
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
|
|
1242
|
+
|
|
1243
|
+
// Collapse batches into M if needed
|
|
1244
|
+
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
|
|
1245
|
+
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
|
1246
|
+
B_batch_stride.back() == 0) {
|
|
1247
|
+
M *= batch_shape.back();
|
|
1248
|
+
batch_size_out = 1;
|
|
1249
|
+
|
|
1250
|
+
A_batch_stride = {0};
|
|
1251
|
+
B_batch_stride = {0};
|
|
1252
|
+
batch_shape = {1};
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1256
|
+
// Gemv specialization
|
|
1257
|
+
|
|
1258
|
+
// Route to gemv if needed
|
|
1259
|
+
if (std::min(M, N) == 1) {
|
|
1260
|
+
return gemv(
|
|
1261
|
+
/* const Stream& s = */ s,
|
|
1262
|
+
/* metal::Device& d = */ d,
|
|
1263
|
+
/* const array& a = */ a,
|
|
1264
|
+
/* const array& b = */ b,
|
|
1265
|
+
/* array& out = */ out,
|
|
1266
|
+
/* int M = */ M,
|
|
1267
|
+
/* int N = */ N,
|
|
1268
|
+
/* int K = */ K,
|
|
1269
|
+
/* int batch_size_out = */ batch_size_out,
|
|
1270
|
+
/* int lda = */ a_cols,
|
|
1271
|
+
/* int ldb = */ b_cols,
|
|
1272
|
+
/* bool transpose_a = */ a_transposed,
|
|
1273
|
+
/* bool transpose_b = */ b_transposed,
|
|
1274
|
+
/* std::vector<array>& copies = */ copies,
|
|
1275
|
+
/* Shape batch_shape = */ std::move(batch_shape),
|
|
1276
|
+
/* Strides A_batch_stride = */ std::move(A_batch_stride),
|
|
1277
|
+
/* Strides B_batch_stride = */ std::move(B_batch_stride));
|
|
1278
|
+
}
|
|
1279
|
+
|
|
1280
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1281
|
+
// Gemm specialization
|
|
1282
|
+
|
|
1283
|
+
return steel_matmul(
|
|
1284
|
+
/* const Stream& s = */ s,
|
|
1285
|
+
/* metal::Device& d = */ d,
|
|
1286
|
+
/* const array& a = */ a,
|
|
1287
|
+
/* const array& b = */ b,
|
|
1288
|
+
/* array& out = */ out,
|
|
1289
|
+
/* int M = */ M,
|
|
1290
|
+
/* int N = */ N,
|
|
1291
|
+
/* int K = */ K,
|
|
1292
|
+
/* int batch_size_out = */ batch_size_out,
|
|
1293
|
+
/* int lda = */ a_cols,
|
|
1294
|
+
/* int ldb = */ b_cols,
|
|
1295
|
+
/* bool transpose_a = */ a_transposed,
|
|
1296
|
+
/* bool transpose_b = */ b_transposed,
|
|
1297
|
+
/* std::vector<array>& copies = */ copies,
|
|
1298
|
+
/* Shape batch_shape = */ std::move(batch_shape),
|
|
1299
|
+
/* Strides A_batch_stride = */ std::move(A_batch_stride),
|
|
1300
|
+
/* Strides B_batch_stride = */ std::move(B_batch_stride));
|
|
1301
|
+
}
|
|
1302
|
+
|
|
1303
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1304
|
+
// AddMM implementation
|
|
1305
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1306
|
+
|
|
1307
|
+
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
1308
|
+
assert(inputs.size() == 3);
|
|
1309
|
+
if (!issubdtype(out.dtype(), floating)) {
|
|
1310
|
+
throw std::runtime_error(
|
|
1311
|
+
"[matmul] Does not yet support non-floating point types.");
|
|
1312
|
+
}
|
|
1313
|
+
|
|
1314
|
+
// Return 0s if either input is empty
|
|
1315
|
+
if (out.size() == 0) {
|
|
1316
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
1317
|
+
return;
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
auto& s = stream();
|
|
1321
|
+
auto& d = metal::device(s.device);
|
|
1322
|
+
|
|
1323
|
+
// Handle empty matrix case (K=0)
|
|
1324
|
+
if (inputs[0].shape(-1) == 0) {
|
|
1325
|
+
auto& c = inputs[2];
|
|
1326
|
+
if (beta_ == 1.0f) {
|
|
1327
|
+
copy_gpu(
|
|
1328
|
+
c,
|
|
1329
|
+
out,
|
|
1330
|
+
c.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
|
1331
|
+
s);
|
|
1332
|
+
} else {
|
|
1333
|
+
array beta_scalar = array(beta_, c.dtype());
|
|
1334
|
+
binary_op_gpu({c, beta_scalar}, out, "Multiply", s);
|
|
1335
|
+
d.add_temporary(std::move(beta_scalar), s.index);
|
|
1336
|
+
}
|
|
1337
|
+
return;
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
1341
|
+
|
|
1342
|
+
auto& a_pre = inputs[0];
|
|
1343
|
+
auto& b_pre = inputs[1];
|
|
1344
|
+
auto& c_pre = inputs[2];
|
|
1345
|
+
|
|
1346
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1347
|
+
// Init checks and prep
|
|
1348
|
+
|
|
1349
|
+
int M = a_pre.shape(-2);
|
|
1350
|
+
int N = b_pre.shape(-1);
|
|
1351
|
+
int K = a_pre.shape(-1);
|
|
1352
|
+
|
|
1353
|
+
// Keep a vector with copies to be cleared in the completed buffer to release
|
|
1354
|
+
// the arrays
|
|
1355
|
+
std::vector<array> copies;
|
|
1356
|
+
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
|
1357
|
+
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
|
1358
|
+
|
|
1359
|
+
array c = c_pre;
|
|
1360
|
+
|
|
1361
|
+
int lda = a_cols;
|
|
1362
|
+
int ldb = b_cols;
|
|
1363
|
+
|
|
1364
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1365
|
+
// Check and collapse batch dimensions
|
|
1366
|
+
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
|
|
1367
|
+
collapse_batches(a, b, c);
|
|
1368
|
+
|
|
1369
|
+
int64_t matrix_stride_out = M * static_cast<int64_t>(N);
|
|
1370
|
+
auto batch_size_out = out.size() / (matrix_stride_out);
|
|
1371
|
+
|
|
1372
|
+
// Collapse batches into M if needed
|
|
1373
|
+
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
|
|
1374
|
+
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
|
|
1375
|
+
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
|
|
1376
|
+
B_batch_stride.back() == 0) {
|
|
1377
|
+
M *= batch_shape.back();
|
|
1378
|
+
batch_size_out = 1;
|
|
1379
|
+
|
|
1380
|
+
A_batch_stride = {0};
|
|
1381
|
+
B_batch_stride = {0};
|
|
1382
|
+
C_batch_stride = {0};
|
|
1383
|
+
batch_shape = {1};
|
|
1384
|
+
}
|
|
1385
|
+
|
|
1386
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1387
|
+
// Gemv specialization
|
|
1388
|
+
|
|
1389
|
+
// Route to gemv if needed
|
|
1390
|
+
if (std::min(M, N) == 1) {
|
|
1391
|
+
return gemv_axbpy(
|
|
1392
|
+
/* const Stream& s = */ s,
|
|
1393
|
+
/* metal::Device& d = */ d,
|
|
1394
|
+
/* const array& a = */ a,
|
|
1395
|
+
/* const array& b = */ b,
|
|
1396
|
+
/* const array& c = */ c,
|
|
1397
|
+
/* array& out = */ out,
|
|
1398
|
+
/* int M = */ M,
|
|
1399
|
+
/* int N = */ N,
|
|
1400
|
+
/* int K = */ K,
|
|
1401
|
+
/* int batch_size_out = */ batch_size_out,
|
|
1402
|
+
/* int lda = */ lda,
|
|
1403
|
+
/* int ldb = */ ldb,
|
|
1404
|
+
/* bool transpose_a = */ transpose_a,
|
|
1405
|
+
/* bool transpose_b = */ transpose_b,
|
|
1406
|
+
/* std::vector<array>& copies = */ copies,
|
|
1407
|
+
/* Shape batch_shape = */ batch_shape,
|
|
1408
|
+
/* Strides A_batch_stride = */ A_batch_stride,
|
|
1409
|
+
/* Strides B_batch_stride = */ B_batch_stride,
|
|
1410
|
+
/* Strides C_batch_stride = */ C_batch_stride,
|
|
1411
|
+
/* float alpha = */ alpha_,
|
|
1412
|
+
/* float beta = */ beta_);
|
|
1413
|
+
}
|
|
1414
|
+
|
|
1415
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1416
|
+
// Regular addmm dispatch
|
|
1417
|
+
|
|
1418
|
+
return steel_matmul_axpby(
|
|
1419
|
+
/* const Stream& s = */ s,
|
|
1420
|
+
/* metal::Device& d = */ d,
|
|
1421
|
+
/* const array& a = */ a,
|
|
1422
|
+
/* const array& b = */ b,
|
|
1423
|
+
/* const array& c = */ c,
|
|
1424
|
+
/* array& out = */ out,
|
|
1425
|
+
/* int M = */ M,
|
|
1426
|
+
/* int N = */ N,
|
|
1427
|
+
/* int K = */ K,
|
|
1428
|
+
/* int batch_size_out = */ batch_size_out,
|
|
1429
|
+
/* int lda = */ lda,
|
|
1430
|
+
/* int ldb = */ ldb,
|
|
1431
|
+
/* bool transpose_a = */ transpose_a,
|
|
1432
|
+
/* bool transpose_b = */ transpose_b,
|
|
1433
|
+
/* std::vector<array>& copies = */ copies,
|
|
1434
|
+
/* Shape batch_shape = */ batch_shape,
|
|
1435
|
+
/* Strides A_batch_stride = */ A_batch_stride,
|
|
1436
|
+
/* Strides B_batch_stride = */ B_batch_stride,
|
|
1437
|
+
/* Strides B_batch_stride = */ C_batch_stride,
|
|
1438
|
+
/* float alpha = */ alpha_,
|
|
1439
|
+
/* float beta = */ beta_);
|
|
1440
|
+
}
|
|
1441
|
+
|
|
1442
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1443
|
+
// BlockMaskedMM implementation
|
|
1444
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1445
|
+
|
|
1446
|
+
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
1447
|
+
using namespace mlx::steel;
|
|
1448
|
+
// assert(inputs.size() == 2);
|
|
1449
|
+
if (!issubdtype(out.dtype(), floating)) {
|
|
1450
|
+
throw std::runtime_error(
|
|
1451
|
+
"[matmul] Does not yet support non-floating point types.");
|
|
1452
|
+
}
|
|
1453
|
+
auto& s = stream();
|
|
1454
|
+
auto& d = metal::device(s.device);
|
|
1455
|
+
|
|
1456
|
+
auto& a_pre = inputs[0];
|
|
1457
|
+
auto& b_pre = inputs[1];
|
|
1458
|
+
// Return 0s if either input is empty
|
|
1459
|
+
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
|
1460
|
+
array zero = array(0, a_pre.dtype());
|
|
1461
|
+
fill_gpu(zero, out, s);
|
|
1462
|
+
d.add_temporary(std::move(zero), s.index);
|
|
1463
|
+
return;
|
|
1464
|
+
}
|
|
1465
|
+
|
|
1466
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
1467
|
+
|
|
1468
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1469
|
+
// Init checks and prep
|
|
1470
|
+
|
|
1471
|
+
int M = a_pre.shape(-2);
|
|
1472
|
+
int N = b_pre.shape(-1);
|
|
1473
|
+
int K = a_pre.shape(-1);
|
|
1474
|
+
|
|
1475
|
+
// Keep a vector with copies to be cleared in the completed buffer to release
|
|
1476
|
+
// the arrays
|
|
1477
|
+
std::vector<array> copies;
|
|
1478
|
+
auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
|
|
1479
|
+
auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
|
|
1480
|
+
|
|
1481
|
+
int lda = a_cols;
|
|
1482
|
+
int ldb = b_cols;
|
|
1483
|
+
|
|
1484
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1485
|
+
// Check and collapse batch dimensions
|
|
1486
|
+
|
|
1487
|
+
bool has_op_mask = inputs.size() > 3;
|
|
1488
|
+
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
|
1489
|
+
|
|
1490
|
+
// Prepare kernel name
|
|
1491
|
+
std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
|
|
1492
|
+
std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
|
|
1493
|
+
|
|
1494
|
+
Shape batch_shape{1};
|
|
1495
|
+
Strides A_batch_stride{0};
|
|
1496
|
+
Strides B_batch_stride{0};
|
|
1497
|
+
Strides outmask_bstride{0};
|
|
1498
|
+
Strides Amask_bstride{0};
|
|
1499
|
+
Strides Bmask_bstride{0};
|
|
1500
|
+
int64_t A_batch_str = 0;
|
|
1501
|
+
int64_t B_batch_str = 0;
|
|
1502
|
+
|
|
1503
|
+
Strides batch_strides;
|
|
1504
|
+
|
|
1505
|
+
if (out.ndim() > 2) {
|
|
1506
|
+
Shape bshape{out.shape().begin(), out.shape().end() - 2};
|
|
1507
|
+
std::vector<Strides> bstrides;
|
|
1508
|
+
|
|
1509
|
+
for (auto& arr : inputs) {
|
|
1510
|
+
bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
|
|
1511
|
+
}
|
|
1512
|
+
|
|
1513
|
+
// auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
|
|
1514
|
+
batch_shape = bshape;
|
|
1515
|
+
A_batch_str = bstrides[0].back();
|
|
1516
|
+
B_batch_str = bstrides[1].back();
|
|
1517
|
+
|
|
1518
|
+
for (auto& bstr : bstrides) {
|
|
1519
|
+
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
|
|
1520
|
+
}
|
|
1521
|
+
|
|
1522
|
+
A_batch_stride = bstrides[0];
|
|
1523
|
+
B_batch_stride = bstrides[1];
|
|
1524
|
+
|
|
1525
|
+
if (has_out_mask) {
|
|
1526
|
+
outmask_bstride = bstrides[2];
|
|
1527
|
+
}
|
|
1528
|
+
if (has_op_mask) {
|
|
1529
|
+
Amask_bstride = bstrides[has_out_mask + 2];
|
|
1530
|
+
Bmask_bstride = bstrides[has_out_mask + 3];
|
|
1531
|
+
}
|
|
1532
|
+
|
|
1533
|
+
} else {
|
|
1534
|
+
batch_strides = Strides(inputs.size(), 0);
|
|
1535
|
+
}
|
|
1536
|
+
|
|
1537
|
+
int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
|
|
1538
|
+
size_t batch_size_out = out.size() / (matrix_stride_out);
|
|
1539
|
+
|
|
1540
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1541
|
+
// Gemv specialization
|
|
1542
|
+
|
|
1543
|
+
// Route to gemv if needed
|
|
1544
|
+
if (std::min(M, N) == 1) {
|
|
1545
|
+
// Collect problem info
|
|
1546
|
+
bool is_b_matrix = N != 1;
|
|
1547
|
+
|
|
1548
|
+
auto& mat = is_b_matrix ? b : a;
|
|
1549
|
+
auto& vec = is_b_matrix ? a : b;
|
|
1550
|
+
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
|
|
1551
|
+
int in_vector_len = K;
|
|
1552
|
+
int out_vector_len = is_b_matrix ? N : M;
|
|
1553
|
+
|
|
1554
|
+
int mat_ld = is_b_matrix ? b_cols : a_cols;
|
|
1555
|
+
|
|
1556
|
+
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
|
|
1557
|
+
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
|
|
1558
|
+
|
|
1559
|
+
auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride;
|
|
1560
|
+
auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride;
|
|
1561
|
+
|
|
1562
|
+
auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2);
|
|
1563
|
+
auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3);
|
|
1564
|
+
|
|
1565
|
+
// Determine if inputs have simple batching / broadcasting
|
|
1566
|
+
bool contiguous_kernel = (batch_shape.size() == 1);
|
|
1567
|
+
|
|
1568
|
+
int batch_ndim = batch_shape.size();
|
|
1569
|
+
|
|
1570
|
+
// Determine dispatch kernel
|
|
1571
|
+
int tm = 4, tn = 4;
|
|
1572
|
+
int sm = 1, sn = 32;
|
|
1573
|
+
int bm = 1, bn = 1;
|
|
1574
|
+
int n_out_per_tgp;
|
|
1575
|
+
std::ostringstream kname;
|
|
1576
|
+
|
|
1577
|
+
if (transpose_mat) {
|
|
1578
|
+
sm = 8;
|
|
1579
|
+
sn = 4;
|
|
1580
|
+
bm = 1;
|
|
1581
|
+
bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2;
|
|
1582
|
+
tm = block_size_ == 32 ? 4 : 8;
|
|
1583
|
+
tn = 4;
|
|
1584
|
+
|
|
1585
|
+
// Specialized kernel for very small outputs
|
|
1586
|
+
tn = out_vector_len < tn ? 1 : tn;
|
|
1587
|
+
|
|
1588
|
+
n_out_per_tgp = bn * sn * tn;
|
|
1589
|
+
kname << "gemv_t";
|
|
1590
|
+
|
|
1591
|
+
} else {
|
|
1592
|
+
if (block_size_ == 32) {
|
|
1593
|
+
sm = 4;
|
|
1594
|
+
sn = 8;
|
|
1595
|
+
bm = 2;
|
|
1596
|
+
} else {
|
|
1597
|
+
sm = 2;
|
|
1598
|
+
sn = 16;
|
|
1599
|
+
bm = out_vector_len >= 512 ? 4 : 2;
|
|
1600
|
+
}
|
|
1601
|
+
|
|
1602
|
+
// Specialized kernel for very small outputs
|
|
1603
|
+
tm = out_vector_len < tm ? 1 : tm;
|
|
1604
|
+
|
|
1605
|
+
n_out_per_tgp = bm * sm * tm;
|
|
1606
|
+
kname << "gemv";
|
|
1607
|
+
}
|
|
1608
|
+
|
|
1609
|
+
kname << "_outmask_" << out_mask_nm;
|
|
1610
|
+
kname << "_opmask_" << op_mask_nm;
|
|
1611
|
+
kname << "_" << type_to_name(out);
|
|
1612
|
+
kname << "_bm" << bm << "_bn" << bn;
|
|
1613
|
+
kname << "_sm" << sm << "_sn" << sn;
|
|
1614
|
+
kname << "_tm" << tm << "_tn" << tn;
|
|
1615
|
+
kname << "_nc" << !contiguous_kernel;
|
|
1616
|
+
|
|
1617
|
+
// Encode and dispatch kernel
|
|
1618
|
+
auto kernel = get_gemv_masked_kernel(
|
|
1619
|
+
d,
|
|
1620
|
+
kname.str(),
|
|
1621
|
+
out,
|
|
1622
|
+
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
|
1623
|
+
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
|
1624
|
+
transpose_mat,
|
|
1625
|
+
bm,
|
|
1626
|
+
bn,
|
|
1627
|
+
sm,
|
|
1628
|
+
sn,
|
|
1629
|
+
tm,
|
|
1630
|
+
tn,
|
|
1631
|
+
contiguous_kernel);
|
|
1632
|
+
|
|
1633
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
1634
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
1635
|
+
|
|
1636
|
+
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
|
1637
|
+
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
|
1638
|
+
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
|
1639
|
+
|
|
1640
|
+
// Get mask params
|
|
1641
|
+
std::vector<int> mask_strides;
|
|
1642
|
+
Strides mask_batch_strides;
|
|
1643
|
+
if (has_out_mask) {
|
|
1644
|
+
auto& out_mask = inputs[2];
|
|
1645
|
+
|
|
1646
|
+
if (transpose_mat) {
|
|
1647
|
+
mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2));
|
|
1648
|
+
mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1));
|
|
1649
|
+
} else {
|
|
1650
|
+
mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2));
|
|
1651
|
+
mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1));
|
|
1652
|
+
}
|
|
1653
|
+
|
|
1654
|
+
mask_batch_strides.insert(
|
|
1655
|
+
mask_batch_strides.end(),
|
|
1656
|
+
outmask_bstride.begin(),
|
|
1657
|
+
outmask_bstride.end());
|
|
1658
|
+
|
|
1659
|
+
compute_encoder.set_input_array(out_mask, 20);
|
|
1660
|
+
}
|
|
1661
|
+
|
|
1662
|
+
if (has_op_mask) {
|
|
1663
|
+
auto& mat_mask = inputs[mat_mask_idx];
|
|
1664
|
+
|
|
1665
|
+
if (transpose_mat) {
|
|
1666
|
+
mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1));
|
|
1667
|
+
mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2));
|
|
1668
|
+
} else {
|
|
1669
|
+
mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1));
|
|
1670
|
+
mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2));
|
|
1671
|
+
}
|
|
1672
|
+
|
|
1673
|
+
mask_batch_strides.insert(
|
|
1674
|
+
mask_batch_strides.end(),
|
|
1675
|
+
mask_bstrides_mat.begin(),
|
|
1676
|
+
mask_bstrides_mat.end());
|
|
1677
|
+
|
|
1678
|
+
compute_encoder.set_input_array(mat_mask, 21);
|
|
1679
|
+
|
|
1680
|
+
auto& vec_mask = inputs[vec_mask_idx];
|
|
1681
|
+
if (transpose_mat) {
|
|
1682
|
+
mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2));
|
|
1683
|
+
mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1));
|
|
1684
|
+
} else {
|
|
1685
|
+
mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2));
|
|
1686
|
+
mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1));
|
|
1687
|
+
}
|
|
1688
|
+
|
|
1689
|
+
mask_batch_strides.insert(
|
|
1690
|
+
mask_batch_strides.end(),
|
|
1691
|
+
mask_bstrides_vec.begin(),
|
|
1692
|
+
mask_bstrides_vec.end());
|
|
1693
|
+
|
|
1694
|
+
compute_encoder.set_input_array(vec_mask, 22);
|
|
1695
|
+
}
|
|
1696
|
+
|
|
1697
|
+
// Get gemv params
|
|
1698
|
+
compute_encoder.set_input_array(mat, 0);
|
|
1699
|
+
compute_encoder.set_input_array(vec, 1);
|
|
1700
|
+
compute_encoder.set_output_array(out, 3);
|
|
1701
|
+
|
|
1702
|
+
compute_encoder.set_bytes(in_vector_len, 4);
|
|
1703
|
+
compute_encoder.set_bytes(out_vector_len, 5);
|
|
1704
|
+
compute_encoder.set_bytes(mat_ld, 6);
|
|
1705
|
+
compute_encoder.set_bytes(batch_ndim, 9);
|
|
1706
|
+
compute_encoder.set_vector_bytes(batch_shape, 10);
|
|
1707
|
+
compute_encoder.set_vector_bytes(batch_strides_vec, 11);
|
|
1708
|
+
compute_encoder.set_vector_bytes(batch_strides_mat, 12);
|
|
1709
|
+
|
|
1710
|
+
compute_encoder.set_vector_bytes(mask_strides, 23);
|
|
1711
|
+
compute_encoder.set_vector_bytes(mask_batch_strides, 24);
|
|
1712
|
+
|
|
1713
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
1714
|
+
|
|
1715
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
1716
|
+
return;
|
|
1717
|
+
}
|
|
1718
|
+
|
|
1719
|
+
/////////////////////////////////////////////////////////////////////////////
|
|
1720
|
+
// Regular kernel dispatch
|
|
1721
|
+
|
|
1722
|
+
// Determine dispatch kernel
|
|
1723
|
+
int bm = block_size_, bn = block_size_, bk = 16;
|
|
1724
|
+
int wm = 2, wn = 2;
|
|
1725
|
+
bool mn_aligned = M % bm == 0 && N % bn == 0;
|
|
1726
|
+
bool k_aligned = K % bk == 0;
|
|
1727
|
+
|
|
1728
|
+
std::ostringstream kname;
|
|
1729
|
+
kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
|
|
1730
|
+
<< op_mask_nm << "_" << (transpose_a ? 't' : 'n')
|
|
1731
|
+
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
|
1732
|
+
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
|
1733
|
+
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
|
1734
|
+
<< "aligned"
|
|
1735
|
+
<< "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
|
1736
|
+
|
|
1737
|
+
// Encode and dispatch kernel
|
|
1738
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
1739
|
+
auto kernel = get_steel_gemm_masked_kernel(
|
|
1740
|
+
d,
|
|
1741
|
+
kname.str(),
|
|
1742
|
+
out,
|
|
1743
|
+
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
|
|
1744
|
+
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
|
|
1745
|
+
transpose_a,
|
|
1746
|
+
transpose_b,
|
|
1747
|
+
bm,
|
|
1748
|
+
bn,
|
|
1749
|
+
bk,
|
|
1750
|
+
wm,
|
|
1751
|
+
wn,
|
|
1752
|
+
mn_aligned,
|
|
1753
|
+
k_aligned);
|
|
1754
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
1755
|
+
|
|
1756
|
+
// Use problem size to determine threadblock swizzle
|
|
1757
|
+
int tn = (N + bn - 1) / bn;
|
|
1758
|
+
int tm = (M + bm - 1) / bm;
|
|
1759
|
+
|
|
1760
|
+
// TODO: Explore device-based tuning for swizzle
|
|
1761
|
+
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
|
|
1762
|
+
|
|
1763
|
+
// Prepare steel matmul params
|
|
1764
|
+
GEMMParams params{/* const int M = */ M,
|
|
1765
|
+
/* const int N = */ N,
|
|
1766
|
+
/* const int K = */ K,
|
|
1767
|
+
/* const int lda = */ lda,
|
|
1768
|
+
/* const int ldb = */ ldb,
|
|
1769
|
+
/* const int ldd = */ N,
|
|
1770
|
+
/* const int tiles_n = */ tn,
|
|
1771
|
+
/* const int tiles_m = */ tm,
|
|
1772
|
+
/* const int64_t batch_stride_a = */ A_batch_str,
|
|
1773
|
+
/* const int64_t batch_stride_b = */ B_batch_str,
|
|
1774
|
+
/* const int64_t batch_stride_d = */ matrix_stride_out,
|
|
1775
|
+
/* const int swizzle_log = */ swizzle_log,
|
|
1776
|
+
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
1777
|
+
/* const int batch_ndim = */ int(batch_shape.size())};
|
|
1778
|
+
|
|
1779
|
+
// Prepare launch grid params
|
|
1780
|
+
int tile = 1 << swizzle_log;
|
|
1781
|
+
tm = (tm + tile - 1) / tile;
|
|
1782
|
+
tn = tn * tile;
|
|
1783
|
+
|
|
1784
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
1785
|
+
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
|
|
1786
|
+
|
|
1787
|
+
std::vector<int> mask_strides;
|
|
1788
|
+
|
|
1789
|
+
if (has_out_mask) {
|
|
1790
|
+
auto& out_mask = inputs[2];
|
|
1791
|
+
mask_strides.push_back(*(out_mask.strides().end() - 1));
|
|
1792
|
+
mask_strides.push_back(*(out_mask.strides().end() - 2));
|
|
1793
|
+
|
|
1794
|
+
compute_encoder.set_input_array(out_mask, 10);
|
|
1795
|
+
}
|
|
1796
|
+
|
|
1797
|
+
if (has_op_mask) {
|
|
1798
|
+
auto& lhs_mask = inputs[2 + has_out_mask];
|
|
1799
|
+
mask_strides.push_back(*(lhs_mask.strides().end() - 1));
|
|
1800
|
+
mask_strides.push_back(*(lhs_mask.strides().end() - 2));
|
|
1801
|
+
|
|
1802
|
+
compute_encoder.set_input_array(lhs_mask, 11);
|
|
1803
|
+
|
|
1804
|
+
auto& rhs_mask = inputs[3 + has_out_mask];
|
|
1805
|
+
mask_strides.push_back(*(rhs_mask.strides().end() - 1));
|
|
1806
|
+
mask_strides.push_back(*(rhs_mask.strides().end() - 2));
|
|
1807
|
+
|
|
1808
|
+
compute_encoder.set_input_array(rhs_mask, 12);
|
|
1809
|
+
}
|
|
1810
|
+
|
|
1811
|
+
// Launch kernel
|
|
1812
|
+
compute_encoder.set_input_array(a, 0);
|
|
1813
|
+
compute_encoder.set_input_array(b, 1);
|
|
1814
|
+
compute_encoder.set_output_array(out, 3);
|
|
1815
|
+
|
|
1816
|
+
compute_encoder.set_bytes(params, 4);
|
|
1817
|
+
|
|
1818
|
+
compute_encoder.set_vector_bytes(batch_shape, 6);
|
|
1819
|
+
compute_encoder.set_vector_bytes(batch_strides, 7);
|
|
1820
|
+
|
|
1821
|
+
compute_encoder.set_vector_bytes(mask_strides, 13);
|
|
1822
|
+
|
|
1823
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
1824
|
+
|
|
1825
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
1826
|
+
}
|
|
1827
|
+
|
|
1828
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1829
|
+
// GatherMM implementation
|
|
1830
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
1831
|
+
|
|
1832
|
+
void gather_mm_rhs(
|
|
1833
|
+
const array& a_,
|
|
1834
|
+
const array& b_,
|
|
1835
|
+
const array& indices_,
|
|
1836
|
+
array& out,
|
|
1837
|
+
metal::Device& d,
|
|
1838
|
+
const Stream& s) {
|
|
1839
|
+
array indices = ensure_row_contiguous(indices_, d, s);
|
|
1840
|
+
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
|
1841
|
+
|
|
1842
|
+
// Broadcast a with indices. If we are here that means lhs_indices were not
|
|
1843
|
+
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
|
1844
|
+
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
|
1845
|
+
// lhs_indices.
|
|
1846
|
+
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
|
1847
|
+
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
|
1848
|
+
return ensure_row_contiguous(x, d, s);
|
|
1849
|
+
}
|
|
1850
|
+
|
|
1851
|
+
auto x_shape = indices.shape();
|
|
1852
|
+
x_shape.push_back(x.shape(-2));
|
|
1853
|
+
x_shape.push_back(x.shape(-1));
|
|
1854
|
+
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
|
1855
|
+
broadcast(x, new_x);
|
|
1856
|
+
return ensure_row_contiguous(new_x, d, s);
|
|
1857
|
+
};
|
|
1858
|
+
array a = broadcast_with_indices(a_);
|
|
1859
|
+
|
|
1860
|
+
// Extract the matmul shapes
|
|
1861
|
+
int K = a.shape(-1);
|
|
1862
|
+
int M = a.size() / K;
|
|
1863
|
+
int N = b.shape(-1);
|
|
1864
|
+
int lda = a.strides()[a.ndim() - 2]; // should be K
|
|
1865
|
+
|
|
1866
|
+
// Define the dispatch blocks
|
|
1867
|
+
int bm = 16, bn = 64, bk = 16;
|
|
1868
|
+
int wm = 1, wn = 2;
|
|
1869
|
+
|
|
1870
|
+
const bool align_M = (M % bm) == 0;
|
|
1871
|
+
const bool align_N = (N % bn) == 0;
|
|
1872
|
+
const bool align_K = (K % bk) == 0;
|
|
1873
|
+
|
|
1874
|
+
// Define the kernel name
|
|
1875
|
+
std::string base_name;
|
|
1876
|
+
base_name.reserve(64);
|
|
1877
|
+
concatenate(
|
|
1878
|
+
base_name,
|
|
1879
|
+
"steel_gather_mm_rhs_n",
|
|
1880
|
+
transpose_b ? 't' : 'n',
|
|
1881
|
+
'_',
|
|
1882
|
+
type_to_name(a),
|
|
1883
|
+
'_',
|
|
1884
|
+
type_to_name(out),
|
|
1885
|
+
"_bm",
|
|
1886
|
+
bm,
|
|
1887
|
+
"_bn",
|
|
1888
|
+
bn,
|
|
1889
|
+
"_bk",
|
|
1890
|
+
bk,
|
|
1891
|
+
"_wm",
|
|
1892
|
+
wm,
|
|
1893
|
+
"_wn",
|
|
1894
|
+
wn);
|
|
1895
|
+
|
|
1896
|
+
metal::MTLFCList func_consts = {
|
|
1897
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
1898
|
+
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
1899
|
+
{&align_K, MTL::DataType::DataTypeBool, 202},
|
|
1900
|
+
};
|
|
1901
|
+
|
|
1902
|
+
// And the kernel hash that includes the function constants
|
|
1903
|
+
std::string hash_name;
|
|
1904
|
+
hash_name.reserve(128);
|
|
1905
|
+
concatenate(
|
|
1906
|
+
hash_name,
|
|
1907
|
+
base_name,
|
|
1908
|
+
"_align_M_",
|
|
1909
|
+
align_M ? 't' : 'n',
|
|
1910
|
+
"_align_N_",
|
|
1911
|
+
align_N ? 't' : 'n',
|
|
1912
|
+
"_align_K_",
|
|
1913
|
+
align_K ? 't' : 'n');
|
|
1914
|
+
|
|
1915
|
+
// Get and set the kernel
|
|
1916
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
1917
|
+
auto kernel = get_steel_gemm_gather_kernel(
|
|
1918
|
+
d,
|
|
1919
|
+
base_name,
|
|
1920
|
+
hash_name,
|
|
1921
|
+
func_consts,
|
|
1922
|
+
out,
|
|
1923
|
+
false,
|
|
1924
|
+
transpose_b,
|
|
1925
|
+
bm,
|
|
1926
|
+
bn,
|
|
1927
|
+
bk,
|
|
1928
|
+
wm,
|
|
1929
|
+
wn,
|
|
1930
|
+
true);
|
|
1931
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
1932
|
+
|
|
1933
|
+
// Prepare the matmul params
|
|
1934
|
+
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
|
1935
|
+
steel::GEMMParams params{
|
|
1936
|
+
/* const int M = */ M,
|
|
1937
|
+
/* const int N = */ N,
|
|
1938
|
+
/* const int K = */ K,
|
|
1939
|
+
/* const int lda = */ lda,
|
|
1940
|
+
/* const int ldb = */ static_cast<int>(ldb),
|
|
1941
|
+
/* const int ldd = */ N,
|
|
1942
|
+
/* const int tiles_n = */ (N + bn - 1) / bn,
|
|
1943
|
+
/* const int tiles_m = */ (M + bm - 1) / bm,
|
|
1944
|
+
/* const int64_t batch_stride_a = */ 0,
|
|
1945
|
+
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
|
1946
|
+
/* const int64_t batch_stride_d = */ 0,
|
|
1947
|
+
/* const int swizzle_log = */ 0,
|
|
1948
|
+
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
1949
|
+
/* const int batch_ndim = */ 0};
|
|
1950
|
+
|
|
1951
|
+
// Prepare the grid
|
|
1952
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
1953
|
+
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
|
1954
|
+
|
|
1955
|
+
// Launch kernel
|
|
1956
|
+
compute_encoder.set_input_array(a, 0);
|
|
1957
|
+
compute_encoder.set_input_array(b, 1);
|
|
1958
|
+
compute_encoder.set_input_array(indices, 2);
|
|
1959
|
+
compute_encoder.set_output_array(out, 3);
|
|
1960
|
+
compute_encoder.set_bytes(params, 4);
|
|
1961
|
+
|
|
1962
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
1963
|
+
}
|
|
1964
|
+
|
|
1965
|
+
void gather_mm_rhs_nax(
|
|
1966
|
+
const array& a_,
|
|
1967
|
+
const array& b_,
|
|
1968
|
+
const array& indices_,
|
|
1969
|
+
array& out,
|
|
1970
|
+
metal::Device& d,
|
|
1971
|
+
const Stream& s) {
|
|
1972
|
+
array indices = ensure_row_contiguous(indices_, d, s);
|
|
1973
|
+
auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
|
|
1974
|
+
|
|
1975
|
+
// Broadcast a with indices. If we are here that means lhs_indices were not
|
|
1976
|
+
// provided so the lhs_indices are implied to be the shape of a broadcasted
|
|
1977
|
+
// with rhs_indices. We need only broadcast a and copy it as if applying the
|
|
1978
|
+
// lhs_indices.
|
|
1979
|
+
auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
|
|
1980
|
+
if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
|
|
1981
|
+
return ensure_row_contiguous(x, d, s);
|
|
1982
|
+
}
|
|
1983
|
+
|
|
1984
|
+
auto x_shape = indices.shape();
|
|
1985
|
+
x_shape.push_back(x.shape(-2));
|
|
1986
|
+
x_shape.push_back(x.shape(-1));
|
|
1987
|
+
array new_x(std::move(x_shape), x.dtype(), nullptr, {});
|
|
1988
|
+
broadcast(x, new_x);
|
|
1989
|
+
return ensure_row_contiguous(new_x, d, s);
|
|
1990
|
+
};
|
|
1991
|
+
array a = broadcast_with_indices(a_);
|
|
1992
|
+
|
|
1993
|
+
// Extract the matmul shapes
|
|
1994
|
+
int K = a.shape(-1);
|
|
1995
|
+
int M = a.size() / K;
|
|
1996
|
+
int N = b.shape(-1);
|
|
1997
|
+
int lda = a.strides()[a.ndim() - 2]; // should be K
|
|
1998
|
+
int E = b.shape(0);
|
|
1999
|
+
|
|
2000
|
+
// Define the dispatch blocks
|
|
2001
|
+
int bm, bn = 128, bk = 128, wm, wn = 4;
|
|
2002
|
+
if (M / E > 48) {
|
|
2003
|
+
bm = 64;
|
|
2004
|
+
wm = 2;
|
|
2005
|
+
} else if (M / E > 24) {
|
|
2006
|
+
bm = 32l;
|
|
2007
|
+
wm = 1;
|
|
2008
|
+
} else {
|
|
2009
|
+
bm = 16;
|
|
2010
|
+
wm = 1;
|
|
2011
|
+
}
|
|
2012
|
+
|
|
2013
|
+
const bool align_M = (M % bm) == 0;
|
|
2014
|
+
const bool align_N = (N % bn) == 0;
|
|
2015
|
+
const bool align_K = (K % bk) == 0;
|
|
2016
|
+
|
|
2017
|
+
// Define the kernel name
|
|
2018
|
+
std::string base_name;
|
|
2019
|
+
base_name.reserve(64);
|
|
2020
|
+
concatenate(
|
|
2021
|
+
base_name,
|
|
2022
|
+
"steel_gather_mm_rhs_nax_n",
|
|
2023
|
+
transpose_b ? 't' : 'n',
|
|
2024
|
+
'_',
|
|
2025
|
+
type_to_name(a),
|
|
2026
|
+
'_',
|
|
2027
|
+
type_to_name(out),
|
|
2028
|
+
"_bm",
|
|
2029
|
+
bm,
|
|
2030
|
+
"_bn",
|
|
2031
|
+
bn,
|
|
2032
|
+
"_bk",
|
|
2033
|
+
bk,
|
|
2034
|
+
"_wm",
|
|
2035
|
+
wm,
|
|
2036
|
+
"_wn",
|
|
2037
|
+
wn);
|
|
2038
|
+
|
|
2039
|
+
metal::MTLFCList func_consts = {
|
|
2040
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
2041
|
+
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
2042
|
+
{&align_K, MTL::DataType::DataTypeBool, 202},
|
|
2043
|
+
};
|
|
2044
|
+
|
|
2045
|
+
// And the kernel hash that includes the function constants
|
|
2046
|
+
std::string hash_name;
|
|
2047
|
+
hash_name.reserve(128);
|
|
2048
|
+
concatenate(
|
|
2049
|
+
hash_name,
|
|
2050
|
+
base_name,
|
|
2051
|
+
"_align_M_",
|
|
2052
|
+
align_M ? 't' : 'n',
|
|
2053
|
+
"_align_N_",
|
|
2054
|
+
align_N ? 't' : 'n',
|
|
2055
|
+
"_align_K_",
|
|
2056
|
+
align_K ? 't' : 'n');
|
|
2057
|
+
|
|
2058
|
+
// Get and set the kernel
|
|
2059
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
2060
|
+
auto kernel = get_steel_gemm_gather_nax_kernel(
|
|
2061
|
+
d,
|
|
2062
|
+
base_name,
|
|
2063
|
+
hash_name,
|
|
2064
|
+
func_consts,
|
|
2065
|
+
out,
|
|
2066
|
+
false,
|
|
2067
|
+
transpose_b,
|
|
2068
|
+
bm,
|
|
2069
|
+
bn,
|
|
2070
|
+
bk,
|
|
2071
|
+
wm,
|
|
2072
|
+
wn,
|
|
2073
|
+
true);
|
|
2074
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
2075
|
+
|
|
2076
|
+
// Prepare the matmul params
|
|
2077
|
+
auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
|
|
2078
|
+
steel::GEMMParams params{
|
|
2079
|
+
/* const int M = */ M,
|
|
2080
|
+
/* const int N = */ N,
|
|
2081
|
+
/* const int K = */ K,
|
|
2082
|
+
/* const int lda = */ lda,
|
|
2083
|
+
/* const int ldb = */ static_cast<int>(ldb),
|
|
2084
|
+
/* const int ldd = */ N,
|
|
2085
|
+
/* const int tiles_n = */ (N + bn - 1) / bn,
|
|
2086
|
+
/* const int tiles_m = */ (M + bm - 1) / bm,
|
|
2087
|
+
/* const int64_t batch_stride_a = */ 0,
|
|
2088
|
+
/* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
|
|
2089
|
+
/* const int64_t batch_stride_d = */ 0,
|
|
2090
|
+
/* const int swizzle_log = */ 0,
|
|
2091
|
+
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
2092
|
+
/* const int batch_ndim = */ 0};
|
|
2093
|
+
|
|
2094
|
+
// Prepare the grid
|
|
2095
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
2096
|
+
MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
|
|
2097
|
+
|
|
2098
|
+
// Launch kernel
|
|
2099
|
+
compute_encoder.set_input_array(a, 0);
|
|
2100
|
+
compute_encoder.set_input_array(b, 1);
|
|
2101
|
+
compute_encoder.set_input_array(indices, 2);
|
|
2102
|
+
compute_encoder.set_output_array(out, 3);
|
|
2103
|
+
compute_encoder.set_bytes(params, 4);
|
|
2104
|
+
|
|
2105
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
2106
|
+
}
|
|
2107
|
+
|
|
2108
|
+
void gather_mv(
|
|
2109
|
+
const array& mat_,
|
|
2110
|
+
const array& vec_,
|
|
2111
|
+
const array& mat_indices_,
|
|
2112
|
+
const array& vec_indices_,
|
|
2113
|
+
array& out,
|
|
2114
|
+
int N,
|
|
2115
|
+
int K,
|
|
2116
|
+
bool is_mv,
|
|
2117
|
+
metal::Device& d,
|
|
2118
|
+
const Stream& s) {
|
|
2119
|
+
// Copy if needed
|
|
2120
|
+
std::vector<array> copies;
|
|
2121
|
+
auto [transpose_mat, mat_cols, mat] =
|
|
2122
|
+
check_transpose(copies, s, mat_, N == 1);
|
|
2123
|
+
auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
|
|
2124
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
2125
|
+
|
|
2126
|
+
// If we are doing vector matrix instead of matrix vector we need to flip the
|
|
2127
|
+
// matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
|
|
2128
|
+
// as a one dimensional array.
|
|
2129
|
+
transpose_mat = (!is_mv) ^ transpose_mat;
|
|
2130
|
+
|
|
2131
|
+
// Define some shapes
|
|
2132
|
+
int in_vector_len = K;
|
|
2133
|
+
int out_vector_len = N;
|
|
2134
|
+
int mat_ld = mat_cols;
|
|
2135
|
+
|
|
2136
|
+
int batch_size_out = out.size() / N;
|
|
2137
|
+
int batch_ndim = out.ndim() - 2;
|
|
2138
|
+
int batch_ndim_mat = mat.ndim() - 2;
|
|
2139
|
+
int batch_ndim_vec = vec.ndim() - 2;
|
|
2140
|
+
Strides index_strides = vec_indices_.strides();
|
|
2141
|
+
index_strides.insert(
|
|
2142
|
+
index_strides.end(),
|
|
2143
|
+
mat_indices_.strides().begin(),
|
|
2144
|
+
mat_indices_.strides().end());
|
|
2145
|
+
|
|
2146
|
+
// Determine dispatch kernel
|
|
2147
|
+
int tm = 4, tn = 4;
|
|
2148
|
+
int sm = 1, sn = 32;
|
|
2149
|
+
int bm = 1, bn = 1;
|
|
2150
|
+
int n_out_per_tgp;
|
|
2151
|
+
std::ostringstream kname;
|
|
2152
|
+
|
|
2153
|
+
if (transpose_mat) {
|
|
2154
|
+
if (in_vector_len >= 8192 && out_vector_len >= 2048) {
|
|
2155
|
+
sm = 4;
|
|
2156
|
+
sn = 8;
|
|
2157
|
+
} else {
|
|
2158
|
+
sm = 8;
|
|
2159
|
+
sn = 4;
|
|
2160
|
+
}
|
|
2161
|
+
|
|
2162
|
+
if (out_vector_len >= 2048) {
|
|
2163
|
+
bn = 16;
|
|
2164
|
+
} else if (out_vector_len >= 512) {
|
|
2165
|
+
bn = 4;
|
|
2166
|
+
} else {
|
|
2167
|
+
bn = 2;
|
|
2168
|
+
}
|
|
2169
|
+
|
|
2170
|
+
// Specialized kernel for very small outputs
|
|
2171
|
+
tn = out_vector_len < tn ? 1 : tn;
|
|
2172
|
+
|
|
2173
|
+
n_out_per_tgp = bn * sn * tn;
|
|
2174
|
+
kname << "gemv_t_gather_" << type_to_name(out);
|
|
2175
|
+
|
|
2176
|
+
} else {
|
|
2177
|
+
bm = out_vector_len >= 4096 ? 8 : 4;
|
|
2178
|
+
sn = 32;
|
|
2179
|
+
|
|
2180
|
+
// Specialized kernel for very small outputs
|
|
2181
|
+
tm = out_vector_len < tm ? 1 : tm;
|
|
2182
|
+
|
|
2183
|
+
n_out_per_tgp = bm * sm * tm;
|
|
2184
|
+
kname << "gemv_gather_" << type_to_name(out);
|
|
2185
|
+
}
|
|
2186
|
+
|
|
2187
|
+
kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
|
|
2188
|
+
<< tm << "_tn" << tn;
|
|
2189
|
+
|
|
2190
|
+
// Encode and dispatch kernel
|
|
2191
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
2192
|
+
auto kernel = d.get_kernel(kname.str());
|
|
2193
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
2194
|
+
|
|
2195
|
+
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
|
|
2196
|
+
MTL::Size group_dims = MTL::Size(32, bn, bm);
|
|
2197
|
+
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
|
|
2198
|
+
|
|
2199
|
+
compute_encoder.set_input_array(mat, 0);
|
|
2200
|
+
compute_encoder.set_input_array(vec, 1);
|
|
2201
|
+
compute_encoder.set_output_array(out, 3);
|
|
2202
|
+
|
|
2203
|
+
compute_encoder.set_bytes(in_vector_len, 4);
|
|
2204
|
+
compute_encoder.set_bytes(out_vector_len, 5);
|
|
2205
|
+
compute_encoder.set_bytes(mat_ld, 6);
|
|
2206
|
+
|
|
2207
|
+
compute_encoder.set_bytes(batch_ndim, 9);
|
|
2208
|
+
compute_encoder.set_vector_bytes(out.shape(), 10);
|
|
2209
|
+
compute_encoder.set_vector_bytes(index_strides, 11);
|
|
2210
|
+
|
|
2211
|
+
compute_encoder.set_bytes(batch_ndim_vec, 12);
|
|
2212
|
+
compute_encoder.set_vector_bytes(vec.shape(), 13);
|
|
2213
|
+
compute_encoder.set_vector_bytes(vec.strides(), 14);
|
|
2214
|
+
|
|
2215
|
+
compute_encoder.set_bytes(batch_ndim_mat, 15);
|
|
2216
|
+
compute_encoder.set_vector_bytes(mat.shape(), 16);
|
|
2217
|
+
compute_encoder.set_vector_bytes(mat.strides(), 17);
|
|
2218
|
+
|
|
2219
|
+
compute_encoder.set_input_array(vec_indices_, 18);
|
|
2220
|
+
compute_encoder.set_input_array(mat_indices_, 19);
|
|
2221
|
+
|
|
2222
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
2223
|
+
}
|
|
2224
|
+
|
|
2225
|
+
void gather_mm(
|
|
2226
|
+
const array& a_,
|
|
2227
|
+
const array& b_,
|
|
2228
|
+
const array& lhs_indices,
|
|
2229
|
+
const array& rhs_indices,
|
|
2230
|
+
array& out,
|
|
2231
|
+
int M,
|
|
2232
|
+
int N,
|
|
2233
|
+
int K,
|
|
2234
|
+
metal::Device& d,
|
|
2235
|
+
const Stream& s) {
|
|
2236
|
+
// Copy if needed
|
|
2237
|
+
std::vector<array> copies;
|
|
2238
|
+
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
|
2239
|
+
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
|
2240
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
2241
|
+
|
|
2242
|
+
// Determine dispatch kernel
|
|
2243
|
+
int bm = 64, bn = 64, bk = 16;
|
|
2244
|
+
int wm = 2, wn = 2;
|
|
2245
|
+
size_t batch_size_out = out.size() / M / N;
|
|
2246
|
+
int batch_ndim = out.ndim() - 2;
|
|
2247
|
+
int batch_ndim_a = a.ndim() - 2;
|
|
2248
|
+
int batch_ndim_b = b.ndim() - 2;
|
|
2249
|
+
|
|
2250
|
+
char devc = d.get_architecture().back();
|
|
2251
|
+
GEMM_TPARAM_MACRO(devc)
|
|
2252
|
+
|
|
2253
|
+
const bool has_batch = batch_ndim > 1;
|
|
2254
|
+
const bool align_M = (M % bm) == 0;
|
|
2255
|
+
const bool align_N = (N % bn) == 0;
|
|
2256
|
+
const bool align_K = (K % bk) == 0;
|
|
2257
|
+
|
|
2258
|
+
// Define the kernel name
|
|
2259
|
+
std::string base_name;
|
|
2260
|
+
base_name.reserve(128);
|
|
2261
|
+
concatenate(
|
|
2262
|
+
base_name,
|
|
2263
|
+
"steel_gather_mm_",
|
|
2264
|
+
transpose_a ? 't' : 'n',
|
|
2265
|
+
transpose_b ? 't' : 'n',
|
|
2266
|
+
"_",
|
|
2267
|
+
type_to_name(a),
|
|
2268
|
+
"_",
|
|
2269
|
+
type_to_name(out),
|
|
2270
|
+
"_bm",
|
|
2271
|
+
bm,
|
|
2272
|
+
"_bn",
|
|
2273
|
+
bn,
|
|
2274
|
+
"_bk",
|
|
2275
|
+
bk,
|
|
2276
|
+
"_wm",
|
|
2277
|
+
wm,
|
|
2278
|
+
"_wn",
|
|
2279
|
+
wn);
|
|
2280
|
+
|
|
2281
|
+
metal::MTLFCList func_consts = {
|
|
2282
|
+
{&has_batch, MTL::DataType::DataTypeBool, 10},
|
|
2283
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
2284
|
+
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
2285
|
+
{&align_K, MTL::DataType::DataTypeBool, 202},
|
|
2286
|
+
};
|
|
2287
|
+
|
|
2288
|
+
// And the kernel hash that includes the function constants
|
|
2289
|
+
std::string hash_name;
|
|
2290
|
+
hash_name.reserve(128);
|
|
2291
|
+
concatenate(
|
|
2292
|
+
hash_name,
|
|
2293
|
+
base_name,
|
|
2294
|
+
"_has_batch_",
|
|
2295
|
+
has_batch ? 't' : 'n',
|
|
2296
|
+
"_align_M_",
|
|
2297
|
+
align_M ? 't' : 'n',
|
|
2298
|
+
"_align_N_",
|
|
2299
|
+
align_N ? 't' : 'n',
|
|
2300
|
+
"_align_K_",
|
|
2301
|
+
align_K ? 't' : 'n');
|
|
2302
|
+
|
|
2303
|
+
// Get and set the kernel
|
|
2304
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
2305
|
+
auto kernel = get_steel_gemm_gather_kernel(
|
|
2306
|
+
d,
|
|
2307
|
+
base_name,
|
|
2308
|
+
hash_name,
|
|
2309
|
+
func_consts,
|
|
2310
|
+
out,
|
|
2311
|
+
transpose_a,
|
|
2312
|
+
transpose_b,
|
|
2313
|
+
bm,
|
|
2314
|
+
bn,
|
|
2315
|
+
bk,
|
|
2316
|
+
wm,
|
|
2317
|
+
wn,
|
|
2318
|
+
false);
|
|
2319
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
2320
|
+
|
|
2321
|
+
// Prepare the matmul params
|
|
2322
|
+
steel::GEMMParams params{/* const int M = */ M,
|
|
2323
|
+
/* const int N = */ N,
|
|
2324
|
+
/* const int K = */ K,
|
|
2325
|
+
/* const int lda = */ static_cast<int>(lda),
|
|
2326
|
+
/* const int ldb = */ static_cast<int>(ldb),
|
|
2327
|
+
/* const int ldd = */ N,
|
|
2328
|
+
/* const int tiles_n = */ (N + bn - 1) / bn,
|
|
2329
|
+
/* const int tiles_m = */ (M + bm - 1) / bm,
|
|
2330
|
+
/* const int64_t batch_stride_a = */
|
|
2331
|
+
(batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
|
|
2332
|
+
/* const int64_t batch_stride_b = */
|
|
2333
|
+
(batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
|
|
2334
|
+
/* const int64_t batch_stride_d = */ M * N,
|
|
2335
|
+
/* const int swizzle_log = */ 0,
|
|
2336
|
+
/* const int gemm_k_iterations_aligned = */ (K / bk),
|
|
2337
|
+
/* const int batch_ndim = */ batch_ndim};
|
|
2338
|
+
|
|
2339
|
+
// Prepare the grid
|
|
2340
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
2341
|
+
MTL::Size grid_dims =
|
|
2342
|
+
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
|
2343
|
+
|
|
2344
|
+
// Launch kernel
|
|
2345
|
+
compute_encoder.set_input_array(a, 0);
|
|
2346
|
+
compute_encoder.set_input_array(b, 1);
|
|
2347
|
+
compute_encoder.set_input_array(lhs_indices, 2);
|
|
2348
|
+
compute_encoder.set_input_array(rhs_indices, 3);
|
|
2349
|
+
compute_encoder.set_output_array(out, 4);
|
|
2350
|
+
compute_encoder.set_bytes(params, 5);
|
|
2351
|
+
compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
|
|
2352
|
+
compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
|
|
2353
|
+
compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
|
|
2354
|
+
compute_encoder.set_bytes(batch_ndim_a, 9);
|
|
2355
|
+
compute_encoder.set_vector_bytes(a.shape(), 10);
|
|
2356
|
+
compute_encoder.set_vector_bytes(a.strides(), 11);
|
|
2357
|
+
compute_encoder.set_bytes(batch_ndim_b, 12);
|
|
2358
|
+
compute_encoder.set_vector_bytes(b.shape(), 13);
|
|
2359
|
+
compute_encoder.set_vector_bytes(b.strides(), 14);
|
|
2360
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
2361
|
+
}
|
|
2362
|
+
|
|
2363
|
+
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
2364
|
+
auto& s = stream();
|
|
2365
|
+
auto& d = metal::device(s.device);
|
|
2366
|
+
|
|
2367
|
+
auto& a = inputs[0];
|
|
2368
|
+
auto& b = inputs[1];
|
|
2369
|
+
auto& lhs_indices = inputs[2];
|
|
2370
|
+
auto& rhs_indices = inputs[3];
|
|
2371
|
+
|
|
2372
|
+
// Return 0s if either input is empty
|
|
2373
|
+
if (a.size() == 0 || b.size() == 0) {
|
|
2374
|
+
array zero = array(0, a.dtype());
|
|
2375
|
+
fill_gpu(zero, out, s);
|
|
2376
|
+
d.add_temporary(std::move(zero), s.index);
|
|
2377
|
+
return;
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
2381
|
+
|
|
2382
|
+
// Extract shapes from inputs.
|
|
2383
|
+
int M = a.shape(-2);
|
|
2384
|
+
int N = b.shape(-1);
|
|
2385
|
+
int K = a.shape(-1);
|
|
2386
|
+
|
|
2387
|
+
// We are walking a in order and b is also in order so we can batch up the
|
|
2388
|
+
// matmuls and reuse reading a and b.
|
|
2389
|
+
if (M == 1 && right_sorted_ == true) {
|
|
2390
|
+
if (metal::is_nax_available() &&
|
|
2391
|
+
(env::enable_tf32() || a.dtype() != float32)) {
|
|
2392
|
+
return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
|
|
2393
|
+
}
|
|
2394
|
+
gather_mm_rhs(a, b, rhs_indices, out, d, s);
|
|
2395
|
+
return;
|
|
2396
|
+
}
|
|
2397
|
+
|
|
2398
|
+
// Route to gather gemv if any of a or b are vectors
|
|
2399
|
+
if (M == 1) {
|
|
2400
|
+
gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
|
|
2401
|
+
return;
|
|
2402
|
+
}
|
|
2403
|
+
if (N == 1) {
|
|
2404
|
+
gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
|
|
2405
|
+
return;
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
// Route to non specialized gather mm
|
|
2409
|
+
gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
|
|
2410
|
+
}
|
|
2411
|
+
|
|
2412
|
+
void segmented_mm(
|
|
2413
|
+
const array& a_,
|
|
2414
|
+
const array& b_,
|
|
2415
|
+
const array& segments_,
|
|
2416
|
+
array& out,
|
|
2417
|
+
int M,
|
|
2418
|
+
int N,
|
|
2419
|
+
int K,
|
|
2420
|
+
metal::Device& d,
|
|
2421
|
+
const Stream& s) {
|
|
2422
|
+
auto check_segments_layout = [&d, &s](const array& x) {
|
|
2423
|
+
// Contiguous so return early
|
|
2424
|
+
if (x.flags().row_contiguous) {
|
|
2425
|
+
return std::make_tuple(true, x);
|
|
2426
|
+
}
|
|
2427
|
+
|
|
2428
|
+
bool rc = true;
|
|
2429
|
+
for (int i = 0; i < x.ndim() - 2; i++) {
|
|
2430
|
+
rc &=
|
|
2431
|
+
(x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
|
|
2432
|
+
}
|
|
2433
|
+
rc &= x.strides(x.ndim() - 1) == 1;
|
|
2434
|
+
if (x.ndim() > 1) {
|
|
2435
|
+
rc &= x.strides(x.ndim() - 2) == 1;
|
|
2436
|
+
}
|
|
2437
|
+
|
|
2438
|
+
if (rc) {
|
|
2439
|
+
return std::make_tuple(false, x);
|
|
2440
|
+
}
|
|
2441
|
+
|
|
2442
|
+
array x_copy = contiguous_copy_gpu(x, s);
|
|
2443
|
+
d.add_temporary(x_copy, s.index);
|
|
2444
|
+
return std::make_tuple(true, x_copy);
|
|
2445
|
+
};
|
|
2446
|
+
|
|
2447
|
+
// Copy if needed
|
|
2448
|
+
std::vector<array> copies;
|
|
2449
|
+
auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
|
|
2450
|
+
auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
|
|
2451
|
+
auto [segments_contiguous, segments] = check_segments_layout(segments_);
|
|
2452
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
2453
|
+
|
|
2454
|
+
// Determine dispatch kernel
|
|
2455
|
+
int bm = 64, bn = 64, bk = 16;
|
|
2456
|
+
int wm = 2, wn = 2;
|
|
2457
|
+
size_t batch_size_out = out.size() / M / N;
|
|
2458
|
+
|
|
2459
|
+
char devc = d.get_architecture().back();
|
|
2460
|
+
GEMM_TPARAM_MACRO(devc)
|
|
2461
|
+
|
|
2462
|
+
const bool align_M = (M % bm) == 0;
|
|
2463
|
+
const bool align_N = (N % bn) == 0;
|
|
2464
|
+
|
|
2465
|
+
// Define the kernel name
|
|
2466
|
+
std::string base_name;
|
|
2467
|
+
base_name.reserve(128);
|
|
2468
|
+
concatenate(
|
|
2469
|
+
base_name,
|
|
2470
|
+
"steel_segmented_mm_",
|
|
2471
|
+
transpose_a ? 't' : 'n',
|
|
2472
|
+
transpose_b ? 't' : 'n',
|
|
2473
|
+
"_",
|
|
2474
|
+
type_to_name(a),
|
|
2475
|
+
"_",
|
|
2476
|
+
type_to_name(out),
|
|
2477
|
+
"_bm",
|
|
2478
|
+
bm,
|
|
2479
|
+
"_bn",
|
|
2480
|
+
bn,
|
|
2481
|
+
"_bk",
|
|
2482
|
+
bk,
|
|
2483
|
+
"_wm",
|
|
2484
|
+
wm,
|
|
2485
|
+
"_wn",
|
|
2486
|
+
wn);
|
|
2487
|
+
|
|
2488
|
+
metal::MTLFCList func_consts = {
|
|
2489
|
+
{&segments_contiguous, MTL::DataType::DataTypeBool, 199},
|
|
2490
|
+
{&align_M, MTL::DataType::DataTypeBool, 200},
|
|
2491
|
+
{&align_N, MTL::DataType::DataTypeBool, 201},
|
|
2492
|
+
};
|
|
2493
|
+
|
|
2494
|
+
// And the kernel hash that includes the function constants
|
|
2495
|
+
std::string hash_name;
|
|
2496
|
+
hash_name.reserve(128);
|
|
2497
|
+
concatenate(
|
|
2498
|
+
hash_name,
|
|
2499
|
+
base_name,
|
|
2500
|
+
"_segments_contiguous_",
|
|
2501
|
+
segments_contiguous ? 't' : 'n',
|
|
2502
|
+
"_align_M_",
|
|
2503
|
+
align_M ? 't' : 'n',
|
|
2504
|
+
"_align_N_",
|
|
2505
|
+
align_N ? 't' : 'n');
|
|
2506
|
+
|
|
2507
|
+
// Get and set the kernel
|
|
2508
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
2509
|
+
auto kernel = get_steel_gemm_segmented_kernel(
|
|
2510
|
+
d,
|
|
2511
|
+
base_name,
|
|
2512
|
+
hash_name,
|
|
2513
|
+
func_consts,
|
|
2514
|
+
out,
|
|
2515
|
+
transpose_a,
|
|
2516
|
+
transpose_b,
|
|
2517
|
+
bm,
|
|
2518
|
+
bn,
|
|
2519
|
+
bk,
|
|
2520
|
+
wm,
|
|
2521
|
+
wn);
|
|
2522
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
2523
|
+
|
|
2524
|
+
// Prepare the matmul params
|
|
2525
|
+
steel::GEMMParams params{/* const int M = */ M,
|
|
2526
|
+
/* const int N = */ N,
|
|
2527
|
+
/* const int K = */ K,
|
|
2528
|
+
/* const int lda = */ static_cast<int>(lda),
|
|
2529
|
+
/* const int ldb = */ static_cast<int>(ldb),
|
|
2530
|
+
/* const int ldd = */ N,
|
|
2531
|
+
/* const int tiles_n = */ (N + bn - 1) / bn,
|
|
2532
|
+
/* const int tiles_m = */ (M + bm - 1) / bm,
|
|
2533
|
+
/* const int64_t batch_stride_a = */ 0,
|
|
2534
|
+
/* const int64_t batch_stride_b = */ 0,
|
|
2535
|
+
/* const int64_t batch_stride_d = */ M * N,
|
|
2536
|
+
/* const int swizzle_log = */ 0,
|
|
2537
|
+
/* const int gemm_k_iterations_aligned = */ 0,
|
|
2538
|
+
/* const int batch_ndim = */ 0};
|
|
2539
|
+
|
|
2540
|
+
// Prepare the grid
|
|
2541
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
2542
|
+
MTL::Size grid_dims =
|
|
2543
|
+
MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
|
|
2544
|
+
|
|
2545
|
+
// Launch kernel
|
|
2546
|
+
compute_encoder.set_input_array(a, 0);
|
|
2547
|
+
compute_encoder.set_input_array(b, 1);
|
|
2548
|
+
compute_encoder.set_input_array(segments, 2);
|
|
2549
|
+
compute_encoder.set_output_array(out, 3);
|
|
2550
|
+
compute_encoder.set_bytes(params, 4);
|
|
2551
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
2552
|
+
}
|
|
2553
|
+
|
|
2554
|
+
void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
2555
|
+
auto& s = stream();
|
|
2556
|
+
auto& d = metal::device(s.device);
|
|
2557
|
+
|
|
2558
|
+
auto& a = inputs[0];
|
|
2559
|
+
auto& b = inputs[1];
|
|
2560
|
+
auto& segments = inputs[2];
|
|
2561
|
+
|
|
2562
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
2563
|
+
|
|
2564
|
+
// Extract shapes from inputs.
|
|
2565
|
+
int M = a.shape(-2);
|
|
2566
|
+
int N = b.shape(-1);
|
|
2567
|
+
int K = a.shape(-1);
|
|
2568
|
+
|
|
2569
|
+
segmented_mm(a, b, segments, out, M, N, K, d, s);
|
|
2570
|
+
}
|
|
2571
|
+
|
|
2572
|
+
} // namespace mlx::core
|