mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,827 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
4
|
+
|
|
5
|
+
using namespace metal;
|
|
6
|
+
|
|
7
|
+
#define MLX_MTL_CONST static constant constexpr const
|
|
8
|
+
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
|
9
|
+
|
|
10
|
+
struct _NoMask {
|
|
11
|
+
char x;
|
|
12
|
+
|
|
13
|
+
constexpr METAL_FUNC operator bool() {
|
|
14
|
+
return true;
|
|
15
|
+
}
|
|
16
|
+
constexpr METAL_FUNC operator bool() const threadgroup {
|
|
17
|
+
return true;
|
|
18
|
+
}
|
|
19
|
+
constexpr METAL_FUNC operator bool() const device {
|
|
20
|
+
return true;
|
|
21
|
+
}
|
|
22
|
+
constexpr METAL_FUNC operator bool() const constant {
|
|
23
|
+
return true;
|
|
24
|
+
}
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
typedef struct _NoMask nomask_t;
|
|
28
|
+
|
|
29
|
+
template <typename OutT, typename InT = OutT>
|
|
30
|
+
struct ScaleOp {
|
|
31
|
+
OutT scale;
|
|
32
|
+
|
|
33
|
+
METAL_FUNC OutT apply(InT x) const {
|
|
34
|
+
return static_cast<OutT>(x) * scale;
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
template <
|
|
39
|
+
typename T,
|
|
40
|
+
typename out_mask_t,
|
|
41
|
+
typename op_mask_t,
|
|
42
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
43
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
44
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
45
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
46
|
+
const int TM, /* Thread rows (in elements) */
|
|
47
|
+
const int TN, /* Thread cols (in elements) */
|
|
48
|
+
typename AccT = float>
|
|
49
|
+
struct GEMVKernel {
|
|
50
|
+
MLX_MTL_CONST int threadsM = BM * SM;
|
|
51
|
+
MLX_MTL_CONST int threadsN = BN * SN;
|
|
52
|
+
|
|
53
|
+
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
54
|
+
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
55
|
+
|
|
56
|
+
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
57
|
+
|
|
58
|
+
static_assert(
|
|
59
|
+
SN == 8 || SN == 16 || SN == 32,
|
|
60
|
+
"gemv block must have a width of 8, 16, or 32");
|
|
61
|
+
|
|
62
|
+
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
|
|
63
|
+
|
|
64
|
+
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
65
|
+
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
66
|
+
|
|
67
|
+
MLX_MTL_CONST bool has_mul_operand_mask =
|
|
68
|
+
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
|
69
|
+
MLX_MTL_CONST bool has_mul_output_mask =
|
|
70
|
+
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
|
71
|
+
|
|
72
|
+
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
|
|
73
|
+
// into blocks of (blockM, blockN) divided among threadgroups
|
|
74
|
+
// - Every thread works on a block of (TM, TN)
|
|
75
|
+
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
76
|
+
//
|
|
77
|
+
// 1. A thread loads TN elements each from mat along TM rows
|
|
78
|
+
// and the corresponding scalar from the vector
|
|
79
|
+
// 2. The thread then multiplies and adds to accumulate its local result for
|
|
80
|
+
// the block
|
|
81
|
+
// 3. At the end, each thread has accumulated results over all blocks across
|
|
82
|
+
// the rows. These are then summed up across the threadgroup
|
|
83
|
+
// 4. Each threadgroup writes its accumulated blockM outputs
|
|
84
|
+
//
|
|
85
|
+
// Edge case handling:
|
|
86
|
+
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
87
|
+
// * The blocks that start outside the matrix are never read (thread results
|
|
88
|
+
// remain zero)
|
|
89
|
+
// * The last thread that partially overlaps with the matrix is shifted
|
|
90
|
+
// inwards such that the thread block fits exactly in the matrix
|
|
91
|
+
|
|
92
|
+
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
|
|
93
|
+
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
|
|
94
|
+
|
|
95
|
+
template <typename U = T>
|
|
96
|
+
static METAL_FUNC void
|
|
97
|
+
load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
|
|
98
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
99
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
100
|
+
dst[tn] = static_cast<U>(src[src_offset + tn]);
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
template <typename U = T>
|
|
105
|
+
static METAL_FUNC void load_safe(
|
|
106
|
+
const device T* src,
|
|
107
|
+
thread U dst[TN],
|
|
108
|
+
const int src_offset = 0,
|
|
109
|
+
const int src_size = TN) {
|
|
110
|
+
if (src_offset + TN <= src_size) {
|
|
111
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
112
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
113
|
+
dst[tn] = static_cast<U>(src[src_offset + tn]);
|
|
114
|
+
}
|
|
115
|
+
} else { // Edgecase
|
|
116
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
117
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
118
|
+
dst[tn] = src_offset + tn < src_size
|
|
119
|
+
? static_cast<U>(src[src_offset + tn])
|
|
120
|
+
: U(0);
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
static METAL_FUNC void run(
|
|
126
|
+
const device T* mat [[buffer(0)]],
|
|
127
|
+
const device T* in_vec [[buffer(1)]],
|
|
128
|
+
device T* out_vec [[buffer(3)]],
|
|
129
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
130
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
131
|
+
const constant int& matrix_ld [[buffer(6)]],
|
|
132
|
+
const device out_mask_t* out_mask [[buffer(20)]],
|
|
133
|
+
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
134
|
+
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
135
|
+
const constant int* mask_strides [[buffer(23)]],
|
|
136
|
+
threadgroup AccT* tgp_memory [[threadgroup(0)]],
|
|
137
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
138
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
139
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
140
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
141
|
+
// Appease compiler
|
|
142
|
+
(void)lid;
|
|
143
|
+
|
|
144
|
+
// Thread local accumulation results
|
|
145
|
+
thread AccT result[TM] = {0};
|
|
146
|
+
thread T inter[TN];
|
|
147
|
+
thread AccT v_coeff[TN];
|
|
148
|
+
|
|
149
|
+
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
150
|
+
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
151
|
+
|
|
152
|
+
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
153
|
+
|
|
154
|
+
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
|
|
155
|
+
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
|
|
156
|
+
|
|
157
|
+
int bm = (simdM + thrM) * TM;
|
|
158
|
+
int bn = (simdN + thrN) * TN;
|
|
159
|
+
|
|
160
|
+
// Block position
|
|
161
|
+
int out_row = tid.x * blockM + bm;
|
|
162
|
+
|
|
163
|
+
// Exit simdgroup if rows out of bound
|
|
164
|
+
if (out_row >= out_vec_size)
|
|
165
|
+
return;
|
|
166
|
+
|
|
167
|
+
// Adjust tail simdgroup to ensure in bound reads
|
|
168
|
+
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
|
|
169
|
+
|
|
170
|
+
// Prepare mask offsets
|
|
171
|
+
const constant int* out_mask_strides = mask_strides;
|
|
172
|
+
const constant int* mat_mask_strides =
|
|
173
|
+
mask_strides + (has_output_mask ? 2 : 0);
|
|
174
|
+
const constant int* vec_mask_strides =
|
|
175
|
+
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
|
176
|
+
|
|
177
|
+
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
|
|
178
|
+
|
|
179
|
+
const int out_mask_offset =
|
|
180
|
+
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
|
|
181
|
+
|
|
182
|
+
int mat_mask_offset =
|
|
183
|
+
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
|
|
184
|
+
int vec_mask_offset = 0;
|
|
185
|
+
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
|
|
186
|
+
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
|
|
187
|
+
|
|
188
|
+
T out_scale{1};
|
|
189
|
+
|
|
190
|
+
// Check output mask
|
|
191
|
+
if (has_output_mask) {
|
|
192
|
+
auto mask_out = out_mask[out_mask_offset];
|
|
193
|
+
|
|
194
|
+
// Write zeros and return if mask is 0
|
|
195
|
+
if (!mask_out) {
|
|
196
|
+
if (simdN == 0 && thrN == 0) {
|
|
197
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
198
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
199
|
+
out_vec[out_row + tm] = T(0.);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
return;
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
// Store scalar if multiplicative mask
|
|
207
|
+
if (has_mul_output_mask) {
|
|
208
|
+
out_scale = T(mask_out);
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
// Advance matrix
|
|
213
|
+
mat += out_row * matrix_ld;
|
|
214
|
+
|
|
215
|
+
// Prepare for loop
|
|
216
|
+
constexpr const uniform<int> loop_stride = make_uniform(blockN);
|
|
217
|
+
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
218
|
+
const uniform<int> n_iter = in_size / loop_stride;
|
|
219
|
+
const uniform<int> last_iter = loop_stride * n_iter;
|
|
220
|
+
const uniform<int> leftover = in_size - last_iter;
|
|
221
|
+
|
|
222
|
+
// Loop over in_vec in blocks of blockN
|
|
223
|
+
for (int i = 0; i < n_iter; ++i) {
|
|
224
|
+
if (!has_operand_mask ||
|
|
225
|
+
(bool(mat_mask[mat_mask_offset]) &&
|
|
226
|
+
bool(vec_mask[vec_mask_offset]))) {
|
|
227
|
+
T block_scale{1};
|
|
228
|
+
if (has_mul_operand_mask) {
|
|
229
|
+
block_scale =
|
|
230
|
+
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
load_unsafe<AccT>(in_vec, v_coeff, bn);
|
|
234
|
+
|
|
235
|
+
// Apply scale
|
|
236
|
+
if (has_mul_operand_mask) {
|
|
237
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
238
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
239
|
+
v_coeff[tn] *= block_scale;
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// Per thread work loop
|
|
244
|
+
int mat_offset = 0;
|
|
245
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
246
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
247
|
+
// Load for the row
|
|
248
|
+
load_unsafe(mat, inter, mat_offset + bn);
|
|
249
|
+
|
|
250
|
+
// Accumulate results
|
|
251
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
252
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
253
|
+
result[tm] += inter[tn] * v_coeff[tn];
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
mat_offset += matrix_ld;
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
bn += blockN;
|
|
261
|
+
mat_mask_offset += mat_mask_step;
|
|
262
|
+
vec_mask_offset += vec_mask_step;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (leftover > 0) {
|
|
266
|
+
if (!has_operand_mask ||
|
|
267
|
+
(bool(mat_mask[mat_mask_offset]) &&
|
|
268
|
+
bool(vec_mask[vec_mask_offset]))) {
|
|
269
|
+
T block_scale{1};
|
|
270
|
+
if (has_mul_operand_mask) {
|
|
271
|
+
block_scale =
|
|
272
|
+
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
load_safe<AccT>(in_vec, v_coeff, bn, in_size);
|
|
276
|
+
|
|
277
|
+
// Apply scale
|
|
278
|
+
if (has_mul_operand_mask) {
|
|
279
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
280
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
281
|
+
v_coeff[tn] *= block_scale;
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
// Per thread work loop
|
|
286
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
287
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
288
|
+
// Load for the row
|
|
289
|
+
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
|
|
290
|
+
|
|
291
|
+
// Accumulate results
|
|
292
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
293
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
294
|
+
result[tm] += inter[tn] * v_coeff[tn];
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
// Apply out scale
|
|
301
|
+
if (has_mul_output_mask) {
|
|
302
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
303
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
304
|
+
result[tm] *= out_scale;
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
// Simdgroup accumulations
|
|
309
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
310
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
311
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
312
|
+
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
|
|
313
|
+
result[tm] += simd_shuffle_down(result[tm], sn);
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
// Threadgroup accumulation results
|
|
318
|
+
if (needs_tgp_reduction) {
|
|
319
|
+
threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
|
|
320
|
+
if (thrN == 0) {
|
|
321
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
322
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
323
|
+
tgp_results[tm] = result[tm];
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
327
|
+
|
|
328
|
+
if (sgN == 0) {
|
|
329
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
330
|
+
for (int sgn = 1; sgn < BN; sgn++) {
|
|
331
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
332
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
333
|
+
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
// Write outputs
|
|
341
|
+
if (simdN == 0 && thrN == 0) {
|
|
342
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
343
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
344
|
+
out_vec[out_row + tm] = static_cast<T>(result[tm]);
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
};
|
|
349
|
+
|
|
350
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
351
|
+
/// Vector matrix multiplication
|
|
352
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
353
|
+
|
|
354
|
+
template <
|
|
355
|
+
typename T,
|
|
356
|
+
typename out_mask_t,
|
|
357
|
+
typename op_mask_t,
|
|
358
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
359
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
360
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
361
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
362
|
+
const int TM, /* Thread rows (in elements) */
|
|
363
|
+
const int TN, /* Thread cols (in elements) */
|
|
364
|
+
typename AccT = float>
|
|
365
|
+
struct GEMVTKernel {
|
|
366
|
+
MLX_MTL_CONST int threadsM = BM * SM;
|
|
367
|
+
MLX_MTL_CONST int threadsN = BN * SN;
|
|
368
|
+
|
|
369
|
+
MLX_MTL_CONST int blockM = threadsM * TM;
|
|
370
|
+
MLX_MTL_CONST int blockN = threadsN * TN;
|
|
371
|
+
|
|
372
|
+
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
|
|
373
|
+
|
|
374
|
+
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
375
|
+
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
376
|
+
|
|
377
|
+
MLX_MTL_CONST bool has_mul_operand_mask =
|
|
378
|
+
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
|
379
|
+
MLX_MTL_CONST bool has_mul_output_mask =
|
|
380
|
+
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
|
381
|
+
|
|
382
|
+
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
|
|
383
|
+
// into blocks of (blockM, blockN) divided among threadgroups
|
|
384
|
+
// - Every thread works on a block of (TM, TN)
|
|
385
|
+
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
|
|
386
|
+
//
|
|
387
|
+
// 1. A thread loads TN elements each from mat along TM contiguous rows
|
|
388
|
+
// and the corresponding scalar from the vector
|
|
389
|
+
// 2. The thread then accumulates its local result for the block
|
|
390
|
+
// 3. At the end, each thread has accumulated results over all blocks across
|
|
391
|
+
// the rows. These are then summed up across the threadgroup
|
|
392
|
+
// 4. Each threadgroup writes its accumulated BN * TN outputs
|
|
393
|
+
//
|
|
394
|
+
// Edge case handling:
|
|
395
|
+
// - The threadgroup with the largest tid has blocks that exceed the matrix
|
|
396
|
+
// * The blocks that start outside the matrix are never read (thread results
|
|
397
|
+
// remain zero)
|
|
398
|
+
// * The last thread that partially overlaps with the matrix is shifted
|
|
399
|
+
// inwards such that the thread block fits exactly in the matrix
|
|
400
|
+
|
|
401
|
+
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
|
|
402
|
+
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
|
|
403
|
+
|
|
404
|
+
static METAL_FUNC void run(
|
|
405
|
+
const device T* mat [[buffer(0)]],
|
|
406
|
+
const device T* in_vec [[buffer(1)]],
|
|
407
|
+
device T* out_vec [[buffer(3)]],
|
|
408
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
409
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
410
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
411
|
+
const device out_mask_t* out_mask [[buffer(20)]],
|
|
412
|
+
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
413
|
+
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
414
|
+
const constant int* mask_strides [[buffer(23)]],
|
|
415
|
+
threadgroup AccT* tgp_memory [[threadgroup(0)]],
|
|
416
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
417
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
418
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
419
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
420
|
+
// Appease compiler
|
|
421
|
+
(void)lid;
|
|
422
|
+
|
|
423
|
+
// Thread local accumulation results
|
|
424
|
+
AccT result[TN] = {0};
|
|
425
|
+
T inter[TN];
|
|
426
|
+
AccT v_coeff[TM];
|
|
427
|
+
|
|
428
|
+
const int thrM = SN != 32 ? simd_lid / SN : 0;
|
|
429
|
+
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
|
|
430
|
+
|
|
431
|
+
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
|
|
432
|
+
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
|
|
433
|
+
|
|
434
|
+
const int simdM = SM * sgM;
|
|
435
|
+
const int simdN = SN * sgN;
|
|
436
|
+
|
|
437
|
+
int cm = (simdM + thrM);
|
|
438
|
+
int cn = (simdN + thrN);
|
|
439
|
+
|
|
440
|
+
int bm = cm * TM;
|
|
441
|
+
int bn = cn * TN;
|
|
442
|
+
|
|
443
|
+
int out_col = tid.x * blockN + bn;
|
|
444
|
+
|
|
445
|
+
// Prepare mask offsets
|
|
446
|
+
const constant int* out_mask_strides = mask_strides;
|
|
447
|
+
const constant int* mat_mask_strides =
|
|
448
|
+
out_mask_strides + (has_output_mask ? 2 : 0);
|
|
449
|
+
const constant int* vec_mask_strides =
|
|
450
|
+
mat_mask_strides + (has_operand_mask ? 2 : 0);
|
|
451
|
+
|
|
452
|
+
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
|
|
453
|
+
|
|
454
|
+
const int out_mask_offset =
|
|
455
|
+
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
|
|
456
|
+
|
|
457
|
+
int mat_mask_offset =
|
|
458
|
+
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
|
|
459
|
+
int vec_mask_offset = 0;
|
|
460
|
+
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
|
|
461
|
+
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
|
|
462
|
+
|
|
463
|
+
T out_scale{1};
|
|
464
|
+
|
|
465
|
+
// Check output mask
|
|
466
|
+
if (has_output_mask) {
|
|
467
|
+
auto mask_out = out_mask[out_mask_offset];
|
|
468
|
+
|
|
469
|
+
// Write zeros and return if mask is 0
|
|
470
|
+
if (!mask_out) {
|
|
471
|
+
if (cm == 0 && out_col < out_vec_size) {
|
|
472
|
+
if (out_col + TN <= out_vec_size) {
|
|
473
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
474
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
475
|
+
out_vec[out_col + tn] = T(0.);
|
|
476
|
+
}
|
|
477
|
+
} else {
|
|
478
|
+
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
|
|
479
|
+
out_vec[out_col + tn] = T(0.);
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
return;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
// Store scalar if multiplicative mask
|
|
488
|
+
if (has_mul_output_mask) {
|
|
489
|
+
out_scale = T(mask_out);
|
|
490
|
+
}
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
// Prepare for loop
|
|
494
|
+
constexpr const uniform<int> loop_stride = make_uniform(blockM);
|
|
495
|
+
const uniform<int> in_size = make_uniform(in_vec_size);
|
|
496
|
+
const uniform<int> n_iter = in_size / loop_stride;
|
|
497
|
+
const uniform<int> last_iter = loop_stride * n_iter;
|
|
498
|
+
const uniform<int> leftover = in_size - last_iter;
|
|
499
|
+
|
|
500
|
+
// Edgecase handling
|
|
501
|
+
if (out_col < out_vec_size) {
|
|
502
|
+
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
|
|
503
|
+
|
|
504
|
+
// Per thread accumulation main loop
|
|
505
|
+
for (int i = 0; i < n_iter; ++i) {
|
|
506
|
+
// Adding a threadgroup_barrier improves performance slightly
|
|
507
|
+
// This is possibly it may help exploit cache better
|
|
508
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
509
|
+
|
|
510
|
+
if (!has_operand_mask ||
|
|
511
|
+
(bool(mat_mask[mat_mask_offset]) &&
|
|
512
|
+
bool(vec_mask[vec_mask_offset]))) {
|
|
513
|
+
T block_scale{1};
|
|
514
|
+
if (has_mul_operand_mask) {
|
|
515
|
+
block_scale =
|
|
516
|
+
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
520
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
521
|
+
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
// Apply scale
|
|
525
|
+
if (has_mul_operand_mask) {
|
|
526
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
527
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
528
|
+
v_coeff[tm] *= block_scale;
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
533
|
+
for (int tm = 0; tm < TM; tm++) {
|
|
534
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
535
|
+
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
536
|
+
}
|
|
537
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
538
|
+
result[tn] += v_coeff[tm] * inter[tn];
|
|
539
|
+
}
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
bm += blockM;
|
|
544
|
+
mat_mask_offset += mat_mask_step;
|
|
545
|
+
vec_mask_offset += vec_mask_step;
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
if (leftover > 0) {
|
|
549
|
+
if (!has_operand_mask ||
|
|
550
|
+
(bool(mat_mask[mat_mask_offset]) &&
|
|
551
|
+
bool(vec_mask[vec_mask_offset]))) {
|
|
552
|
+
T block_scale{1};
|
|
553
|
+
if (has_mul_operand_mask) {
|
|
554
|
+
block_scale =
|
|
555
|
+
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
|
|
559
|
+
v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
|
|
560
|
+
|
|
561
|
+
if (has_mul_operand_mask) {
|
|
562
|
+
v_coeff[tm] *= block_scale;
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
566
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
567
|
+
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
571
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
572
|
+
result[tn] += v_coeff[tm] * inter[tn];
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
}
|
|
576
|
+
}
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
// Apply out scale
|
|
580
|
+
if (has_mul_output_mask) {
|
|
581
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
582
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
583
|
+
result[tn] *= out_scale;
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
// Simdgroup accumulations
|
|
588
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
589
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
590
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
591
|
+
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
|
|
592
|
+
result[tn] += simd_shuffle_down(result[tn], SN * sm);
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
// Threadgroup accumulation results
|
|
597
|
+
if (needs_tgp_reduction) {
|
|
598
|
+
threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
|
|
599
|
+
if (thrM == 0) {
|
|
600
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
601
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
602
|
+
tgp_results[tn] = result[tn];
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
606
|
+
|
|
607
|
+
if (sgM == 0) {
|
|
608
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
609
|
+
for (int sgm = 1; sgm < BM; sgm++) {
|
|
610
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
611
|
+
for (int tn = 0; tn < TN; tn++) {
|
|
612
|
+
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
}
|
|
616
|
+
}
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
// Threadgroup accumulation and writing out results
|
|
620
|
+
if (cm == 0 && out_col < out_vec_size) {
|
|
621
|
+
MLX_MTL_PRAGMA_UNROLL
|
|
622
|
+
for (int j = 0; j < TN; j++) {
|
|
623
|
+
out_vec[out_col + j] = static_cast<T>(result[j]);
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
}
|
|
627
|
+
};
|
|
628
|
+
|
|
629
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
630
|
+
/// Matrix vector multiplication
|
|
631
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
632
|
+
|
|
633
|
+
template <
|
|
634
|
+
typename T,
|
|
635
|
+
typename out_mask_t,
|
|
636
|
+
typename op_mask_t,
|
|
637
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
638
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
639
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
640
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
641
|
+
const int TM, /* Thread rows (in elements) */
|
|
642
|
+
const int TN, /* Thread cols (in elements) */
|
|
643
|
+
const bool kDoNCBatch> /* Batch ndim > 1 */
|
|
644
|
+
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
|
|
645
|
+
const device T* mat [[buffer(0)]],
|
|
646
|
+
const device T* in_vec [[buffer(1)]],
|
|
647
|
+
device T* out_vec [[buffer(3)]],
|
|
648
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
649
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
650
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
651
|
+
const constant int& batch_ndim [[buffer(9)]],
|
|
652
|
+
const constant int* batch_shape [[buffer(10)]],
|
|
653
|
+
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
654
|
+
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
655
|
+
const device out_mask_t* out_mask [[buffer(20)]],
|
|
656
|
+
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
657
|
+
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
658
|
+
const constant int* mask_strides [[buffer(23)]],
|
|
659
|
+
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
|
660
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
661
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
662
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
663
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
664
|
+
using gemv_kernel =
|
|
665
|
+
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
|
666
|
+
threadgroup float tgp_memory
|
|
667
|
+
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
668
|
+
|
|
669
|
+
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
670
|
+
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
671
|
+
|
|
672
|
+
// Update batch offsets
|
|
673
|
+
if (kDoNCBatch) {
|
|
674
|
+
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
675
|
+
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
676
|
+
|
|
677
|
+
if (has_output_mask) {
|
|
678
|
+
out_mask +=
|
|
679
|
+
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
|
680
|
+
mask_batch_strides += batch_ndim;
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
if (has_operand_mask) {
|
|
684
|
+
const constant auto* mask_strides_mat = mask_batch_strides;
|
|
685
|
+
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
|
686
|
+
|
|
687
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
688
|
+
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
|
689
|
+
|
|
690
|
+
mat_mask += batch_offsets.x;
|
|
691
|
+
vec_mask += batch_offsets.y;
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
} else {
|
|
695
|
+
in_vec += tid.z * vector_batch_stride[0];
|
|
696
|
+
mat += tid.z * matrix_batch_stride[0];
|
|
697
|
+
|
|
698
|
+
if (has_output_mask) {
|
|
699
|
+
out_mask += tid.z * mask_batch_strides[0];
|
|
700
|
+
mask_batch_strides += batch_ndim;
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
if (has_operand_mask) {
|
|
704
|
+
mat_mask += tid.z * mask_batch_strides[0];
|
|
705
|
+
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
out_vec += tid.z * out_vec_size;
|
|
710
|
+
|
|
711
|
+
gemv_kernel::run(
|
|
712
|
+
mat,
|
|
713
|
+
in_vec,
|
|
714
|
+
out_vec,
|
|
715
|
+
in_vec_size,
|
|
716
|
+
out_vec_size,
|
|
717
|
+
marix_ld,
|
|
718
|
+
out_mask,
|
|
719
|
+
mat_mask,
|
|
720
|
+
vec_mask,
|
|
721
|
+
mask_strides,
|
|
722
|
+
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
723
|
+
tid,
|
|
724
|
+
lid,
|
|
725
|
+
simd_gid,
|
|
726
|
+
simd_lid);
|
|
727
|
+
}
|
|
728
|
+
|
|
729
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
730
|
+
/// Vector matrix multiplication
|
|
731
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
732
|
+
|
|
733
|
+
template <
|
|
734
|
+
typename T,
|
|
735
|
+
typename out_mask_t,
|
|
736
|
+
typename op_mask_t,
|
|
737
|
+
const int BM, /* Threadgroup rows (in simdgroups) */
|
|
738
|
+
const int BN, /* Threadgroup cols (in simdgroups) */
|
|
739
|
+
const int SM, /* Simdgroup rows (in threads) */
|
|
740
|
+
const int SN, /* Simdgroup cols (in threads) */
|
|
741
|
+
const int TM, /* Thread rows (in elements) */
|
|
742
|
+
const int TN, /* Thread cols (in elements) */
|
|
743
|
+
const bool kDoNCBatch> /* Batch ndim > 1 */
|
|
744
|
+
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
|
|
745
|
+
const device T* mat [[buffer(0)]],
|
|
746
|
+
const device T* in_vec [[buffer(1)]],
|
|
747
|
+
device T* out_vec [[buffer(3)]],
|
|
748
|
+
const constant int& in_vec_size [[buffer(4)]],
|
|
749
|
+
const constant int& out_vec_size [[buffer(5)]],
|
|
750
|
+
const constant int& marix_ld [[buffer(6)]],
|
|
751
|
+
const constant int& batch_ndim [[buffer(9)]],
|
|
752
|
+
const constant int* batch_shape [[buffer(10)]],
|
|
753
|
+
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
754
|
+
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
755
|
+
const device out_mask_t* out_mask [[buffer(20)]],
|
|
756
|
+
const device op_mask_t* mat_mask [[buffer(21)]],
|
|
757
|
+
const device op_mask_t* vec_mask [[buffer(22)]],
|
|
758
|
+
const constant int* mask_strides [[buffer(23)]],
|
|
759
|
+
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
|
760
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
761
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
762
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
763
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
764
|
+
using gemv_kernel =
|
|
765
|
+
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
|
|
766
|
+
threadgroup float tgp_memory
|
|
767
|
+
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
|
768
|
+
|
|
769
|
+
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
770
|
+
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
771
|
+
|
|
772
|
+
// Update batch offsets
|
|
773
|
+
if (kDoNCBatch) {
|
|
774
|
+
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
|
|
775
|
+
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
|
|
776
|
+
|
|
777
|
+
if (has_output_mask) {
|
|
778
|
+
out_mask +=
|
|
779
|
+
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
|
|
780
|
+
mask_batch_strides += batch_ndim;
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
if (has_operand_mask) {
|
|
784
|
+
const constant auto* mask_strides_mat = mask_batch_strides;
|
|
785
|
+
const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
|
|
786
|
+
|
|
787
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
788
|
+
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
|
|
789
|
+
|
|
790
|
+
mat_mask += batch_offsets.x;
|
|
791
|
+
vec_mask += batch_offsets.y;
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
} else {
|
|
795
|
+
in_vec += tid.z * vector_batch_stride[0];
|
|
796
|
+
mat += tid.z * matrix_batch_stride[0];
|
|
797
|
+
|
|
798
|
+
if (has_output_mask) {
|
|
799
|
+
out_mask += tid.z * mask_batch_strides[0];
|
|
800
|
+
mask_batch_strides += batch_ndim;
|
|
801
|
+
}
|
|
802
|
+
|
|
803
|
+
if (has_operand_mask) {
|
|
804
|
+
mat_mask += tid.z * mask_batch_strides[0];
|
|
805
|
+
vec_mask += tid.z * mask_batch_strides[batch_ndim];
|
|
806
|
+
}
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
out_vec += tid.z * out_vec_size;
|
|
810
|
+
|
|
811
|
+
gemv_kernel::run(
|
|
812
|
+
mat,
|
|
813
|
+
in_vec,
|
|
814
|
+
out_vec,
|
|
815
|
+
in_vec_size,
|
|
816
|
+
out_vec_size,
|
|
817
|
+
marix_ld,
|
|
818
|
+
out_mask,
|
|
819
|
+
mat_mask,
|
|
820
|
+
vec_mask,
|
|
821
|
+
mask_strides,
|
|
822
|
+
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
|
|
823
|
+
tid,
|
|
824
|
+
lid,
|
|
825
|
+
simd_gid,
|
|
826
|
+
simd_lid);
|
|
827
|
+
}
|