cui-llama.rn 1.3.0 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/android/src/main/CMakeLists.txt +6 -1
  2. package/android/src/main/jni.cpp +6 -6
  3. package/cpp/amx/amx.cpp +196 -0
  4. package/cpp/amx/amx.h +20 -0
  5. package/cpp/amx/common.h +101 -0
  6. package/cpp/amx/mmq.cpp +2524 -0
  7. package/cpp/amx/mmq.h +16 -0
  8. package/cpp/common.cpp +1981 -1682
  9. package/cpp/common.h +636 -600
  10. package/cpp/ggml-aarch64.c +129 -129
  11. package/cpp/ggml-aarch64.h +19 -19
  12. package/cpp/ggml-alloc.c +1038 -1040
  13. package/cpp/ggml-alloc.h +76 -76
  14. package/cpp/ggml-backend-impl.h +238 -216
  15. package/cpp/ggml-backend-reg.cpp +423 -195
  16. package/cpp/ggml-backend.cpp +1999 -1997
  17. package/cpp/ggml-backend.h +351 -328
  18. package/cpp/ggml-common.h +1859 -1853
  19. package/cpp/ggml-cpp.h +38 -38
  20. package/cpp/ggml-cpu-aarch64.c +3823 -3560
  21. package/cpp/ggml-cpu-aarch64.h +32 -30
  22. package/cpp/ggml-cpu-impl.h +386 -371
  23. package/cpp/ggml-cpu-quants.c +10835 -10822
  24. package/cpp/ggml-cpu-quants.h +63 -63
  25. package/cpp/ggml-cpu.c +99 -103
  26. package/cpp/ggml-cpu.cpp +69 -17
  27. package/cpp/ggml-cpu.h +152 -177
  28. package/cpp/ggml-impl.h +556 -550
  29. package/cpp/ggml-metal.h +66 -66
  30. package/cpp/ggml-metal.m +4426 -4294
  31. package/cpp/ggml-quants.c +5247 -5247
  32. package/cpp/ggml-quants.h +100 -100
  33. package/cpp/ggml-threading.cpp +12 -12
  34. package/cpp/ggml-threading.h +12 -12
  35. package/cpp/ggml.c +7618 -8180
  36. package/cpp/ggml.h +2255 -2411
  37. package/cpp/json-schema-to-grammar.cpp +1045 -0
  38. package/cpp/json-schema-to-grammar.h +8 -0
  39. package/cpp/json.hpp +24766 -0
  40. package/cpp/llama-grammar.cpp +1138 -1138
  41. package/cpp/llama-grammar.h +144 -144
  42. package/cpp/llama-impl.h +181 -181
  43. package/cpp/llama-sampling.cpp +2348 -2348
  44. package/cpp/llama-sampling.h +48 -48
  45. package/cpp/llama-vocab.cpp +1984 -1984
  46. package/cpp/llama-vocab.h +170 -170
  47. package/cpp/llama.cpp +22332 -22132
  48. package/cpp/llama.h +1259 -1253
  49. package/cpp/log.cpp +401 -401
  50. package/cpp/log.h +121 -121
  51. package/cpp/rn-llama.hpp +6 -6
  52. package/cpp/sampling.cpp +505 -466
  53. package/cpp/sampling.h +22 -1
  54. package/cpp/sgemm.cpp +1884 -1884
  55. package/cpp/speculative.cpp +270 -0
  56. package/cpp/speculative.h +28 -0
  57. package/cpp/unicode.cpp +11 -0
  58. package/ios/RNLlamaContext.mm +13 -0
  59. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  60. package/lib/commonjs/grammar.js +4 -2
  61. package/lib/commonjs/grammar.js.map +1 -1
  62. package/lib/commonjs/index.js.map +1 -1
  63. package/lib/module/NativeRNLlama.js.map +1 -1
  64. package/lib/module/grammar.js +2 -1
  65. package/lib/module/grammar.js.map +1 -1
  66. package/lib/module/index.js.map +1 -1
  67. package/lib/typescript/NativeRNLlama.d.ts +94 -4
  68. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  69. package/lib/typescript/grammar.d.ts +5 -6
  70. package/lib/typescript/grammar.d.ts.map +1 -1
  71. package/lib/typescript/index.d.ts +4 -2
  72. package/lib/typescript/index.d.ts.map +1 -1
  73. package/package.json +2 -1
  74. package/src/NativeRNLlama.ts +97 -10
  75. package/src/grammar.ts +10 -8
  76. package/src/index.ts +22 -1
