whispercpp 1.3.0 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +60 -11
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -16
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/{whisper.h → include/whisper.h} +23 -22
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1492 -9
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -21755
@@ -0,0 +1,2510 @@
1
+
2
+ #if defined(__GNUC__)
3
+ #pragma GCC diagnostic ignored "-Wpedantic"
4
+ #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
5
+ #endif
6
+
7
+ #include "mmq.h"
8
+ #include "ggml-impl.h"
9
+ #include "ggml-quants.h"
10
+ #include <algorithm>
11
+ #include <type_traits>
12
+
13
+ #if defined(__gnu_linux__)
14
+ #include <sys/syscall.h>
15
+ #include <unistd.h>
16
+ #endif
17
+
18
+ #if defined(_OPENMP)
19
+ #include <omp.h>
20
+ #endif
21
+
22
+ #if (defined(_WIN32) || defined(_WIN64))
23
+ #define RESTRICT __restrict
24
+ #else
25
+ #define RESTRICT __restrict__
26
+ #endif
27
+
28
+ #if (defined(_WIN32) || defined(_WIN64))
29
+ #define ALWAYS_INLINE __forceinline
30
+ #elif __has_attribute(always_inline) || defined(__GNUC__)
31
+ #define ALWAYS_INLINE __attribute__((__always_inline__)) inline
32
+ #else
33
+ #define ALWAYS_INLINE inline
34
+ #endif
35
+
36
+ #if defined(__AMX_INT8__)
37
+
38
+ namespace {
39
+
40
+ // Forced unrolling
41
+ template <int n>
42
+ struct Unroll {
43
+ template <typename Func, typename... Args>
44
+ ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
45
+ Unroll<n - 1>{}(f, args...);
46
+ f(std::integral_constant<int, n - 1>{}, args...);
47
+ }
48
+ };
49
+
50
+ template <>
51
+ struct Unroll<1> {
52
+ template <typename Func, typename... Args>
53
+ ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
54
+ f(std::integral_constant<int, 0>{}, args...);
55
+ }
56
+ };
57
+
58
+ // type traits
59
+ template <typename T> struct PackedTypes {};
60
+ template <> struct PackedTypes<block_q4_0> { using type = int8_t; };
61
+ template <> struct PackedTypes<block_q4_1> { using type = uint8_t; };
62
+ template <> struct PackedTypes<block_q8_0> { using type = int8_t; };
63
+ template <typename T> using packed_B_type = typename PackedTypes<T>::type;
64
+
65
+ template <typename T>
66
+ struct do_compensate : std::integral_constant<bool,
67
+ std::is_same<T, block_q8_0>::value> {};
68
+
69
+ template <typename T>
70
+ struct do_unpack : std::integral_constant<bool,
71
+ std::is_same<T, block_q4_0>::value ||
72
+ std::is_same<T, block_q4_1>::value> {};
73
+
74
+ template <typename T>
75
+ struct is_type_qkk : std::integral_constant<bool,
76
+ std::is_same<T, block_q4_K>::value ||
77
+ std::is_same<T, block_q5_K>::value ||
78
+ std::is_same<T, block_q6_K>::value ||
79
+ std::is_same<T, block_iq4_xs>::value> {};
80
+
81
+ #define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \
82
+ [&] { \
83
+ switch (TYPE) { \
84
+ case GGML_TYPE_F16: { \
85
+ using type = ggml_fp16_t; \
86
+ constexpr int blck_size = 16; \
87
+ return __VA_ARGS__(); \
88
+ } \
89
+ case GGML_TYPE_BF16: { \
90
+ using type = ggml_bf16_t; \
91
+ constexpr int blck_size = 32; \
92
+ return __VA_ARGS__(); \
93
+ } \
94
+ default: \
95
+ fprintf(stderr, "Unsupported floating data type\n"); \
96
+ } \
97
+ }()
98
+
99
+ #define GGML_DISPATCH_QTYPES(QT, ...) \
100
+ [&] { \
101
+ switch (QT) { \
102
+ case GGML_TYPE_Q4_0: { \
103
+ using type = block_q4_0; \
104
+ using vec_dot_type = block_q8_0; \
105
+ constexpr int blck_size = QK4_0; \
106
+ return __VA_ARGS__(); \
107
+ } \
108
+ case GGML_TYPE_Q4_1: { \
109
+ using type = block_q4_1; \
110
+ using vec_dot_type = block_q8_1; \
111
+ constexpr int blck_size = QK4_1; \
112
+ return __VA_ARGS__(); \
113
+ } \
114
+ case GGML_TYPE_Q8_0: { \
115
+ using type = block_q8_0; \
116
+ using vec_dot_type = block_q8_0; \
117
+ constexpr int blck_size = QK8_0; \
118
+ return __VA_ARGS__(); \
119
+ } \
120
+ case GGML_TYPE_Q4_K: { \
121
+ using type = block_q4_K; \
122
+ using vec_dot_type = block_q8_K; \
123
+ constexpr int blck_size = QK_K; \
124
+ return __VA_ARGS__(); \
125
+ } \
126
+ case GGML_TYPE_Q5_K: { \
127
+ using type = block_q5_K; \
128
+ using vec_dot_type = block_q8_K; \
129
+ constexpr int blck_size = QK_K; \
130
+ return __VA_ARGS__(); \
131
+ } \
132
+ case GGML_TYPE_Q6_K: { \
133
+ using type = block_q6_K; \
134
+ using vec_dot_type = block_q8_K; \
135
+ constexpr int blck_size = QK_K; \
136
+ return __VA_ARGS__(); \
137
+ } \
138
+ case GGML_TYPE_IQ4_XS: { \
139
+ using type = block_iq4_xs; \
140
+ using vec_dot_type = block_q8_K; \
141
+ constexpr int blck_size = QK_K; \
142
+ return __VA_ARGS__(); \
143
+ } \
144
+ default: \
145
+ fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \
146
+ } \
147
+ }()
148
+
149
+ #define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
150
+ [&] { \
151
+ if (BOOL_V) { \
152
+ constexpr bool BOOL_NAME = true; \
153
+ return __VA_ARGS__(); \
154
+ } else { \
155
+ constexpr bool BOOL_NAME = false; \
156
+ return __VA_ARGS__(); \
157
+ } \
158
+ }()
159
+
160
+ // define amx tile config data structure
161
+ struct tile_config_t{
162
+ uint8_t palette_id = 0;
163
+ uint8_t start_row = 0;
164
+ uint8_t reserved_0[14] = {0};
165
+ uint16_t colsb[16] = {0};
166
+ uint8_t rows[16] = {0};
167
+ };
168
+
169
+ // Notes: amx tile config
170
+ //
171
+ // Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,
172
+ // and accumulate the result to a 16 x 16 matrix C containing INT32 values,
173
+ //
174
+ // As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used
175
+ // instead of the normally used 16-16-64 config.
176
+ //
177
+ // Block A: {16, 32}, dtype = int8_t
178
+ // Block B: {16, 32}, dtype = uint8_t/int8_t
179
+ // Block C: {16, 16}, dtype = int32_t
180
+ //
181
+ // Block B needs to be prepacked to vnni format before feeding into TMUL:
182
+ // packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}
183
+ //
184
+ // Therefore, we get tileconfig:
185
+ // A B C
186
+ // rows 16 8 16
187
+ // colsb 32 64 16
188
+ //
189
+ // For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,
190
+ // C used TMM4-TMM7:
191
+ // B TMM0 B TMM1
192
+ // A TMM2 C TMM4 C TMM6
193
+ // A TMM3 C TMM5 C TMM7
194
+ //
195
+ // Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A
196
+ // will be needed.
197
+ //
198
+ // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;
199
+ // and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.
200
+ //
201
+ // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/
202
+ // advanced-matrix-extensions-intrinsics-functions.html
203
+ //
204
+
205
+ #define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
206
+ void ggml_tile_config_init(void) {
207
+ static thread_local bool is_first_time = true;
208
+
209
+ if (!is_first_time) {
210
+ return;
211
+ }
212
+
213
+ static thread_local tile_config_t tc;
214
+ tile_config_t current_tc;
215
+ _tile_storeconfig(&current_tc);
216
+
217
+ // load only when config changes
218
+ if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
219
+ memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
220
+ tc.palette_id = 1;
221
+ tc.start_row = 0;
222
+ TC_CONFIG_TILE(TMM0, 8, 64);
223
+ TC_CONFIG_TILE(TMM1, 8, 64);
224
+ TC_CONFIG_TILE(TMM2, 16, 32);
225
+ TC_CONFIG_TILE(TMM3, 16, 32);
226
+ TC_CONFIG_TILE(TMM4, 16, 64);
227
+ TC_CONFIG_TILE(TMM5, 16, 64);
228
+ TC_CONFIG_TILE(TMM6, 16, 64);
229
+ TC_CONFIG_TILE(TMM7, 16, 64);
230
+ _tile_loadconfig(&tc);
231
+ }
232
+
233
+ is_first_time = false;
234
+ }
235
+
236
+ // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
237
+ // See the notes `s8s8 igemm compensation in avx512-vnni` for detail.
238
+ template <typename TB>
239
+ int get_tile_size() {
240
+ int tile_size = TILE_N * sizeof(TB);
241
+ if (do_compensate<TB>::value) {
242
+ tile_size += TILE_N * sizeof(int32_t);
243
+ }
244
+ if (std::is_same<TB, block_q4_K>::value ||
245
+ std::is_same<TB, block_q5_K>::value) {
246
+ tile_size += TILE_N * 4;
247
+ }
248
+ if (std::is_same<TB, block_iq4_xs>::value) {
249
+ tile_size += TILE_N * 2;
250
+ }
251
+ return tile_size;
252
+ }
253
+
254
+ template <typename TB, int BLOCK_K>
255
+ int get_row_size(int K) {
256
+ int KB = K / BLOCK_K;
257
+ int row_size = KB * sizeof(TB);
258
+ if (do_compensate<TB>::value) {
259
+ row_size += KB * sizeof(int32_t);
260
+ }
261
+ if (std::is_same<TB, block_q4_K>::value ||
262
+ std::is_same<TB, block_q5_K>::value) {
263
+ row_size += KB * 4;
264
+ }
265
+ if (std::is_same<TB, block_iq4_xs>::value) {
266
+ row_size += KB * 2;
267
+ }
268
+ return row_size;
269
+ }
270
+
271
+ // vectorized dtype conversion
272
+ inline float FP16_TO_FP32(ggml_half val) {
273
+ __m256i v = _mm256_setr_epi16(
274
+ val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
275
+ __m512 o = _mm512_cvtph_ps(v);
276
+ return _mm512_cvtss_f32(o);
277
+ }
278
+
279
+ inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
280
+ __m256i v = _mm256_set1_epi16(val);
281
+ return _mm512_cvtph_ps(v);
282
+ }
283
+
284
+ // horizontal reduce
285
+ inline float _mm512_reduce_max_ps(const __m512 x) {
286
+ __m512 v = x;
287
+ __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
288
+ v = _mm512_max_ps(v, v1);
289
+ v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
290
+ v = _mm512_max_ps(v, v1);
291
+ v1 = _mm512_shuffle_ps(v, v, 0x4E);
292
+ v = _mm512_max_ps(v, v1);
293
+ v1 = _mm512_shuffle_ps(v, v, 0xB1);
294
+ v = _mm512_max_ps(v, v1);
295
+ return _mm512_cvtss_f32(v);
296
+ }
297
+
298
+ // transpose utils
299
+ #define SHUFFLE_EPI32(a, b, mask) \
300
+ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
301
+ inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) {
302
+ // unpacking and 32-bit elements
303
+ v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);
304
+ v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);
305
+ v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);
306
+ v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);
307
+ v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);
308
+ v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);
309
+ v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);
310
+ v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);
311
+
312
+ // shuffling the 32-bit elements
313
+ v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);
314
+ v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);
315
+ v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);
316
+ v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);
317
+ v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);
318
+ v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);
319
+ v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);
320
+ v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);
321
+
322
+ // shuffling 128-bit elements
323
+ v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);
324
+ v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);
325
+ v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);
326
+ v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);
327
+ v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);
328
+ v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);
329
+ v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);
330
+ v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);
331
+ }
332
+
333
+ inline void transpose_16x4_32bit(__m512i * r, __m512i * d) {
334
+
335
+ static const __m512i index1 = _mm512_set_epi32(
336
+ 0x0f, 0x0b, 0x07, 0x03,
337
+ 0x0e, 0x0a, 0x06, 0x02,
338
+ 0x0d, 0x09, 0x05, 0x01,
339
+ 0x0c, 0x08, 0x04, 0x00);
340
+
341
+ d[0] = _mm512_permutexvar_epi32(index1, r[0]);
342
+ d[1] = _mm512_permutexvar_epi32(index1, r[1]);
343
+ d[2] = _mm512_permutexvar_epi32(index1, r[2]);
344
+ d[3] = _mm512_permutexvar_epi32(index1, r[3]);
345
+
346
+ r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
347
+ r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
348
+ r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
349
+ r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
350
+
351
+ d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
352
+ d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
353
+ d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
354
+ d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
355
+ }
356
+
357
+ inline void transpose_16x16_32bit(__m512i * v) {
358
+ __m512i v1[16];
359
+ v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
360
+ v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
361
+ v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
362
+ v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
363
+ v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
364
+ v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
365
+ v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
366
+ v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
367
+ v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
368
+ v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
369
+ v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
370
+ v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
371
+ v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
372
+ v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
373
+ v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
374
+ v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
375
+
376
+ v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
377
+ v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
378
+ v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
379
+ v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
380
+ v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
381
+ v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
382
+ v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
383
+ v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
384
+ v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
385
+ v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
386
+ v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
387
+ v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
388
+ v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
389
+ v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
390
+ v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
391
+ v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
392
+
393
+ v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
394
+ v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
395
+ v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
396
+ v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
397
+ v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
398
+ v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
399
+ v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
400
+ v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
401
+ v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
402
+ v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
403
+ v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
404
+ v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
405
+ v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
406
+ v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
407
+ v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
408
+ v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
409
+
410
+ v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
411
+ v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
412
+ v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
413
+ v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
414
+ v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
415
+ v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
416
+ v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
417
+ v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
418
+ v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
419
+ v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
420
+ v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
421
+ v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
422
+ v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
423
+ v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
424
+ v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
425
+ v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
426
+ }
427
+
428
+ void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) {
429
+ assert(k % QK_K == 0);
430
+ const int KB = k / QK_K;
431
+ constexpr int kVecs = QK_K / 16;
432
+
433
+ block_q8_K * y = reinterpret_cast<block_q8_K *>(vy);
434
+
435
+ // hold 16 float vecs from x
436
+ __m512 v[kVecs];
437
+
438
+ // hold the quants vecs
439
+ __m512i vq[kVecs / 4];
440
+
441
+ // hold the packed quants vecs
442
+ __m512i vq_packed[kVecs / 4];
443
+
444
+ const __m512 signBit = _mm512_set1_ps(-0.f);
445
+
446
+ for (int i = 0; i < KB; ++i) {
447
+ // Compute max(abs(e)) for the block
448
+ __m512 vamax = _mm512_set1_ps(0.f);
449
+ for (int j = 0; j < kVecs; ++j) {
450
+ v[j] = _mm512_loadu_ps(x); x += 16;
451
+ vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));
452
+ }
453
+ const float amax = _mm512_reduce_max_ps(vamax);
454
+
455
+ // Quantize these floats
456
+ const float iscale = 127.f / amax;
457
+ y[i].d = GGML_FP32_TO_FP16(1 / iscale);
458
+ const float id = ( amax != 0.0f ) ? iscale : 0.f;
459
+ const __m512 vscale = _mm512_set1_ps(id);
460
+
461
+ // Apply multiplier and round to nearest integer
462
+ for (int j = 0; j < kVecs; ++j) {
463
+ v[j] = _mm512_mul_ps(v[j], vscale);
464
+ v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
465
+ }
466
+
467
+ // Pack to epi8 vecs
468
+ for (int j = 0; j < kVecs / 4; ++j) {
469
+ __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));
470
+ __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));
471
+ __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));
472
+ __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));
473
+
474
+ __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);
475
+ __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);
476
+
477
+ vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);
478
+ _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]);
479
+ }
480
+
481
+ // Compute the bsums with vnni
482
+ transpose_16x4_32bit(vq, vq_packed);
483
+
484
+ const __m512i one = _mm512_set1_epi8(1);
485
+ __m512i sum = _mm512_setzero_si512();
486
+ for (int k = 0; k < 4; ++k) {
487
+ sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);
488
+ }
489
+ _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum));
490
+ }
491
+ }
492
+
493
+ // quantize A from float to `vec_dot_type`
494
+ template <typename T>
495
+ inline void from_float(const float * x, char * vy, int64_t k);
496
+
497
+ template <>
498
+ inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
499
+ // FIXME: using unoptimized reference impl until moved to CPU backend
500
+ quantize_row_q8_0_ref(x, (block_q8_0 *)vy, k);
501
+ }
502
+
503
+ template <>
504
+ inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
505
+ quantize_row_q8_1_ref(x, (block_q8_1 *)vy, k);
506
+ }
507
+
508
+ template <>
509
+ inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
510
+ #if 1
511
+ // TODO: this is reference impl!
512
+ quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);
513
+ #else
514
+ quantize_row_q8_K_vnni(x, vy, k);
515
+ #endif
516
+ }
517
+
518
+ // load A from memory to array when nrows can not fill in whole tile
519
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) {
520
+ assert(nr != TILE_M);
521
+ for (int m = 0; m < nr; ++m) {
522
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
523
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
524
+ }
525
+ }
526
+
527
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) {
528
+ assert(nr != TILE_M);
529
+ for (int m = 0; m < nr; ++m) {
530
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
531
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
532
+ }
533
+ }
534
+
535
+ template <typename TB>
536
+ void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
537
+ assert(nr <= TILE_M);
538
+ for (int m = 0; m < nr; ++m) {
539
+ const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32));
540
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
541
+ }
542
+ }
543
+
544
+ template <>
545
+ void unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
546
+ assert(nr <= TILE_M);
547
+ // zero padding k from 16 to 32, so that we don't have to re-config amx
548
+ const __m128i zero = _mm_setzero_si128();
549
+ for (int m = 0; m < nr; ++m) {
550
+ const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16));
551
+ const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);
552
+ _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r);
553
+ }
554
+ }
555
+
556
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
557
+ inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
558
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
559
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
560
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
561
+ return _mm256_and_si256(lowMask, bytes);
562
+ }
563
+
564
+ // used for block_q4_K
565
+ inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) {
566
+ const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi);
567
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
568
+ const __m256i q4l = _mm256_and_si256(tmp, lowMask);
569
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);
570
+ return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);
571
+ }
572
+
573
+ // used for block_q5_K
574
+ inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) {
575
+ const __m256i lowMask = _mm256_set1_epi8(0xF);
576
+ __m256i hmask = _mm256_set1_epi8(1);
577
+ hmask = _mm256_slli_epi16(hmask, k);
578
+
579
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs);
580
+ const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh);
581
+
582
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);
583
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);
584
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
585
+ hmask = _mm256_slli_epi16(hmask, 1);
586
+
587
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);
588
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);
589
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
590
+
591
+ return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);
592
+ }
593
+
594
+ // used for block_q6_K
595
+ inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) {
596
+ const __m256i m4 = _mm256_set1_epi8(0xF);
597
+ const __m256i m2 = _mm256_set1_epi8(0x3);
598
+
599
+ const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs);
600
+ const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32));
601
+ const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh);
602
+
603
+ const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4);
604
+ const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);
605
+ const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);
606
+ const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);
607
+
608
+ const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);
609
+ const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);
610
+ const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);
611
+ const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);
612
+
613
+ r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);
614
+ r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);
615
+ }
616
+
617
+ inline __m512i packNibbles(__m512i r0, __m512i r1) {
618
+ return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4));
619
+ }
620
+
621
+ template <typename TB>
622
+ inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) {
623
+ int8_t tmp[8 * 64];
624
+ __m256i v[8], v2[8];
625
+ for (int n = 0; n < 8; ++n) {
626
+ v[n] = bytes_from_nibbles_32(B[n * KB].qs);
627
+ }
628
+ transpose_8x8_32bit(v, v2);
629
+ for (int n = 0; n < 8; ++n) {
630
+ _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]);
631
+ }
632
+ for (int n = 0; n < 8; ++n) {
633
+ v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);
634
+ }
635
+ transpose_8x8_32bit(v, v2);
636
+ for (int n = 0; n < 8; ++n) {
637
+ _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]);
638
+ }
639
+
640
+ // pack again with 128 to fully utilize vector length
641
+ for (int n = 0; n < 8; n += 2) {
642
+ __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64));
643
+ __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64));
644
+ __m512i r1r0 = packNibbles(r0, r1);
645
+ _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0);
646
+ }
647
+ }
648
+
649
+ template <>
650
+ inline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
651
+ __m256i v[8], v2[8];
652
+ for (int n = 0; n < 8; ++n) {
653
+ v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs));
654
+ }
655
+ transpose_8x8_32bit(v, v2);
656
+ for (int n = 0; n < 8; ++n) {
657
+ _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]);
658
+ }
659
+ for (int n = 0; n < 8; ++n) {
660
+ v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs));
661
+ }
662
+ transpose_8x8_32bit(v, v2);
663
+ for (int n = 0; n < 8; ++n) {
664
+ _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]);
665
+ }
666
+ }
667
+
668
+ template <>
669
+ inline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
670
+ __m512i v[16];
671
+ // QK_K 256 with 8 groups, handle 2 groups at a time
672
+ char * pb = (char *)packed_B;
673
+ for (int k = 0; k < QK_K / 64; ++k) {
674
+ // pack 2 groups { n, g, k} to {g, k/4, 4n}
675
+ // e.g. {16, 2, 32} to {2, 8, 64}
676
+ for (int n = 0; n < TILE_N; ++n) {
677
+ v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);
678
+ }
679
+
680
+ transpose_16x16_32bit(v);
681
+
682
+ // pack again with 128 to fully utilize vector length
683
+ for (int n = 0; n < TILE_N; n += 2) {
684
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
685
+ pb += 64;
686
+ }
687
+ }
688
+ }
689
+
690
+ template <>
691
+ inline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
692
+ __m512i v[16];
693
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
694
+ // QK_K 256 with 8 groups, handle 2 groups at a time
695
+ char * pb = (char *)packed_B;
696
+ char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
697
+ for (int k = 0; k < QK_K / 64; ++k) {
698
+ // pack 2 groups { n, g, k} to {g, k/4, 4n}
699
+ // e.g. {16, 2, 32} to {2, 8, 64}
700
+ for (int n = 0; n < TILE_N; ++n) {
701
+ v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k);
702
+ }
703
+
704
+ transpose_16x16_32bit(v);
705
+
706
+ // 1. pack lower 4bits with 2 groups
707
+ for (int n = 0; n < TILE_N; n += 2) {
708
+ // get lower 4 bits
709
+ const __m512i r0 = _mm512_and_si512(v[n], lowMask);
710
+ const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
711
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
712
+ }
713
+
714
+ // 2. pack higher 1bit with 2 groups
715
+ const __m512i hmask = _mm512_set1_epi8(0x10);
716
+ for (int g = 0; g < 2; ++g) {
717
+ __m512i hbits = _mm512_setzero_si512();
718
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));
719
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));
720
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));
721
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));
722
+ hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) );
723
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));
724
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));
725
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));
726
+ _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
727
+ }
728
+ }
729
+ }
730
+
731
+ template <>
732
+ inline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
733
+ __m512i v[32];
734
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
735
+ // QK_K 256 with 8 groups, handle 4 groups at a time
736
+ char * pb = (char *)packed_B;
737
+ char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
738
+ for (int k = 0; k < QK_K / 128; ++k) {
739
+ for (int n = 0; n < TILE_N; ++n) {
740
+ bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);
741
+ }
742
+
743
+ // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7
744
+ transpose_16x16_32bit(v);
745
+ transpose_16x16_32bit(v + 16);
746
+
747
+ // 1. pack lower 4bits with 4 groups
748
+ for (int n = 0; n < 32; n += 2) {
749
+ const __m512i r0 = _mm512_and_si512(v[n], lowMask);
750
+ const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
751
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
752
+ }
753
+
754
+ // 2. pack higher 2bit with 4 groups
755
+ const __m512i hmask = _mm512_set1_epi8(0x30);
756
+ for (int g = 0; g < 8; ++g) {
757
+ __m512i hbits = _mm512_setzero_si512();
758
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));
759
+ hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));
760
+ hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) );
761
+ hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));
762
+ _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
763
+ }
764
+ }
765
+ }
766
+
767
+ template <>
768
+ inline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
769
+ __m512i v[16];
770
+ char * pb = (char *)packed_B;
771
+ for (int k = 0; k < QK_K / 64; ++k) {
772
+ for (int n = 0; n < TILE_N; ++n) {
773
+ __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0);
774
+ __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);
775
+ v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);
776
+ }
777
+
778
+ transpose_16x16_32bit(v);
779
+
780
+ // pack again with 128 to fully utilize vector length
781
+ for (int n = 0; n < TILE_N; n += 2) {
782
+ _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
783
+ pb += 64;
784
+ }
785
+ }
786
+ }
787
+
788
+ // pack B to vnni formats in 4bits or 8 bits
789
+ void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) {
790
+ pack_qs(packed_B, B, KB);
791
+ ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
792
+ for (int n = 0; n < TILE_N; ++n) {
793
+ d0[n] = B[n * KB].d;
794
+ }
795
+ }
796
+
797
+ void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) {
798
+ pack_qs(packed_B, B, KB);
799
+ ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
800
+ ggml_half * m0 = d0 + TILE_N;
801
+ for (int n = 0; n < TILE_N; ++n) {
802
+ d0[n] = B[n * KB].d;
803
+ m0[n] = B[n * KB].m;
804
+ }
805
+ }
806
+
807
+ inline void s8s8_compensation(void * RESTRICT packed_B) {
808
+ // packed_B layout:
809
+ // quants {TILE_N, TILEK} int8_t
810
+ // d0 {TILE_N} ggml_half
811
+ // comp {TILE_N} int32_t
812
+ const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
813
+ __m512i vcomp = _mm512_setzero_si512();
814
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
815
+ for (int k = 0; k < 8; ++k) {
816
+ __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64));
817
+ vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);
818
+ }
819
+ _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp);
820
+ }
821
+
822
+ void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
823
+ pack_qs(packed_B, B, KB);
824
+ ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K);
825
+ for (int n = 0; n < TILE_N; ++n) {
826
+ d0[n] = B[n * KB].d;
827
+ }
828
+ s8s8_compensation(packed_B);
829
+ }
830
+
831
+ // convert 8 * {min, scale} from int6 to int8
832
+ inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) {
833
+ const uint32_t kmask1 = 0x3f3f3f3f;
834
+ const uint32_t kmask2 = 0x0f0f0f0f;
835
+ const uint32_t kmask3 = 0x03030303;
836
+
837
+ memcpy(utmp, scales, 12);
838
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
839
+ const uint32_t uaux = utmp[1] & kmask1;
840
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
841
+ utmp[2] = uaux;
842
+ utmp[0] &= kmask1;
843
+ }
844
+
845
+ // packed_B layout:
846
+ // quants {8, TILE_N, 16} uint8
847
+ // scales {8, TILE_N} uint8
848
+ // mins {8, TILE_N} uint8
849
+ // d {TILE_N} ggml_half
850
+ // dmin {TILE_N} ggml_half
851
+ void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
852
+ pack_qs(packed_B, B, KB);
853
+
854
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
855
+ uint8_t * mins = scales + 8 * TILE_N;
856
+ ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);
857
+ ggml_half * dmin = d + TILE_N;
858
+
859
+ union {
860
+ uint32_t u32[4];
861
+ uint8_t u8[16];
862
+ } s;
863
+
864
+ for (int n = 0; n < TILE_N; ++n) {
865
+ unpack_mins_and_scales(B[n * KB].scales, s.u32);
866
+ for (int k = 0; k < 8; ++k) {
867
+ scales[k * TILE_N + n] = s.u8[k];
868
+ mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
869
+ }
870
+ d[n] = B[n * KB].d;
871
+ dmin[n] = B[n * KB].dmin;
872
+ }
873
+ }
874
+
875
+ // packed_B layout:
876
+ // quants {8, TILE_N, 16} uint8
877
+ // qh {8, TILE_N, 4} uint8
878
+ // scales {8, TILE_N} uint8
879
+ // mins {8, TILE_N} uint8
880
+ // d {TILE_N} ggml_half
881
+ // dmin {TILE_N} ggml_half
882
+ void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
883
+ pack_qs(packed_B, B, KB);
884
+
885
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
886
+ uint8_t * mins = scales + 8 * TILE_N;
887
+ ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);
888
+ ggml_half * dmin = d + TILE_N;
889
+
890
+ union {
891
+ uint32_t u32[4];
892
+ uint8_t u8[16];
893
+ } s;
894
+
895
+ for (int n = 0; n < TILE_N; ++n) {
896
+ unpack_mins_and_scales(B[n * KB].scales, s.u32);
897
+ for (int k = 0; k < 8; ++k) {
898
+ scales[k * TILE_N + n] = s.u8[k];
899
+ mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
900
+ }
901
+ d[n] = B[n * KB].d;
902
+ dmin[n] = B[n * KB].dmin;
903
+ }
904
+ }
905
+
906
+ // packed_B layout:
907
+ // quants {16, TILE_N, 8} uint8
908
+ // qh {16, TILE_N, 4} uint8
909
+ // scales {16, TILE_N} uint8
910
+ // d {TILE_N} ggml_half
911
+ void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
912
+ pack_qs(packed_B, B, KB);
913
+
914
+ uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
915
+ ggml_half * d = reinterpret_cast<ggml_half *>(scales + 16 * TILE_N);
916
+ for (int n = 0; n < TILE_N; ++n) {
917
+ const int8_t * ps = B[n * KB].scales;
918
+ for (int k = 0; k < 16; ++k) {
919
+ scales[k * TILE_N + n] = ps[k];
920
+ }
921
+ d[n] = B[n * KB].d;
922
+ }
923
+ }
924
+
925
+ // packed_B layout:
926
+ // quants {8, TILE_N, 16} uint8
927
+ // scales {8, TILE_N} int8
928
+ // d {TILE_N} ggml_half
929
+ void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
930
+ pack_qs(packed_B, B, KB);
931
+
932
+ int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
933
+ ggml_half * d = reinterpret_cast<ggml_half *>(scales + 8 * TILE_N);
934
+
935
+ // pack the scales
936
+ for (int n = 0; n < TILE_N; ++n) {
937
+ uint16_t sh = B[n * KB].scales_h;
938
+ for (int k = 0; k < 8; k += 2) {
939
+ const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;
940
+ const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32;
941
+ scales[(k + 0) * TILE_N + n] = ls1;
942
+ scales[(k + 1) * TILE_N + n] = ls2;
943
+ sh >>= 4;
944
+ }
945
+ d[n] = B[n * KB].d;
946
+ }
947
+ }
948
+
949
+ template<typename TB, typename packed_B_t = packed_B_type<TB>>
950
+ void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {
951
+ GGML_UNUSED(tile);
952
+ GGML_UNUSED(packed_B);
953
+ };
954
+
955
+ template <>
956
+ void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {
957
+ const __m512i off = _mm512_set1_epi8(8);
958
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
959
+ for (int n = 0; n < 8; n += 2) {
960
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
961
+ const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);
962
+ const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);
963
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
964
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
965
+ }
966
+ }
967
+
968
+ template <>
969
+ void unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) {
970
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
971
+ for (int n = 0; n < 8; n += 2) {
972
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
973
+ const __m512i r0 = _mm512_and_si512(bytes, lowMask);
974
+ const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
975
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
976
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
977
+ }
978
+ }
979
+
980
+ // packed_B_t for QKK is int8_t
981
+ template <typename TB>
982
+ void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
983
+ const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
984
+ const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size;
985
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
986
+ for (int n = 0; n < 8; n += 2) {
987
+ __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);
988
+ const __m512i r0 = _mm512_and_si512(bytes, lowMask);
989
+ const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
990
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
991
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
992
+ }
993
+ }
994
+
995
+ template <>
996
+ void unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
997
+ // lower 4bits, stride 256 bytes
998
+ const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;
999
+ const char * pb = (const char *)packed_B + k * packed_l4_group_size;
1000
+
1001
+ // higher 1bit, stride 64 bytes
1002
+ const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;
1003
+ const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;
1004
+ const __m512i hbits = _mm512_loadu_si512(ph);
1005
+
1006
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1007
+ __m512i hmask0 = _mm512_set1_epi8(0x1);
1008
+ __m512i hmask1 = _mm512_set1_epi8(0x2);
1009
+
1010
+ for (int n = 0; n < 8; n += 2) {
1011
+ __m512i bytes = _mm512_loadu_si512(pb + n * 32);
1012
+ __m512i r0 = _mm512_and_si512(bytes, lowMask);
1013
+ __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1014
+ __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);
1015
+ __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);
1016
+
1017
+ hmask0 = _mm512_slli_epi16(hmask0, 2);
1018
+ hmask1 = _mm512_slli_epi16(hmask1, 2);
1019
+ r0 = _mm512_add_epi8(r0, h0);
1020
+ r1 = _mm512_add_epi8(r1, h1);
1021
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
1022
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
1023
+ }
1024
+ }
1025
+
1026
+ template <>
1027
+ void unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
1028
+ // lower 4bits, stride 128 bytes
1029
+ const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;
1030
+ const char * pb = (const char *)packed_B + k * packed_l4_group_size;
1031
+
1032
+ // higher 2bits, stride 64 bytes
1033
+ const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;
1034
+ const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;
1035
+ const __m512i hbits = _mm512_loadu_si512(ph);
1036
+
1037
+ const __m512i off = _mm512_set1_epi8(32);
1038
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1039
+ __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011
1040
+ __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100
1041
+
1042
+ // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`
1043
+ __m512i bytes = _mm512_loadu_si512(pb);
1044
+ __m512i r0 = _mm512_and_si512(bytes, lowMask);
1045
+ __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1046
+ __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);
1047
+ __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);
1048
+ _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
1049
+ _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
1050
+
1051
+ hmask0 = _mm512_slli_epi16(hmask0, 4);
1052
+ hmask1 = _mm512_slli_epi16(hmask1, 4);
1053
+
1054
+ bytes = _mm512_loadu_si512(pb + 64);
1055
+ r0 = _mm512_and_si512(bytes, lowMask);
1056
+ r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1057
+ h0 = _mm512_and_si512(hbits, hmask0);
1058
+ h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);
1059
+ _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
1060
+ _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
1061
+ }
1062
+
1063
+ template <>
1064
+ void unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
1065
+ static const __m512i values128 = _mm512_set_epi8(
1066
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1067
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1068
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1069
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
1070
+ );
1071
+
1072
+ const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
1073
+ const char * pb = (const char *)packed_B + k * packed_B_group_size;
1074
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1075
+
1076
+ for (int n = 0; n < 8; n += 2) {
1077
+ __m512i bytes = _mm512_loadu_si512(pb + n * 32);
1078
+ const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));
1079
+ const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
1080
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
1081
+ _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
1082
+ }
1083
+ }
1084
+
1085
+ template <typename TA, typename TB, bool is_acc>
1086
+ struct acc_C {};
1087
+
1088
+ template <bool is_acc>
1089
+ struct acc_C<block_q8_0, block_q4_0, is_acc> {
1090
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
1091
+ const int offset = TILE_N * TILE_K / 2;
1092
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1093
+
1094
+ for (int m = 0; m < nr; ++m) {
1095
+ const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
1096
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1097
+
1098
+ __m512 vsum;
1099
+ if (is_acc) {
1100
+ vsum = _mm512_loadu_ps(C + m * ldc);
1101
+ } else {
1102
+ vsum = _mm512_set1_ps(0.f);
1103
+ }
1104
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1105
+ _mm512_storeu_ps(C + m * ldc, vsum);
1106
+ }
1107
+ }
1108
+ };
1109
+
1110
+ template <bool is_acc>
1111
+ struct acc_C<block_q8_1, block_q4_1, is_acc> {
1112
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) {
1113
+ const int offset = TILE_N * TILE_K / 2;
1114
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1115
+ const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half))));
1116
+
1117
+ for (int m = 0; m < nr; ++m) {
1118
+ const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
1119
+ const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s));
1120
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1121
+
1122
+ __m512 vsum;
1123
+ if (is_acc) {
1124
+ vsum = _mm512_loadu_ps(C + m * ldc);
1125
+ } else {
1126
+ vsum = _mm512_set1_ps(0.f);
1127
+ }
1128
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1129
+ vsum = _mm512_fmadd_ps(vm0, vs1, vsum);
1130
+ _mm512_storeu_ps(C + m * ldc, vsum);
1131
+ }
1132
+ }
1133
+ };
1134
+
1135
+ template <bool is_acc>
1136
+ struct acc_C<block_q8_0, block_q8_0, is_acc> {
1137
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
1138
+ const int offset = TILE_N * TILE_K;
1139
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
1140
+
1141
+ for (int m = 0; m < nr; ++m) {
1142
+ const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
1143
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1144
+
1145
+ __m512 vsum;
1146
+ if (is_acc) {
1147
+ vsum = _mm512_loadu_ps(C + m * ldc);
1148
+ } else {
1149
+ vsum = _mm512_set1_ps(0.f);
1150
+ }
1151
+ vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
1152
+ _mm512_storeu_ps(C + m * ldc, vsum);
1153
+ }
1154
+ }
1155
+ };
1156
+
1157
+ template <bool is_acc>
1158
+ struct acc_C<block_q8_K, block_q4_K, is_acc> {
1159
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1160
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
1161
+ const uint8_t * mins = scales + 8 * TILE_N;
1162
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);
1163
+ const ggml_half * dmin = d0 + TILE_N;
1164
+
1165
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1166
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
1167
+
1168
+ for (int m = 0; m < nr; ++m) {
1169
+ const float d1 = A[m * lda].d;
1170
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1171
+ const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
1172
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1173
+
1174
+ __m512 vsum;
1175
+ if (is_acc) {
1176
+ vsum = _mm512_loadu_ps(C + m * ldc);
1177
+ } else {
1178
+ vsum = _mm512_set1_ps(0.f);
1179
+ }
1180
+
1181
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
1182
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1183
+
1184
+ __m512i acc_m = _mm512_setzero_si512();
1185
+ for (int k = 0; k < 4; ++k) {
1186
+ __m512i vmask = _mm512_set1_epi32(k);
1187
+ __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
1188
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
1189
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1190
+ }
1191
+
1192
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1193
+ vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
1194
+ _mm512_storeu_ps(C + m * ldc, vsum);
1195
+ }
1196
+ }
1197
+ };
1198
+
1199
+ template <bool is_acc>
1200
+ struct acc_C<block_q8_K, block_q5_K, is_acc> {
1201
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1202
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
1203
+ const uint8_t * mins = scales + 8 * TILE_N;
1204
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);
1205
+ const ggml_half * dmin = d0 + TILE_N;
1206
+
1207
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1208
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
1209
+
1210
+ for (int m = 0; m < nr; ++m) {
1211
+ const float d1 = A[m * lda].d;
1212
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1213
+ const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
1214
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1215
+
1216
+ __m512 vsum;
1217
+ if (is_acc) {
1218
+ vsum = _mm512_loadu_ps(C + m * ldc);
1219
+ } else {
1220
+ vsum = _mm512_set1_ps(0.f);
1221
+ }
1222
+
1223
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
1224
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1225
+
1226
+ __m512i acc_m = _mm512_setzero_si512();
1227
+ for (int k = 0; k < 4; ++k) {
1228
+ __m512i vmask = _mm512_set1_epi32(k);
1229
+ __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
1230
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
1231
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1232
+ }
1233
+
1234
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1235
+ vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
1236
+ _mm512_storeu_ps(C + m * ldc, vsum);
1237
+ }
1238
+ }
1239
+ };
1240
+
1241
+ template <bool is_acc>
1242
+ struct acc_C<block_q8_K, block_q6_K, is_acc> {
1243
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1244
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
1245
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 16 * TILE_N);
1246
+
1247
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1248
+
1249
+ for (int m = 0; m < nr; ++m) {
1250
+ const float d1 = A[m * lda].d;
1251
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1252
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1253
+
1254
+ __m512 vsum;
1255
+ if (is_acc) {
1256
+ vsum = _mm512_loadu_ps(C + m * ldc);
1257
+ } else {
1258
+ vsum = _mm512_set1_ps(0.f);
1259
+ }
1260
+
1261
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1262
+ _mm512_storeu_ps(C + m * ldc, vsum);
1263
+ }
1264
+ }
1265
+ };
1266
+
1267
+ template <bool is_acc>
1268
+ struct acc_C<block_q8_K, block_iq4_xs, is_acc> {
1269
+ static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
1270
+ const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
1271
+ const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 8 * TILE_N);
1272
+
1273
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
1274
+
1275
+ for (int m = 0; m < nr; ++m) {
1276
+ const float d1 = A[m * lda].d;
1277
+ const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
1278
+ const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
1279
+
1280
+ __m512 vsum;
1281
+ if (is_acc) {
1282
+ vsum = _mm512_loadu_ps(C + m * ldc);
1283
+ } else {
1284
+ vsum = _mm512_set1_ps(0.f);
1285
+ }
1286
+
1287
+ vsum = _mm512_fmadd_ps(vtile, vd, vsum);
1288
+ _mm512_storeu_ps(C + m * ldc, vsum);
1289
+ }
1290
+ }
1291
+ };
1292
+
1293
+ template <typename TB> constexpr int get_quants_size();
1294
+ template <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; }
1295
+ template <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; }
1296
+ template <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; }
1297
+ template <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; }
1298
+
1299
+ // used for QKK format
1300
+ template <typename TB, bool is_acc,
1301
+ typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
1302
+ inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) {
1303
+ const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>());
1304
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N)));
1305
+
1306
+ for (int m = 0; m < nr; ++m) {
1307
+ __m512i vsumi;
1308
+ if (is_acc) {
1309
+ vsumi = _mm512_loadu_si512(sumi + m * TILE_N);
1310
+ } else {
1311
+ vsumi = _mm512_setzero_si512();
1312
+ }
1313
+ __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);
1314
+ vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));
1315
+ _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi);
1316
+ }
1317
+ }
1318
+
1319
+ template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
1320
+ struct tinygemm_kernel_avx {
1321
+ static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) {
1322
+ GGML_UNUSED(K);
1323
+ GGML_UNUSED(A);
1324
+ GGML_UNUSED(B);
1325
+ GGML_UNUSED(C);
1326
+ GGML_UNUSED(ldc);
1327
+ }
1328
+ };
1329
+
1330
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1331
+ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1332
+ static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) {
1333
+ constexpr int ROWS = BLOCK_M;
1334
+ constexpr int COLS = BLOCK_N;
1335
+ assert(BLOCK_K == 16);
1336
+
1337
+ __m512 va;
1338
+ __m512 vb[COLS];
1339
+ __m512 vc[ROWS * COLS];
1340
+
1341
+ auto loadc = [&](int idx) {
1342
+ vc[idx] = _mm512_setzero_ps();
1343
+ };
1344
+ Unroll<ROWS * COLS>{}(loadc);
1345
+
1346
+ auto compute = [&](int idx, int k) {
1347
+ // TODO: use `constexpr` here to get rid of interger div
1348
+ // when upgraded to C++17
1349
+ const int row = idx / COLS;
1350
+ const int col = idx % COLS;
1351
+
1352
+ if (col == 0) {
1353
+ va = _mm512_loadu_ps(A + row * K + k);
1354
+ }
1355
+ if (row == 0) {
1356
+ vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
1357
+ }
1358
+ vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
1359
+ };
1360
+
1361
+ for (int k = 0; k < K; k += 16) {
1362
+ Unroll<ROWS * COLS>{}(compute, k);
1363
+ }
1364
+
1365
+ auto storec = [&](int idx) {
1366
+ const int row = idx / COLS;
1367
+ const int col = idx % COLS;
1368
+ C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
1369
+ };
1370
+ Unroll<ROWS * COLS>{}(storec);
1371
+ }
1372
+ };
1373
+
1374
+ #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
1375
+ tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
1376
+ K, (const float *)src1->data + mb_start * K, \
1377
+ (const type *)src0->data + nb_start * K, \
1378
+ (float *)dst->data + mb_start * ldc + nb_start, ldc);
1379
+
1380
+
1381
+ // re-organize in the format {NB, KB, TILE_SIZE}:
1382
+ #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
1383
+
1384
+ template<typename TB, int BLOCK_K>
1385
+ void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) {
1386
+ const int NB = N / TILE_N;
1387
+ const int KB = K / BLOCK_K;
1388
+ const int TILE_SIZE = get_tile_size<TB>();
1389
+
1390
+ // parallel on NB should be enough
1391
+ parallel_for(n_threads, NB, [&](int begin, int end) {
1392
+ for (int n = begin; n < end; ++n) {
1393
+ for (int k = 0; k < KB; ++k) {
1394
+ int n0 = n * TILE_N;
1395
+ pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);
1396
+ }
1397
+ }
1398
+ });
1399
+ }
1400
+
1401
+ template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
1402
+ struct tinygemm_kernel_vnni {};
1403
+
1404
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1405
+ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1406
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1407
+
1408
+ constexpr int COLS = BLOCK_N / 16;
1409
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_0);
1410
+
1411
+ const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
1412
+ const char * RESTRICT B = static_cast<const char *>(_B);
1413
+
1414
+ __m512i va[8];
1415
+ __m512 vc[COLS];
1416
+ __m512 vd1;
1417
+
1418
+ // sum of offsets, shared across COLS
1419
+ //
1420
+ // avx512-vnni does not have `_mm512_dpbssd_epi32`,
1421
+ // need to transfrom ss to us:
1422
+ // a * (b - 8) is equavilent to b * a - 8 * a
1423
+ // s u u u s u s
1424
+ //
1425
+ __m512i vcomp;
1426
+
1427
+ const __m512i off = _mm512_set1_epi8(8);
1428
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1429
+
1430
+ auto loadc = [&](int col) {
1431
+ vc[col] = _mm512_setzero_ps();
1432
+ };
1433
+ Unroll<COLS>{}(loadc);
1434
+
1435
+ auto compute = [&](int col, int i) {
1436
+ // load a and compute compensation
1437
+ if (col == 0) {
1438
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1439
+ vcomp = _mm512_setzero_si512();
1440
+ for (int k = 0; k < 8; ++k) {
1441
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1442
+ vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);
1443
+ }
1444
+ vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
1445
+ }
1446
+
1447
+ // load b
1448
+ __m512i vsum = _mm512_setzero_si512();
1449
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1450
+ for (int k = 0; k < 8; k += 2) {
1451
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
1452
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1453
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);
1454
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1455
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);
1456
+ }
1457
+ const int offset = TILE_N * TILE_K / 2;
1458
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1459
+ vsum = _mm512_sub_epi32(vsum, vcomp);
1460
+
1461
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1462
+ };
1463
+
1464
+ for (int i = 0; i < KB; ++i) {
1465
+ Unroll<COLS>{}(compute, i);
1466
+ }
1467
+
1468
+ //store to C
1469
+ auto storec = [&](int col) {
1470
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1471
+ };
1472
+ Unroll<COLS>{}(storec);
1473
+ }
1474
+ };
1475
+
1476
+ template <int BLOCK_N, int BLOCK_K>
1477
+ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {
1478
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1479
+
1480
+ constexpr int COLS = BLOCK_N / 16;
1481
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_1);
1482
+
1483
+ const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A);
1484
+ const char * RESTRICT B = static_cast<const char *>(_B);
1485
+
1486
+ __m512i va[8];
1487
+ __m512i vb[8];
1488
+ __m512 vc[COLS];
1489
+ __m512 vd1, vs1;
1490
+
1491
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1492
+
1493
+ auto loadc = [&](int col) {
1494
+ vc[col] = _mm512_setzero_ps();
1495
+ };
1496
+ Unroll<COLS>{}(loadc);
1497
+
1498
+ auto compute = [&](int col, int i) {
1499
+ // load a
1500
+ if (col == 0) {
1501
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1502
+ for (int k = 0; k < 8; ++k) {
1503
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1504
+ }
1505
+ vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
1506
+ vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s));
1507
+ }
1508
+
1509
+ // load b
1510
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1511
+ for (int k = 0; k < 8; k += 2) {
1512
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
1513
+ vb[k + 0] = _mm512_and_si512(bytes, lowMask);
1514
+ vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1515
+ }
1516
+ const int offset = TILE_N * TILE_K / 2;
1517
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1518
+ const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half))));
1519
+
1520
+ __m512i vsum = _mm512_setzero_si512();
1521
+ for (int k = 0; k < 8; ++k) {
1522
+ vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);
1523
+ }
1524
+
1525
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1526
+ vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);
1527
+ };
1528
+
1529
+ for (int i = 0; i < KB; ++i) {
1530
+ Unroll<COLS>{}(compute, i);
1531
+ }
1532
+
1533
+ //store to C
1534
+ auto storec = [&](int col) {
1535
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1536
+ };
1537
+ Unroll<COLS>{}(storec);
1538
+ }
1539
+ };
1540
+
1541
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1542
+ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1543
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1544
+
1545
+ constexpr int COLS = BLOCK_N / 16;
1546
+ const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);
1547
+
1548
+ const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
1549
+ const char * RESTRICT B = static_cast<const char *>(_B);
1550
+
1551
+ __m512i va[8];
1552
+ __m512i vb[8];
1553
+ __m512 vc[COLS];
1554
+ __m512 vd1;
1555
+
1556
+ // Notes: s8s8 igemm compensation in avx512-vnni
1557
+ // change s8s8 to u8s8 with compensate
1558
+ // a * b = (a + 128) * b - 128 * b
1559
+ // s s u s u s
1560
+ //
1561
+ // (128 * b is pre-computed when packing B to vnni formats)
1562
+ //
1563
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1564
+
1565
+ auto loadc = [&](int col) {
1566
+ vc[col] = _mm512_setzero_ps();
1567
+ };
1568
+ Unroll<COLS>{}(loadc);
1569
+
1570
+ auto compute = [&](int col, int i) {
1571
+ // load a and add offset 128
1572
+ if (col == 0) {
1573
+ const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1574
+ for (int k = 0; k < 8; ++k) {
1575
+ va[k] = _mm512_set1_epi32(a_ptr[k]);
1576
+ va[k] = _mm512_add_epi8(va[k], off);
1577
+ }
1578
+ vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
1579
+ }
1580
+
1581
+ // load b
1582
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1583
+ for (int k = 0; k < 8; ++k) {
1584
+ vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64));
1585
+ }
1586
+ const int offset = TILE_N * TILE_K;
1587
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
1588
+ const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
1589
+ const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2));
1590
+
1591
+ __m512i vsum = _mm512_setzero_si512();
1592
+ for (int k = 0; k < 8; ++k) {
1593
+ vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);
1594
+ }
1595
+ vsum = _mm512_sub_epi32(vsum, vcomp);
1596
+
1597
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
1598
+ };
1599
+
1600
+ for (int i = 0; i < KB; ++i) {
1601
+ Unroll<COLS>{}(compute, i);
1602
+ }
1603
+
1604
+ //store to C
1605
+ auto storec = [&](int col) {
1606
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1607
+ };
1608
+ Unroll<COLS>{}(storec);
1609
+ }
1610
+ };
1611
+
1612
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1613
+ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1614
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1615
+
1616
+ constexpr int COLS = BLOCK_N / 16;
1617
+ const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;
1618
+
1619
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1620
+ const char * RESTRICT B = static_cast<const char *>(_B);
1621
+
1622
+ // a.qs: 8 groups, 32 bytes each group (m256i)
1623
+ __m512i va[8];
1624
+ // a.bsum: 8 groups, 2 bytes each group (m128i)
1625
+ __m512i va_bsum;
1626
+ __m512 vc[COLS];
1627
+ __m512 vd1;
1628
+
1629
+ // packed_B:
1630
+ const int offset_scales = (QK_K / 2) * TILE_N;
1631
+ const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N;
1632
+ const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N;
1633
+ const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
1634
+
1635
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1636
+
1637
+ auto loadc = [&](int col) {
1638
+ vc[col] = _mm512_setzero_ps();
1639
+ };
1640
+ Unroll<COLS>{}(loadc);
1641
+
1642
+ // Notes: vnni formats in QK_K
1643
+ // a) quants vnni format
1644
+ // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32
1645
+ // from {16, 32} to {8, 64}
1646
+ //
1647
+ // b) min vnni format
1648
+ // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
1649
+ // from {16, 8} to {4, 32}
1650
+ //
1651
+ auto compute = [&](int col, int i) {
1652
+ // load a
1653
+ if (col == 0) {
1654
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1655
+ va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1656
+ }
1657
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1658
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1659
+ va_bsum = _mm512_castsi128_si512(q8s);
1660
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1661
+ }
1662
+
1663
+ // step 1: accumultate the quants
1664
+ __m512i acc = _mm512_setzero_si512();
1665
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1666
+ const char * b_qs = b_ptr;
1667
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1668
+ __m512i vsum = _mm512_setzero_si512();
1669
+ for (int k = 0; k < 8; k += 2) {
1670
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
1671
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
1672
+
1673
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
1674
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1675
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1676
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1677
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1678
+
1679
+ b_qs += 64;
1680
+ }
1681
+ // vacc += scale * (q8 @ q4)
1682
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1683
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1684
+ }
1685
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1686
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1687
+
1688
+ // step 2: accumulate the mins
1689
+ __m512i acc_m = _mm512_setzero_si512();
1690
+ for (int k = 0; k < 4; ++k) {
1691
+ __m512i vmask = _mm512_set1_epi32(k);
1692
+ __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
1693
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
1694
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1695
+ }
1696
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
1697
+ vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
1698
+ };
1699
+
1700
+ for (int i = 0; i < KB; ++i) {
1701
+ Unroll<COLS>{}(compute, i);
1702
+ }
1703
+
1704
+ //store to C
1705
+ auto storec = [&](int col) {
1706
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1707
+ };
1708
+ Unroll<COLS>{}(storec);
1709
+ }
1710
+ };
1711
+
1712
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1713
+ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1714
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1715
+
1716
+ constexpr int COLS = BLOCK_N / 16;
1717
+ const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;
1718
+
1719
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1720
+ const char * RESTRICT B = static_cast<const char *>(_B);
1721
+
1722
+ // a.qs: 8 groups, 32 bytes each group (m256i)
1723
+ __m512i va[8];
1724
+ // a.bsum: 8 groups, 2 bytes each group (m128i)
1725
+ __m512i va_bsum;
1726
+ __m512 vc[COLS];
1727
+ __m512 vd1;
1728
+
1729
+ // packed_B:
1730
+ const int offset_qh = (QK_K / 2) * TILE_N;
1731
+ const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;
1732
+ const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N;
1733
+ const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;
1734
+ const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
1735
+
1736
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1737
+
1738
+ auto loadc = [&](int col) {
1739
+ vc[col] = _mm512_setzero_ps();
1740
+ };
1741
+ Unroll<COLS>{}(loadc);
1742
+
1743
+ // Q5_K and Q4_K shares the same vnni formats, refer to notes above.
1744
+ auto compute = [&](int col, int i) {
1745
+ // load a
1746
+ if (col == 0) {
1747
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1748
+ va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1749
+ }
1750
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1751
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1752
+ va_bsum = _mm512_castsi128_si512(q8s);
1753
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1754
+ }
1755
+
1756
+ // step 1: accumultate the quants
1757
+ __m512i acc = _mm512_setzero_si512();
1758
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1759
+ const char * b_qs = b_ptr;
1760
+ const char * b_qh = b_ptr + offset_qh;
1761
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1762
+ __m512i vsum = _mm512_setzero_si512();
1763
+ __m512i hmask0 = _mm512_set1_epi8(0x1);
1764
+ __m512i hmask1 = _mm512_set1_epi8(0x2);
1765
+ __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64));
1766
+ for (int k = 0; k < 8; k += 2) {
1767
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
1768
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
1769
+
1770
+ __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
1771
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1772
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1773
+
1774
+ __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);
1775
+ __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);
1776
+
1777
+ hmask0 = _mm512_slli_epi16(hmask0, 2);
1778
+ hmask1 = _mm512_slli_epi16(hmask1, 2);
1779
+ vb0 = _mm512_add_epi8(vb0, vh0);
1780
+ vb1 = _mm512_add_epi8(vb1, vh1);
1781
+
1782
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1783
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1784
+
1785
+ b_qs += 64;
1786
+ }
1787
+ // vacc += scale * (q8 @ q5)
1788
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1789
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1790
+ }
1791
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1792
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1793
+
1794
+ // step 2: accumulate the mins
1795
+ __m512i acc_m = _mm512_setzero_si512();
1796
+ for (int k = 0; k < 4; ++k) {
1797
+ __m512i vmask = _mm512_set1_epi32(k);
1798
+ __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
1799
+ __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
1800
+ acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
1801
+ }
1802
+ const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
1803
+ vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
1804
+ };
1805
+
1806
+ for (int i = 0; i < KB; ++i) {
1807
+ Unroll<COLS>{}(compute, i);
1808
+ }
1809
+
1810
+ //store to C
1811
+ auto storec = [&](int col) {
1812
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1813
+ };
1814
+ Unroll<COLS>{}(storec);
1815
+ }
1816
+ };
1817
+
1818
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1819
+ struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1820
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1821
+
1822
+ constexpr int COLS = BLOCK_N / 16;
1823
+ const int TILE_SIZE = TILE_N * sizeof(block_q6_K);
1824
+
1825
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1826
+ const char * RESTRICT B = static_cast<const char *>(_B);
1827
+
1828
+ // load the 256 bytes from A to 4 avx512 vectors
1829
+ __m512i va[4];
1830
+ __m512 vc[COLS];
1831
+ __m512 vd1;
1832
+
1833
+ // packed_B:
1834
+ const int offset_qh = (QK_K / 2) * TILE_N;
1835
+ const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;
1836
+ const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;
1837
+
1838
+ // compensation
1839
+ __m512i vcomp;
1840
+
1841
+ const __m512i m32s = _mm512_set1_epi32(32);
1842
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1843
+
1844
+ auto loadc = [&](int col) {
1845
+ vc[col] = _mm512_setzero_ps();
1846
+ };
1847
+ Unroll<COLS>{}(loadc);
1848
+
1849
+ auto compute = [&](int col, int i) {
1850
+ if (col == 0) {
1851
+ // load a
1852
+ va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1853
+ va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
1854
+ va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
1855
+ va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
1856
+
1857
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1858
+ vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);
1859
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1860
+ }
1861
+
1862
+ // accmulate the quants
1863
+ __m512i acc = _mm512_setzero_si512();
1864
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1865
+ const char * b_qs = b_ptr;
1866
+ const char * b_qh = b_ptr + offset_qh;
1867
+ int mask = 0;
1868
+ for (int k_group = 0; k_group < QK_K / 16; ++k_group) {
1869
+ int r = k_group >> 2;
1870
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1871
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1872
+
1873
+ __m512i vsum = _mm512_setzero_si512();
1874
+ __m512i hmask = _mm512_set1_epi8(0x3);
1875
+
1876
+ __m512i bytes = _mm512_loadu_si512(b_qs);
1877
+ __m512i hbits = _mm512_loadu_si512(b_qh);
1878
+ __m512i vb0 = _mm512_and_si512(bytes, lowMask);
1879
+ __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1880
+ __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);
1881
+ __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);
1882
+
1883
+ vb0 = _mm512_add_epi8(vb0, vh0);
1884
+ vb1 = _mm512_add_epi8(vb1, vh1);
1885
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1886
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1887
+ b_qs += 64;
1888
+
1889
+ va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1890
+ va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1891
+
1892
+ bytes = _mm512_loadu_si512(b_qs);
1893
+ vb0 = _mm512_and_si512(bytes, lowMask);
1894
+ vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
1895
+ vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));
1896
+ vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);
1897
+ vb0 = _mm512_add_epi8(vb0, vh0);
1898
+ vb1 = _mm512_add_epi8(vb1, vh1);
1899
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1900
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
1901
+ b_qs += 64;
1902
+ b_qh += 64;
1903
+
1904
+ // B * A - 32 * A
1905
+ __m512i vmask = _mm512_set1_epi32(k_group);
1906
+ vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
1907
+
1908
+ // vacc += scale * (q8 @ q6)
1909
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
1910
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
1911
+ }
1912
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
1913
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
1914
+ };
1915
+
1916
+ for (int i = 0; i < KB; ++i) {
1917
+ Unroll<COLS>{}(compute, i);
1918
+ }
1919
+
1920
+ //store to C
1921
+ auto storec = [&](int col) {
1922
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1923
+ };
1924
+ Unroll<COLS>{}(storec);
1925
+ }
1926
+ };
1927
+
1928
+ template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
1929
+ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {
1930
+ static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
1931
+
1932
+ constexpr int COLS = BLOCK_N / 16;
1933
+ const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;
1934
+
1935
+ const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
1936
+ const char * RESTRICT B = static_cast<const char *>(_B);
1937
+
1938
+ // load the 256 bytes from A to 4 avx512 vectors
1939
+ __m512i va[4];
1940
+ __m512 vc[COLS];
1941
+ __m512 vd1;
1942
+
1943
+ // packed_B:
1944
+ const int offset_scales = (QK_K / 2) * TILE_N ;
1945
+ const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N;
1946
+
1947
+ // compensation
1948
+ __m512i vcomp;
1949
+
1950
+ const __m256i m128s = _mm256_set1_epi16(128);
1951
+ const __m512i lowMask = _mm512_set1_epi8(0xF);
1952
+
1953
+ const __m512i values128 = _mm512_set_epi8(
1954
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1955
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1956
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
1957
+ 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
1958
+ );
1959
+ const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1960
+ const __m512i values256 = _mm512_add_epi8(values128, off);
1961
+
1962
+ auto loadc = [&](int col) {
1963
+ vc[col] = _mm512_setzero_ps();
1964
+ };
1965
+ Unroll<COLS>{}(loadc);
1966
+
1967
+ auto compute = [&](int col, int i) {
1968
+ if (col == 0) {
1969
+ // load a
1970
+ va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1971
+ va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
1972
+ va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
1973
+ va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
1974
+
1975
+ // compensation: 128 * A
1976
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
1977
+ vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));
1978
+ vd1 = _mm512_set1_ps(A[0 * KB + i].d);
1979
+ }
1980
+
1981
+ // accmulate the quants
1982
+ __m512i acc = _mm512_setzero_si512();
1983
+ const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
1984
+ const char * b_qs = b_ptr;
1985
+ int mask = 0;
1986
+ for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1987
+ int r = k_group >> 1;
1988
+ __m512i vmask = _mm512_set1_epi32(k_group);
1989
+ __m512i vsum = _mm512_setzero_si512();
1990
+ for (int k = 0; k < 8; k += 2) {
1991
+ __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1992
+ __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
1993
+
1994
+ __m512i bytes = _mm512_loadu_si512(b_qs);
1995
+ __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));
1996
+ __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
1997
+
1998
+ vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
1999
+ vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
2000
+ b_qs += 64;
2001
+ }
2002
+ // (B + 128) * A - 128 * A
2003
+ vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
2004
+
2005
+ // vacc += scale * (q8 @ q4)
2006
+ const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
2007
+ acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
2008
+ }
2009
+ const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
2010
+ vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
2011
+ };
2012
+
2013
+ for (int i = 0; i < KB; ++i) {
2014
+ Unroll<COLS>{}(compute, i);
2015
+ }
2016
+
2017
+ //store to C
2018
+ auto storec = [&](int col) {
2019
+ _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
2020
+ };
2021
+ Unroll<COLS>{}(storec);
2022
+ }
2023
+ };
2024
+
2025
+ #define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
2026
+ tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
2027
+ KB, (const char *)wdata + 0 * row_size_A, \
2028
+ (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
2029
+ (float *) dst->data + 0 * N + nb_start, ldc)
2030
+
2031
+ template <typename TA, typename TB, typename TC, int BLOCK_K,
2032
+ typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
2033
+ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) {
2034
+ using packed_B_t = packed_B_type<TB>;
2035
+ const int TILE_SIZE = get_tile_size<TB>();
2036
+ const bool need_unpack = do_unpack<TB>::value;
2037
+
2038
+ GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
2039
+ const TA * RESTRICT A = static_cast<const TA *>(_A);
2040
+ const char * RESTRICT B = static_cast<const char *>(_B);
2041
+
2042
+ const int m0 = std::min(M, TILE_M);
2043
+ const int m1 = std::max(M - TILE_M, 0);
2044
+ const int lda = KB * sizeof(TA);
2045
+ //const int ldb = KB * sizeof(TB);
2046
+
2047
+ static thread_local packed_B_t Tile0[TILE_N * TILE_K];
2048
+ static thread_local packed_B_t Tile1[TILE_N * TILE_K];
2049
+ static thread_local int8_t Tile23[TILE_M * TILE_K];
2050
+
2051
+ static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
2052
+ static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
2053
+
2054
+ // double buffering C to interleave avx512 and amx
2055
+ int32_t * C_cur = TileC0;
2056
+ int32_t * C_pre = TileC1;
2057
+
2058
+ auto Tile4 = [&](int32_t * base) { return base; };
2059
+ auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; };
2060
+ auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; };
2061
+ auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; };
2062
+
2063
+ if (M == 2 * TILE_M) {
2064
+ // i = 0
2065
+ const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);
2066
+ const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);
2067
+ if (need_unpack) {
2068
+ unpack_B<TB>(Tile0, B_blk0);
2069
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2070
+ } else {
2071
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2072
+ }
2073
+
2074
+ _tile_zero(TMM4);
2075
+ _tile_loadd(TMM2, A[0].qs, lda);
2076
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2077
+ _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));
2078
+
2079
+ _tile_zero(TMM5);
2080
+ _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);
2081
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2082
+ _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
2083
+
2084
+ if (need_unpack) {
2085
+ unpack_B<TB>(Tile1, B_blk0);
2086
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2087
+ } else {
2088
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2089
+ }
2090
+
2091
+ _tile_zero(TMM6);
2092
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2093
+ _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));
2094
+
2095
+ _tile_zero(TMM7);
2096
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2097
+ _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));
2098
+
2099
+ for (int i = 1; i < KB; ++i) {
2100
+ // index of previous iter
2101
+ const int ii = i - 1;
2102
+ const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
2103
+ const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
2104
+ GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {
2105
+ if (need_unpack) {
2106
+ unpack_B<TB>(Tile0, B_blk0);
2107
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2108
+ } else {
2109
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2110
+ }
2111
+ _tile_zero(TMM4);
2112
+ _tile_loadd(TMM2, A[i].qs, lda);
2113
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2114
+
2115
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2116
+ _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
2117
+
2118
+ _tile_zero(TMM5);
2119
+ _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);
2120
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2121
+
2122
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2123
+ _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
2124
+
2125
+ if (need_unpack) {
2126
+ unpack_B<TB>(Tile1, B_blk1);
2127
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2128
+ } else {
2129
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2130
+ }
2131
+ _tile_zero(TMM6);
2132
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2133
+
2134
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2135
+ _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
2136
+
2137
+ _tile_zero(TMM7);
2138
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2139
+
2140
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2141
+ _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
2142
+
2143
+ std::swap(C_cur, C_pre);
2144
+ });
2145
+ }
2146
+ // final accumulation
2147
+ {
2148
+ int ii = KB - 1;
2149
+ acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2150
+ acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
2151
+ acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2152
+ acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
2153
+ }
2154
+ } else {
2155
+ for (int i = 0; i < KB; ++i) {
2156
+ _tile_zero(TMM4);
2157
+ _tile_zero(TMM6);
2158
+ if (m1 != 0) {
2159
+ _tile_zero(TMM5);
2160
+ _tile_zero(TMM7);
2161
+ }
2162
+
2163
+ const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
2164
+ const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
2165
+ if (need_unpack) {
2166
+ unpack_B<TB>(Tile0, B_blk0);
2167
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2168
+ } else {
2169
+ _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
2170
+ }
2171
+
2172
+ if (need_unpack) {
2173
+ unpack_B<TB>(Tile1, B_blk1);
2174
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2175
+ } else {
2176
+ _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
2177
+ }
2178
+
2179
+ if (m0 == TILE_M) {
2180
+ _tile_loadd(TMM2, A[i].qs, lda);
2181
+ } else {
2182
+ unpack_A(Tile23, &A[i], KB, m0);
2183
+ _tile_loadd(TMM2, Tile23, TILE_K);
2184
+ }
2185
+
2186
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2187
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2188
+
2189
+ _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
2190
+ _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
2191
+
2192
+ GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2193
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
2194
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
2195
+ });
2196
+
2197
+ if (m1 != 0) {
2198
+ unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);
2199
+ _tile_loadd(TMM3, Tile23, TILE_K);
2200
+
2201
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2202
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2203
+ _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
2204
+ _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
2205
+ GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2206
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
2207
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
2208
+ });
2209
+ }
2210
+ }
2211
+ }
2212
+ return;
2213
+ }
2214
+
2215
+ template <typename TA, typename TB, typename TC, int BLOCK_K,
2216
+ typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
2217
+ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
2218
+ static_assert(std::is_same<TA, block_q8_K>::value);
2219
+ const int TILE_SIZE = get_tile_size<TB>();
2220
+
2221
+ GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
2222
+ const TA * RESTRICT A = static_cast<const TA *>(_A);
2223
+ const char * RESTRICT B = static_cast<const char *>(_B);
2224
+
2225
+ const int m0 = std::min(M, TILE_M);
2226
+ const int m1 = std::max(M - TILE_M, 0);
2227
+ //const int lda = KB * sizeof(TA);
2228
+
2229
+ static thread_local int8_t Tile0[TILE_N * TILE_K];
2230
+ static thread_local int8_t Tile1[TILE_N * TILE_K];
2231
+ static thread_local int8_t Tile23[TILE_M * TILE_K];
2232
+
2233
+ // mat mul result for each group
2234
+ static thread_local int32_t Tile4[TILE_M * TILE_N];
2235
+ static thread_local int32_t Tile5[TILE_M * TILE_N];
2236
+ static thread_local int32_t Tile6[TILE_M * TILE_N];
2237
+ static thread_local int32_t Tile7[TILE_M * TILE_N];
2238
+
2239
+ // sum of each QK_K block, contains 8 groups, int32
2240
+ static thread_local int32_t Sumi4[TILE_M * TILE_N];
2241
+ static thread_local int32_t Sumi5[TILE_M * TILE_N];
2242
+ static thread_local int32_t Sumi6[TILE_M * TILE_N];
2243
+ static thread_local int32_t Sumi7[TILE_M * TILE_N];
2244
+
2245
+ const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
2246
+ for (int i = 0; i < KB; ++i) {
2247
+ // step 1: accumulate the quants across 8 groups, each group with 32
2248
+ for (int k = 0; k < QK_K / k_group_size; ++k) {
2249
+ GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {
2250
+ _tile_zero(TMM4);
2251
+ _tile_zero(TMM6);
2252
+
2253
+ unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);
2254
+ _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
2255
+
2256
+ unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);
2257
+ _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
2258
+
2259
+ unpack_A<TB>(Tile23, &A[i], KB, k, m0);
2260
+ _tile_loadd(TMM2, Tile23, TILE_K);
2261
+
2262
+ _tile_dpbssd(TMM4, TMM2, TMM0);
2263
+ _tile_dpbssd(TMM6, TMM2, TMM1);
2264
+
2265
+ _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
2266
+ _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
2267
+
2268
+ scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);
2269
+ scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);
2270
+
2271
+ if (m1 != 0) {
2272
+ _tile_zero(TMM5);
2273
+ _tile_zero(TMM7);
2274
+
2275
+ unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);
2276
+ _tile_loadd(TMM3, Tile23, TILE_K);
2277
+
2278
+ _tile_dpbssd(TMM5, TMM3, TMM0);
2279
+ _tile_dpbssd(TMM7, TMM3, TMM1);
2280
+
2281
+ _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
2282
+ _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
2283
+
2284
+ scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);
2285
+ scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);
2286
+ }
2287
+ });
2288
+ }
2289
+
2290
+ // step 2: accmulate the mins
2291
+ GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
2292
+ acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
2293
+ acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
2294
+ if (m1 != 0) {
2295
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
2296
+ acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
2297
+ }
2298
+ });
2299
+ }
2300
+ return;
2301
+ }
2302
+
2303
+ } // anonymous namespace
2304
+
2305
+ // get the packed tensor size for quantized weights
2306
+ size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) {
2307
+ const enum ggml_type TYPE = tensor->type;
2308
+
2309
+ const int K = tensor->ne[0]; // ne0: in_features
2310
+ const int N = tensor->ne[1]; // ne1: out_features
2311
+
2312
+ auto get_tensor_size = [&] {
2313
+ size_t row_size_B{0};
2314
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2315
+ row_size_B = get_row_size<type, blck_size>(K);
2316
+ });
2317
+ return N * row_size_B;
2318
+ };
2319
+
2320
+ if (qtype_has_amx_kernels(TYPE)) {
2321
+ return get_tensor_size();
2322
+ } else {
2323
+ // for f16, bf16 we don't do packing
2324
+ return ggml_nbytes(tensor);
2325
+ }
2326
+ }
2327
+
2328
+ // pack weight to vnni format
2329
+ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2330
+
2331
+ size_t alloc_size = ggml_backend_amx_get_alloc_size(tensor);
2332
+ GGML_ASSERT(alloc_size == size);
2333
+
2334
+ const enum ggml_type TYPE = tensor->type;
2335
+
2336
+ const int K = tensor->ne[0]; // ne0: in_features
2337
+ const int N = tensor->ne[1]; // ne1: out_features
2338
+
2339
+ #if defined(_OPENMP)
2340
+ // the buffer ctx is not initialized when .set_tensor is called
2341
+ int n_threads = omp_get_num_threads();
2342
+ #else
2343
+ int n_threads = 1;
2344
+ #endif
2345
+
2346
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2347
+ convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads);
2348
+ });
2349
+ }
2350
+
2351
+ // NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)
2352
+ //
2353
+ // src0: weight in shape of {N, K}, quantized
2354
+ // src1: input in shape of {M, K}, float32
2355
+ // dst: output in shape of {M, N}, float32
2356
+ //
2357
+ // the function performs: dst = src1 @ src0.T
2358
+ //
2359
+ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
2360
+ struct ggml_tensor * src0 = dst->src[0];
2361
+ struct ggml_tensor * src1 = dst->src[1];
2362
+
2363
+ const enum ggml_type TYPE = src0->type;
2364
+
2365
+ const int n_threads = ctx->n_threads;
2366
+
2367
+ // f16 only has avx512 kernels for now,
2368
+ // amx kernels will be added once 6th gen xeon is released.
2369
+ const bool is_floating_type = TYPE == GGML_TYPE_F16;
2370
+
2371
+ const int M = dst->ne[1];
2372
+ const int N = dst->ne[0];
2373
+ const int K = src0->ne[0];
2374
+ const int ldc = dst->nb[1] / dst->nb[0];
2375
+
2376
+ if (is_floating_type) {
2377
+ constexpr int BLOCK_M = 4;
2378
+ constexpr int BLOCK_N = 6;
2379
+ const int MB = div_up(M, BLOCK_M);
2380
+ const int NB = div_up(N, BLOCK_N);
2381
+
2382
+ parallel_for(n_threads, MB * NB, [&](int begin, int end) {
2383
+ GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
2384
+ for (int i = begin; i < end; ++i) {
2385
+ int mb = i / NB;
2386
+ int nb = i % NB;
2387
+
2388
+ int mb_start = mb * BLOCK_M;
2389
+ int mb_size = std::min(BLOCK_M, M - mb_start);
2390
+ int nb_start = nb * BLOCK_N;
2391
+ int nb_size = std::min(BLOCK_N, N - nb_start);
2392
+
2393
+ switch (mb_size << 4 | nb_size) {
2394
+ case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break;
2395
+ case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break;
2396
+ case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break;
2397
+ case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break;
2398
+ case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break;
2399
+ case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break;
2400
+ case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break;
2401
+ case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break;
2402
+ case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break;
2403
+ case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break;
2404
+ case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break;
2405
+ case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break;
2406
+ default: fprintf(stderr, "Unexpected block size!\n");
2407
+ }
2408
+ }
2409
+ });
2410
+ });
2411
+ return;
2412
+ }
2413
+
2414
+ // pointer to work space, used convert A from float to quantized type
2415
+ void * wdata = nullptr;
2416
+
2417
+ //TODO: performance improvement: merge quant A
2418
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2419
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2420
+ const size_t desired_wsize = M * row_size_A;
2421
+ if (ctx->work_size < desired_wsize) {
2422
+ ctx->work_data.reset(new char[desired_wsize]);
2423
+ ctx->work_size = desired_wsize;
2424
+ }
2425
+ wdata = ctx->work_data.get();
2426
+
2427
+ // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
2428
+ // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
2429
+ GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
2430
+
2431
+ const float * A_data = static_cast<const float *>(src1->data);
2432
+ for (int m = 0; m < M; ++m) {
2433
+ from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
2434
+ }
2435
+ });
2436
+
2437
+ if (M == 1) {
2438
+ // MB = 1 and handle 8 tiles in each block
2439
+ constexpr int kTilesN = 4;
2440
+ constexpr int BLOCK_N = TILE_N * kTilesN;
2441
+ const int NB = div_up(N, BLOCK_N);
2442
+
2443
+ parallel_for(n_threads, NB, [&](int begin, int end) {
2444
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2445
+ const int KB = K / blck_size;
2446
+ const int TILE_SIZE = get_tile_size<type>();
2447
+ const int row_size_A = KB * sizeof(vec_dot_type);
2448
+ for (int i = begin; i < end; ++i) {
2449
+ int nb = i;
2450
+ int nb_start = nb * BLOCK_N;
2451
+ int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
2452
+
2453
+ switch (nb_size) {
2454
+ //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;
2455
+ case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break;
2456
+ case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break;
2457
+ case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break;
2458
+ case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break;
2459
+ default: fprintf(stderr, "Unexpected n block size!\n");
2460
+ }
2461
+ }
2462
+ });
2463
+ });
2464
+ return;
2465
+ }
2466
+
2467
+ // handle 4 tiles at a tile
2468
+ constexpr int BLOCK_M = TILE_M * 2;
2469
+ constexpr int BLOCK_N = TILE_N * 2;
2470
+ const int MB = div_up(M, BLOCK_M);
2471
+ const int NB = div_up(N, BLOCK_N);
2472
+
2473
+ parallel_for(n_threads, MB * NB, [&](int begin, int end) {
2474
+ // init tile config for each thread
2475
+ ggml_tile_config_init();
2476
+
2477
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2478
+ const int KB = K / blck_size;
2479
+ const int TILE_SIZE = get_tile_size<type>();
2480
+ const int row_size_A = KB * sizeof(vec_dot_type);
2481
+
2482
+ for (int i = begin; i < end; ++i) {
2483
+ int mb = i / NB;
2484
+ int nb = i % NB;
2485
+
2486
+ int mb_start = mb * BLOCK_M;
2487
+ int mb_size = std::min(BLOCK_M, M - mb_start);
2488
+ int nb_start = nb * BLOCK_N;
2489
+ int nb_size = BLOCK_N;
2490
+
2491
+ tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
2492
+ mb_size, nb_size, KB,
2493
+ (const char *)wdata + mb_start * row_size_A,
2494
+ (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
2495
+ (float *) dst->data + mb_start * N + nb_start, ldc);
2496
+ }
2497
+ });
2498
+ });
2499
+ }
2500
+
2501
+ #else // if defined(__AMX_INT8__)
2502
+
2503
+ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
2504
+ fprintf(stderr, "GGML is not compiled with AMX support!\n");
2505
+
2506
+ GGML_UNUSED(ctx);
2507
+ GGML_UNUSED(dst);
2508
+ }
2509
+
2510
+ #endif // if defined(__AMX_INT8__)