mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,2502 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <metal_simdgroup>
4
+ #include <metal_stdlib>
5
+
6
+ constant bool align_M [[function_constant(200)]];
7
+ constant bool align_N [[function_constant(201)]];
8
+ constant bool align_K [[function_constant(202)]];
9
+
10
+ using namespace metal;
11
+
12
+ #define MLX_MTL_CONST static constant constexpr const
13
+
14
+ MLX_MTL_CONST int SIMD_SIZE = 32;
15
+ MLX_MTL_CONST int QUAD_SIZE = 4;
16
+
17
+ template <int bits, int wsize = 8>
18
+ inline constexpr short get_pack_factor() {
19
+ return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
20
+ }
21
+
22
+ template <int bits, int wsize = 8>
23
+ inline constexpr short get_bytes_per_pack() {
24
+ constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
25
+ return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
26
+ }
27
+
28
+ template <typename T, typename U, int values_per_thread, int bits>
29
+ inline U load_vector(const device T* x, thread U* x_thread) {
30
+ static_assert(
31
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
32
+ bits == 8,
33
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
34
+
35
+ U sum = 0;
36
+
37
+ if (bits == 2) {
38
+ for (int i = 0; i < values_per_thread; i += 4) {
39
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
40
+ x_thread[i] = x[i];
41
+ x_thread[i + 1] = x[i + 1] / 4.0f;
42
+ x_thread[i + 2] = x[i + 2] / 16.0f;
43
+ x_thread[i + 3] = x[i + 3] / 64.0f;
44
+ }
45
+ }
46
+
47
+ else if (bits == 3) {
48
+ for (int i = 0; i < values_per_thread; i += 8) {
49
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
50
+ x[i + 6] + x[i + 7];
51
+ x_thread[i] = x[i];
52
+ x_thread[i + 1] = x[i + 1] / 8.0f;
53
+ x_thread[i + 2] = x[i + 2] / 64.0f;
54
+ x_thread[i + 3] = x[i + 3] / 2.0f;
55
+ x_thread[i + 4] = x[i + 4] / 16.0f;
56
+ x_thread[i + 5] = x[i + 5] / 128.0f;
57
+ x_thread[i + 6] = x[i + 6] / 4.0f;
58
+ x_thread[i + 7] = x[i + 7] / 32.0f;
59
+ }
60
+ }
61
+
62
+ else if (bits == 4) {
63
+ for (int i = 0; i < values_per_thread; i += 4) {
64
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
65
+ x_thread[i] = x[i];
66
+ x_thread[i + 1] = x[i + 1] / 16.0f;
67
+ x_thread[i + 2] = x[i + 2] / 256.0f;
68
+ x_thread[i + 3] = x[i + 3] / 4096.0f;
69
+ }
70
+ }
71
+
72
+ else if (bits == 5) {
73
+ for (int i = 0; i < values_per_thread; i += 8) {
74
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
75
+ x[i + 6] + x[i + 7];
76
+ x_thread[i] = x[i];
77
+ x_thread[i + 1] = x[i + 1] / 32.0f;
78
+ x_thread[i + 2] = x[i + 2] / 4.0f;
79
+ x_thread[i + 3] = x[i + 3] / 128.0f;
80
+ x_thread[i + 4] = x[i + 4] / 16.0f;
81
+ x_thread[i + 5] = x[i + 5] / 2.0f;
82
+ x_thread[i + 6] = x[i + 6] / 64.0f;
83
+ x_thread[i + 7] = x[i + 7] / 8.0f;
84
+ }
85
+ }
86
+
87
+ else if (bits == 6) {
88
+ for (int i = 0; i < values_per_thread; i += 4) {
89
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
90
+ x_thread[i] = x[i];
91
+ x_thread[i + 1] = x[i + 1] / 64.0f;
92
+ x_thread[i + 2] = x[i + 2] / 16.0f;
93
+ x_thread[i + 3] = x[i + 3] / 4.0f;
94
+ }
95
+ }
96
+
97
+ else if (bits == 8) {
98
+ for (int i = 0; i < values_per_thread; i++) {
99
+ sum += x[i];
100
+ x_thread[i] = x[i];
101
+ }
102
+ }
103
+
104
+ return sum;
105
+ }
106
+
107
+ template <typename T, typename U, int values_per_thread, int bits>
108
+ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
109
+ static_assert(
110
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
111
+ bits == 8,
112
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
113
+
114
+ U sum = 0;
115
+
116
+ if (bits == 2) {
117
+ for (int i = 0; i < N; i += 4) {
118
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
119
+ x_thread[i] = x[i];
120
+ x_thread[i + 1] = x[i + 1] / 4.0f;
121
+ x_thread[i + 2] = x[i + 2] / 16.0f;
122
+ x_thread[i + 3] = x[i + 3] / 64.0f;
123
+ }
124
+ }
125
+
126
+ else if (bits == 3) {
127
+ for (int i = 0; i < N; i += 8) {
128
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
129
+ x[i + 6] + x[i + 7];
130
+
131
+ x_thread[i] = x[i];
132
+ x_thread[i + 1] = x[i + 1] / 8.0f;
133
+ x_thread[i + 2] = x[i + 2] / 64.0f;
134
+ x_thread[i + 3] = x[i + 3] / 2.0f;
135
+ x_thread[i + 4] = x[i + 4] / 16.0f;
136
+ x_thread[i + 5] = x[i + 5] / 128.0f;
137
+ x_thread[i + 6] = x[i + 6] / 4.0f;
138
+ x_thread[i + 7] = x[i + 7] / 32.0f;
139
+ }
140
+ }
141
+
142
+ else if (bits == 4) {
143
+ for (int i = 0; i < N; i += 4) {
144
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
145
+ x_thread[i] = x[i];
146
+ x_thread[i + 1] = x[i + 1] / 16.0f;
147
+ x_thread[i + 2] = x[i + 2] / 256.0f;
148
+ x_thread[i + 3] = x[i + 3] / 4096.0f;
149
+ }
150
+ }
151
+
152
+ else if (bits == 5) {
153
+ for (int i = 0; i < N; i += 8) {
154
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
155
+ x[i + 6] + x[i + 7];
156
+ x_thread[i] = x[i];
157
+ x_thread[i + 1] = x[i + 1] / 32.0f;
158
+ x_thread[i + 2] = x[i + 2] / 4.0f;
159
+ x_thread[i + 3] = x[i + 3] / 128.0f;
160
+ x_thread[i + 4] = x[i + 4] / 16.0f;
161
+ x_thread[i + 5] = x[i + 5] / 2.0f;
162
+ x_thread[i + 6] = x[i + 6] / 64.0f;
163
+ x_thread[i + 7] = x[i + 7] / 8.0f;
164
+ }
165
+ }
166
+
167
+ else if (bits == 6) {
168
+ for (int i = 0; i < N; i += 4) {
169
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
170
+ x_thread[i] = x[i];
171
+ x_thread[i + 1] = x[i + 1] / 64.0f;
172
+ x_thread[i + 2] = x[i + 2] / 16.0f;
173
+ x_thread[i + 3] = x[i + 3] / 4.0f;
174
+ }
175
+ }
176
+
177
+ else if (bits == 8) {
178
+ for (int i = 0; i < N; i++) {
179
+ sum += x[i];
180
+ x_thread[i] = x[i];
181
+ }
182
+ }
183
+
184
+ for (int i = N; i < values_per_thread; i++) {
185
+ x_thread[i] = 0;
186
+ }
187
+
188
+ return sum;
189
+ }
190
+
191
+ template <typename U, int values_per_thread, int bits>
192
+ inline U qdot(
193
+ const device uint8_t* w,
194
+ const thread U* x_thread,
195
+ U scale,
196
+ U bias,
197
+ U sum) {
198
+ static_assert(
199
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
200
+ bits == 8,
201
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
202
+
203
+ U accum = 0;
204
+
205
+ if (bits == 2) {
206
+ for (int i = 0; i < (values_per_thread / 4); i++) {
207
+ accum +=
208
+ (x_thread[4 * i] * (w[i] & 0x03) +
209
+ x_thread[4 * i + 1] * (w[i] & 0x0c) +
210
+ x_thread[4 * i + 2] * (w[i] & 0x30) +
211
+ x_thread[4 * i + 3] * (w[i] & 0xc0));
212
+ }
213
+ }
214
+
215
+ else if (bits == 3) {
216
+ for (int i = 0; i < (values_per_thread / 8); i++) {
217
+ x_thread += 8 * i;
218
+ w += 3 * i;
219
+
220
+ accum += (w[0] & 0x07) * x_thread[0];
221
+ accum += (w[0] & 0x38) * x_thread[1];
222
+ accum += (w[0] & 0xc0) * x_thread[2];
223
+ accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
224
+
225
+ accum += (w[1] & 0x0e) * x_thread[3];
226
+ accum += (w[1] & 0x70) * x_thread[4];
227
+ accum += (w[1] & 0x80) * x_thread[5];
228
+ accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
229
+
230
+ accum += (w[2] & 0x1c) * x_thread[6];
231
+ accum += (w[2] & 0xe0) * x_thread[7];
232
+ }
233
+ }
234
+
235
+ else if (bits == 4) {
236
+ const device uint16_t* ws = (const device uint16_t*)w;
237
+ for (int i = 0; i < (values_per_thread / 4); i++) {
238
+ accum +=
239
+ (x_thread[4 * i] * (ws[i] & 0x000f) +
240
+ x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
241
+ x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
242
+ x_thread[4 * i + 3] * (ws[i] & 0xf000));
243
+ }
244
+ }
245
+
246
+ else if (bits == 5) {
247
+ for (int i = 0; i < (values_per_thread / 8); i++) {
248
+ x_thread += 8 * i;
249
+ w += 5 * i;
250
+
251
+ accum += (w[0] & 0x1f) * x_thread[0];
252
+ accum += (w[0] & 0xe0) * x_thread[1];
253
+ accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
254
+ accum += (w[1] & 0x7c) * x_thread[2];
255
+ accum += (w[1] & 0x80) * x_thread[3];
256
+ accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
257
+ accum += (w[2] & 0xf0) * x_thread[4];
258
+ accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
259
+ accum += (w[3] & 0x3e) * x_thread[5];
260
+ accum += (w[3] & 0xc0) * x_thread[6];
261
+ accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
262
+ accum += (w[4] & 0xf8) * x_thread[7];
263
+ }
264
+ }
265
+
266
+ else if (bits == 6) {
267
+ for (int i = 0; i < (values_per_thread / 4); i++) {
268
+ x_thread += 4 * i;
269
+ w += 3 * i;
270
+
271
+ accum += (w[0] & 0x3f) * x_thread[0];
272
+
273
+ accum += (w[0] & 0xc0) * x_thread[1];
274
+ accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
275
+
276
+ accum += (w[1] & 0xf0) * x_thread[2];
277
+ accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
278
+
279
+ accum += (w[2] & 0xfc) * x_thread[3];
280
+ }
281
+ }
282
+
283
+ else if (bits == 8) {
284
+ for (int i = 0; i < values_per_thread; i++) {
285
+ accum += x_thread[i] * w[i];
286
+ }
287
+ }
288
+
289
+ return scale * accum + sum * bias;
290
+ }
291
+
292
+ template <typename U, int values_per_thread, int bits>
293
+ inline U qdot_safe(
294
+ const device uint8_t* w,
295
+ const thread U* x_thread,
296
+ U scale,
297
+ U bias,
298
+ U sum,
299
+ int N) {
300
+ static_assert(
301
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
302
+ bits == 8,
303
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
304
+
305
+ U accum = 0;
306
+
307
+ if (bits == 2) {
308
+ for (int i = 0; i < (N / 4); i++) {
309
+ accum +=
310
+ (x_thread[4 * i] * (w[i] & 0x03) +
311
+ x_thread[4 * i + 1] * (w[i] & 0x0c) +
312
+ x_thread[4 * i + 2] * (w[i] & 0x30) +
313
+ x_thread[4 * i + 3] * (w[i] & 0xc0));
314
+ }
315
+ }
316
+
317
+ else if (bits == 3) {
318
+ for (int i = 0; i < (N / 8); i++) {
319
+ x_thread += 8 * i;
320
+ w += 3 * i;
321
+
322
+ accum += (w[0] & 0x07) * x_thread[0];
323
+ accum += (w[0] & 0x38) * x_thread[1];
324
+ accum += (w[0] & 0xc0) * x_thread[2];
325
+ accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
326
+
327
+ accum += (w[1] & 0x0e) * x_thread[3];
328
+ accum += (w[1] & 0x70) * x_thread[4];
329
+ accum += (w[1] & 0x80) * x_thread[5];
330
+ accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
331
+
332
+ accum += (w[2] & 0x1c) * x_thread[6];
333
+ accum += (w[2] & 0xe0) * x_thread[7];
334
+ }
335
+ }
336
+
337
+ else if (bits == 4) {
338
+ const device uint16_t* ws = (const device uint16_t*)w;
339
+ for (int i = 0; i < (N / 4); i++) {
340
+ accum +=
341
+ (x_thread[4 * i] * (ws[i] & 0x000f) +
342
+ x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
343
+ x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
344
+ x_thread[4 * i + 3] * (ws[i] & 0xf000));
345
+ }
346
+ }
347
+
348
+ else if (bits == 5) {
349
+ for (int i = 0; i < (N / 8); i++) {
350
+ x_thread += 8 * i;
351
+ w += 5 * i;
352
+
353
+ accum += (w[0] & 0x1f) * x_thread[0];
354
+ accum += (w[0] & 0xe0) * x_thread[1];
355
+ accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
356
+ accum += (w[1] & 0x7c) * x_thread[2];
357
+ accum += (w[1] & 0x80) * x_thread[3];
358
+ accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
359
+ accum += (w[2] & 0xf0) * x_thread[4];
360
+ accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
361
+ accum += (w[3] & 0x3e) * x_thread[5];
362
+ accum += (w[3] & 0xc0) * x_thread[6];
363
+ accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
364
+ accum += (w[4] & 0xf8) * x_thread[7];
365
+ }
366
+ }
367
+
368
+ else if (bits == 6) {
369
+ for (int i = 0; i < (N / 4); i++) {
370
+ x_thread += 4 * i;
371
+ w += 3 * i;
372
+
373
+ accum += (w[0] & 0x3f) * x_thread[0];
374
+
375
+ accum += (w[0] & 0xc0) * x_thread[1];
376
+ accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
377
+
378
+ accum += (w[1] & 0xf0) * x_thread[2];
379
+ accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
380
+
381
+ accum += (w[2] & 0xfc) * x_thread[3];
382
+ }
383
+ }
384
+
385
+ else if (bits == 8) {
386
+ for (int i = 0; i < N; i++) {
387
+ accum += x_thread[i] * w[i];
388
+ }
389
+ }
390
+
391
+ return scale * accum + sum * bias;
392
+ }
393
+
394
+ template <typename U, int values_per_thread, int bits>
395
+ inline void
396
+ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
397
+ static_assert(
398
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
399
+ bits == 8,
400
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
401
+
402
+ if (bits == 2) {
403
+ U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
404
+ for (int i = 0; i < (values_per_thread / 4); i++) {
405
+ result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
406
+ result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
407
+ result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
408
+ result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
409
+ }
410
+ }
411
+
412
+ else if (bits == 3) {
413
+ for (int i = 0; i < (values_per_thread / 8); i++) {
414
+ uint8_t w0 = w[3 * i];
415
+ uint8_t w1 = w[3 * i + 1];
416
+ uint8_t w2 = w[3 * i + 2];
417
+
418
+ result[8 * i] += x * ((w0 & 0x7) * scale + bias);
419
+ result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
420
+ result[8 * i + 2] +=
421
+ x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
422
+ result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
423
+ result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
424
+ result[8 * i + 5] +=
425
+ x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
426
+ result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
427
+ result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
428
+ }
429
+ }
430
+
431
+ else if (bits == 4) {
432
+ U s[2] = {scale, scale / 16.0f};
433
+ for (int i = 0; i < (values_per_thread / 2); i++) {
434
+ result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
435
+ result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
436
+ }
437
+ }
438
+
439
+ else if (bits == 5) {
440
+ for (int i = 0; i < (values_per_thread / 8); i++) {
441
+ uint8_t w0 = w[5 * i];
442
+ uint8_t w1 = w[5 * i + 1];
443
+ uint8_t w2 = w[5 * i + 2];
444
+ uint8_t w3 = w[5 * i + 3];
445
+ uint8_t w4 = w[5 * i + 4];
446
+ result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
447
+ result[8 * i + 1] +=
448
+ x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
449
+ result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
450
+ result[8 * i + 3] +=
451
+ x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
452
+ result[8 * i + 4] +=
453
+ x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
454
+ result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
455
+ result[8 * i + 6] +=
456
+ x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
457
+ result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
458
+ }
459
+ }
460
+
461
+ else if (bits == 6) {
462
+ for (int i = 0; i < (values_per_thread / 4); i++) {
463
+ uint8_t w0 = w[3 * i];
464
+ uint8_t w1 = w[3 * i + 1];
465
+ uint8_t w2 = w[3 * i + 2];
466
+
467
+ result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
468
+ result[4 * i + 1] +=
469
+ x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
470
+ result[4 * i + 2] +=
471
+ x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
472
+ result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
473
+ }
474
+ }
475
+
476
+ else if (bits == 8) {
477
+ for (int i = 0; i < values_per_thread; i++) {
478
+ result[i] += x * (scale * w[i] + bias);
479
+ }
480
+ }
481
+ }
482
+
483
+ template <typename U, int N, int bits>
484
+ inline void
485
+ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
486
+ static_assert(
487
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
488
+ bits == 8,
489
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
490
+
491
+ if (bits == 2) {
492
+ U s[4] = {
493
+ scale,
494
+ scale / static_cast<U>(4.0f),
495
+ scale / static_cast<U>(16.0f),
496
+ scale / static_cast<U>(64.0f)};
497
+ for (int i = 0; i < (N / 4); i++) {
498
+ w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
499
+ w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
500
+ w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
501
+ w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
502
+ }
503
+ }
504
+
505
+ else if (bits == 3) {
506
+ for (int i = 0; i < (N / 8); i++) {
507
+ w_local += 8 * i;
508
+ w += 3 * i;
509
+
510
+ w_local[0] = (w[0] & 0x7) * scale + bias;
511
+ w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
512
+ w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
513
+ w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
514
+ w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
515
+ w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
516
+ w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
517
+ w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
518
+ }
519
+ }
520
+
521
+ else if (bits == 4) {
522
+ U s[2] = {scale, scale / static_cast<U>(16.0f)};
523
+ for (int i = 0; i < (N / 2); i++) {
524
+ w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
525
+ w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
526
+ }
527
+ }
528
+
529
+ else if (bits == 5) {
530
+ for (int i = 0; i < (N / 8); i++) {
531
+ w_local += 8 * i;
532
+ w += 5 * i;
533
+
534
+ w_local[0] = (w[0] & 0x1f) * scale + bias;
535
+ w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
536
+ w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
537
+ w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
538
+ w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
539
+ w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
540
+ w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
541
+ w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
542
+ }
543
+ }
544
+
545
+ else if (bits == 6) {
546
+ for (int i = 0; i < (N / 4); i++) {
547
+ w_local += 4 * i;
548
+ w += 3 * i;
549
+ w_local[0] = (w[0] & 0x3f) * scale + bias;
550
+ w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
551
+ w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
552
+ w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
553
+ }
554
+ }
555
+
556
+ else if (bits == 8) {
557
+ for (int i = 0; i < N; i++) {
558
+ w_local[i] = scale * w[i] + bias;
559
+ }
560
+ }
561
+ }
562
+
563
+ template <
564
+ typename T,
565
+ short BROWS,
566
+ short BCOLS,
567
+ short dst_ld,
568
+ short reduction_dim,
569
+ short tgp_size,
570
+ short group_size,
571
+ short bits>
572
+ struct QuantizedBlockLoader {
573
+ static_assert(
574
+ BCOLS <= group_size,
575
+ "The group size should be larger than the columns");
576
+ static_assert(
577
+ group_size % BCOLS == 0,
578
+ "The group size should be divisible by the columns");
579
+ static_assert(
580
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
581
+ bits == 8,
582
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
583
+
584
+ MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
585
+ MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
586
+ MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
587
+ MLX_MTL_CONST short n_reads =
588
+ (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
589
+ MLX_MTL_CONST short group_steps = group_size / BCOLS;
590
+
591
+ const int src_ld;
592
+ const int tile_stride;
593
+ short group_step_cnt;
594
+ const int group_stride;
595
+
596
+ const short thread_idx;
597
+ const short bi;
598
+ const short bj;
599
+
600
+ threadgroup T* dst;
601
+ const device uint8_t* src;
602
+ const device T* scales;
603
+ const device T* biases;
604
+
605
+ QuantizedBlockLoader(
606
+ const device uint8_t* src_,
607
+ const device T* scales_,
608
+ const device T* biases_,
609
+ const int src_ld_,
610
+ threadgroup T* dst_,
611
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
612
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
613
+ : src_ld(src_ld_),
614
+ tile_stride(
615
+ reduction_dim ? BCOLS_PACKED * bytes_per_pack
616
+ : BROWS * src_ld * bytes_per_pack / pack_factor),
617
+ group_step_cnt(0),
618
+ group_stride(BROWS * src_ld / group_size),
619
+ thread_idx(simd_group_id * 32 + simd_lane_id),
620
+ bi(n_reads * thread_idx / BCOLS_PACKED),
621
+ bj((n_reads * thread_idx) % BCOLS_PACKED),
622
+ dst(dst_ + bi * dst_ld + bj * pack_factor),
623
+ src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
624
+ bj * bytes_per_pack),
625
+ scales(scales_ + bi * src_ld / group_size),
626
+ biases(biases_ + bi * src_ld / group_size) {}
627
+
628
+ void load_unsafe() const {
629
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
630
+ return;
631
+ }
632
+
633
+ T scale = *scales;
634
+ T bias = *biases;
635
+ for (int i = 0; i < n_reads; i++) {
636
+ dequantize<T, pack_factor, bits>(
637
+ src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
638
+ }
639
+ }
640
+
641
+ void load_safe(short2 src_tile_dim) const {
642
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
643
+ return;
644
+ }
645
+
646
+ if (reduction_dim == 1 && bi >= src_tile_dim.x) {
647
+ for (int i = 0; i < n_reads * pack_factor; i++) {
648
+ dst[i] = T(0);
649
+ }
650
+ return;
651
+ }
652
+
653
+ if (reduction_dim == 0 && bi >= src_tile_dim.y) {
654
+ for (int i = 0; i < n_reads * pack_factor; i++) {
655
+ dst[i] = T(0);
656
+ }
657
+ return;
658
+ }
659
+
660
+ T scale = *scales;
661
+ T bias = *biases;
662
+ for (int i = 0; i < n_reads; i++) {
663
+ dequantize<T, pack_factor, bits>(
664
+ (device uint8_t*)(src + i * bytes_per_pack),
665
+ scale,
666
+ bias,
667
+ dst + i * pack_factor);
668
+ }
669
+ }
670
+
671
+ void next() {
672
+ src += tile_stride;
673
+ if (reduction_dim == 1) {
674
+ if (group_steps > 1) {
675
+ group_step_cnt++;
676
+ if (group_step_cnt == group_steps) {
677
+ group_step_cnt = 0;
678
+ scales++;
679
+ biases++;
680
+ }
681
+ } else {
682
+ scales++;
683
+ biases++;
684
+ }
685
+ } else {
686
+ scales += group_stride;
687
+ biases += group_stride;
688
+ }
689
+ }
690
+ };
691
+
692
+ template <typename T, int group_size, int bits, int D>
693
+ METAL_FUNC void qmv_quad_impl(
694
+ const device uint32_t* w,
695
+ const device T* scales,
696
+ const device T* biases,
697
+ const device T* x,
698
+ device T* y,
699
+ constant int& in_vec_size,
700
+ const constant int& out_vec_size,
701
+ uint3 tid [[threadgroup_position_in_grid]],
702
+ uint quad_gid [[quadgroup_index_in_threadgroup]],
703
+ uint quad_lid [[thread_index_in_quadgroup]]) {
704
+ constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
705
+ constexpr int pack_factor = 32 / bits;
706
+ constexpr int values_per_thread = D / QUAD_SIZE;
707
+ constexpr int packs_per_thread = values_per_thread / pack_factor;
708
+ constexpr int scale_step_per_thread = group_size / values_per_thread;
709
+ constexpr int results_per_quadgroup = 8;
710
+
711
+ typedef float U;
712
+
713
+ thread U x_thread[values_per_thread];
714
+ thread U result[results_per_quadgroup] = {0};
715
+
716
+ // Adjust positions
717
+ const int in_vec_size_w = in_vec_size / pack_factor;
718
+ const int in_vec_size_g = in_vec_size / group_size;
719
+ const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
720
+
721
+ w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
722
+ scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
723
+ biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
724
+ x += tid.x * in_vec_size + quad_lid * values_per_thread;
725
+ y += tid.x * out_vec_size + out_row;
726
+
727
+ U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
728
+
729
+ for (int row = 0; row < results_per_quadgroup; row++) {
730
+ auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
731
+ const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
732
+ const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
733
+
734
+ U s = sl[0];
735
+ U b = bl[0];
736
+ if (row * quads_per_simd + out_row < out_vec_size) {
737
+ result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
738
+ }
739
+ }
740
+
741
+ for (int row = 0; row < results_per_quadgroup; row++) {
742
+ result[row] = quad_sum(result[row]);
743
+ if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
744
+ y[row * quads_per_simd] = static_cast<T>(result[row]);
745
+ }
746
+ }
747
+ }
748
+
749
+ template <typename T, int group_size, int bits>
750
+ METAL_FUNC void qmv_fast_impl(
751
+ const device uint32_t* w,
752
+ const device T* scales,
753
+ const device T* biases,
754
+ const device T* x,
755
+ device T* y,
756
+ const constant int& in_vec_size,
757
+ const constant int& out_vec_size,
758
+ uint3 tid [[threadgroup_position_in_grid]],
759
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
760
+ uint simd_lid [[thread_index_in_simdgroup]]) {
761
+ constexpr int packs_per_thread = bits == 2 ? 1 : 2;
762
+ constexpr int num_simdgroups = 2;
763
+ constexpr int results_per_simdgroup = 4;
764
+ constexpr int pack_factor = get_pack_factor<bits, 32>();
765
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
766
+ constexpr int values_per_thread = pack_factor * packs_per_thread;
767
+ constexpr int block_size = values_per_thread * SIMD_SIZE;
768
+ constexpr int scale_step_per_thread = group_size / values_per_thread;
769
+
770
+ const device uint8_t* ws = (const device uint8_t*)w;
771
+
772
+ typedef float U;
773
+
774
+ thread U x_thread[values_per_thread];
775
+ thread U result[results_per_simdgroup] = {0};
776
+
777
+ // Adjust positions
778
+ const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
779
+ const int in_vec_size_g = in_vec_size / group_size;
780
+ const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
781
+ simd_gid * results_per_simdgroup;
782
+
783
+ ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
784
+ scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
785
+ biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
786
+ x += tid.x * in_vec_size + simd_lid * values_per_thread;
787
+ y += tid.x * out_vec_size + out_row;
788
+
789
+ for (int k = 0; k < in_vec_size; k += block_size) {
790
+ U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
791
+
792
+ for (int row = 0; row < results_per_simdgroup; row++) {
793
+ auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
794
+ const device T* sl = scales + row * in_vec_size_g;
795
+ const device T* bl = biases + row * in_vec_size_g;
796
+
797
+ U s = sl[0];
798
+ U b = bl[0];
799
+ result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
800
+ }
801
+
802
+ ws += block_size * bytes_per_pack / pack_factor;
803
+ scales += block_size / group_size;
804
+ biases += block_size / group_size;
805
+ x += block_size;
806
+ }
807
+
808
+ for (int row = 0; row < results_per_simdgroup; row++) {
809
+ result[row] = simd_sum(result[row]);
810
+ if (simd_lid == 0) {
811
+ y[row] = static_cast<T>(result[row]);
812
+ }
813
+ }
814
+ }
815
+
816
+ template <typename T, int group_size, int bits>
817
+ METAL_FUNC void qmv_impl(
818
+ const device uint32_t* w,
819
+ const device T* scales,
820
+ const device T* biases,
821
+ const device T* x,
822
+ device T* y,
823
+ const constant int& in_vec_size,
824
+ const constant int& out_vec_size,
825
+ uint3 tid [[threadgroup_position_in_grid]],
826
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
827
+ uint simd_lid [[thread_index_in_simdgroup]]) {
828
+ constexpr int num_simdgroups = 2;
829
+ constexpr int results_per_simdgroup = 4;
830
+ constexpr int packs_per_thread = 1;
831
+ constexpr int pack_factor = get_pack_factor<bits, 32>();
832
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
833
+
834
+ constexpr int values_per_thread = pack_factor * packs_per_thread;
835
+ constexpr int block_size = values_per_thread * SIMD_SIZE;
836
+ constexpr int scale_step_per_thread = group_size / values_per_thread;
837
+
838
+ const device uint8_t* ws = (const device uint8_t*)w;
839
+
840
+ typedef float U;
841
+
842
+ thread U x_thread[values_per_thread];
843
+ thread U result[results_per_simdgroup] = {0};
844
+
845
+ // Adjust positions
846
+ const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
847
+ const int in_vec_size_g = in_vec_size / group_size;
848
+ const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
849
+ simd_gid * results_per_simdgroup;
850
+ const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
851
+
852
+ if (out_row >= out_vec_size) {
853
+ return;
854
+ }
855
+
856
+ // In this case we need to properly guard all our reads because there isn't
857
+ // even 1 tile in the matrix
858
+ if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
859
+ ws +=
860
+ out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
861
+ scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
862
+ biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
863
+ x += tid.x * in_vec_size + simd_lid * values_per_thread;
864
+ y += tid.x * out_vec_size + out_row;
865
+
866
+ int k = 0;
867
+ for (; k < in_vec_size - block_size; k += block_size) {
868
+ U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
869
+
870
+ for (int row = 0; out_row + row < out_vec_size; row++) {
871
+ auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
872
+ const device T* sl = scales + row * in_vec_size_g;
873
+ const device T* bl = biases + row * in_vec_size_g;
874
+
875
+ U s = sl[0];
876
+ U b = bl[0];
877
+ result[row] +=
878
+ qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
879
+ }
880
+
881
+ ws += block_size * bytes_per_pack / pack_factor;
882
+ scales += block_size / group_size;
883
+ biases += block_size / group_size;
884
+ x += block_size;
885
+ }
886
+ const int remaining = clamp(
887
+ static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
888
+ 0,
889
+ values_per_thread);
890
+ if (remaining > 0) {
891
+ U sum = load_vector_safe<T, U, values_per_thread, bits>(
892
+ x, x_thread, remaining);
893
+
894
+ for (int row = 0; out_row + row < out_vec_size; row++) {
895
+ auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
896
+ const device T* sl = scales + row * in_vec_size_g;
897
+ const device T* bl = biases + row * in_vec_size_g;
898
+
899
+ U s = sl[0];
900
+ U b = bl[0];
901
+ result[row] +=
902
+ qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
903
+ }
904
+ }
905
+
906
+ for (int row = 0; out_row + row < out_vec_size; row++) {
907
+ result[row] = simd_sum(result[row]);
908
+ if (simd_lid == 0) {
909
+ y[row] = static_cast<T>(result[row]);
910
+ }
911
+ }
912
+ }
913
+
914
+ // In this case the last tile is moved back to redo some output values
915
+ else {
916
+ ws += used_out_row * in_vec_size_w +
917
+ simd_lid * packs_per_thread * bytes_per_pack;
918
+ scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
919
+ biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
920
+ x += tid.x * in_vec_size + simd_lid * values_per_thread;
921
+ y += tid.x * out_vec_size + used_out_row;
922
+
923
+ int k = 0;
924
+ for (; k < in_vec_size - block_size; k += block_size) {
925
+ U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
926
+
927
+ for (int row = 0; row < results_per_simdgroup; row++) {
928
+ auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
929
+ const device T* sl = scales + row * in_vec_size_g;
930
+ const device T* bl = biases + row * in_vec_size_g;
931
+
932
+ U s = sl[0];
933
+ U b = bl[0];
934
+ result[row] +=
935
+ qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
936
+ }
937
+
938
+ ws += block_size * bytes_per_pack / pack_factor;
939
+ scales += block_size / group_size;
940
+ biases += block_size / group_size;
941
+ x += block_size;
942
+ }
943
+ const int remaining = clamp(
944
+ static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
945
+ 0,
946
+ values_per_thread);
947
+ if (remaining > 0) {
948
+ U sum = load_vector_safe<T, U, values_per_thread, bits>(
949
+ x, x_thread, remaining);
950
+
951
+ for (int row = 0; row < results_per_simdgroup; row++) {
952
+ auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
953
+ const device T* sl = scales + row * in_vec_size_g;
954
+ const device T* bl = biases + row * in_vec_size_g;
955
+
956
+ U s = sl[0];
957
+ U b = bl[0];
958
+ result[row] += qdot_safe<U, values_per_thread, bits>(
959
+ wl, x_thread, s, b, sum, remaining);
960
+ }
961
+ }
962
+ for (int row = 0; row < results_per_simdgroup; row++) {
963
+ result[row] = simd_sum(result[row]);
964
+ if (simd_lid == 0) {
965
+ y[row] = static_cast<T>(result[row]);
966
+ }
967
+ }
968
+ }
969
+ }
970
+
971
+ template <typename T, const int group_size, const int bits>
972
+ METAL_FUNC void qvm_impl(
973
+ const device uint32_t* w,
974
+ const device T* scales,
975
+ const device T* biases,
976
+ const device T* x,
977
+ device T* y,
978
+ const int in_vec_size,
979
+ const int out_vec_size,
980
+ uint3 tid [[threadgroup_position_in_grid]],
981
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
982
+ uint simd_lid [[thread_index_in_simdgroup]]) {
983
+ constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
984
+ constexpr int num_simdgroups = 2;
985
+ constexpr int pack_factor = get_pack_factor<bits, 32>();
986
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
987
+
988
+ constexpr int tn = 32 / pack_factor;
989
+ constexpr int block_size = SIMD_SIZE;
990
+
991
+ using W_T =
992
+ typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
993
+ const device W_T* ws = (const device W_T*)w;
994
+
995
+ typedef float U;
996
+ typedef struct {
997
+ W_T wi[tn * bytes_per_pack];
998
+ } vec_w;
999
+
1000
+ thread vec_w w_local;
1001
+ thread U result[tn * pack_factor] = {0};
1002
+ thread U scale = 1;
1003
+ thread U bias = 0;
1004
+ thread U x_local = 0;
1005
+
1006
+ // Adjust positions
1007
+ const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
1008
+ const int out_vec_size_g = out_vec_size / group_size;
1009
+ int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
1010
+ ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
1011
+ scales += out_col / group_size + simd_lid * out_vec_size_g;
1012
+ biases += out_col / group_size + simd_lid * out_vec_size_g;
1013
+ x += tid.x * in_vec_size + simd_lid;
1014
+ y += tid.x * out_vec_size + out_col;
1015
+
1016
+ if (out_col >= out_vec_size) {
1017
+ return;
1018
+ }
1019
+
1020
+ // Loop over in_vec in blocks of block_size
1021
+ int remaining = in_vec_size % block_size;
1022
+ if (remaining == 0) {
1023
+ for (int i = 0; i < in_vec_size; i += block_size) {
1024
+ x_local = *x;
1025
+ scale = *scales;
1026
+ bias = *biases;
1027
+ w_local = *((device vec_w*)ws);
1028
+ qouter<U, tn * pack_factor, bits>(
1029
+ (thread uint8_t*)&w_local, x_local, scale, bias, result);
1030
+
1031
+ x += block_size;
1032
+ scales += block_size * out_vec_size_g;
1033
+ biases += block_size * out_vec_size_g;
1034
+ ws += block_size * out_vec_size_w;
1035
+ }
1036
+ } else {
1037
+ for (int i = block_size; i < in_vec_size; i += block_size) {
1038
+ x_local = *x;
1039
+ scale = *scales;
1040
+ bias = *biases;
1041
+ w_local = *((device vec_w*)ws);
1042
+
1043
+ qouter<U, tn * pack_factor, bits>(
1044
+ (thread uint8_t*)&w_local, x_local, scale, bias, result);
1045
+
1046
+ x += block_size;
1047
+ scales += block_size * out_vec_size_g;
1048
+ biases += block_size * out_vec_size_g;
1049
+ ws += block_size * out_vec_size_w;
1050
+ }
1051
+ if (static_cast<int>(simd_lid) < remaining) {
1052
+ x_local = *x;
1053
+ scale = *scales;
1054
+ bias = *biases;
1055
+ w_local = *((device vec_w*)ws);
1056
+ } else {
1057
+ x_local = 0;
1058
+ scale = 0;
1059
+ bias = 0;
1060
+ }
1061
+ qouter<U, tn * pack_factor, bits>(
1062
+ (thread uint8_t*)&w_local, x_local, scale, bias, result);
1063
+ }
1064
+
1065
+ // Accumulate in the simdgroup
1066
+ #pragma clang loop unroll(full)
1067
+ for (int k = 0; k < tn * pack_factor; k++) {
1068
+ result[k] = simd_sum(result[k]);
1069
+ }
1070
+
1071
+ // Store the result
1072
+ if (simd_lid == 0) {
1073
+ #pragma clang loop unroll(full)
1074
+ for (int k = 0; k < tn * pack_factor; k++) {
1075
+ y[k] = static_cast<T>(result[k]);
1076
+ }
1077
+ }
1078
+ }
1079
+
1080
+ template <
1081
+ typename T,
1082
+ const int group_size,
1083
+ const int bits,
1084
+ const bool aligned_N,
1085
+ const int BM = 32,
1086
+ const int BK = 32,
1087
+ const int BN = 32>
1088
+ METAL_FUNC void qmm_t_impl(
1089
+ const device uint32_t* w,
1090
+ const device T* scales,
1091
+ const device T* biases,
1092
+ const device T* x,
1093
+ device T* y,
1094
+ threadgroup T* Xs,
1095
+ threadgroup T* Ws,
1096
+ const constant int& K,
1097
+ const constant int& N,
1098
+ const constant int& M,
1099
+ uint3 tid [[threadgroup_position_in_grid]],
1100
+ uint lid [[thread_index_in_threadgroup]],
1101
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1102
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1103
+ static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
1104
+ static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
1105
+
1106
+ (void)lid;
1107
+
1108
+ constexpr int WM = 2;
1109
+ constexpr int WN = 2;
1110
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
1111
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
1112
+
1113
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1114
+
1115
+ // Instantiate the appropriate BlockMMA and Loader
1116
+ using mma_t = mlx::steel::
1117
+ BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
1118
+ using loader_x_t =
1119
+ mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
1120
+ using loader_w_t = QuantizedBlockLoader<
1121
+ T,
1122
+ BN,
1123
+ BK,
1124
+ BK_padded,
1125
+ 1,
1126
+ WM * WN * SIMD_SIZE,
1127
+ group_size,
1128
+ bits>;
1129
+
1130
+ // Set the block
1131
+ const int K_w = K * bytes_per_pack / pack_factor;
1132
+ const int K_g = K / group_size;
1133
+ const int y_row = tid.y * BM;
1134
+ const int y_col = tid.x * BN;
1135
+
1136
+ auto wl = (const device uint8_t*)w;
1137
+
1138
+ x += y_row * static_cast<int64_t>(K);
1139
+ wl += y_col * K_w;
1140
+ scales += y_col * K_g;
1141
+ biases += y_col * K_g;
1142
+ y += y_row * static_cast<int64_t>(N) + y_col;
1143
+
1144
+ // Make the x loader and mma operation
1145
+ const short num_els = min(BM, M - y_row);
1146
+ const short num_outs = min(BN, N - y_col);
1147
+ loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1148
+ loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1149
+ mma_t mma_op(simd_gid, simd_lid);
1150
+
1151
+ if (num_els < BM) {
1152
+ if (!aligned_N && num_outs < BN) {
1153
+ for (int k = 0; k < K; k += BK) {
1154
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1155
+ loader_x.load_safe(short2(BK, num_els));
1156
+ loader_w.load_safe(short2(BK, num_outs));
1157
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1158
+ mma_op.mma(Xs, Ws);
1159
+ loader_x.next();
1160
+ loader_w.next();
1161
+ }
1162
+ } else {
1163
+ for (int k = 0; k < K; k += BK) {
1164
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1165
+ loader_x.load_safe(short2(BK, num_els));
1166
+ loader_w.load_unsafe();
1167
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1168
+ mma_op.mma(Xs, Ws);
1169
+ loader_x.next();
1170
+ loader_w.next();
1171
+ }
1172
+ }
1173
+ } else {
1174
+ if (!aligned_N && num_outs < BN) {
1175
+ for (int k = 0; k < K; k += BK) {
1176
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1177
+ loader_x.load_unsafe();
1178
+ loader_w.load_safe(short2(BK, num_outs));
1179
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1180
+ mma_op.mma(Xs, Ws);
1181
+ loader_x.next();
1182
+ loader_w.next();
1183
+ }
1184
+ } else {
1185
+ for (int k = 0; k < K; k += BK) {
1186
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1187
+ loader_x.load_unsafe();
1188
+ loader_w.load_unsafe();
1189
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1190
+
1191
+ mma_op.mma(Xs, Ws);
1192
+ loader_x.next();
1193
+ loader_w.next();
1194
+ }
1195
+ }
1196
+ }
1197
+
1198
+ // Store results to device memory
1199
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1200
+ if (num_els < BM || num_outs < BN) {
1201
+ mma_op.store_result_safe(y, N, short2(num_outs, num_els));
1202
+ } else {
1203
+ mma_op.store_result(y, N);
1204
+ }
1205
+ }
1206
+
1207
+ template <
1208
+ typename T,
1209
+ const int group_size,
1210
+ const int bits,
1211
+ const int BM = 32,
1212
+ const int BK = 32,
1213
+ const int BN = 32>
1214
+ METAL_FUNC void qmm_n_impl(
1215
+ const device uint32_t* w,
1216
+ const device T* scales,
1217
+ const device T* biases,
1218
+ const device T* x,
1219
+ device T* y,
1220
+ threadgroup T* Xs,
1221
+ threadgroup T* Ws,
1222
+ const constant int& K,
1223
+ const constant int& N,
1224
+ const constant int& M,
1225
+ uint3 tid [[threadgroup_position_in_grid]],
1226
+ uint lid [[thread_index_in_threadgroup]],
1227
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1228
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1229
+ static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
1230
+ static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
1231
+
1232
+ (void)lid;
1233
+
1234
+ constexpr int WM = 2;
1235
+ constexpr int WN = 2;
1236
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
1237
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
1238
+
1239
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1240
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
1241
+
1242
+ // Instantiate the appropriate BlockMMA and Loader
1243
+ using mma_t = mlx::steel::
1244
+ BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
1245
+ using loader_x_t = mlx::steel::
1246
+ BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
1247
+ using loader_w_t = QuantizedBlockLoader<
1248
+ T,
1249
+ BK,
1250
+ BN,
1251
+ BN_padded,
1252
+ 0,
1253
+ WM * WN * SIMD_SIZE,
1254
+ group_size,
1255
+ bits>;
1256
+
1257
+ auto wl = (const device uint8_t*)w;
1258
+
1259
+ // Set the block
1260
+ const int y_row = tid.y * BM;
1261
+ const int y_col = tid.x * BN;
1262
+ x += y_row * static_cast<int64_t>(K);
1263
+ wl += y_col * bytes_per_pack / pack_factor;
1264
+ scales += y_col / group_size;
1265
+ biases += y_col / group_size;
1266
+ y += y_row * static_cast<int64_t>(N) + y_col;
1267
+
1268
+ // Make the x loader and mma operation
1269
+ const short num_els = min(BM, M - y_row);
1270
+ loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1271
+ loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
1272
+ mma_t mma_op(simd_gid, simd_lid);
1273
+
1274
+ if (num_els < BM) {
1275
+ if ((K % BK) != 0) {
1276
+ const int k_blocks = K / BK;
1277
+ for (int k = 0; k < k_blocks; k++) {
1278
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1279
+ loader_x.load_safe(short2(BK, num_els));
1280
+ loader_w.load_unsafe();
1281
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1282
+ mma_op.mma(Xs, Ws);
1283
+ loader_x.next();
1284
+ loader_w.next();
1285
+ }
1286
+ const short num_k = K - k_blocks * BK;
1287
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1288
+ loader_x.load_safe(short2(num_k, num_els));
1289
+ loader_w.load_safe(short2(BN, num_k));
1290
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1291
+ mma_op.mma(Xs, Ws);
1292
+ } else {
1293
+ for (int k = 0; k < K; k += BK) {
1294
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1295
+ loader_x.load_safe(short2(BK, num_els));
1296
+ loader_w.load_unsafe();
1297
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1298
+ mma_op.mma(Xs, Ws);
1299
+ loader_x.next();
1300
+ loader_w.next();
1301
+ }
1302
+ }
1303
+ } else {
1304
+ if ((K % BK) != 0) {
1305
+ const int k_blocks = K / BK;
1306
+ for (int k = 0; k < k_blocks; k++) {
1307
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1308
+ loader_x.load_unsafe();
1309
+ loader_w.load_unsafe();
1310
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1311
+ mma_op.mma(Xs, Ws);
1312
+ loader_x.next();
1313
+ loader_w.next();
1314
+ }
1315
+ const short num_k = K - k_blocks * BK;
1316
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1317
+ loader_x.load_safe(short2(num_k, BM));
1318
+ loader_w.load_safe(short2(BN, num_k));
1319
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1320
+ mma_op.mma(Xs, Ws);
1321
+ } else {
1322
+ for (int k = 0; k < K; k += BK) {
1323
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1324
+ loader_x.load_unsafe();
1325
+ loader_w.load_unsafe();
1326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1327
+ mma_op.mma(Xs, Ws);
1328
+ loader_x.next();
1329
+ loader_w.next();
1330
+ }
1331
+ }
1332
+ }
1333
+
1334
+ // Store results to device memory
1335
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1336
+ if (num_els < BM) {
1337
+ mma_op.store_result_safe(y, N, short2(BN, num_els));
1338
+ } else {
1339
+ mma_op.store_result(y, N);
1340
+ }
1341
+ }
1342
+
1343
+ template <typename T>
1344
+ METAL_FUNC void adjust_matrix_offsets(
1345
+ const device T*& x,
1346
+ const device uint32_t*& w,
1347
+ const device T*& scales,
1348
+ const device T*& biases,
1349
+ device T*& y,
1350
+ int output_stride,
1351
+ const constant int& x_batch_ndims,
1352
+ const constant int* x_shape,
1353
+ const constant int64_t* x_strides,
1354
+ const constant int& w_batch_ndims,
1355
+ const constant int* w_shape,
1356
+ const constant int64_t* w_strides,
1357
+ const constant int64_t* s_strides,
1358
+ const constant int64_t* b_strides,
1359
+ uint3 tid [[threadgroup_position_in_grid]]) {
1360
+ // Set the input/output matrices
1361
+ uint32_t x_idx = tid.z;
1362
+ uint32_t w_idx = tid.z;
1363
+ if (x_batch_ndims == 1) {
1364
+ x += x_idx * x_strides[0];
1365
+ } else {
1366
+ x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1367
+ }
1368
+ if (w_batch_ndims == 1) {
1369
+ w += w_idx * w_strides[0];
1370
+ scales += w_idx * s_strides[0];
1371
+ biases += w_idx * b_strides[0];
1372
+ } else {
1373
+ ulong3 idx = elem_to_loc_broadcast(
1374
+ w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1375
+ w += idx.x;
1376
+ scales += idx.y;
1377
+ biases += idx.z;
1378
+ }
1379
+ y += tid.z * output_stride;
1380
+ }
1381
+
1382
+ template <typename T>
1383
+ METAL_FUNC void adjust_matrix_offsets(
1384
+ const device T*& x,
1385
+ const device uint32_t*& w,
1386
+ const device T*& scales,
1387
+ const device T*& biases,
1388
+ const device uint32_t* lhs_indices,
1389
+ const device uint32_t* rhs_indices,
1390
+ device T*& y,
1391
+ int output_stride,
1392
+ const constant int& batch_ndims,
1393
+ const constant int* batch_shape,
1394
+ const constant int64_t* lhs_strides,
1395
+ const constant int64_t* rhs_strides,
1396
+ const constant int& x_batch_ndims,
1397
+ const constant int* x_shape,
1398
+ const constant int64_t* x_strides,
1399
+ const constant int& w_batch_ndims,
1400
+ const constant int* w_shape,
1401
+ const constant int64_t* w_strides,
1402
+ const constant int64_t* s_strides,
1403
+ const constant int64_t* b_strides,
1404
+ uint3 tid [[threadgroup_position_in_grid]]) {
1405
+ // Set the input/output matrices
1406
+ uint32_t x_idx;
1407
+ uint32_t w_idx;
1408
+ if (batch_ndims == 1) {
1409
+ x_idx = lhs_indices[tid.z * lhs_strides[0]];
1410
+ w_idx = rhs_indices[tid.z * rhs_strides[0]];
1411
+ } else {
1412
+ ulong2 idx = elem_to_loc_broadcast(
1413
+ tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1414
+ x_idx = lhs_indices[idx.x];
1415
+ w_idx = rhs_indices[idx.y];
1416
+ }
1417
+ if (x_batch_ndims == 1) {
1418
+ x += x_idx * x_strides[0];
1419
+ } else {
1420
+ x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1421
+ }
1422
+ if (w_batch_ndims == 1) {
1423
+ w += w_idx * w_strides[0];
1424
+ scales += w_idx * s_strides[0];
1425
+ biases += w_idx * b_strides[0];
1426
+ } else {
1427
+ ulong3 idx = elem_to_loc_broadcast(
1428
+ w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1429
+ w += idx.x;
1430
+ scales += idx.y;
1431
+ biases += idx.z;
1432
+ }
1433
+ y += tid.z * output_stride;
1434
+ }
1435
+
1436
+ template <typename T, int group_size, int bits, int D, bool batched>
1437
+ [[kernel]] void affine_qmv_quad(
1438
+ const device uint32_t* w [[buffer(0)]],
1439
+ const device T* scales [[buffer(1)]],
1440
+ const device T* biases [[buffer(2)]],
1441
+ const device T* x [[buffer(3)]],
1442
+ device T* y [[buffer(4)]],
1443
+ const constant int& in_vec_size [[buffer(5)]],
1444
+ const constant int& out_vec_size [[buffer(6)]],
1445
+ const constant int& x_batch_ndims [[buffer(7)]],
1446
+ const constant int* x_shape [[buffer(8)]],
1447
+ const constant int64_t* x_strides [[buffer(9)]],
1448
+ const constant int& w_batch_ndims [[buffer(10)]],
1449
+ const constant int* w_shape [[buffer(11)]],
1450
+ const constant int64_t* w_strides [[buffer(12)]],
1451
+ const constant int64_t* s_strides [[buffer(13)]],
1452
+ const constant int64_t* b_strides [[buffer(14)]],
1453
+ uint3 tid [[threadgroup_position_in_grid]],
1454
+ uint quad_gid [[quadgroup_index_in_threadgroup]],
1455
+ uint quad_lid [[thread_index_in_quadgroup]]) {
1456
+ if (batched) {
1457
+ int M = x_shape[x_batch_ndims];
1458
+ adjust_matrix_offsets<T>(
1459
+ x,
1460
+ w,
1461
+ scales,
1462
+ biases,
1463
+ y,
1464
+ out_vec_size * M,
1465
+ x_batch_ndims,
1466
+ x_shape,
1467
+ x_strides,
1468
+ w_batch_ndims,
1469
+ w_shape,
1470
+ w_strides,
1471
+ s_strides,
1472
+ b_strides,
1473
+ tid);
1474
+ }
1475
+ qmv_quad_impl<T, group_size, bits, D>(
1476
+ w,
1477
+ scales,
1478
+ biases,
1479
+ x,
1480
+ y,
1481
+ in_vec_size,
1482
+ out_vec_size,
1483
+ tid,
1484
+ quad_gid,
1485
+ quad_lid);
1486
+ }
1487
+
1488
+ template <typename T, int group_size, int bits, bool batched>
1489
+ [[kernel]] void affine_qmv_fast(
1490
+ const device uint32_t* w [[buffer(0)]],
1491
+ const device T* scales [[buffer(1)]],
1492
+ const device T* biases [[buffer(2)]],
1493
+ const device T* x [[buffer(3)]],
1494
+ device T* y [[buffer(4)]],
1495
+ const constant int& in_vec_size [[buffer(5)]],
1496
+ const constant int& out_vec_size [[buffer(6)]],
1497
+ const constant int& x_batch_ndims [[buffer(7)]],
1498
+ const constant int* x_shape [[buffer(8)]],
1499
+ const constant int64_t* x_strides [[buffer(9)]],
1500
+ const constant int& w_batch_ndims [[buffer(10)]],
1501
+ const constant int* w_shape [[buffer(11)]],
1502
+ const constant int64_t* w_strides [[buffer(12)]],
1503
+ const constant int64_t* s_strides [[buffer(13)]],
1504
+ const constant int64_t* b_strides [[buffer(14)]],
1505
+ uint3 tid [[threadgroup_position_in_grid]],
1506
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1507
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1508
+ if (batched) {
1509
+ int M = x_shape[x_batch_ndims];
1510
+ adjust_matrix_offsets<T>(
1511
+ x,
1512
+ w,
1513
+ scales,
1514
+ biases,
1515
+ y,
1516
+ out_vec_size * M,
1517
+ x_batch_ndims,
1518
+ x_shape,
1519
+ x_strides,
1520
+ w_batch_ndims,
1521
+ w_shape,
1522
+ w_strides,
1523
+ s_strides,
1524
+ b_strides,
1525
+ tid);
1526
+ }
1527
+ qmv_fast_impl<T, group_size, bits>(
1528
+ w,
1529
+ scales,
1530
+ biases,
1531
+ x,
1532
+ y,
1533
+ in_vec_size,
1534
+ out_vec_size,
1535
+ tid,
1536
+ simd_gid,
1537
+ simd_lid);
1538
+ }
1539
+
1540
+ template <typename T, const int group_size, const int bits, bool batched>
1541
+ [[kernel]] void affine_qmv(
1542
+ const device uint32_t* w [[buffer(0)]],
1543
+ const device T* scales [[buffer(1)]],
1544
+ const device T* biases [[buffer(2)]],
1545
+ const device T* x [[buffer(3)]],
1546
+ device T* y [[buffer(4)]],
1547
+ const constant int& in_vec_size [[buffer(5)]],
1548
+ const constant int& out_vec_size [[buffer(6)]],
1549
+ const constant int& x_batch_ndims [[buffer(7)]],
1550
+ const constant int* x_shape [[buffer(8)]],
1551
+ const constant int64_t* x_strides [[buffer(9)]],
1552
+ const constant int& w_batch_ndims [[buffer(10)]],
1553
+ const constant int* w_shape [[buffer(11)]],
1554
+ const constant int64_t* w_strides [[buffer(12)]],
1555
+ const constant int64_t* s_strides [[buffer(13)]],
1556
+ const constant int64_t* b_strides [[buffer(14)]],
1557
+ uint3 tid [[threadgroup_position_in_grid]],
1558
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1559
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1560
+ if (batched) {
1561
+ int M = x_shape[x_batch_ndims];
1562
+ adjust_matrix_offsets<T>(
1563
+ x,
1564
+ w,
1565
+ scales,
1566
+ biases,
1567
+ y,
1568
+ out_vec_size * M,
1569
+ x_batch_ndims,
1570
+ x_shape,
1571
+ x_strides,
1572
+ w_batch_ndims,
1573
+ w_shape,
1574
+ w_strides,
1575
+ s_strides,
1576
+ b_strides,
1577
+ tid);
1578
+ }
1579
+ qmv_impl<T, group_size, bits>(
1580
+ w,
1581
+ scales,
1582
+ biases,
1583
+ x,
1584
+ y,
1585
+ in_vec_size,
1586
+ out_vec_size,
1587
+ tid,
1588
+ simd_gid,
1589
+ simd_lid);
1590
+ }
1591
+
1592
+ template <typename T, const int group_size, const int bits, bool batched>
1593
+ [[kernel]] void affine_qvm(
1594
+ const device uint32_t* w [[buffer(0)]],
1595
+ const device T* scales [[buffer(1)]],
1596
+ const device T* biases [[buffer(2)]],
1597
+ const device T* x [[buffer(3)]],
1598
+ device T* y [[buffer(4)]],
1599
+ const constant int& in_vec_size [[buffer(5)]],
1600
+ const constant int& out_vec_size [[buffer(6)]],
1601
+ const constant int& x_batch_ndims [[buffer(7)]],
1602
+ const constant int* x_shape [[buffer(8)]],
1603
+ const constant int64_t* x_strides [[buffer(9)]],
1604
+ const constant int& w_batch_ndims [[buffer(10)]],
1605
+ const constant int* w_shape [[buffer(11)]],
1606
+ const constant int64_t* w_strides [[buffer(12)]],
1607
+ const constant int64_t* s_strides [[buffer(13)]],
1608
+ const constant int64_t* b_strides [[buffer(14)]],
1609
+ uint3 tid [[threadgroup_position_in_grid]],
1610
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1611
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1612
+ if (batched) {
1613
+ int M = x_shape[x_batch_ndims];
1614
+ adjust_matrix_offsets<T>(
1615
+ x,
1616
+ w,
1617
+ scales,
1618
+ biases,
1619
+ y,
1620
+ out_vec_size * M,
1621
+ x_batch_ndims,
1622
+ x_shape,
1623
+ x_strides,
1624
+ w_batch_ndims,
1625
+ w_shape,
1626
+ w_strides,
1627
+ s_strides,
1628
+ b_strides,
1629
+ tid);
1630
+ }
1631
+ qvm_impl<T, group_size, bits>(
1632
+ w,
1633
+ scales,
1634
+ biases,
1635
+ x,
1636
+ y,
1637
+ in_vec_size,
1638
+ out_vec_size,
1639
+ tid,
1640
+ simd_gid,
1641
+ simd_lid);
1642
+ }
1643
+
1644
+ template <typename T, const int group_size, const int bits, int split_k = 32>
1645
+ [[kernel]] void affine_qvm_split_k(
1646
+ const device uint32_t* w [[buffer(0)]],
1647
+ const device T* scales [[buffer(1)]],
1648
+ const device T* biases [[buffer(2)]],
1649
+ const device T* x [[buffer(3)]],
1650
+ device T* y [[buffer(4)]],
1651
+ const constant int& in_vec_size [[buffer(5)]],
1652
+ const constant int& out_vec_size [[buffer(6)]],
1653
+ const constant int& x_batch_ndims [[buffer(7)]],
1654
+ const constant int* x_shape [[buffer(8)]],
1655
+ const constant int64_t* x_strides [[buffer(9)]],
1656
+ const constant int& w_batch_ndims [[buffer(10)]],
1657
+ const constant int* w_shape [[buffer(11)]],
1658
+ const constant int64_t* w_strides [[buffer(12)]],
1659
+ const constant int64_t* s_strides [[buffer(13)]],
1660
+ const constant int64_t* b_strides [[buffer(14)]],
1661
+ const constant int& final_block_size [[buffer(15)]],
1662
+ uint3 tid [[threadgroup_position_in_grid]],
1663
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1664
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1665
+ int M = x_shape[x_batch_ndims];
1666
+ adjust_matrix_offsets<T>(
1667
+ x,
1668
+ w,
1669
+ scales,
1670
+ biases,
1671
+ y,
1672
+ out_vec_size * M,
1673
+ x_batch_ndims,
1674
+ x_shape,
1675
+ x_strides,
1676
+ w_batch_ndims,
1677
+ w_shape,
1678
+ w_strides,
1679
+ s_strides,
1680
+ b_strides,
1681
+ tid);
1682
+
1683
+ // When (in_vec_size % split_k != 0) the final block needs to be smaller
1684
+ int in_vec_size_adj =
1685
+ tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1686
+
1687
+ qvm_impl<T, group_size, bits>(
1688
+ w,
1689
+ scales,
1690
+ biases,
1691
+ x,
1692
+ y,
1693
+ in_vec_size_adj,
1694
+ out_vec_size,
1695
+ tid,
1696
+ simd_gid,
1697
+ simd_lid);
1698
+ }
1699
+
1700
+ template <
1701
+ typename T,
1702
+ const int group_size,
1703
+ const int bits,
1704
+ const bool aligned_N,
1705
+ const bool batched,
1706
+ const int BM = 32,
1707
+ const int BK = 32,
1708
+ const int BN = 32>
1709
+ [[kernel]] void affine_qmm_t(
1710
+ const device uint32_t* w [[buffer(0)]],
1711
+ const device T* scales [[buffer(1)]],
1712
+ const device T* biases [[buffer(2)]],
1713
+ const device T* x [[buffer(3)]],
1714
+ device T* y [[buffer(4)]],
1715
+ const constant int& K [[buffer(5)]],
1716
+ const constant int& N [[buffer(6)]],
1717
+ const constant int& M [[buffer(7)]],
1718
+ const constant int& x_batch_ndims [[buffer(8)]],
1719
+ const constant int* x_shape [[buffer(9)]],
1720
+ const constant int64_t* x_strides [[buffer(10)]],
1721
+ const constant int& w_batch_ndims [[buffer(11)]],
1722
+ const constant int* w_shape [[buffer(12)]],
1723
+ const constant int64_t* w_strides [[buffer(13)]],
1724
+ const constant int64_t* s_strides [[buffer(14)]],
1725
+ const constant int64_t* b_strides [[buffer(15)]],
1726
+ uint3 tid [[threadgroup_position_in_grid]],
1727
+ uint lid [[thread_index_in_threadgroup]],
1728
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1729
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1730
+ (void)lid;
1731
+
1732
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1733
+
1734
+ threadgroup T Xs[BM * BK_padded];
1735
+ threadgroup T Ws[BN * BK_padded];
1736
+
1737
+ if (batched) {
1738
+ adjust_matrix_offsets<T>(
1739
+ x,
1740
+ w,
1741
+ scales,
1742
+ biases,
1743
+ y,
1744
+ M * N,
1745
+ x_batch_ndims,
1746
+ x_shape,
1747
+ x_strides,
1748
+ w_batch_ndims,
1749
+ w_shape,
1750
+ w_strides,
1751
+ s_strides,
1752
+ b_strides,
1753
+ tid);
1754
+ }
1755
+ qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1756
+ w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1757
+ }
1758
+
1759
+ template <
1760
+ typename T,
1761
+ const int group_size,
1762
+ const int bits,
1763
+ const bool batched,
1764
+ const int BM = 32,
1765
+ const int BK = 32,
1766
+ const int BN = 32>
1767
+ [[kernel]] void affine_qmm_n(
1768
+ const device uint32_t* w [[buffer(0)]],
1769
+ const device T* scales [[buffer(1)]],
1770
+ const device T* biases [[buffer(2)]],
1771
+ const device T* x [[buffer(3)]],
1772
+ device T* y [[buffer(4)]],
1773
+ const constant int& K [[buffer(5)]],
1774
+ const constant int& N [[buffer(6)]],
1775
+ const constant int& M [[buffer(7)]],
1776
+ const constant int& x_batch_ndims [[buffer(8)]],
1777
+ const constant int* x_shape [[buffer(9)]],
1778
+ const constant int64_t* x_strides [[buffer(10)]],
1779
+ const constant int& w_batch_ndims [[buffer(11)]],
1780
+ const constant int* w_shape [[buffer(12)]],
1781
+ const constant int64_t* w_strides [[buffer(13)]],
1782
+ const constant int64_t* s_strides [[buffer(14)]],
1783
+ const constant int64_t* b_strides [[buffer(15)]],
1784
+ uint3 tid [[threadgroup_position_in_grid]],
1785
+ uint lid [[thread_index_in_threadgroup]],
1786
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1787
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1788
+ (void)lid;
1789
+
1790
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1791
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
1792
+
1793
+ threadgroup T Xs[BM * BK_padded];
1794
+ threadgroup T Ws[BK * BN_padded];
1795
+
1796
+ if (batched) {
1797
+ adjust_matrix_offsets<T>(
1798
+ x,
1799
+ w,
1800
+ scales,
1801
+ biases,
1802
+ y,
1803
+ M * N,
1804
+ x_batch_ndims,
1805
+ x_shape,
1806
+ x_strides,
1807
+ w_batch_ndims,
1808
+ w_shape,
1809
+ w_strides,
1810
+ s_strides,
1811
+ b_strides,
1812
+ tid);
1813
+ }
1814
+
1815
+ qmm_n_impl<T, group_size, bits, BM, BK, BN>(
1816
+ w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1817
+ }
1818
+
1819
+ template <typename T, int group_size, int bits>
1820
+ [[kernel]] void affine_gather_qmv_fast(
1821
+ const device uint32_t* w [[buffer(0)]],
1822
+ const device T* scales [[buffer(1)]],
1823
+ const device T* biases [[buffer(2)]],
1824
+ const device T* x [[buffer(3)]],
1825
+ const device uint32_t* lhs_indices [[buffer(4)]],
1826
+ const device uint32_t* rhs_indices [[buffer(5)]],
1827
+ device T* y [[buffer(6)]],
1828
+ const constant int& in_vec_size [[buffer(7)]],
1829
+ const constant int& out_vec_size [[buffer(8)]],
1830
+ const constant int& x_batch_ndims [[buffer(9)]],
1831
+ const constant int* x_shape [[buffer(10)]],
1832
+ const constant int64_t* x_strides [[buffer(11)]],
1833
+ const constant int& w_batch_ndims [[buffer(12)]],
1834
+ const constant int* w_shape [[buffer(13)]],
1835
+ const constant int64_t* w_strides [[buffer(14)]],
1836
+ const constant int64_t* s_strides [[buffer(15)]],
1837
+ const constant int64_t* b_strides [[buffer(16)]],
1838
+ const constant int& batch_ndims [[buffer(17)]],
1839
+ const constant int* batch_shape [[buffer(18)]],
1840
+ const constant int64_t* lhs_strides [[buffer(19)]],
1841
+ const constant int64_t* rhs_strides [[buffer(20)]],
1842
+ uint3 tid [[threadgroup_position_in_grid]],
1843
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1844
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1845
+ int M = x_shape[x_batch_ndims];
1846
+ adjust_matrix_offsets<T>(
1847
+ x,
1848
+ w,
1849
+ scales,
1850
+ biases,
1851
+ lhs_indices,
1852
+ rhs_indices,
1853
+ y,
1854
+ out_vec_size * M,
1855
+ batch_ndims,
1856
+ batch_shape,
1857
+ lhs_strides,
1858
+ rhs_strides,
1859
+ x_batch_ndims,
1860
+ x_shape,
1861
+ x_strides,
1862
+ w_batch_ndims,
1863
+ w_shape,
1864
+ w_strides,
1865
+ s_strides,
1866
+ b_strides,
1867
+ tid);
1868
+ qmv_fast_impl<T, group_size, bits>(
1869
+ w,
1870
+ scales,
1871
+ biases,
1872
+ x,
1873
+ y,
1874
+ in_vec_size,
1875
+ out_vec_size,
1876
+ tid,
1877
+ simd_gid,
1878
+ simd_lid);
1879
+ }
1880
+
1881
+ template <typename T, int group_size, int bits>
1882
+ [[kernel]] void affine_gather_qmv(
1883
+ const device uint32_t* w [[buffer(0)]],
1884
+ const device T* scales [[buffer(1)]],
1885
+ const device T* biases [[buffer(2)]],
1886
+ const device T* x [[buffer(3)]],
1887
+ const device uint32_t* lhs_indices [[buffer(4)]],
1888
+ const device uint32_t* rhs_indices [[buffer(5)]],
1889
+ device T* y [[buffer(6)]],
1890
+ const constant int& in_vec_size [[buffer(7)]],
1891
+ const constant int& out_vec_size [[buffer(8)]],
1892
+ const constant int& x_batch_ndims [[buffer(9)]],
1893
+ const constant int* x_shape [[buffer(10)]],
1894
+ const constant int64_t* x_strides [[buffer(11)]],
1895
+ const constant int& w_batch_ndims [[buffer(12)]],
1896
+ const constant int* w_shape [[buffer(13)]],
1897
+ const constant int64_t* w_strides [[buffer(14)]],
1898
+ const constant int64_t* s_strides [[buffer(15)]],
1899
+ const constant int64_t* b_strides [[buffer(16)]],
1900
+ const constant int& batch_ndims [[buffer(17)]],
1901
+ const constant int* batch_shape [[buffer(18)]],
1902
+ const constant int64_t* lhs_strides [[buffer(19)]],
1903
+ const constant int64_t* rhs_strides [[buffer(20)]],
1904
+ uint3 tid [[threadgroup_position_in_grid]],
1905
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1906
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1907
+ int M = x_shape[x_batch_ndims];
1908
+ adjust_matrix_offsets<T>(
1909
+ x,
1910
+ w,
1911
+ scales,
1912
+ biases,
1913
+ lhs_indices,
1914
+ rhs_indices,
1915
+ y,
1916
+ out_vec_size * M,
1917
+ batch_ndims,
1918
+ batch_shape,
1919
+ lhs_strides,
1920
+ rhs_strides,
1921
+ x_batch_ndims,
1922
+ x_shape,
1923
+ x_strides,
1924
+ w_batch_ndims,
1925
+ w_shape,
1926
+ w_strides,
1927
+ s_strides,
1928
+ b_strides,
1929
+ tid);
1930
+ qmv_impl<T, group_size, bits>(
1931
+ w,
1932
+ scales,
1933
+ biases,
1934
+ x,
1935
+ y,
1936
+ in_vec_size,
1937
+ out_vec_size,
1938
+ tid,
1939
+ simd_gid,
1940
+ simd_lid);
1941
+ }
1942
+
1943
+ template <typename T, int group_size, int bits>
1944
+ [[kernel]] void affine_gather_qvm(
1945
+ const device uint32_t* w [[buffer(0)]],
1946
+ const device T* scales [[buffer(1)]],
1947
+ const device T* biases [[buffer(2)]],
1948
+ const device T* x [[buffer(3)]],
1949
+ const device uint32_t* lhs_indices [[buffer(4)]],
1950
+ const device uint32_t* rhs_indices [[buffer(5)]],
1951
+ device T* y [[buffer(6)]],
1952
+ const constant int& in_vec_size [[buffer(7)]],
1953
+ const constant int& out_vec_size [[buffer(8)]],
1954
+ const constant int& x_batch_ndims [[buffer(9)]],
1955
+ const constant int* x_shape [[buffer(10)]],
1956
+ const constant int64_t* x_strides [[buffer(11)]],
1957
+ const constant int& w_batch_ndims [[buffer(12)]],
1958
+ const constant int* w_shape [[buffer(13)]],
1959
+ const constant int64_t* w_strides [[buffer(14)]],
1960
+ const constant int64_t* s_strides [[buffer(15)]],
1961
+ const constant int64_t* b_strides [[buffer(16)]],
1962
+ const constant int& batch_ndims [[buffer(17)]],
1963
+ const constant int* batch_shape [[buffer(18)]],
1964
+ const constant int64_t* lhs_strides [[buffer(19)]],
1965
+ const constant int64_t* rhs_strides [[buffer(20)]],
1966
+ uint3 tid [[threadgroup_position_in_grid]],
1967
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1968
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1969
+ int M = x_shape[x_batch_ndims];
1970
+ adjust_matrix_offsets<T>(
1971
+ x,
1972
+ w,
1973
+ scales,
1974
+ biases,
1975
+ lhs_indices,
1976
+ rhs_indices,
1977
+ y,
1978
+ out_vec_size * M,
1979
+ batch_ndims,
1980
+ batch_shape,
1981
+ lhs_strides,
1982
+ rhs_strides,
1983
+ x_batch_ndims,
1984
+ x_shape,
1985
+ x_strides,
1986
+ w_batch_ndims,
1987
+ w_shape,
1988
+ w_strides,
1989
+ s_strides,
1990
+ b_strides,
1991
+ tid);
1992
+ qvm_impl<T, group_size, bits>(
1993
+ w,
1994
+ scales,
1995
+ biases,
1996
+ x,
1997
+ y,
1998
+ in_vec_size,
1999
+ out_vec_size,
2000
+ tid,
2001
+ simd_gid,
2002
+ simd_lid);
2003
+ }
2004
+
2005
+ template <
2006
+ typename T,
2007
+ const int group_size,
2008
+ const int bits,
2009
+ const bool aligned_N,
2010
+ const int BM = 32,
2011
+ const int BK = 32,
2012
+ const int BN = 32>
2013
+ [[kernel]] void affine_gather_qmm_t(
2014
+ const device uint32_t* w [[buffer(0)]],
2015
+ const device T* scales [[buffer(1)]],
2016
+ const device T* biases [[buffer(2)]],
2017
+ const device T* x [[buffer(3)]],
2018
+ const device uint32_t* lhs_indices [[buffer(4)]],
2019
+ const device uint32_t* rhs_indices [[buffer(5)]],
2020
+ device T* y [[buffer(6)]],
2021
+ const constant int& K [[buffer(7)]],
2022
+ const constant int& N [[buffer(8)]],
2023
+ const constant int& M [[buffer(9)]],
2024
+ const constant int& x_batch_ndims [[buffer(10)]],
2025
+ const constant int* x_shape [[buffer(11)]],
2026
+ const constant int64_t* x_strides [[buffer(12)]],
2027
+ const constant int& w_batch_ndims [[buffer(13)]],
2028
+ const constant int* w_shape [[buffer(14)]],
2029
+ const constant int64_t* w_strides [[buffer(15)]],
2030
+ const constant int64_t* s_strides [[buffer(16)]],
2031
+ const constant int64_t* b_strides [[buffer(17)]],
2032
+ const constant int& batch_ndims [[buffer(18)]],
2033
+ const constant int* batch_shape [[buffer(19)]],
2034
+ const constant int64_t* lhs_strides [[buffer(20)]],
2035
+ const constant int64_t* rhs_strides [[buffer(21)]],
2036
+ uint3 tid [[threadgroup_position_in_grid]],
2037
+ uint lid [[thread_index_in_threadgroup]],
2038
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
2039
+ uint simd_lid [[thread_index_in_simdgroup]]) {
2040
+ (void)lid;
2041
+
2042
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
2043
+
2044
+ threadgroup T Xs[BM * BK_padded];
2045
+ threadgroup T Ws[BN * BK_padded];
2046
+
2047
+ adjust_matrix_offsets<T>(
2048
+ x,
2049
+ w,
2050
+ scales,
2051
+ biases,
2052
+ lhs_indices,
2053
+ rhs_indices,
2054
+ y,
2055
+ M * N,
2056
+ batch_ndims,
2057
+ batch_shape,
2058
+ lhs_strides,
2059
+ rhs_strides,
2060
+ x_batch_ndims,
2061
+ x_shape,
2062
+ x_strides,
2063
+ w_batch_ndims,
2064
+ w_shape,
2065
+ w_strides,
2066
+ s_strides,
2067
+ b_strides,
2068
+ tid);
2069
+ qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
2070
+ w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2071
+ }
2072
+
2073
+ template <
2074
+ typename T,
2075
+ const int group_size,
2076
+ const int bits,
2077
+ const int BM = 32,
2078
+ const int BK = 32,
2079
+ const int BN = 32>
2080
+ [[kernel]] void affine_gather_qmm_n(
2081
+ const device uint32_t* w [[buffer(0)]],
2082
+ const device T* scales [[buffer(1)]],
2083
+ const device T* biases [[buffer(2)]],
2084
+ const device T* x [[buffer(3)]],
2085
+ const device uint32_t* lhs_indices [[buffer(4)]],
2086
+ const device uint32_t* rhs_indices [[buffer(5)]],
2087
+ device T* y [[buffer(6)]],
2088
+ const constant int& K [[buffer(7)]],
2089
+ const constant int& N [[buffer(8)]],
2090
+ const constant int& M [[buffer(9)]],
2091
+ const constant int& x_batch_ndims [[buffer(10)]],
2092
+ const constant int* x_shape [[buffer(11)]],
2093
+ const constant int64_t* x_strides [[buffer(12)]],
2094
+ const constant int& w_batch_ndims [[buffer(13)]],
2095
+ const constant int* w_shape [[buffer(14)]],
2096
+ const constant int64_t* w_strides [[buffer(15)]],
2097
+ const constant int64_t* s_strides [[buffer(16)]],
2098
+ const constant int64_t* b_strides [[buffer(17)]],
2099
+ const constant int& batch_ndims [[buffer(18)]],
2100
+ const constant int* batch_shape [[buffer(19)]],
2101
+ const constant int64_t* lhs_strides [[buffer(20)]],
2102
+ const constant int64_t* rhs_strides [[buffer(21)]],
2103
+ uint3 tid [[threadgroup_position_in_grid]],
2104
+ uint lid [[thread_index_in_threadgroup]],
2105
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
2106
+ uint simd_lid [[thread_index_in_simdgroup]]) {
2107
+ (void)lid;
2108
+
2109
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
2110
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
2111
+
2112
+ threadgroup T Xs[BM * BK_padded];
2113
+ threadgroup T Ws[BK * BN_padded];
2114
+
2115
+ adjust_matrix_offsets<T>(
2116
+ x,
2117
+ w,
2118
+ scales,
2119
+ biases,
2120
+ lhs_indices,
2121
+ rhs_indices,
2122
+ y,
2123
+ M * N,
2124
+ batch_ndims,
2125
+ batch_shape,
2126
+ lhs_strides,
2127
+ rhs_strides,
2128
+ x_batch_ndims,
2129
+ x_shape,
2130
+ x_strides,
2131
+ w_batch_ndims,
2132
+ w_shape,
2133
+ w_strides,
2134
+ s_strides,
2135
+ b_strides,
2136
+ tid);
2137
+ qmm_n_impl<T, group_size, bits, BM, BK, BN>(
2138
+ w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2139
+ }
2140
+
2141
+ template <
2142
+ typename T,
2143
+ int group_size,
2144
+ int bits,
2145
+ int BM,
2146
+ int BN,
2147
+ int BK,
2148
+ int WM,
2149
+ int WN,
2150
+ bool transpose>
2151
+ [[kernel]] void affine_gather_qmm_rhs(
2152
+ const device T* x [[buffer(0)]],
2153
+ const device uint32_t* w [[buffer(1)]],
2154
+ const device T* scales [[buffer(2)]],
2155
+ const device T* biases [[buffer(3)]],
2156
+ const device uint32_t* indices [[buffer(4)]],
2157
+ device T* y [[buffer(5)]],
2158
+ const constant int& M [[buffer(6)]],
2159
+ const constant int& N [[buffer(7)]],
2160
+ const constant int& K [[buffer(8)]],
2161
+ uint3 tid [[threadgroup_position_in_grid]],
2162
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
2163
+ uint simd_lane_id [[thread_index_in_simdgroup]]) {
2164
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
2165
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
2166
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
2167
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
2168
+
2169
+ using mma_t = mlx::steel::BlockMMA<
2170
+ T,
2171
+ T,
2172
+ BM,
2173
+ BN,
2174
+ BK,
2175
+ WM,
2176
+ WN,
2177
+ false,
2178
+ transpose,
2179
+ BK_padded,
2180
+ transpose ? BK_padded : BN_padded>;
2181
+ using loader_x_t =
2182
+ mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
2183
+ using loader_w_t = QuantizedBlockLoader<
2184
+ T,
2185
+ transpose ? BN : BK,
2186
+ transpose ? BK : BN,
2187
+ transpose ? BK_padded : BN_padded,
2188
+ transpose,
2189
+ WM * WN * SIMD_SIZE,
2190
+ group_size,
2191
+ bits>;
2192
+
2193
+ threadgroup T Xs[BM * BK_padded];
2194
+ threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
2195
+
2196
+ // Compute the block
2197
+ const int K_w = K * bytes_per_pack / pack_factor;
2198
+ const int K_g = K / group_size;
2199
+ const int N_w = N * bytes_per_pack / pack_factor;
2200
+ const int N_g = N / group_size;
2201
+ const int K_it = K / BK;
2202
+ const size_t stride_w = transpose ? N * K_w : K * N_w;
2203
+ const size_t stride_s = transpose ? N * K_g : K * N_g;
2204
+ const int y_row = tid.y * BM;
2205
+ const int y_col = tid.x * BN;
2206
+ const size_t y_row_long = size_t(y_row);
2207
+ const size_t y_col_long = size_t(y_col);
2208
+
2209
+ // Prepare threadgroup bounds
2210
+ const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
2211
+ const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
2212
+
2213
+ // Calculate the final tiles in the case that K is not aligned
2214
+ const int k_remain = K - K_it * BK;
2215
+ const short2 tile_x = short2(k_remain, tgp_bm);
2216
+ const short2 tile_w =
2217
+ transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
2218
+
2219
+ // Move x and output to the correct block
2220
+ auto wl = (const device uint8_t*)w;
2221
+ x += y_row_long * K;
2222
+ y += y_row_long * N + y_col_long;
2223
+ wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
2224
+ scales += transpose ? y_col_long * K_g : y_col / group_size;
2225
+ biases += transpose ? y_col_long * K_g : y_col / group_size;
2226
+
2227
+ // Do as many matmuls as necessary
2228
+ uint32_t index;
2229
+ short offset;
2230
+ uint32_t index_next = indices[y_row];
2231
+ short offset_next = 0;
2232
+ int n = 0;
2233
+ while (n < tgp_bm) {
2234
+ n++;
2235
+ offset = offset_next;
2236
+ index = index_next;
2237
+ offset_next = tgp_bm;
2238
+ for (; n < tgp_bm; n++) {
2239
+ if (indices[y_row + n] != index) {
2240
+ offset_next = n;
2241
+ index_next = indices[y_row + n];
2242
+ break;
2243
+ }
2244
+ }
2245
+ threadgroup_barrier(mem_flags::mem_none);
2246
+
2247
+ // Prepare threadgroup mma operation
2248
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
2249
+
2250
+ // Prepare threadgroup loading operations
2251
+ thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
2252
+ thread loader_w_t loader_w(
2253
+ wl + index * stride_w,
2254
+ scales + index * stride_s,
2255
+ biases + index * stride_s,
2256
+ transpose ? K : N,
2257
+ Ws,
2258
+ simd_group_id,
2259
+ simd_lane_id);
2260
+
2261
+ // Matrices are all aligned check nothing
2262
+ if (align_M && align_N) {
2263
+ gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
2264
+ if (!align_K) {
2265
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2266
+ gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
2267
+ }
2268
+
2269
+ // Store results to device memory
2270
+ if (offset_next - offset == BM) {
2271
+ mma_op.store_result(y, N);
2272
+ } else {
2273
+ mma_op.store_result_slice(
2274
+ y, N, short2(0, offset), short2(BN, offset_next));
2275
+ }
2276
+ } else {
2277
+ // Tile aligned so check outside of the hot loop
2278
+ if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
2279
+ gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
2280
+ if (!align_K) {
2281
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2282
+ gemm_loop_finalize(
2283
+ Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
2284
+ }
2285
+
2286
+ // Store results to device memory
2287
+ if (offset_next - offset == BM) {
2288
+ mma_op.store_result(y, N);
2289
+ } else {
2290
+ mma_op.store_result_slice(
2291
+ y, N, short2(0, offset), short2(BN, offset_next));
2292
+ }
2293
+ }
2294
+
2295
+ // Tile partially aligned check rows
2296
+ else if (align_N || tgp_bn == BN) {
2297
+ gemm_loop_unaligned<false, true, transpose>(
2298
+ Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
2299
+ if (!align_K) {
2300
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2301
+ gemm_loop_finalize(
2302
+ Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
2303
+ }
2304
+ mma_op.store_result_slice(
2305
+ y, N, short2(0, offset), short2(BN, offset_next));
2306
+ }
2307
+
2308
+ // Tile partially aligned check cols
2309
+ else if (align_M || tgp_bm == BM) {
2310
+ gemm_loop_unaligned<true, false, transpose>(
2311
+ Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
2312
+ if (!align_K) {
2313
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2314
+ gemm_loop_finalize(
2315
+ Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
2316
+ }
2317
+ mma_op.store_result_slice(
2318
+ y, N, short2(0, offset), short2(tgp_bn, offset_next));
2319
+ }
2320
+
2321
+ // Nothing aligned so check both rows and cols
2322
+ else {
2323
+ gemm_loop_unaligned<false, false, transpose>(
2324
+ Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
2325
+ if (!align_K) {
2326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2327
+ gemm_loop_finalize(
2328
+ Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
2329
+ }
2330
+ mma_op.store_result_slice(
2331
+ y, N, short2(0, offset), short2(tgp_bn, offset_next));
2332
+ }
2333
+ }
2334
+ }
2335
+ }
2336
+
2337
+ template <typename T, const int group_size, const int bits>
2338
+ [[kernel]] void affine_quantize(
2339
+ const device T* w [[buffer(0)]],
2340
+ device uint8_t* out [[buffer(1)]],
2341
+ device T* scales [[buffer(2)]],
2342
+ device T* biases [[buffer(3)]],
2343
+ uint2 index [[thread_position_in_grid]],
2344
+ uint2 grid_dim [[threads_per_grid]]) {
2345
+ constexpr float eps = 1e-7;
2346
+ constexpr int simd_size = 32;
2347
+ constexpr float n_bins = (1 << bits) - 1;
2348
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
2349
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
2350
+ constexpr int values_per_reduce = group_size / simd_size;
2351
+ constexpr int writes_per_reduce = pack_factor / values_per_reduce;
2352
+ constexpr int writes_per_pack =
2353
+ writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
2354
+ constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2355
+
2356
+ static_assert(
2357
+ group_size % simd_size == 0,
2358
+ "Group size must be divisible by simd size.");
2359
+
2360
+ size_t offset = index.x + grid_dim.x * size_t(index.y);
2361
+ size_t in_index = offset * values_per_reduce;
2362
+ size_t out_index = power_of_2_bits
2363
+ ? offset * writes_per_pack
2364
+ : offset * bytes_per_pack / writes_per_reduce;
2365
+
2366
+ float w_thread[values_per_reduce];
2367
+ float w_min = Limits<T>::max;
2368
+ float w_max = 0;
2369
+
2370
+ #pragma clang loop unroll(full)
2371
+ for (int i = 0; i < values_per_reduce; i++) {
2372
+ float val = w[in_index + i];
2373
+ w_thread[i] = val;
2374
+ w_min = min(w_min, val);
2375
+ w_max = max(w_max, val);
2376
+ }
2377
+
2378
+ w_min = simd_min(w_min);
2379
+ w_max = simd_max(w_max);
2380
+
2381
+ float scale = max((w_max - w_min) / n_bins, eps);
2382
+ bool side = abs(w_min) > abs(w_max);
2383
+ scale = side ? scale : -scale;
2384
+ float edge = side ? w_min : w_max;
2385
+ float q0 = round(edge / scale);
2386
+ bool at_zero = q0 == 0.0f;
2387
+ scale = at_zero ? scale : edge / q0;
2388
+ float bias = at_zero ? 0 : edge;
2389
+
2390
+ // Write out the scales and biases
2391
+ size_t gindex = in_index / group_size;
2392
+ if (in_index % group_size == 0) {
2393
+ scales[gindex] = static_cast<T>(scale);
2394
+ biases[gindex] = static_cast<T>(bias);
2395
+ }
2396
+
2397
+ using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
2398
+ OutType output = 0;
2399
+
2400
+ #pragma clang loop unroll(full)
2401
+ for (int i = 0; i < values_per_reduce; i++) {
2402
+ uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
2403
+ if (bits == 8) {
2404
+ output = val;
2405
+ } else {
2406
+ output |= val << (bits * (i % pack_factor));
2407
+ }
2408
+
2409
+ if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
2410
+ out[out_index + i / pack_factor] = output;
2411
+ output = 0;
2412
+ } else {
2413
+ #pragma clang loop unroll(full)
2414
+ for (int j = 1; j < writes_per_reduce; j++) {
2415
+ uint8_t sval = simd_shuffle_down(val, j);
2416
+ output |= static_cast<OutType>(sval)
2417
+ << (bits * (j * values_per_reduce + i));
2418
+ }
2419
+ }
2420
+ }
2421
+ if (bits == 3 || bits == 6) {
2422
+ if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
2423
+ out[out_index] = output & 0xff;
2424
+ out[out_index + 1] = (output & 0xff00) >> 8;
2425
+ out[out_index + 2] = (output & 0xff0000) >> 16;
2426
+ }
2427
+ } else if (bits == 5) {
2428
+ if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
2429
+ out[out_index] = output & 0xff;
2430
+ out[out_index + 1] = (output & 0xff00) >> 8;
2431
+ out[out_index + 2] = (output & 0xff0000) >> 16;
2432
+ out[out_index + 3] = (output & 0xff000000) >> 24;
2433
+ out[out_index + 4] = (output & 0xff00000000) >> 32;
2434
+ }
2435
+ } else {
2436
+ if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
2437
+ out[out_index / writes_per_reduce] = output;
2438
+ }
2439
+ }
2440
+ }
2441
+
2442
+ template <typename T, const int group_size, const int bits>
2443
+ [[kernel]] void affine_dequantize(
2444
+ const device uint8_t* w [[buffer(0)]],
2445
+ const device T* scales [[buffer(1)]],
2446
+ const device T* biases [[buffer(2)]],
2447
+ device T* out [[buffer(3)]],
2448
+ uint2 index [[thread_position_in_grid]],
2449
+ uint2 grid_dim [[threads_per_grid]]) {
2450
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
2451
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
2452
+
2453
+ size_t offset = index.x + grid_dim.x * size_t(index.y);
2454
+ size_t oindex = offset * pack_factor;
2455
+ size_t gindex = oindex / group_size;
2456
+ T scale = scales[gindex];
2457
+ T bias = biases[gindex];
2458
+
2459
+ out += oindex;
2460
+
2461
+ if (bits == 3) {
2462
+ w += offset * bytes_per_pack;
2463
+ out[0] = (w[0] & 0x7) * scale + bias;
2464
+ out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
2465
+ out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
2466
+ out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
2467
+ out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
2468
+ out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
2469
+ out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
2470
+ out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
2471
+ } else if (bits == 5) {
2472
+ w += offset * bytes_per_pack;
2473
+ out[0] = (w[0] & 0x1f) * scale + bias;
2474
+ out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
2475
+ out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
2476
+ out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
2477
+ out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
2478
+ out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
2479
+ out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
2480
+ out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
2481
+ } else if (bits == 6) {
2482
+ w += offset * bytes_per_pack;
2483
+ out[0] = (w[0] & 0x3f) * scale + bias;
2484
+ out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
2485
+ out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
2486
+ out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
2487
+ } else {
2488
+ uint val = w[offset];
2489
+ #pragma clang loop unroll(full)
2490
+ for (int i = 0; i < pack_factor; i++) {
2491
+ uint8_t d;
2492
+ if (bits == 2) {
2493
+ d = (val >> (bits * i)) & 0x03;
2494
+ } else if (bits == 4) {
2495
+ d = (val >> (bits * i)) & 0x0f;
2496
+ } else if (bits == 8) {
2497
+ d = val;
2498
+ }
2499
+ out[i] = scale * d + bias;
2500
+ }
2501
+ }
2502
+ }