whispercpp 1.3.0 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -0,0 +1,2510 @@
|
|
1
|
+
|
2
|
+
#if defined(__GNUC__)
|
3
|
+
#pragma GCC diagnostic ignored "-Wpedantic"
|
4
|
+
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
5
|
+
#endif
|
6
|
+
|
7
|
+
#include "mmq.h"
|
8
|
+
#include "ggml-impl.h"
|
9
|
+
#include "ggml-quants.h"
|
10
|
+
#include <algorithm>
|
11
|
+
#include <type_traits>
|
12
|
+
|
13
|
+
#if defined(__gnu_linux__)
|
14
|
+
#include <sys/syscall.h>
|
15
|
+
#include <unistd.h>
|
16
|
+
#endif
|
17
|
+
|
18
|
+
#if defined(_OPENMP)
|
19
|
+
#include <omp.h>
|
20
|
+
#endif
|
21
|
+
|
22
|
+
#if (defined(_WIN32) || defined(_WIN64))
|
23
|
+
#define RESTRICT __restrict
|
24
|
+
#else
|
25
|
+
#define RESTRICT __restrict__
|
26
|
+
#endif
|
27
|
+
|
28
|
+
#if (defined(_WIN32) || defined(_WIN64))
|
29
|
+
#define ALWAYS_INLINE __forceinline
|
30
|
+
#elif __has_attribute(always_inline) || defined(__GNUC__)
|
31
|
+
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
32
|
+
#else
|
33
|
+
#define ALWAYS_INLINE inline
|
34
|
+
#endif
|
35
|
+
|
36
|
+
#if defined(__AMX_INT8__)
|
37
|
+
|
38
|
+
namespace {
|
39
|
+
|
40
|
+
// Forced unrolling
|
41
|
+
template <int n>
|
42
|
+
struct Unroll {
|
43
|
+
template <typename Func, typename... Args>
|
44
|
+
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
45
|
+
Unroll<n - 1>{}(f, args...);
|
46
|
+
f(std::integral_constant<int, n - 1>{}, args...);
|
47
|
+
}
|
48
|
+
};
|
49
|
+
|
50
|
+
template <>
|
51
|
+
struct Unroll<1> {
|
52
|
+
template <typename Func, typename... Args>
|
53
|
+
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
54
|
+
f(std::integral_constant<int, 0>{}, args...);
|
55
|
+
}
|
56
|
+
};
|
57
|
+
|
58
|
+
// type traits
|
59
|
+
template <typename T> struct PackedTypes {};
|
60
|
+
template <> struct PackedTypes<block_q4_0> { using type = int8_t; };
|
61
|
+
template <> struct PackedTypes<block_q4_1> { using type = uint8_t; };
|
62
|
+
template <> struct PackedTypes<block_q8_0> { using type = int8_t; };
|
63
|
+
template <typename T> using packed_B_type = typename PackedTypes<T>::type;
|
64
|
+
|
65
|
+
template <typename T>
|
66
|
+
struct do_compensate : std::integral_constant<bool,
|
67
|
+
std::is_same<T, block_q8_0>::value> {};
|
68
|
+
|
69
|
+
template <typename T>
|
70
|
+
struct do_unpack : std::integral_constant<bool,
|
71
|
+
std::is_same<T, block_q4_0>::value ||
|
72
|
+
std::is_same<T, block_q4_1>::value> {};
|
73
|
+
|
74
|
+
template <typename T>
|
75
|
+
struct is_type_qkk : std::integral_constant<bool,
|
76
|
+
std::is_same<T, block_q4_K>::value ||
|
77
|
+
std::is_same<T, block_q5_K>::value ||
|
78
|
+
std::is_same<T, block_q6_K>::value ||
|
79
|
+
std::is_same<T, block_iq4_xs>::value> {};
|
80
|
+
|
81
|
+
#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \
|
82
|
+
[&] { \
|
83
|
+
switch (TYPE) { \
|
84
|
+
case GGML_TYPE_F16: { \
|
85
|
+
using type = ggml_fp16_t; \
|
86
|
+
constexpr int blck_size = 16; \
|
87
|
+
return __VA_ARGS__(); \
|
88
|
+
} \
|
89
|
+
case GGML_TYPE_BF16: { \
|
90
|
+
using type = ggml_bf16_t; \
|
91
|
+
constexpr int blck_size = 32; \
|
92
|
+
return __VA_ARGS__(); \
|
93
|
+
} \
|
94
|
+
default: \
|
95
|
+
fprintf(stderr, "Unsupported floating data type\n"); \
|
96
|
+
} \
|
97
|
+
}()
|
98
|
+
|
99
|
+
#define GGML_DISPATCH_QTYPES(QT, ...) \
|
100
|
+
[&] { \
|
101
|
+
switch (QT) { \
|
102
|
+
case GGML_TYPE_Q4_0: { \
|
103
|
+
using type = block_q4_0; \
|
104
|
+
using vec_dot_type = block_q8_0; \
|
105
|
+
constexpr int blck_size = QK4_0; \
|
106
|
+
return __VA_ARGS__(); \
|
107
|
+
} \
|
108
|
+
case GGML_TYPE_Q4_1: { \
|
109
|
+
using type = block_q4_1; \
|
110
|
+
using vec_dot_type = block_q8_1; \
|
111
|
+
constexpr int blck_size = QK4_1; \
|
112
|
+
return __VA_ARGS__(); \
|
113
|
+
} \
|
114
|
+
case GGML_TYPE_Q8_0: { \
|
115
|
+
using type = block_q8_0; \
|
116
|
+
using vec_dot_type = block_q8_0; \
|
117
|
+
constexpr int blck_size = QK8_0; \
|
118
|
+
return __VA_ARGS__(); \
|
119
|
+
} \
|
120
|
+
case GGML_TYPE_Q4_K: { \
|
121
|
+
using type = block_q4_K; \
|
122
|
+
using vec_dot_type = block_q8_K; \
|
123
|
+
constexpr int blck_size = QK_K; \
|
124
|
+
return __VA_ARGS__(); \
|
125
|
+
} \
|
126
|
+
case GGML_TYPE_Q5_K: { \
|
127
|
+
using type = block_q5_K; \
|
128
|
+
using vec_dot_type = block_q8_K; \
|
129
|
+
constexpr int blck_size = QK_K; \
|
130
|
+
return __VA_ARGS__(); \
|
131
|
+
} \
|
132
|
+
case GGML_TYPE_Q6_K: { \
|
133
|
+
using type = block_q6_K; \
|
134
|
+
using vec_dot_type = block_q8_K; \
|
135
|
+
constexpr int blck_size = QK_K; \
|
136
|
+
return __VA_ARGS__(); \
|
137
|
+
} \
|
138
|
+
case GGML_TYPE_IQ4_XS: { \
|
139
|
+
using type = block_iq4_xs; \
|
140
|
+
using vec_dot_type = block_q8_K; \
|
141
|
+
constexpr int blck_size = QK_K; \
|
142
|
+
return __VA_ARGS__(); \
|
143
|
+
} \
|
144
|
+
default: \
|
145
|
+
fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \
|
146
|
+
} \
|
147
|
+
}()
|
148
|
+
|
149
|
+
#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
150
|
+
[&] { \
|
151
|
+
if (BOOL_V) { \
|
152
|
+
constexpr bool BOOL_NAME = true; \
|
153
|
+
return __VA_ARGS__(); \
|
154
|
+
} else { \
|
155
|
+
constexpr bool BOOL_NAME = false; \
|
156
|
+
return __VA_ARGS__(); \
|
157
|
+
} \
|
158
|
+
}()
|
159
|
+
|
160
|
+
// define amx tile config data structure
|
161
|
+
struct tile_config_t{
|
162
|
+
uint8_t palette_id = 0;
|
163
|
+
uint8_t start_row = 0;
|
164
|
+
uint8_t reserved_0[14] = {0};
|
165
|
+
uint16_t colsb[16] = {0};
|
166
|
+
uint8_t rows[16] = {0};
|
167
|
+
};
|
168
|
+
|
169
|
+
// Notes: amx tile config
|
170
|
+
//
|
171
|
+
// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,
|
172
|
+
// and accumulate the result to a 16 x 16 matrix C containing INT32 values,
|
173
|
+
//
|
174
|
+
// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used
|
175
|
+
// instead of the normally used 16-16-64 config.
|
176
|
+
//
|
177
|
+
// Block A: {16, 32}, dtype = int8_t
|
178
|
+
// Block B: {16, 32}, dtype = uint8_t/int8_t
|
179
|
+
// Block C: {16, 16}, dtype = int32_t
|
180
|
+
//
|
181
|
+
// Block B needs to be prepacked to vnni format before feeding into TMUL:
|
182
|
+
// packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}
|
183
|
+
//
|
184
|
+
// Therefore, we get tileconfig:
|
185
|
+
// A B C
|
186
|
+
// rows 16 8 16
|
187
|
+
// colsb 32 64 16
|
188
|
+
//
|
189
|
+
// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,
|
190
|
+
// C used TMM4-TMM7:
|
191
|
+
// B TMM0 B TMM1
|
192
|
+
// A TMM2 C TMM4 C TMM6
|
193
|
+
// A TMM3 C TMM5 C TMM7
|
194
|
+
//
|
195
|
+
// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A
|
196
|
+
// will be needed.
|
197
|
+
//
|
198
|
+
// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;
|
199
|
+
// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.
|
200
|
+
//
|
201
|
+
// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/
|
202
|
+
// advanced-matrix-extensions-intrinsics-functions.html
|
203
|
+
//
|
204
|
+
|
205
|
+
#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
|
206
|
+
void ggml_tile_config_init(void) {
|
207
|
+
static thread_local bool is_first_time = true;
|
208
|
+
|
209
|
+
if (!is_first_time) {
|
210
|
+
return;
|
211
|
+
}
|
212
|
+
|
213
|
+
static thread_local tile_config_t tc;
|
214
|
+
tile_config_t current_tc;
|
215
|
+
_tile_storeconfig(¤t_tc);
|
216
|
+
|
217
|
+
// load only when config changes
|
218
|
+
if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
|
219
|
+
memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
|
220
|
+
tc.palette_id = 1;
|
221
|
+
tc.start_row = 0;
|
222
|
+
TC_CONFIG_TILE(TMM0, 8, 64);
|
223
|
+
TC_CONFIG_TILE(TMM1, 8, 64);
|
224
|
+
TC_CONFIG_TILE(TMM2, 16, 32);
|
225
|
+
TC_CONFIG_TILE(TMM3, 16, 32);
|
226
|
+
TC_CONFIG_TILE(TMM4, 16, 64);
|
227
|
+
TC_CONFIG_TILE(TMM5, 16, 64);
|
228
|
+
TC_CONFIG_TILE(TMM6, 16, 64);
|
229
|
+
TC_CONFIG_TILE(TMM7, 16, 64);
|
230
|
+
_tile_loadconfig(&tc);
|
231
|
+
}
|
232
|
+
|
233
|
+
is_first_time = false;
|
234
|
+
}
|
235
|
+
|
236
|
+
// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
|
237
|
+
// See the notes `s8s8 igemm compensation in avx512-vnni` for detail.
|
238
|
+
template <typename TB>
|
239
|
+
int get_tile_size() {
|
240
|
+
int tile_size = TILE_N * sizeof(TB);
|
241
|
+
if (do_compensate<TB>::value) {
|
242
|
+
tile_size += TILE_N * sizeof(int32_t);
|
243
|
+
}
|
244
|
+
if (std::is_same<TB, block_q4_K>::value ||
|
245
|
+
std::is_same<TB, block_q5_K>::value) {
|
246
|
+
tile_size += TILE_N * 4;
|
247
|
+
}
|
248
|
+
if (std::is_same<TB, block_iq4_xs>::value) {
|
249
|
+
tile_size += TILE_N * 2;
|
250
|
+
}
|
251
|
+
return tile_size;
|
252
|
+
}
|
253
|
+
|
254
|
+
template <typename TB, int BLOCK_K>
|
255
|
+
int get_row_size(int K) {
|
256
|
+
int KB = K / BLOCK_K;
|
257
|
+
int row_size = KB * sizeof(TB);
|
258
|
+
if (do_compensate<TB>::value) {
|
259
|
+
row_size += KB * sizeof(int32_t);
|
260
|
+
}
|
261
|
+
if (std::is_same<TB, block_q4_K>::value ||
|
262
|
+
std::is_same<TB, block_q5_K>::value) {
|
263
|
+
row_size += KB * 4;
|
264
|
+
}
|
265
|
+
if (std::is_same<TB, block_iq4_xs>::value) {
|
266
|
+
row_size += KB * 2;
|
267
|
+
}
|
268
|
+
return row_size;
|
269
|
+
}
|
270
|
+
|
271
|
+
// vectorized dtype conversion
|
272
|
+
inline float FP16_TO_FP32(ggml_half val) {
|
273
|
+
__m256i v = _mm256_setr_epi16(
|
274
|
+
val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
275
|
+
__m512 o = _mm512_cvtph_ps(v);
|
276
|
+
return _mm512_cvtss_f32(o);
|
277
|
+
}
|
278
|
+
|
279
|
+
inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
|
280
|
+
__m256i v = _mm256_set1_epi16(val);
|
281
|
+
return _mm512_cvtph_ps(v);
|
282
|
+
}
|
283
|
+
|
284
|
+
// horizontal reduce
|
285
|
+
inline float _mm512_reduce_max_ps(const __m512 x) {
|
286
|
+
__m512 v = x;
|
287
|
+
__m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
|
288
|
+
v = _mm512_max_ps(v, v1);
|
289
|
+
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
|
290
|
+
v = _mm512_max_ps(v, v1);
|
291
|
+
v1 = _mm512_shuffle_ps(v, v, 0x4E);
|
292
|
+
v = _mm512_max_ps(v, v1);
|
293
|
+
v1 = _mm512_shuffle_ps(v, v, 0xB1);
|
294
|
+
v = _mm512_max_ps(v, v1);
|
295
|
+
return _mm512_cvtss_f32(v);
|
296
|
+
}
|
297
|
+
|
298
|
+
// transpose utils
|
299
|
+
#define SHUFFLE_EPI32(a, b, mask) \
|
300
|
+
_mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
|
301
|
+
inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) {
|
302
|
+
// unpacking and 32-bit elements
|
303
|
+
v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);
|
304
|
+
v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);
|
305
|
+
v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);
|
306
|
+
v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);
|
307
|
+
v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);
|
308
|
+
v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);
|
309
|
+
v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);
|
310
|
+
v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);
|
311
|
+
|
312
|
+
// shuffling the 32-bit elements
|
313
|
+
v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);
|
314
|
+
v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);
|
315
|
+
v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);
|
316
|
+
v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);
|
317
|
+
v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);
|
318
|
+
v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);
|
319
|
+
v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);
|
320
|
+
v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);
|
321
|
+
|
322
|
+
// shuffling 128-bit elements
|
323
|
+
v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);
|
324
|
+
v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);
|
325
|
+
v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);
|
326
|
+
v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);
|
327
|
+
v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);
|
328
|
+
v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);
|
329
|
+
v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);
|
330
|
+
v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);
|
331
|
+
}
|
332
|
+
|
333
|
+
inline void transpose_16x4_32bit(__m512i * r, __m512i * d) {
|
334
|
+
|
335
|
+
static const __m512i index1 = _mm512_set_epi32(
|
336
|
+
0x0f, 0x0b, 0x07, 0x03,
|
337
|
+
0x0e, 0x0a, 0x06, 0x02,
|
338
|
+
0x0d, 0x09, 0x05, 0x01,
|
339
|
+
0x0c, 0x08, 0x04, 0x00);
|
340
|
+
|
341
|
+
d[0] = _mm512_permutexvar_epi32(index1, r[0]);
|
342
|
+
d[1] = _mm512_permutexvar_epi32(index1, r[1]);
|
343
|
+
d[2] = _mm512_permutexvar_epi32(index1, r[2]);
|
344
|
+
d[3] = _mm512_permutexvar_epi32(index1, r[3]);
|
345
|
+
|
346
|
+
r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
|
347
|
+
r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
|
348
|
+
r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
|
349
|
+
r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
|
350
|
+
|
351
|
+
d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
|
352
|
+
d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
|
353
|
+
d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
|
354
|
+
d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
|
355
|
+
}
|
356
|
+
|
357
|
+
inline void transpose_16x16_32bit(__m512i * v) {
|
358
|
+
__m512i v1[16];
|
359
|
+
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
360
|
+
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
361
|
+
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
362
|
+
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
363
|
+
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
364
|
+
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
365
|
+
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
366
|
+
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
367
|
+
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
368
|
+
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
369
|
+
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
370
|
+
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
371
|
+
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
372
|
+
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
373
|
+
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
374
|
+
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
375
|
+
|
376
|
+
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
377
|
+
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
378
|
+
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
379
|
+
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
380
|
+
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
381
|
+
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
382
|
+
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
383
|
+
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
384
|
+
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
385
|
+
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
386
|
+
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
387
|
+
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
388
|
+
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
389
|
+
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
390
|
+
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
391
|
+
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
392
|
+
|
393
|
+
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
394
|
+
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
395
|
+
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
396
|
+
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
397
|
+
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
398
|
+
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
399
|
+
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
400
|
+
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
401
|
+
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
402
|
+
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
403
|
+
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
404
|
+
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
405
|
+
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
406
|
+
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
407
|
+
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
408
|
+
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
409
|
+
|
410
|
+
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
411
|
+
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
412
|
+
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
413
|
+
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
414
|
+
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
415
|
+
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
416
|
+
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
417
|
+
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
418
|
+
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
419
|
+
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
420
|
+
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
421
|
+
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
422
|
+
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
423
|
+
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
424
|
+
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
425
|
+
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
426
|
+
}
|
427
|
+
|
428
|
+
void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) {
|
429
|
+
assert(k % QK_K == 0);
|
430
|
+
const int KB = k / QK_K;
|
431
|
+
constexpr int kVecs = QK_K / 16;
|
432
|
+
|
433
|
+
block_q8_K * y = reinterpret_cast<block_q8_K *>(vy);
|
434
|
+
|
435
|
+
// hold 16 float vecs from x
|
436
|
+
__m512 v[kVecs];
|
437
|
+
|
438
|
+
// hold the quants vecs
|
439
|
+
__m512i vq[kVecs / 4];
|
440
|
+
|
441
|
+
// hold the packed quants vecs
|
442
|
+
__m512i vq_packed[kVecs / 4];
|
443
|
+
|
444
|
+
const __m512 signBit = _mm512_set1_ps(-0.f);
|
445
|
+
|
446
|
+
for (int i = 0; i < KB; ++i) {
|
447
|
+
// Compute max(abs(e)) for the block
|
448
|
+
__m512 vamax = _mm512_set1_ps(0.f);
|
449
|
+
for (int j = 0; j < kVecs; ++j) {
|
450
|
+
v[j] = _mm512_loadu_ps(x); x += 16;
|
451
|
+
vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));
|
452
|
+
}
|
453
|
+
const float amax = _mm512_reduce_max_ps(vamax);
|
454
|
+
|
455
|
+
// Quantize these floats
|
456
|
+
const float iscale = 127.f / amax;
|
457
|
+
y[i].d = GGML_FP32_TO_FP16(1 / iscale);
|
458
|
+
const float id = ( amax != 0.0f ) ? iscale : 0.f;
|
459
|
+
const __m512 vscale = _mm512_set1_ps(id);
|
460
|
+
|
461
|
+
// Apply multiplier and round to nearest integer
|
462
|
+
for (int j = 0; j < kVecs; ++j) {
|
463
|
+
v[j] = _mm512_mul_ps(v[j], vscale);
|
464
|
+
v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
465
|
+
}
|
466
|
+
|
467
|
+
// Pack to epi8 vecs
|
468
|
+
for (int j = 0; j < kVecs / 4; ++j) {
|
469
|
+
__m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));
|
470
|
+
__m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));
|
471
|
+
__m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));
|
472
|
+
__m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));
|
473
|
+
|
474
|
+
__m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);
|
475
|
+
__m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);
|
476
|
+
|
477
|
+
vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);
|
478
|
+
_mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]);
|
479
|
+
}
|
480
|
+
|
481
|
+
// Compute the bsums with vnni
|
482
|
+
transpose_16x4_32bit(vq, vq_packed);
|
483
|
+
|
484
|
+
const __m512i one = _mm512_set1_epi8(1);
|
485
|
+
__m512i sum = _mm512_setzero_si512();
|
486
|
+
for (int k = 0; k < 4; ++k) {
|
487
|
+
sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);
|
488
|
+
}
|
489
|
+
_mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum));
|
490
|
+
}
|
491
|
+
}
|
492
|
+
|
493
|
+
// quantize A from float to `vec_dot_type`
|
494
|
+
template <typename T>
|
495
|
+
inline void from_float(const float * x, char * vy, int64_t k);
|
496
|
+
|
497
|
+
template <>
|
498
|
+
inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
|
499
|
+
// FIXME: using unoptimized reference impl until moved to CPU backend
|
500
|
+
quantize_row_q8_0_ref(x, (block_q8_0 *)vy, k);
|
501
|
+
}
|
502
|
+
|
503
|
+
template <>
|
504
|
+
inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
|
505
|
+
quantize_row_q8_1_ref(x, (block_q8_1 *)vy, k);
|
506
|
+
}
|
507
|
+
|
508
|
+
template <>
|
509
|
+
inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
|
510
|
+
#if 1
|
511
|
+
// TODO: this is reference impl!
|
512
|
+
quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);
|
513
|
+
#else
|
514
|
+
quantize_row_q8_K_vnni(x, vy, k);
|
515
|
+
#endif
|
516
|
+
}
|
517
|
+
|
518
|
+
// load A from memory to array when nrows can not fill in whole tile
|
519
|
+
void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) {
|
520
|
+
assert(nr != TILE_M);
|
521
|
+
for (int m = 0; m < nr; ++m) {
|
522
|
+
const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
|
523
|
+
_mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
|
524
|
+
}
|
525
|
+
}
|
526
|
+
|
527
|
+
void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) {
|
528
|
+
assert(nr != TILE_M);
|
529
|
+
for (int m = 0; m < nr; ++m) {
|
530
|
+
const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
|
531
|
+
_mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
|
532
|
+
}
|
533
|
+
}
|
534
|
+
|
535
|
+
template <typename TB>
|
536
|
+
void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
|
537
|
+
assert(nr <= TILE_M);
|
538
|
+
for (int m = 0; m < nr; ++m) {
|
539
|
+
const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32));
|
540
|
+
_mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
|
541
|
+
}
|
542
|
+
}
|
543
|
+
|
544
|
+
template <>
|
545
|
+
void unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
|
546
|
+
assert(nr <= TILE_M);
|
547
|
+
// zero padding k from 16 to 32, so that we don't have to re-config amx
|
548
|
+
const __m128i zero = _mm_setzero_si128();
|
549
|
+
for (int m = 0; m < nr; ++m) {
|
550
|
+
const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16));
|
551
|
+
const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);
|
552
|
+
_mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r);
|
553
|
+
}
|
554
|
+
}
|
555
|
+
|
556
|
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
557
|
+
inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
558
|
+
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
|
559
|
+
const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
|
560
|
+
const __m256i lowMask = _mm256_set1_epi8(0xF);
|
561
|
+
return _mm256_and_si256(lowMask, bytes);
|
562
|
+
}
|
563
|
+
|
564
|
+
// used for block_q4_K
|
565
|
+
inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) {
|
566
|
+
const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi);
|
567
|
+
const __m256i lowMask = _mm256_set1_epi8(0xF);
|
568
|
+
const __m256i q4l = _mm256_and_si256(tmp, lowMask);
|
569
|
+
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);
|
570
|
+
return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);
|
571
|
+
}
|
572
|
+
|
573
|
+
// used for block_q5_K
|
574
|
+
inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) {
|
575
|
+
const __m256i lowMask = _mm256_set1_epi8(0xF);
|
576
|
+
__m256i hmask = _mm256_set1_epi8(1);
|
577
|
+
hmask = _mm256_slli_epi16(hmask, k);
|
578
|
+
|
579
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs);
|
580
|
+
const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh);
|
581
|
+
|
582
|
+
const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);
|
583
|
+
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);
|
584
|
+
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
585
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
586
|
+
|
587
|
+
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);
|
588
|
+
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);
|
589
|
+
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
590
|
+
|
591
|
+
return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);
|
592
|
+
}
|
593
|
+
|
594
|
+
// used for block_q6_K
|
595
|
+
inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) {
|
596
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
597
|
+
const __m256i m2 = _mm256_set1_epi8(0x3);
|
598
|
+
|
599
|
+
const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs);
|
600
|
+
const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32));
|
601
|
+
const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh);
|
602
|
+
|
603
|
+
const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4);
|
604
|
+
const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);
|
605
|
+
const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);
|
606
|
+
const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);
|
607
|
+
|
608
|
+
const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);
|
609
|
+
const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);
|
610
|
+
const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);
|
611
|
+
const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);
|
612
|
+
|
613
|
+
r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);
|
614
|
+
r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);
|
615
|
+
}
|
616
|
+
|
617
|
+
inline __m512i packNibbles(__m512i r0, __m512i r1) {
|
618
|
+
return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4));
|
619
|
+
}
|
620
|
+
|
621
|
+
template <typename TB>
|
622
|
+
inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) {
|
623
|
+
int8_t tmp[8 * 64];
|
624
|
+
__m256i v[8], v2[8];
|
625
|
+
for (int n = 0; n < 8; ++n) {
|
626
|
+
v[n] = bytes_from_nibbles_32(B[n * KB].qs);
|
627
|
+
}
|
628
|
+
transpose_8x8_32bit(v, v2);
|
629
|
+
for (int n = 0; n < 8; ++n) {
|
630
|
+
_mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]);
|
631
|
+
}
|
632
|
+
for (int n = 0; n < 8; ++n) {
|
633
|
+
v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);
|
634
|
+
}
|
635
|
+
transpose_8x8_32bit(v, v2);
|
636
|
+
for (int n = 0; n < 8; ++n) {
|
637
|
+
_mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]);
|
638
|
+
}
|
639
|
+
|
640
|
+
// pack again with 128 to fully utilize vector length
|
641
|
+
for (int n = 0; n < 8; n += 2) {
|
642
|
+
__m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64));
|
643
|
+
__m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64));
|
644
|
+
__m512i r1r0 = packNibbles(r0, r1);
|
645
|
+
_mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0);
|
646
|
+
}
|
647
|
+
}
|
648
|
+
|
649
|
+
template <>
|
650
|
+
inline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
|
651
|
+
__m256i v[8], v2[8];
|
652
|
+
for (int n = 0; n < 8; ++n) {
|
653
|
+
v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs));
|
654
|
+
}
|
655
|
+
transpose_8x8_32bit(v, v2);
|
656
|
+
for (int n = 0; n < 8; ++n) {
|
657
|
+
_mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]);
|
658
|
+
}
|
659
|
+
for (int n = 0; n < 8; ++n) {
|
660
|
+
v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs));
|
661
|
+
}
|
662
|
+
transpose_8x8_32bit(v, v2);
|
663
|
+
for (int n = 0; n < 8; ++n) {
|
664
|
+
_mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]);
|
665
|
+
}
|
666
|
+
}
|
667
|
+
|
668
|
+
template <>
|
669
|
+
inline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
|
670
|
+
__m512i v[16];
|
671
|
+
// QK_K 256 with 8 groups, handle 2 groups at a time
|
672
|
+
char * pb = (char *)packed_B;
|
673
|
+
for (int k = 0; k < QK_K / 64; ++k) {
|
674
|
+
// pack 2 groups { n, g, k} to {g, k/4, 4n}
|
675
|
+
// e.g. {16, 2, 32} to {2, 8, 64}
|
676
|
+
for (int n = 0; n < TILE_N; ++n) {
|
677
|
+
v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);
|
678
|
+
}
|
679
|
+
|
680
|
+
transpose_16x16_32bit(v);
|
681
|
+
|
682
|
+
// pack again with 128 to fully utilize vector length
|
683
|
+
for (int n = 0; n < TILE_N; n += 2) {
|
684
|
+
_mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
|
685
|
+
pb += 64;
|
686
|
+
}
|
687
|
+
}
|
688
|
+
}
|
689
|
+
|
690
|
+
template <>
|
691
|
+
inline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
|
692
|
+
__m512i v[16];
|
693
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
694
|
+
// QK_K 256 with 8 groups, handle 2 groups at a time
|
695
|
+
char * pb = (char *)packed_B;
|
696
|
+
char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
|
697
|
+
for (int k = 0; k < QK_K / 64; ++k) {
|
698
|
+
// pack 2 groups { n, g, k} to {g, k/4, 4n}
|
699
|
+
// e.g. {16, 2, 32} to {2, 8, 64}
|
700
|
+
for (int n = 0; n < TILE_N; ++n) {
|
701
|
+
v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k);
|
702
|
+
}
|
703
|
+
|
704
|
+
transpose_16x16_32bit(v);
|
705
|
+
|
706
|
+
// 1. pack lower 4bits with 2 groups
|
707
|
+
for (int n = 0; n < TILE_N; n += 2) {
|
708
|
+
// get lower 4 bits
|
709
|
+
const __m512i r0 = _mm512_and_si512(v[n], lowMask);
|
710
|
+
const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
|
711
|
+
_mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
|
712
|
+
}
|
713
|
+
|
714
|
+
// 2. pack higher 1bit with 2 groups
|
715
|
+
const __m512i hmask = _mm512_set1_epi8(0x10);
|
716
|
+
for (int g = 0; g < 2; ++g) {
|
717
|
+
__m512i hbits = _mm512_setzero_si512();
|
718
|
+
hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));
|
719
|
+
hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));
|
720
|
+
hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));
|
721
|
+
hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));
|
722
|
+
hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) );
|
723
|
+
hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));
|
724
|
+
hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));
|
725
|
+
hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));
|
726
|
+
_mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
|
727
|
+
}
|
728
|
+
}
|
729
|
+
}
|
730
|
+
|
731
|
+
template <>
|
732
|
+
inline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
|
733
|
+
__m512i v[32];
|
734
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
735
|
+
// QK_K 256 with 8 groups, handle 4 groups at a time
|
736
|
+
char * pb = (char *)packed_B;
|
737
|
+
char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
|
738
|
+
for (int k = 0; k < QK_K / 128; ++k) {
|
739
|
+
for (int n = 0; n < TILE_N; ++n) {
|
740
|
+
bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);
|
741
|
+
}
|
742
|
+
|
743
|
+
// top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7
|
744
|
+
transpose_16x16_32bit(v);
|
745
|
+
transpose_16x16_32bit(v + 16);
|
746
|
+
|
747
|
+
// 1. pack lower 4bits with 4 groups
|
748
|
+
for (int n = 0; n < 32; n += 2) {
|
749
|
+
const __m512i r0 = _mm512_and_si512(v[n], lowMask);
|
750
|
+
const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
|
751
|
+
_mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
|
752
|
+
}
|
753
|
+
|
754
|
+
// 2. pack higher 2bit with 4 groups
|
755
|
+
const __m512i hmask = _mm512_set1_epi8(0x30);
|
756
|
+
for (int g = 0; g < 8; ++g) {
|
757
|
+
__m512i hbits = _mm512_setzero_si512();
|
758
|
+
hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));
|
759
|
+
hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));
|
760
|
+
hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) );
|
761
|
+
hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));
|
762
|
+
_mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
|
763
|
+
}
|
764
|
+
}
|
765
|
+
}
|
766
|
+
|
767
|
+
template <>
|
768
|
+
inline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
|
769
|
+
__m512i v[16];
|
770
|
+
char * pb = (char *)packed_B;
|
771
|
+
for (int k = 0; k < QK_K / 64; ++k) {
|
772
|
+
for (int n = 0; n < TILE_N; ++n) {
|
773
|
+
__m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0);
|
774
|
+
__m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);
|
775
|
+
v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);
|
776
|
+
}
|
777
|
+
|
778
|
+
transpose_16x16_32bit(v);
|
779
|
+
|
780
|
+
// pack again with 128 to fully utilize vector length
|
781
|
+
for (int n = 0; n < TILE_N; n += 2) {
|
782
|
+
_mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
|
783
|
+
pb += 64;
|
784
|
+
}
|
785
|
+
}
|
786
|
+
}
|
787
|
+
|
788
|
+
// pack B to vnni formats in 4bits or 8 bits
|
789
|
+
void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) {
|
790
|
+
pack_qs(packed_B, B, KB);
|
791
|
+
ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
|
792
|
+
for (int n = 0; n < TILE_N; ++n) {
|
793
|
+
d0[n] = B[n * KB].d;
|
794
|
+
}
|
795
|
+
}
|
796
|
+
|
797
|
+
void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) {
|
798
|
+
pack_qs(packed_B, B, KB);
|
799
|
+
ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2);
|
800
|
+
ggml_half * m0 = d0 + TILE_N;
|
801
|
+
for (int n = 0; n < TILE_N; ++n) {
|
802
|
+
d0[n] = B[n * KB].d;
|
803
|
+
m0[n] = B[n * KB].m;
|
804
|
+
}
|
805
|
+
}
|
806
|
+
|
807
|
+
inline void s8s8_compensation(void * RESTRICT packed_B) {
|
808
|
+
// packed_B layout:
|
809
|
+
// quants {TILE_N, TILEK} int8_t
|
810
|
+
// d0 {TILE_N} ggml_half
|
811
|
+
// comp {TILE_N} int32_t
|
812
|
+
const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
|
813
|
+
__m512i vcomp = _mm512_setzero_si512();
|
814
|
+
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
815
|
+
for (int k = 0; k < 8; ++k) {
|
816
|
+
__m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64));
|
817
|
+
vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);
|
818
|
+
}
|
819
|
+
_mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp);
|
820
|
+
}
|
821
|
+
|
822
|
+
void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
|
823
|
+
pack_qs(packed_B, B, KB);
|
824
|
+
ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K);
|
825
|
+
for (int n = 0; n < TILE_N; ++n) {
|
826
|
+
d0[n] = B[n * KB].d;
|
827
|
+
}
|
828
|
+
s8s8_compensation(packed_B);
|
829
|
+
}
|
830
|
+
|
831
|
+
// convert 8 * {min, scale} from int6 to int8
|
832
|
+
inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) {
|
833
|
+
const uint32_t kmask1 = 0x3f3f3f3f;
|
834
|
+
const uint32_t kmask2 = 0x0f0f0f0f;
|
835
|
+
const uint32_t kmask3 = 0x03030303;
|
836
|
+
|
837
|
+
memcpy(utmp, scales, 12);
|
838
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
839
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
840
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
841
|
+
utmp[2] = uaux;
|
842
|
+
utmp[0] &= kmask1;
|
843
|
+
}
|
844
|
+
|
845
|
+
// packed_B layout:
|
846
|
+
// quants {8, TILE_N, 16} uint8
|
847
|
+
// scales {8, TILE_N} uint8
|
848
|
+
// mins {8, TILE_N} uint8
|
849
|
+
// d {TILE_N} ggml_half
|
850
|
+
// dmin {TILE_N} ggml_half
|
851
|
+
void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
|
852
|
+
pack_qs(packed_B, B, KB);
|
853
|
+
|
854
|
+
uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
|
855
|
+
uint8_t * mins = scales + 8 * TILE_N;
|
856
|
+
ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);
|
857
|
+
ggml_half * dmin = d + TILE_N;
|
858
|
+
|
859
|
+
union {
|
860
|
+
uint32_t u32[4];
|
861
|
+
uint8_t u8[16];
|
862
|
+
} s;
|
863
|
+
|
864
|
+
for (int n = 0; n < TILE_N; ++n) {
|
865
|
+
unpack_mins_and_scales(B[n * KB].scales, s.u32);
|
866
|
+
for (int k = 0; k < 8; ++k) {
|
867
|
+
scales[k * TILE_N + n] = s.u8[k];
|
868
|
+
mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
|
869
|
+
}
|
870
|
+
d[n] = B[n * KB].d;
|
871
|
+
dmin[n] = B[n * KB].dmin;
|
872
|
+
}
|
873
|
+
}
|
874
|
+
|
875
|
+
// packed_B layout:
|
876
|
+
// quants {8, TILE_N, 16} uint8
|
877
|
+
// qh {8, TILE_N, 4} uint8
|
878
|
+
// scales {8, TILE_N} uint8
|
879
|
+
// mins {8, TILE_N} uint8
|
880
|
+
// d {TILE_N} ggml_half
|
881
|
+
// dmin {TILE_N} ggml_half
|
882
|
+
void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
|
883
|
+
pack_qs(packed_B, B, KB);
|
884
|
+
|
885
|
+
uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
|
886
|
+
uint8_t * mins = scales + 8 * TILE_N;
|
887
|
+
ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N);
|
888
|
+
ggml_half * dmin = d + TILE_N;
|
889
|
+
|
890
|
+
union {
|
891
|
+
uint32_t u32[4];
|
892
|
+
uint8_t u8[16];
|
893
|
+
} s;
|
894
|
+
|
895
|
+
for (int n = 0; n < TILE_N; ++n) {
|
896
|
+
unpack_mins_and_scales(B[n * KB].scales, s.u32);
|
897
|
+
for (int k = 0; k < 8; ++k) {
|
898
|
+
scales[k * TILE_N + n] = s.u8[k];
|
899
|
+
mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
|
900
|
+
}
|
901
|
+
d[n] = B[n * KB].d;
|
902
|
+
dmin[n] = B[n * KB].dmin;
|
903
|
+
}
|
904
|
+
}
|
905
|
+
|
906
|
+
// packed_B layout:
|
907
|
+
// quants {16, TILE_N, 8} uint8
|
908
|
+
// qh {16, TILE_N, 4} uint8
|
909
|
+
// scales {16, TILE_N} uint8
|
910
|
+
// d {TILE_N} ggml_half
|
911
|
+
void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
|
912
|
+
pack_qs(packed_B, B, KB);
|
913
|
+
|
914
|
+
uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
|
915
|
+
ggml_half * d = reinterpret_cast<ggml_half *>(scales + 16 * TILE_N);
|
916
|
+
for (int n = 0; n < TILE_N; ++n) {
|
917
|
+
const int8_t * ps = B[n * KB].scales;
|
918
|
+
for (int k = 0; k < 16; ++k) {
|
919
|
+
scales[k * TILE_N + n] = ps[k];
|
920
|
+
}
|
921
|
+
d[n] = B[n * KB].d;
|
922
|
+
}
|
923
|
+
}
|
924
|
+
|
925
|
+
// packed_B layout:
|
926
|
+
// quants {8, TILE_N, 16} uint8
|
927
|
+
// scales {8, TILE_N} int8
|
928
|
+
// d {TILE_N} ggml_half
|
929
|
+
void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
|
930
|
+
pack_qs(packed_B, B, KB);
|
931
|
+
|
932
|
+
int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N);
|
933
|
+
ggml_half * d = reinterpret_cast<ggml_half *>(scales + 8 * TILE_N);
|
934
|
+
|
935
|
+
// pack the scales
|
936
|
+
for (int n = 0; n < TILE_N; ++n) {
|
937
|
+
uint16_t sh = B[n * KB].scales_h;
|
938
|
+
for (int k = 0; k < 8; k += 2) {
|
939
|
+
const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
940
|
+
const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32;
|
941
|
+
scales[(k + 0) * TILE_N + n] = ls1;
|
942
|
+
scales[(k + 1) * TILE_N + n] = ls2;
|
943
|
+
sh >>= 4;
|
944
|
+
}
|
945
|
+
d[n] = B[n * KB].d;
|
946
|
+
}
|
947
|
+
}
|
948
|
+
|
949
|
+
template<typename TB, typename packed_B_t = packed_B_type<TB>>
|
950
|
+
void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {
|
951
|
+
GGML_UNUSED(tile);
|
952
|
+
GGML_UNUSED(packed_B);
|
953
|
+
};
|
954
|
+
|
955
|
+
template <>
|
956
|
+
void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {
|
957
|
+
const __m512i off = _mm512_set1_epi8(8);
|
958
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
959
|
+
for (int n = 0; n < 8; n += 2) {
|
960
|
+
__m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
|
961
|
+
const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);
|
962
|
+
const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);
|
963
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
|
964
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
|
965
|
+
}
|
966
|
+
}
|
967
|
+
|
968
|
+
template <>
|
969
|
+
void unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) {
|
970
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
971
|
+
for (int n = 0; n < 8; n += 2) {
|
972
|
+
__m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
|
973
|
+
const __m512i r0 = _mm512_and_si512(bytes, lowMask);
|
974
|
+
const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
975
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
|
976
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
|
977
|
+
}
|
978
|
+
}
|
979
|
+
|
980
|
+
// packed_B_t for QKK is int8_t
|
981
|
+
template <typename TB>
|
982
|
+
void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
|
983
|
+
const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
|
984
|
+
const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size;
|
985
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
986
|
+
for (int n = 0; n < 8; n += 2) {
|
987
|
+
__m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);
|
988
|
+
const __m512i r0 = _mm512_and_si512(bytes, lowMask);
|
989
|
+
const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
990
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
|
991
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
|
992
|
+
}
|
993
|
+
}
|
994
|
+
|
995
|
+
template <>
|
996
|
+
void unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
|
997
|
+
// lower 4bits, stride 256 bytes
|
998
|
+
const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;
|
999
|
+
const char * pb = (const char *)packed_B + k * packed_l4_group_size;
|
1000
|
+
|
1001
|
+
// higher 1bit, stride 64 bytes
|
1002
|
+
const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;
|
1003
|
+
const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;
|
1004
|
+
const __m512i hbits = _mm512_loadu_si512(ph);
|
1005
|
+
|
1006
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1007
|
+
__m512i hmask0 = _mm512_set1_epi8(0x1);
|
1008
|
+
__m512i hmask1 = _mm512_set1_epi8(0x2);
|
1009
|
+
|
1010
|
+
for (int n = 0; n < 8; n += 2) {
|
1011
|
+
__m512i bytes = _mm512_loadu_si512(pb + n * 32);
|
1012
|
+
__m512i r0 = _mm512_and_si512(bytes, lowMask);
|
1013
|
+
__m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1014
|
+
__m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);
|
1015
|
+
__m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);
|
1016
|
+
|
1017
|
+
hmask0 = _mm512_slli_epi16(hmask0, 2);
|
1018
|
+
hmask1 = _mm512_slli_epi16(hmask1, 2);
|
1019
|
+
r0 = _mm512_add_epi8(r0, h0);
|
1020
|
+
r1 = _mm512_add_epi8(r1, h1);
|
1021
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
|
1022
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
|
1023
|
+
}
|
1024
|
+
}
|
1025
|
+
|
1026
|
+
template <>
|
1027
|
+
void unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
|
1028
|
+
// lower 4bits, stride 128 bytes
|
1029
|
+
const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;
|
1030
|
+
const char * pb = (const char *)packed_B + k * packed_l4_group_size;
|
1031
|
+
|
1032
|
+
// higher 2bits, stride 64 bytes
|
1033
|
+
const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;
|
1034
|
+
const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;
|
1035
|
+
const __m512i hbits = _mm512_loadu_si512(ph);
|
1036
|
+
|
1037
|
+
const __m512i off = _mm512_set1_epi8(32);
|
1038
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1039
|
+
__m512i hmask0 = _mm512_set1_epi8(0x3); // 0011
|
1040
|
+
__m512i hmask1 = _mm512_set1_epi8(0xC); // 1100
|
1041
|
+
|
1042
|
+
// notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`
|
1043
|
+
__m512i bytes = _mm512_loadu_si512(pb);
|
1044
|
+
__m512i r0 = _mm512_and_si512(bytes, lowMask);
|
1045
|
+
__m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1046
|
+
__m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);
|
1047
|
+
__m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);
|
1048
|
+
_mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
|
1049
|
+
_mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
|
1050
|
+
|
1051
|
+
hmask0 = _mm512_slli_epi16(hmask0, 4);
|
1052
|
+
hmask1 = _mm512_slli_epi16(hmask1, 4);
|
1053
|
+
|
1054
|
+
bytes = _mm512_loadu_si512(pb + 64);
|
1055
|
+
r0 = _mm512_and_si512(bytes, lowMask);
|
1056
|
+
r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1057
|
+
h0 = _mm512_and_si512(hbits, hmask0);
|
1058
|
+
h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);
|
1059
|
+
_mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
|
1060
|
+
_mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
|
1061
|
+
}
|
1062
|
+
|
1063
|
+
template <>
|
1064
|
+
void unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
|
1065
|
+
static const __m512i values128 = _mm512_set_epi8(
|
1066
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
|
1067
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
|
1068
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
|
1069
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
|
1070
|
+
);
|
1071
|
+
|
1072
|
+
const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
|
1073
|
+
const char * pb = (const char *)packed_B + k * packed_B_group_size;
|
1074
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1075
|
+
|
1076
|
+
for (int n = 0; n < 8; n += 2) {
|
1077
|
+
__m512i bytes = _mm512_loadu_si512(pb + n * 32);
|
1078
|
+
const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));
|
1079
|
+
const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
|
1080
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0);
|
1081
|
+
_mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
|
1082
|
+
}
|
1083
|
+
}
|
1084
|
+
|
1085
|
+
template <typename TA, typename TB, bool is_acc>
|
1086
|
+
struct acc_C {};
|
1087
|
+
|
1088
|
+
template <bool is_acc>
|
1089
|
+
struct acc_C<block_q8_0, block_q4_0, is_acc> {
|
1090
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
|
1091
|
+
const int offset = TILE_N * TILE_K / 2;
|
1092
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
|
1093
|
+
|
1094
|
+
for (int m = 0; m < nr; ++m) {
|
1095
|
+
const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
|
1096
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1097
|
+
|
1098
|
+
__m512 vsum;
|
1099
|
+
if (is_acc) {
|
1100
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1101
|
+
} else {
|
1102
|
+
vsum = _mm512_set1_ps(0.f);
|
1103
|
+
}
|
1104
|
+
vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
|
1105
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1106
|
+
}
|
1107
|
+
}
|
1108
|
+
};
|
1109
|
+
|
1110
|
+
template <bool is_acc>
|
1111
|
+
struct acc_C<block_q8_1, block_q4_1, is_acc> {
|
1112
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) {
|
1113
|
+
const int offset = TILE_N * TILE_K / 2;
|
1114
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
|
1115
|
+
const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half))));
|
1116
|
+
|
1117
|
+
for (int m = 0; m < nr; ++m) {
|
1118
|
+
const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
|
1119
|
+
const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s));
|
1120
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1121
|
+
|
1122
|
+
__m512 vsum;
|
1123
|
+
if (is_acc) {
|
1124
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1125
|
+
} else {
|
1126
|
+
vsum = _mm512_set1_ps(0.f);
|
1127
|
+
}
|
1128
|
+
vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
|
1129
|
+
vsum = _mm512_fmadd_ps(vm0, vs1, vsum);
|
1130
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1131
|
+
}
|
1132
|
+
}
|
1133
|
+
};
|
1134
|
+
|
1135
|
+
template <bool is_acc>
|
1136
|
+
struct acc_C<block_q8_0, block_q8_0, is_acc> {
|
1137
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
|
1138
|
+
const int offset = TILE_N * TILE_K;
|
1139
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
|
1140
|
+
|
1141
|
+
for (int m = 0; m < nr; ++m) {
|
1142
|
+
const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
|
1143
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1144
|
+
|
1145
|
+
__m512 vsum;
|
1146
|
+
if (is_acc) {
|
1147
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1148
|
+
} else {
|
1149
|
+
vsum = _mm512_set1_ps(0.f);
|
1150
|
+
}
|
1151
|
+
vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
|
1152
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1153
|
+
}
|
1154
|
+
}
|
1155
|
+
};
|
1156
|
+
|
1157
|
+
template <bool is_acc>
|
1158
|
+
struct acc_C<block_q8_K, block_q4_K, is_acc> {
|
1159
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
|
1160
|
+
const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
|
1161
|
+
const uint8_t * mins = scales + 8 * TILE_N;
|
1162
|
+
const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);
|
1163
|
+
const ggml_half * dmin = d0 + TILE_N;
|
1164
|
+
|
1165
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
|
1166
|
+
const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
|
1167
|
+
|
1168
|
+
for (int m = 0; m < nr; ++m) {
|
1169
|
+
const float d1 = A[m * lda].d;
|
1170
|
+
const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
|
1171
|
+
const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
|
1172
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1173
|
+
|
1174
|
+
__m512 vsum;
|
1175
|
+
if (is_acc) {
|
1176
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1177
|
+
} else {
|
1178
|
+
vsum = _mm512_set1_ps(0.f);
|
1179
|
+
}
|
1180
|
+
|
1181
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
|
1182
|
+
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
1183
|
+
|
1184
|
+
__m512i acc_m = _mm512_setzero_si512();
|
1185
|
+
for (int k = 0; k < 4; ++k) {
|
1186
|
+
__m512i vmask = _mm512_set1_epi32(k);
|
1187
|
+
__m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
|
1188
|
+
__m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
|
1189
|
+
acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
|
1190
|
+
}
|
1191
|
+
|
1192
|
+
vsum = _mm512_fmadd_ps(vtile, vd, vsum);
|
1193
|
+
vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
|
1194
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1195
|
+
}
|
1196
|
+
}
|
1197
|
+
};
|
1198
|
+
|
1199
|
+
template <bool is_acc>
|
1200
|
+
struct acc_C<block_q8_K, block_q5_K, is_acc> {
|
1201
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
|
1202
|
+
const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
|
1203
|
+
const uint8_t * mins = scales + 8 * TILE_N;
|
1204
|
+
const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N);
|
1205
|
+
const ggml_half * dmin = d0 + TILE_N;
|
1206
|
+
|
1207
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
|
1208
|
+
const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
|
1209
|
+
|
1210
|
+
for (int m = 0; m < nr; ++m) {
|
1211
|
+
const float d1 = A[m * lda].d;
|
1212
|
+
const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
|
1213
|
+
const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
|
1214
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1215
|
+
|
1216
|
+
__m512 vsum;
|
1217
|
+
if (is_acc) {
|
1218
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1219
|
+
} else {
|
1220
|
+
vsum = _mm512_set1_ps(0.f);
|
1221
|
+
}
|
1222
|
+
|
1223
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
|
1224
|
+
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
1225
|
+
|
1226
|
+
__m512i acc_m = _mm512_setzero_si512();
|
1227
|
+
for (int k = 0; k < 4; ++k) {
|
1228
|
+
__m512i vmask = _mm512_set1_epi32(k);
|
1229
|
+
__m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
|
1230
|
+
__m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
|
1231
|
+
acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
|
1232
|
+
}
|
1233
|
+
|
1234
|
+
vsum = _mm512_fmadd_ps(vtile, vd, vsum);
|
1235
|
+
vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
|
1236
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1237
|
+
}
|
1238
|
+
}
|
1239
|
+
};
|
1240
|
+
|
1241
|
+
template <bool is_acc>
|
1242
|
+
struct acc_C<block_q8_K, block_q6_K, is_acc> {
|
1243
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
|
1244
|
+
const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
|
1245
|
+
const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 16 * TILE_N);
|
1246
|
+
|
1247
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
|
1248
|
+
|
1249
|
+
for (int m = 0; m < nr; ++m) {
|
1250
|
+
const float d1 = A[m * lda].d;
|
1251
|
+
const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
|
1252
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1253
|
+
|
1254
|
+
__m512 vsum;
|
1255
|
+
if (is_acc) {
|
1256
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1257
|
+
} else {
|
1258
|
+
vsum = _mm512_set1_ps(0.f);
|
1259
|
+
}
|
1260
|
+
|
1261
|
+
vsum = _mm512_fmadd_ps(vtile, vd, vsum);
|
1262
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1263
|
+
}
|
1264
|
+
}
|
1265
|
+
};
|
1266
|
+
|
1267
|
+
template <bool is_acc>
|
1268
|
+
struct acc_C<block_q8_K, block_iq4_xs, is_acc> {
|
1269
|
+
static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
|
1270
|
+
const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N);
|
1271
|
+
const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 8 * TILE_N);
|
1272
|
+
|
1273
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
|
1274
|
+
|
1275
|
+
for (int m = 0; m < nr; ++m) {
|
1276
|
+
const float d1 = A[m * lda].d;
|
1277
|
+
const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
|
1278
|
+
const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
|
1279
|
+
|
1280
|
+
__m512 vsum;
|
1281
|
+
if (is_acc) {
|
1282
|
+
vsum = _mm512_loadu_ps(C + m * ldc);
|
1283
|
+
} else {
|
1284
|
+
vsum = _mm512_set1_ps(0.f);
|
1285
|
+
}
|
1286
|
+
|
1287
|
+
vsum = _mm512_fmadd_ps(vtile, vd, vsum);
|
1288
|
+
_mm512_storeu_ps(C + m * ldc, vsum);
|
1289
|
+
}
|
1290
|
+
}
|
1291
|
+
};
|
1292
|
+
|
1293
|
+
template <typename TB> constexpr int get_quants_size();
|
1294
|
+
template <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; }
|
1295
|
+
template <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; }
|
1296
|
+
template <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; }
|
1297
|
+
template <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; }
|
1298
|
+
|
1299
|
+
// used for QKK format
|
1300
|
+
template <typename TB, bool is_acc,
|
1301
|
+
typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
|
1302
|
+
inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) {
|
1303
|
+
const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>());
|
1304
|
+
const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N)));
|
1305
|
+
|
1306
|
+
for (int m = 0; m < nr; ++m) {
|
1307
|
+
__m512i vsumi;
|
1308
|
+
if (is_acc) {
|
1309
|
+
vsumi = _mm512_loadu_si512(sumi + m * TILE_N);
|
1310
|
+
} else {
|
1311
|
+
vsumi = _mm512_setzero_si512();
|
1312
|
+
}
|
1313
|
+
__m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);
|
1314
|
+
vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));
|
1315
|
+
_mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi);
|
1316
|
+
}
|
1317
|
+
}
|
1318
|
+
|
1319
|
+
template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1320
|
+
struct tinygemm_kernel_avx {
|
1321
|
+
static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) {
|
1322
|
+
GGML_UNUSED(K);
|
1323
|
+
GGML_UNUSED(A);
|
1324
|
+
GGML_UNUSED(B);
|
1325
|
+
GGML_UNUSED(C);
|
1326
|
+
GGML_UNUSED(ldc);
|
1327
|
+
}
|
1328
|
+
};
|
1329
|
+
|
1330
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1331
|
+
struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1332
|
+
static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) {
|
1333
|
+
constexpr int ROWS = BLOCK_M;
|
1334
|
+
constexpr int COLS = BLOCK_N;
|
1335
|
+
assert(BLOCK_K == 16);
|
1336
|
+
|
1337
|
+
__m512 va;
|
1338
|
+
__m512 vb[COLS];
|
1339
|
+
__m512 vc[ROWS * COLS];
|
1340
|
+
|
1341
|
+
auto loadc = [&](int idx) {
|
1342
|
+
vc[idx] = _mm512_setzero_ps();
|
1343
|
+
};
|
1344
|
+
Unroll<ROWS * COLS>{}(loadc);
|
1345
|
+
|
1346
|
+
auto compute = [&](int idx, int k) {
|
1347
|
+
// TODO: use `constexpr` here to get rid of interger div
|
1348
|
+
// when upgraded to C++17
|
1349
|
+
const int row = idx / COLS;
|
1350
|
+
const int col = idx % COLS;
|
1351
|
+
|
1352
|
+
if (col == 0) {
|
1353
|
+
va = _mm512_loadu_ps(A + row * K + k);
|
1354
|
+
}
|
1355
|
+
if (row == 0) {
|
1356
|
+
vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
|
1357
|
+
}
|
1358
|
+
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
1359
|
+
};
|
1360
|
+
|
1361
|
+
for (int k = 0; k < K; k += 16) {
|
1362
|
+
Unroll<ROWS * COLS>{}(compute, k);
|
1363
|
+
}
|
1364
|
+
|
1365
|
+
auto storec = [&](int idx) {
|
1366
|
+
const int row = idx / COLS;
|
1367
|
+
const int col = idx % COLS;
|
1368
|
+
C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
|
1369
|
+
};
|
1370
|
+
Unroll<ROWS * COLS>{}(storec);
|
1371
|
+
}
|
1372
|
+
};
|
1373
|
+
|
1374
|
+
#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
|
1375
|
+
tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
|
1376
|
+
K, (const float *)src1->data + mb_start * K, \
|
1377
|
+
(const type *)src0->data + nb_start * K, \
|
1378
|
+
(float *)dst->data + mb_start * ldc + nb_start, ldc);
|
1379
|
+
|
1380
|
+
|
1381
|
+
// re-organize in the format {NB, KB, TILE_SIZE}:
|
1382
|
+
#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
|
1383
|
+
|
1384
|
+
template<typename TB, int BLOCK_K>
|
1385
|
+
void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) {
|
1386
|
+
const int NB = N / TILE_N;
|
1387
|
+
const int KB = K / BLOCK_K;
|
1388
|
+
const int TILE_SIZE = get_tile_size<TB>();
|
1389
|
+
|
1390
|
+
// parallel on NB should be enough
|
1391
|
+
parallel_for(n_threads, NB, [&](int begin, int end) {
|
1392
|
+
for (int n = begin; n < end; ++n) {
|
1393
|
+
for (int k = 0; k < KB; ++k) {
|
1394
|
+
int n0 = n * TILE_N;
|
1395
|
+
pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);
|
1396
|
+
}
|
1397
|
+
}
|
1398
|
+
});
|
1399
|
+
}
|
1400
|
+
|
1401
|
+
template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1402
|
+
struct tinygemm_kernel_vnni {};
|
1403
|
+
|
1404
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1405
|
+
struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1406
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1407
|
+
|
1408
|
+
constexpr int COLS = BLOCK_N / 16;
|
1409
|
+
const int TILE_SIZE = TILE_N * sizeof(block_q4_0);
|
1410
|
+
|
1411
|
+
const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
|
1412
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1413
|
+
|
1414
|
+
__m512i va[8];
|
1415
|
+
__m512 vc[COLS];
|
1416
|
+
__m512 vd1;
|
1417
|
+
|
1418
|
+
// sum of offsets, shared across COLS
|
1419
|
+
//
|
1420
|
+
// avx512-vnni does not have `_mm512_dpbssd_epi32`,
|
1421
|
+
// need to transfrom ss to us:
|
1422
|
+
// a * (b - 8) is equavilent to b * a - 8 * a
|
1423
|
+
// s u u u s u s
|
1424
|
+
//
|
1425
|
+
__m512i vcomp;
|
1426
|
+
|
1427
|
+
const __m512i off = _mm512_set1_epi8(8);
|
1428
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1429
|
+
|
1430
|
+
auto loadc = [&](int col) {
|
1431
|
+
vc[col] = _mm512_setzero_ps();
|
1432
|
+
};
|
1433
|
+
Unroll<COLS>{}(loadc);
|
1434
|
+
|
1435
|
+
auto compute = [&](int col, int i) {
|
1436
|
+
// load a and compute compensation
|
1437
|
+
if (col == 0) {
|
1438
|
+
const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
|
1439
|
+
vcomp = _mm512_setzero_si512();
|
1440
|
+
for (int k = 0; k < 8; ++k) {
|
1441
|
+
va[k] = _mm512_set1_epi32(a_ptr[k]);
|
1442
|
+
vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);
|
1443
|
+
}
|
1444
|
+
vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
|
1445
|
+
}
|
1446
|
+
|
1447
|
+
// load b
|
1448
|
+
__m512i vsum = _mm512_setzero_si512();
|
1449
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1450
|
+
for (int k = 0; k < 8; k += 2) {
|
1451
|
+
__m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
|
1452
|
+
__m512i vb0 = _mm512_and_si512(bytes, lowMask);
|
1453
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);
|
1454
|
+
__m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1455
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);
|
1456
|
+
}
|
1457
|
+
const int offset = TILE_N * TILE_K / 2;
|
1458
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
|
1459
|
+
vsum = _mm512_sub_epi32(vsum, vcomp);
|
1460
|
+
|
1461
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
|
1462
|
+
};
|
1463
|
+
|
1464
|
+
for (int i = 0; i < KB; ++i) {
|
1465
|
+
Unroll<COLS>{}(compute, i);
|
1466
|
+
}
|
1467
|
+
|
1468
|
+
//store to C
|
1469
|
+
auto storec = [&](int col) {
|
1470
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
1471
|
+
};
|
1472
|
+
Unroll<COLS>{}(storec);
|
1473
|
+
}
|
1474
|
+
};
|
1475
|
+
|
1476
|
+
template <int BLOCK_N, int BLOCK_K>
|
1477
|
+
struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {
|
1478
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1479
|
+
|
1480
|
+
constexpr int COLS = BLOCK_N / 16;
|
1481
|
+
const int TILE_SIZE = TILE_N * sizeof(block_q4_1);
|
1482
|
+
|
1483
|
+
const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A);
|
1484
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1485
|
+
|
1486
|
+
__m512i va[8];
|
1487
|
+
__m512i vb[8];
|
1488
|
+
__m512 vc[COLS];
|
1489
|
+
__m512 vd1, vs1;
|
1490
|
+
|
1491
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1492
|
+
|
1493
|
+
auto loadc = [&](int col) {
|
1494
|
+
vc[col] = _mm512_setzero_ps();
|
1495
|
+
};
|
1496
|
+
Unroll<COLS>{}(loadc);
|
1497
|
+
|
1498
|
+
auto compute = [&](int col, int i) {
|
1499
|
+
// load a
|
1500
|
+
if (col == 0) {
|
1501
|
+
const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
|
1502
|
+
for (int k = 0; k < 8; ++k) {
|
1503
|
+
va[k] = _mm512_set1_epi32(a_ptr[k]);
|
1504
|
+
}
|
1505
|
+
vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
|
1506
|
+
vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s));
|
1507
|
+
}
|
1508
|
+
|
1509
|
+
// load b
|
1510
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1511
|
+
for (int k = 0; k < 8; k += 2) {
|
1512
|
+
__m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
|
1513
|
+
vb[k + 0] = _mm512_and_si512(bytes, lowMask);
|
1514
|
+
vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1515
|
+
}
|
1516
|
+
const int offset = TILE_N * TILE_K / 2;
|
1517
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
|
1518
|
+
const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half))));
|
1519
|
+
|
1520
|
+
__m512i vsum = _mm512_setzero_si512();
|
1521
|
+
for (int k = 0; k < 8; ++k) {
|
1522
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);
|
1523
|
+
}
|
1524
|
+
|
1525
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
|
1526
|
+
vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);
|
1527
|
+
};
|
1528
|
+
|
1529
|
+
for (int i = 0; i < KB; ++i) {
|
1530
|
+
Unroll<COLS>{}(compute, i);
|
1531
|
+
}
|
1532
|
+
|
1533
|
+
//store to C
|
1534
|
+
auto storec = [&](int col) {
|
1535
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
1536
|
+
};
|
1537
|
+
Unroll<COLS>{}(storec);
|
1538
|
+
}
|
1539
|
+
};
|
1540
|
+
|
1541
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1542
|
+
struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1543
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1544
|
+
|
1545
|
+
constexpr int COLS = BLOCK_N / 16;
|
1546
|
+
const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);
|
1547
|
+
|
1548
|
+
const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A);
|
1549
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1550
|
+
|
1551
|
+
__m512i va[8];
|
1552
|
+
__m512i vb[8];
|
1553
|
+
__m512 vc[COLS];
|
1554
|
+
__m512 vd1;
|
1555
|
+
|
1556
|
+
// Notes: s8s8 igemm compensation in avx512-vnni
|
1557
|
+
// change s8s8 to u8s8 with compensate
|
1558
|
+
// a * b = (a + 128) * b - 128 * b
|
1559
|
+
// s s u s u s
|
1560
|
+
//
|
1561
|
+
// (128 * b is pre-computed when packing B to vnni formats)
|
1562
|
+
//
|
1563
|
+
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
1564
|
+
|
1565
|
+
auto loadc = [&](int col) {
|
1566
|
+
vc[col] = _mm512_setzero_ps();
|
1567
|
+
};
|
1568
|
+
Unroll<COLS>{}(loadc);
|
1569
|
+
|
1570
|
+
auto compute = [&](int col, int i) {
|
1571
|
+
// load a and add offset 128
|
1572
|
+
if (col == 0) {
|
1573
|
+
const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
|
1574
|
+
for (int k = 0; k < 8; ++k) {
|
1575
|
+
va[k] = _mm512_set1_epi32(a_ptr[k]);
|
1576
|
+
va[k] = _mm512_add_epi8(va[k], off);
|
1577
|
+
}
|
1578
|
+
vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
|
1579
|
+
}
|
1580
|
+
|
1581
|
+
// load b
|
1582
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1583
|
+
for (int k = 0; k < 8; ++k) {
|
1584
|
+
vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64));
|
1585
|
+
}
|
1586
|
+
const int offset = TILE_N * TILE_K;
|
1587
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
|
1588
|
+
const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
|
1589
|
+
const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2));
|
1590
|
+
|
1591
|
+
__m512i vsum = _mm512_setzero_si512();
|
1592
|
+
for (int k = 0; k < 8; ++k) {
|
1593
|
+
vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);
|
1594
|
+
}
|
1595
|
+
vsum = _mm512_sub_epi32(vsum, vcomp);
|
1596
|
+
|
1597
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
|
1598
|
+
};
|
1599
|
+
|
1600
|
+
for (int i = 0; i < KB; ++i) {
|
1601
|
+
Unroll<COLS>{}(compute, i);
|
1602
|
+
}
|
1603
|
+
|
1604
|
+
//store to C
|
1605
|
+
auto storec = [&](int col) {
|
1606
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
1607
|
+
};
|
1608
|
+
Unroll<COLS>{}(storec);
|
1609
|
+
}
|
1610
|
+
};
|
1611
|
+
|
1612
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1613
|
+
struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1614
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1615
|
+
|
1616
|
+
constexpr int COLS = BLOCK_N / 16;
|
1617
|
+
const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;
|
1618
|
+
|
1619
|
+
const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
|
1620
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1621
|
+
|
1622
|
+
// a.qs: 8 groups, 32 bytes each group (m256i)
|
1623
|
+
__m512i va[8];
|
1624
|
+
// a.bsum: 8 groups, 2 bytes each group (m128i)
|
1625
|
+
__m512i va_bsum;
|
1626
|
+
__m512 vc[COLS];
|
1627
|
+
__m512 vd1;
|
1628
|
+
|
1629
|
+
// packed_B:
|
1630
|
+
const int offset_scales = (QK_K / 2) * TILE_N;
|
1631
|
+
const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N;
|
1632
|
+
const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N;
|
1633
|
+
const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
|
1634
|
+
|
1635
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1636
|
+
|
1637
|
+
auto loadc = [&](int col) {
|
1638
|
+
vc[col] = _mm512_setzero_ps();
|
1639
|
+
};
|
1640
|
+
Unroll<COLS>{}(loadc);
|
1641
|
+
|
1642
|
+
// Notes: vnni formats in QK_K
|
1643
|
+
// a) quants vnni format
|
1644
|
+
// int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32
|
1645
|
+
// from {16, 32} to {8, 64}
|
1646
|
+
//
|
1647
|
+
// b) min vnni format
|
1648
|
+
// int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
|
1649
|
+
// from {16, 8} to {4, 32}
|
1650
|
+
//
|
1651
|
+
auto compute = [&](int col, int i) {
|
1652
|
+
// load a
|
1653
|
+
if (col == 0) {
|
1654
|
+
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
|
1655
|
+
va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
|
1656
|
+
}
|
1657
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
|
1658
|
+
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
1659
|
+
va_bsum = _mm512_castsi128_si512(q8s);
|
1660
|
+
vd1 = _mm512_set1_ps(A[0 * KB + i].d);
|
1661
|
+
}
|
1662
|
+
|
1663
|
+
// step 1: accumultate the quants
|
1664
|
+
__m512i acc = _mm512_setzero_si512();
|
1665
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1666
|
+
const char * b_qs = b_ptr;
|
1667
|
+
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
|
1668
|
+
__m512i vsum = _mm512_setzero_si512();
|
1669
|
+
for (int k = 0; k < 8; k += 2) {
|
1670
|
+
__m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
|
1671
|
+
__m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
|
1672
|
+
|
1673
|
+
__m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
|
1674
|
+
__m512i vb0 = _mm512_and_si512(bytes, lowMask);
|
1675
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
|
1676
|
+
__m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1677
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
|
1678
|
+
|
1679
|
+
b_qs += 64;
|
1680
|
+
}
|
1681
|
+
// vacc += scale * (q8 @ q4)
|
1682
|
+
const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
|
1683
|
+
acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
|
1684
|
+
}
|
1685
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
|
1686
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
|
1687
|
+
|
1688
|
+
// step 2: accumulate the mins
|
1689
|
+
__m512i acc_m = _mm512_setzero_si512();
|
1690
|
+
for (int k = 0; k < 4; ++k) {
|
1691
|
+
__m512i vmask = _mm512_set1_epi32(k);
|
1692
|
+
__m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
|
1693
|
+
__m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
|
1694
|
+
acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
|
1695
|
+
}
|
1696
|
+
const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
|
1697
|
+
vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
|
1698
|
+
};
|
1699
|
+
|
1700
|
+
for (int i = 0; i < KB; ++i) {
|
1701
|
+
Unroll<COLS>{}(compute, i);
|
1702
|
+
}
|
1703
|
+
|
1704
|
+
//store to C
|
1705
|
+
auto storec = [&](int col) {
|
1706
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
1707
|
+
};
|
1708
|
+
Unroll<COLS>{}(storec);
|
1709
|
+
}
|
1710
|
+
};
|
1711
|
+
|
1712
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1713
|
+
struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1714
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1715
|
+
|
1716
|
+
constexpr int COLS = BLOCK_N / 16;
|
1717
|
+
const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;
|
1718
|
+
|
1719
|
+
const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
|
1720
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1721
|
+
|
1722
|
+
// a.qs: 8 groups, 32 bytes each group (m256i)
|
1723
|
+
__m512i va[8];
|
1724
|
+
// a.bsum: 8 groups, 2 bytes each group (m128i)
|
1725
|
+
__m512i va_bsum;
|
1726
|
+
__m512 vc[COLS];
|
1727
|
+
__m512 vd1;
|
1728
|
+
|
1729
|
+
// packed_B:
|
1730
|
+
const int offset_qh = (QK_K / 2) * TILE_N;
|
1731
|
+
const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;
|
1732
|
+
const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N;
|
1733
|
+
const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;
|
1734
|
+
const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
|
1735
|
+
|
1736
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1737
|
+
|
1738
|
+
auto loadc = [&](int col) {
|
1739
|
+
vc[col] = _mm512_setzero_ps();
|
1740
|
+
};
|
1741
|
+
Unroll<COLS>{}(loadc);
|
1742
|
+
|
1743
|
+
// Q5_K and Q4_K shares the same vnni formats, refer to notes above.
|
1744
|
+
auto compute = [&](int col, int i) {
|
1745
|
+
// load a
|
1746
|
+
if (col == 0) {
|
1747
|
+
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
|
1748
|
+
va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
|
1749
|
+
}
|
1750
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
|
1751
|
+
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
1752
|
+
va_bsum = _mm512_castsi128_si512(q8s);
|
1753
|
+
vd1 = _mm512_set1_ps(A[0 * KB + i].d);
|
1754
|
+
}
|
1755
|
+
|
1756
|
+
// step 1: accumultate the quants
|
1757
|
+
__m512i acc = _mm512_setzero_si512();
|
1758
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1759
|
+
const char * b_qs = b_ptr;
|
1760
|
+
const char * b_qh = b_ptr + offset_qh;
|
1761
|
+
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
|
1762
|
+
__m512i vsum = _mm512_setzero_si512();
|
1763
|
+
__m512i hmask0 = _mm512_set1_epi8(0x1);
|
1764
|
+
__m512i hmask1 = _mm512_set1_epi8(0x2);
|
1765
|
+
__m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64));
|
1766
|
+
for (int k = 0; k < 8; k += 2) {
|
1767
|
+
__m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
|
1768
|
+
__m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
|
1769
|
+
|
1770
|
+
__m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
|
1771
|
+
__m512i vb0 = _mm512_and_si512(bytes, lowMask);
|
1772
|
+
__m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1773
|
+
|
1774
|
+
__m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);
|
1775
|
+
__m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);
|
1776
|
+
|
1777
|
+
hmask0 = _mm512_slli_epi16(hmask0, 2);
|
1778
|
+
hmask1 = _mm512_slli_epi16(hmask1, 2);
|
1779
|
+
vb0 = _mm512_add_epi8(vb0, vh0);
|
1780
|
+
vb1 = _mm512_add_epi8(vb1, vh1);
|
1781
|
+
|
1782
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
|
1783
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
|
1784
|
+
|
1785
|
+
b_qs += 64;
|
1786
|
+
}
|
1787
|
+
// vacc += scale * (q8 @ q5)
|
1788
|
+
const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
|
1789
|
+
acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
|
1790
|
+
}
|
1791
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
|
1792
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
|
1793
|
+
|
1794
|
+
// step 2: accumulate the mins
|
1795
|
+
__m512i acc_m = _mm512_setzero_si512();
|
1796
|
+
for (int k = 0; k < 4; ++k) {
|
1797
|
+
__m512i vmask = _mm512_set1_epi32(k);
|
1798
|
+
__m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
|
1799
|
+
__m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
|
1800
|
+
acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
|
1801
|
+
}
|
1802
|
+
const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
|
1803
|
+
vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
|
1804
|
+
};
|
1805
|
+
|
1806
|
+
for (int i = 0; i < KB; ++i) {
|
1807
|
+
Unroll<COLS>{}(compute, i);
|
1808
|
+
}
|
1809
|
+
|
1810
|
+
//store to C
|
1811
|
+
auto storec = [&](int col) {
|
1812
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
1813
|
+
};
|
1814
|
+
Unroll<COLS>{}(storec);
|
1815
|
+
}
|
1816
|
+
};
|
1817
|
+
|
1818
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1819
|
+
struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1820
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1821
|
+
|
1822
|
+
constexpr int COLS = BLOCK_N / 16;
|
1823
|
+
const int TILE_SIZE = TILE_N * sizeof(block_q6_K);
|
1824
|
+
|
1825
|
+
const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
|
1826
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1827
|
+
|
1828
|
+
// load the 256 bytes from A to 4 avx512 vectors
|
1829
|
+
__m512i va[4];
|
1830
|
+
__m512 vc[COLS];
|
1831
|
+
__m512 vd1;
|
1832
|
+
|
1833
|
+
// packed_B:
|
1834
|
+
const int offset_qh = (QK_K / 2) * TILE_N;
|
1835
|
+
const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;
|
1836
|
+
const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;
|
1837
|
+
|
1838
|
+
// compensation
|
1839
|
+
__m512i vcomp;
|
1840
|
+
|
1841
|
+
const __m512i m32s = _mm512_set1_epi32(32);
|
1842
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1843
|
+
|
1844
|
+
auto loadc = [&](int col) {
|
1845
|
+
vc[col] = _mm512_setzero_ps();
|
1846
|
+
};
|
1847
|
+
Unroll<COLS>{}(loadc);
|
1848
|
+
|
1849
|
+
auto compute = [&](int col, int i) {
|
1850
|
+
if (col == 0) {
|
1851
|
+
// load a
|
1852
|
+
va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
|
1853
|
+
va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
|
1854
|
+
va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
|
1855
|
+
va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
|
1856
|
+
|
1857
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
|
1858
|
+
vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);
|
1859
|
+
vd1 = _mm512_set1_ps(A[0 * KB + i].d);
|
1860
|
+
}
|
1861
|
+
|
1862
|
+
// accmulate the quants
|
1863
|
+
__m512i acc = _mm512_setzero_si512();
|
1864
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1865
|
+
const char * b_qs = b_ptr;
|
1866
|
+
const char * b_qh = b_ptr + offset_qh;
|
1867
|
+
int mask = 0;
|
1868
|
+
for (int k_group = 0; k_group < QK_K / 16; ++k_group) {
|
1869
|
+
int r = k_group >> 2;
|
1870
|
+
__m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
|
1871
|
+
__m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
|
1872
|
+
|
1873
|
+
__m512i vsum = _mm512_setzero_si512();
|
1874
|
+
__m512i hmask = _mm512_set1_epi8(0x3);
|
1875
|
+
|
1876
|
+
__m512i bytes = _mm512_loadu_si512(b_qs);
|
1877
|
+
__m512i hbits = _mm512_loadu_si512(b_qh);
|
1878
|
+
__m512i vb0 = _mm512_and_si512(bytes, lowMask);
|
1879
|
+
__m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1880
|
+
__m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);
|
1881
|
+
__m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);
|
1882
|
+
|
1883
|
+
vb0 = _mm512_add_epi8(vb0, vh0);
|
1884
|
+
vb1 = _mm512_add_epi8(vb1, vh1);
|
1885
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
|
1886
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
|
1887
|
+
b_qs += 64;
|
1888
|
+
|
1889
|
+
va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
|
1890
|
+
va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
|
1891
|
+
|
1892
|
+
bytes = _mm512_loadu_si512(b_qs);
|
1893
|
+
vb0 = _mm512_and_si512(bytes, lowMask);
|
1894
|
+
vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
|
1895
|
+
vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));
|
1896
|
+
vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);
|
1897
|
+
vb0 = _mm512_add_epi8(vb0, vh0);
|
1898
|
+
vb1 = _mm512_add_epi8(vb1, vh1);
|
1899
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
|
1900
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
|
1901
|
+
b_qs += 64;
|
1902
|
+
b_qh += 64;
|
1903
|
+
|
1904
|
+
// B * A - 32 * A
|
1905
|
+
__m512i vmask = _mm512_set1_epi32(k_group);
|
1906
|
+
vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
|
1907
|
+
|
1908
|
+
// vacc += scale * (q8 @ q6)
|
1909
|
+
const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
|
1910
|
+
acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
|
1911
|
+
}
|
1912
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
|
1913
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
|
1914
|
+
};
|
1915
|
+
|
1916
|
+
for (int i = 0; i < KB; ++i) {
|
1917
|
+
Unroll<COLS>{}(compute, i);
|
1918
|
+
}
|
1919
|
+
|
1920
|
+
//store to C
|
1921
|
+
auto storec = [&](int col) {
|
1922
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
1923
|
+
};
|
1924
|
+
Unroll<COLS>{}(storec);
|
1925
|
+
}
|
1926
|
+
};
|
1927
|
+
|
1928
|
+
template <int BLOCK_M, int BLOCK_N, int BLOCK_K>
|
1929
|
+
struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {
|
1930
|
+
static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
1931
|
+
|
1932
|
+
constexpr int COLS = BLOCK_N / 16;
|
1933
|
+
const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;
|
1934
|
+
|
1935
|
+
const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A);
|
1936
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
1937
|
+
|
1938
|
+
// load the 256 bytes from A to 4 avx512 vectors
|
1939
|
+
__m512i va[4];
|
1940
|
+
__m512 vc[COLS];
|
1941
|
+
__m512 vd1;
|
1942
|
+
|
1943
|
+
// packed_B:
|
1944
|
+
const int offset_scales = (QK_K / 2) * TILE_N ;
|
1945
|
+
const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N;
|
1946
|
+
|
1947
|
+
// compensation
|
1948
|
+
__m512i vcomp;
|
1949
|
+
|
1950
|
+
const __m256i m128s = _mm256_set1_epi16(128);
|
1951
|
+
const __m512i lowMask = _mm512_set1_epi8(0xF);
|
1952
|
+
|
1953
|
+
const __m512i values128 = _mm512_set_epi8(
|
1954
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
|
1955
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
|
1956
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
|
1957
|
+
113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
|
1958
|
+
);
|
1959
|
+
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
1960
|
+
const __m512i values256 = _mm512_add_epi8(values128, off);
|
1961
|
+
|
1962
|
+
auto loadc = [&](int col) {
|
1963
|
+
vc[col] = _mm512_setzero_ps();
|
1964
|
+
};
|
1965
|
+
Unroll<COLS>{}(loadc);
|
1966
|
+
|
1967
|
+
auto compute = [&](int col, int i) {
|
1968
|
+
if (col == 0) {
|
1969
|
+
// load a
|
1970
|
+
va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
|
1971
|
+
va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
|
1972
|
+
va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
|
1973
|
+
va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
|
1974
|
+
|
1975
|
+
// compensation: 128 * A
|
1976
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
|
1977
|
+
vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));
|
1978
|
+
vd1 = _mm512_set1_ps(A[0 * KB + i].d);
|
1979
|
+
}
|
1980
|
+
|
1981
|
+
// accmulate the quants
|
1982
|
+
__m512i acc = _mm512_setzero_si512();
|
1983
|
+
const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
|
1984
|
+
const char * b_qs = b_ptr;
|
1985
|
+
int mask = 0;
|
1986
|
+
for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
|
1987
|
+
int r = k_group >> 1;
|
1988
|
+
__m512i vmask = _mm512_set1_epi32(k_group);
|
1989
|
+
__m512i vsum = _mm512_setzero_si512();
|
1990
|
+
for (int k = 0; k < 8; k += 2) {
|
1991
|
+
__m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
|
1992
|
+
__m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
|
1993
|
+
|
1994
|
+
__m512i bytes = _mm512_loadu_si512(b_qs);
|
1995
|
+
__m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));
|
1996
|
+
__m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
|
1997
|
+
|
1998
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
|
1999
|
+
vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
|
2000
|
+
b_qs += 64;
|
2001
|
+
}
|
2002
|
+
// (B + 128) * A - 128 * A
|
2003
|
+
vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
|
2004
|
+
|
2005
|
+
// vacc += scale * (q8 @ q4)
|
2006
|
+
const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
|
2007
|
+
acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
|
2008
|
+
}
|
2009
|
+
const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
|
2010
|
+
vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
|
2011
|
+
};
|
2012
|
+
|
2013
|
+
for (int i = 0; i < KB; ++i) {
|
2014
|
+
Unroll<COLS>{}(compute, i);
|
2015
|
+
}
|
2016
|
+
|
2017
|
+
//store to C
|
2018
|
+
auto storec = [&](int col) {
|
2019
|
+
_mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
|
2020
|
+
};
|
2021
|
+
Unroll<COLS>{}(storec);
|
2022
|
+
}
|
2023
|
+
};
|
2024
|
+
|
2025
|
+
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
|
2026
|
+
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
|
2027
|
+
KB, (const char *)wdata + 0 * row_size_A, \
|
2028
|
+
(const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
|
2029
|
+
(float *) dst->data + 0 * N + nb_start, ldc)
|
2030
|
+
|
2031
|
+
template <typename TA, typename TB, typename TC, int BLOCK_K,
|
2032
|
+
typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
|
2033
|
+
void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) {
|
2034
|
+
using packed_B_t = packed_B_type<TB>;
|
2035
|
+
const int TILE_SIZE = get_tile_size<TB>();
|
2036
|
+
const bool need_unpack = do_unpack<TB>::value;
|
2037
|
+
|
2038
|
+
GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
|
2039
|
+
const TA * RESTRICT A = static_cast<const TA *>(_A);
|
2040
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
2041
|
+
|
2042
|
+
const int m0 = std::min(M, TILE_M);
|
2043
|
+
const int m1 = std::max(M - TILE_M, 0);
|
2044
|
+
const int lda = KB * sizeof(TA);
|
2045
|
+
//const int ldb = KB * sizeof(TB);
|
2046
|
+
|
2047
|
+
static thread_local packed_B_t Tile0[TILE_N * TILE_K];
|
2048
|
+
static thread_local packed_B_t Tile1[TILE_N * TILE_K];
|
2049
|
+
static thread_local int8_t Tile23[TILE_M * TILE_K];
|
2050
|
+
|
2051
|
+
static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
|
2052
|
+
static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
|
2053
|
+
|
2054
|
+
// double buffering C to interleave avx512 and amx
|
2055
|
+
int32_t * C_cur = TileC0;
|
2056
|
+
int32_t * C_pre = TileC1;
|
2057
|
+
|
2058
|
+
auto Tile4 = [&](int32_t * base) { return base; };
|
2059
|
+
auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; };
|
2060
|
+
auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; };
|
2061
|
+
auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; };
|
2062
|
+
|
2063
|
+
if (M == 2 * TILE_M) {
|
2064
|
+
// i = 0
|
2065
|
+
const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);
|
2066
|
+
const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);
|
2067
|
+
if (need_unpack) {
|
2068
|
+
unpack_B<TB>(Tile0, B_blk0);
|
2069
|
+
_tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
|
2070
|
+
} else {
|
2071
|
+
_tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
|
2072
|
+
}
|
2073
|
+
|
2074
|
+
_tile_zero(TMM4);
|
2075
|
+
_tile_loadd(TMM2, A[0].qs, lda);
|
2076
|
+
_tile_dpbssd(TMM4, TMM2, TMM0);
|
2077
|
+
_tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));
|
2078
|
+
|
2079
|
+
_tile_zero(TMM5);
|
2080
|
+
_tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);
|
2081
|
+
_tile_dpbssd(TMM5, TMM3, TMM0);
|
2082
|
+
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
|
2083
|
+
|
2084
|
+
if (need_unpack) {
|
2085
|
+
unpack_B<TB>(Tile1, B_blk0);
|
2086
|
+
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
|
2087
|
+
} else {
|
2088
|
+
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
|
2089
|
+
}
|
2090
|
+
|
2091
|
+
_tile_zero(TMM6);
|
2092
|
+
_tile_dpbssd(TMM6, TMM2, TMM1);
|
2093
|
+
_tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));
|
2094
|
+
|
2095
|
+
_tile_zero(TMM7);
|
2096
|
+
_tile_dpbssd(TMM7, TMM3, TMM1);
|
2097
|
+
_tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));
|
2098
|
+
|
2099
|
+
for (int i = 1; i < KB; ++i) {
|
2100
|
+
// index of previous iter
|
2101
|
+
const int ii = i - 1;
|
2102
|
+
const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
|
2103
|
+
const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
|
2104
|
+
GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {
|
2105
|
+
if (need_unpack) {
|
2106
|
+
unpack_B<TB>(Tile0, B_blk0);
|
2107
|
+
_tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
|
2108
|
+
} else {
|
2109
|
+
_tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
|
2110
|
+
}
|
2111
|
+
_tile_zero(TMM4);
|
2112
|
+
_tile_loadd(TMM2, A[i].qs, lda);
|
2113
|
+
acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
|
2114
|
+
|
2115
|
+
_tile_dpbssd(TMM4, TMM2, TMM0);
|
2116
|
+
_tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
|
2117
|
+
|
2118
|
+
_tile_zero(TMM5);
|
2119
|
+
_tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);
|
2120
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
|
2121
|
+
|
2122
|
+
_tile_dpbssd(TMM5, TMM3, TMM0);
|
2123
|
+
_tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
|
2124
|
+
|
2125
|
+
if (need_unpack) {
|
2126
|
+
unpack_B<TB>(Tile1, B_blk1);
|
2127
|
+
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
|
2128
|
+
} else {
|
2129
|
+
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
|
2130
|
+
}
|
2131
|
+
_tile_zero(TMM6);
|
2132
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
|
2133
|
+
|
2134
|
+
_tile_dpbssd(TMM6, TMM2, TMM1);
|
2135
|
+
_tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
|
2136
|
+
|
2137
|
+
_tile_zero(TMM7);
|
2138
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
|
2139
|
+
|
2140
|
+
_tile_dpbssd(TMM7, TMM3, TMM1);
|
2141
|
+
_tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
|
2142
|
+
|
2143
|
+
std::swap(C_cur, C_pre);
|
2144
|
+
});
|
2145
|
+
}
|
2146
|
+
// final accumulation
|
2147
|
+
{
|
2148
|
+
int ii = KB - 1;
|
2149
|
+
acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
|
2150
|
+
acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
|
2151
|
+
acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
|
2152
|
+
acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
|
2153
|
+
}
|
2154
|
+
} else {
|
2155
|
+
for (int i = 0; i < KB; ++i) {
|
2156
|
+
_tile_zero(TMM4);
|
2157
|
+
_tile_zero(TMM6);
|
2158
|
+
if (m1 != 0) {
|
2159
|
+
_tile_zero(TMM5);
|
2160
|
+
_tile_zero(TMM7);
|
2161
|
+
}
|
2162
|
+
|
2163
|
+
const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
|
2164
|
+
const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
|
2165
|
+
if (need_unpack) {
|
2166
|
+
unpack_B<TB>(Tile0, B_blk0);
|
2167
|
+
_tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
|
2168
|
+
} else {
|
2169
|
+
_tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
|
2170
|
+
}
|
2171
|
+
|
2172
|
+
if (need_unpack) {
|
2173
|
+
unpack_B<TB>(Tile1, B_blk1);
|
2174
|
+
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
|
2175
|
+
} else {
|
2176
|
+
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
|
2177
|
+
}
|
2178
|
+
|
2179
|
+
if (m0 == TILE_M) {
|
2180
|
+
_tile_loadd(TMM2, A[i].qs, lda);
|
2181
|
+
} else {
|
2182
|
+
unpack_A(Tile23, &A[i], KB, m0);
|
2183
|
+
_tile_loadd(TMM2, Tile23, TILE_K);
|
2184
|
+
}
|
2185
|
+
|
2186
|
+
_tile_dpbssd(TMM4, TMM2, TMM0);
|
2187
|
+
_tile_dpbssd(TMM6, TMM2, TMM1);
|
2188
|
+
|
2189
|
+
_tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
|
2190
|
+
_tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
|
2191
|
+
|
2192
|
+
GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
|
2193
|
+
acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
|
2194
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
|
2195
|
+
});
|
2196
|
+
|
2197
|
+
if (m1 != 0) {
|
2198
|
+
unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);
|
2199
|
+
_tile_loadd(TMM3, Tile23, TILE_K);
|
2200
|
+
|
2201
|
+
_tile_dpbssd(TMM5, TMM3, TMM0);
|
2202
|
+
_tile_dpbssd(TMM7, TMM3, TMM1);
|
2203
|
+
_tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
|
2204
|
+
_tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
|
2205
|
+
GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
|
2206
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
|
2207
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
|
2208
|
+
});
|
2209
|
+
}
|
2210
|
+
}
|
2211
|
+
}
|
2212
|
+
return;
|
2213
|
+
}
|
2214
|
+
|
2215
|
+
template <typename TA, typename TB, typename TC, int BLOCK_K,
|
2216
|
+
typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>
|
2217
|
+
void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
|
2218
|
+
static_assert(std::is_same<TA, block_q8_K>::value);
|
2219
|
+
const int TILE_SIZE = get_tile_size<TB>();
|
2220
|
+
|
2221
|
+
GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
|
2222
|
+
const TA * RESTRICT A = static_cast<const TA *>(_A);
|
2223
|
+
const char * RESTRICT B = static_cast<const char *>(_B);
|
2224
|
+
|
2225
|
+
const int m0 = std::min(M, TILE_M);
|
2226
|
+
const int m1 = std::max(M - TILE_M, 0);
|
2227
|
+
//const int lda = KB * sizeof(TA);
|
2228
|
+
|
2229
|
+
static thread_local int8_t Tile0[TILE_N * TILE_K];
|
2230
|
+
static thread_local int8_t Tile1[TILE_N * TILE_K];
|
2231
|
+
static thread_local int8_t Tile23[TILE_M * TILE_K];
|
2232
|
+
|
2233
|
+
// mat mul result for each group
|
2234
|
+
static thread_local int32_t Tile4[TILE_M * TILE_N];
|
2235
|
+
static thread_local int32_t Tile5[TILE_M * TILE_N];
|
2236
|
+
static thread_local int32_t Tile6[TILE_M * TILE_N];
|
2237
|
+
static thread_local int32_t Tile7[TILE_M * TILE_N];
|
2238
|
+
|
2239
|
+
// sum of each QK_K block, contains 8 groups, int32
|
2240
|
+
static thread_local int32_t Sumi4[TILE_M * TILE_N];
|
2241
|
+
static thread_local int32_t Sumi5[TILE_M * TILE_N];
|
2242
|
+
static thread_local int32_t Sumi6[TILE_M * TILE_N];
|
2243
|
+
static thread_local int32_t Sumi7[TILE_M * TILE_N];
|
2244
|
+
|
2245
|
+
const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
|
2246
|
+
for (int i = 0; i < KB; ++i) {
|
2247
|
+
// step 1: accumulate the quants across 8 groups, each group with 32
|
2248
|
+
for (int k = 0; k < QK_K / k_group_size; ++k) {
|
2249
|
+
GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {
|
2250
|
+
_tile_zero(TMM4);
|
2251
|
+
_tile_zero(TMM6);
|
2252
|
+
|
2253
|
+
unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);
|
2254
|
+
_tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
|
2255
|
+
|
2256
|
+
unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);
|
2257
|
+
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
|
2258
|
+
|
2259
|
+
unpack_A<TB>(Tile23, &A[i], KB, k, m0);
|
2260
|
+
_tile_loadd(TMM2, Tile23, TILE_K);
|
2261
|
+
|
2262
|
+
_tile_dpbssd(TMM4, TMM2, TMM0);
|
2263
|
+
_tile_dpbssd(TMM6, TMM2, TMM1);
|
2264
|
+
|
2265
|
+
_tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
|
2266
|
+
_tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
|
2267
|
+
|
2268
|
+
scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);
|
2269
|
+
scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);
|
2270
|
+
|
2271
|
+
if (m1 != 0) {
|
2272
|
+
_tile_zero(TMM5);
|
2273
|
+
_tile_zero(TMM7);
|
2274
|
+
|
2275
|
+
unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);
|
2276
|
+
_tile_loadd(TMM3, Tile23, TILE_K);
|
2277
|
+
|
2278
|
+
_tile_dpbssd(TMM5, TMM3, TMM0);
|
2279
|
+
_tile_dpbssd(TMM7, TMM3, TMM1);
|
2280
|
+
|
2281
|
+
_tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
|
2282
|
+
_tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
|
2283
|
+
|
2284
|
+
scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);
|
2285
|
+
scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);
|
2286
|
+
}
|
2287
|
+
});
|
2288
|
+
}
|
2289
|
+
|
2290
|
+
// step 2: accmulate the mins
|
2291
|
+
GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
|
2292
|
+
acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
|
2293
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
|
2294
|
+
if (m1 != 0) {
|
2295
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
|
2296
|
+
acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
|
2297
|
+
}
|
2298
|
+
});
|
2299
|
+
}
|
2300
|
+
return;
|
2301
|
+
}
|
2302
|
+
|
2303
|
+
} // anonymous namespace
|
2304
|
+
|
2305
|
+
// get the packed tensor size for quantized weights
|
2306
|
+
size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) {
|
2307
|
+
const enum ggml_type TYPE = tensor->type;
|
2308
|
+
|
2309
|
+
const int K = tensor->ne[0]; // ne0: in_features
|
2310
|
+
const int N = tensor->ne[1]; // ne1: out_features
|
2311
|
+
|
2312
|
+
auto get_tensor_size = [&] {
|
2313
|
+
size_t row_size_B{0};
|
2314
|
+
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
2315
|
+
row_size_B = get_row_size<type, blck_size>(K);
|
2316
|
+
});
|
2317
|
+
return N * row_size_B;
|
2318
|
+
};
|
2319
|
+
|
2320
|
+
if (qtype_has_amx_kernels(TYPE)) {
|
2321
|
+
return get_tensor_size();
|
2322
|
+
} else {
|
2323
|
+
// for f16, bf16 we don't do packing
|
2324
|
+
return ggml_nbytes(tensor);
|
2325
|
+
}
|
2326
|
+
}
|
2327
|
+
|
2328
|
+
// pack weight to vnni format
|
2329
|
+
void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
2330
|
+
|
2331
|
+
size_t alloc_size = ggml_backend_amx_get_alloc_size(tensor);
|
2332
|
+
GGML_ASSERT(alloc_size == size);
|
2333
|
+
|
2334
|
+
const enum ggml_type TYPE = tensor->type;
|
2335
|
+
|
2336
|
+
const int K = tensor->ne[0]; // ne0: in_features
|
2337
|
+
const int N = tensor->ne[1]; // ne1: out_features
|
2338
|
+
|
2339
|
+
#if defined(_OPENMP)
|
2340
|
+
// the buffer ctx is not initialized when .set_tensor is called
|
2341
|
+
int n_threads = omp_get_num_threads();
|
2342
|
+
#else
|
2343
|
+
int n_threads = 1;
|
2344
|
+
#endif
|
2345
|
+
|
2346
|
+
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
2347
|
+
convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads);
|
2348
|
+
});
|
2349
|
+
}
|
2350
|
+
|
2351
|
+
// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)
|
2352
|
+
//
|
2353
|
+
// src0: weight in shape of {N, K}, quantized
|
2354
|
+
// src1: input in shape of {M, K}, float32
|
2355
|
+
// dst: output in shape of {M, N}, float32
|
2356
|
+
//
|
2357
|
+
// the function performs: dst = src1 @ src0.T
|
2358
|
+
//
|
2359
|
+
void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
|
2360
|
+
struct ggml_tensor * src0 = dst->src[0];
|
2361
|
+
struct ggml_tensor * src1 = dst->src[1];
|
2362
|
+
|
2363
|
+
const enum ggml_type TYPE = src0->type;
|
2364
|
+
|
2365
|
+
const int n_threads = ctx->n_threads;
|
2366
|
+
|
2367
|
+
// f16 only has avx512 kernels for now,
|
2368
|
+
// amx kernels will be added once 6th gen xeon is released.
|
2369
|
+
const bool is_floating_type = TYPE == GGML_TYPE_F16;
|
2370
|
+
|
2371
|
+
const int M = dst->ne[1];
|
2372
|
+
const int N = dst->ne[0];
|
2373
|
+
const int K = src0->ne[0];
|
2374
|
+
const int ldc = dst->nb[1] / dst->nb[0];
|
2375
|
+
|
2376
|
+
if (is_floating_type) {
|
2377
|
+
constexpr int BLOCK_M = 4;
|
2378
|
+
constexpr int BLOCK_N = 6;
|
2379
|
+
const int MB = div_up(M, BLOCK_M);
|
2380
|
+
const int NB = div_up(N, BLOCK_N);
|
2381
|
+
|
2382
|
+
parallel_for(n_threads, MB * NB, [&](int begin, int end) {
|
2383
|
+
GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
|
2384
|
+
for (int i = begin; i < end; ++i) {
|
2385
|
+
int mb = i / NB;
|
2386
|
+
int nb = i % NB;
|
2387
|
+
|
2388
|
+
int mb_start = mb * BLOCK_M;
|
2389
|
+
int mb_size = std::min(BLOCK_M, M - mb_start);
|
2390
|
+
int nb_start = nb * BLOCK_N;
|
2391
|
+
int nb_size = std::min(BLOCK_N, N - nb_start);
|
2392
|
+
|
2393
|
+
switch (mb_size << 4 | nb_size) {
|
2394
|
+
case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break;
|
2395
|
+
case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break;
|
2396
|
+
case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break;
|
2397
|
+
case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break;
|
2398
|
+
case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break;
|
2399
|
+
case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break;
|
2400
|
+
case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break;
|
2401
|
+
case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break;
|
2402
|
+
case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break;
|
2403
|
+
case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break;
|
2404
|
+
case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break;
|
2405
|
+
case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break;
|
2406
|
+
default: fprintf(stderr, "Unexpected block size!\n");
|
2407
|
+
}
|
2408
|
+
}
|
2409
|
+
});
|
2410
|
+
});
|
2411
|
+
return;
|
2412
|
+
}
|
2413
|
+
|
2414
|
+
// pointer to work space, used convert A from float to quantized type
|
2415
|
+
void * wdata = nullptr;
|
2416
|
+
|
2417
|
+
//TODO: performance improvement: merge quant A
|
2418
|
+
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
2419
|
+
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
|
2420
|
+
const size_t desired_wsize = M * row_size_A;
|
2421
|
+
if (ctx->work_size < desired_wsize) {
|
2422
|
+
ctx->work_data.reset(new char[desired_wsize]);
|
2423
|
+
ctx->work_size = desired_wsize;
|
2424
|
+
}
|
2425
|
+
wdata = ctx->work_data.get();
|
2426
|
+
|
2427
|
+
// Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
|
2428
|
+
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
|
2429
|
+
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
|
2430
|
+
|
2431
|
+
const float * A_data = static_cast<const float *>(src1->data);
|
2432
|
+
for (int m = 0; m < M; ++m) {
|
2433
|
+
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
|
2434
|
+
}
|
2435
|
+
});
|
2436
|
+
|
2437
|
+
if (M == 1) {
|
2438
|
+
// MB = 1 and handle 8 tiles in each block
|
2439
|
+
constexpr int kTilesN = 4;
|
2440
|
+
constexpr int BLOCK_N = TILE_N * kTilesN;
|
2441
|
+
const int NB = div_up(N, BLOCK_N);
|
2442
|
+
|
2443
|
+
parallel_for(n_threads, NB, [&](int begin, int end) {
|
2444
|
+
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
2445
|
+
const int KB = K / blck_size;
|
2446
|
+
const int TILE_SIZE = get_tile_size<type>();
|
2447
|
+
const int row_size_A = KB * sizeof(vec_dot_type);
|
2448
|
+
for (int i = begin; i < end; ++i) {
|
2449
|
+
int nb = i;
|
2450
|
+
int nb_start = nb * BLOCK_N;
|
2451
|
+
int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
|
2452
|
+
|
2453
|
+
switch (nb_size) {
|
2454
|
+
//case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;
|
2455
|
+
case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break;
|
2456
|
+
case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break;
|
2457
|
+
case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break;
|
2458
|
+
case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break;
|
2459
|
+
default: fprintf(stderr, "Unexpected n block size!\n");
|
2460
|
+
}
|
2461
|
+
}
|
2462
|
+
});
|
2463
|
+
});
|
2464
|
+
return;
|
2465
|
+
}
|
2466
|
+
|
2467
|
+
// handle 4 tiles at a tile
|
2468
|
+
constexpr int BLOCK_M = TILE_M * 2;
|
2469
|
+
constexpr int BLOCK_N = TILE_N * 2;
|
2470
|
+
const int MB = div_up(M, BLOCK_M);
|
2471
|
+
const int NB = div_up(N, BLOCK_N);
|
2472
|
+
|
2473
|
+
parallel_for(n_threads, MB * NB, [&](int begin, int end) {
|
2474
|
+
// init tile config for each thread
|
2475
|
+
ggml_tile_config_init();
|
2476
|
+
|
2477
|
+
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
2478
|
+
const int KB = K / blck_size;
|
2479
|
+
const int TILE_SIZE = get_tile_size<type>();
|
2480
|
+
const int row_size_A = KB * sizeof(vec_dot_type);
|
2481
|
+
|
2482
|
+
for (int i = begin; i < end; ++i) {
|
2483
|
+
int mb = i / NB;
|
2484
|
+
int nb = i % NB;
|
2485
|
+
|
2486
|
+
int mb_start = mb * BLOCK_M;
|
2487
|
+
int mb_size = std::min(BLOCK_M, M - mb_start);
|
2488
|
+
int nb_start = nb * BLOCK_N;
|
2489
|
+
int nb_size = BLOCK_N;
|
2490
|
+
|
2491
|
+
tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
|
2492
|
+
mb_size, nb_size, KB,
|
2493
|
+
(const char *)wdata + mb_start * row_size_A,
|
2494
|
+
(const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
|
2495
|
+
(float *) dst->data + mb_start * N + nb_start, ldc);
|
2496
|
+
}
|
2497
|
+
});
|
2498
|
+
});
|
2499
|
+
}
|
2500
|
+
|
2501
|
+
#else // if defined(__AMX_INT8__)
|
2502
|
+
|
2503
|
+
void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
|
2504
|
+
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
2505
|
+
|
2506
|
+
GGML_UNUSED(ctx);
|
2507
|
+
GGML_UNUSED(dst);
|
2508
|
+
}
|
2509
|
+
|
2510
|
+
#endif // if defined(__AMX_INT8__)
|