@@ -0,0 +1,2524 @@
1
+
2
+ #if defined(__GNUC__)
3
+ #pragma GCC diagnostic ignored "-Wpedantic"
4
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
5
+ #endif
6
+
7
+ #include "amx.h"
8
+ #include "mmq.h"
9
+ #include "ggml-impl.h"
10
+ #include "ggml-cpu-impl.h"
11
+ #include "ggml-cpu-quants.h"
12
+ #include "ggml-quants.h"
13
+ #include <algorithm>
14
+ #include <type_traits>
15
+
16
+ #if defined(__gnu_linux__)
17
+ #include <sys/syscall.h>
18
+ #include <unistd.h>
19
+ #endif
20
+
21
+ #if defined(_OPENMP)
22
+ #include <omp.h>
23
+ #endif
24
+
25
+ #if (defined(_WIN32) || defined(_WIN64))
26
+ #define RESTRICT __restrict
27
+ #else
28
+ #define RESTRICT __restrict__
29
+ #endif
30
+
31
+ #if (defined(_WIN32) || defined(_WIN64))
32
+ #define ALWAYS_INLINE __forceinline
33
+ #elif __has_attribute(always_inline) || defined(__GNUC__)
34
+ #define ALWAYS_INLINE __attribute__((__always_inline__)) inline
35
+ #else
36
+ #define ALWAYS_INLINE inline
37
+ #endif
38
+
39
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
40
+
41
+ namespace {
42
+
43
+ // Forced unrolling
44
+ template <int n>
45
+ struct Unroll {
46
+ template <typename Func, typename... Args>
47
+ ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
48
+ Unroll<n - 1>{}(f, args...);
49
+ f(std::integral_constant<int, n - 1>{}, args...);
50
+ }
51
+ };
52
+
53
+ template <>
54
+ struct Unroll<1> {
55
+ template <typename Func, typename... Args>
56
+ ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
57
+ f(std::integral_constant<int, 0>{}, args...);
58
+ }
59
+ };
60
+
61
+ // type traits
62
+ template <typename T> struct PackedTypes {};
63
+ template <> struct PackedTypes<block_q4_0> { using type = int8_t; };
64
+ template <> struct PackedTypes<block_q4_1> { using type = uint8_t; };
65
+ template <> struct PackedTypes<block_q8_0> { using type = int8_t; };
66
+ template <typename T> using packed_B_type = typename PackedTypes<T>::type;
67
+
68
+ template <typename T>
69
+ struct do_compensate : std::integral_constant<bool,
70
+ std::is_same<T, block_q8_0>::value> {};
71
+
72
+ template <typename T>
73
+ struct do_unpack : std::integral_constant<bool,
74
+ std::is_same<T, block_q4_0>::value ||
75
+ std::is_same<T, block_q4_1>::value> {};
76
+
77
+ template <typename T>
78
+ struct is_type_qkk : std::integral_constant<bool,
79
+ std::is_same<T, block_q4_K>::value ||
80
+ std::is_same<T, block_q5_K>::value ||
81
+ std::is_same<T, block_q6_K>::value ||
82
+ std::is_same<T, block_iq4_xs>::value> {};
83
+
84
+ #define LM_GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \
85
+ [&] { \
86
+ switch (TYPE) { \
87
+ case LM_GGML_TYPE_F16: { \
88
+ using type = lm_ggml_fp16_t; \
89
+ constexpr int blck_size = 16; \
90
+ return __VA_ARGS__(); \
91
+ } \
92
+ case LM_GGML_TYPE_BF16: { \
93
+ using type = lm_ggml_bf16_t; \
94
+ constexpr int blck_size = 32; \
95
+ return __VA_ARGS__(); \
96
+ } \
97
+ default: \
98
+ fprintf(stderr, "Unsupported floating data type\n"); \
99
+ } \
100
+ }()
101
+
102
+ #define LM_GGML_DISPATCH_QTYPES(QT, ...) \
103
+ [&] { \
104
+ switch (QT) { \
105
+ case LM_GGML_TYPE_Q4_0: { \
106
+ using type = block_q4_0; \
107
+ using vec_dot_type = block_q8_0; \
108
+ constexpr int blck_size = QK4_0; \
109
+ return __VA_ARGS__(); \
110
+ } \
111
+ case LM_GGML_TYPE_Q4_1: { \
112
+ using type = block_q4_1; \
113
+ using vec_dot_type = block_q8_1; \
114
+ constexpr int blck_size = QK4_1; \
115
+ return __VA_ARGS__(); \
116
+ } \
117
+ case LM_GGML_TYPE_Q8_0: { \
118
+ using type = block_q8_0; \
119
+ using vec_dot_type = block_q8_0; \
120
+ constexpr int blck_size = QK8_0; \
121
+ return __VA_ARGS__(); \
122
+ } \
123
+ case LM_GGML_TYPE_Q4_K: { \
124
+ using type = block_q4_K; \
125
+ using vec_dot_type = block_q8_K; \
126
+ constexpr int blck_size = QK_K; \
127
+ return __VA_ARGS__(); \
128
+ } \
129
+ case LM_GGML_TYPE_Q5_K: { \
130
+ using type = block_q5_K; \
131
+ using vec_dot_type = block_q8_K; \
132
+ constexpr int blck_size = QK_K; \
133
+ return __VA_ARGS__(); \
134
+ } \
135
+ case LM_GGML_TYPE_Q6_K: { \
136
+ using type = block_q6_K; \
137
+ using vec_dot_type = block_q8_K; \
138
+ constexpr int blck_size = QK_K; \
139
+ return __VA_ARGS__(); \
140
+ } \
141
+ case LM_GGML_TYPE_IQ4_XS: { \
142
+ using type = block_iq4_xs; \
143
+ using vec_dot_type = block_q8_K; \
144
+ constexpr int blck_size = QK_K; \
145
+ return __VA_ARGS__(); \
146
+ } \
147
+ default: \
148
+ fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \
149
+ } \
150
+ }()
151
+
152
+ #define LM_GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
153
+ [&] { \
154
+ if (BOOL_V) { \
155
+ constexpr bool BOOL_NAME = true; \
156
+ return __VA_ARGS__(); \
157
+ } else { \
158
+ constexpr bool BOOL_NAME = false; \
159
+ return __VA_ARGS__(); \
160
+ } \
161
+ }()
162
+
163
+ // define amx tile config data structure
164
+ struct tile_config_t{
165
+ uint8_t palette_id = 0;
166
+ uint8_t start_row = 0;
167
+ uint8_t reserved_0[14] = {0};
168
+ uint16_t colsb[16] = {0};
169
+ uint8_t rows[16] = {0};
170
+ };
171
+
172
+ // Notes: amx tile config
173
+ //
174
+ // Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,
175
+ // and accumulate the result to a 16 x 16 matrix C containing INT32 values,
176
+ //
177
+ // As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used
178
+ // instead of the normally used 16-16-64 config.
179
+ //
180
+ // Block A: {16, 32}, dtype = int8_t
181
+ // Block B: {16, 32}, dtype = uint8_t/int8_t
182
+ // Block C: {16, 16}, dtype = int32_t
183
+ //
184
+ // Block B needs to be prepacked to vnni format before feeding into TMUL:
185
+ // packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}
186
+ //
187
+ // Therefore, we get tileconfig:
188
+ // A B C
189
+ // rows 16 8 16
190
+ // colsb 32 64 16
191
+ //
192
+ // For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,
193
+ // C used TMM4-TMM7:
194
+ // B TMM0 B TMM1
195
+ // A TMM2 C TMM4 C TMM6
196
+ // A TMM3 C TMM5 C TMM7
197
+ //
198
+ // Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A
199
+ // will be needed.
200
+ //
201
+ // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;
202
+ // and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.
203
+ //
204
+ // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/
205
+ // advanced-matrix-extensions-intrinsics-functions.html
206
+ //
207
+
208
+ #define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
209
+ void lm_ggml_tile_config_init(void) {
210
+ static thread_local bool is_first_time = true;
211
+
212
+ if (!is_first_time) {
213
+ return;
214
+ }
215
+
216
+ static thread_local tile_config_t tc;
217
+ tile_config_t current_tc;
218
+ _tile_storeconfig(&current_tc);
219
+
220
+ // load only when config changes
221
+ if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
222
+ memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
223
+ tc.palette_id = 1;
224
+ tc.start_row = 0;
225
+ TC_CONFIG_TILE(TMM0, 8, 64);
226
+ TC_CONFIG_TILE(TMM1, 8, 64);
227
+ TC_CONFIG_TILE(TMM2, 16, 32);
228
+ TC_CONFIG_TILE(TMM3, 16, 32);
229
+ TC_CONFIG_TILE(TMM4, 16, 64);
230
+ TC_CONFIG_TILE(TMM5, 16, 64);
231
+ TC_CONFIG_TILE(TMM6, 16, 64);
232
+ TC_CONFIG_TILE(TMM7, 16, 64);
233
+ _tile_loadconfig(&tc);
234
+ }
235
+
236
+ is_first_time = false;
237
+ }
238
+
239
+ // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
240
+ // See the notes `s8s8 igemm compensation in avx512-vnni` for detail.
241
+ template <typename TB>
242
+ int get_tile_size() {
243
+ int tile_size = TILE_N * sizeof(TB);
244
+ if (do_compensate<TB>::value) {
245
+ tile_size += TILE_N * sizeof(int32_t);
246
+ }
247
+ if (std::is_same<TB, block_q4_K>::value ||
248
+ std::is_same<TB, block_q5_K>::value) {
249
+ tile_size += TILE_N * 4;
250
+ }
251
+ if (std::is_same<TB, block_iq4_xs>::value) {
252
+ tile_size += TILE_N * 2;
253
+ }
254
+ return tile_size;
255
+ }
256
+
257
+ template <typename TB, int BLOCK_K>
258
+ int get_row_size(int K) {
259
+ int KB = K / BLOCK_K;
260
+ int row_size = KB * sizeof(TB);
261
+ if (do_compensate<TB>::value) {
262
+ row_size += KB * sizeof(int32_t);
263
+ }
264
+ if (std::is_same<TB, block_q4_K>::value ||
265
+ std::is_same<TB, block_q5_K>::value) {
266
+ row_size += KB * 4;
267
+ }
268
+ if (std::is_same<TB, block_iq4_xs>::value) {
269
+ row_size += KB * 2;
270
+ }
271
+ return row_size;
272
+ }
273
+
274
+ // vectorized dtype conversion
275
+ inline float FP16_TO_FP32(lm_ggml_half val) {
276
+ __m256i v = _mm256_setr_epi16(
277
+ val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
278
+ __m512 o = _mm512_cvtph_ps(v);
279
+ return _mm512_cvtss_f32(o);
280
+ }
281
+
282
+ inline __m512 FP16_TO_FP32_VEC(lm_ggml_half val) {
283
+ __m256i v = _mm256_set1_epi16(val);
284
+ return _mm512_cvtph_ps(v);
285
+ }
286
+
287
+ // horizontal reduce
288
+ inline float _mm512_reduce_max_ps(const __m512 x) {
289
+ __m512 v = x;
290
+ __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
291
+ v = _mm512_max_ps(v, v1);
292
+ v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
293
+ v = _mm512_max_ps(v, v1);
294
+ v1 = _mm512_shuffle_ps(v, v, 0x4E);
295
+ v = _mm512_max_ps(v, v1);
296
+ v1 = _mm512_shuffle_ps(v, v, 0xB1);
297
+ v = _mm512_max_ps(v, v1);
298
+ return _mm512_cvtss_f32(v);
299
+ }
300
+
301
+ // transpose utils
302
+ #define SHUFFLE_EPI32(a, b, mask) \
303
+ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
304
+ inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) {
305
+ // unpacking and 32-bit elements
306
+ v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);
307
+ v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);
308
+ v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);
309
+ v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);
310
+ v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);
311
+ v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);
312
+ v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);
313
+ v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);
314
+
315
+ // shuffling the 32-bit elements
316
+ v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);
317
+ v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);
318
+ v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);
319
+ v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);
320
+ v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);
321
+ v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);
322
+ v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);
323
+ v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);
324
+
325
+ // shuffling 128-bit elements
326
+ v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);
327
+ v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);
328
+ v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);
329
+ v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);
330
+ v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);
331
+ v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);
332
+ v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);
333
+ v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);
334
+ }
335
+
336
+ inline void transpose_16x4_32bit(__m512i * r, __m512i * d) {
337
+
338
+ static const __m512i index1 = _mm512_set_epi32(
339
+ 0x0f, 0x0b, 0x07, 0x03,
340
+ 0x0e, 0x0a, 0x06, 0x02,
341
+ 0x0d, 0x09, 0x05, 0x01,
342
+ 0x0c, 0x08, 0x04, 0x00);
343
+
344
+ d[0] = _mm512_permutexvar_epi32(index1, r[0]);
345
+ d[1] = _mm512_permutexvar_epi32(index1, r[1]);
346
+ d[2] = _mm512_permutexvar_epi32(index1, r[2]);
347
+ d[3] = _mm512_permutexvar_epi32(index1, r[3]);
348
+
349
+ r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
350
+ r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
351
+ r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
352
+ r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
353
+
354
+ d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
355
+ d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
356
+ d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
357
+ d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
358
+ }
359
+
360
+ inline void transpose_16x16_32bit(__m512i * v) {
361
+ __m512i v1[16];
362
+ v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
363
+ v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
364
+ v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
365
+ v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
366
+ v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
367
+ v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
368
+ v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
369
+ v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
370
+ v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
371
+ v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
372
+ v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
373
+ v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
374
+ v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
375
+ v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
376
+ v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
377
+ v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
378
+
379
+ v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
380
+ v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
381
+ v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
382
+ v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
383
+ v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
384
+ v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
385
+ v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
386
+ v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
387
+ v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
388
+ v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
389
+ v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
390
+ v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
391
+ v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
392
+ v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
393
+ v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
394
+ v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
395
+
396
+ v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
397
+ v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
398
+ v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
399
+ v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
400
+ v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
401
+ v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
402
+ v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
403
+ v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
404
+ v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
405
+ v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
406
+ v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
407
+ v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
408
+ v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
409
+ v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
410
+ v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
411
+ v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
412
+
413
+ v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
414
+ v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
415
+ v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
416
+ v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
417
+ v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
418
+ v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
419
+ v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
420
+ v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
421
+ v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
422
+ v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
423
+ v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
424
+ v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
425
+ v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
426
+ v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
427
+ v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
428
+ v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
429
+ }
430
+
431
+ void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) {
432
+ assert(k % QK_K == 0);
433
+ const int KB = k / QK_K;
434
+ constexpr int kVecs = QK_K / 16;
435
+
436
+ block_q8_K * y = reinterpret_cast<block_q8_K *>(vy);
437
+
438
+ // hold 16 float vecs from x
439
+ __m512 v[kVecs];
440
+
441
+ // hold the quants vecs
442
+ __m512i vq[kVecs / 4];
443
+
444
+ // hold the packed quants vecs
445
+ __m512i vq_packed[kVecs / 4];
446
+
447
+ const __m512 signBit = _mm512_set1_ps(-0.f);
448
+
449
+ for (int i = 0; i < KB; ++i) {
450
+ // Compute max(abs(e)) for the block
451
+ __m512 vamax = _mm512_set1_ps(0.f);
452
+ for (int j = 0; j < kVecs; ++j) {
453
+ v[j] = _mm512_loadu_ps(x); x += 16;
454
+ vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));
455
+ }
456
+ const float amax = _mm512_reduce_max_ps(vamax);
457
+
458
+ // Quantize these floats
459
+ const float iscale = 127.f / amax;
460
+ y[i].d = LM_GGML_FP32_TO_FP16(1 / iscale);
461
+ const float id = ( amax != 0.0f ) ? iscale : 0.f;
462
+ const __m512 vscale = _mm512_set1_ps(id);
463
+
464
+ // Apply multiplier and round to nearest integer
465
+ for (int j = 0; j < kVecs; ++j) {
466
+ v[j] = _mm512_mul_ps(v[j], vscale);
467
+ v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
468
+ }
469
+
470
+ // Pack to epi8 vecs
471
+ for (int j = 0; j < kVecs / 4; ++j) {
472
+ __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));
473
+ __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));
474
+ __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));
475
+ __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));
476
+
477
+ __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);
478
+ __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);
479
+
480
+ vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);
481
+ _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]);
482
+ }
483
+
484
+ // Compute the bsums with vnni
485
+ transpose_16x4_32bit(vq, vq_packed);
486
+
487
+ const __m512i one = _mm512_set1_epi8(1);
488
+ __m512i sum = _mm512_setzero_si512();
489
+ for (int k = 0; k < 4; ++k) {
490
+ sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);
491
+ }
492
+ _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum));
493
+ }
494
+ }
495
+
496
+ // quantize A from float to `vec_dot_type`
497
+ template <typename T>
498
+ inline void from_float(const float * x, char * vy, int64_t k);
499
+
500
+ template <>
501
+ inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
502
+ quantize_row_q8_0(x, (block_q8_0 *)vy, k);
503
+ }
504
+
505
+ template <>
506
+ inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
507
+ quantize_row_q8_1(x, (block_q8_1 *)vy, k);
508
+ }
509
+
510
+ template <>
511
+ inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
512
+ #if 1
513
+ // TODO: this is reference impl!
514
+ quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);
515
+ #else
516
+ quantize_row_q8_K_vnni(x, vy, k);
517
+ #endif
518
+ }
519
+
520
+ // load A from memory to array when nrows can not fill in whole tile
521
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) {
522
+ assert(nr != TILE_M);
523
+ for (int m = 0; m < nr; ++m) {
524
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
525
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
526
+ }
527
+ }
528
+
529
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) {
530
+ assert(nr != TILE_M);
531
+ for (int m = 0; m < nr; ++m) {
532
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
533
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
534
+ }
535
+ }
536
+
537
+ template <typename TB>
538
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
539
+ assert(nr <= TILE_M);
540
+ for (int m = 0; m < nr; ++m) {
541
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32));
542
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
543
+ }
544
+ }
545
+
546
+ template <>
547
+ void unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
548
+ assert(nr <= TILE_M);
549
+ // zero padding k from 16 to 32, so that we don't have to re-config amx
550
+ const __m128i zero = _mm_setzero_si128();
551
+ for (int m = 0; m < nr; ++m) {
552
+ const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16));
553
+ const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);
554
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r);
555
+ }
556
+ }
557
+
558
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
559
+ inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
560
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
561
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
562
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
563
+ return _mm256_and_si256(lowMask, bytes);
564
+ }
565
+
566
+ // used for block_q4_K
567
+ inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) {
568
+ const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi);
569
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
570
+ const __m256i q4l = _mm256_and_si256(tmp, lowMask);
571
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);
572
+ return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);
573
+ }
574
+
575
+ // used for block_q5_K
576
+ inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) {
577
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
578
+ __m256i hmask = _mm256_set1_epi8(1);
579
+ hmask = _mm256_slli_epi16(hmask, k);
580
+
581
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs);
582
+ const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh);
583
+
584
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);
585
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);
586
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
587
+ hmask = _mm256_slli_epi16(hmask, 1);
588
+
589
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);
590
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);
591
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
592
+
593
+ return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);
594
+ }
595
+
596
+ // used for block_q6_K
597
+ inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) {
598
+ const __m256i m4 = _mm256_set1_epi8(0xF);
599
+ const __m256i m2 = _mm256_set1_epi8(0x3);
600
+
601
+ const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs);
602
+ const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32));
603
+ const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh);
604
+
605
+ const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4);
606
+ const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);
607
+ const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);
608
+ const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);
609
+
610
+ const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);
611
+ const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);
612
+ const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);
613
+ const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);
614
+
615
+ r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);
616
+ r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);
617
+ }
618
+
619
+ inline __m512i packNibbles(__m512i r0, __m512i r1) {
620
+ return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4));
621
+ }
622
+
623
+ template <typename TB>
624
+ inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) {
625
+ int8_t tmp[8 * 64];
626
+ __m256i v[8], v2[8];
627
+ for (int n = 0; n < 8; ++n) {
628
+ v[n] = bytes_from_nibbles_32(B[n * KB].qs);
629
+ }
630
+ transpose_8x8_32bit(v, v2);
631
+ for (int n = 0; n < 8; ++n) {
632
+ _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]);
633
+ }
634
+ for (int n = 0; n < 8; ++n) {
635
+ v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);
636
+ }
637
+ transpose_8x8_32bit(v, v2);
638
+ for (int n = 0; n < 8; ++n) {
639
+ _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]);
640
+ }
641
+
642
+ // pack again with 128 to fully utilize vector length
643
+ for (int n = 0; n < 8; n += 2) {
644
+ __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64));
645
+ __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64));
646
+ __m512i r1r0 = packNibbles(r0, r1);
647
+ _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0);
648
+ }
649
+ }
650
+
651
+ template <>
652
+ inline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
653
+ __m256i v[8], v2[8];
654
+ for (int n = 0; n < 8; ++n) {
655
+ v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs));
656
+ }
657
+ transpose_8x8_32bit(v, v2);
658
+ for (int n = 0; n < 8; ++n) {
659
+ _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]);
660
+ }
661
+ for (int n = 0; n < 8; ++n) {
662
+ v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs));
663
+ }
664
+ transpose_8x8_32bit(v, v2);
665
+ for (int n = 0; n < 8; ++n) {
666
+ _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]);
667
+ }
668
+ }
669
+
670
+ template <>
671
+ inline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
672
+ __m512i v[16];
673
+ // QK_K 256 with 8 groups, handle 2 groups at a time
674
+ char * pb = (char *)packed_B;
675
+ for (int k = 0; k < QK_K / 64; ++k) {
676
+ // pack 2 groups { n, g, k} to {g, k/4, 4n}
677
+ // e.g. {16, 2, 32} to {2, 8, 64}
678
+ for (int n = 0; n < TILE_N; ++n) {
679
+ v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);
680
+ }
681
+
682
+ transpose_16x16_32bit(v);
683
+
684
+ // pack again with 128 to fully utilize vector length
685
+ for (int n = 0; n < TILE_N; n += 2) {
686
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
687
+ pb += 64;
688
+ }
689
+ }
690
+ }
691
+
692
+ template <>
693
+ inline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
694
+ __m512i v[16];
695
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
696
+ // QK_K 256 with 8 groups, handle 2 groups at a time
697
+ char * pb = (char *)packed_B;
698
+ char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
699
+ for (int k = 0; k < QK_K / 64; ++k) {
700
+ // pack 2 groups { n, g, k} to {g, k/4, 4n}
701
+ // e.g. {16, 2, 32} to {2, 8, 64}
702
+ for (int n = 0; n < TILE_N; ++n) {
703
+ v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k);
704
+ }
705
+
706
+ transpose_16x16_32bit(v);
707
+
708
+ // 1. pack lower 4bits with 2 groups
709
+ for (int n = 0; n < TILE_N; n += 2) {
710
+ // get lower 4 bits
711
+ const __m512i r0 = _mm512_and_si512(v[n], lowMask);
712
+ const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
713
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
714
+ }
715
+
716
+ // 2. pack higher 1bit with 2 groups
717
+ const __m512i hmask = _mm512_set1_epi8(0x10);
718
+ for (int g = 0; g < 2; ++g) {
719
+ __m512i hbits = _mm512_setzero_si512();
720
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));
721
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));
722
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));
723
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));
724
+ hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) );
725
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));
726
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));
727
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));
728
+ _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
729
+ }
730
+ }
731
+ }
732
+
733
+ template <>
734
+ inline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
735
+ __m512i v[32];
736
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
737
+ // QK_K 256 with 8 groups, handle 4 groups at a time
738
+ char * pb = (char *)packed_B;
739
+ char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
740
+ for (int k = 0; k < QK_K / 128; ++k) {
741
+ for (int n = 0; n < TILE_N; ++n) {
742
+ bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);
743
+ }
744
+
745
+ // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7
746
+ transpose_16x16_32bit(v);
747
+ transpose_16x16_32bit(v + 16);
748
+
749
+ // 1. pack lower 4bits with 4 groups
750
+ for (int n = 0; n < 32; n += 2) {
751
+ const __m512i r0 = _mm512_and_si512(v[n], lowMask);
752
+ const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
753
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
754
+ }
755
+
756
+ // 2. pack higher 2bit with 4 groups
757
+ const __m512i hmask = _mm512_set1_epi8(0x30);
758
+ for (int g = 0; g < 8; ++g) {
759
+ __m512i hbits = _mm512_setzero_si512();
760
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));
761
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));
762
+ hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) );
763
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));
764
+ _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
765
+ }
766
+ }
767
+ }
768
+
769
+ template <>
770
+ inline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
771
+ __m512i v[16];
772
+ char * pb = (char *)packed_B;
773
+ for (int k = 0; k < QK_K / 64; ++k) {
774
+ for (int n = 0; n < TILE_N; ++n) {
775
+ __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0);
776
+ __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);
777
+ v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);
778
+ }
779
+
780
+ transpose_16x16_32bit(v);
781
+
782
+ // pack again with 128 to fully utilize vector length
783
+ for (int n = 0; n < TILE_N; n += 2) {
784
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
785
+ pb += 64;
786
+ }
787
+ }
788
+ }
789
+
790
+ // pack B to vnni formats in 4bits or 8 bits
791
+ void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) {
792
+ pack_qs(packed_B, B, KB);
793
+ lm_ggml_half * d0 = reinterpret_cast<lm_ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
794
+ for (int n = 0; n < TILE_N; ++n) {
795
+ d0[n] = B[n * KB].d;
796
+ }
797
+ }
798
+
799
+ void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) {
800
+ pack_qs(packed_B, B, KB);
801
+ lm_ggml_half * d0 = reinterpret_cast<lm_ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
802
+ lm_ggml_half * m0 = d0 + TILE_N;
803
+ for (int n = 0; n < TILE_N; ++n) {
804
+ d0[n] = B[n * KB].d;
805
+ m0[n] = B[n * KB].m;
806
+ }
807
+ }
808
+
809
+ inline void s8s8_compensation(void * RESTRICT packed_B) {
810
+ // packed_B layout:
811
+ // quants {TILE_N, TILEK} int8_t
812
+ // d0 {TILE_N} lm_ggml_half
813
+ // comp {TILE_N} int32_t
814
+ const int offset = TILE_N * TILE_K + TILE_N * sizeof(lm_ggml_half);
815
+ __m512i vcomp = _mm512_setzero_si512();
816
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
817
+ for (int k = 0; k < 8; ++k) {
818
+ __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64));
819
+ vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);
820
+ }
821
+ _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp);
822
+ }
823
+
824
+ void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
825
+ pack_qs(packed_B, B, KB);
826
+ lm_ggml_half * d0 = reinterpret_cast<lm_ggml_half *>((char *)packed_B + TILE_N * TILE_K);
827
+ for (int n = 0; n < TILE_N; ++n) {
828
+ d0[n] = B[n * KB].d;
829
+ }
830
+ s8s8_compensation(packed_B);
831
+ }
832
+
833
+ // convert 8 * {min, scale} from int6 to int8
834
+ inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) {
835
+ const uint32_t kmask1 = 0x3f3f3f3f;
836
+ const uint32_t kmask2 = 0x0f0f0f0f;
837
+ const uint32_t kmask3 = 0x03030303;
838
+
839
+ memcpy(utmp, scales, 12);
840
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
841
+ const uint32_t uaux = utmp[1] & kmask1;
842
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
843
+ utmp[2] = uaux;
844
+ utmp[0] &= kmask1;
845
+ }
846
+
847
+ // packed_B layout:
848
+ // quants {8, TILE_N, 16} uint8
849
+ // scales {8, TILE_N} uint8
850
+ // mins {8, TILE_N} uint8
851
+ // d {TILE_N} lm_ggml_half
852
+ // dmin {TILE_N} lm_ggml_half
853
+ void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
854
+ pack_qs(packed_B, B, KB);
855
+
856
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
857
+ uint8_t * mins = scales + 8 * TILE_N;
858
+ lm_ggml_half * d = reinterpret_cast<lm_ggml_half *>(mins + 8 * TILE_N);
859
+ lm_ggml_half * dmin = d + TILE_N;
860
+
861
+ union {
862
+ uint32_t u32[4];
863
+ uint8_t u8[16];
864
+ } s;
865
+
866
+ for (int n = 0; n < TILE_N; ++n) {
867
+ unpack_mins_and_scales(B[n * KB].scales, s.u32);
868
+ for (int k = 0; k < 8; ++k) {
869
+ scales[k * TILE_N + n] = s.u8[k];
870
+ mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
871
+ }
872
+ d[n] = B[n * KB].d;
873
+ dmin[n] = B[n * KB].dmin;
874
+ }
875
+ }
876
+
877
+ // packed_B layout:
878
+ // quants {8, TILE_N, 16} uint8
879
+ // qh {8, TILE_N, 4} uint8
880
+ // scales {8, TILE_N} uint8
881
+ // mins {8, TILE_N} uint8
882
+ // d {TILE_N} lm_ggml_half
883
+ // dmin {TILE_N} lm_ggml_half
884
+ void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
885
+ pack_qs(packed_B, B, KB);
886
+
887
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
888
+ uint8_t * mins = scales + 8 * TILE_N;
889
+ lm_ggml_half * d = reinterpret_cast<lm_ggml_half *>(mins + 8 * TILE_N);
890
+ lm_ggml_half * dmin = d + TILE_N;
891
+
892
+ union {
893
+ uint32_t u32[4];
894
+ uint8_t u8[16];
895
+ } s;
896
+
897
+ for (int n = 0; n < TILE_N; ++n) {
898
+ unpack_mins_and_scales(B[n * KB].scales, s.u32);
899
+ for (int k = 0; k < 8; ++k) {
900
+ scales[k * TILE_N + n] = s.u8[k];
901
+ mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
902
+ }
903
+ d[n] = B[n * KB].d;
904
+ dmin[n] = B[n * KB].dmin;
905
+ }
906
+ }
907
+
908
+ // packed_B layout:
909
+ // quants {16, TILE_N, 8} uint8
910
+ // qh {16, TILE_N, 4} uint8
911
+ // scales {16, TILE_N} uint8
912
+ // d {TILE_N} lm_ggml_half
913
+ void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
914
+ pack_qs(packed_B, B, KB);
915
+
916
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
917
+ lm_ggml_half * d = reinterpret_cast<lm_ggml_half *>(scales + 16 * TILE_N);
918
+ for (int n = 0; n < TILE_N; ++n) {
919
+ const int8_t * ps = B[n * KB].scales;
920
+ for (int k = 0; k < 16; ++k) {
921
+ scales[k * TILE_N + n] = ps[k];
922
+ }
923
+ d[n] = B[n * KB].d;
924
+ }
925
+ }
926
+
927
+ // packed_B layout:
928
+ // quants {8, TILE_N, 16} uint8
929
+ // scales {8, TILE_N} int8
930
+ // d {TILE_N} lm_ggml_half
931
+ void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
932
+ pack_qs(packed_B, B, KB);
933
+
934
+ int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
935
+ lm_ggml_half * d = reinterpret_cast<lm_ggml_half *>(scales + 8 * TILE_N);
936
+
937
+ // pack the scales
938
+ for (int n = 0; n < TILE_N; ++n) {
939
+ uint16_t sh = B[n * KB].scales_h;
940
+ for (int k = 0; k < 8; k += 2) {
941
+ const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;
942
+ const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32;
943
+ scales[(k + 0) * TILE_N + n] = ls1;
944
+ scales[(k + 1) * TILE_N + n] = ls2;
945
+ sh >>= 4;
946
+ }
947
+ d[n] = B[n * KB].d;
948
+ }
949
+ }
950
+
951
+ template<typename TB, typename packed_B_t = packed_B_type<TB>>
952
+ void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {
953
+ LM_GGML_UNUSED(tile);
954
+ LM_GGML_UNUSED(packed_B);
955
+ }
956
+
957
+ template <>
958
+ void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {
959
+ const __m512i off = _mm512_set1_epi8(8);
960
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
961
+ for (int n = 0; n < 8; n += 2) {
962
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
963
+ const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);
964
+ const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);
965
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
966
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
967
+ }
968
+ }
969
+
970
+ template <>
971
+ void unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) {
972
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
973
+ for (int n = 0; n < 8; n += 2) {
974
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
975
+ const __m512i r0 = _mm512_and_si512(bytes, lowMask);
976
+ const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
977
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
978
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
979
+ }
980
+ }
981
+
982
+ // packed_B_t for QKK is int8_t
983
+ template <typename TB>
984
+ void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
985
+ const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
986
+ const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size;
987
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
988
+ for (int n = 0; n < 8; n += 2) {
989
+ __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);
990
+ const __m512i r0 = _mm512_and_si512(bytes, lowMask);
991
+ const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
992
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
993
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
994
+ }
995
+ }
996
+
997
+ template <>
998
+ void unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
999
+ // lower 4bits, stride 256 bytes
1000
+ const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;
1001
+ const char * pb = (const char *)packed_B + k * packed_l4_group_size;
1002
+
1003
+ // higher 1bit, stride 64 bytes
1004
+ const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;
1005
+ const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;
1006
+ const __m512i hbits = _mm512_loadu_si512(ph);
1007
+
1008
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1009
+ __m512i hmask0 = _mm512_set1_epi8(0x1);
1010
+ __m512i hmask1 = _mm512_set1_epi8(0x2);
1011
+
1012
+ for (int n = 0; n < 8; n += 2) {
1013
+ __m512i bytes = _mm512_loadu_si512(pb + n * 32);
1014
+ __m512i r0 = _mm512_and_si512(bytes, lowMask);
1015
+ __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1016
+ __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);
1017
+ __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);
1018
+
1019
+ hmask0 = _mm512_slli_epi16(hmask0, 2);
1020
+ hmask1 = _mm512_slli_epi16(hmask1, 2);
1021
+ r0 = _mm512_add_epi8(r0, h0);
1022
+ r1 = _mm512_add_epi8(r1, h1);
1023
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
1024
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
1025
+ }
1026
+ }
1027
+
1028
+ template <>
1029
+ void unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
1030
+ // lower 4bits, stride 128 bytes
1031
+ const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;
1032
+ const char * pb = (const char *)packed_B + k * packed_l4_group_size;
1033
+
1034
+ // higher 2bits, stride 64 bytes
1035
+ const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;
1036
+ const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;
1037
+ const __m512i hbits = _mm512_loadu_si512(ph);
1038
+
1039
+ const __m512i off = _mm512_set1_epi8(32);
1040
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1041
+ __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011
1042
+ __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100
1043
+
1044
+ // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`
1045
+ __m512i bytes = _mm512_loadu_si512(pb);
1046
+ __m512i r0 = _mm512_and_si512(bytes, lowMask);
1047
+ __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1048
+ __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);
1049
+ __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);
1050
+ _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
1051
+ _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
1052
+
1053
+ hmask0 = _mm512_slli_epi16(hmask0, 4);
1054
+ hmask1 = _mm512_slli_epi16(hmask1, 4);
1055
+
1056
+ bytes = _mm512_loadu_si512(pb + 64);
1057
+ r0 = _mm512_and_si512(bytes, lowMask);
1058
+ r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1059
+ h0 = _mm512_and_si512(hbits, hmask0);
1060
+ h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);
1061
+ _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
1062
+ _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
1063
+ }
1064
+
1065
+ template <>
1066
+ void unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
1067
+ static const __m512i values128 = _mm512_set_epi8(
1068
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1069
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1070
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1071
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
1072
+ );
1073
+
1074
+ const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
1075
+ const char * pb = (const char *)packed_B + k * packed_B_group_size;
1076
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1077
+
1078
+ for (int n = 0; n < 8; n += 2) {
1079
+ __m512i bytes = _mm512_loadu_si512(pb + n * 32);
1080
+ const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));
1081
+ const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
1082
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
1083
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
1084
+ }
1085
+ }
1086
+
1087
+ template <typename TA, typename TB, bool is_acc>
1088
+ struct acc_C {};
1089
+
1090
+ template <bool is_acc>
1091
+ struct acc_C<block_q8_0, block_q4_0, is_acc> {
1092
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
1093
+ const int offset = TILE_N * TILE_K / 2;
1094
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1095
+
1096
+ for (int m = 0; m < nr; ++m) {
1097
+ const __m512 vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].d));
1098
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1099
+
1100
+ __m512 vsum;
1101
+ if (is_acc) {
1102
+ vsum = _mm512_loadu_ps(C + m * ldc);
1103
+ } else {
1104
+ vsum = _mm512_set1_ps(0.f);
1105
+ }
1106
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1107
+ _mm512_storeu_ps(C + m * ldc, vsum);
1108
+ }
1109
+ }
1110
+ };
1111
+
1112
+ template <bool is_acc>
1113
+ struct acc_C<block_q8_1, block_q4_1, is_acc> {
1114
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) {
1115
+ const int offset = TILE_N * TILE_K / 2;
1116
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1117
+ const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(lm_ggml_half))));
1118
+
1119
+ for (int m = 0; m < nr; ++m) {
1120
+ const __m512 vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].d));
1121
+ const __m512 vs1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].s));
1122
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1123
+
1124
+ __m512 vsum;
1125
+ if (is_acc) {
1126
+ vsum = _mm512_loadu_ps(C + m * ldc);
1127
+ } else {
1128
+ vsum = _mm512_set1_ps(0.f);
1129
+ }
1130
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1131
+ vsum = _mm512_fmadd_ps(vm0, vs1, vsum);
1132
+ _mm512_storeu_ps(C + m * ldc, vsum);
1133
+ }
1134
+ }
1135
+ };
1136
+
1137
+ template <bool is_acc>
1138
+ struct acc_C<block_q8_0, block_q8_0, is_acc> {
1139
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
1140
+ const int offset = TILE_N * TILE_K;
1141
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1142
+
1143
+ for (int m = 0; m < nr; ++m) {
1144
+ const __m512 vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].d));
1145
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1146
+
1147
+ __m512 vsum;
1148
+ if (is_acc) {
1149
+ vsum = _mm512_loadu_ps(C + m * ldc);
1150
+ } else {
1151
+ vsum = _mm512_set1_ps(0.f);
1152
+ }
1153
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1154
+ _mm512_storeu_ps(C + m * ldc, vsum);
1155
+ }
1156
+ }
1157
+ };
1158
+
1159
+ template <bool is_acc>
1160
+ struct acc_C<block_q8_K, block_q4_K, is_acc> {
1161
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1162
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
1163
+ const uint8_t * mins = scales + 8 * TILE_N;
1164
+ const lm_ggml_half * d0 = reinterpret_cast<const lm_ggml_half *>(mins + 8 * TILE_N);
1165
+ const lm_ggml_half * dmin = d0 + TILE_N;
1166
+
1167
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1168
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
1169
+
1170
+ for (int m = 0; m < nr; ++m) {
1171
+ const float d1 = A[m * lda].d;
1172
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1173
+ const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
1174
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1175
+
1176
+ __m512 vsum;
1177
+ if (is_acc) {
1178
+ vsum = _mm512_loadu_ps(C + m * ldc);
1179
+ } else {
1180
+ vsum = _mm512_set1_ps(0.f);
1181
+ }
1182
+
1183
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
1184
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1185
+
1186
+ __m512i acc_m = _mm512_setzero_si512();
1187
+ for (int k = 0; k < 4; ++k) {
1188
+ __m512i vmask = _mm512_set1_epi32(k);
1189
+ __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
1190
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
1191
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1192
+ }
1193
+
1194
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1195
+ vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
1196
+ _mm512_storeu_ps(C + m * ldc, vsum);
1197
+ }
1198
+ }
1199
+ };
1200
+
1201
+ template <bool is_acc>
1202
+ struct acc_C<block_q8_K, block_q5_K, is_acc> {
1203
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1204
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
1205
+ const uint8_t * mins = scales + 8 * TILE_N;
1206
+ const lm_ggml_half * d0 = reinterpret_cast<const lm_ggml_half *>(mins + 8 * TILE_N);
1207
+ const lm_ggml_half * dmin = d0 + TILE_N;
1208
+
1209
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1210
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
1211
+
1212
+ for (int m = 0; m < nr; ++m) {
1213
+ const float d1 = A[m * lda].d;
1214
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1215
+ const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
1216
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1217
+
1218
+ __m512 vsum;
1219
+ if (is_acc) {
1220
+ vsum = _mm512_loadu_ps(C + m * ldc);
1221
+ } else {
1222
+ vsum = _mm512_set1_ps(0.f);
1223
+ }
1224
+
1225
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
1226
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1227
+
1228
+ __m512i acc_m = _mm512_setzero_si512();
1229
+ for (int k = 0; k < 4; ++k) {
1230
+ __m512i vmask = _mm512_set1_epi32(k);
1231
+ __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
1232
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
1233
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1234
+ }
1235
+
1236
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1237
+ vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
1238
+ _mm512_storeu_ps(C + m * ldc, vsum);
1239
+ }
1240
+ }
1241
+ };
1242
+
1243
+ template <bool is_acc>
1244
+ struct acc_C<block_q8_K, block_q6_K, is_acc> {
1245
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1246
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
1247
+ const lm_ggml_half * d0 = reinterpret_cast<const lm_ggml_half *>(scales + 16 * TILE_N);
1248
+
1249
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1250
+
1251
+ for (int m = 0; m < nr; ++m) {
1252
+ const float d1 = A[m * lda].d;
1253
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1254
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1255
+
1256
+ __m512 vsum;
1257
+ if (is_acc) {
1258
+ vsum = _mm512_loadu_ps(C + m * ldc);
1259
+ } else {
1260
+ vsum = _mm512_set1_ps(0.f);
1261
+ }
1262
+
1263
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1264
+ _mm512_storeu_ps(C + m * ldc, vsum);
1265
+ }
1266
+ }
1267
+ };
1268
+
1269
+ template <bool is_acc>
1270
+ struct acc_C<block_q8_K, block_iq4_xs, is_acc> {
1271
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1272
+ const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
1273
+ const lm_ggml_half * d0 = reinterpret_cast<const lm_ggml_half *>(scales + 8 * TILE_N);
1274
+
1275
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1276
+
1277
+ for (int m = 0; m < nr; ++m) {
1278
+ const float d1 = A[m * lda].d;
1279
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1280
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1281
+
1282
+ __m512 vsum;
1283
+ if (is_acc) {
1284
+ vsum = _mm512_loadu_ps(C + m * ldc);
1285
+ } else {
1286
+ vsum = _mm512_set1_ps(0.f);
1287
+ }
1288
+
1289
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1290
+ _mm512_storeu_ps(C + m * ldc, vsum);
1291
+ }
1292
+ }
1293
+ };
1294
+
1295
+ template <typename TB> constexpr int get_quants_size();
1296
+ template <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; }
1297
+ template <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; }
1298
+ template <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; }
1299
+ template <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; }
1300
+
1301
+ // used for QKK format
1302
+ template <typename TB, bool is_acc,
1303
+ typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
1304
+ inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) {
1305
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>());
1306
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N)));
1307
+
1308
+ for (int m = 0; m < nr; ++m) {
1309
+ __m512i vsumi;
1310
+ if (is_acc) {
1311
+ vsumi = _mm512_loadu_si512(sumi + m * TILE_N);
1312
+ } else {
1313
+ vsumi = _mm512_setzero_si512();
1314
+ }
1315
+ __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);
1316
+ vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));
1317
+ _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi);
1318
+ }
1319
+ }
1320
+
1321
+ template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
1322
+ struct tinygemm_kernel_avx {
1323
+ static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) {
1324
+ LM_GGML_UNUSED(K);
1325
+ LM_GGML_UNUSED(A);
1326
+ LM_GGML_UNUSED(B);
1327
+ LM_GGML_UNUSED(C);
1328
+ LM_GGML_UNUSED(ldc);
1329
+ }
1330
+ };
1331
+
1332
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1333
+ struct tinygemm_kernel_avx<float, lm_ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1334
+ static void apply(int K, const float * RESTRICT A, const lm_ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) {
1335
+ constexpr int ROWS = BLOCK_M;
1336
+ constexpr int COLS = BLOCK_N;
1337
+ assert(BLOCK_K == 16);
1338
+
1339
+ __m512 va;
1340
+ __m512 vb[COLS];
1341
+ __m512 vc[ROWS * COLS];
1342
+
1343
+ auto loadc = [&](int idx) {
1344
+ vc[idx] = _mm512_setzero_ps();
1345
+ };
1346
+ Unroll<ROWS * COLS>{}(loadc);
1347
+
1348
+ auto compute = [&](int idx, int k) {
1349
+ // TODO: use `constexpr` here to get rid of interger div
1350
+ // when upgraded to C++17
1351
+ const int row = idx / COLS;
1352
+ const int col = idx % COLS;
1353
+
1354
+ if (col == 0) {
1355
+ va = _mm512_loadu_ps(A + row * K + k);
1356
+ }
1357
+ if (row == 0) {
1358
+ vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
1359
+ }
1360
+ vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
1361
+ };
1362
+
1363
+ for (int k = 0; k < K; k += 16) {
1364
+ Unroll<ROWS * COLS>{}(compute, k);
1365
+ }
1366
+
1367
+ auto storec = [&](int idx) {
1368
+ const int row = idx / COLS;
1369
+ const int col = idx % COLS;
1370
+ C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
1371
+ };
1372
+ Unroll<ROWS * COLS>{}(storec);
1373
+ }
1374
+ };
1375
+
1376
+ #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
1377
+ tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
1378
+ K, (const float *)src1->data + mb_start * K, \
1379
+ (const type *)src0->data + nb_start * K, \
1380
+ (float *)dst->data + mb_start * ldc + nb_start, ldc);
1381
+
1382
+
1383
+ // re-organize in the format {NB, KB, TILE_SIZE}:
1384
+ #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
1385
+
1386
+ template<typename TB, int BLOCK_K>
1387
+ void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) {
1388
+ const int NB = N / TILE_N;
1389
+ const int KB = K / BLOCK_K;
1390
+ const int TILE_SIZE = get_tile_size<TB>();
1391
+
1392
+ // parallel on NB should be enough
1393
+ parallel_for(n_threads, NB, [&](int begin, int end) {
1394
+ for (int n = begin; n < end; ++n) {
1395
+ for (int k = 0; k < KB; ++k) {
1396
+ int n0 = n * TILE_N;
1397
+ pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);
1398
+ }
1399
+ }
1400
+ });
1401
+ }
1402
+
1403
+ template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
1404
+ struct tinygemm_kernel_vnni {};
1405
+
1406
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1407
+ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1408
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1409
+
1410
+ constexpr int COLS = BLOCK_N / 16;
1411
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_0);
1412
+
1413
+ const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
1414
+ const char * RESTRICT B = static_cast<const char *>(_B);
1415
+
1416
+ __m512i va[8];
1417
+ __m512 vc[COLS];
1418
+ __m512 vd1;
1419
+
1420
+ // sum of offsets, shared across COLS
1421
+ //
1422
+ // avx512-vnni does not have `_mm512_dpbssd_epi32`,
1423
+ // need to transfrom ss to us:
1424
+ // a * (b - 8) is equavilent to b * a - 8 * a
1425
+ // s u u u s u s
1426
+ //
1427
+ __m512i vcomp;
1428
+
1429
+ const __m512i off = _mm512_set1_epi8(8);
1430
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1431
+
1432
+ auto loadc = [&](int col) {
1433
+ vc[col] = _mm512_setzero_ps();
1434
+ };
1435
+ Unroll<COLS>{}(loadc);
1436
+
1437
+ auto compute = [&](int col, int i) {
1438
+ // load a and compute compensation
1439
+ if (col == 0) {
1440
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1441
+ vcomp = _mm512_setzero_si512();
1442
+ for (int k = 0; k < 8; ++k) {
1443
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1444
+ vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);
1445
+ }
1446
+ vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].d));
1447
+ }
1448
+
1449
+ // load b
1450
+ __m512i vsum = _mm512_setzero_si512();
1451
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1452
+ for (int k = 0; k < 8; k += 2) {
1453
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
1454
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1455
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);
1456
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1457
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);
1458
+ }
1459
+ const int offset = TILE_N * TILE_K / 2;
1460
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1461
+ vsum = _mm512_sub_epi32(vsum, vcomp);
1462
+
1463
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1464
+ };
1465
+
1466
+ for (int i = 0; i < KB; ++i) {
1467
+ Unroll<COLS>{}(compute, i);
1468
+ }
1469
+
1470
+ //store to C
1471
+ auto storec = [&](int col) {
1472
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1473
+ };
1474
+ Unroll<COLS>{}(storec);
1475
+ }
1476
+ };
1477
+
1478
+ template <int BLOCK_N, int BLOCK_K>
1479
+ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {
1480
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1481
+
1482
+ constexpr int COLS = BLOCK_N / 16;
1483
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_1);
1484
+
1485
+ const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A);
1486
+ const char * RESTRICT B = static_cast<const char *>(_B);
1487
+
1488
+ __m512i va[8];
1489
+ __m512i vb[8];
1490
+ __m512 vc[COLS];
1491
+ __m512 vd1, vs1;
1492
+
1493
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1494
+
1495
+ auto loadc = [&](int col) {
1496
+ vc[col] = _mm512_setzero_ps();
1497
+ };
1498
+ Unroll<COLS>{}(loadc);
1499
+
1500
+ auto compute = [&](int col, int i) {
1501
+ // load a
1502
+ if (col == 0) {
1503
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1504
+ for (int k = 0; k < 8; ++k) {
1505
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1506
+ }
1507
+ vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].d));
1508
+ vs1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].s));
1509
+ }
1510
+
1511
+ // load b
1512
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1513
+ for (int k = 0; k < 8; k += 2) {
1514
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
1515
+ vb[k + 0] = _mm512_and_si512(bytes, lowMask);
1516
+ vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1517
+ }
1518
+ const int offset = TILE_N * TILE_K / 2;
1519
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1520
+ const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(lm_ggml_half))));
1521
+
1522
+ __m512i vsum = _mm512_setzero_si512();
1523
+ for (int k = 0; k < 8; ++k) {
1524
+ vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);
1525
+ }
1526
+
1527
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1528
+ vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);
1529
+ };
1530
+
1531
+ for (int i = 0; i < KB; ++i) {
1532
+ Unroll<COLS>{}(compute, i);
1533
+ }
1534
+
1535
+ //store to C
1536
+ auto storec = [&](int col) {
1537
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1538
+ };
1539
+ Unroll<COLS>{}(storec);
1540
+ }
1541
+ };
1542
+
1543
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1544
+ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1545
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1546
+
1547
+ constexpr int COLS = BLOCK_N / 16;
1548
+ const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);
1549
+
1550
+ const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
1551
+ const char * RESTRICT B = static_cast<const char *>(_B);
1552
+
1553
+ __m512i va[8];
1554
+ __m512i vb[8];
1555
+ __m512 vc[COLS];
1556
+ __m512 vd1;
1557
+
1558
+ // Notes: s8s8 igemm compensation in avx512-vnni
1559
+ // change s8s8 to u8s8 with compensate
1560
+ // a * b = (a + 128) * b - 128 * b
1561
+ // s s u s u s
1562
+ //
1563
+ // (128 * b is pre-computed when packing B to vnni formats)
1564
+ //
1565
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1566
+
1567
+ auto loadc = [&](int col) {
1568
+ vc[col] = _mm512_setzero_ps();
1569
+ };
1570
+ Unroll<COLS>{}(loadc);
1571
+
1572
+ auto compute = [&](int col, int i) {
1573
+ // load a and add offset 128
1574
+ if (col == 0) {
1575
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1576
+ for (int k = 0; k < 8; ++k) {
1577
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1578
+ va[k] = _mm512_add_epi8(va[k], off);
1579
+ }
1580
+ vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].d));
1581
+ }
1582
+
1583
+ // load b
1584
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1585
+ for (int k = 0; k < 8; ++k) {
1586
+ vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64));
1587
+ }
1588
+ const int offset = TILE_N * TILE_K;
1589
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1590
+ const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(lm_ggml_half);
1591
+ const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2));
1592
+
1593
+ __m512i vsum = _mm512_setzero_si512();
1594
+ for (int k = 0; k < 8; ++k) {
1595
+ vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);
1596
+ }
1597
+ vsum = _mm512_sub_epi32(vsum, vcomp);
1598
+
1599
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1600
+ };
1601
+
1602
+ for (int i = 0; i < KB; ++i) {
1603
+ Unroll<COLS>{}(compute, i);
1604
+ }
1605
+
1606
+ //store to C
1607
+ auto storec = [&](int col) {
1608
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1609
+ };
1610
+ Unroll<COLS>{}(storec);
1611
+ }
1612
+ };
1613
+
1614
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1615
+ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1616
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1617
+
1618
+ constexpr int COLS = BLOCK_N / 16;
1619
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;
1620
+
1621
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1622
+ const char * RESTRICT B = static_cast<const char *>(_B);
1623
+
1624
+ // a.qs: 8 groups, 32 bytes each group (m256i)
1625
+ __m512i va[8];
1626
+ // a.bsum: 8 groups, 2 bytes each group (m128i)
1627
+ __m512i va_bsum;
1628
+ __m512 vc[COLS];
1629
+ __m512 vd1;
1630
+
1631
+ // packed_B:
1632
+ const int offset_scales = (QK_K / 2) * TILE_N;
1633
+ const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N;
1634
+ const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N;
1635
+ const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(lm_ggml_half);
1636
+
1637
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1638
+
1639
+ auto loadc = [&](int col) {
1640
+ vc[col] = _mm512_setzero_ps();
1641
+ };
1642
+ Unroll<COLS>{}(loadc);
1643
+
1644
+ // Notes: vnni formats in QK_K
1645
+ // a) quants vnni format
1646
+ // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32
1647
+ // from {16, 32} to {8, 64}
1648
+ //
1649
+ // b) min vnni format
1650
+ // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
1651
+ // from {16, 8} to {4, 32}
1652
+ //
1653
+ auto compute = [&](int col, int i) {
1654
+ // load a
1655
+ if (col == 0) {
1656
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1657
+ va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1658
+ }
1659
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1660
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1661
+ va_bsum = _mm512_castsi128_si512(q8s);
1662
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1663
+ }
1664
+
1665
+ // step 1: accumultate the quants
1666
+ __m512i acc = _mm512_setzero_si512();
1667
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1668
+ const char * b_qs = b_ptr;
1669
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1670
+ __m512i vsum = _mm512_setzero_si512();
1671
+ for (int k = 0; k < 8; k += 2) {
1672
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
1673
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
1674
+
1675
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
1676
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1677
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1678
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1679
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1680
+
1681
+ b_qs += 64;
1682
+ }
1683
+ // vacc += scale * (q8 @ q4)
1684
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1685
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1686
+ }
1687
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1688
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1689
+
1690
+ // step 2: accumulate the mins
1691
+ __m512i acc_m = _mm512_setzero_si512();
1692
+ for (int k = 0; k < 4; ++k) {
1693
+ __m512i vmask = _mm512_set1_epi32(k);
1694
+ __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
1695
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
1696
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1697
+ }
1698
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
1699
+ vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
1700
+ };
1701
+
1702
+ for (int i = 0; i < KB; ++i) {
1703
+ Unroll<COLS>{}(compute, i);
1704
+ }
1705
+
1706
+ //store to C
1707
+ auto storec = [&](int col) {
1708
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1709
+ };
1710
+ Unroll<COLS>{}(storec);
1711
+ }
1712
+ };
1713
+
1714
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1715
+ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1716
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1717
+
1718
+ constexpr int COLS = BLOCK_N / 16;
1719
+ const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;
1720
+
1721
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1722
+ const char * RESTRICT B = static_cast<const char *>(_B);
1723
+
1724
+ // a.qs: 8 groups, 32 bytes each group (m256i)
1725
+ __m512i va[8];
1726
+ // a.bsum: 8 groups, 2 bytes each group (m128i)
1727
+ __m512i va_bsum;
1728
+ __m512 vc[COLS];
1729
+ __m512 vd1;
1730
+
1731
+ // packed_B:
1732
+ const int offset_qh = (QK_K / 2) * TILE_N;
1733
+ const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;
1734
+ const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N;
1735
+ const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;
1736
+ const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(lm_ggml_half);
1737
+
1738
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1739
+
1740
+ auto loadc = [&](int col) {
1741
+ vc[col] = _mm512_setzero_ps();
1742
+ };
1743
+ Unroll<COLS>{}(loadc);
1744
+
1745
+ // Q5_K and Q4_K shares the same vnni formats, refer to notes above.
1746
+ auto compute = [&](int col, int i) {
1747
+ // load a
1748
+ if (col == 0) {
1749
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1750
+ va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1751
+ }
1752
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1753
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1754
+ va_bsum = _mm512_castsi128_si512(q8s);
1755
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1756
+ }
1757
+
1758
+ // step 1: accumultate the quants
1759
+ __m512i acc = _mm512_setzero_si512();
1760
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1761
+ const char * b_qs = b_ptr;
1762
+ const char * b_qh = b_ptr + offset_qh;
1763
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1764
+ __m512i vsum = _mm512_setzero_si512();
1765
+ __m512i hmask0 = _mm512_set1_epi8(0x1);
1766
+ __m512i hmask1 = _mm512_set1_epi8(0x2);
1767
+ __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64));
1768
+ for (int k = 0; k < 8; k += 2) {
1769
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
1770
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
1771
+
1772
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
1773
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1774
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1775
+
1776
+ __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);
1777
+ __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);
1778
+
1779
+ hmask0 = _mm512_slli_epi16(hmask0, 2);
1780
+ hmask1 = _mm512_slli_epi16(hmask1, 2);
1781
+ vb0 = _mm512_add_epi8(vb0, vh0);
1782
+ vb1 = _mm512_add_epi8(vb1, vh1);
1783
+
1784
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1785
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1786
+
1787
+ b_qs += 64;
1788
+ }
1789
+ // vacc += scale * (q8 @ q5)
1790
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1791
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1792
+ }
1793
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1794
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1795
+
1796
+ // step 2: accumulate the mins
1797
+ __m512i acc_m = _mm512_setzero_si512();
1798
+ for (int k = 0; k < 4; ++k) {
1799
+ __m512i vmask = _mm512_set1_epi32(k);
1800
+ __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
1801
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
1802
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1803
+ }
1804
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
1805
+ vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
1806
+ };
1807
+
1808
+ for (int i = 0; i < KB; ++i) {
1809
+ Unroll<COLS>{}(compute, i);
1810
+ }
1811
+
1812
+ //store to C
1813
+ auto storec = [&](int col) {
1814
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1815
+ };
1816
+ Unroll<COLS>{}(storec);
1817
+ }
1818
+ };
1819
+
1820
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1821
+ struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1822
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1823
+
1824
+ constexpr int COLS = BLOCK_N / 16;
1825
+ const int TILE_SIZE = TILE_N * sizeof(block_q6_K);
1826
+
1827
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1828
+ const char * RESTRICT B = static_cast<const char *>(_B);
1829
+
1830
+ // load the 256 bytes from A to 4 avx512 vectors
1831
+ __m512i va[4];
1832
+ __m512 vc[COLS];
1833
+ __m512 vd1;
1834
+
1835
+ // packed_B:
1836
+ const int offset_qh = (QK_K / 2) * TILE_N;
1837
+ const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;
1838
+ const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;
1839
+
1840
+ // compensation
1841
+ __m512i vcomp;
1842
+
1843
+ const __m512i m32s = _mm512_set1_epi32(32);
1844
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1845
+
1846
+ auto loadc = [&](int col) {
1847
+ vc[col] = _mm512_setzero_ps();
1848
+ };
1849
+ Unroll<COLS>{}(loadc);
1850
+
1851
+ auto compute = [&](int col, int i) {
1852
+ if (col == 0) {
1853
+ // load a
1854
+ va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1855
+ va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
1856
+ va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
1857
+ va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
1858
+
1859
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1860
+ vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);
1861
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1862
+ }
1863
+
1864
+ // accmulate the quants
1865
+ __m512i acc = _mm512_setzero_si512();
1866
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1867
+ const char * b_qs = b_ptr;
1868
+ const char * b_qh = b_ptr + offset_qh;
1869
+ int mask = 0;
1870
+ for (int k_group = 0; k_group < QK_K / 16; ++k_group) {
1871
+ int r = k_group >> 2;
1872
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1873
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1874
+
1875
+ __m512i vsum = _mm512_setzero_si512();
1876
+ __m512i hmask = _mm512_set1_epi8(0x3);
1877
+
1878
+ __m512i bytes = _mm512_loadu_si512(b_qs);
1879
+ __m512i hbits = _mm512_loadu_si512(b_qh);
1880
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1881
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1882
+ __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);
1883
+ __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);
1884
+
1885
+ vb0 = _mm512_add_epi8(vb0, vh0);
1886
+ vb1 = _mm512_add_epi8(vb1, vh1);
1887
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1888
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1889
+ b_qs += 64;
1890
+
1891
+ va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1892
+ va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1893
+
1894
+ bytes = _mm512_loadu_si512(b_qs);
1895
+ vb0 = _mm512_and_si512(bytes, lowMask);
1896
+ vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1897
+ vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));
1898
+ vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);
1899
+ vb0 = _mm512_add_epi8(vb0, vh0);
1900
+ vb1 = _mm512_add_epi8(vb1, vh1);
1901
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1902
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1903
+ b_qs += 64;
1904
+ b_qh += 64;
1905
+
1906
+ // B * A - 32 * A
1907
+ __m512i vmask = _mm512_set1_epi32(k_group);
1908
+ vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
1909
+
1910
+ // vacc += scale * (q8 @ q6)
1911
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1912
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1913
+ }
1914
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1915
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1916
+ };
1917
+
1918
+ for (int i = 0; i < KB; ++i) {
1919
+ Unroll<COLS>{}(compute, i);
1920
+ }
1921
+
1922
+ //store to C
1923
+ auto storec = [&](int col) {
1924
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1925
+ };
1926
+ Unroll<COLS>{}(storec);
1927
+ }
1928
+ };
1929
+
1930
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1931
+ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1932
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1933
+
1934
+ constexpr int COLS = BLOCK_N / 16;
1935
+ const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;
1936
+
1937
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1938
+ const char * RESTRICT B = static_cast<const char *>(_B);
1939
+
1940
+ // load the 256 bytes from A to 4 avx512 vectors
1941
+ __m512i va[4];
1942
+ __m512 vc[COLS];
1943
+ __m512 vd1;
1944
+
1945
+ // packed_B:
1946
+ const int offset_scales = (QK_K / 2) * TILE_N ;
1947
+ const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N;
1948
+
1949
+ // compensation
1950
+ __m512i vcomp;
1951
+
1952
+ const __m256i m128s = _mm256_set1_epi16(128);
1953
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1954
+
1955
+ const __m512i values128 = _mm512_set_epi8(
1956
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1957
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1958
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1959
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
1960
+ );
1961
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1962
+ const __m512i values256 = _mm512_add_epi8(values128, off);
1963
+
1964
+ auto loadc = [&](int col) {
1965
+ vc[col] = _mm512_setzero_ps();
1966
+ };
1967
+ Unroll<COLS>{}(loadc);
1968
+
1969
+ auto compute = [&](int col, int i) {
1970
+ if (col == 0) {
1971
+ // load a
1972
+ va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1973
+ va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
1974
+ va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
1975
+ va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
1976
+
1977
+ // compensation: 128 * A
1978
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1979
+ vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));
1980
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1981
+ }
1982
+
1983
+ // accmulate the quants
1984
+ __m512i acc = _mm512_setzero_si512();
1985
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1986
+ const char * b_qs = b_ptr;
1987
+ int mask = 0;
1988
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1989
+ int r = k_group >> 1;
1990
+ __m512i vmask = _mm512_set1_epi32(k_group);
1991
+ __m512i vsum = _mm512_setzero_si512();
1992
+ for (int k = 0; k < 8; k += 2) {
1993
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1994
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1995
+
1996
+ __m512i bytes = _mm512_loadu_si512(b_qs);
1997
+ __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));
1998
+ __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
1999
+
2000
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
2001
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
2002
+ b_qs += 64;
2003
+ }
2004
+ // (B + 128) * A - 128 * A
2005
+ vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
2006
+
2007
+ // vacc += scale * (q8 @ q4)
2008
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
2009
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
2010
+ }
2011
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
2012
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
2013
+ };
2014
+
2015
+ for (int i = 0; i < KB; ++i) {
2016
+ Unroll<COLS>{}(compute, i);
2017
+ }
2018
+
2019
+ //store to C
2020
+ auto storec = [&](int col) {
2021
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
2022
+ };
2023
+ Unroll<COLS>{}(storec);
2024
+ }
2025
+ };
2026
+
2027
+ #define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
2028
+ tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
2029
+ KB, (const char *)wdata + 0 * row_size_A, \
2030
+ (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
2031
+ (float *) dst->data + 0 * N + nb_start, ldc)
2032
+
2033
+ template <typename TA, typename TB, typename TC, int BLOCK_K,
2034
+ typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
2035
+ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) {
2036
+ using packed_B_t = packed_B_type<TB>;
2037
+ const int TILE_SIZE = get_tile_size<TB>();
2038
+ const bool need_unpack = do_unpack<TB>::value;
2039
+
2040
+ LM_GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
2041
+ const TA * RESTRICT A = static_cast<const TA *>(_A);
2042
+ const char * RESTRICT B = static_cast<const char *>(_B);
2043
+
2044
+ const int m0 = std::min(M, TILE_M);
2045
+ const int m1 = std::max(M - TILE_M, 0);
2046
+ const int lda = KB * sizeof(TA);
2047
+ //const int ldb = KB * sizeof(TB);
2048
+
2049
+ static thread_local packed_B_t Tile0[TILE_N * TILE_K];
2050
+ static thread_local packed_B_t Tile1[TILE_N * TILE_K];
2051
+ static thread_local int8_t Tile23[TILE_M * TILE_K];
2052
+
2053
+ static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
2054
+ static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
2055
+
2056
+ // double buffering C to interleave avx512 and amx
2057
+ int32_t * C_cur = TileC0;
2058
+ int32_t * C_pre = TileC1;
2059
+
2060
+ auto Tile4 = [&](int32_t * base) { return base; };
2061
+ auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; };
2062
+ auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; };
2063
+ auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; };
2064
+
2065
+ if (M == 2 * TILE_M) {
2066
+ // i = 0
2067
+ const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);
2068
+ const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);
2069
+ if (need_unpack) {
2070
+ unpack_B<TB>(Tile0, B_blk0);
2071
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2072
+ } else {
2073
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2074
+ }
2075
+
2076
+ _tile_zero(TMM4);
2077
+ _tile_loadd(TMM2, A[0].qs, lda);
2078
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2079
+ _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));
2080
+
2081
+ _tile_zero(TMM5);
2082
+ _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);
2083
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2084
+ _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
2085
+
2086
+ if (need_unpack) {
2087
+ unpack_B<TB>(Tile1, B_blk0);
2088
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2089
+ } else {
2090
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2091
+ }
2092
+
2093
+ _tile_zero(TMM6);
2094
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2095
+ _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));
2096
+
2097
+ _tile_zero(TMM7);
2098
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2099
+ _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));
2100
+
2101
+ for (int i = 1; i < KB; ++i) {
2102
+ // index of previous iter
2103
+ const int ii = i - 1;
2104
+ const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
2105
+ const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
2106
+ LM_GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {
2107
+ if (need_unpack) {
2108
+ unpack_B<TB>(Tile0, B_blk0);
2109
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2110
+ } else {
2111
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2112
+ }
2113
+ _tile_zero(TMM4);
2114
+ _tile_loadd(TMM2, A[i].qs, lda);
2115
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2116
+
2117
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2118
+ _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
2119
+
2120
+ _tile_zero(TMM5);
2121
+ _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);
2122
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2123
+
2124
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2125
+ _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
2126
+
2127
+ if (need_unpack) {
2128
+ unpack_B<TB>(Tile1, B_blk1);
2129
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2130
+ } else {
2131
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2132
+ }
2133
+ _tile_zero(TMM6);
2134
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2135
+
2136
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2137
+ _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
2138
+
2139
+ _tile_zero(TMM7);
2140
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2141
+
2142
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2143
+ _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
2144
+
2145
+ std::swap(C_cur, C_pre);
2146
+ });
2147
+ }
2148
+ // final accumulation
2149
+ {
2150
+ int ii = KB - 1;
2151
+ acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2152
+ acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2153
+ acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2154
+ acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2155
+ }
2156
+ } else {
2157
+ for (int i = 0; i < KB; ++i) {
2158
+ _tile_zero(TMM4);
2159
+ _tile_zero(TMM6);
2160
+ if (m1 != 0) {
2161
+ _tile_zero(TMM5);
2162
+ _tile_zero(TMM7);
2163
+ }
2164
+
2165
+ const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
2166
+ const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
2167
+ if (need_unpack) {
2168
+ unpack_B<TB>(Tile0, B_blk0);
2169
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2170
+ } else {
2171
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2172
+ }
2173
+
2174
+ if (need_unpack) {
2175
+ unpack_B<TB>(Tile1, B_blk1);
2176
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2177
+ } else {
2178
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2179
+ }
2180
+
2181
+ if (m0 == TILE_M) {
2182
+ _tile_loadd(TMM2, A[i].qs, lda);
2183
+ } else {
2184
+ unpack_A(Tile23, &A[i], KB, m0);
2185
+ _tile_loadd(TMM2, Tile23, TILE_K);
2186
+ }
2187
+
2188
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2189
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2190
+
2191
+ _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
2192
+ _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
2193
+
2194
+ LM_GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2195
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
2196
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
2197
+ });
2198
+
2199
+ if (m1 != 0) {
2200
+ unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);
2201
+ _tile_loadd(TMM3, Tile23, TILE_K);
2202
+
2203
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2204
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2205
+ _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
2206
+ _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
2207
+ LM_GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2208
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
2209
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
2210
+ });
2211
+ }
2212
+ }
2213
+ }
2214
+ return;
2215
+ }
2216
+
2217
+ template <typename TA, typename TB, typename TC, int BLOCK_K,
2218
+ typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
2219
+ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
2220
+ static_assert(std::is_same<TA, block_q8_K>::value);
2221
+ const int TILE_SIZE = get_tile_size<TB>();
2222
+
2223
+ LM_GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
2224
+ const TA * RESTRICT A = static_cast<const TA *>(_A);
2225
+ const char * RESTRICT B = static_cast<const char *>(_B);
2226
+
2227
+ const int m0 = std::min(M, TILE_M);
2228
+ const int m1 = std::max(M - TILE_M, 0);
2229
+ //const int lda = KB * sizeof(TA);
2230
+
2231
+ static thread_local int8_t Tile0[TILE_N * TILE_K];
2232
+ static thread_local int8_t Tile1[TILE_N * TILE_K];
2233
+ static thread_local int8_t Tile23[TILE_M * TILE_K];
2234
+
2235
+ // mat mul result for each group
2236
+ static thread_local int32_t Tile4[TILE_M * TILE_N];
2237
+ static thread_local int32_t Tile5[TILE_M * TILE_N];
2238
+ static thread_local int32_t Tile6[TILE_M * TILE_N];
2239
+ static thread_local int32_t Tile7[TILE_M * TILE_N];
2240
+
2241
+ // sum of each QK_K block, contains 8 groups, int32
2242
+ static thread_local int32_t Sumi4[TILE_M * TILE_N];
2243
+ static thread_local int32_t Sumi5[TILE_M * TILE_N];
2244
+ static thread_local int32_t Sumi6[TILE_M * TILE_N];
2245
+ static thread_local int32_t Sumi7[TILE_M * TILE_N];
2246
+
2247
+ const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
2248
+ for (int i = 0; i < KB; ++i) {
2249
+ // step 1: accumulate the quants across 8 groups, each group with 32
2250
+ for (int k = 0; k < QK_K / k_group_size; ++k) {
2251
+ LM_GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {
2252
+ _tile_zero(TMM4);
2253
+ _tile_zero(TMM6);
2254
+
2255
+ unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);
2256
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2257
+
2258
+ unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);
2259
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2260
+
2261
+ unpack_A<TB>(Tile23, &A[i], KB, k, m0);
2262
+ _tile_loadd(TMM2, Tile23, TILE_K);
2263
+
2264
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2265
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2266
+
2267
+ _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
2268
+ _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
2269
+
2270
+ scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);
2271
+ scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);
2272
+
2273
+ if (m1 != 0) {
2274
+ _tile_zero(TMM5);
2275
+ _tile_zero(TMM7);
2276
+
2277
+ unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);
2278
+ _tile_loadd(TMM3, Tile23, TILE_K);
2279
+
2280
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2281
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2282
+
2283
+ _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
2284
+ _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
2285
+
2286
+ scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);
2287
+ scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);
2288
+ }
2289
+ });
2290
+ }
2291
+
2292
+ // step 2: accmulate the mins
2293
+ LM_GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2294
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
2295
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
2296
+ if (m1 != 0) {
2297
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
2298
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
2299
+ }
2300
+ });
2301
+ }
2302
+ return;
2303
+ }
2304
+
2305
+ } // anonymous namespace
2306
+
2307
+ // get the packed tensor size for quantized weights
2308
+ size_t lm_ggml_backend_amx_get_alloc_size(const struct lm_ggml_tensor * tensor) {
2309
+ const enum lm_ggml_type TYPE = tensor->type;
2310
+
2311
+ const int K = tensor->ne[0]; // ne0: in_features
2312
+ const int N = tensor->ne[1]; // ne1: out_features
2313
+
2314
+ auto get_tensor_size = [&] {
2315
+ size_t row_size_B{0};
2316
+ LM_GGML_DISPATCH_QTYPES(TYPE, [&] {
2317
+ row_size_B = get_row_size<type, blck_size>(K);
2318
+ });
2319
+ return N * row_size_B;
2320
+ };
2321
+
2322
+ if (qtype_has_amx_kernels(TYPE)) {
2323
+ return get_tensor_size();
2324
+ } else {
2325
+ // for f16, bf16 we don't do packing
2326
+ return lm_ggml_nbytes(tensor);
2327
+ }
2328
+ }
2329
+
2330
+ // pack weight to vnni format
2331
+ void lm_ggml_backend_amx_convert_weight(struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2332
+ LM_GGML_ASSERT(offset == 0 && size == lm_ggml_nbytes(tensor)); // only full tensor conversion is supported for now
2333
+
2334
+ const enum lm_ggml_type TYPE = tensor->type;
2335
+
2336
+ const int K = tensor->ne[0]; // ne0: in_features
2337
+ const int N = tensor->ne[1]; // ne1: out_features
2338
+
2339
+ #if defined(_OPENMP)
2340
+ // the buffer ctx is not initialized when .set_tensor is called
2341
+ int n_threads = omp_get_num_threads();
2342
+ #else
2343
+ int n_threads = 1;
2344
+ #endif
2345
+
2346
+ LM_GGML_DISPATCH_QTYPES(TYPE, [&] {
2347
+ convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads);
2348
+ });
2349
+ }
2350
+
2351
+ size_t lm_ggml_backend_amx_desired_wsize(const struct lm_ggml_tensor * dst) {
2352
+ struct lm_ggml_tensor * src0 = dst->src[0];
2353
+
2354
+ const enum lm_ggml_type TYPE = src0->type;
2355
+
2356
+ const bool is_floating_type = TYPE == LM_GGML_TYPE_F16;
2357
+ if (is_floating_type) {
2358
+ return 0;
2359
+ }
2360
+
2361
+ const int M = dst->ne[1];
2362
+ const int K = src0->ne[0];
2363
+
2364
+ size_t desired_wsize = 0;
2365
+
2366
+ LM_GGML_DISPATCH_QTYPES(TYPE, [&] {
2367
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2368
+ desired_wsize = M * row_size_A;
2369
+ });
2370
+
2371
+ return desired_wsize;
2372
+ }
2373
+
2374
+ // NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)
2375
+ //
2376
+ // src0: weight in shape of {N, K}, quantized
2377
+ // src1: input in shape of {M, K}, float32
2378
+ // dst: output in shape of {M, N}, float32
2379
+ //
2380
+ // the function performs: dst = src1 @ src0.T
2381
+ //
2382
+ void lm_ggml_backend_amx_mul_mat(const lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) {
2383
+ struct lm_ggml_tensor * src0 = dst->src[0];
2384
+ struct lm_ggml_tensor * src1 = dst->src[1];
2385
+
2386
+ const enum lm_ggml_type TYPE = src0->type;
2387
+
2388
+ // f16 only has avx512 kernels for now,
2389
+ // amx kernels will be added once 6th gen xeon is released.
2390
+ const bool is_floating_type = TYPE == LM_GGML_TYPE_F16;
2391
+
2392
+ const int M = dst->ne[1];
2393
+ const int N = dst->ne[0];
2394
+ const int K = src0->ne[0];
2395
+ const int ldc = dst->nb[1] / dst->nb[0];
2396
+
2397
+ if (is_floating_type) {
2398
+ constexpr int BLOCK_M = 4;
2399
+ constexpr int BLOCK_N = 6;
2400
+ const int MB = div_up(M, BLOCK_M);
2401
+ const int NB = div_up(N, BLOCK_N);
2402
+
2403
+ parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2404
+ LM_GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
2405
+ for (int i = begin; i < end; ++i) {
2406
+ int mb = i / NB;
2407
+ int nb = i % NB;
2408
+
2409
+ int mb_start = mb * BLOCK_M;
2410
+ int mb_size = std::min(BLOCK_M, M - mb_start);
2411
+ int nb_start = nb * BLOCK_N;
2412
+ int nb_size = std::min(BLOCK_N, N - nb_start);
2413
+
2414
+ switch (mb_size << 4 | nb_size) {
2415
+ case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break;
2416
+ case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break;
2417
+ case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break;
2418
+ case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break;
2419
+ case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break;
2420
+ case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break;
2421
+ case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break;
2422
+ case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break;
2423
+ case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break;
2424
+ case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break;
2425
+ case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break;
2426
+ case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break;
2427
+ default: fprintf(stderr, "Unexpected block size!\n");
2428
+ }
2429
+ }
2430
+ });
2431
+ });
2432
+ return;
2433
+ }
2434
+
2435
+ // pointer to work space, used convert A from float to quantized type
2436
+ void * wdata = params->wdata;
2437
+
2438
+ //TODO: performance improvement: merge quant A
2439
+ if (params->ith == 0) {
2440
+ LM_GGML_DISPATCH_QTYPES(TYPE, [&] {
2441
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2442
+ const size_t desired_wsize = M * row_size_A;
2443
+ if (params->wsize < desired_wsize) {
2444
+ LM_GGML_ABORT("insufficient work space size");
2445
+ }
2446
+
2447
+ // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
2448
+ // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
2449
+ LM_GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
2450
+
2451
+ const float * A_data = static_cast<const float *>(src1->data);
2452
+ for (int m = 0; m < M; ++m) {
2453
+ from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
2454
+ }
2455
+ });
2456
+ }
2457
+
2458
+ lm_ggml_barrier(params->threadpool);
2459
+
2460
+ if (M == 1) {
2461
+ // MB = 1 and handle 8 tiles in each block
2462
+ constexpr int kTilesN = 4;
2463
+ constexpr int BLOCK_N = TILE_N * kTilesN;
2464
+ const int NB = div_up(N, BLOCK_N);
2465
+
2466
+ parallel_for_ggml(params, NB, [&](int begin, int end) {
2467
+ LM_GGML_DISPATCH_QTYPES(TYPE, [&] {
2468
+ const int KB = K / blck_size;
2469
+ const int TILE_SIZE = get_tile_size<type>();
2470
+ const int row_size_A = KB * sizeof(vec_dot_type);
2471
+ for (int i = begin; i < end; ++i) {
2472
+ int nb = i;
2473
+ int nb_start = nb * BLOCK_N;
2474
+ int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
2475
+
2476
+ switch (nb_size) {
2477
+ //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;
2478
+ case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break;
2479
+ case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break;
2480
+ case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break;
2481
+ case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break;
2482
+ default: fprintf(stderr, "Unexpected n block size!\n");
2483
+ }
2484
+ }
2485
+ });
2486
+ });
2487
+ return;
2488
+ }
2489
+
2490
+ // handle 4 tiles at a tile
2491
+ constexpr int BLOCK_M = TILE_M * 2;
2492
+ constexpr int BLOCK_N = TILE_N * 2;
2493
+ const int MB = div_up(M, BLOCK_M);
2494
+ const int NB = div_up(N, BLOCK_N);
2495
+
2496
+ parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2497
+ // init tile config for each thread
2498
+ lm_ggml_tile_config_init();
2499
+
2500
+ LM_GGML_DISPATCH_QTYPES(TYPE, [&] {
2501
+ const int KB = K / blck_size;
2502
+ const int TILE_SIZE = get_tile_size<type>();
2503
+ const int row_size_A = KB * sizeof(vec_dot_type);
2504
+
2505
+ for (int i = begin; i < end; ++i) {
2506
+ int mb = i / NB;
2507
+ int nb = i % NB;
2508
+
2509
+ int mb_start = mb * BLOCK_M;
2510
+ int mb_size = std::min(BLOCK_M, M - mb_start);
2511
+ int nb_start = nb * BLOCK_N;
2512
+ int nb_size = BLOCK_N;
2513
+
2514
+ tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
2515
+ mb_size, nb_size, KB,
2516
+ (const char *)wdata + mb_start * row_size_A,
2517
+ (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
2518
+ (float *) dst->data + mb_start * N + nb_start, ldc);
2519
+ }
2520
+ });
2521
+ });
2522
+ }
2523
+
2524
+ #endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__)