mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,868 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <metal_simdgroup>
|
|
4
|
+
#include <metal_stdlib>
|
|
5
|
+
|
|
6
|
+
#include "mlx/backend/metal/kernels/utils.h"
|
|
7
|
+
|
|
8
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
9
|
+
|
|
10
|
+
using namespace metal;
|
|
11
|
+
|
|
12
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
13
|
+
/// Matrix vector multiplication
|
|
14
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
15
|
+
|
|
16
|
+
#define MLX_MTL_CONST static constant constexpr const
|
|
17
|
+
|
|
18
|
+
template <typename U>
|
|
19
|
+
struct DefaultAccT {
|
|
20
|
+
using type = float;
|
|
21
|
+
};
|
|
22
|
+
template <>
|
|
23
|
+
struct DefaultAccT<complex64_t> {
|
|
24
|
+
using type = complex64_t;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
template <
|
|
28
|
+
typename T,
|
|
29
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
30
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
31
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
32
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
33
|
+
const int TM, /* Thread rows (in elements) */
|
|
34
|
+
const int TN, /* Thread cols (in elements) */
|
|
35
|
+
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
|
36
|
+
typename AccT = typename DefaultAccT<T>::type>
|
|
37
|
+
struct GEMVKernel {
|
|
38
|
+
using acc_type = AccT;
|
|
39
|
+
|
|
40
|
+
MLX_MTL_CONST int threadsM = BM * SM;
|
|
41
|
+
MLX_MTL_CONST int threadsN = BN * SN;
|
|
42
|
+
|
|
43
|
+
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
44
|
+
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
45
|
+
|
|
46
|
+
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
47
|
+
|
|
48
|
+
static_assert(
|
|
49
|
+
SN == 4 || SN == 8 || SN == 16 || SN == 32,
|
|
50
|
+
"gemv block must have a width of 4, 8, 16, or 32");
|
|
51
|
+
|
|
52
|
+
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
|
53
|
+
// into blocks of (blockM, blockN) divided among threadgroups
|
|
54
|
+
// - Every thread works on a block of (TM, TN)
|
|
55
|
+
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
56
|
+
//
|
|
57
|
+
// 1. A thread loads TN elements each from mat along TM rows
|
|
58
|
+
// and the corresponding scalar from the vector
|
|
59
|
+
// 2. The thread then multiplies and adds to accumulate its local result for
|
|
60
|
+
// the block
|
|
61
|
+
// 3. At the end, each thread has accumulated results over all blocks across
|
|
62
|
+
// the rows. These are then summed up across the threadgroup
|
|
63
|
+
// 4. Each threadgroup writes its accumulated blockM outputs
|
|
64
|
+
//
|
|
65
|
+
// Edge case handling:
|
|
66
|
+
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
67
|
+
// * The blocks that start outside the matrix are never read (thread results
|
|
68
|
+
// remain zero)
|
|
69
|
+
// * The last thread that partially overlaps with the matrix is shifted
|
|
70
|
+
// inwards such that the thread block fits exactly in the matrix
|
|
71
|
+
|
|
72
|
+
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
|
73
|
+
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
|
74
|
+
|
|
75
|
+
template <typename U = T>
|
|
76
|
+
static METAL_FUNC void
|
|
77
|
+
load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
|
|
78
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
79
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
80
|
+
dst[tn] = static_cast<U>(src[src_offset + tn]);
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
template <typename U = T>
|
|
85
|
+
static METAL_FUNC void load_safe(
|
|
86
|
+
const device T* src,
|
|
87
|
+
thread U dst[TN],
|
|
88
|
+
const int src_offset = 0,
|
|
89
|
+
const int src_size = TN) {
|
|
90
|
+
if (src_offset + TN <= src_size) {
|
|
91
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
92
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
93
|
+
dst[tn] = static_cast<U>(src[src_offset + tn]);
|
|
94
|
+
}
|
|
95
|
+
} else { // Edgecase
|
|
96
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
97
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
98
|
+
dst[tn] = src_offset + tn < src_size
|
|
99
|
+
? static_cast<U>(src[src_offset + tn])
|
|
100
|
+
: U(0);
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
static METAL_FUNC void run(
|
|
106
|
+
const device T* mat [[buffer(0)]],
|
|
107
|
+
const device T* in_vec [[buffer(1)]],
|
|
108
|
+
const device T* bias [[buffer(2)]],
|
|
109
|
+
device T* out_vec [[buffer(3)]],
|
|
110
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
111
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
112
|
+
const constant int& matrix_ld [[buffer(6)]],
|
|
113
|
+
const constant float& alpha [[buffer(7)]],
|
|
114
|
+
const constant float& beta [[buffer(8)]],
|
|
115
|
+
const constant int& bias_stride [[buffer(14)]],
|
|
116
|
+
threadgroup AccT* tgp_memory [[threadgroup(0)]],
|
|
117
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
118
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
119
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
120
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
121
|
+
// Appease compiler
|
|
122
|
+
(void)lid;
|
|
123
|
+
|
|
124
|
+
// Thread local accumulation results
|
|
125
|
+
thread AccT result[TM] = {0};
|
|
126
|
+
thread T inter[TN];
|
|
127
|
+
thread AccT v_coeff[TN];
|
|
128
|
+
|
|
129
|
+
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
130
|
+
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
131
|
+
|
|
132
|
+
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
133
|
+
|
|
134
|
+
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
|
135
|
+
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
|
136
|
+
|
|
137
|
+
int bm = (simdM + thrM) * TM;
|
|
138
|
+
int bn = (simdN + thrN) * TN;
|
|
139
|
+
|
|
140
|
+
// Block position
|
|
141
|
+
int out_row = tid.x * blockM + bm;
|
|
142
|
+
|
|
143
|
+
// Exit simdgroup if rows out of bound
|
|
144
|
+
if (out_row >= out_vec_size)
|
|
145
|
+
return;
|
|
146
|
+
|
|
147
|
+
// Adjust tail simdgroup to ensure in bound reads
|
|
148
|
+
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
|
149
|
+
|
|
150
|
+
// Advance matrix
|
|
151
|
+
mat += out_row * matrix_ld;
|
|
152
|
+
|
|
153
|
+
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
|
154
|
+
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
155
|
+
const uniform<int> n_iter = in_size / loop_stride;
|
|
156
|
+
const uniform<int> last_iter = loop_stride * n_iter;
|
|
157
|
+
const uniform<int> leftover = in_size - last_iter;
|
|
158
|
+
|
|
159
|
+
// Loop over in_vec in blocks of blockN
|
|
160
|
+
for (int i = 0; i < n_iter; ++i) {
|
|
161
|
+
load_unsafe<AccT>(in_vec, v_coeff, bn);
|
|
162
|
+
|
|
163
|
+
// Per thread work loop
|
|
164
|
+
int mat_offset = 0;
|
|
165
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
166
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
167
|
+
// Load for the row
|
|
168
|
+
load_unsafe(mat, inter, mat_offset + bn);
|
|
169
|
+
|
|
170
|
+
// Accumulate results
|
|
171
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
172
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
173
|
+
result[tm] += inter[tn] * v_coeff[tn];
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
mat_offset += matrix_ld;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
bn += blockN;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
if (leftover > 0) {
|
|
183
|
+
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
|
184
|
+
|
|
185
|
+
// Per thread work loop
|
|
186
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
187
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
188
|
+
// Load for the row
|
|
189
|
+
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
|
190
|
+
|
|
191
|
+
// Accumulate results
|
|
192
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
193
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
194
|
+
result[tm] += inter[tn] * v_coeff[tn];
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// Simdgroup accumulations
|
|
200
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
201
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
202
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
203
|
+
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
|
204
|
+
result[tm] += simd_shuffle_down(result[tm], sn);
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// Threadgroup accumulation results
|
|
209
|
+
if (needs_tgp_reduction) {
|
|
210
|
+
threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
|
211
|
+
if (thrN == 0) {
|
|
212
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
213
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
214
|
+
tgp_results[tm] = result[tm];
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
218
|
+
|
|
219
|
+
if (sgN == 0) {
|
|
220
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
221
|
+
for (int sgn = 1; sgn < BN; sgn++) {
|
|
222
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
223
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
224
|
+
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// Write outputs
|
|
232
|
+
if (simdN == 0 && thrN == 0) {
|
|
233
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
234
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
235
|
+
if (kDoAxpby) {
|
|
236
|
+
out_vec[out_row + tm] =
|
|
237
|
+
static_cast<T>(alpha) * static_cast<T>(result[tm]) +
|
|
238
|
+
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
|
|
239
|
+
} else {
|
|
240
|
+
out_vec[out_row + tm] = static_cast<T>(result[tm]);
|
|
241
|
+
}
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
};
|
|
246
|
+
|
|
247
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
248
|
+
/// Vector matrix multiplication
|
|
249
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
250
|
+
|
|
251
|
+
template <
|
|
252
|
+
typename T,
|
|
253
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
254
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
255
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
256
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
257
|
+
const int TM, /* Thread rows (in elements) */
|
|
258
|
+
const int TN, /* Thread cols (in elements) */
|
|
259
|
+
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
|
260
|
+
typename AccT = typename DefaultAccT<T>::type>
|
|
261
|
+
struct GEMVTKernel {
|
|
262
|
+
using acc_type = AccT;
|
|
263
|
+
|
|
264
|
+
MLX_MTL_CONST int threadsM = BM * SM;
|
|
265
|
+
MLX_MTL_CONST int threadsN = BN * SN;
|
|
266
|
+
|
|
267
|
+
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
268
|
+
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
269
|
+
|
|
270
|
+
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
271
|
+
|
|
272
|
+
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
|
273
|
+
// into blocks of (blockM, blockN) divided among threadgroups
|
|
274
|
+
// - Every thread works on a block of (TM, TN)
|
|
275
|
+
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
276
|
+
//
|
|
277
|
+
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
|
278
|
+
// and the corresponding scalar from the vector
|
|
279
|
+
// 2. The thread then accumulates its local result for the block
|
|
280
|
+
// 3. At the end, each thread has accumulated results over all blocks across
|
|
281
|
+
// the rows. These are then summed up across the threadgroup
|
|
282
|
+
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
|
283
|
+
//
|
|
284
|
+
// Edge case handling:
|
|
285
|
+
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
286
|
+
// * The blocks that start outside the matrix are never read (thread results
|
|
287
|
+
// remain zero)
|
|
288
|
+
// * The last thread that partially overlaps with the matrix is shifted
|
|
289
|
+
// inwards such that the thread block fits exactly in the matrix
|
|
290
|
+
|
|
291
|
+
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
|
292
|
+
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
|
293
|
+
|
|
294
|
+
static METAL_FUNC void run(
|
|
295
|
+
const device T* mat [[buffer(0)]],
|
|
296
|
+
const device T* in_vec [[buffer(1)]],
|
|
297
|
+
const device T* bias [[buffer(2)]],
|
|
298
|
+
device T* out_vec [[buffer(3)]],
|
|
299
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
300
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
301
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
302
|
+
const constant float& alpha [[buffer(7)]],
|
|
303
|
+
const constant float& beta [[buffer(8)]],
|
|
304
|
+
const constant int& bias_stride [[buffer(14)]],
|
|
305
|
+
threadgroup AccT* tgp_memory [[threadgroup(0)]],
|
|
306
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
307
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
308
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
309
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
310
|
+
// Appease compiler
|
|
311
|
+
(void)lid;
|
|
312
|
+
|
|
313
|
+
// Thread local accumulation results
|
|
314
|
+
AccT result[TN] = {0};
|
|
315
|
+
T inter[TN];
|
|
316
|
+
AccT v_coeff[TM];
|
|
317
|
+
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
318
|
+
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
319
|
+
|
|
320
|
+
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
|
321
|
+
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
322
|
+
|
|
323
|
+
const int simdM = SM * sgM;
|
|
324
|
+
const int simdN = SN * sgN;
|
|
325
|
+
|
|
326
|
+
int cm = (simdM + thrM);
|
|
327
|
+
int cn = (simdN + thrN);
|
|
328
|
+
|
|
329
|
+
int bm = cm * TM;
|
|
330
|
+
int bn = cn * TN;
|
|
331
|
+
|
|
332
|
+
int out_col = tid.x * blockN + bn;
|
|
333
|
+
|
|
334
|
+
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
|
335
|
+
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
336
|
+
const uniform<int> n_iter = in_size / loop_stride;
|
|
337
|
+
const uniform<int> last_iter = loop_stride * n_iter;
|
|
338
|
+
const uniform<int> leftover = in_size - last_iter;
|
|
339
|
+
|
|
340
|
+
// Edgecase handling
|
|
341
|
+
if (out_col < out_vec_size) {
|
|
342
|
+
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
|
|
343
|
+
|
|
344
|
+
// Per thread accumulation main loop
|
|
345
|
+
for (int i = 0; i < n_iter; ++i) {
|
|
346
|
+
// Adding a threadgroup_barrier improves performance slightly
|
|
347
|
+
// This is possibly it may help exploit cache better
|
|
348
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
349
|
+
|
|
350
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
351
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
352
|
+
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
356
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
357
|
+
auto vc = static_cast<AccT>(v_coeff[tm]);
|
|
358
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
359
|
+
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
360
|
+
}
|
|
361
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
362
|
+
result[tn] += vc * inter[tn];
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
bm += blockM;
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
if (leftover > 0) {
|
|
370
|
+
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
|
371
|
+
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
372
|
+
|
|
373
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
374
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
375
|
+
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
379
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
380
|
+
result[tn] += v_coeff[tm] * inter[tn];
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
// Simdgroup accumulations
|
|
387
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
388
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
389
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
390
|
+
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
|
391
|
+
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
// Threadgroup accumulation results
|
|
396
|
+
if (needs_tgp_reduction) {
|
|
397
|
+
threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
|
398
|
+
if (thrM == 0) {
|
|
399
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
400
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
401
|
+
tgp_results[tn] = result[tn];
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
405
|
+
|
|
406
|
+
if (sgM == 0) {
|
|
407
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
408
|
+
for (int sgm = 1; sgm < BM; sgm++) {
|
|
409
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
410
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
411
|
+
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
// Threadgroup accumulation and writing out results
|
|
419
|
+
if (cm == 0 && out_col < out_vec_size) {
|
|
420
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
421
|
+
for (int j = 0; j < TN; j++) {
|
|
422
|
+
if (kDoAxpby) {
|
|
423
|
+
out_vec[out_col + j] =
|
|
424
|
+
static_cast<T>(alpha) * static_cast<T>(result[j]) +
|
|
425
|
+
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
|
|
426
|
+
} else {
|
|
427
|
+
out_vec[out_col + j] = static_cast<T>(result[j]);
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
};
|
|
433
|
+
|
|
434
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
435
|
+
/// Matrix vector multiplication
|
|
436
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
437
|
+
|
|
438
|
+
template <
|
|
439
|
+
typename T,
|
|
440
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
441
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
442
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
443
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
444
|
+
const int TM, /* Thread rows (in elements) */
|
|
445
|
+
const int TN, /* Thread cols (in elements) */
|
|
446
|
+
const bool kDoNCBatch, /* Batch ndim > 1 */
|
|
447
|
+
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
|
448
|
+
[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv(
|
|
449
|
+
const device T* mat [[buffer(0)]],
|
|
450
|
+
const device T* in_vec [[buffer(1)]],
|
|
451
|
+
const device T* bias [[buffer(2)]],
|
|
452
|
+
device T* out_vec [[buffer(3)]],
|
|
453
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
454
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
455
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
456
|
+
const constant float& alpha [[buffer(7)]],
|
|
457
|
+
const constant float& beta [[buffer(8)]],
|
|
458
|
+
const constant int& batch_ndim [[buffer(9)]],
|
|
459
|
+
const constant int* batch_shape [[buffer(10)]],
|
|
460
|
+
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
461
|
+
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
462
|
+
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
|
463
|
+
const constant int& bias_stride [[buffer(14)]],
|
|
464
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
465
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
466
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
467
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
468
|
+
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
|
469
|
+
threadgroup typename gemv_kernel::acc_type tgp_memory
|
|
470
|
+
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
471
|
+
|
|
472
|
+
// Update batch offsets
|
|
473
|
+
if (kDoNCBatch) {
|
|
474
|
+
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
475
|
+
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
476
|
+
|
|
477
|
+
if (kDoAxpby) {
|
|
478
|
+
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
} else {
|
|
482
|
+
in_vec += tid.z * vector_batch_stride[0];
|
|
483
|
+
mat += tid.z * matrix_batch_stride[0];
|
|
484
|
+
|
|
485
|
+
if (kDoAxpby) {
|
|
486
|
+
bias += tid.z * bias_batch_stride[0];
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
out_vec += tid.z * out_vec_size;
|
|
491
|
+
|
|
492
|
+
gemv_kernel::run(
|
|
493
|
+
mat,
|
|
494
|
+
in_vec,
|
|
495
|
+
bias,
|
|
496
|
+
out_vec,
|
|
497
|
+
in_vec_size,
|
|
498
|
+
out_vec_size,
|
|
499
|
+
marix_ld,
|
|
500
|
+
alpha,
|
|
501
|
+
beta,
|
|
502
|
+
bias_stride,
|
|
503
|
+
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
504
|
+
tid,
|
|
505
|
+
lid,
|
|
506
|
+
simd_gid,
|
|
507
|
+
simd_lid);
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
#define instantiate_gemv_helper( \
|
|
511
|
+
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
|
512
|
+
instantiate_kernel( \
|
|
513
|
+
"gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
|
|
514
|
+
"_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
|
515
|
+
gemv, \
|
|
516
|
+
itype, \
|
|
517
|
+
bm, \
|
|
518
|
+
bn, \
|
|
519
|
+
sm, \
|
|
520
|
+
sn, \
|
|
521
|
+
tm, \
|
|
522
|
+
tn, \
|
|
523
|
+
nc, \
|
|
524
|
+
axpby)
|
|
525
|
+
|
|
526
|
+
// clang-format off
|
|
527
|
+
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
|
|
528
|
+
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
|
529
|
+
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
|
530
|
+
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
|
|
531
|
+
instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
|
|
532
|
+
|
|
533
|
+
// clang-format off
|
|
534
|
+
#define instantiate_gemv_blocks(name, itype) \
|
|
535
|
+
instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
|
|
536
|
+
instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
|
|
537
|
+
instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
|
|
538
|
+
instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
|
|
539
|
+
instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
|
|
540
|
+
instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
|
|
541
|
+
instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
|
542
|
+
|
|
543
|
+
instantiate_gemv_blocks(float32, float);
|
|
544
|
+
instantiate_gemv_blocks(float16, half);
|
|
545
|
+
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
|
546
|
+
instantiate_gemv_blocks(complex64, complex64_t);
|
|
547
|
+
|
|
548
|
+
template <
|
|
549
|
+
typename T,
|
|
550
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
551
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
552
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
553
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
554
|
+
const int TM, /* Thread rows (in elements) */
|
|
555
|
+
const int TN> /* Thread cols (in elements) */
|
|
556
|
+
[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather(
|
|
557
|
+
const device T* mat [[buffer(0)]],
|
|
558
|
+
const device T* in_vec [[buffer(1)]],
|
|
559
|
+
const device T* bias [[buffer(2)]],
|
|
560
|
+
device T* out_vec [[buffer(3)]],
|
|
561
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
562
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
563
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
564
|
+
const constant float& alpha [[buffer(7)]],
|
|
565
|
+
const constant float& beta [[buffer(8)]],
|
|
566
|
+
const constant int& batch_ndim [[buffer(9)]],
|
|
567
|
+
const constant int* batch_shape [[buffer(10)]],
|
|
568
|
+
const constant int64_t* index_batch_strides [[buffer(11)]],
|
|
569
|
+
const constant int& vector_batch_ndim [[buffer(12)]],
|
|
570
|
+
const constant int* vector_batch_shape [[buffer(13)]],
|
|
571
|
+
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
|
572
|
+
const constant int& matrix_batch_ndim [[buffer(15)]],
|
|
573
|
+
const constant int* matrix_batch_shape [[buffer(16)]],
|
|
574
|
+
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
|
575
|
+
const constant uint32_t* vec_indices [[buffer(18)]],
|
|
576
|
+
const constant uint32_t* mat_indices [[buffer(19)]],
|
|
577
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
578
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
579
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
580
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
581
|
+
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
|
582
|
+
threadgroup typename gemv_kernel::acc_type tgp_memory
|
|
583
|
+
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
584
|
+
|
|
585
|
+
uint32_t indx_vec;
|
|
586
|
+
uint32_t indx_mat;
|
|
587
|
+
|
|
588
|
+
// Update batch offsets
|
|
589
|
+
if (batch_ndim > 1) {
|
|
590
|
+
const constant auto* veci_bstrides = index_batch_strides;
|
|
591
|
+
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
|
592
|
+
|
|
593
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
594
|
+
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
|
595
|
+
|
|
596
|
+
indx_vec = vec_indices[batch_offsets.x];
|
|
597
|
+
indx_mat = mat_indices[batch_offsets.y];
|
|
598
|
+
|
|
599
|
+
} else {
|
|
600
|
+
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
|
601
|
+
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
if (vector_batch_ndim > 1) {
|
|
605
|
+
in_vec += elem_to_loc(
|
|
606
|
+
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
|
607
|
+
} else {
|
|
608
|
+
in_vec += indx_vec * vector_batch_stride[0];
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
if (matrix_batch_ndim > 1) {
|
|
612
|
+
mat += elem_to_loc(
|
|
613
|
+
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
|
614
|
+
} else {
|
|
615
|
+
mat += indx_mat * matrix_batch_stride[0];
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
out_vec += tid.z * out_vec_size;
|
|
619
|
+
|
|
620
|
+
gemv_kernel::run(
|
|
621
|
+
mat,
|
|
622
|
+
in_vec,
|
|
623
|
+
bias,
|
|
624
|
+
out_vec,
|
|
625
|
+
in_vec_size,
|
|
626
|
+
out_vec_size,
|
|
627
|
+
marix_ld,
|
|
628
|
+
alpha,
|
|
629
|
+
beta,
|
|
630
|
+
batch_ndim, // Not used
|
|
631
|
+
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
632
|
+
tid,
|
|
633
|
+
lid,
|
|
634
|
+
simd_gid,
|
|
635
|
+
simd_lid);
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
// clang-format off
|
|
639
|
+
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
|
|
640
|
+
instantiate_kernel( \
|
|
641
|
+
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
|
642
|
+
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
|
643
|
+
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
|
|
644
|
+
|
|
645
|
+
#define instantiate_gemv_bs_blocks(name, itype) \
|
|
646
|
+
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
|
|
647
|
+
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
|
|
648
|
+
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
|
|
649
|
+
|
|
650
|
+
instantiate_gemv_bs_blocks(float32, float);
|
|
651
|
+
instantiate_gemv_bs_blocks(float16, half);
|
|
652
|
+
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
|
653
|
+
instantiate_gemv_bs_blocks(complex64, complex64_t);
|
|
654
|
+
|
|
655
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
656
|
+
/// Vector matrix multiplication
|
|
657
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
658
|
+
|
|
659
|
+
template <
|
|
660
|
+
typename T,
|
|
661
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
662
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
663
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
664
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
665
|
+
const int TM, /* Thread rows (in elements) */
|
|
666
|
+
const int TN, /* Thread cols (in elements) */
|
|
667
|
+
const bool kDoNCBatch, /* Batch ndim > 1 */
|
|
668
|
+
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
|
|
669
|
+
[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t(
|
|
670
|
+
const device T* mat [[buffer(0)]],
|
|
671
|
+
const device T* in_vec [[buffer(1)]],
|
|
672
|
+
const device T* bias [[buffer(2)]],
|
|
673
|
+
device T* out_vec [[buffer(3)]],
|
|
674
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
675
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
676
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
677
|
+
const constant float& alpha [[buffer(7)]],
|
|
678
|
+
const constant float& beta [[buffer(8)]],
|
|
679
|
+
const constant int& batch_ndim [[buffer(9)]],
|
|
680
|
+
const constant int* batch_shape [[buffer(10)]],
|
|
681
|
+
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
682
|
+
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
683
|
+
const constant int64_t* bias_batch_stride [[buffer(13)]],
|
|
684
|
+
const constant int& bias_stride [[buffer(14)]],
|
|
685
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
686
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
687
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
688
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
689
|
+
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
|
690
|
+
threadgroup typename gemv_kernel::acc_type tgp_memory
|
|
691
|
+
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
692
|
+
|
|
693
|
+
// Update batch offsets
|
|
694
|
+
if (kDoNCBatch) {
|
|
695
|
+
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
696
|
+
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
697
|
+
|
|
698
|
+
if (kDoAxpby) {
|
|
699
|
+
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
} else {
|
|
703
|
+
in_vec += tid.z * vector_batch_stride[0];
|
|
704
|
+
mat += tid.z * matrix_batch_stride[0];
|
|
705
|
+
|
|
706
|
+
if (kDoAxpby) {
|
|
707
|
+
bias += tid.z * bias_batch_stride[0];
|
|
708
|
+
}
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
out_vec += tid.z * out_vec_size;
|
|
712
|
+
|
|
713
|
+
gemv_kernel::run(
|
|
714
|
+
mat,
|
|
715
|
+
in_vec,
|
|
716
|
+
bias,
|
|
717
|
+
out_vec,
|
|
718
|
+
in_vec_size,
|
|
719
|
+
out_vec_size,
|
|
720
|
+
marix_ld,
|
|
721
|
+
alpha,
|
|
722
|
+
beta,
|
|
723
|
+
bias_stride,
|
|
724
|
+
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
725
|
+
tid,
|
|
726
|
+
lid,
|
|
727
|
+
simd_gid,
|
|
728
|
+
simd_lid);
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
// clang-format off
|
|
732
|
+
#define instantiate_gemv_t_helper( \
|
|
733
|
+
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
|
|
734
|
+
instantiate_kernel( \
|
|
735
|
+
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
|
|
736
|
+
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
|
|
737
|
+
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
|
|
738
|
+
|
|
739
|
+
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
|
|
740
|
+
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
|
|
741
|
+
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
|
|
742
|
+
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
|
|
743
|
+
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
|
|
744
|
+
|
|
745
|
+
// clang-format off
|
|
746
|
+
#define instantiate_gemv_t_blocks(name, itype) \
|
|
747
|
+
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \
|
|
748
|
+
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
|
|
749
|
+
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \
|
|
750
|
+
instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \
|
|
751
|
+
instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
|
|
752
|
+
|
|
753
|
+
// clang-format off
|
|
754
|
+
instantiate_gemv_t_blocks(float32, float);
|
|
755
|
+
instantiate_gemv_t_blocks(float16, half);
|
|
756
|
+
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
|
757
|
+
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
|
|
758
|
+
|
|
759
|
+
template <
|
|
760
|
+
typename T,
|
|
761
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
762
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
763
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
764
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
765
|
+
const int TM, /* Thread rows (in elements) */
|
|
766
|
+
const int TN> /* Thread cols (in elements) */
|
|
767
|
+
[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather(
|
|
768
|
+
const device T* mat [[buffer(0)]],
|
|
769
|
+
const device T* in_vec [[buffer(1)]],
|
|
770
|
+
const device T* bias [[buffer(2)]],
|
|
771
|
+
device T* out_vec [[buffer(3)]],
|
|
772
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
773
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
774
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
775
|
+
const constant float& alpha [[buffer(7)]],
|
|
776
|
+
const constant float& beta [[buffer(8)]],
|
|
777
|
+
const constant int& batch_ndim [[buffer(9)]],
|
|
778
|
+
const constant int* batch_shape [[buffer(10)]],
|
|
779
|
+
const constant int64_t* index_batch_strides [[buffer(11)]],
|
|
780
|
+
const constant int& vector_batch_ndim [[buffer(12)]],
|
|
781
|
+
const constant int* vector_batch_shape [[buffer(13)]],
|
|
782
|
+
const constant int64_t* vector_batch_stride [[buffer(14)]],
|
|
783
|
+
const constant int& matrix_batch_ndim [[buffer(15)]],
|
|
784
|
+
const constant int* matrix_batch_shape [[buffer(16)]],
|
|
785
|
+
const constant int64_t* matrix_batch_stride [[buffer(17)]],
|
|
786
|
+
const constant uint32_t* vec_indices [[buffer(18)]],
|
|
787
|
+
const constant uint32_t* mat_indices [[buffer(19)]],
|
|
788
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
789
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
790
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
791
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
792
|
+
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
|
793
|
+
threadgroup typename gemv_kernel::acc_type tgp_memory
|
|
794
|
+
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
795
|
+
|
|
796
|
+
uint32_t indx_vec;
|
|
797
|
+
uint32_t indx_mat;
|
|
798
|
+
|
|
799
|
+
// Update batch offsets
|
|
800
|
+
if (batch_ndim > 1) {
|
|
801
|
+
const constant auto* veci_bstrides = index_batch_strides;
|
|
802
|
+
const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
|
|
803
|
+
|
|
804
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
805
|
+
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
|
|
806
|
+
|
|
807
|
+
indx_vec = vec_indices[batch_offsets.x];
|
|
808
|
+
indx_mat = mat_indices[batch_offsets.y];
|
|
809
|
+
|
|
810
|
+
} else {
|
|
811
|
+
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
|
|
812
|
+
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
if (vector_batch_ndim > 1) {
|
|
816
|
+
in_vec += elem_to_loc(
|
|
817
|
+
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
|
|
818
|
+
} else {
|
|
819
|
+
in_vec += indx_vec * vector_batch_stride[0];
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
if (matrix_batch_ndim > 1) {
|
|
823
|
+
mat += elem_to_loc(
|
|
824
|
+
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
|
|
825
|
+
} else {
|
|
826
|
+
mat += indx_mat * matrix_batch_stride[0];
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
out_vec += tid.z * out_vec_size;
|
|
830
|
+
|
|
831
|
+
gemv_kernel::run(
|
|
832
|
+
mat,
|
|
833
|
+
in_vec,
|
|
834
|
+
bias,
|
|
835
|
+
out_vec,
|
|
836
|
+
in_vec_size,
|
|
837
|
+
out_vec_size,
|
|
838
|
+
marix_ld,
|
|
839
|
+
alpha,
|
|
840
|
+
beta,
|
|
841
|
+
batch_ndim, // Not used,
|
|
842
|
+
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
843
|
+
tid,
|
|
844
|
+
lid,
|
|
845
|
+
simd_gid,
|
|
846
|
+
simd_lid);
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
// clang-format off
|
|
850
|
+
#define instantiate_gemv_t_bs_helper( \
|
|
851
|
+
nm, itype, bm, bn, sm, sn, tm, tn) \
|
|
852
|
+
instantiate_kernel( \
|
|
853
|
+
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
|
|
854
|
+
"_sn" #sn "_tm" #tm "_tn" #tn, \
|
|
855
|
+
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
|
|
856
|
+
|
|
857
|
+
#define instantiate_gemv_t_bs_blocks(name, itype) \
|
|
858
|
+
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
|
|
859
|
+
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
|
|
860
|
+
instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \
|
|
861
|
+
instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \
|
|
862
|
+
instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
|
|
863
|
+
|
|
864
|
+
// clang-format off
|
|
865
|
+
instantiate_gemv_t_bs_blocks(float32, float);
|
|
866
|
+
instantiate_gemv_t_bs_blocks(float16, half);
|
|
867
|
+
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
|
|
868
|
+
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on
|