mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
|
|
2
|
+
#include "mlx/backend/cuda/steel/mma.cuh"
|
|
3
|
+
#include "mlx/backend/cuda/steel/tiles.cuh"
|
|
4
|
+
|
|
5
|
+
namespace mlx::core::cu {
|
|
6
|
+
|
|
7
|
+
/**
|
|
8
|
+
* An example gemm written with the utils.
|
|
9
|
+
*
|
|
10
|
+
* Computes A @ B.T when A and B are all aligned with the block sizes.
|
|
11
|
+
*/
|
|
12
|
+
template <typename T, int BM, int BN, int BK>
|
|
13
|
+
__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
|
|
14
|
+
constexpr int WARPS_M = 2;
|
|
15
|
+
constexpr int WARPS_N = 2;
|
|
16
|
+
constexpr int NUM_WARPS = WARPS_M * WARPS_N;
|
|
17
|
+
constexpr int WARP_STEP_M = BM / WARPS_M;
|
|
18
|
+
constexpr int WARP_STEP_N = BN / WARPS_N;
|
|
19
|
+
|
|
20
|
+
// Precompute some offsets for each thread
|
|
21
|
+
const int warpid = threadIdx.x / 32;
|
|
22
|
+
const int laneid = threadIdx.x % 32;
|
|
23
|
+
const int wm = warpid / WARPS_N;
|
|
24
|
+
const int wn = warpid % WARPS_N;
|
|
25
|
+
const int offset_m = wm * WARP_STEP_M;
|
|
26
|
+
const int offset_n = wn * WARP_STEP_N;
|
|
27
|
+
|
|
28
|
+
// Allocate shared memory
|
|
29
|
+
extern __shared__ char shmem[];
|
|
30
|
+
SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
|
|
31
|
+
SharedTile<T, BN, BK>(&bs)[2] =
|
|
32
|
+
*(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
|
|
33
|
+
|
|
34
|
+
// Allocate registers for the MMA
|
|
35
|
+
RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
|
|
36
|
+
RegisterTile<T, BM / WARPS_M, 16> A;
|
|
37
|
+
RegisterTile<T, BN / WARPS_N, 16> B;
|
|
38
|
+
|
|
39
|
+
// Move the global pointers to the tile
|
|
40
|
+
a += blockIdx.y * BM * K;
|
|
41
|
+
b += blockIdx.x * BN * K;
|
|
42
|
+
y += blockIdx.y * BM * N + blockIdx.x * BN;
|
|
43
|
+
|
|
44
|
+
// Zero the accumulators
|
|
45
|
+
C.fill(0);
|
|
46
|
+
|
|
47
|
+
// Start the SM pipeline
|
|
48
|
+
load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
|
|
49
|
+
load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
|
|
50
|
+
cp_async_commit();
|
|
51
|
+
|
|
52
|
+
int tic = 0;
|
|
53
|
+
for (int k_block = BK; k_block < K; k_block += BK) {
|
|
54
|
+
load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
|
|
55
|
+
load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
|
|
56
|
+
cp_async_commit();
|
|
57
|
+
cp_async_wait<1>();
|
|
58
|
+
__syncthreads();
|
|
59
|
+
|
|
60
|
+
MLX_UNROLL
|
|
61
|
+
for (int k = 0; k < BK / 16; k++) {
|
|
62
|
+
A.load(
|
|
63
|
+
as[tic],
|
|
64
|
+
as[tic].base_addr(),
|
|
65
|
+
offset_m + laneid % 16,
|
|
66
|
+
k * 16 + laneid / 16 * 8);
|
|
67
|
+
B.load(
|
|
68
|
+
bs[tic],
|
|
69
|
+
bs[tic].base_addr(),
|
|
70
|
+
offset_n + laneid % 16,
|
|
71
|
+
k * 16 + laneid / 16 * 8);
|
|
72
|
+
|
|
73
|
+
mma_t(C, A, B);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
tic ^= 1;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// Empty the pipeline
|
|
80
|
+
cp_async_wait_all();
|
|
81
|
+
__syncthreads();
|
|
82
|
+
MLX_UNROLL
|
|
83
|
+
for (int k = 0; k < BK / 16; k++) {
|
|
84
|
+
A.load(
|
|
85
|
+
as[tic],
|
|
86
|
+
as[tic].base_addr(),
|
|
87
|
+
offset_m + laneid % 16,
|
|
88
|
+
k * 16 + laneid / 16 * 8);
|
|
89
|
+
B.load(
|
|
90
|
+
bs[tic],
|
|
91
|
+
bs[tic].base_addr(),
|
|
92
|
+
offset_n + laneid % 16,
|
|
93
|
+
k * 16 + laneid / 16 * 8);
|
|
94
|
+
|
|
95
|
+
mma_t(C, A, B);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
C.store_global(y, N, offset_m, offset_n);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cuda/steel/defines.cuh"
|
|
6
|
+
#include "mlx/backend/cuda/steel/tiles.cuh"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core::cu {
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* Fallback mma.
|
|
12
|
+
*
|
|
13
|
+
* We should probably a) implement a fallback or complain about it to the
|
|
14
|
+
* compiler.
|
|
15
|
+
*/
|
|
16
|
+
template <typename U, typename T>
|
|
17
|
+
__device__ inline void
|
|
18
|
+
mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
|
|
22
|
+
* float tile.
|
|
23
|
+
*
|
|
24
|
+
* We actually perform C += A @ B.T
|
|
25
|
+
*/
|
|
26
|
+
__device__ __forceinline__ void mma_t(
|
|
27
|
+
Tile16x16<float>& C,
|
|
28
|
+
Tile16x16<__nv_bfloat16>& A,
|
|
29
|
+
Tile16x16<__nv_bfloat16>& B) {
|
|
30
|
+
#if defined(MLX_CUDA_SM_80_ENABLED)
|
|
31
|
+
asm volatile(
|
|
32
|
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
|
33
|
+
"{%0, %1, %2, %3}, "
|
|
34
|
+
"{%4, %5, %6, %7}, "
|
|
35
|
+
"{%8, %9}, "
|
|
36
|
+
"{%10, %11, %12, %13};"
|
|
37
|
+
|
|
38
|
+
// D matrix
|
|
39
|
+
: "+f"(C.values[0].x),
|
|
40
|
+
"+f"(C.values[0].y),
|
|
41
|
+
"+f"(C.values[1].x),
|
|
42
|
+
"+f"(C.values[1].y)
|
|
43
|
+
|
|
44
|
+
// A matrix
|
|
45
|
+
: "r"(*(uint32_t*)(&A.values[0])),
|
|
46
|
+
"r"(*(uint32_t*)(&A.values[1])),
|
|
47
|
+
"r"(*(uint32_t*)(&A.values[2])),
|
|
48
|
+
"r"(*(uint32_t*)(&A.values[3])),
|
|
49
|
+
|
|
50
|
+
// B matrix
|
|
51
|
+
"r"(*(uint32_t*)(&B.values[0])),
|
|
52
|
+
"r"(*(uint32_t*)(&B.values[2])),
|
|
53
|
+
|
|
54
|
+
// C matrix
|
|
55
|
+
"f"(C.values[0].x),
|
|
56
|
+
"f"(C.values[0].y),
|
|
57
|
+
"f"(C.values[1].x),
|
|
58
|
+
"f"(C.values[1].y));
|
|
59
|
+
asm volatile(
|
|
60
|
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
|
61
|
+
"{%0, %1, %2, %3}, "
|
|
62
|
+
"{%4, %5, %6, %7}, "
|
|
63
|
+
"{%8, %9}, "
|
|
64
|
+
"{%10, %11, %12, %13};"
|
|
65
|
+
|
|
66
|
+
// D matrix
|
|
67
|
+
: "+f"(C.values[2].x),
|
|
68
|
+
"+f"(C.values[2].y),
|
|
69
|
+
"+f"(C.values[3].x),
|
|
70
|
+
"+f"(C.values[3].y)
|
|
71
|
+
|
|
72
|
+
// A matrix
|
|
73
|
+
: "r"(*(uint32_t*)(&A.values[0])),
|
|
74
|
+
"r"(*(uint32_t*)(&A.values[1])),
|
|
75
|
+
"r"(*(uint32_t*)(&A.values[2])),
|
|
76
|
+
"r"(*(uint32_t*)(&A.values[3])),
|
|
77
|
+
|
|
78
|
+
// B matrix
|
|
79
|
+
"r"(*(uint32_t*)(&B.values[1])),
|
|
80
|
+
"r"(*(uint32_t*)(&B.values[3])),
|
|
81
|
+
|
|
82
|
+
// C matrix
|
|
83
|
+
"f"(C.values[2].x),
|
|
84
|
+
"f"(C.values[2].y),
|
|
85
|
+
"f"(C.values[3].x),
|
|
86
|
+
"f"(C.values[3].y));
|
|
87
|
+
#endif
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
/**
|
|
91
|
+
* Multiply larger register tiles by delegating to mma_t.
|
|
92
|
+
*/
|
|
93
|
+
template <typename U, typename T, int M, int N, int K>
|
|
94
|
+
__device__ __forceinline__ void mma_t(
|
|
95
|
+
RegisterTile<U, M, N>& C,
|
|
96
|
+
RegisterTile<T, M, K>& A,
|
|
97
|
+
RegisterTile<T, N, K>& B) {
|
|
98
|
+
constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
|
|
99
|
+
constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
|
|
100
|
+
constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
|
|
101
|
+
|
|
102
|
+
MLX_UNROLL
|
|
103
|
+
for (int k = 0; k < TILES_K; k++) {
|
|
104
|
+
MLX_UNROLL
|
|
105
|
+
for (int m = 0; m < TILES_M; m++) {
|
|
106
|
+
MLX_UNROLL
|
|
107
|
+
for (int n = 0; n < TILES_N; n++) {
|
|
108
|
+
mma_t(
|
|
109
|
+
C.data[m * TILES_N + n],
|
|
110
|
+
A.data[m * TILES_K + k],
|
|
111
|
+
B.data[n * TILES_K + k]);
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cuda/steel/utils.cuh"
|
|
6
|
+
#include "mlx/backend/cuda/vector_types.cuh"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core::cu {
|
|
9
|
+
|
|
10
|
+
/**
|
|
11
|
+
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
|
12
|
+
* the warp.
|
|
13
|
+
*
|
|
14
|
+
* Each thread holds 8 values. They are distributed according to
|
|
15
|
+
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
|
16
|
+
*
|
|
17
|
+
* For use instructions see the individual methods eg load().
|
|
18
|
+
*/
|
|
19
|
+
template <typename T>
|
|
20
|
+
struct Tile16x16 {
|
|
21
|
+
using T2 = Vector2_t<T>;
|
|
22
|
+
|
|
23
|
+
T2 values[4];
|
|
24
|
+
|
|
25
|
+
__device__ inline void fill(T v) {
|
|
26
|
+
T2 v2 = {v, v};
|
|
27
|
+
for (int i = 0; i < 4; i++) {
|
|
28
|
+
values[i] = v2;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* Load a 16x16 tile from shared memory.
|
|
34
|
+
*
|
|
35
|
+
* The instruction is a bit weird in the sense that the address provided by
|
|
36
|
+
* each thread and the elements loaded are not the same.
|
|
37
|
+
*
|
|
38
|
+
* We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
|
|
39
|
+
* result the warp provides 4*8 = 32 addresses one per row.
|
|
40
|
+
*
|
|
41
|
+
* Threads 0-7 provide the addresses for the first tile, 8-15 for the second
|
|
42
|
+
* and so on. For instance to load a non swizzled tile we would do
|
|
43
|
+
*
|
|
44
|
+
* base_addr + (laneid % 16) * BK + (laneid / 2) * 8
|
|
45
|
+
*
|
|
46
|
+
* See
|
|
47
|
+
* https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
|
|
48
|
+
*/
|
|
49
|
+
__device__ __forceinline__ void load(uint32_t row_address) {
|
|
50
|
+
if constexpr (
|
|
51
|
+
std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
|
|
52
|
+
asm volatile(
|
|
53
|
+
"ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
|
|
54
|
+
: "=r"(*(uint32_t*)&(values[0])),
|
|
55
|
+
"=r"(*(uint32_t*)&(values[1])),
|
|
56
|
+
"=r"(*(uint32_t*)&(values[2])),
|
|
57
|
+
"=r"(*(uint32_t*)&(values[3]))
|
|
58
|
+
: "r"(row_address));
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
/**
|
|
63
|
+
* Store the tile to the address pointed to by `x`.
|
|
64
|
+
*
|
|
65
|
+
* The provided pointer is a generic pointer but this is meant to be used to
|
|
66
|
+
* store to global memory. For storing to shared memory we should use
|
|
67
|
+
* `stmatrix`.
|
|
68
|
+
*
|
|
69
|
+
* This also showcases the format of the tile quite nicely. Each register is
|
|
70
|
+
* holding to adjacent values. The indices are
|
|
71
|
+
*
|
|
72
|
+
* row + 0, col + 0
|
|
73
|
+
* row + 8, col + 0
|
|
74
|
+
* row + 0, col + 8
|
|
75
|
+
* row + 8, col + 8
|
|
76
|
+
*
|
|
77
|
+
* Given that we are dealing with Vector2_t<U> the column offsets are 4
|
|
78
|
+
* instead of 8.
|
|
79
|
+
*/
|
|
80
|
+
template <typename U>
|
|
81
|
+
__device__ inline void store_global(U* x, int N) {
|
|
82
|
+
using U2 = Vector2_t<U>;
|
|
83
|
+
U2* x2 = reinterpret_cast<U2*>(x);
|
|
84
|
+
const int laneid = threadIdx.x % 32;
|
|
85
|
+
const int row = laneid / 4;
|
|
86
|
+
const int col = laneid % 4;
|
|
87
|
+
if constexpr (std::is_same_v<U2, T2>) {
|
|
88
|
+
x2[(row + 0) * (N / 2) + col + 0] = values[0];
|
|
89
|
+
x2[(row + 0) * (N / 2) + col + 4] = values[2];
|
|
90
|
+
x2[(row + 8) * (N / 2) + col + 0] = values[1];
|
|
91
|
+
x2[(row + 8) * (N / 2) + col + 4] = values[3];
|
|
92
|
+
} else if constexpr (
|
|
93
|
+
std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
|
|
94
|
+
x2[(row + 0) * (N / 2) + col + 0] =
|
|
95
|
+
__floats2bfloat162_rn(values[0].x, values[0].y);
|
|
96
|
+
x2[(row + 0) * (N / 2) + col + 4] =
|
|
97
|
+
__floats2bfloat162_rn(values[2].x, values[2].y);
|
|
98
|
+
x2[(row + 8) * (N / 2) + col + 0] =
|
|
99
|
+
__floats2bfloat162_rn(values[1].x, values[1].y);
|
|
100
|
+
x2[(row + 8) * (N / 2) + col + 4] =
|
|
101
|
+
__floats2bfloat162_rn(values[3].x, values[3].y);
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
template <typename U>
|
|
106
|
+
__device__ inline void store_global_safe(U* x, int N, int max_rows) {
|
|
107
|
+
const int laneid = threadIdx.x % 32;
|
|
108
|
+
const int row = laneid / 4;
|
|
109
|
+
const int col = laneid % 4;
|
|
110
|
+
if (row < max_rows) {
|
|
111
|
+
x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
|
|
112
|
+
x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
|
|
113
|
+
x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
|
|
114
|
+
x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
|
|
115
|
+
}
|
|
116
|
+
if (row + 8 < max_rows) {
|
|
117
|
+
x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
|
|
118
|
+
x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
|
|
119
|
+
x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
|
|
120
|
+
x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
};
|
|
124
|
+
|
|
125
|
+
/**
|
|
126
|
+
* A simple container of multiple Tile16x16.
|
|
127
|
+
*
|
|
128
|
+
* Provides utility functions for loading and manipulating collections of basic
|
|
129
|
+
* tiles.
|
|
130
|
+
*/
|
|
131
|
+
template <typename T, int ROWS_, int COLS_>
|
|
132
|
+
struct RegisterTile {
|
|
133
|
+
static constexpr int ROWS = ROWS_;
|
|
134
|
+
static constexpr int COLS = COLS_;
|
|
135
|
+
static constexpr int TILES_X = COLS / 16;
|
|
136
|
+
static constexpr int TILES_Y = ROWS / 16;
|
|
137
|
+
|
|
138
|
+
Tile16x16<T> data[TILES_X * TILES_Y];
|
|
139
|
+
|
|
140
|
+
__device__ inline void fill(T v) {
|
|
141
|
+
MLX_UNROLL
|
|
142
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
143
|
+
MLX_UNROLL
|
|
144
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
145
|
+
data[i * TILES_X + j].fill(v);
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
template <typename Tile>
|
|
151
|
+
__device__ __forceinline__ void
|
|
152
|
+
load(Tile& tile, uint32_t base_address, int row, int col) {
|
|
153
|
+
MLX_UNROLL
|
|
154
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
155
|
+
MLX_UNROLL
|
|
156
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
157
|
+
data[i * TILES_X + j].load(
|
|
158
|
+
tile.loc(base_address, row + i * 16, col + j * 16));
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
template <typename Tile, typename F>
|
|
164
|
+
__device__ __forceinline__ void
|
|
165
|
+
load(Tile& tile, F f, uint32_t base_address, int row, int col) {
|
|
166
|
+
MLX_UNROLL
|
|
167
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
168
|
+
MLX_UNROLL
|
|
169
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
170
|
+
f(data[i * TILES_X + j],
|
|
171
|
+
tile,
|
|
172
|
+
base_address,
|
|
173
|
+
row + i * 16,
|
|
174
|
+
col + j * 16);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
template <typename U>
|
|
180
|
+
__device__ inline void store_global(U* x, int N, int row, int col) {
|
|
181
|
+
MLX_UNROLL
|
|
182
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
183
|
+
MLX_UNROLL
|
|
184
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
185
|
+
data[i * TILES_X + j].store_global(
|
|
186
|
+
x + (row + i * 16) * N + col + j * 16, N);
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
template <typename U>
|
|
192
|
+
__device__ inline void
|
|
193
|
+
store_global_safe(U* x, int N, int row, int col, int max_rows) {
|
|
194
|
+
MLX_UNROLL
|
|
195
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
196
|
+
MLX_UNROLL
|
|
197
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
198
|
+
data[i * TILES_X + j].store_global_safe(
|
|
199
|
+
x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
};
|
|
204
|
+
|
|
205
|
+
/**
|
|
206
|
+
* A simple container of multiple Tile16x16.
|
|
207
|
+
*
|
|
208
|
+
* Provides utility functions for loading and manipulating collections of basic
|
|
209
|
+
* tiles.
|
|
210
|
+
*/
|
|
211
|
+
template <typename T, int ROWS_, int COLS_>
|
|
212
|
+
struct RegisterTile {
|
|
213
|
+
static constexpr int ROWS = ROWS_;
|
|
214
|
+
static constexpr int COLS = COLS_;
|
|
215
|
+
static constexpr int TILES_X = COLS / 16;
|
|
216
|
+
static constexpr int TILES_Y = ROWS / 16;
|
|
217
|
+
|
|
218
|
+
Tile16x16<T> data[TILES_X * TILES_Y];
|
|
219
|
+
|
|
220
|
+
__device__ inline void fill(T v) {
|
|
221
|
+
MLX_UNROLL
|
|
222
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
223
|
+
MLX_UNROLL
|
|
224
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
225
|
+
data[i * TILES_X + j].fill(v);
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
template <typename Tile>
|
|
231
|
+
__device__ inline void
|
|
232
|
+
load(Tile& tile, uint32_t base_address, int row, int col) {
|
|
233
|
+
MLX_UNROLL
|
|
234
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
235
|
+
MLX_UNROLL
|
|
236
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
237
|
+
data[i * TILES_X + j].load(
|
|
238
|
+
tile.loc(base_address, row + i * 16, col + j * 16));
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
template <typename U>
|
|
244
|
+
__device__ inline void store_global(U* x, int N, int row, int col) {
|
|
245
|
+
MLX_UNROLL
|
|
246
|
+
for (int i = 0; i < TILES_Y; i++) {
|
|
247
|
+
MLX_UNROLL
|
|
248
|
+
for (int j = 0; j < TILES_X; j++) {
|
|
249
|
+
data[i * TILES_X + j].store_global(
|
|
250
|
+
x + (row + i * 16) * N + col + j * 16, N);
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
};
|
|
255
|
+
|
|
256
|
+
template <typename T, int ROWS_, int COLS_>
|
|
257
|
+
struct SharedTile {
|
|
258
|
+
static constexpr int ROWS = ROWS_;
|
|
259
|
+
static constexpr int COLS = COLS_;
|
|
260
|
+
static constexpr int TILES_X = COLS / 16;
|
|
261
|
+
static constexpr int TILES_Y = ROWS / 16;
|
|
262
|
+
static constexpr int NUMEL = ROWS * COLS;
|
|
263
|
+
|
|
264
|
+
// Swizzle taken from ThunderKittens. Should be changed when we switch to
|
|
265
|
+
// cute Layouts.
|
|
266
|
+
//
|
|
267
|
+
// See inludes/types/shared/st.cuh
|
|
268
|
+
//
|
|
269
|
+
// I do feel that it is too math heavy and can be improved. Also the math is
|
|
270
|
+
// done every time although the addresses don't change from load to load. I
|
|
271
|
+
// guess we are expecting the compiler to figure that out.
|
|
272
|
+
static constexpr int swizzle_bytes =
|
|
273
|
+
(sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
|
|
274
|
+
: (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
|
|
275
|
+
|
|
276
|
+
T data[ROWS * COLS];
|
|
277
|
+
|
|
278
|
+
__device__ inline uint32_t base_addr() const {
|
|
279
|
+
return __cvta_generic_to_shared(&data[0]);
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// Return a pointer to the element at (row, col) using the swizzle.
|
|
283
|
+
__device__ static inline T* ptr(T* ptr, int row, int col) {
|
|
284
|
+
if constexpr (swizzle_bytes > 0) {
|
|
285
|
+
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
|
286
|
+
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
|
287
|
+
const int outer_idx = col / subtile_cols;
|
|
288
|
+
const uint64_t addr =
|
|
289
|
+
(uint64_t)(&ptr
|
|
290
|
+
[outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
|
291
|
+
col % subtile_cols]);
|
|
292
|
+
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
|
293
|
+
return (T*)(addr ^ swizzle);
|
|
294
|
+
} else {
|
|
295
|
+
return ptr + row * COLS + col;
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
// Return the location of the element at (row, col) using the swizzle.
|
|
300
|
+
__device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
|
|
301
|
+
if constexpr (swizzle_bytes > 0) {
|
|
302
|
+
static constexpr int swizzle_repeat = swizzle_bytes * 8;
|
|
303
|
+
static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
|
|
304
|
+
const int outer_idx = col / subtile_cols;
|
|
305
|
+
const uint32_t addr = ptr +
|
|
306
|
+
sizeof(T) *
|
|
307
|
+
(outer_idx * ROWS * subtile_cols + row * subtile_cols +
|
|
308
|
+
col % subtile_cols);
|
|
309
|
+
const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
|
310
|
+
return (addr ^ swizzle);
|
|
311
|
+
} else {
|
|
312
|
+
return ptr + sizeof(T) * (row * COLS + col);
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
// Convenience functions to edit elements going through the swizzle.
|
|
317
|
+
__device__ inline T& operator()(int row, int col) {
|
|
318
|
+
return *ptr(data, row, col);
|
|
319
|
+
}
|
|
320
|
+
__device__ inline void store(float4& v, int row, int col) {
|
|
321
|
+
*(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
|
|
322
|
+
}
|
|
323
|
+
__device__ inline void store(float2& v, int row, int col) {
|
|
324
|
+
*(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
|
|
325
|
+
}
|
|
326
|
+
__device__ inline void store(float& v, int row, int col) {
|
|
327
|
+
*(reinterpret_cast<float*>(ptr(data, row, col))) = v;
|
|
328
|
+
}
|
|
329
|
+
template <int N>
|
|
330
|
+
__device__ inline void store(T (&v)[N], int row, int col) {
|
|
331
|
+
if constexpr (sizeof(T) * N == 4) {
|
|
332
|
+
store(*(reinterpret_cast<float*>(&v[0])), row, col);
|
|
333
|
+
} else if constexpr (sizeof(T) * N == 8) {
|
|
334
|
+
store(*(reinterpret_cast<float2*>(&v[0])), row, col);
|
|
335
|
+
} else if constexpr (sizeof(T) * N == 16) {
|
|
336
|
+
store(*(reinterpret_cast<float4*>(&v[0])), row, col);
|
|
337
|
+
} else {
|
|
338
|
+
MLX_UNROLL
|
|
339
|
+
for (int i = 0; i < N; i++) {
|
|
340
|
+
*ptr(data, row, col + i) = v[i];
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
};
|
|
345
|
+
|
|
346
|
+
/**
|
|
347
|
+
* Load the tile from global memory by loading 16 bytes at a time and storing
|
|
348
|
+
* them immediately.
|
|
349
|
+
*
|
|
350
|
+
* Can also be used as a fallback for architectures before sm_80.
|
|
351
|
+
*/
|
|
352
|
+
template <int NUM_WARPS, typename T, typename Tile>
|
|
353
|
+
__device__ inline void load(Tile& tile, const T* x, int N) {
|
|
354
|
+
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
|
355
|
+
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
|
356
|
+
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
|
357
|
+
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
|
358
|
+
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
|
359
|
+
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
|
360
|
+
|
|
361
|
+
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
|
362
|
+
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
|
363
|
+
|
|
364
|
+
x += row * N + col * ELEMENTS_PER_LOAD;
|
|
365
|
+
|
|
366
|
+
MLX_UNROLL
|
|
367
|
+
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
|
368
|
+
float4 tmp;
|
|
369
|
+
tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
|
|
370
|
+
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
/**
|
|
375
|
+
* The asynchronous equivalent of load.
|
|
376
|
+
*
|
|
377
|
+
* Loads the tile from global memory by submitting a bunch of async copy
|
|
378
|
+
* instructions. The copy won't start until commit is called and we don't have
|
|
379
|
+
* a guarantee it will finish until wait is called.
|
|
380
|
+
*
|
|
381
|
+
* It should be used as follows
|
|
382
|
+
*
|
|
383
|
+
* load(...)
|
|
384
|
+
* load(...)
|
|
385
|
+
* cp_async_commit()
|
|
386
|
+
* do_other_stuff()
|
|
387
|
+
* cp_async_wait_all()
|
|
388
|
+
* do_stuff_with_shmem()
|
|
389
|
+
*/
|
|
390
|
+
template <int NUM_WARPS, typename T, typename Tile>
|
|
391
|
+
__device__ inline void
|
|
392
|
+
load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
|
|
393
|
+
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
|
394
|
+
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
|
395
|
+
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
|
396
|
+
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
|
397
|
+
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
|
398
|
+
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
|
399
|
+
|
|
400
|
+
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
|
401
|
+
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
|
402
|
+
|
|
403
|
+
x += row * N + col * ELEMENTS_PER_LOAD;
|
|
404
|
+
|
|
405
|
+
MLX_UNROLL
|
|
406
|
+
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
|
407
|
+
cp_async<16>(
|
|
408
|
+
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
|
409
|
+
x + i * STEP_ROWS * N);
|
|
410
|
+
}
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
/**
|
|
414
|
+
* Same as load_async but checks if we can load the row.
|
|
415
|
+
*
|
|
416
|
+
* NOTE: It should be changed to use a predicated cp async instead.
|
|
417
|
+
*/
|
|
418
|
+
template <int NUM_WARPS, typename T, typename Tile>
|
|
419
|
+
__device__ inline void load_async_safe(
|
|
420
|
+
Tile& tile,
|
|
421
|
+
uint32_t base_address,
|
|
422
|
+
const T* x,
|
|
423
|
+
int N,
|
|
424
|
+
int max_rows) {
|
|
425
|
+
constexpr int NUM_THREADS = NUM_WARPS * 32;
|
|
426
|
+
constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
|
|
427
|
+
constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
|
|
428
|
+
constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
|
|
429
|
+
constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
|
|
430
|
+
constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
|
|
431
|
+
|
|
432
|
+
const int row = threadIdx.x / NUM_LOADS_PER_ROW;
|
|
433
|
+
const int col = threadIdx.x % NUM_LOADS_PER_ROW;
|
|
434
|
+
|
|
435
|
+
x += row * N + col * ELEMENTS_PER_LOAD;
|
|
436
|
+
|
|
437
|
+
MLX_UNROLL
|
|
438
|
+
for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
|
|
439
|
+
if (row + i * STEP_ROWS < max_rows) {
|
|
440
|
+
cp_async<16>(
|
|
441
|
+
tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
|
|
442
|
+
x + i * STEP_ROWS * N);
|
|
443
|
+
} else {
|
|
444
|
+
float4 tmp = {0, 0, 0, 0};
|
|
445
|
+
tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
} // namespace mlx::core::cu
|