mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,827 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include "mlx/backend/metal/kernels/steel/utils.h"
4
+
5
+ using namespace metal;
6
+
7
+ #define MLX_MTL_CONST static constant constexpr const
8
+ #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
9
+
10
+ struct _NoMask {
11
+ char x;
12
+
13
+ constexpr METAL_FUNC operator bool() {
14
+ return true;
15
+ }
16
+ constexpr METAL_FUNC operator bool() const threadgroup {
17
+ return true;
18
+ }
19
+ constexpr METAL_FUNC operator bool() const device {
20
+ return true;
21
+ }
22
+ constexpr METAL_FUNC operator bool() const constant {
23
+ return true;
24
+ }
25
+ };
26
+
27
+ typedef struct _NoMask nomask_t;
28
+
29
+ template <typename OutT, typename InT = OutT>
30
+ struct ScaleOp {
31
+ OutT scale;
32
+
33
+ METAL_FUNC OutT apply(InT x) const {
34
+ return static_cast<OutT>(x) * scale;
35
+ }
36
+ };
37
+
38
+ template <
39
+ typename T,
40
+ typename out_mask_t,
41
+ typename op_mask_t,
42
+ const int BM, /* Threadgroup rows (in simdgroups) */
43
+ const int BN, /* Threadgroup cols (in simdgroups) */
44
+ const int SM, /* Simdgroup rows (in threads) */
45
+ const int SN, /* Simdgroup cols (in threads) */
46
+ const int TM, /* Thread rows (in elements) */
47
+ const int TN, /* Thread cols (in elements) */
48
+ typename AccT = float>
49
+ struct GEMVKernel {
50
+ MLX_MTL_CONST int threadsM = BM * SM;
51
+ MLX_MTL_CONST int threadsN = BN * SN;
52
+
53
+ MLX_MTL_CONST int blockM = threadsM * TM;
54
+ MLX_MTL_CONST int blockN = threadsN * TN;
55
+
56
+ static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
57
+
58
+ static_assert(
59
+ SN == 8 || SN == 16 || SN == 32,
60
+ "gemv block must have a width of 8, 16, or 32");
61
+
62
+ static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
63
+
64
+ MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
65
+ MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
66
+
67
+ MLX_MTL_CONST bool has_mul_operand_mask =
68
+ has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
69
+ MLX_MTL_CONST bool has_mul_output_mask =
70
+ has_output_mask && !metal::is_same_v<out_mask_t, bool>;
71
+
72
+ // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
73
+ // into blocks of (blockM, blockN) divided among threadgroups
74
+ // - Every thread works on a block of (TM, TN)
75
+ // - We assume each threadgroup has (threadsN, threadsM, 1) threads
76
+ //
77
+ // 1. A thread loads TN elements each from mat along TM rows
78
+ // and the corresponding scalar from the vector
79
+ // 2. The thread then multiplies and adds to accumulate its local result for
80
+ // the block
81
+ // 3. At the end, each thread has accumulated results over all blocks across
82
+ // the rows. These are then summed up across the threadgroup
83
+ // 4. Each threadgroup writes its accumulated blockM outputs
84
+ //
85
+ // Edge case handling:
86
+ // - The threadgroup with the largest tid has blocks that exceed the matrix
87
+ // * The blocks that start outside the matrix are never read (thread results
88
+ // remain zero)
89
+ // * The last thread that partially overlaps with the matrix is shifted
90
+ // inwards such that the thread block fits exactly in the matrix
91
+
92
+ MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
93
+ MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
94
+
95
+ template <typename U = T>
96
+ static METAL_FUNC void
97
+ load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
98
+ MLX_MTL_PRAGMA_UNROLL
99
+ for (int tn = 0; tn < TN; tn++) {
100
+ dst[tn] = static_cast<U>(src[src_offset + tn]);
101
+ }
102
+ }
103
+
104
+ template <typename U = T>
105
+ static METAL_FUNC void load_safe(
106
+ const device T* src,
107
+ thread U dst[TN],
108
+ const int src_offset = 0,
109
+ const int src_size = TN) {
110
+ if (src_offset + TN <= src_size) {
111
+ MLX_MTL_PRAGMA_UNROLL
112
+ for (int tn = 0; tn < TN; tn++) {
113
+ dst[tn] = static_cast<U>(src[src_offset + tn]);
114
+ }
115
+ } else { // Edgecase
116
+ MLX_MTL_PRAGMA_UNROLL
117
+ for (int tn = 0; tn < TN; tn++) {
118
+ dst[tn] = src_offset + tn < src_size
119
+ ? static_cast<U>(src[src_offset + tn])
120
+ : U(0);
121
+ }
122
+ }
123
+ }
124
+
125
+ static METAL_FUNC void run(
126
+ const device T* mat [[buffer(0)]],
127
+ const device T* in_vec [[buffer(1)]],
128
+ device T* out_vec [[buffer(3)]],
129
+ const constant int& in_vec_size [[buffer(4)]],
130
+ const constant int& out_vec_size [[buffer(5)]],
131
+ const constant int& matrix_ld [[buffer(6)]],
132
+ const device out_mask_t* out_mask [[buffer(20)]],
133
+ const device op_mask_t* mat_mask [[buffer(21)]],
134
+ const device op_mask_t* vec_mask [[buffer(22)]],
135
+ const constant int* mask_strides [[buffer(23)]],
136
+ threadgroup AccT* tgp_memory [[threadgroup(0)]],
137
+ uint3 tid [[threadgroup_position_in_grid]],
138
+ uint3 lid [[thread_position_in_threadgroup]],
139
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
140
+ uint simd_lid [[thread_index_in_simdgroup]]) {
141
+ // Appease compiler
142
+ (void)lid;
143
+
144
+ // Thread local accumulation results
145
+ thread AccT result[TM] = {0};
146
+ thread T inter[TN];
147
+ thread AccT v_coeff[TN];
148
+
149
+ const int thrM = SN != 32 ? simd_lid / SN : 0;
150
+ const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
151
+
152
+ const int sgN = BN != 1 ? (simd_gid % BN) : 0;
153
+
154
+ const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
155
+ const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
156
+
157
+ int bm = (simdM + thrM) * TM;
158
+ int bn = (simdN + thrN) * TN;
159
+
160
+ // Block position
161
+ int out_row = tid.x * blockM + bm;
162
+
163
+ // Exit simdgroup if rows out of bound
164
+ if (out_row >= out_vec_size)
165
+ return;
166
+
167
+ // Adjust tail simdgroup to ensure in bound reads
168
+ out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
169
+
170
+ // Prepare mask offsets
171
+ const constant int* out_mask_strides = mask_strides;
172
+ const constant int* mat_mask_strides =
173
+ mask_strides + (has_output_mask ? 2 : 0);
174
+ const constant int* vec_mask_strides =
175
+ mat_mask_strides + (has_operand_mask ? 2 : 0);
176
+
177
+ const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
178
+
179
+ const int out_mask_offset =
180
+ !has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
181
+
182
+ int mat_mask_offset =
183
+ !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
184
+ int vec_mask_offset = 0;
185
+ const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
186
+ const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
187
+
188
+ T out_scale{1};
189
+
190
+ // Check output mask
191
+ if (has_output_mask) {
192
+ auto mask_out = out_mask[out_mask_offset];
193
+
194
+ // Write zeros and return if mask is 0
195
+ if (!mask_out) {
196
+ if (simdN == 0 && thrN == 0) {
197
+ MLX_MTL_PRAGMA_UNROLL
198
+ for (int tm = 0; tm < TM; tm++) {
199
+ out_vec[out_row + tm] = T(0.);
200
+ }
201
+ }
202
+
203
+ return;
204
+ }
205
+
206
+ // Store scalar if multiplicative mask
207
+ if (has_mul_output_mask) {
208
+ out_scale = T(mask_out);
209
+ }
210
+ }
211
+
212
+ // Advance matrix
213
+ mat += out_row * matrix_ld;
214
+
215
+ // Prepare for loop
216
+ constexpr const uniform<int> loop_stride = make_uniform(blockN);
217
+ const uniform<int> in_size = make_uniform(in_vec_size);
218
+ const uniform<int> n_iter = in_size / loop_stride;
219
+ const uniform<int> last_iter = loop_stride * n_iter;
220
+ const uniform<int> leftover = in_size - last_iter;
221
+
222
+ // Loop over in_vec in blocks of blockN
223
+ for (int i = 0; i < n_iter; ++i) {
224
+ if (!has_operand_mask ||
225
+ (bool(mat_mask[mat_mask_offset]) &&
226
+ bool(vec_mask[vec_mask_offset]))) {
227
+ T block_scale{1};
228
+ if (has_mul_operand_mask) {
229
+ block_scale =
230
+ T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
231
+ }
232
+
233
+ load_unsafe<AccT>(in_vec, v_coeff, bn);
234
+
235
+ // Apply scale
236
+ if (has_mul_operand_mask) {
237
+ MLX_MTL_PRAGMA_UNROLL
238
+ for (int tn = 0; tn < TN; tn++) {
239
+ v_coeff[tn] *= block_scale;
240
+ }
241
+ }
242
+
243
+ // Per thread work loop
244
+ int mat_offset = 0;
245
+ MLX_MTL_PRAGMA_UNROLL
246
+ for (int tm = 0; tm < TM; tm++) {
247
+ // Load for the row
248
+ load_unsafe(mat, inter, mat_offset + bn);
249
+
250
+ // Accumulate results
251
+ MLX_MTL_PRAGMA_UNROLL
252
+ for (int tn = 0; tn < TN; tn++) {
253
+ result[tm] += inter[tn] * v_coeff[tn];
254
+ }
255
+
256
+ mat_offset += matrix_ld;
257
+ }
258
+ }
259
+
260
+ bn += blockN;
261
+ mat_mask_offset += mat_mask_step;
262
+ vec_mask_offset += vec_mask_step;
263
+ }
264
+
265
+ if (leftover > 0) {
266
+ if (!has_operand_mask ||
267
+ (bool(mat_mask[mat_mask_offset]) &&
268
+ bool(vec_mask[vec_mask_offset]))) {
269
+ T block_scale{1};
270
+ if (has_mul_operand_mask) {
271
+ block_scale =
272
+ T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
273
+ }
274
+
275
+ load_safe<AccT>(in_vec, v_coeff, bn, in_size);
276
+
277
+ // Apply scale
278
+ if (has_mul_operand_mask) {
279
+ MLX_MTL_PRAGMA_UNROLL
280
+ for (int tn = 0; tn < TN; tn++) {
281
+ v_coeff[tn] *= block_scale;
282
+ }
283
+ }
284
+
285
+ // Per thread work loop
286
+ MLX_MTL_PRAGMA_UNROLL
287
+ for (int tm = 0; tm < TM; tm++) {
288
+ // Load for the row
289
+ load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
290
+
291
+ // Accumulate results
292
+ MLX_MTL_PRAGMA_UNROLL
293
+ for (int tn = 0; tn < TN; tn++) {
294
+ result[tm] += inter[tn] * v_coeff[tn];
295
+ }
296
+ }
297
+ }
298
+ }
299
+
300
+ // Apply out scale
301
+ if (has_mul_output_mask) {
302
+ MLX_MTL_PRAGMA_UNROLL
303
+ for (int tm = 0; tm < TM; tm++) {
304
+ result[tm] *= out_scale;
305
+ }
306
+ }
307
+
308
+ // Simdgroup accumulations
309
+ MLX_MTL_PRAGMA_UNROLL
310
+ for (int tm = 0; tm < TM; tm++) {
311
+ MLX_MTL_PRAGMA_UNROLL
312
+ for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
313
+ result[tm] += simd_shuffle_down(result[tm], sn);
314
+ }
315
+ }
316
+
317
+ // Threadgroup accumulation results
318
+ if (needs_tgp_reduction) {
319
+ threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
320
+ if (thrN == 0) {
321
+ MLX_MTL_PRAGMA_UNROLL
322
+ for (int tm = 0; tm < TM; tm++) {
323
+ tgp_results[tm] = result[tm];
324
+ }
325
+
326
+ threadgroup_barrier(mem_flags::mem_none);
327
+
328
+ if (sgN == 0) {
329
+ MLX_MTL_PRAGMA_UNROLL
330
+ for (int sgn = 1; sgn < BN; sgn++) {
331
+ MLX_MTL_PRAGMA_UNROLL
332
+ for (int tm = 0; tm < TM; tm++) {
333
+ result[tm] += tgp_results[sgn * (blockM + TM) + tm];
334
+ }
335
+ }
336
+ }
337
+ }
338
+ }
339
+
340
+ // Write outputs
341
+ if (simdN == 0 && thrN == 0) {
342
+ MLX_MTL_PRAGMA_UNROLL
343
+ for (int tm = 0; tm < TM; tm++) {
344
+ out_vec[out_row + tm] = static_cast<T>(result[tm]);
345
+ }
346
+ }
347
+ }
348
+ };
349
+
350
+ ///////////////////////////////////////////////////////////////////////////////
351
+ /// Vector matrix multiplication
352
+ ///////////////////////////////////////////////////////////////////////////////
353
+
354
+ template <
355
+ typename T,
356
+ typename out_mask_t,
357
+ typename op_mask_t,
358
+ const int BM, /* Threadgroup rows (in simdgroups) */
359
+ const int BN, /* Threadgroup cols (in simdgroups) */
360
+ const int SM, /* Simdgroup rows (in threads) */
361
+ const int SN, /* Simdgroup cols (in threads) */
362
+ const int TM, /* Thread rows (in elements) */
363
+ const int TN, /* Thread cols (in elements) */
364
+ typename AccT = float>
365
+ struct GEMVTKernel {
366
+ MLX_MTL_CONST int threadsM = BM * SM;
367
+ MLX_MTL_CONST int threadsN = BN * SN;
368
+
369
+ MLX_MTL_CONST int blockM = threadsM * TM;
370
+ MLX_MTL_CONST int blockN = threadsN * TN;
371
+
372
+ static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
373
+
374
+ MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
375
+ MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
376
+
377
+ MLX_MTL_CONST bool has_mul_operand_mask =
378
+ has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
379
+ MLX_MTL_CONST bool has_mul_output_mask =
380
+ has_output_mask && !metal::is_same_v<out_mask_t, bool>;
381
+
382
+ // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
383
+ // into blocks of (blockM, blockN) divided among threadgroups
384
+ // - Every thread works on a block of (TM, TN)
385
+ // - We assume each threadgroup has (threadsN, threadsM, 1) threads
386
+ //
387
+ // 1. A thread loads TN elements each from mat along TM contiguous rows
388
+ // and the corresponding scalar from the vector
389
+ // 2. The thread then accumulates its local result for the block
390
+ // 3. At the end, each thread has accumulated results over all blocks across
391
+ // the rows. These are then summed up across the threadgroup
392
+ // 4. Each threadgroup writes its accumulated BN * TN outputs
393
+ //
394
+ // Edge case handling:
395
+ // - The threadgroup with the largest tid has blocks that exceed the matrix
396
+ // * The blocks that start outside the matrix are never read (thread results
397
+ // remain zero)
398
+ // * The last thread that partially overlaps with the matrix is shifted
399
+ // inwards such that the thread block fits exactly in the matrix
400
+
401
+ MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
402
+ MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
403
+
404
+ static METAL_FUNC void run(
405
+ const device T* mat [[buffer(0)]],
406
+ const device T* in_vec [[buffer(1)]],
407
+ device T* out_vec [[buffer(3)]],
408
+ const constant int& in_vec_size [[buffer(4)]],
409
+ const constant int& out_vec_size [[buffer(5)]],
410
+ const constant int& marix_ld [[buffer(6)]],
411
+ const device out_mask_t* out_mask [[buffer(20)]],
412
+ const device op_mask_t* mat_mask [[buffer(21)]],
413
+ const device op_mask_t* vec_mask [[buffer(22)]],
414
+ const constant int* mask_strides [[buffer(23)]],
415
+ threadgroup AccT* tgp_memory [[threadgroup(0)]],
416
+ uint3 tid [[threadgroup_position_in_grid]],
417
+ uint3 lid [[thread_position_in_threadgroup]],
418
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
419
+ uint simd_lid [[thread_index_in_simdgroup]]) {
420
+ // Appease compiler
421
+ (void)lid;
422
+
423
+ // Thread local accumulation results
424
+ AccT result[TN] = {0};
425
+ T inter[TN];
426
+ AccT v_coeff[TM];
427
+
428
+ const int thrM = SN != 32 ? simd_lid / SN : 0;
429
+ const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
430
+
431
+ const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
432
+ const int sgN = BN != 1 ? (simd_gid % BN) : 0;
433
+
434
+ const int simdM = SM * sgM;
435
+ const int simdN = SN * sgN;
436
+
437
+ int cm = (simdM + thrM);
438
+ int cn = (simdN + thrN);
439
+
440
+ int bm = cm * TM;
441
+ int bn = cn * TN;
442
+
443
+ int out_col = tid.x * blockN + bn;
444
+
445
+ // Prepare mask offsets
446
+ const constant int* out_mask_strides = mask_strides;
447
+ const constant int* mat_mask_strides =
448
+ out_mask_strides + (has_output_mask ? 2 : 0);
449
+ const constant int* vec_mask_strides =
450
+ mat_mask_strides + (has_operand_mask ? 2 : 0);
451
+
452
+ const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
453
+
454
+ const int out_mask_offset =
455
+ !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
456
+
457
+ int mat_mask_offset =
458
+ !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
459
+ int vec_mask_offset = 0;
460
+ const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
461
+ const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
462
+
463
+ T out_scale{1};
464
+
465
+ // Check output mask
466
+ if (has_output_mask) {
467
+ auto mask_out = out_mask[out_mask_offset];
468
+
469
+ // Write zeros and return if mask is 0
470
+ if (!mask_out) {
471
+ if (cm == 0 && out_col < out_vec_size) {
472
+ if (out_col + TN <= out_vec_size) {
473
+ MLX_MTL_PRAGMA_UNROLL
474
+ for (int tn = 0; tn < TN; tn++) {
475
+ out_vec[out_col + tn] = T(0.);
476
+ }
477
+ } else {
478
+ for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
479
+ out_vec[out_col + tn] = T(0.);
480
+ }
481
+ }
482
+ }
483
+
484
+ return;
485
+ }
486
+
487
+ // Store scalar if multiplicative mask
488
+ if (has_mul_output_mask) {
489
+ out_scale = T(mask_out);
490
+ }
491
+ }
492
+
493
+ // Prepare for loop
494
+ constexpr const uniform<int> loop_stride = make_uniform(blockM);
495
+ const uniform<int> in_size = make_uniform(in_vec_size);
496
+ const uniform<int> n_iter = in_size / loop_stride;
497
+ const uniform<int> last_iter = loop_stride * n_iter;
498
+ const uniform<int> leftover = in_size - last_iter;
499
+
500
+ // Edgecase handling
501
+ if (out_col < out_vec_size) {
502
+ out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
503
+
504
+ // Per thread accumulation main loop
505
+ for (int i = 0; i < n_iter; ++i) {
506
+ // Adding a threadgroup_barrier improves performance slightly
507
+ // This is possibly it may help exploit cache better
508
+ threadgroup_barrier(mem_flags::mem_none);
509
+
510
+ if (!has_operand_mask ||
511
+ (bool(mat_mask[mat_mask_offset]) &&
512
+ bool(vec_mask[vec_mask_offset]))) {
513
+ T block_scale{1};
514
+ if (has_mul_operand_mask) {
515
+ block_scale =
516
+ T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
517
+ }
518
+
519
+ MLX_MTL_PRAGMA_UNROLL
520
+ for (int tm = 0; tm < TM; tm++) {
521
+ v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
522
+ }
523
+
524
+ // Apply scale
525
+ if (has_mul_operand_mask) {
526
+ MLX_MTL_PRAGMA_UNROLL
527
+ for (int tm = 0; tm < TM; tm++) {
528
+ v_coeff[tm] *= block_scale;
529
+ }
530
+ }
531
+
532
+ MLX_MTL_PRAGMA_UNROLL
533
+ for (int tm = 0; tm < TM; tm++) {
534
+ for (int tn = 0; tn < TN; tn++) {
535
+ inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
536
+ }
537
+ for (int tn = 0; tn < TN; tn++) {
538
+ result[tn] += v_coeff[tm] * inter[tn];
539
+ }
540
+ }
541
+ }
542
+
543
+ bm += blockM;
544
+ mat_mask_offset += mat_mask_step;
545
+ vec_mask_offset += vec_mask_step;
546
+ }
547
+
548
+ if (leftover > 0) {
549
+ if (!has_operand_mask ||
550
+ (bool(mat_mask[mat_mask_offset]) &&
551
+ bool(vec_mask[vec_mask_offset]))) {
552
+ T block_scale{1};
553
+ if (has_mul_operand_mask) {
554
+ block_scale =
555
+ T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
556
+ }
557
+
558
+ for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
559
+ v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
560
+
561
+ if (has_mul_operand_mask) {
562
+ v_coeff[tm] *= block_scale;
563
+ }
564
+
565
+ MLX_MTL_PRAGMA_UNROLL
566
+ for (int tn = 0; tn < TN; tn++) {
567
+ inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
568
+ }
569
+
570
+ MLX_MTL_PRAGMA_UNROLL
571
+ for (int tn = 0; tn < TN; tn++) {
572
+ result[tn] += v_coeff[tm] * inter[tn];
573
+ }
574
+ }
575
+ }
576
+ }
577
+ }
578
+
579
+ // Apply out scale
580
+ if (has_mul_output_mask) {
581
+ MLX_MTL_PRAGMA_UNROLL
582
+ for (int tn = 0; tn < TN; tn++) {
583
+ result[tn] *= out_scale;
584
+ }
585
+ }
586
+
587
+ // Simdgroup accumulations
588
+ MLX_MTL_PRAGMA_UNROLL
589
+ for (int tn = 0; tn < TN; tn++) {
590
+ MLX_MTL_PRAGMA_UNROLL
591
+ for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
592
+ result[tn] += simd_shuffle_down(result[tn], SN * sm);
593
+ }
594
+ }
595
+
596
+ // Threadgroup accumulation results
597
+ if (needs_tgp_reduction) {
598
+ threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
599
+ if (thrM == 0) {
600
+ MLX_MTL_PRAGMA_UNROLL
601
+ for (int tn = 0; tn < TN; tn++) {
602
+ tgp_results[tn] = result[tn];
603
+ }
604
+
605
+ threadgroup_barrier(mem_flags::mem_none);
606
+
607
+ if (sgM == 0) {
608
+ MLX_MTL_PRAGMA_UNROLL
609
+ for (int sgm = 1; sgm < BM; sgm++) {
610
+ MLX_MTL_PRAGMA_UNROLL
611
+ for (int tn = 0; tn < TN; tn++) {
612
+ result[tn] += tgp_results[sgm * (blockN + TN) + tn];
613
+ }
614
+ }
615
+ }
616
+ }
617
+ }
618
+
619
+ // Threadgroup accumulation and writing out results
620
+ if (cm == 0 && out_col < out_vec_size) {
621
+ MLX_MTL_PRAGMA_UNROLL
622
+ for (int j = 0; j < TN; j++) {
623
+ out_vec[out_col + j] = static_cast<T>(result[j]);
624
+ }
625
+ }
626
+ }
627
+ };
628
+
629
+ ///////////////////////////////////////////////////////////////////////////////
630
+ /// Matrix vector multiplication
631
+ ///////////////////////////////////////////////////////////////////////////////
632
+
633
+ template <
634
+ typename T,
635
+ typename out_mask_t,
636
+ typename op_mask_t,
637
+ const int BM, /* Threadgroup rows (in simdgroups) */
638
+ const int BN, /* Threadgroup cols (in simdgroups) */
639
+ const int SM, /* Simdgroup rows (in threads) */
640
+ const int SN, /* Simdgroup cols (in threads) */
641
+ const int TM, /* Thread rows (in elements) */
642
+ const int TN, /* Thread cols (in elements) */
643
+ const bool kDoNCBatch> /* Batch ndim > 1 */
644
+ [[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
645
+ const device T* mat [[buffer(0)]],
646
+ const device T* in_vec [[buffer(1)]],
647
+ device T* out_vec [[buffer(3)]],
648
+ const constant int& in_vec_size [[buffer(4)]],
649
+ const constant int& out_vec_size [[buffer(5)]],
650
+ const constant int& marix_ld [[buffer(6)]],
651
+ const constant int& batch_ndim [[buffer(9)]],
652
+ const constant int* batch_shape [[buffer(10)]],
653
+ const constant int64_t* vector_batch_stride [[buffer(11)]],
654
+ const constant int64_t* matrix_batch_stride [[buffer(12)]],
655
+ const device out_mask_t* out_mask [[buffer(20)]],
656
+ const device op_mask_t* mat_mask [[buffer(21)]],
657
+ const device op_mask_t* vec_mask [[buffer(22)]],
658
+ const constant int* mask_strides [[buffer(23)]],
659
+ const constant int64_t* mask_batch_strides [[buffer(24)]],
660
+ uint3 tid [[threadgroup_position_in_grid]],
661
+ uint3 lid [[thread_position_in_threadgroup]],
662
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
663
+ uint simd_lid [[thread_index_in_simdgroup]]) {
664
+ using gemv_kernel =
665
+ GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
666
+ threadgroup float tgp_memory
667
+ [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
668
+
669
+ constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
670
+ constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
671
+
672
+ // Update batch offsets
673
+ if (kDoNCBatch) {
674
+ in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
675
+ mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
676
+
677
+ if (has_output_mask) {
678
+ out_mask +=
679
+ elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
680
+ mask_batch_strides += batch_ndim;
681
+ }
682
+
683
+ if (has_operand_mask) {
684
+ const constant auto* mask_strides_mat = mask_batch_strides;
685
+ const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
686
+
687
+ ulong2 batch_offsets = elem_to_loc_broadcast(
688
+ tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
689
+
690
+ mat_mask += batch_offsets.x;
691
+ vec_mask += batch_offsets.y;
692
+ }
693
+
694
+ } else {
695
+ in_vec += tid.z * vector_batch_stride[0];
696
+ mat += tid.z * matrix_batch_stride[0];
697
+
698
+ if (has_output_mask) {
699
+ out_mask += tid.z * mask_batch_strides[0];
700
+ mask_batch_strides += batch_ndim;
701
+ }
702
+
703
+ if (has_operand_mask) {
704
+ mat_mask += tid.z * mask_batch_strides[0];
705
+ vec_mask += tid.z * mask_batch_strides[batch_ndim];
706
+ }
707
+ }
708
+
709
+ out_vec += tid.z * out_vec_size;
710
+
711
+ gemv_kernel::run(
712
+ mat,
713
+ in_vec,
714
+ out_vec,
715
+ in_vec_size,
716
+ out_vec_size,
717
+ marix_ld,
718
+ out_mask,
719
+ mat_mask,
720
+ vec_mask,
721
+ mask_strides,
722
+ gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
723
+ tid,
724
+ lid,
725
+ simd_gid,
726
+ simd_lid);
727
+ }
728
+
729
+ ///////////////////////////////////////////////////////////////////////////////
730
+ /// Vector matrix multiplication
731
+ ///////////////////////////////////////////////////////////////////////////////
732
+
733
+ template <
734
+ typename T,
735
+ typename out_mask_t,
736
+ typename op_mask_t,
737
+ const int BM, /* Threadgroup rows (in simdgroups) */
738
+ const int BN, /* Threadgroup cols (in simdgroups) */
739
+ const int SM, /* Simdgroup rows (in threads) */
740
+ const int SN, /* Simdgroup cols (in threads) */
741
+ const int TM, /* Thread rows (in elements) */
742
+ const int TN, /* Thread cols (in elements) */
743
+ const bool kDoNCBatch> /* Batch ndim > 1 */
744
+ [[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
745
+ const device T* mat [[buffer(0)]],
746
+ const device T* in_vec [[buffer(1)]],
747
+ device T* out_vec [[buffer(3)]],
748
+ const constant int& in_vec_size [[buffer(4)]],
749
+ const constant int& out_vec_size [[buffer(5)]],
750
+ const constant int& marix_ld [[buffer(6)]],
751
+ const constant int& batch_ndim [[buffer(9)]],
752
+ const constant int* batch_shape [[buffer(10)]],
753
+ const constant int64_t* vector_batch_stride [[buffer(11)]],
754
+ const constant int64_t* matrix_batch_stride [[buffer(12)]],
755
+ const device out_mask_t* out_mask [[buffer(20)]],
756
+ const device op_mask_t* mat_mask [[buffer(21)]],
757
+ const device op_mask_t* vec_mask [[buffer(22)]],
758
+ const constant int* mask_strides [[buffer(23)]],
759
+ const constant int64_t* mask_batch_strides [[buffer(24)]],
760
+ uint3 tid [[threadgroup_position_in_grid]],
761
+ uint3 lid [[thread_position_in_threadgroup]],
762
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
763
+ uint simd_lid [[thread_index_in_simdgroup]]) {
764
+ using gemv_kernel =
765
+ GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
766
+ threadgroup float tgp_memory
767
+ [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
768
+
769
+ constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
770
+ constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
771
+
772
+ // Update batch offsets
773
+ if (kDoNCBatch) {
774
+ in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
775
+ mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
776
+
777
+ if (has_output_mask) {
778
+ out_mask +=
779
+ elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
780
+ mask_batch_strides += batch_ndim;
781
+ }
782
+
783
+ if (has_operand_mask) {
784
+ const constant auto* mask_strides_mat = mask_batch_strides;
785
+ const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
786
+
787
+ ulong2 batch_offsets = elem_to_loc_broadcast(
788
+ tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
789
+
790
+ mat_mask += batch_offsets.x;
791
+ vec_mask += batch_offsets.y;
792
+ }
793
+
794
+ } else {
795
+ in_vec += tid.z * vector_batch_stride[0];
796
+ mat += tid.z * matrix_batch_stride[0];
797
+
798
+ if (has_output_mask) {
799
+ out_mask += tid.z * mask_batch_strides[0];
800
+ mask_batch_strides += batch_ndim;
801
+ }
802
+
803
+ if (has_operand_mask) {
804
+ mat_mask += tid.z * mask_batch_strides[0];
805
+ vec_mask += tid.z * mask_batch_strides[batch_ndim];
806
+ }
807
+ }
808
+
809
+ out_vec += tid.z * out_vec_size;
810
+
811
+ gemv_kernel::run(
812
+ mat,
813
+ in_vec,
814
+ out_vec,
815
+ in_vec_size,
816
+ out_vec_size,
817
+ marix_ld,
818
+ out_mask,
819
+ mat_mask,
820
+ vec_mask,
821
+ mask_strides,
822
+ gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
823
+ tid,
824
+ lid,
825
+ simd_gid,
826
+ simd_lid);
827
+ }