whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,2511 @@
1
+
2
+ #if defined(__GNUC__)
3
+ #pragma GCC diagnostic ignored "-Wpedantic"
4
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
5
+ #endif
6
+
7
+ #include "amx.h"
8
+ #include "mmq.h"
9
+ #include "ggml-impl.h"
10
+ #include "ggml-cpu-impl.h"
11
+ #include "ggml-cpu-quants.h"
12
+ #include "ggml-quants.h"
13
+ #include <algorithm>
14
+ #include <type_traits>
15
+
16
+ #if defined(__gnu_linux__)
17
+ #include <sys/syscall.h>
18
+ #include <unistd.h>
19
+ #endif
20
+
21
+ #if (defined(_WIN32) || defined(_WIN64))
22
+ #define RESTRICT __restrict
23
+ #else
24
+ #define RESTRICT __restrict__
25
+ #endif
26
+
27
+ #if (defined(_WIN32) || defined(_WIN64))
28
+ #define ALWAYS_INLINE __forceinline
29
+ #elif __has_attribute(always_inline) || defined(__GNUC__)
30
+ #define ALWAYS_INLINE __attribute__((__always_inline__)) inline
31
+ #else
32
+ #define ALWAYS_INLINE inline
33
+ #endif
34
+
35
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
36
+
37
+ namespace {
38
+
39
+ // Forced unrolling
40
+ template <int n>
41
+ struct Unroll {
42
+ template <typename Func, typename... Args>
43
+ ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
44
+ Unroll<n - 1>{}(f, args...);
45
+ f(std::integral_constant<int, n - 1>{}, args...);
46
+ }
47
+ };
48
+
49
+ template <>
50
+ struct Unroll<1> {
51
+ template <typename Func, typename... Args>
52
+ ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
53
+ f(std::integral_constant<int, 0>{}, args...);
54
+ }
55
+ };
56
+
57
+ // type traits
58
+ template <typename T> struct PackedTypes {};
59
+ template <> struct PackedTypes<block_q4_0> { using type = int8_t; };
60
+ template <> struct PackedTypes<block_q4_1> { using type = uint8_t; };
61
+ template <> struct PackedTypes<block_q8_0> { using type = int8_t; };
62
+ template <typename T> using packed_B_type = typename PackedTypes<T>::type;
63
+
64
+ template <typename T>
65
+ struct do_compensate : std::integral_constant<bool,
66
+ std::is_same<T, block_q8_0>::value> {};
67
+
68
+ template <typename T>
69
+ struct do_unpack : std::integral_constant<bool,
70
+ std::is_same<T, block_q4_0>::value ||
71
+ std::is_same<T, block_q4_1>::value> {};
72
+
73
+ template <typename T>
74
+ struct is_type_qkk : std::integral_constant<bool,
75
+ std::is_same<T, block_q4_K>::value ||
76
+ std::is_same<T, block_q5_K>::value ||
77
+ std::is_same<T, block_q6_K>::value ||
78
+ std::is_same<T, block_iq4_xs>::value> {};
79
+
80
+ #define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \
81
+ [&] { \
82
+ switch (TYPE) { \
83
+ case GGML_TYPE_F16: { \
84
+ using type = ggml_fp16_t; \
85
+ constexpr int blck_size = 16; \
86
+ return __VA_ARGS__(); \
87
+ } \
88
+ case GGML_TYPE_BF16: { \
89
+ using type = ggml_bf16_t; \
90
+ constexpr int blck_size = 32; \
91
+ return __VA_ARGS__(); \
92
+ } \
93
+ default: \
94
+ fprintf(stderr, "Unsupported floating data type\n"); \
95
+ } \
96
+ }()
97
+
98
+ #define GGML_DISPATCH_QTYPES(QT, ...) \
99
+ [&] { \
100
+ switch (QT) { \
101
+ case GGML_TYPE_Q4_0: { \
102
+ using type = block_q4_0; \
103
+ using vec_dot_type = block_q8_0; \
104
+ constexpr int blck_size = QK4_0; \
105
+ return __VA_ARGS__(); \
106
+ } \
107
+ case GGML_TYPE_Q4_1: { \
108
+ using type = block_q4_1; \
109
+ using vec_dot_type = block_q8_1; \
110
+ constexpr int blck_size = QK4_1; \
111
+ return __VA_ARGS__(); \
112
+ } \
113
+ case GGML_TYPE_Q8_0: { \
114
+ using type = block_q8_0; \
115
+ using vec_dot_type = block_q8_0; \
116
+ constexpr int blck_size = QK8_0; \
117
+ return __VA_ARGS__(); \
118
+ } \
119
+ case GGML_TYPE_Q4_K: { \
120
+ using type = block_q4_K; \
121
+ using vec_dot_type = block_q8_K; \
122
+ constexpr int blck_size = QK_K; \
123
+ return __VA_ARGS__(); \
124
+ } \
125
+ case GGML_TYPE_Q5_K: { \
126
+ using type = block_q5_K; \
127
+ using vec_dot_type = block_q8_K; \
128
+ constexpr int blck_size = QK_K; \
129
+ return __VA_ARGS__(); \
130
+ } \
131
+ case GGML_TYPE_Q6_K: { \
132
+ using type = block_q6_K; \
133
+ using vec_dot_type = block_q8_K; \
134
+ constexpr int blck_size = QK_K; \
135
+ return __VA_ARGS__(); \
136
+ } \
137
+ case GGML_TYPE_IQ4_XS: { \
138
+ using type = block_iq4_xs; \
139
+ using vec_dot_type = block_q8_K; \
140
+ constexpr int blck_size = QK_K; \
141
+ return __VA_ARGS__(); \
142
+ } \
143
+ default: \
144
+ fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \
145
+ } \
146
+ }()
147
+
148
+ #define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
149
+ [&] { \
150
+ if (BOOL_V) { \
151
+ constexpr bool BOOL_NAME = true; \
152
+ return __VA_ARGS__(); \
153
+ } else { \
154
+ constexpr bool BOOL_NAME = false; \
155
+ return __VA_ARGS__(); \
156
+ } \
157
+ }()
158
+
159
+ // define amx tile config data structure
160
+ struct tile_config_t{
161
+ uint8_t palette_id = 0;
162
+ uint8_t start_row = 0;
163
+ uint8_t reserved_0[14] = {0};
164
+ uint16_t colsb[16] = {0};
165
+ uint8_t rows[16] = {0};
166
+ };
167
+
168
+ // Notes: amx tile config
169
+ //
170
+ // Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,
171
+ // and accumulate the result to a 16 x 16 matrix C containing INT32 values,
172
+ //
173
+ // As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used
174
+ // instead of the normally used 16-16-64 config.
175
+ //
176
+ // Block A: {16, 32}, dtype = int8_t
177
+ // Block B: {16, 32}, dtype = uint8_t/int8_t
178
+ // Block C: {16, 16}, dtype = int32_t
179
+ //
180
+ // Block B needs to be prepacked to vnni format before feeding into TMUL:
181
+ // packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}
182
+ //
183
+ // Therefore, we get tileconfig:
184
+ // A B C
185
+ // rows 16 8 16
186
+ // colsb 32 64 16
187
+ //
188
+ // For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,
189
+ // C used TMM4-TMM7:
190
+ // B TMM0 B TMM1
191
+ // A TMM2 C TMM4 C TMM6
192
+ // A TMM3 C TMM5 C TMM7
193
+ //
194
+ // Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A
195
+ // will be needed.
196
+ //
197
+ // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;
198
+ // and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.
199
+ //
200
+ // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/
201
+ // advanced-matrix-extensions-intrinsics-functions.html
202
+ //
203
+
204
+ #define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
205
+ void ggml_tile_config_init(void) {
206
+ static thread_local bool is_first_time = true;
207
+
208
+ if (!is_first_time) {
209
+ return;
210
+ }
211
+
212
+ static thread_local tile_config_t tc;
213
+ tile_config_t current_tc;
214
+ _tile_storeconfig(&current_tc);
215
+
216
+ // load only when config changes
217
+ if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
218
+ memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
219
+ tc.palette_id = 1;
220
+ tc.start_row = 0;
221
+ TC_CONFIG_TILE(TMM0, 8, 64);
222
+ TC_CONFIG_TILE(TMM1, 8, 64);
223
+ TC_CONFIG_TILE(TMM2, 16, 32);
224
+ TC_CONFIG_TILE(TMM3, 16, 32);
225
+ TC_CONFIG_TILE(TMM4, 16, 64);
226
+ TC_CONFIG_TILE(TMM5, 16, 64);
227
+ TC_CONFIG_TILE(TMM6, 16, 64);
228
+ TC_CONFIG_TILE(TMM7, 16, 64);
229
+ _tile_loadconfig(&tc);
230
+ }
231
+
232
+ is_first_time = false;
233
+ }
234
+
235
+ // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
236
+ // See the notes `s8s8 igemm compensation in avx512-vnni` for detail.
237
+ template <typename TB>
238
+ int get_tile_size() {
239
+ int tile_size = TILE_N * sizeof(TB);
240
+ if (do_compensate<TB>::value) {
241
+ tile_size += TILE_N * sizeof(int32_t);
242
+ }
243
+ if (std::is_same<TB, block_q4_K>::value ||
244
+ std::is_same<TB, block_q5_K>::value) {
245
+ tile_size += TILE_N * 4;
246
+ }
247
+ if (std::is_same<TB, block_iq4_xs>::value) {
248
+ tile_size += TILE_N * 2;
249
+ }
250
+ return tile_size;
251
+ }
252
+
253
+ template <typename TB, int BLOCK_K>
254
+ int get_row_size(int K) {
255
+ int KB = K / BLOCK_K;
256
+ int row_size = KB * sizeof(TB);
257
+ if (do_compensate<TB>::value) {
258
+ row_size += KB * sizeof(int32_t);
259
+ }
260
+ if (std::is_same<TB, block_q4_K>::value ||
261
+ std::is_same<TB, block_q5_K>::value) {
262
+ row_size += KB * 4;
263
+ }
264
+ if (std::is_same<TB, block_iq4_xs>::value) {
265
+ row_size += KB * 2;
266
+ }
267
+ return row_size;
268
+ }
269
+
270
+ // vectorized dtype conversion
271
+ inline float FP16_TO_FP32(ggml_half val) {
272
+ __m256i v = _mm256_setr_epi16(
273
+ val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
274
+ __m512 o = _mm512_cvtph_ps(v);
275
+ return _mm512_cvtss_f32(o);
276
+ }
277
+
278
+ inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
279
+ __m256i v = _mm256_set1_epi16(val);
280
+ return _mm512_cvtph_ps(v);
281
+ }
282
+
283
+ // horizontal reduce
284
+ inline float _mm512_reduce_max_ps(const __m512 x) {
285
+ __m512 v = x;
286
+ __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
287
+ v = _mm512_max_ps(v, v1);
288
+ v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
289
+ v = _mm512_max_ps(v, v1);
290
+ v1 = _mm512_shuffle_ps(v, v, 0x4E);
291
+ v = _mm512_max_ps(v, v1);
292
+ v1 = _mm512_shuffle_ps(v, v, 0xB1);
293
+ v = _mm512_max_ps(v, v1);
294
+ return _mm512_cvtss_f32(v);
295
+ }
296
+
297
+ // transpose utils
298
+ #define SHUFFLE_EPI32(a, b, mask) \
299
+ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
300
+ inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) {
301
+ // unpacking and 32-bit elements
302
+ v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);
303
+ v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);
304
+ v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);
305
+ v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);
306
+ v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);
307
+ v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);
308
+ v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);
309
+ v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);
310
+
311
+ // shuffling the 32-bit elements
312
+ v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);
313
+ v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);
314
+ v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);
315
+ v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);
316
+ v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);
317
+ v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);
318
+ v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);
319
+ v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);
320
+
321
+ // shuffling 128-bit elements
322
+ v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);
323
+ v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);
324
+ v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);
325
+ v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);
326
+ v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);
327
+ v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);
328
+ v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);
329
+ v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);
330
+ }
331
+
332
+ inline void transpose_16x4_32bit(__m512i * r, __m512i * d) {
333
+
334
+ static const __m512i index1 = _mm512_set_epi32(
335
+ 0x0f, 0x0b, 0x07, 0x03,
336
+ 0x0e, 0x0a, 0x06, 0x02,
337
+ 0x0d, 0x09, 0x05, 0x01,
338
+ 0x0c, 0x08, 0x04, 0x00);
339
+
340
+ d[0] = _mm512_permutexvar_epi32(index1, r[0]);
341
+ d[1] = _mm512_permutexvar_epi32(index1, r[1]);
342
+ d[2] = _mm512_permutexvar_epi32(index1, r[2]);
343
+ d[3] = _mm512_permutexvar_epi32(index1, r[3]);
344
+
345
+ r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
346
+ r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
347
+ r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
348
+ r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
349
+
350
+ d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
351
+ d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
352
+ d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
353
+ d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
354
+ }
355
+
356
+ inline void transpose_16x16_32bit(__m512i * v) {
357
+ __m512i v1[16];
358
+ v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
359
+ v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
360
+ v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
361
+ v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
362
+ v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
363
+ v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
364
+ v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
365
+ v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
366
+ v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
367
+ v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
368
+ v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
369
+ v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
370
+ v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
371
+ v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
372
+ v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
373
+ v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
374
+
375
+ v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
376
+ v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
377
+ v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
378
+ v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
379
+ v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
380
+ v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
381
+ v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
382
+ v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
383
+ v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
384
+ v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
385
+ v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
386
+ v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
387
+ v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
388
+ v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
389
+ v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
390
+ v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
391
+
392
+ v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
393
+ v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
394
+ v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
395
+ v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
396
+ v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
397
+ v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
398
+ v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
399
+ v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
400
+ v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
401
+ v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
402
+ v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
403
+ v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
404
+ v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
405
+ v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
406
+ v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
407
+ v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
408
+
409
+ v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
410
+ v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
411
+ v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
412
+ v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
413
+ v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
414
+ v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
415
+ v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
416
+ v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
417
+ v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
418
+ v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
419
+ v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
420
+ v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
421
+ v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
422
+ v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
423
+ v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
424
+ v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
425
+ }
426
+
427
+ void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) {
428
+ assert(k % QK_K == 0);
429
+ const int KB = k / QK_K;
430
+ constexpr int kVecs = QK_K / 16;
431
+
432
+ block_q8_K * y = reinterpret_cast<block_q8_K *>(vy);
433
+
434
+ // hold 16 float vecs from x
435
+ __m512 v[kVecs];
436
+
437
+ // hold the quants vecs
438
+ __m512i vq[kVecs / 4];
439
+
440
+ // hold the packed quants vecs
441
+ __m512i vq_packed[kVecs / 4];
442
+
443
+ const __m512 signBit = _mm512_set1_ps(-0.f);
444
+
445
+ for (int i = 0; i < KB; ++i) {
446
+ // Compute max(abs(e)) for the block
447
+ __m512 vamax = _mm512_set1_ps(0.f);
448
+ for (int j = 0; j < kVecs; ++j) {
449
+ v[j] = _mm512_loadu_ps(x); x += 16;
450
+ vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));
451
+ }
452
+ const float amax = _mm512_reduce_max_ps(vamax);
453
+
454
+ // Quantize these floats
455
+ const float iscale = 127.f / amax;
456
+ y[i].d = GGML_FP32_TO_FP16(1 / iscale);
457
+ const float id = ( amax != 0.0f ) ? iscale : 0.f;
458
+ const __m512 vscale = _mm512_set1_ps(id);
459
+
460
+ // Apply multiplier and round to nearest integer
461
+ for (int j = 0; j < kVecs; ++j) {
462
+ v[j] = _mm512_mul_ps(v[j], vscale);
463
+ v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
464
+ }
465
+
466
+ // Pack to epi8 vecs
467
+ for (int j = 0; j < kVecs / 4; ++j) {
468
+ __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));
469
+ __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));
470
+ __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));
471
+ __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));
472
+
473
+ __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);
474
+ __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);
475
+
476
+ vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);
477
+ _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]);
478
+ }
479
+
480
+ // Compute the bsums with vnni
481
+ transpose_16x4_32bit(vq, vq_packed);
482
+
483
+ const __m512i one = _mm512_set1_epi8(1);
484
+ __m512i sum = _mm512_setzero_si512();
485
+ for (int k = 0; k < 4; ++k) {
486
+ sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);
487
+ }
488
+ _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum));
489
+ }
490
+ }
491
+
492
+ // quantize A from float to `vec_dot_type`
493
+ template <typename T>
494
+ inline void from_float(const float * x, char * vy, int64_t k);
495
+
496
+ template <>
497
+ inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
498
+ quantize_row_q8_0(x, (block_q8_0 *)vy, k);
499
+ }
500
+
501
+ template <>
502
+ inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
503
+ quantize_row_q8_1(x, (block_q8_1 *)vy, k);
504
+ }
505
+
506
+ template <>
507
+ inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
508
+ #if 1
509
+ // TODO: this is reference impl!
510
+ quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);
511
+ #else
512
+ quantize_row_q8_K_vnni(x, vy, k);
513
+ #endif
514
+ }
515
+
516
+ // load A from memory to array when nrows can not fill in whole tile
517
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) {
518
+ assert(nr != TILE_M);
519
+ for (int m = 0; m < nr; ++m) {
520
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
521
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
522
+ }
523
+ }
524
+
525
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) {
526
+ assert(nr != TILE_M);
527
+ for (int m = 0; m < nr; ++m) {
528
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
529
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
530
+ }
531
+ }
532
+
533
+ template <typename TB>
534
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
535
+ assert(nr <= TILE_M);
536
+ for (int m = 0; m < nr; ++m) {
537
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32));
538
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
539
+ }
540
+ }
541
+
542
+ template <>
543
+ void unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
544
+ assert(nr <= TILE_M);
545
+ // zero padding k from 16 to 32, so that we don't have to re-config amx
546
+ const __m128i zero = _mm_setzero_si128();
547
+ for (int m = 0; m < nr; ++m) {
548
+ const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16));
549
+ const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);
550
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r);
551
+ }
552
+ }
553
+
554
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
555
+ inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
556
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
557
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
558
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
559
+ return _mm256_and_si256(lowMask, bytes);
560
+ }
561
+
562
+ // used for block_q4_K
563
+ inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) {
564
+ const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi);
565
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
566
+ const __m256i q4l = _mm256_and_si256(tmp, lowMask);
567
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);
568
+ return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);
569
+ }
570
+
571
+ // used for block_q5_K
572
+ inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) {
573
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
574
+ __m256i hmask = _mm256_set1_epi8(1);
575
+ hmask = _mm256_slli_epi16(hmask, k);
576
+
577
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs);
578
+ const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh);
579
+
580
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);
581
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);
582
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
583
+ hmask = _mm256_slli_epi16(hmask, 1);
584
+
585
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);
586
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);
587
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
588
+
589
+ return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);
590
+ }
591
+
592
+ // used for block_q6_K
593
+ inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) {
594
+ const __m256i m4 = _mm256_set1_epi8(0xF);
595
+ const __m256i m2 = _mm256_set1_epi8(0x3);
596
+
597
+ const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs);
598
+ const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32));
599
+ const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh);
600
+
601
+ const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4);
602
+ const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);
603
+ const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);
604
+ const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);
605
+
606
+ const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);
607
+ const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);
608
+ const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);
609
+ const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);
610
+
611
+ r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);
612
+ r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);
613
+ }
614
+
615
+ inline __m512i packNibbles(__m512i r0, __m512i r1) {
616
+ return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4));
617
+ }
618
+
619
+ template <typename TB>
620
+ inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) {
621
+ int8_t tmp[8 * 64];
622
+ __m256i v[8], v2[8];
623
+ for (int n = 0; n < 8; ++n) {
624
+ v[n] = bytes_from_nibbles_32(B[n * KB].qs);
625
+ }
626
+ transpose_8x8_32bit(v, v2);
627
+ for (int n = 0; n < 8; ++n) {
628
+ _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]);
629
+ }
630
+ for (int n = 0; n < 8; ++n) {
631
+ v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);
632
+ }
633
+ transpose_8x8_32bit(v, v2);
634
+ for (int n = 0; n < 8; ++n) {
635
+ _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]);
636
+ }
637
+
638
+ // pack again with 128 to fully utilize vector length
639
+ for (int n = 0; n < 8; n += 2) {
640
+ __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64));
641
+ __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64));
642
+ __m512i r1r0 = packNibbles(r0, r1);
643
+ _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0);
644
+ }
645
+ }
646
+
647
+ template <>
648
+ inline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
649
+ __m256i v[8], v2[8];
650
+ for (int n = 0; n < 8; ++n) {
651
+ v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs));
652
+ }
653
+ transpose_8x8_32bit(v, v2);
654
+ for (int n = 0; n < 8; ++n) {
655
+ _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]);
656
+ }
657
+ for (int n = 0; n < 8; ++n) {
658
+ v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs));
659
+ }
660
+ transpose_8x8_32bit(v, v2);
661
+ for (int n = 0; n < 8; ++n) {
662
+ _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]);
663
+ }
664
+ }
665
+
666
+ template <>
667
+ inline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
668
+ __m512i v[16];
669
+ // QK_K 256 with 8 groups, handle 2 groups at a time
670
+ char * pb = (char *)packed_B;
671
+ for (int k = 0; k < QK_K / 64; ++k) {
672
+ // pack 2 groups { n, g, k} to {g, k/4, 4n}
673
+ // e.g. {16, 2, 32} to {2, 8, 64}
674
+ for (int n = 0; n < TILE_N; ++n) {
675
+ v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);
676
+ }
677
+
678
+ transpose_16x16_32bit(v);
679
+
680
+ // pack again with 128 to fully utilize vector length
681
+ for (int n = 0; n < TILE_N; n += 2) {
682
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
683
+ pb += 64;
684
+ }
685
+ }
686
+ }
687
+
688
+ template <>
689
+ inline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
690
+ __m512i v[16];
691
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
692
+ // QK_K 256 with 8 groups, handle 2 groups at a time
693
+ char * pb = (char *)packed_B;
694
+ char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
695
+ for (int k = 0; k < QK_K / 64; ++k) {
696
+ // pack 2 groups { n, g, k} to {g, k/4, 4n}
697
+ // e.g. {16, 2, 32} to {2, 8, 64}
698
+ for (int n = 0; n < TILE_N; ++n) {
699
+ v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k);
700
+ }
701
+
702
+ transpose_16x16_32bit(v);
703
+
704
+ // 1. pack lower 4bits with 2 groups
705
+ for (int n = 0; n < TILE_N; n += 2) {
706
+ // get lower 4 bits
707
+ const __m512i r0 = _mm512_and_si512(v[n], lowMask);
708
+ const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
709
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
710
+ }
711
+
712
+ // 2. pack higher 1bit with 2 groups
713
+ const __m512i hmask = _mm512_set1_epi8(0x10);
714
+ for (int g = 0; g < 2; ++g) {
715
+ __m512i hbits = _mm512_setzero_si512();
716
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));
717
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));
718
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));
719
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));
720
+ hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) );
721
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));
722
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));
723
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));
724
+ _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
725
+ }
726
+ }
727
+ }
728
+
729
+ template <>
730
+ inline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
731
+ __m512i v[32];
732
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
733
+ // QK_K 256 with 8 groups, handle 4 groups at a time
734
+ char * pb = (char *)packed_B;
735
+ char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
736
+ for (int k = 0; k < QK_K / 128; ++k) {
737
+ for (int n = 0; n < TILE_N; ++n) {
738
+ bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);
739
+ }
740
+
741
+ // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7
742
+ transpose_16x16_32bit(v);
743
+ transpose_16x16_32bit(v + 16);
744
+
745
+ // 1. pack lower 4bits with 4 groups
746
+ for (int n = 0; n < 32; n += 2) {
747
+ const __m512i r0 = _mm512_and_si512(v[n], lowMask);
748
+ const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
749
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
750
+ }
751
+
752
+ // 2. pack higher 2bit with 4 groups
753
+ const __m512i hmask = _mm512_set1_epi8(0x30);
754
+ for (int g = 0; g < 8; ++g) {
755
+ __m512i hbits = _mm512_setzero_si512();
756
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));
757
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));
758
+ hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) );
759
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));
760
+ _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
761
+ }
762
+ }
763
+ }
764
+
765
+ template <>
766
+ inline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
767
+ __m512i v[16];
768
+ char * pb = (char *)packed_B;
769
+ for (int k = 0; k < QK_K / 64; ++k) {
770
+ for (int n = 0; n < TILE_N; ++n) {
771
+ __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0);
772
+ __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);
773
+ v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);
774
+ }
775
+
776
+ transpose_16x16_32bit(v);
777
+
778
+ // pack again with 128 to fully utilize vector length
779
+ for (int n = 0; n < TILE_N; n += 2) {
780
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
781
+ pb += 64;
782
+ }
783
+ }
784
+ }
785
+
786
+ // pack B to vnni formats in 4bits or 8 bits
787
+ void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) {
788
+ pack_qs(packed_B, B, KB);
789
+ ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
790
+ for (int n = 0; n < TILE_N; ++n) {
791
+ d0[n] = B[n * KB].d;
792
+ }
793
+ }
794
+
795
+ void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) {
796
+ pack_qs(packed_B, B, KB);
797
+ ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
798
+ ggml_half * m0 = d0 + TILE_N;
799
+ for (int n = 0; n < TILE_N; ++n) {
800
+ d0[n] = B[n * KB].d;
801
+ m0[n] = B[n * KB].m;
802
+ }
803
+ }
804
+
805
+ inline void s8s8_compensation(void * RESTRICT packed_B) {
806
+ // packed_B layout:
807
+ // quants {TILE_N, TILEK} int8_t
808
+ // d0 {TILE_N} ggml_half
809
+ // comp {TILE_N} int32_t
810
+ const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
811
+ __m512i vcomp = _mm512_setzero_si512();
812
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
813
+ for (int k = 0; k < 8; ++k) {
814
+ __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64));
815
+ vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);
816
+ }
817
+ _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp);
818
+ }
819
+
820
+ void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
821
+ pack_qs(packed_B, B, KB);
822
+ ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K);
823
+ for (int n = 0; n < TILE_N; ++n) {
824
+ d0[n] = B[n * KB].d;
825
+ }
826
+ s8s8_compensation(packed_B);
827
+ }
828
+
829
+ // convert 8 * {min, scale} from int6 to int8
830
+ inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) {
831
+ const uint32_t kmask1 = 0x3f3f3f3f;
832
+ const uint32_t kmask2 = 0x0f0f0f0f;
833
+ const uint32_t kmask3 = 0x03030303;
834
+
835
+ memcpy(utmp, scales, 12);
836
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
837
+ const uint32_t uaux = utmp[1] & kmask1;
838
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
839
+ utmp[2] = uaux;
840
+ utmp[0] &= kmask1;
841
+ }
842
+
843
+ // packed_B layout:
844
+ // quants {8, TILE_N, 16} uint8
845
+ // scales {8, TILE_N} uint8
846
+ // mins {8, TILE_N} uint8
847
+ // d {TILE_N} ggml_half
848
+ // dmin {TILE_N} ggml_half
849
+ void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
850
+ pack_qs(packed_B, B, KB);
851
+
852
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
853
+ uint8_t * mins = scales + 8 * TILE_N;
854
+ ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);
855
+ ggml_half * dmin = d + TILE_N;
856
+
857
+ union {
858
+ uint32_t u32[4];
859
+ uint8_t u8[16];
860
+ } s;
861
+
862
+ for (int n = 0; n < TILE_N; ++n) {
863
+ unpack_mins_and_scales(B[n * KB].scales, s.u32);
864
+ for (int k = 0; k < 8; ++k) {
865
+ scales[k * TILE_N + n] = s.u8[k];
866
+ mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
867
+ }
868
+ d[n] = B[n * KB].d;
869
+ dmin[n] = B[n * KB].dmin;
870
+ }
871
+ }
872
+
873
+ // packed_B layout:
874
+ // quants {8, TILE_N, 16} uint8
875
+ // qh {8, TILE_N, 4} uint8
876
+ // scales {8, TILE_N} uint8
877
+ // mins {8, TILE_N} uint8
878
+ // d {TILE_N} ggml_half
879
+ // dmin {TILE_N} ggml_half
880
+ void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
881
+ pack_qs(packed_B, B, KB);
882
+
883
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
884
+ uint8_t * mins = scales + 8 * TILE_N;
885
+ ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);
886
+ ggml_half * dmin = d + TILE_N;
887
+
888
+ union {
889
+ uint32_t u32[4];
890
+ uint8_t u8[16];
891
+ } s;
892
+
893
+ for (int n = 0; n < TILE_N; ++n) {
894
+ unpack_mins_and_scales(B[n * KB].scales, s.u32);
895
+ for (int k = 0; k < 8; ++k) {
896
+ scales[k * TILE_N + n] = s.u8[k];
897
+ mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
898
+ }
899
+ d[n] = B[n * KB].d;
900
+ dmin[n] = B[n * KB].dmin;
901
+ }
902
+ }
903
+
904
+ // packed_B layout:
905
+ // quants {16, TILE_N, 8} uint8
906
+ // qh {16, TILE_N, 4} uint8
907
+ // scales {16, TILE_N} uint8
908
+ // d {TILE_N} ggml_half
909
+ void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
910
+ pack_qs(packed_B, B, KB);
911
+
912
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
913
+ ggml_half * d = reinterpret_cast<ggml_half *>(scales + 16 * TILE_N);
914
+ for (int n = 0; n < TILE_N; ++n) {
915
+ const int8_t * ps = B[n * KB].scales;
916
+ for (int k = 0; k < 16; ++k) {
917
+ scales[k * TILE_N + n] = ps[k];
918
+ }
919
+ d[n] = B[n * KB].d;
920
+ }
921
+ }
922
+
923
+ // packed_B layout:
924
+ // quants {8, TILE_N, 16} uint8
925
+ // scales {8, TILE_N} int8
926
+ // d {TILE_N} ggml_half
927
+ void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
928
+ pack_qs(packed_B, B, KB);
929
+
930
+ int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
931
+ ggml_half * d = reinterpret_cast<ggml_half *>(scales + 8 * TILE_N);
932
+
933
+ // pack the scales
934
+ for (int n = 0; n < TILE_N; ++n) {
935
+ uint16_t sh = B[n * KB].scales_h;
936
+ for (int k = 0; k < 8; k += 2) {
937
+ const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;
938
+ const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32;
939
+ scales[(k + 0) * TILE_N + n] = ls1;
940
+ scales[(k + 1) * TILE_N + n] = ls2;
941
+ sh >>= 4;
942
+ }
943
+ d[n] = B[n * KB].d;
944
+ }
945
+ }
946
+
947
+ template<typename TB, typename packed_B_t = packed_B_type<TB>>
948
+ void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {
949
+ GGML_UNUSED(tile);
950
+ GGML_UNUSED(packed_B);
951
+ }
952
+
953
+ template <>
954
+ void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {
955
+ const __m512i off = _mm512_set1_epi8(8);
956
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
957
+ for (int n = 0; n < 8; n += 2) {
958
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
959
+ const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);
960
+ const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);
961
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
962
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
963
+ }
964
+ }
965
+
966
+ template <>
967
+ void unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) {
968
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
969
+ for (int n = 0; n < 8; n += 2) {
970
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
971
+ const __m512i r0 = _mm512_and_si512(bytes, lowMask);
972
+ const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
973
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
974
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
975
+ }
976
+ }
977
+
978
+ // packed_B_t for QKK is int8_t
979
+ template <typename TB>
980
+ void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
981
+ const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
982
+ const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size;
983
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
984
+ for (int n = 0; n < 8; n += 2) {
985
+ __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);
986
+ const __m512i r0 = _mm512_and_si512(bytes, lowMask);
987
+ const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
988
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
989
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
990
+ }
991
+ }
992
+
993
+ template <>
994
+ void unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
995
+ // lower 4bits, stride 256 bytes
996
+ const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;
997
+ const char * pb = (const char *)packed_B + k * packed_l4_group_size;
998
+
999
+ // higher 1bit, stride 64 bytes
1000
+ const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;
1001
+ const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;
1002
+ const __m512i hbits = _mm512_loadu_si512(ph);
1003
+
1004
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1005
+ __m512i hmask0 = _mm512_set1_epi8(0x1);
1006
+ __m512i hmask1 = _mm512_set1_epi8(0x2);
1007
+
1008
+ for (int n = 0; n < 8; n += 2) {
1009
+ __m512i bytes = _mm512_loadu_si512(pb + n * 32);
1010
+ __m512i r0 = _mm512_and_si512(bytes, lowMask);
1011
+ __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1012
+ __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);
1013
+ __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);
1014
+
1015
+ hmask0 = _mm512_slli_epi16(hmask0, 2);
1016
+ hmask1 = _mm512_slli_epi16(hmask1, 2);
1017
+ r0 = _mm512_add_epi8(r0, h0);
1018
+ r1 = _mm512_add_epi8(r1, h1);
1019
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
1020
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
1021
+ }
1022
+ }
1023
+
1024
+ template <>
1025
+ void unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
1026
+ // lower 4bits, stride 128 bytes
1027
+ const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;
1028
+ const char * pb = (const char *)packed_B + k * packed_l4_group_size;
1029
+
1030
+ // higher 2bits, stride 64 bytes
1031
+ const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;
1032
+ const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;
1033
+ const __m512i hbits = _mm512_loadu_si512(ph);
1034
+
1035
+ const __m512i off = _mm512_set1_epi8(32);
1036
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1037
+ __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011
1038
+ __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100
1039
+
1040
+ // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`
1041
+ __m512i bytes = _mm512_loadu_si512(pb);
1042
+ __m512i r0 = _mm512_and_si512(bytes, lowMask);
1043
+ __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1044
+ __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);
1045
+ __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);
1046
+ _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
1047
+ _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
1048
+
1049
+ hmask0 = _mm512_slli_epi16(hmask0, 4);
1050
+ hmask1 = _mm512_slli_epi16(hmask1, 4);
1051
+
1052
+ bytes = _mm512_loadu_si512(pb + 64);
1053
+ r0 = _mm512_and_si512(bytes, lowMask);
1054
+ r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1055
+ h0 = _mm512_and_si512(hbits, hmask0);
1056
+ h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);
1057
+ _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
1058
+ _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
1059
+ }
1060
+
1061
+ template <>
1062
+ void unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
1063
+ static const __m512i values128 = _mm512_set_epi8(
1064
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1065
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1066
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1067
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
1068
+ );
1069
+
1070
+ const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
1071
+ const char * pb = (const char *)packed_B + k * packed_B_group_size;
1072
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1073
+
1074
+ for (int n = 0; n < 8; n += 2) {
1075
+ __m512i bytes = _mm512_loadu_si512(pb + n * 32);
1076
+ const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));
1077
+ const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
1078
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
1079
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
1080
+ }
1081
+ }
1082
+
1083
+ template <typename TA, typename TB, bool is_acc>
1084
+ struct acc_C {};
1085
+
1086
+ template <bool is_acc>
1087
+ struct acc_C<block_q8_0, block_q4_0, is_acc> {
1088
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
1089
+ const int offset = TILE_N * TILE_K / 2;
1090
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1091
+
1092
+ for (int m = 0; m < nr; ++m) {
1093
+ const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
1094
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1095
+
1096
+ __m512 vsum;
1097
+ if (is_acc) {
1098
+ vsum = _mm512_loadu_ps(C + m * ldc);
1099
+ } else {
1100
+ vsum = _mm512_set1_ps(0.f);
1101
+ }
1102
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1103
+ _mm512_storeu_ps(C + m * ldc, vsum);
1104
+ }
1105
+ }
1106
+ };
1107
+
1108
+ template <bool is_acc>
1109
+ struct acc_C<block_q8_1, block_q4_1, is_acc> {
1110
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) {
1111
+ const int offset = TILE_N * TILE_K / 2;
1112
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1113
+ const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half))));
1114
+
1115
+ for (int m = 0; m < nr; ++m) {
1116
+ const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
1117
+ const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s));
1118
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1119
+
1120
+ __m512 vsum;
1121
+ if (is_acc) {
1122
+ vsum = _mm512_loadu_ps(C + m * ldc);
1123
+ } else {
1124
+ vsum = _mm512_set1_ps(0.f);
1125
+ }
1126
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1127
+ vsum = _mm512_fmadd_ps(vm0, vs1, vsum);
1128
+ _mm512_storeu_ps(C + m * ldc, vsum);
1129
+ }
1130
+ }
1131
+ };
1132
+
1133
+ template <bool is_acc>
1134
+ struct acc_C<block_q8_0, block_q8_0, is_acc> {
1135
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
1136
+ const int offset = TILE_N * TILE_K;
1137
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1138
+
1139
+ for (int m = 0; m < nr; ++m) {
1140
+ const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
1141
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1142
+
1143
+ __m512 vsum;
1144
+ if (is_acc) {
1145
+ vsum = _mm512_loadu_ps(C + m * ldc);
1146
+ } else {
1147
+ vsum = _mm512_set1_ps(0.f);
1148
+ }
1149
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1150
+ _mm512_storeu_ps(C + m * ldc, vsum);
1151
+ }
1152
+ }
1153
+ };
1154
+
1155
+ template <bool is_acc>
1156
+ struct acc_C<block_q8_K, block_q4_K, is_acc> {
1157
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1158
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
1159
+ const uint8_t * mins = scales + 8 * TILE_N;
1160
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);
1161
+ const ggml_half * dmin = d0 + TILE_N;
1162
+
1163
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1164
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
1165
+
1166
+ for (int m = 0; m < nr; ++m) {
1167
+ const float d1 = A[m * lda].d;
1168
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1169
+ const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
1170
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1171
+
1172
+ __m512 vsum;
1173
+ if (is_acc) {
1174
+ vsum = _mm512_loadu_ps(C + m * ldc);
1175
+ } else {
1176
+ vsum = _mm512_set1_ps(0.f);
1177
+ }
1178
+
1179
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
1180
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1181
+
1182
+ __m512i acc_m = _mm512_setzero_si512();
1183
+ for (int k = 0; k < 4; ++k) {
1184
+ __m512i vmask = _mm512_set1_epi32(k);
1185
+ __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
1186
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
1187
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1188
+ }
1189
+
1190
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1191
+ vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
1192
+ _mm512_storeu_ps(C + m * ldc, vsum);
1193
+ }
1194
+ }
1195
+ };
1196
+
1197
+ template <bool is_acc>
1198
+ struct acc_C<block_q8_K, block_q5_K, is_acc> {
1199
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1200
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
1201
+ const uint8_t * mins = scales + 8 * TILE_N;
1202
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);
1203
+ const ggml_half * dmin = d0 + TILE_N;
1204
+
1205
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1206
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
1207
+
1208
+ for (int m = 0; m < nr; ++m) {
1209
+ const float d1 = A[m * lda].d;
1210
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1211
+ const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
1212
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1213
+
1214
+ __m512 vsum;
1215
+ if (is_acc) {
1216
+ vsum = _mm512_loadu_ps(C + m * ldc);
1217
+ } else {
1218
+ vsum = _mm512_set1_ps(0.f);
1219
+ }
1220
+
1221
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
1222
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1223
+
1224
+ __m512i acc_m = _mm512_setzero_si512();
1225
+ for (int k = 0; k < 4; ++k) {
1226
+ __m512i vmask = _mm512_set1_epi32(k);
1227
+ __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
1228
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
1229
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1230
+ }
1231
+
1232
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1233
+ vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
1234
+ _mm512_storeu_ps(C + m * ldc, vsum);
1235
+ }
1236
+ }
1237
+ };
1238
+
1239
+ template <bool is_acc>
1240
+ struct acc_C<block_q8_K, block_q6_K, is_acc> {
1241
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1242
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
1243
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 16 * TILE_N);
1244
+
1245
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1246
+
1247
+ for (int m = 0; m < nr; ++m) {
1248
+ const float d1 = A[m * lda].d;
1249
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1250
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1251
+
1252
+ __m512 vsum;
1253
+ if (is_acc) {
1254
+ vsum = _mm512_loadu_ps(C + m * ldc);
1255
+ } else {
1256
+ vsum = _mm512_set1_ps(0.f);
1257
+ }
1258
+
1259
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1260
+ _mm512_storeu_ps(C + m * ldc, vsum);
1261
+ }
1262
+ }
1263
+ };
1264
+
1265
+ template <bool is_acc>
1266
+ struct acc_C<block_q8_K, block_iq4_xs, is_acc> {
1267
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1268
+ const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
1269
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 8 * TILE_N);
1270
+
1271
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1272
+
1273
+ for (int m = 0; m < nr; ++m) {
1274
+ const float d1 = A[m * lda].d;
1275
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1276
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1277
+
1278
+ __m512 vsum;
1279
+ if (is_acc) {
1280
+ vsum = _mm512_loadu_ps(C + m * ldc);
1281
+ } else {
1282
+ vsum = _mm512_set1_ps(0.f);
1283
+ }
1284
+
1285
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1286
+ _mm512_storeu_ps(C + m * ldc, vsum);
1287
+ }
1288
+ }
1289
+ };
1290
+
1291
+ template <typename TB> constexpr int get_quants_size();
1292
+ template <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; }
1293
+ template <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; }
1294
+ template <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; }
1295
+ template <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; }
1296
+
1297
+ // used for QKK format
1298
+ template <typename TB, bool is_acc,
1299
+ typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
1300
+ inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) {
1301
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>());
1302
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N)));
1303
+
1304
+ for (int m = 0; m < nr; ++m) {
1305
+ __m512i vsumi;
1306
+ if (is_acc) {
1307
+ vsumi = _mm512_loadu_si512(sumi + m * TILE_N);
1308
+ } else {
1309
+ vsumi = _mm512_setzero_si512();
1310
+ }
1311
+ __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);
1312
+ vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));
1313
+ _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi);
1314
+ }
1315
+ }
1316
+
1317
+ template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
1318
+ struct tinygemm_kernel_avx {
1319
+ static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) {
1320
+ GGML_UNUSED(K);
1321
+ GGML_UNUSED(A);
1322
+ GGML_UNUSED(B);
1323
+ GGML_UNUSED(C);
1324
+ GGML_UNUSED(ldc);
1325
+ }
1326
+ };
1327
+
1328
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1329
+ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1330
+ static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) {
1331
+ constexpr int ROWS = BLOCK_M;
1332
+ constexpr int COLS = BLOCK_N;
1333
+ assert(BLOCK_K == 16);
1334
+
1335
+ __m512 va;
1336
+ __m512 vb[COLS];
1337
+ __m512 vc[ROWS * COLS];
1338
+
1339
+ auto loadc = [&](auto idx) {
1340
+ vc[idx] = _mm512_setzero_ps();
1341
+ };
1342
+ Unroll<ROWS * COLS>{}(loadc);
1343
+
1344
+ auto compute = [&](auto idx, auto k) {
1345
+ constexpr int row = idx / COLS;
1346
+ constexpr int col = idx % COLS;
1347
+
1348
+ if constexpr (col == 0) {
1349
+ va = _mm512_loadu_ps(A + row * K + k);
1350
+ }
1351
+ if constexpr (row == 0) {
1352
+ vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
1353
+ }
1354
+ vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
1355
+ };
1356
+
1357
+ for (int k = 0; k < K; k += 16) {
1358
+ Unroll<ROWS * COLS>{}(compute, k);
1359
+ }
1360
+
1361
+ auto storec = [&](auto idx) {
1362
+ constexpr int row = idx / COLS;
1363
+ constexpr int col = idx % COLS;
1364
+ C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
1365
+ };
1366
+ Unroll<ROWS * COLS>{}(storec);
1367
+ }
1368
+ };
1369
+
1370
+ #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
1371
+ tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
1372
+ K, (const float *)src1->data + mb_start * K, \
1373
+ (const type *)src0->data + nb_start * K, \
1374
+ (float *)dst->data + mb_start * ldc + nb_start, ldc);
1375
+
1376
+
1377
+ // re-organize in the format {NB, KB, TILE_SIZE}:
1378
+ #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
1379
+
1380
+ template<typename TB, int BLOCK_K>
1381
+ void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) {
1382
+ const int NB = N / TILE_N;
1383
+ const int KB = K / BLOCK_K;
1384
+ const int TILE_SIZE = get_tile_size<TB>();
1385
+
1386
+ // parallel on NB should be enough
1387
+ parallel_for(NB, [&](int begin, int end) {
1388
+ for (int n = begin; n < end; ++n) {
1389
+ for (int k = 0; k < KB; ++k) {
1390
+ int n0 = n * TILE_N;
1391
+ pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);
1392
+ }
1393
+ }
1394
+ });
1395
+ }
1396
+
1397
+ template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
1398
+ struct tinygemm_kernel_vnni {};
1399
+
1400
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1401
+ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1402
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1403
+
1404
+ constexpr int COLS = BLOCK_N / 16;
1405
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_0);
1406
+
1407
+ const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
1408
+ const char * RESTRICT B = static_cast<const char *>(_B);
1409
+
1410
+ __m512i va[8];
1411
+ __m512 vc[COLS];
1412
+ __m512 vd1;
1413
+
1414
+ // sum of offsets, shared across COLS
1415
+ //
1416
+ // avx512-vnni does not have `_mm512_dpbssd_epi32`,
1417
+ // need to transfrom ss to us:
1418
+ // a * (b - 8) is equavilent to b * a - 8 * a
1419
+ // s u u u s u s
1420
+ //
1421
+ __m512i vcomp;
1422
+
1423
+ const __m512i off = _mm512_set1_epi8(8);
1424
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1425
+
1426
+ auto loadc = [&](auto col) {
1427
+ vc[col] = _mm512_setzero_ps();
1428
+ };
1429
+ Unroll<COLS>{}(loadc);
1430
+
1431
+ auto compute = [&](auto col, auto i) {
1432
+ // load a and compute compensation
1433
+ if constexpr (col == 0) {
1434
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1435
+ vcomp = _mm512_setzero_si512();
1436
+ for (int k = 0; k < 8; ++k) {
1437
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1438
+ vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);
1439
+ }
1440
+ vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
1441
+ }
1442
+
1443
+ // load b
1444
+ __m512i vsum = _mm512_setzero_si512();
1445
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1446
+ for (int k = 0; k < 8; k += 2) {
1447
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
1448
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1449
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);
1450
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1451
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);
1452
+ }
1453
+ const int offset = TILE_N * TILE_K / 2;
1454
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1455
+ vsum = _mm512_sub_epi32(vsum, vcomp);
1456
+
1457
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1458
+ };
1459
+
1460
+ for (int i = 0; i < KB; ++i) {
1461
+ Unroll<COLS>{}(compute, i);
1462
+ }
1463
+
1464
+ //store to C
1465
+ auto storec = [&](auto col) {
1466
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1467
+ };
1468
+ Unroll<COLS>{}(storec);
1469
+ }
1470
+ };
1471
+
1472
+ template <int BLOCK_N, int BLOCK_K>
1473
+ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {
1474
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1475
+
1476
+ constexpr int COLS = BLOCK_N / 16;
1477
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_1);
1478
+
1479
+ const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A);
1480
+ const char * RESTRICT B = static_cast<const char *>(_B);
1481
+
1482
+ __m512i va[8];
1483
+ __m512i vb[8];
1484
+ __m512 vc[COLS];
1485
+ __m512 vd1, vs1;
1486
+
1487
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1488
+
1489
+ auto loadc = [&](auto col) {
1490
+ vc[col] = _mm512_setzero_ps();
1491
+ };
1492
+ Unroll<COLS>{}(loadc);
1493
+
1494
+ auto compute = [&](auto col, auto i) {
1495
+ // load a
1496
+ if constexpr (col == 0) {
1497
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1498
+ for (int k = 0; k < 8; ++k) {
1499
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1500
+ }
1501
+ vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
1502
+ vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s));
1503
+ }
1504
+
1505
+ // load b
1506
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1507
+ for (int k = 0; k < 8; k += 2) {
1508
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
1509
+ vb[k + 0] = _mm512_and_si512(bytes, lowMask);
1510
+ vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1511
+ }
1512
+ const int offset = TILE_N * TILE_K / 2;
1513
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1514
+ const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half))));
1515
+
1516
+ __m512i vsum = _mm512_setzero_si512();
1517
+ for (int k = 0; k < 8; ++k) {
1518
+ vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);
1519
+ }
1520
+
1521
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1522
+ vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);
1523
+ };
1524
+
1525
+ for (int i = 0; i < KB; ++i) {
1526
+ Unroll<COLS>{}(compute, i);
1527
+ }
1528
+
1529
+ //store to C
1530
+ auto storec = [&](auto col) {
1531
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1532
+ };
1533
+ Unroll<COLS>{}(storec);
1534
+ }
1535
+ };
1536
+
1537
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1538
+ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1539
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1540
+
1541
+ constexpr int COLS = BLOCK_N / 16;
1542
+ const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);
1543
+
1544
+ const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
1545
+ const char * RESTRICT B = static_cast<const char *>(_B);
1546
+
1547
+ __m512i va[8];
1548
+ __m512i vb[8];
1549
+ __m512 vc[COLS];
1550
+ __m512 vd1;
1551
+
1552
+ // Notes: s8s8 igemm compensation in avx512-vnni
1553
+ // change s8s8 to u8s8 with compensate
1554
+ // a * b = (a + 128) * b - 128 * b
1555
+ // s s u s u s
1556
+ //
1557
+ // (128 * b is pre-computed when packing B to vnni formats)
1558
+ //
1559
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1560
+
1561
+ auto loadc = [&](auto col) {
1562
+ vc[col] = _mm512_setzero_ps();
1563
+ };
1564
+ Unroll<COLS>{}(loadc);
1565
+
1566
+ auto compute = [&](auto col, auto i) {
1567
+ // load a and add offset 128
1568
+ if constexpr (col == 0) {
1569
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1570
+ for (int k = 0; k < 8; ++k) {
1571
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1572
+ va[k] = _mm512_add_epi8(va[k], off);
1573
+ }
1574
+ vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
1575
+ }
1576
+
1577
+ // load b
1578
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1579
+ for (int k = 0; k < 8; ++k) {
1580
+ vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64));
1581
+ }
1582
+ const int offset = TILE_N * TILE_K;
1583
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1584
+ const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
1585
+ const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2));
1586
+
1587
+ __m512i vsum = _mm512_setzero_si512();
1588
+ for (int k = 0; k < 8; ++k) {
1589
+ vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);
1590
+ }
1591
+ vsum = _mm512_sub_epi32(vsum, vcomp);
1592
+
1593
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1594
+ };
1595
+
1596
+ for (int i = 0; i < KB; ++i) {
1597
+ Unroll<COLS>{}(compute, i);
1598
+ }
1599
+
1600
+ //store to C
1601
+ auto storec = [&](auto col) {
1602
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1603
+ };
1604
+ Unroll<COLS>{}(storec);
1605
+ }
1606
+ };
1607
+
1608
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1609
+ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1610
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1611
+
1612
+ constexpr int COLS = BLOCK_N / 16;
1613
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;
1614
+
1615
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1616
+ const char * RESTRICT B = static_cast<const char *>(_B);
1617
+
1618
+ // a.qs: 8 groups, 32 bytes each group (m256i)
1619
+ __m512i va[8];
1620
+ // a.bsum: 8 groups, 2 bytes each group (m128i)
1621
+ __m512i va_bsum;
1622
+ __m512 vc[COLS];
1623
+ __m512 vd1;
1624
+
1625
+ // packed_B:
1626
+ const int offset_scales = (QK_K / 2) * TILE_N;
1627
+ const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N;
1628
+ const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N;
1629
+ const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
1630
+
1631
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1632
+
1633
+ auto loadc = [&](auto col) {
1634
+ vc[col] = _mm512_setzero_ps();
1635
+ };
1636
+ Unroll<COLS>{}(loadc);
1637
+
1638
+ // Notes: vnni formats in QK_K
1639
+ // a) quants vnni format
1640
+ // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32
1641
+ // from {16, 32} to {8, 64}
1642
+ //
1643
+ // b) min vnni format
1644
+ // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
1645
+ // from {16, 8} to {4, 32}
1646
+ //
1647
+ auto compute = [&](auto col, auto i) {
1648
+ // load a
1649
+ if constexpr (col == 0) {
1650
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1651
+ va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1652
+ }
1653
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1654
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1655
+ va_bsum = _mm512_castsi128_si512(q8s);
1656
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1657
+ }
1658
+
1659
+ // step 1: accumultate the quants
1660
+ __m512i acc = _mm512_setzero_si512();
1661
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1662
+ const char * b_qs = b_ptr;
1663
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1664
+ __m512i vsum = _mm512_setzero_si512();
1665
+ for (int k = 0; k < 8; k += 2) {
1666
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
1667
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
1668
+
1669
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
1670
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1671
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1672
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1673
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1674
+
1675
+ b_qs += 64;
1676
+ }
1677
+ // vacc += scale * (q8 @ q4)
1678
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1679
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1680
+ }
1681
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1682
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1683
+
1684
+ // step 2: accumulate the mins
1685
+ __m512i acc_m = _mm512_setzero_si512();
1686
+ for (int k = 0; k < 4; ++k) {
1687
+ __m512i vmask = _mm512_set1_epi32(k);
1688
+ __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
1689
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
1690
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1691
+ }
1692
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
1693
+ vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
1694
+ };
1695
+
1696
+ for (int i = 0; i < KB; ++i) {
1697
+ Unroll<COLS>{}(compute, i);
1698
+ }
1699
+
1700
+ //store to C
1701
+ auto storec = [&](auto col) {
1702
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1703
+ };
1704
+ Unroll<COLS>{}(storec);
1705
+ }
1706
+ };
1707
+
1708
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1709
+ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1710
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1711
+
1712
+ constexpr int COLS = BLOCK_N / 16;
1713
+ const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;
1714
+
1715
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1716
+ const char * RESTRICT B = static_cast<const char *>(_B);
1717
+
1718
+ // a.qs: 8 groups, 32 bytes each group (m256i)
1719
+ __m512i va[8];
1720
+ // a.bsum: 8 groups, 2 bytes each group (m128i)
1721
+ __m512i va_bsum;
1722
+ __m512 vc[COLS];
1723
+ __m512 vd1;
1724
+
1725
+ // packed_B:
1726
+ const int offset_qh = (QK_K / 2) * TILE_N;
1727
+ const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;
1728
+ const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N;
1729
+ const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;
1730
+ const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
1731
+
1732
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1733
+
1734
+ auto loadc = [&](auto col) {
1735
+ vc[col] = _mm512_setzero_ps();
1736
+ };
1737
+ Unroll<COLS>{}(loadc);
1738
+
1739
+ // Q5_K and Q4_K shares the same vnni formats, refer to notes above.
1740
+ auto compute = [&](auto col, auto i) {
1741
+ // load a
1742
+ if constexpr (col == 0) {
1743
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1744
+ va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1745
+ }
1746
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1747
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1748
+ va_bsum = _mm512_castsi128_si512(q8s);
1749
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1750
+ }
1751
+
1752
+ // step 1: accumultate the quants
1753
+ __m512i acc = _mm512_setzero_si512();
1754
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1755
+ const char * b_qs = b_ptr;
1756
+ const char * b_qh = b_ptr + offset_qh;
1757
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1758
+ __m512i vsum = _mm512_setzero_si512();
1759
+ __m512i hmask0 = _mm512_set1_epi8(0x1);
1760
+ __m512i hmask1 = _mm512_set1_epi8(0x2);
1761
+ __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64));
1762
+ for (int k = 0; k < 8; k += 2) {
1763
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
1764
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
1765
+
1766
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
1767
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1768
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1769
+
1770
+ __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);
1771
+ __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);
1772
+
1773
+ hmask0 = _mm512_slli_epi16(hmask0, 2);
1774
+ hmask1 = _mm512_slli_epi16(hmask1, 2);
1775
+ vb0 = _mm512_add_epi8(vb0, vh0);
1776
+ vb1 = _mm512_add_epi8(vb1, vh1);
1777
+
1778
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1779
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1780
+
1781
+ b_qs += 64;
1782
+ }
1783
+ // vacc += scale * (q8 @ q5)
1784
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1785
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1786
+ }
1787
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1788
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1789
+
1790
+ // step 2: accumulate the mins
1791
+ __m512i acc_m = _mm512_setzero_si512();
1792
+ for (int k = 0; k < 4; ++k) {
1793
+ __m512i vmask = _mm512_set1_epi32(k);
1794
+ __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
1795
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
1796
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1797
+ }
1798
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
1799
+ vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
1800
+ };
1801
+
1802
+ for (int i = 0; i < KB; ++i) {
1803
+ Unroll<COLS>{}(compute, i);
1804
+ }
1805
+
1806
+ //store to C
1807
+ auto storec = [&](auto col) {
1808
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1809
+ };
1810
+ Unroll<COLS>{}(storec);
1811
+ }
1812
+ };
1813
+
1814
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1815
+ struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1816
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1817
+
1818
+ constexpr int COLS = BLOCK_N / 16;
1819
+ const int TILE_SIZE = TILE_N * sizeof(block_q6_K);
1820
+
1821
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1822
+ const char * RESTRICT B = static_cast<const char *>(_B);
1823
+
1824
+ // load the 256 bytes from A to 4 avx512 vectors
1825
+ __m512i va[4];
1826
+ __m512 vc[COLS];
1827
+ __m512 vd1;
1828
+
1829
+ // packed_B:
1830
+ const int offset_qh = (QK_K / 2) * TILE_N;
1831
+ const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;
1832
+ const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;
1833
+
1834
+ // compensation
1835
+ __m512i vcomp;
1836
+
1837
+ const __m512i m32s = _mm512_set1_epi32(32);
1838
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1839
+
1840
+ auto loadc = [&](auto col) {
1841
+ vc[col] = _mm512_setzero_ps();
1842
+ };
1843
+ Unroll<COLS>{}(loadc);
1844
+
1845
+ auto compute = [&](auto col, auto i) {
1846
+ if constexpr (col == 0) {
1847
+ // load a
1848
+ va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1849
+ va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
1850
+ va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
1851
+ va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
1852
+
1853
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1854
+ vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);
1855
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1856
+ }
1857
+
1858
+ // accmulate the quants
1859
+ __m512i acc = _mm512_setzero_si512();
1860
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1861
+ const char * b_qs = b_ptr;
1862
+ const char * b_qh = b_ptr + offset_qh;
1863
+ int mask = 0;
1864
+ for (int k_group = 0; k_group < QK_K / 16; ++k_group) {
1865
+ int r = k_group >> 2;
1866
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1867
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1868
+
1869
+ __m512i vsum = _mm512_setzero_si512();
1870
+ __m512i hmask = _mm512_set1_epi8(0x3);
1871
+
1872
+ __m512i bytes = _mm512_loadu_si512(b_qs);
1873
+ __m512i hbits = _mm512_loadu_si512(b_qh);
1874
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1875
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1876
+ __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);
1877
+ __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);
1878
+
1879
+ vb0 = _mm512_add_epi8(vb0, vh0);
1880
+ vb1 = _mm512_add_epi8(vb1, vh1);
1881
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1882
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1883
+ b_qs += 64;
1884
+
1885
+ va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1886
+ va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1887
+
1888
+ bytes = _mm512_loadu_si512(b_qs);
1889
+ vb0 = _mm512_and_si512(bytes, lowMask);
1890
+ vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1891
+ vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));
1892
+ vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);
1893
+ vb0 = _mm512_add_epi8(vb0, vh0);
1894
+ vb1 = _mm512_add_epi8(vb1, vh1);
1895
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1896
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1897
+ b_qs += 64;
1898
+ b_qh += 64;
1899
+
1900
+ // B * A - 32 * A
1901
+ __m512i vmask = _mm512_set1_epi32(k_group);
1902
+ vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
1903
+
1904
+ // vacc += scale * (q8 @ q6)
1905
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1906
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1907
+ }
1908
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1909
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1910
+ };
1911
+
1912
+ for (int i = 0; i < KB; ++i) {
1913
+ Unroll<COLS>{}(compute, i);
1914
+ }
1915
+
1916
+ //store to C
1917
+ auto storec = [&](int col) {
1918
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1919
+ };
1920
+ Unroll<COLS>{}(storec);
1921
+ }
1922
+ };
1923
+
1924
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1925
+ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1926
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1927
+
1928
+ constexpr int COLS = BLOCK_N / 16;
1929
+ const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;
1930
+
1931
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1932
+ const char * RESTRICT B = static_cast<const char *>(_B);
1933
+
1934
+ // load the 256 bytes from A to 4 avx512 vectors
1935
+ __m512i va[4];
1936
+ __m512 vc[COLS];
1937
+ __m512 vd1;
1938
+
1939
+ // packed_B:
1940
+ const int offset_scales = (QK_K / 2) * TILE_N ;
1941
+ const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N;
1942
+
1943
+ // compensation
1944
+ __m512i vcomp;
1945
+
1946
+ const __m256i m128s = _mm256_set1_epi16(128);
1947
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1948
+
1949
+ const __m512i values128 = _mm512_set_epi8(
1950
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1951
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1952
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1953
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
1954
+ );
1955
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1956
+ const __m512i values256 = _mm512_add_epi8(values128, off);
1957
+
1958
+ auto loadc = [&](auto col) {
1959
+ vc[col] = _mm512_setzero_ps();
1960
+ };
1961
+ Unroll<COLS>{}(loadc);
1962
+
1963
+ auto compute = [&](auto col, auto i) {
1964
+ if constexpr (col == 0) {
1965
+ // load a
1966
+ va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1967
+ va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
1968
+ va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
1969
+ va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
1970
+
1971
+ // compensation: 128 * A
1972
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1973
+ vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));
1974
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1975
+ }
1976
+
1977
+ // accmulate the quants
1978
+ __m512i acc = _mm512_setzero_si512();
1979
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1980
+ const char * b_qs = b_ptr;
1981
+ int mask = 0;
1982
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1983
+ int r = k_group >> 1;
1984
+ __m512i vmask = _mm512_set1_epi32(k_group);
1985
+ __m512i vsum = _mm512_setzero_si512();
1986
+ for (int k = 0; k < 8; k += 2) {
1987
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1988
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1989
+
1990
+ __m512i bytes = _mm512_loadu_si512(b_qs);
1991
+ __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));
1992
+ __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
1993
+
1994
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1995
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1996
+ b_qs += 64;
1997
+ }
1998
+ // (B + 128) * A - 128 * A
1999
+ vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
2000
+
2001
+ // vacc += scale * (q8 @ q4)
2002
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
2003
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
2004
+ }
2005
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
2006
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
2007
+ };
2008
+
2009
+ for (int i = 0; i < KB; ++i) {
2010
+ Unroll<COLS>{}(compute, i);
2011
+ }
2012
+
2013
+ //store to C
2014
+ auto storec = [&](auto col) {
2015
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
2016
+ };
2017
+ Unroll<COLS>{}(storec);
2018
+ }
2019
+ };
2020
+
2021
+ #define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
2022
+ tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
2023
+ KB, (const char *)wdata + 0 * row_size_A, \
2024
+ (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
2025
+ (float *) dst->data + 0 * N + nb_start, ldc)
2026
+
2027
+ template <typename TA, typename TB, typename TC, int BLOCK_K,
2028
+ typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
2029
+ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) {
2030
+ using packed_B_t = packed_B_type<TB>;
2031
+ const int TILE_SIZE = get_tile_size<TB>();
2032
+ const bool need_unpack = do_unpack<TB>::value;
2033
+
2034
+ GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
2035
+ const TA * RESTRICT A = static_cast<const TA *>(_A);
2036
+ const char * RESTRICT B = static_cast<const char *>(_B);
2037
+
2038
+ const int m0 = std::min(M, TILE_M);
2039
+ const int m1 = std::max(M - TILE_M, 0);
2040
+ const int lda = KB * sizeof(TA);
2041
+ //const int ldb = KB * sizeof(TB);
2042
+
2043
+ static thread_local packed_B_t Tile0[TILE_N * TILE_K];
2044
+ static thread_local packed_B_t Tile1[TILE_N * TILE_K];
2045
+ static thread_local int8_t Tile23[TILE_M * TILE_K];
2046
+
2047
+ static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
2048
+ static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
2049
+
2050
+ // double buffering C to interleave avx512 and amx
2051
+ int32_t * C_cur = TileC0;
2052
+ int32_t * C_pre = TileC1;
2053
+
2054
+ auto Tile4 = [&](int32_t * base) { return base; };
2055
+ auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; };
2056
+ auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; };
2057
+ auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; };
2058
+
2059
+ if (M == 2 * TILE_M) {
2060
+ // i = 0
2061
+ const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);
2062
+ const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);
2063
+ if (need_unpack) {
2064
+ unpack_B<TB>(Tile0, B_blk0);
2065
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2066
+ } else {
2067
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2068
+ }
2069
+
2070
+ _tile_zero(TMM4);
2071
+ _tile_loadd(TMM2, A[0].qs, lda);
2072
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2073
+ _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));
2074
+
2075
+ _tile_zero(TMM5);
2076
+ _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);
2077
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2078
+ _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
2079
+
2080
+ if (need_unpack) {
2081
+ unpack_B<TB>(Tile1, B_blk0);
2082
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2083
+ } else {
2084
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2085
+ }
2086
+
2087
+ _tile_zero(TMM6);
2088
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2089
+ _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));
2090
+
2091
+ _tile_zero(TMM7);
2092
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2093
+ _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));
2094
+
2095
+ for (int i = 1; i < KB; ++i) {
2096
+ // index of previous iter
2097
+ const int ii = i - 1;
2098
+ const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
2099
+ const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
2100
+ GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {
2101
+ if (need_unpack) {
2102
+ unpack_B<TB>(Tile0, B_blk0);
2103
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2104
+ } else {
2105
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2106
+ }
2107
+ _tile_zero(TMM4);
2108
+ _tile_loadd(TMM2, A[i].qs, lda);
2109
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2110
+
2111
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2112
+ _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
2113
+
2114
+ _tile_zero(TMM5);
2115
+ _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);
2116
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2117
+
2118
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2119
+ _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
2120
+
2121
+ if (need_unpack) {
2122
+ unpack_B<TB>(Tile1, B_blk1);
2123
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2124
+ } else {
2125
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2126
+ }
2127
+ _tile_zero(TMM6);
2128
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2129
+
2130
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2131
+ _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
2132
+
2133
+ _tile_zero(TMM7);
2134
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2135
+
2136
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2137
+ _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
2138
+
2139
+ std::swap(C_cur, C_pre);
2140
+ });
2141
+ }
2142
+ // final accumulation
2143
+ {
2144
+ int ii = KB - 1;
2145
+ acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2146
+ acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2147
+ acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2148
+ acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2149
+ }
2150
+ } else {
2151
+ for (int i = 0; i < KB; ++i) {
2152
+ _tile_zero(TMM4);
2153
+ _tile_zero(TMM6);
2154
+ if (m1 != 0) {
2155
+ _tile_zero(TMM5);
2156
+ _tile_zero(TMM7);
2157
+ }
2158
+
2159
+ const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
2160
+ const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
2161
+ if (need_unpack) {
2162
+ unpack_B<TB>(Tile0, B_blk0);
2163
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2164
+ } else {
2165
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2166
+ }
2167
+
2168
+ if (need_unpack) {
2169
+ unpack_B<TB>(Tile1, B_blk1);
2170
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2171
+ } else {
2172
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2173
+ }
2174
+
2175
+ if (m0 == TILE_M) {
2176
+ _tile_loadd(TMM2, A[i].qs, lda);
2177
+ } else {
2178
+ unpack_A(Tile23, &A[i], KB, m0);
2179
+ _tile_loadd(TMM2, Tile23, TILE_K);
2180
+ }
2181
+
2182
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2183
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2184
+
2185
+ _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
2186
+ _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
2187
+
2188
+ GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2189
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
2190
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
2191
+ });
2192
+
2193
+ if (m1 != 0) {
2194
+ unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);
2195
+ _tile_loadd(TMM3, Tile23, TILE_K);
2196
+
2197
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2198
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2199
+ _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
2200
+ _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
2201
+ GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2202
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
2203
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
2204
+ });
2205
+ }
2206
+ }
2207
+ }
2208
+ return;
2209
+ }
2210
+
2211
+ template <typename TA, typename TB, typename TC, int BLOCK_K,
2212
+ typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
2213
+ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
2214
+ static_assert(std::is_same<TA, block_q8_K>::value);
2215
+ const int TILE_SIZE = get_tile_size<TB>();
2216
+
2217
+ GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
2218
+ const TA * RESTRICT A = static_cast<const TA *>(_A);
2219
+ const char * RESTRICT B = static_cast<const char *>(_B);
2220
+
2221
+ const int m0 = std::min(M, TILE_M);
2222
+ const int m1 = std::max(M - TILE_M, 0);
2223
+ //const int lda = KB * sizeof(TA);
2224
+
2225
+ static thread_local int8_t Tile0[TILE_N * TILE_K];
2226
+ static thread_local int8_t Tile1[TILE_N * TILE_K];
2227
+ static thread_local int8_t Tile23[TILE_M * TILE_K];
2228
+
2229
+ // mat mul result for each group
2230
+ static thread_local int32_t Tile4[TILE_M * TILE_N];
2231
+ static thread_local int32_t Tile5[TILE_M * TILE_N];
2232
+ static thread_local int32_t Tile6[TILE_M * TILE_N];
2233
+ static thread_local int32_t Tile7[TILE_M * TILE_N];
2234
+
2235
+ // sum of each QK_K block, contains 8 groups, int32
2236
+ static thread_local int32_t Sumi4[TILE_M * TILE_N];
2237
+ static thread_local int32_t Sumi5[TILE_M * TILE_N];
2238
+ static thread_local int32_t Sumi6[TILE_M * TILE_N];
2239
+ static thread_local int32_t Sumi7[TILE_M * TILE_N];
2240
+
2241
+ const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
2242
+ for (int i = 0; i < KB; ++i) {
2243
+ // step 1: accumulate the quants across 8 groups, each group with 32
2244
+ for (int k = 0; k < QK_K / k_group_size; ++k) {
2245
+ GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {
2246
+ _tile_zero(TMM4);
2247
+ _tile_zero(TMM6);
2248
+
2249
+ unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);
2250
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2251
+
2252
+ unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);
2253
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2254
+
2255
+ unpack_A<TB>(Tile23, &A[i], KB, k, m0);
2256
+ _tile_loadd(TMM2, Tile23, TILE_K);
2257
+
2258
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2259
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2260
+
2261
+ _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
2262
+ _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
2263
+
2264
+ scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);
2265
+ scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);
2266
+
2267
+ if (m1 != 0) {
2268
+ _tile_zero(TMM5);
2269
+ _tile_zero(TMM7);
2270
+
2271
+ unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);
2272
+ _tile_loadd(TMM3, Tile23, TILE_K);
2273
+
2274
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2275
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2276
+
2277
+ _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
2278
+ _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
2279
+
2280
+ scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);
2281
+ scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);
2282
+ }
2283
+ });
2284
+ }
2285
+
2286
+ // step 2: accmulate the mins
2287
+ GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2288
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
2289
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
2290
+ if (m1 != 0) {
2291
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
2292
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
2293
+ }
2294
+ });
2295
+ }
2296
+ return;
2297
+ }
2298
+
2299
+ } // anonymous namespace
2300
+
2301
+ // get the packed tensor size for quantized weights
2302
+ size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) {
2303
+ const enum ggml_type TYPE = tensor->type;
2304
+
2305
+ const int K = tensor->ne[0]; // ne0: in_features
2306
+ const int N = tensor->ne[1]; // ne1: out_features
2307
+
2308
+ auto get_tensor_size = [&] {
2309
+ size_t row_size_B{0};
2310
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2311
+ row_size_B = get_row_size<type, blck_size>(K);
2312
+ });
2313
+ return N * row_size_B;
2314
+ };
2315
+
2316
+ if (qtype_has_amx_kernels(TYPE)) {
2317
+ return get_tensor_size();
2318
+ } else {
2319
+ // for f16, bf16 we don't do packing
2320
+ return ggml_nbytes(tensor);
2321
+ }
2322
+ }
2323
+
2324
+ // pack weight to vnni format
2325
+ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2326
+ GGML_ASSERT(offset == 0 && size == ggml_nbytes(tensor)); // only full tensor conversion is supported for now
2327
+
2328
+ const enum ggml_type TYPE = tensor->type;
2329
+
2330
+ const int K = tensor->ne[0]; // ne0: in_features
2331
+ const int N = tensor->ne[1]; // ne1: out_features
2332
+
2333
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2334
+ convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K);
2335
+ });
2336
+ }
2337
+
2338
+ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
2339
+ struct ggml_tensor * src0 = dst->src[0];
2340
+
2341
+ const enum ggml_type TYPE = src0->type;
2342
+
2343
+ const bool is_floating_type = TYPE == GGML_TYPE_F16;
2344
+ if (is_floating_type) {
2345
+ return 0;
2346
+ }
2347
+
2348
+ const int M = dst->ne[1];
2349
+ const int K = src0->ne[0];
2350
+
2351
+ size_t desired_wsize = 0;
2352
+
2353
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2354
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2355
+ desired_wsize = M * row_size_A;
2356
+ });
2357
+
2358
+ return desired_wsize;
2359
+ }
2360
+
2361
+ // NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)
2362
+ //
2363
+ // src0: weight in shape of {N, K}, quantized
2364
+ // src1: input in shape of {M, K}, float32
2365
+ // dst: output in shape of {M, N}, float32
2366
+ //
2367
+ // the function performs: dst = src1 @ src0.T
2368
+ //
2369
+ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
2370
+ struct ggml_tensor * src0 = dst->src[0];
2371
+ struct ggml_tensor * src1 = dst->src[1];
2372
+
2373
+ const enum ggml_type TYPE = src0->type;
2374
+
2375
+ // f16 only has avx512 kernels for now,
2376
+ // amx kernels will be added once 6th gen xeon is released.
2377
+ const bool is_floating_type = TYPE == GGML_TYPE_F16;
2378
+
2379
+ const int M = dst->ne[1];
2380
+ const int N = dst->ne[0];
2381
+ const int K = src0->ne[0];
2382
+ const int ldc = dst->nb[1] / dst->nb[0];
2383
+
2384
+ if (is_floating_type) {
2385
+ constexpr int BLOCK_M = 4;
2386
+ constexpr int BLOCK_N = 6;
2387
+ const int MB = div_up(M, BLOCK_M);
2388
+ const int NB = div_up(N, BLOCK_N);
2389
+
2390
+ parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2391
+ GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
2392
+ for (int i = begin; i < end; ++i) {
2393
+ int mb = i / NB;
2394
+ int nb = i % NB;
2395
+
2396
+ int mb_start = mb * BLOCK_M;
2397
+ int mb_size = std::min(BLOCK_M, M - mb_start);
2398
+ int nb_start = nb * BLOCK_N;
2399
+ int nb_size = std::min(BLOCK_N, N - nb_start);
2400
+
2401
+ switch (mb_size << 4 | nb_size) {
2402
+ case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break;
2403
+ case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break;
2404
+ case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break;
2405
+ case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break;
2406
+ case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break;
2407
+ case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break;
2408
+ case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break;
2409
+ case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break;
2410
+ case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break;
2411
+ case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break;
2412
+ case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break;
2413
+ case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break;
2414
+ default: fprintf(stderr, "Unexpected block size!\n");
2415
+ }
2416
+ }
2417
+ });
2418
+ });
2419
+ return;
2420
+ }
2421
+
2422
+ // pointer to work space, used convert A from float to quantized type
2423
+ void * wdata = params->wdata;
2424
+
2425
+ //TODO: performance improvement: merge quant A
2426
+ if (params->ith == 0) {
2427
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2428
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2429
+ const size_t desired_wsize = M * row_size_A;
2430
+ if (params->wsize < desired_wsize) {
2431
+ GGML_ABORT("insufficient work space size");
2432
+ }
2433
+
2434
+ // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
2435
+ // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
2436
+ GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
2437
+
2438
+ const float * A_data = static_cast<const float *>(src1->data);
2439
+ for (int m = 0; m < M; ++m) {
2440
+ from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
2441
+ }
2442
+ });
2443
+ }
2444
+
2445
+ ggml_barrier(params->threadpool);
2446
+
2447
+ if (M == 1) {
2448
+ // MB = 1 and handle 8 tiles in each block
2449
+ constexpr int kTilesN = 4;
2450
+ constexpr int BLOCK_N = TILE_N * kTilesN;
2451
+ const int NB = div_up(N, BLOCK_N);
2452
+
2453
+ parallel_for_ggml(params, NB, [&](int begin, int end) {
2454
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2455
+ const int KB = K / blck_size;
2456
+ const int TILE_SIZE = get_tile_size<type>();
2457
+ const int row_size_A = KB * sizeof(vec_dot_type);
2458
+ for (int i = begin; i < end; ++i) {
2459
+ int nb = i;
2460
+ int nb_start = nb * BLOCK_N;
2461
+ int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
2462
+
2463
+ switch (nb_size) {
2464
+ //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;
2465
+ case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break;
2466
+ case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break;
2467
+ case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break;
2468
+ case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break;
2469
+ default: fprintf(stderr, "Unexpected n block size!\n");
2470
+ }
2471
+ }
2472
+ });
2473
+ });
2474
+ return;
2475
+ }
2476
+
2477
+ // handle 4 tiles at a tile
2478
+ constexpr int BLOCK_M = TILE_M * 2;
2479
+ constexpr int BLOCK_N = TILE_N * 2;
2480
+ const int MB = div_up(M, BLOCK_M);
2481
+ const int NB = div_up(N, BLOCK_N);
2482
+
2483
+ parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2484
+ // init tile config for each thread
2485
+ ggml_tile_config_init();
2486
+
2487
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2488
+ const int KB = K / blck_size;
2489
+ const int TILE_SIZE = get_tile_size<type>();
2490
+ const int row_size_A = KB * sizeof(vec_dot_type);
2491
+
2492
+ for (int i = begin; i < end; ++i) {
2493
+ int mb = i / NB;
2494
+ int nb = i % NB;
2495
+
2496
+ int mb_start = mb * BLOCK_M;
2497
+ int mb_size = std::min(BLOCK_M, M - mb_start);
2498
+ int nb_start = nb * BLOCK_N;
2499
+ int nb_size = BLOCK_N;
2500
+
2501
+ tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
2502
+ mb_size, nb_size, KB,
2503
+ (const char *)wdata + mb_start * row_size_A,
2504
+ (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
2505
+ (float *) dst->data + mb_start * N + nb_start, ldc);
2506
+ }
2507
+ });
2508
+ });
2509
+ }
2510
+
2511
+ #endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__)