llama_cpp 0.1.4 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/examples/README.md +60 -0
- data/examples/chat.rb +195 -0
- data/ext/llama_cpp/extconf.rb +26 -1
- data/ext/llama_cpp/llama_cpp.cpp +262 -13
- data/ext/llama_cpp/src/ggml-cuda.cu +2483 -0
- data/ext/llama_cpp/src/ggml-cuda.h +18 -2
- data/ext/llama_cpp/src/ggml-metal.h +64 -0
- data/ext/llama_cpp/src/ggml-metal.m +834 -0
- data/ext/llama_cpp/src/ggml-metal.metal +1436 -0
- data/ext/llama_cpp/src/ggml-opencl.cpp +207 -40
- data/ext/llama_cpp/src/ggml-opencl.h +4 -1
- data/ext/llama_cpp/src/ggml.c +2236 -404
- data/ext/llama_cpp/src/ggml.h +170 -8
- data/ext/llama_cpp/src/k_quants.c +2244 -0
- data/ext/llama_cpp/src/k_quants.h +122 -0
- data/ext/llama_cpp/src/llama-util.h +16 -0
- data/ext/llama_cpp/src/llama.cpp +631 -179
- data/ext/llama_cpp/src/llama.h +51 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +36 -1
- metadata +10 -2
@@ -0,0 +1,2244 @@
|
|
1
|
+
#include "k_quants.h"
|
2
|
+
#include "ggml.h"
|
3
|
+
|
4
|
+
#include <math.h>
|
5
|
+
#include <string.h>
|
6
|
+
#include <assert.h>
|
7
|
+
|
8
|
+
#ifdef __ARM_NEON
|
9
|
+
|
10
|
+
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
11
|
+
//
|
12
|
+
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
13
|
+
//
|
14
|
+
#include <arm_neon.h>
|
15
|
+
|
16
|
+
#else
|
17
|
+
|
18
|
+
#ifdef __wasm_simd128__
|
19
|
+
#include <wasm_simd128.h>
|
20
|
+
#else
|
21
|
+
#ifdef __POWER9_VECTOR__
|
22
|
+
#include <altivec.h>
|
23
|
+
#undef bool
|
24
|
+
#define bool _Bool
|
25
|
+
#else
|
26
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
27
|
+
#include <intrin.h>
|
28
|
+
#else
|
29
|
+
#if !defined(__riscv)
|
30
|
+
#include <immintrin.h>
|
31
|
+
#endif
|
32
|
+
#endif
|
33
|
+
#endif
|
34
|
+
#endif
|
35
|
+
#endif
|
36
|
+
|
37
|
+
#undef MIN
|
38
|
+
#undef MAX
|
39
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
40
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
41
|
+
|
42
|
+
//
|
43
|
+
// 2-6 bit quantization in super-blocks
|
44
|
+
//
|
45
|
+
|
46
|
+
|
47
|
+
//
|
48
|
+
// ===================== Helper functions
|
49
|
+
//
|
50
|
+
static inline int nearest_int(float fval) {
|
51
|
+
assert(fval <= 4194303.f);
|
52
|
+
float val = fval + 12582912.f;
|
53
|
+
int i; memcpy(&i, &val, sizeof(int));
|
54
|
+
return (i & 0x007fffff) - 0x00400000;
|
55
|
+
}
|
56
|
+
|
57
|
+
static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
|
58
|
+
float max = 0;
|
59
|
+
float amax = 0;
|
60
|
+
for (int i = 0; i < n; ++i) {
|
61
|
+
float ax = fabsf(x[i]);
|
62
|
+
if (ax > amax) { amax = ax; max = x[i]; }
|
63
|
+
}
|
64
|
+
if (!amax) { // all zero
|
65
|
+
for (int i = 0; i < n; ++i) {
|
66
|
+
L[i] = 0;
|
67
|
+
}
|
68
|
+
return 0.f;
|
69
|
+
}
|
70
|
+
float iscale = -nmax / max;
|
71
|
+
if (rmse_type == 0) {
|
72
|
+
for (int i = 0; i < n; ++i) {
|
73
|
+
int l = nearest_int(iscale * x[i]);
|
74
|
+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
75
|
+
}
|
76
|
+
return 1/iscale;
|
77
|
+
}
|
78
|
+
int weight_type = rmse_type%2;
|
79
|
+
float sumlx = 0;
|
80
|
+
float suml2 = 0;
|
81
|
+
for (int i = 0; i < n; ++i) {
|
82
|
+
int l = nearest_int(iscale * x[i]);
|
83
|
+
l = MAX(-nmax, MIN(nmax-1, l));
|
84
|
+
L[i] = l + nmax;
|
85
|
+
float w = weight_type == 1 ? x[i] * x[i] : 1;
|
86
|
+
sumlx += w*x[i]*l;
|
87
|
+
suml2 += w*l*l;
|
88
|
+
}
|
89
|
+
float scale = sumlx/suml2;
|
90
|
+
float best = scale * sumlx;
|
91
|
+
for (int itry = 0; itry < 3; ++itry) {
|
92
|
+
iscale = 1/scale;
|
93
|
+
float slx = 0;
|
94
|
+
float sl2 = 0;
|
95
|
+
bool changed = false;
|
96
|
+
for (int i = 0; i < n; ++i) {
|
97
|
+
int l = nearest_int(iscale * x[i]);
|
98
|
+
l = MAX(-nmax, MIN(nmax-1, l));
|
99
|
+
if (l + nmax != L[i]) { changed = true; }
|
100
|
+
float w = weight_type == 1 ? x[i] * x[i] : 1.f;
|
101
|
+
slx += w*x[i]*l;
|
102
|
+
sl2 += w*l*l;
|
103
|
+
}
|
104
|
+
if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
|
105
|
+
for (int i = 0; i < n; ++i) {
|
106
|
+
int l = nearest_int(iscale * x[i]);
|
107
|
+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
108
|
+
}
|
109
|
+
sumlx = slx; suml2 = sl2;
|
110
|
+
scale = sumlx/suml2;
|
111
|
+
best = scale * sumlx;
|
112
|
+
}
|
113
|
+
for (int itry = 0; itry < 5; ++itry) {
|
114
|
+
int n_changed = 0;
|
115
|
+
for (int i = 0; i < n; ++i) {
|
116
|
+
float w = weight_type == 1 ? x[i]*x[i] : 1;
|
117
|
+
int l = L[i] - nmax;
|
118
|
+
float slx = sumlx - w*x[i]*l;
|
119
|
+
if (slx > 0) {
|
120
|
+
float sl2 = suml2 - w*l*l;
|
121
|
+
int new_l = nearest_int(x[i] * sl2 / slx);
|
122
|
+
new_l = MAX(-nmax, MIN(nmax-1, new_l));
|
123
|
+
if (new_l != l) {
|
124
|
+
slx += w*x[i]*new_l;
|
125
|
+
sl2 += w*new_l*new_l;
|
126
|
+
if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
|
127
|
+
L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
|
128
|
+
scale = sumlx / suml2; best = scale * sumlx;
|
129
|
+
++n_changed;
|
130
|
+
}
|
131
|
+
}
|
132
|
+
}
|
133
|
+
}
|
134
|
+
if (!n_changed) { break; }
|
135
|
+
}
|
136
|
+
if (rmse_type < 3) {
|
137
|
+
return scale;
|
138
|
+
}
|
139
|
+
for (int is = -4; is <= 4; ++is) {
|
140
|
+
if (is == 0) {
|
141
|
+
continue;
|
142
|
+
}
|
143
|
+
iscale = -(nmax + 0.1f*is) / max;
|
144
|
+
sumlx = suml2 = 0;
|
145
|
+
for (int i = 0; i < n; ++i) {
|
146
|
+
int l = nearest_int(iscale * x[i]);
|
147
|
+
l = MAX(-nmax, MIN(nmax-1, l));
|
148
|
+
float w = weight_type == 1 ? x[i] * x[i] : 1;
|
149
|
+
sumlx += w*x[i]*l;
|
150
|
+
suml2 += w*l*l;
|
151
|
+
}
|
152
|
+
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
|
153
|
+
for (int i = 0; i < n; ++i) {
|
154
|
+
int l = nearest_int(iscale * x[i]);
|
155
|
+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
|
156
|
+
}
|
157
|
+
scale = sumlx/suml2; best = scale*sumlx;
|
158
|
+
}
|
159
|
+
}
|
160
|
+
return scale;
|
161
|
+
}
|
162
|
+
|
163
|
+
static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
|
164
|
+
float max = 0;
|
165
|
+
float amax = 0;
|
166
|
+
for (int i = 0; i < n; ++i) {
|
167
|
+
float ax = fabsf(x[i]);
|
168
|
+
if (ax > amax) { amax = ax; max = x[i]; }
|
169
|
+
}
|
170
|
+
if (!amax) { // all zero
|
171
|
+
for (int i = 0; i < n; ++i) { L[i] = 0; }
|
172
|
+
return 0.f;
|
173
|
+
}
|
174
|
+
float iscale = -nmax / max;
|
175
|
+
if (do_rmse) {
|
176
|
+
float sumlx = 0;
|
177
|
+
float suml2 = 0;
|
178
|
+
for (int i = 0; i < n; ++i) {
|
179
|
+
int l = nearest_int(iscale * x[i]);
|
180
|
+
l = MAX(-nmax, MIN(nmax-1, l));
|
181
|
+
L[i] = l;
|
182
|
+
float w = x[i]*x[i];
|
183
|
+
sumlx += w*x[i]*l;
|
184
|
+
suml2 += w*l*l;
|
185
|
+
}
|
186
|
+
for (int itry = 0; itry < 5; ++itry) {
|
187
|
+
int n_changed = 0;
|
188
|
+
for (int i = 0; i < n; ++i) {
|
189
|
+
float w = x[i]*x[i];
|
190
|
+
float slx = sumlx - w*x[i]*L[i];
|
191
|
+
if (slx > 0) {
|
192
|
+
float sl2 = suml2 - w*L[i]*L[i];
|
193
|
+
int new_l = nearest_int(x[i] * sl2 / slx);
|
194
|
+
new_l = MAX(-nmax, MIN(nmax-1, new_l));
|
195
|
+
if (new_l != L[i]) {
|
196
|
+
slx += w*x[i]*new_l;
|
197
|
+
sl2 += w*new_l*new_l;
|
198
|
+
if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
|
199
|
+
L[i] = new_l; sumlx = slx; suml2 = sl2;
|
200
|
+
++n_changed;
|
201
|
+
}
|
202
|
+
}
|
203
|
+
}
|
204
|
+
}
|
205
|
+
if (!n_changed) {
|
206
|
+
break;
|
207
|
+
}
|
208
|
+
}
|
209
|
+
for (int i = 0; i < n; ++i) {
|
210
|
+
L[i] += nmax;
|
211
|
+
}
|
212
|
+
return sumlx / suml2;
|
213
|
+
}
|
214
|
+
for (int i = 0; i < n; ++i) {
|
215
|
+
int l = nearest_int(iscale * x[i]);
|
216
|
+
l = MAX(-nmax, MIN(nmax-1, l));
|
217
|
+
L[i] = l + nmax;
|
218
|
+
}
|
219
|
+
return 1/iscale;
|
220
|
+
}
|
221
|
+
|
222
|
+
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
|
223
|
+
float min = x[0];
|
224
|
+
float max = x[0];
|
225
|
+
for (int i = 1; i < n; ++i) {
|
226
|
+
if (x[i] < min) min = x[i];
|
227
|
+
if (x[i] > max) max = x[i];
|
228
|
+
}
|
229
|
+
if (max == min) {
|
230
|
+
for (int i = 0; i < n; ++i) L[i] = 0;
|
231
|
+
*the_min = 0;
|
232
|
+
return 0.f;
|
233
|
+
}
|
234
|
+
if (min > 0) min = 0;
|
235
|
+
float iscale = nmax/(max - min);
|
236
|
+
float scale = 1/iscale;
|
237
|
+
for (int itry = 0; itry < ntry; ++itry) {
|
238
|
+
float sumlx = 0; int suml2 = 0;
|
239
|
+
bool did_change = false;
|
240
|
+
for (int i = 0; i < n; ++i) {
|
241
|
+
int l = nearest_int(iscale*(x[i] - min));
|
242
|
+
l = MAX(0, MIN(nmax, l));
|
243
|
+
if (l != L[i]) {
|
244
|
+
L[i] = l;
|
245
|
+
did_change = true;
|
246
|
+
}
|
247
|
+
sumlx += (x[i] - min)*l;
|
248
|
+
suml2 += l*l;
|
249
|
+
}
|
250
|
+
scale = sumlx/suml2;
|
251
|
+
float sum = 0;
|
252
|
+
for (int i = 0; i < n; ++i) {
|
253
|
+
sum += x[i] - scale*L[i];
|
254
|
+
}
|
255
|
+
min = sum/n;
|
256
|
+
if (min > 0) min = 0;
|
257
|
+
iscale = 1/scale;
|
258
|
+
if (!did_change) break;
|
259
|
+
}
|
260
|
+
*the_min = -min;
|
261
|
+
return scale;
|
262
|
+
}
|
263
|
+
|
264
|
+
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
265
|
+
if (j < 4) {
|
266
|
+
*d = q[j] & 63; *m = q[j + 4] & 63;
|
267
|
+
} else {
|
268
|
+
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
269
|
+
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
270
|
+
}
|
271
|
+
}
|
272
|
+
|
273
|
+
//========================- 2-bit (de)-quantization
|
274
|
+
|
275
|
+
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
|
276
|
+
assert(k % QK_K == 0);
|
277
|
+
const int nb = k / QK_K;
|
278
|
+
|
279
|
+
uint8_t L[QK_K];
|
280
|
+
float mins[QK_K/16];
|
281
|
+
float scales[QK_K/16];
|
282
|
+
|
283
|
+
const float q4scale = 15.f;
|
284
|
+
|
285
|
+
for (int i = 0; i < nb; i++) {
|
286
|
+
|
287
|
+
float max_scale = 0; // as we are deducting the min, scales are always positive
|
288
|
+
float max_min = 0;
|
289
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
290
|
+
scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
|
291
|
+
float scale = scales[j];
|
292
|
+
if (scale > max_scale) {
|
293
|
+
max_scale = scale;
|
294
|
+
}
|
295
|
+
float min = mins[j];
|
296
|
+
if (min > max_min) {
|
297
|
+
max_min = min;
|
298
|
+
}
|
299
|
+
}
|
300
|
+
|
301
|
+
if (max_scale > 0) {
|
302
|
+
float iscale = q4scale/max_scale;
|
303
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
304
|
+
int l = nearest_int(iscale*scales[j]);
|
305
|
+
y[i].scales[j] = l;
|
306
|
+
}
|
307
|
+
y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
|
308
|
+
} else {
|
309
|
+
for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
|
310
|
+
y[i].d = ggml_fp32_to_fp16(0.f);
|
311
|
+
}
|
312
|
+
if (max_min > 0) {
|
313
|
+
float iscale = q4scale/max_min;
|
314
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
315
|
+
int l = nearest_int(iscale*mins[j]);
|
316
|
+
y[i].scales[j] |= (l << 4);
|
317
|
+
}
|
318
|
+
y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
|
319
|
+
} else {
|
320
|
+
y[i].dmin = ggml_fp32_to_fp16(0.f);
|
321
|
+
}
|
322
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
323
|
+
const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
|
324
|
+
if (!d) continue;
|
325
|
+
const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
|
326
|
+
for (int ii = 0; ii < 16; ++ii) {
|
327
|
+
int l = nearest_int((x[16*j + ii] + dm)/d);
|
328
|
+
l = MAX(0, MIN(3, l));
|
329
|
+
L[16*j + ii] = l;
|
330
|
+
}
|
331
|
+
}
|
332
|
+
|
333
|
+
for (int j = 0; j < QK_K; j += 128) {
|
334
|
+
for (int l = 0; l < 32; ++l) {
|
335
|
+
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
336
|
+
}
|
337
|
+
}
|
338
|
+
|
339
|
+
x += QK_K;
|
340
|
+
|
341
|
+
}
|
342
|
+
}
|
343
|
+
|
344
|
+
void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
|
345
|
+
assert(k % QK_K == 0);
|
346
|
+
const int nb = k / QK_K;
|
347
|
+
|
348
|
+
for (int i = 0; i < nb; i++) {
|
349
|
+
|
350
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
351
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
352
|
+
|
353
|
+
const uint8_t * q = x[i].qs;
|
354
|
+
|
355
|
+
int is = 0;
|
356
|
+
float dl, ml;
|
357
|
+
for (int n = 0; n < QK_K; n += 128) {
|
358
|
+
int shift = 0;
|
359
|
+
for (int j = 0; j < 4; ++j) {
|
360
|
+
|
361
|
+
uint8_t sc = x[i].scales[is++];
|
362
|
+
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
363
|
+
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
|
364
|
+
|
365
|
+
sc = x[i].scales[is++];
|
366
|
+
dl = d * (sc & 0xF); ml = min * (sc >> 4);
|
367
|
+
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
|
368
|
+
|
369
|
+
shift += 2;
|
370
|
+
}
|
371
|
+
q += 32;
|
372
|
+
}
|
373
|
+
|
374
|
+
}
|
375
|
+
}
|
376
|
+
|
377
|
+
void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
|
378
|
+
quantize_row_q2_K_reference(x, vy, k);
|
379
|
+
}
|
380
|
+
|
381
|
+
size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
|
382
|
+
const int nb = k / QK_K;
|
383
|
+
|
384
|
+
// TODO - collect histograms - although, at a second thought, I don't really care about them
|
385
|
+
(void)hist;
|
386
|
+
|
387
|
+
for (int j = 0; j < nb; j += k) {
|
388
|
+
block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
|
389
|
+
quantize_row_q2_K_reference(src + j, y, k);
|
390
|
+
}
|
391
|
+
return (n/QK_K*sizeof(block_q2_K));
|
392
|
+
}
|
393
|
+
|
394
|
+
//========================= 3-bit (de)-quantization
|
395
|
+
|
396
|
+
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
|
397
|
+
assert(k % QK_K == 0);
|
398
|
+
const int nb = k / QK_K;
|
399
|
+
|
400
|
+
int8_t L[QK_K];
|
401
|
+
float scales[QK_K / 16];
|
402
|
+
|
403
|
+
for (int i = 0; i < nb; i++) {
|
404
|
+
|
405
|
+
float max_scale = 0;
|
406
|
+
float amax = 0;
|
407
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
408
|
+
scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
|
409
|
+
float scale = fabsf(scales[j]);
|
410
|
+
if (scale > amax) {
|
411
|
+
amax = scale; max_scale = scales[j];
|
412
|
+
}
|
413
|
+
}
|
414
|
+
|
415
|
+
memset(y[i].scales, 0, 12);
|
416
|
+
if (max_scale) {
|
417
|
+
float iscale = -32.f/max_scale;
|
418
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
419
|
+
int8_t l = nearest_int(iscale*scales[j]);
|
420
|
+
l = MAX(-32, MIN(31, l)) + 32;
|
421
|
+
if (j < 8) {
|
422
|
+
y[i].scales[j] = l & 0xF;
|
423
|
+
} else {
|
424
|
+
y[i].scales[j-8] |= ((l & 0xF) << 4);
|
425
|
+
}
|
426
|
+
l >>= 4;
|
427
|
+
y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
|
428
|
+
}
|
429
|
+
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
430
|
+
} else {
|
431
|
+
y[i].d = ggml_fp32_to_fp16(0.f);
|
432
|
+
}
|
433
|
+
|
434
|
+
int8_t sc;
|
435
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
436
|
+
sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
|
437
|
+
sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
|
438
|
+
float d = ggml_fp16_to_fp32(y[i].d) * sc;
|
439
|
+
if (!d) {
|
440
|
+
continue;
|
441
|
+
}
|
442
|
+
for (int ii = 0; ii < 16; ++ii) {
|
443
|
+
int l = nearest_int(x[16*j + ii]/d);
|
444
|
+
l = MAX(-4, MIN(3, l));
|
445
|
+
L[16*j + ii] = l + 4;
|
446
|
+
}
|
447
|
+
}
|
448
|
+
|
449
|
+
memset(y[i].hmask, 0, QK_K/8);
|
450
|
+
// We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
|
451
|
+
int m = 0;
|
452
|
+
uint8_t hm = 1;
|
453
|
+
for (int j = 0; j < QK_K; ++j) {
|
454
|
+
if (L[j] > 3) {
|
455
|
+
y[i].hmask[m] |= hm;
|
456
|
+
L[j] -= 4;
|
457
|
+
}
|
458
|
+
if (++m == QK_K/8) {
|
459
|
+
m = 0; hm <<= 1;
|
460
|
+
}
|
461
|
+
}
|
462
|
+
for (int j = 0; j < QK_K; j += 128) {
|
463
|
+
for (int l = 0; l < 32; ++l) {
|
464
|
+
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
465
|
+
}
|
466
|
+
}
|
467
|
+
|
468
|
+
x += QK_K;
|
469
|
+
}
|
470
|
+
}
|
471
|
+
|
472
|
+
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
|
473
|
+
assert(k % QK_K == 0);
|
474
|
+
assert(QK_K == 256);
|
475
|
+
const int nb = k / QK_K;
|
476
|
+
|
477
|
+
const uint32_t kmask1 = 0x03030303;
|
478
|
+
const uint32_t kmask2 = 0x0f0f0f0f;
|
479
|
+
|
480
|
+
uint32_t aux[4];
|
481
|
+
const int8_t * scales = (const int8_t*)aux;
|
482
|
+
|
483
|
+
for (int i = 0; i < nb; i++) {
|
484
|
+
|
485
|
+
const float d_all = ggml_fp16_to_fp32(x[i].d);
|
486
|
+
|
487
|
+
const uint8_t * restrict q = x[i].qs;
|
488
|
+
const uint8_t * restrict hm = x[i].hmask;
|
489
|
+
uint8_t m = 1;
|
490
|
+
|
491
|
+
memcpy(aux, x[i].scales, 12);
|
492
|
+
uint32_t tmp = aux[2];
|
493
|
+
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
494
|
+
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
495
|
+
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
496
|
+
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
497
|
+
|
498
|
+
int is = 0;
|
499
|
+
float dl;
|
500
|
+
for (int n = 0; n < QK_K; n += 128) {
|
501
|
+
int shift = 0;
|
502
|
+
for (int j = 0; j < 4; ++j) {
|
503
|
+
|
504
|
+
dl = d_all * (scales[is++] - 32);
|
505
|
+
for (int l = 0; l < 16; ++l) {
|
506
|
+
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
|
507
|
+
}
|
508
|
+
|
509
|
+
dl = d_all * (scales[is++] - 32);
|
510
|
+
for (int l = 0; l < 16; ++l) {
|
511
|
+
*y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
|
512
|
+
}
|
513
|
+
|
514
|
+
shift += 2;
|
515
|
+
m <<= 1;
|
516
|
+
}
|
517
|
+
q += 32;
|
518
|
+
}
|
519
|
+
|
520
|
+
}
|
521
|
+
}
|
522
|
+
|
523
|
+
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
|
524
|
+
quantize_row_q3_K_reference(x, vy, k);
|
525
|
+
}
|
526
|
+
|
527
|
+
size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
|
528
|
+
const int nb = k / QK_K;
|
529
|
+
|
530
|
+
// TODO - collect histograms - although, at a second thought, I don't really care about them
|
531
|
+
(void)hist;
|
532
|
+
|
533
|
+
for (int j = 0; j < nb; j += k) {
|
534
|
+
block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
|
535
|
+
quantize_row_q3_K_reference(src + j, y, k);
|
536
|
+
}
|
537
|
+
return (n/QK_K*sizeof(block_q3_K));
|
538
|
+
}
|
539
|
+
|
540
|
+
// ====================== 4-bit (de)-quantization
|
541
|
+
|
542
|
+
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
|
543
|
+
assert(k % QK_K == 0);
|
544
|
+
const int nb = k / QK_K;
|
545
|
+
|
546
|
+
uint8_t L[QK_K];
|
547
|
+
float mins[QK_K/32];
|
548
|
+
float scales[QK_K/32];
|
549
|
+
|
550
|
+
for (int i = 0; i < nb; i++) {
|
551
|
+
|
552
|
+
float max_scale = 0; // as we are deducting the min, scales are always positive
|
553
|
+
float max_min = 0;
|
554
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
555
|
+
scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
|
556
|
+
float scale = scales[j];
|
557
|
+
if (scale > max_scale) {
|
558
|
+
max_scale = scale;
|
559
|
+
}
|
560
|
+
float min = mins[j];
|
561
|
+
if (min > max_min) {
|
562
|
+
max_min = min;
|
563
|
+
}
|
564
|
+
}
|
565
|
+
|
566
|
+
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
567
|
+
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
568
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
569
|
+
uint8_t ls = nearest_int(inv_scale*scales[j]);
|
570
|
+
uint8_t lm = nearest_int(inv_min*mins[j]);
|
571
|
+
ls = MIN(63, ls);
|
572
|
+
lm = MIN(63, lm);
|
573
|
+
if (j < 4) {
|
574
|
+
y[i].scales[j] = ls;
|
575
|
+
y[i].scales[j+4] = lm;
|
576
|
+
} else {
|
577
|
+
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
|
578
|
+
y[i].scales[j-4] |= ((ls >> 4) << 6);
|
579
|
+
y[i].scales[j-0] |= ((lm >> 4) << 6);
|
580
|
+
}
|
581
|
+
}
|
582
|
+
y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
|
583
|
+
y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
|
584
|
+
|
585
|
+
uint8_t sc, m;
|
586
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
587
|
+
get_scale_min_k4(j, y[i].scales, &sc, &m);
|
588
|
+
const float d = ggml_fp16_to_fp32(y[i].d) * sc;
|
589
|
+
if (!d) continue;
|
590
|
+
const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
|
591
|
+
for (int ii = 0; ii < 32; ++ii) {
|
592
|
+
int l = nearest_int((x[32*j + ii] + dm)/d);
|
593
|
+
l = MAX(0, MIN(15, l));
|
594
|
+
L[32*j + ii] = l;
|
595
|
+
}
|
596
|
+
}
|
597
|
+
uint8_t * q = y[i].qs;
|
598
|
+
for (int j = 0; j < QK_K; j += 64) {
|
599
|
+
for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
|
600
|
+
}
|
601
|
+
|
602
|
+
x += QK_K;
|
603
|
+
|
604
|
+
}
|
605
|
+
}
|
606
|
+
|
607
|
+
void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
|
608
|
+
assert(k % QK_K == 0);
|
609
|
+
const int nb = k / QK_K;
|
610
|
+
|
611
|
+
for (int i = 0; i < nb; i++) {
|
612
|
+
|
613
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
614
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
615
|
+
|
616
|
+
const uint8_t * q = x[i].qs;
|
617
|
+
|
618
|
+
int is = 0;
|
619
|
+
uint8_t sc, m;
|
620
|
+
for (int j = 0; j < QK_K; j += 64) {
|
621
|
+
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
622
|
+
const float d1 = d * sc; const float m1 = min * m;
|
623
|
+
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
624
|
+
const float d2 = d * sc; const float m2 = min * m;
|
625
|
+
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
626
|
+
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
627
|
+
q += 32; is += 2;
|
628
|
+
}
|
629
|
+
|
630
|
+
}
|
631
|
+
}
|
632
|
+
|
633
|
+
void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
|
634
|
+
assert(k % QK_K == 0);
|
635
|
+
block_q4_K * restrict y = vy;
|
636
|
+
quantize_row_q4_K_reference(x, y, k);
|
637
|
+
}
|
638
|
+
|
639
|
+
size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
|
640
|
+
assert(k % QK_K == 0);
|
641
|
+
const int nb = k / QK_K;
|
642
|
+
(void)hist; // TODO: collect histograms
|
643
|
+
for (int j = 0; j < nb; j += k) {
|
644
|
+
block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
|
645
|
+
quantize_row_q4_K_reference(src + j, y, k);
|
646
|
+
}
|
647
|
+
return (n/QK_K*sizeof(block_q4_K));
|
648
|
+
}
|
649
|
+
|
650
|
+
// ====================== 5-bit (de)-quantization
|
651
|
+
|
652
|
+
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
|
653
|
+
assert(k % QK_K == 0);
|
654
|
+
const int nb = k / QK_K;
|
655
|
+
|
656
|
+
uint8_t L[QK_K];
|
657
|
+
float mins[QK_K/32];
|
658
|
+
float scales[QK_K/32];
|
659
|
+
|
660
|
+
for (int i = 0; i < nb; i++) {
|
661
|
+
|
662
|
+
float max_scale = 0; // as we are deducting the min, scales are always positive
|
663
|
+
float max_min = 0;
|
664
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
665
|
+
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
|
666
|
+
float scale = scales[j];
|
667
|
+
if (scale > max_scale) {
|
668
|
+
max_scale = scale;
|
669
|
+
}
|
670
|
+
float min = mins[j];
|
671
|
+
if (min > max_min) {
|
672
|
+
max_min = min;
|
673
|
+
}
|
674
|
+
}
|
675
|
+
|
676
|
+
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
677
|
+
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
678
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
679
|
+
uint8_t ls = nearest_int(inv_scale*scales[j]);
|
680
|
+
uint8_t lm = nearest_int(inv_min*mins[j]);
|
681
|
+
ls = MIN(63, ls);
|
682
|
+
lm = MIN(63, lm);
|
683
|
+
if (j < 4) {
|
684
|
+
y[i].scales[j] = ls;
|
685
|
+
y[i].scales[j+4] = lm;
|
686
|
+
} else {
|
687
|
+
y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
|
688
|
+
y[i].scales[j-4] |= ((ls >> 4) << 6);
|
689
|
+
y[i].scales[j-0] |= ((lm >> 4) << 6);
|
690
|
+
}
|
691
|
+
}
|
692
|
+
y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
|
693
|
+
y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
|
694
|
+
|
695
|
+
uint8_t sc, m;
|
696
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
697
|
+
get_scale_min_k4(j, y[i].scales, &sc, &m);
|
698
|
+
const float d = ggml_fp16_to_fp32(y[i].d) * sc;
|
699
|
+
if (!d) continue;
|
700
|
+
const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
|
701
|
+
for (int ii = 0; ii < 32; ++ii) {
|
702
|
+
int l = nearest_int((x[32*j + ii] + dm)/d);
|
703
|
+
l = MAX(0, MIN(31, l));
|
704
|
+
L[32*j + ii] = l;
|
705
|
+
}
|
706
|
+
}
|
707
|
+
|
708
|
+
uint8_t * restrict qh = y[i].qh;
|
709
|
+
uint8_t * restrict ql = y[i].qs;
|
710
|
+
memset(qh, 0, QK_K/8);
|
711
|
+
|
712
|
+
uint8_t m1 = 1, m2 = 2;
|
713
|
+
for (int n = 0; n < QK_K; n += 64) {
|
714
|
+
for (int j = 0; j < 32; ++j) {
|
715
|
+
int l1 = L[n + j];
|
716
|
+
if (l1 > 15) {
|
717
|
+
l1 -= 16; qh[j] |= m1;
|
718
|
+
}
|
719
|
+
int l2 = L[n + j + 32];
|
720
|
+
if (l2 > 15) {
|
721
|
+
l2 -= 16; qh[j] |= m2;
|
722
|
+
}
|
723
|
+
ql[j] = l1 | (l2 << 4);
|
724
|
+
}
|
725
|
+
m1 <<= 2; m2 <<= 2;
|
726
|
+
ql += 32;
|
727
|
+
}
|
728
|
+
|
729
|
+
x += QK_K;
|
730
|
+
|
731
|
+
}
|
732
|
+
}
|
733
|
+
|
734
|
+
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
|
735
|
+
assert(k % QK_K == 0);
|
736
|
+
const int nb = k / QK_K;
|
737
|
+
|
738
|
+
for (int i = 0; i < nb; i++) {
|
739
|
+
|
740
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
741
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
742
|
+
|
743
|
+
const uint8_t * ql = x[i].qs;
|
744
|
+
const uint8_t * qh = x[i].qh;
|
745
|
+
|
746
|
+
int is = 0;
|
747
|
+
uint8_t sc, m;
|
748
|
+
uint8_t u1 = 1, u2 = 2;
|
749
|
+
for (int j = 0; j < QK_K; j += 64) {
|
750
|
+
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
751
|
+
const float d1 = d * sc; const float m1 = min * m;
|
752
|
+
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
753
|
+
const float d2 = d * sc; const float m2 = min * m;
|
754
|
+
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
755
|
+
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
756
|
+
ql += 32; is += 2;
|
757
|
+
u1 <<= 2; u2 <<= 2;
|
758
|
+
}
|
759
|
+
}
|
760
|
+
}
|
761
|
+
|
762
|
+
void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
|
763
|
+
assert(k % QK_K == 0);
|
764
|
+
block_q5_K * restrict y = vy;
|
765
|
+
quantize_row_q5_K_reference(x, y, k);
|
766
|
+
}
|
767
|
+
|
768
|
+
size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
|
769
|
+
assert(k % QK_K == 0);
|
770
|
+
const int nb = k / QK_K;
|
771
|
+
(void)hist;
|
772
|
+
for (int j = 0; j < nb; j += k) {
|
773
|
+
block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
|
774
|
+
quantize_row_q5_K_reference(src + j, y, k);
|
775
|
+
}
|
776
|
+
return (n/QK_K*sizeof(block_q5_K));
|
777
|
+
}
|
778
|
+
|
779
|
+
// ====================== 6-bit (de)-quantization
|
780
|
+
|
781
|
+
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
|
782
|
+
assert(k % QK_K == 0);
|
783
|
+
const int nb = k / QK_K;
|
784
|
+
|
785
|
+
int8_t L[QK_K];
|
786
|
+
float scales[QK_K/16];
|
787
|
+
|
788
|
+
for (int i = 0; i < nb; i++) {
|
789
|
+
|
790
|
+
float max_scale = 0;
|
791
|
+
float max_abs_scale = 0;
|
792
|
+
|
793
|
+
for (int ib = 0; ib < QK_K/16; ++ib) {
|
794
|
+
|
795
|
+
const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
|
796
|
+
scales[ib] = scale;
|
797
|
+
|
798
|
+
const float abs_scale = fabsf(scale);
|
799
|
+
if (abs_scale > max_abs_scale) {
|
800
|
+
max_abs_scale = abs_scale;
|
801
|
+
max_scale = scale;
|
802
|
+
}
|
803
|
+
|
804
|
+
}
|
805
|
+
|
806
|
+
float iscale = -128.f/max_scale;
|
807
|
+
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
808
|
+
for (int ib = 0; ib < QK_K/16; ++ib) {
|
809
|
+
y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
|
810
|
+
}
|
811
|
+
|
812
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
813
|
+
float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
|
814
|
+
if (!d) {
|
815
|
+
continue;
|
816
|
+
}
|
817
|
+
for (int ii = 0; ii < 16; ++ii) {
|
818
|
+
int l = nearest_int(x[16*j + ii]/d);
|
819
|
+
l = MAX(-32, MIN(31, l));
|
820
|
+
L[16*j + ii] = l + 32;
|
821
|
+
}
|
822
|
+
}
|
823
|
+
|
824
|
+
uint8_t * restrict ql = y[i].ql;
|
825
|
+
uint8_t * restrict qh = y[i].qh;
|
826
|
+
for (int j = 0; j < QK_K; j += 128) {
|
827
|
+
for (int l = 0; l < 32; ++l) {
|
828
|
+
const uint8_t q1 = L[j + l + 0] & 0xF;
|
829
|
+
const uint8_t q2 = L[j + l + 32] & 0xF;
|
830
|
+
const uint8_t q3 = L[j + l + 64] & 0xF;
|
831
|
+
const uint8_t q4 = L[j + l + 96] & 0xF;
|
832
|
+
ql[l+ 0] = q1 | (q3 << 4);
|
833
|
+
ql[l+32] = q2 | (q4 << 4);
|
834
|
+
qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
|
835
|
+
}
|
836
|
+
ql += 64;
|
837
|
+
qh += 32;
|
838
|
+
}
|
839
|
+
|
840
|
+
x += QK_K;
|
841
|
+
|
842
|
+
}
|
843
|
+
}
|
844
|
+
|
845
|
+
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
|
846
|
+
assert(k % QK_K == 0);
|
847
|
+
const int nb = k / QK_K;
|
848
|
+
|
849
|
+
for (int i = 0; i < nb; i++) {
|
850
|
+
|
851
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
852
|
+
|
853
|
+
const uint8_t * restrict ql = x[i].ql;
|
854
|
+
const uint8_t * restrict qh = x[i].qh;
|
855
|
+
const int8_t * restrict sc = x[i].scales;
|
856
|
+
|
857
|
+
for (int n = 0; n < QK_K; n += 128) {
|
858
|
+
for (int l = 0; l < 32; ++l) {
|
859
|
+
int is = l/16;
|
860
|
+
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
861
|
+
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
862
|
+
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
863
|
+
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
864
|
+
y[l + 0] = d * sc[is + 0] * q1;
|
865
|
+
y[l + 32] = d * sc[is + 2] * q2;
|
866
|
+
y[l + 64] = d * sc[is + 4] * q3;
|
867
|
+
y[l + 96] = d * sc[is + 6] * q4;
|
868
|
+
}
|
869
|
+
y += 128;
|
870
|
+
ql += 64;
|
871
|
+
qh += 32;
|
872
|
+
sc += 8;
|
873
|
+
}
|
874
|
+
|
875
|
+
}
|
876
|
+
}
|
877
|
+
|
878
|
+
void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
|
879
|
+
assert(k % QK_K == 0);
|
880
|
+
block_q6_K * restrict y = vy;
|
881
|
+
quantize_row_q6_K_reference(x, y, k);
|
882
|
+
}
|
883
|
+
|
884
|
+
size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
|
885
|
+
assert(k % QK_K == 0);
|
886
|
+
const int nb = k / QK_K;
|
887
|
+
|
888
|
+
(void)hist; // TODO
|
889
|
+
|
890
|
+
for (int j = 0; j < nb; j += k) {
|
891
|
+
block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
|
892
|
+
quantize_row_q6_K_reference(src + j, y, k);
|
893
|
+
}
|
894
|
+
return (n/QK_K*sizeof(block_q6_K));
|
895
|
+
}
|
896
|
+
|
897
|
+
//===================================== Q8_K ==============================================
|
898
|
+
|
899
|
+
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
900
|
+
assert(k % QK_K == 0);
|
901
|
+
const int nb = k / QK_K;
|
902
|
+
|
903
|
+
for (int i = 0; i < nb; i++) {
|
904
|
+
|
905
|
+
float max = 0;
|
906
|
+
float amax = 0;
|
907
|
+
for (int j = 0; j < QK_K; ++j) {
|
908
|
+
float ax = fabsf(x[j]);
|
909
|
+
if (ax > amax) {
|
910
|
+
amax = ax; max = x[j];
|
911
|
+
}
|
912
|
+
}
|
913
|
+
if (!amax) {
|
914
|
+
y[i].d = 0;
|
915
|
+
memset(y[i].qs, 0, QK_K);
|
916
|
+
x += QK_K;
|
917
|
+
continue;
|
918
|
+
}
|
919
|
+
const float iscale = -128.f/max;
|
920
|
+
for (int j = 0; j < QK_K; ++j) {
|
921
|
+
int v = nearest_int(iscale*x[j]);
|
922
|
+
y[i].qs[j] = MIN(127, v);
|
923
|
+
}
|
924
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
925
|
+
int sum = 0;
|
926
|
+
for (int ii = 0; ii < 16; ++ii) {
|
927
|
+
sum += y[i].qs[j*16 + ii];
|
928
|
+
}
|
929
|
+
y[i].bsums[j] = sum;
|
930
|
+
}
|
931
|
+
y[i].d = 1/iscale;
|
932
|
+
x += QK_K;
|
933
|
+
}
|
934
|
+
}
|
935
|
+
|
936
|
+
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
|
937
|
+
assert(k % QK_K == 0);
|
938
|
+
const int nb = k / QK_K;
|
939
|
+
|
940
|
+
for (int i = 0; i < nb; i++) {
|
941
|
+
for (int j = 0; j < QK_K; ++j) {
|
942
|
+
*y++ = x[i].d * x[i].qs[j];
|
943
|
+
}
|
944
|
+
}
|
945
|
+
}
|
946
|
+
|
947
|
+
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
|
948
|
+
quantize_row_q8_K_reference(x, y, k);
|
949
|
+
}
|
950
|
+
|
951
|
+
//===================================== Dot ptoducts =================================
|
952
|
+
|
953
|
+
//
|
954
|
+
// Helper functions
|
955
|
+
//
|
956
|
+
#if __AVX__ || __AVX2__ || __AVX512F__
|
957
|
+
|
958
|
+
// horizontally add 8 floats
|
959
|
+
static inline float hsum_float_8(const __m256 x) {
|
960
|
+
__m128 res = _mm256_extractf128_ps(x, 1);
|
961
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
|
962
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
963
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
964
|
+
return _mm_cvtss_f32(res);
|
965
|
+
}
|
966
|
+
|
967
|
+
// shuffles to pick the required scales in dot products
|
968
|
+
static inline __m256i get_scale_shuffle_q3k(int i) {
|
969
|
+
static const uint8_t k_shuffle[128] = {
|
970
|
+
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
971
|
+
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
972
|
+
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
973
|
+
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
974
|
+
};
|
975
|
+
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
|
976
|
+
}
|
977
|
+
static inline __m256i get_scale_shuffle_k4(int i) {
|
978
|
+
static const uint8_t k_shuffle[256] = {
|
979
|
+
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
|
980
|
+
2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
981
|
+
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
|
982
|
+
6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
983
|
+
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
|
984
|
+
10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
985
|
+
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
|
986
|
+
14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
|
987
|
+
};
|
988
|
+
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
|
989
|
+
}
|
990
|
+
static inline __m128i get_scale_shuffle(int i) {
|
991
|
+
static const uint8_t k_shuffle[128] = {
|
992
|
+
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
|
993
|
+
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
|
994
|
+
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
|
995
|
+
6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
|
996
|
+
8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
|
997
|
+
10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
|
998
|
+
12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
|
999
|
+
14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
|
1000
|
+
};
|
1001
|
+
return _mm_loadu_si128((const __m128i*)k_shuffle + i);
|
1002
|
+
}
|
1003
|
+
#endif
|
1004
|
+
|
1005
|
+
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1006
|
+
|
1007
|
+
const block_q2_K * restrict x = vx;
|
1008
|
+
const block_q8_K * restrict y = vy;
|
1009
|
+
|
1010
|
+
const int nb = n / QK_K;
|
1011
|
+
|
1012
|
+
#ifdef __ARM_NEON
|
1013
|
+
|
1014
|
+
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
1015
|
+
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
1016
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1017
|
+
|
1018
|
+
int8x16x2_t q2bytes;
|
1019
|
+
uint8_t aux[16];
|
1020
|
+
|
1021
|
+
float sum = 0;
|
1022
|
+
|
1023
|
+
for (int i = 0; i < nb; ++i) {
|
1024
|
+
|
1025
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1026
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1027
|
+
|
1028
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1029
|
+
const int8_t * restrict q8 = y[i].qs;
|
1030
|
+
const uint8_t * restrict sc = x[i].scales;
|
1031
|
+
|
1032
|
+
const uint8x16_t mins_and_scales = vld1q_u8(sc);
|
1033
|
+
const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
|
1034
|
+
vst1q_u8(aux, scales);
|
1035
|
+
|
1036
|
+
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
1037
|
+
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
1038
|
+
const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
1039
|
+
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
1040
|
+
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
1041
|
+
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
1042
|
+
vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
|
1043
|
+
sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
|
1044
|
+
|
1045
|
+
int isum = 0;
|
1046
|
+
int is = 0;
|
1047
|
+
|
1048
|
+
// We use this macro instead of a function call because for some reason
|
1049
|
+
// the code runs 2-3% slower, even if the function is declared inline
|
1050
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1051
|
+
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
1052
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
1053
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
1054
|
+
#else
|
1055
|
+
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
1056
|
+
{\
|
1057
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
|
1058
|
+
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
1059
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
1060
|
+
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
1061
|
+
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
1062
|
+
}
|
1063
|
+
#endif
|
1064
|
+
|
1065
|
+
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
1066
|
+
q8bytes = vld1q_s8_x2(q8); q8 += 32;\
|
1067
|
+
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
|
1068
|
+
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
1069
|
+
MULTIPLY_ACCUM_WITH_SCALE((index));
|
1070
|
+
|
1071
|
+
|
1072
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1073
|
+
|
1074
|
+
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
|
1075
|
+
|
1076
|
+
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
1077
|
+
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
1078
|
+
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
1079
|
+
MULTIPLY_ACCUM_WITH_SCALE(0);
|
1080
|
+
|
1081
|
+
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
1082
|
+
|
1083
|
+
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
1084
|
+
|
1085
|
+
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
1086
|
+
|
1087
|
+
is += 8;
|
1088
|
+
}
|
1089
|
+
sum += d * isum;
|
1090
|
+
|
1091
|
+
}
|
1092
|
+
|
1093
|
+
*s = sum;
|
1094
|
+
|
1095
|
+
#elif defined __AVX2__
|
1096
|
+
|
1097
|
+
const __m256i m3 = _mm256_set1_epi8(3);
|
1098
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
1099
|
+
|
1100
|
+
__m256 acc = _mm256_setzero_ps();
|
1101
|
+
|
1102
|
+
for (int i = 0; i < nb; ++i) {
|
1103
|
+
|
1104
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1105
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1106
|
+
|
1107
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1108
|
+
const int8_t * restrict q8 = y[i].qs;
|
1109
|
+
|
1110
|
+
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
1111
|
+
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
|
1112
|
+
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
1113
|
+
const __m256i mins = _mm256_cvtepi8_epi16(mins8);
|
1114
|
+
const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
|
1115
|
+
|
1116
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
|
1117
|
+
|
1118
|
+
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
|
1119
|
+
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
1120
|
+
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
1121
|
+
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
1122
|
+
|
1123
|
+
__m256i sumi = _mm256_setzero_si256();
|
1124
|
+
|
1125
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1126
|
+
|
1127
|
+
const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
|
1128
|
+
|
1129
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1130
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1131
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1132
|
+
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1133
|
+
|
1134
|
+
const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
|
1135
|
+
const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
|
1136
|
+
const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
|
1137
|
+
const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
|
1138
|
+
|
1139
|
+
__m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
1140
|
+
__m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
1141
|
+
__m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
|
1142
|
+
__m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
|
1143
|
+
|
1144
|
+
p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
|
1145
|
+
p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
|
1146
|
+
p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
|
1147
|
+
p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
|
1148
|
+
|
1149
|
+
p0 = _mm256_add_epi32(p0, p1);
|
1150
|
+
p2 = _mm256_add_epi32(p2, p3);
|
1151
|
+
|
1152
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
|
1153
|
+
}
|
1154
|
+
|
1155
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
1156
|
+
|
1157
|
+
}
|
1158
|
+
|
1159
|
+
*s = hsum_float_8(acc);
|
1160
|
+
|
1161
|
+
#else
|
1162
|
+
|
1163
|
+
float sumf = 0;
|
1164
|
+
|
1165
|
+
for (int i = 0; i < nb; ++i) {
|
1166
|
+
|
1167
|
+
const uint8_t * q2 = x[i].qs;
|
1168
|
+
const int8_t * q8 = y[i].qs;
|
1169
|
+
const uint8_t * sc = x[i].scales;
|
1170
|
+
|
1171
|
+
int summs = 0;
|
1172
|
+
for (int j = 0; j < 16; ++j) {
|
1173
|
+
summs += y[i].bsums[j] * (sc[j] >> 4);
|
1174
|
+
}
|
1175
|
+
|
1176
|
+
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1177
|
+
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1178
|
+
|
1179
|
+
int isum = 0;
|
1180
|
+
int is = 0;
|
1181
|
+
int d;
|
1182
|
+
for (int k = 0; k < QK_K/128; ++k) {
|
1183
|
+
int shift = 0;
|
1184
|
+
for (int j = 0; j < 4; ++j) {
|
1185
|
+
d = sc[is++] & 0xF;
|
1186
|
+
int isuml = 0;
|
1187
|
+
for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
1188
|
+
isum += d * isuml;
|
1189
|
+
d = sc[is++] & 0xF;
|
1190
|
+
isuml = 0;
|
1191
|
+
for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
1192
|
+
isum += d * isuml;
|
1193
|
+
shift += 2;
|
1194
|
+
q8 += 32;
|
1195
|
+
}
|
1196
|
+
q2 += 32;
|
1197
|
+
}
|
1198
|
+
sumf += dall * isum - dmin * summs;
|
1199
|
+
}
|
1200
|
+
*s = sumf;
|
1201
|
+
#endif
|
1202
|
+
}
|
1203
|
+
|
1204
|
+
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1205
|
+
assert(n % QK_K == 0);
|
1206
|
+
|
1207
|
+
const uint32_t kmask1 = 0x03030303;
|
1208
|
+
const uint32_t kmask2 = 0x0f0f0f0f;
|
1209
|
+
|
1210
|
+
const block_q3_K * restrict x = vx;
|
1211
|
+
const block_q8_K * restrict y = vy;
|
1212
|
+
|
1213
|
+
const int nb = n / QK_K;
|
1214
|
+
|
1215
|
+
#ifdef __ARM_NEON
|
1216
|
+
|
1217
|
+
uint32_t aux[3];
|
1218
|
+
uint32_t utmp[4];
|
1219
|
+
|
1220
|
+
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
1221
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
1222
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1223
|
+
#endif
|
1224
|
+
|
1225
|
+
const uint8x16_t m0 = vdupq_n_u8(1);
|
1226
|
+
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
1227
|
+
const uint8x16_t m2 = vshlq_n_u8(m0, 2);
|
1228
|
+
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
|
1229
|
+
const int8_t m32 = 32;
|
1230
|
+
|
1231
|
+
int8x16x4_t q3bytes;
|
1232
|
+
|
1233
|
+
float sum = 0;
|
1234
|
+
|
1235
|
+
for (int i = 0; i < nb; ++i) {
|
1236
|
+
|
1237
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1238
|
+
|
1239
|
+
const uint8_t * restrict q3 = x[i].qs;
|
1240
|
+
const uint8_t * restrict qh = x[i].hmask;
|
1241
|
+
const int8_t * restrict q8 = y[i].qs;
|
1242
|
+
|
1243
|
+
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
|
1244
|
+
|
1245
|
+
uint8x16x4_t q3h;
|
1246
|
+
|
1247
|
+
int32_t isum = 0;
|
1248
|
+
|
1249
|
+
// Set up scales
|
1250
|
+
memcpy(aux, x[i].scales, 12);
|
1251
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
1252
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
1253
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
1254
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
1255
|
+
|
1256
|
+
int8_t * scale = (int8_t *)utmp;
|
1257
|
+
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
1258
|
+
|
1259
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1260
|
+
|
1261
|
+
const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
|
1262
|
+
const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
|
1263
|
+
const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
|
1264
|
+
|
1265
|
+
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
|
1266
|
+
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
|
1267
|
+
q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
|
1268
|
+
q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
|
1269
|
+
|
1270
|
+
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
|
1271
|
+
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
|
1272
|
+
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
1273
|
+
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
1274
|
+
|
1275
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1276
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
1277
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
1278
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
1279
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
1280
|
+
#else
|
1281
|
+
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
1282
|
+
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
1283
|
+
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
1284
|
+
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
1285
|
+
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
1286
|
+
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
1287
|
+
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
1288
|
+
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
1289
|
+
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
1290
|
+
#endif
|
1291
|
+
scale += 4;
|
1292
|
+
|
1293
|
+
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
1294
|
+
q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
|
1295
|
+
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
|
1296
|
+
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
|
1297
|
+
|
1298
|
+
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
|
1299
|
+
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
|
1300
|
+
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
1301
|
+
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
1302
|
+
|
1303
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1304
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
1305
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
1306
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
1307
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
1308
|
+
#else
|
1309
|
+
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
1310
|
+
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
1311
|
+
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
1312
|
+
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
1313
|
+
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
1314
|
+
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
1315
|
+
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
1316
|
+
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
1317
|
+
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
1318
|
+
#endif
|
1319
|
+
scale += 4;
|
1320
|
+
|
1321
|
+
if (j == 0) {
|
1322
|
+
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
|
1323
|
+
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
|
1324
|
+
}
|
1325
|
+
|
1326
|
+
}
|
1327
|
+
sum += d * isum;
|
1328
|
+
|
1329
|
+
}
|
1330
|
+
|
1331
|
+
*s = sum;
|
1332
|
+
|
1333
|
+
#elif defined __AVX2__
|
1334
|
+
|
1335
|
+
const __m256i m3 = _mm256_set1_epi8(3);
|
1336
|
+
const __m256i mone = _mm256_set1_epi8(1);
|
1337
|
+
const __m128i m32 = _mm_set1_epi8(32);
|
1338
|
+
|
1339
|
+
__m256 acc = _mm256_setzero_ps();
|
1340
|
+
|
1341
|
+
uint32_t aux[3];
|
1342
|
+
|
1343
|
+
for (int i = 0; i < nb; ++i) {
|
1344
|
+
|
1345
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1346
|
+
|
1347
|
+
const uint8_t * restrict q3 = x[i].qs;
|
1348
|
+
const int8_t * restrict q8 = y[i].qs;
|
1349
|
+
|
1350
|
+
// Set up scales
|
1351
|
+
memcpy(aux, x[i].scales, 12);
|
1352
|
+
__m128i scales128 = _mm_set_epi32(
|
1353
|
+
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
|
1354
|
+
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
|
1355
|
+
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
|
1356
|
+
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
|
1357
|
+
scales128 = _mm_sub_epi8(scales128, m32);
|
1358
|
+
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
1359
|
+
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
1360
|
+
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
1361
|
+
const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
|
1362
|
+
|
1363
|
+
// high bit
|
1364
|
+
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
|
1365
|
+
|
1366
|
+
// integer accumulator
|
1367
|
+
__m256i sumi = _mm256_setzero_si256();
|
1368
|
+
|
1369
|
+
int bit = 0;
|
1370
|
+
int is = 0;
|
1371
|
+
|
1372
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1373
|
+
// load low 2 bits
|
1374
|
+
const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
|
1375
|
+
|
1376
|
+
// prepare low and high bits
|
1377
|
+
const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
|
1378
|
+
const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
1379
|
+
++bit;
|
1380
|
+
|
1381
|
+
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
|
1382
|
+
const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
1383
|
+
++bit;
|
1384
|
+
|
1385
|
+
const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
|
1386
|
+
const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
1387
|
+
++bit;
|
1388
|
+
|
1389
|
+
const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
|
1390
|
+
const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
|
1391
|
+
++bit;
|
1392
|
+
|
1393
|
+
// load Q8 quants
|
1394
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1395
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1396
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1397
|
+
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1398
|
+
|
1399
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
1400
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
1401
|
+
// and 2 if the high bit was set)
|
1402
|
+
__m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
1403
|
+
__m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
1404
|
+
__m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
|
1405
|
+
__m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
|
1406
|
+
|
1407
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
1408
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
1409
|
+
__m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
|
1410
|
+
__m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
|
1411
|
+
|
1412
|
+
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
1413
|
+
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
1414
|
+
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
1415
|
+
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
1416
|
+
|
1417
|
+
// multiply with scales
|
1418
|
+
p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
|
1419
|
+
p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
|
1420
|
+
p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
|
1421
|
+
p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
|
1422
|
+
|
1423
|
+
// accumulate
|
1424
|
+
p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
1425
|
+
p16_2 = _mm256_add_epi32(p16_2, p16_3);
|
1426
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
|
1427
|
+
|
1428
|
+
}
|
1429
|
+
|
1430
|
+
// multiply with block scale and accumulate
|
1431
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
1432
|
+
|
1433
|
+
}
|
1434
|
+
|
1435
|
+
*s = hsum_float_8(acc);
|
1436
|
+
|
1437
|
+
#else
|
1438
|
+
// scalar version
|
1439
|
+
// This function is written like this so the compiler can manage to vectorize most of it
|
1440
|
+
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
1441
|
+
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
1442
|
+
// The ideal situation would be if we could just write the code once, and the compiler would
|
1443
|
+
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
1444
|
+
// write vectorized versions for AVX, ARM_NEON, etc.
|
1445
|
+
|
1446
|
+
int8_t aux8[QK_K];
|
1447
|
+
int16_t aux16[8];
|
1448
|
+
float sums [8];
|
1449
|
+
int32_t aux32[8];
|
1450
|
+
memset(sums, 0, 8*sizeof(float));
|
1451
|
+
|
1452
|
+
uint32_t auxs[4];
|
1453
|
+
const int8_t * scales = (const int8_t*)auxs;
|
1454
|
+
|
1455
|
+
float sumf = 0;
|
1456
|
+
for (int i = 0; i < nb; ++i) {
|
1457
|
+
const uint8_t * restrict q3 = x[i].qs;
|
1458
|
+
const uint8_t * restrict hm = x[i].hmask;
|
1459
|
+
const int8_t * restrict q8 = y[i].qs;
|
1460
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
1461
|
+
int8_t * restrict a = aux8;
|
1462
|
+
uint8_t m = 1;
|
1463
|
+
for (int j = 0; j < QK_K; j += 128) {
|
1464
|
+
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
|
1465
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
1466
|
+
a += 32; m <<= 1;
|
1467
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
|
1468
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
1469
|
+
a += 32; m <<= 1;
|
1470
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
|
1471
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
1472
|
+
a += 32; m <<= 1;
|
1473
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
|
1474
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
1475
|
+
a += 32; m <<= 1;
|
1476
|
+
q3 += 32;
|
1477
|
+
}
|
1478
|
+
a = aux8;
|
1479
|
+
|
1480
|
+
memcpy(auxs, x[i].scales, 12);
|
1481
|
+
uint32_t tmp = auxs[2];
|
1482
|
+
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
1483
|
+
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
1484
|
+
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
1485
|
+
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
1486
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
1487
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1488
|
+
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
1489
|
+
q8 += 8; a += 8;
|
1490
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1491
|
+
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
1492
|
+
q8 += 8; a += 8;
|
1493
|
+
}
|
1494
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
1495
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
1496
|
+
}
|
1497
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
1498
|
+
*s = sumf;
|
1499
|
+
|
1500
|
+
#endif
|
1501
|
+
|
1502
|
+
}
|
1503
|
+
|
1504
|
+
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1505
|
+
assert(n % QK_K == 0);
|
1506
|
+
|
1507
|
+
const block_q4_K * restrict x = vx;
|
1508
|
+
const block_q8_K * restrict y = vy;
|
1509
|
+
|
1510
|
+
const int nb = n / QK_K;
|
1511
|
+
|
1512
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
1513
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
1514
|
+
static const uint32_t kmask3 = 0x03030303;
|
1515
|
+
|
1516
|
+
uint32_t utmp[4];
|
1517
|
+
|
1518
|
+
#ifdef __ARM_NEON
|
1519
|
+
|
1520
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
1521
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
1522
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
1523
|
+
#endif
|
1524
|
+
|
1525
|
+
int8x16x2_t q4bytes;
|
1526
|
+
int8x16x2_t q8bytes;
|
1527
|
+
|
1528
|
+
float sumf = 0;
|
1529
|
+
|
1530
|
+
for (int i = 0; i < nb; ++i) {
|
1531
|
+
|
1532
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1533
|
+
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1534
|
+
|
1535
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
1536
|
+
|
1537
|
+
memcpy(utmp, x[i].scales, 12);
|
1538
|
+
|
1539
|
+
const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
|
1540
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1541
|
+
utmp[0] &= kmask1;
|
1542
|
+
|
1543
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
1544
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
1545
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
1546
|
+
sumf -= dmin * vaddvq_s32(prod);
|
1547
|
+
|
1548
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
1549
|
+
|
1550
|
+
const uint8_t * restrict q4 = x[i].qs;
|
1551
|
+
const int8_t * restrict q8 = y[i].qs;
|
1552
|
+
|
1553
|
+
//int32x4_t isum = mzero;
|
1554
|
+
|
1555
|
+
int32_t sumi1 = 0;
|
1556
|
+
int32_t sumi2 = 0;
|
1557
|
+
|
1558
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1559
|
+
|
1560
|
+
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
|
1561
|
+
|
1562
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
1563
|
+
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
1564
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
1565
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
1566
|
+
|
1567
|
+
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
1568
|
+
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
1569
|
+
|
1570
|
+
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
1571
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
1572
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
1573
|
+
|
1574
|
+
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
1575
|
+
|
1576
|
+
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
1577
|
+
#else
|
1578
|
+
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
1579
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
1580
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
1581
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
1582
|
+
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
1583
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
1584
|
+
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
1585
|
+
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
1586
|
+
|
1587
|
+
q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
1588
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
1589
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
1590
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
1591
|
+
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
1592
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
1593
|
+
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
1594
|
+
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
|
1595
|
+
|
1596
|
+
#endif
|
1597
|
+
}
|
1598
|
+
|
1599
|
+
sumf += d * (sumi1 + sumi2);
|
1600
|
+
|
1601
|
+
}
|
1602
|
+
|
1603
|
+
*s = sumf;
|
1604
|
+
|
1605
|
+
#elif defined __AVX2__
|
1606
|
+
|
1607
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
1608
|
+
|
1609
|
+
__m256 acc = _mm256_setzero_ps();
|
1610
|
+
__m128 acc_m = _mm_setzero_ps();
|
1611
|
+
|
1612
|
+
for (int i = 0; i < nb; ++i) {
|
1613
|
+
|
1614
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1615
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1616
|
+
|
1617
|
+
const uint8_t * restrict q4 = x[i].qs;
|
1618
|
+
const int8_t * restrict q8 = y[i].qs;
|
1619
|
+
|
1620
|
+
memcpy(utmp, x[i].scales, 12);
|
1621
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1622
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1623
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1624
|
+
utmp[2] = uaux;
|
1625
|
+
utmp[0] &= kmask1;
|
1626
|
+
|
1627
|
+
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
1628
|
+
|
1629
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
|
1630
|
+
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
1631
|
+
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
1632
|
+
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
1633
|
+
|
1634
|
+
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
1635
|
+
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
1636
|
+
|
1637
|
+
__m256i sumi = _mm256_setzero_si256();
|
1638
|
+
|
1639
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1640
|
+
|
1641
|
+
const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
|
1642
|
+
const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
|
1643
|
+
|
1644
|
+
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
|
1645
|
+
const __m256i q4l = _mm256_and_si256(q4bits, m4);
|
1646
|
+
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
1647
|
+
|
1648
|
+
const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1649
|
+
__m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
1650
|
+
p16l = _mm256_madd_epi16(scale_l, p16l);
|
1651
|
+
sumi = _mm256_add_epi32(sumi, p16l);
|
1652
|
+
|
1653
|
+
const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1654
|
+
__m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
1655
|
+
p16h = _mm256_madd_epi16(scale_h, p16h);
|
1656
|
+
sumi = _mm256_add_epi32(sumi, p16h);
|
1657
|
+
|
1658
|
+
}
|
1659
|
+
|
1660
|
+
__m256 vd = _mm256_set1_ps(d);
|
1661
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
1662
|
+
|
1663
|
+
}
|
1664
|
+
|
1665
|
+
acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
1666
|
+
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
1667
|
+
|
1668
|
+
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
|
1669
|
+
|
1670
|
+
#else
|
1671
|
+
|
1672
|
+
|
1673
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
1674
|
+
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
1675
|
+
|
1676
|
+
int8_t aux8[QK_K];
|
1677
|
+
int16_t aux16[8];
|
1678
|
+
float sums [8];
|
1679
|
+
int32_t aux32[8];
|
1680
|
+
memset(sums, 0, 8*sizeof(float));
|
1681
|
+
|
1682
|
+
float sumf = 0;
|
1683
|
+
for (int i = 0; i < nb; ++i) {
|
1684
|
+
const uint8_t * restrict q4 = x[i].qs;
|
1685
|
+
const int8_t * restrict q8 = y[i].qs;
|
1686
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
1687
|
+
int8_t * restrict a = aux8;
|
1688
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1689
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
1690
|
+
a += 32;
|
1691
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
1692
|
+
a += 32; q4 += 32;
|
1693
|
+
}
|
1694
|
+
memcpy(utmp, x[i].scales, 12);
|
1695
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1696
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1697
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1698
|
+
utmp[2] = uaux;
|
1699
|
+
utmp[0] &= kmask1;
|
1700
|
+
|
1701
|
+
int sumi = 0;
|
1702
|
+
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
1703
|
+
a = aux8;
|
1704
|
+
int is = 0;
|
1705
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
1706
|
+
int32_t scale = scales[is++];
|
1707
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1708
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1709
|
+
q8 += 8; a += 8;
|
1710
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1711
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1712
|
+
q8 += 8; a += 8;
|
1713
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1714
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1715
|
+
q8 += 8; a += 8;
|
1716
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1717
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1718
|
+
q8 += 8; a += 8;
|
1719
|
+
}
|
1720
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
1721
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
1722
|
+
const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
|
1723
|
+
sumf -= dmin * sumi;
|
1724
|
+
}
|
1725
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
1726
|
+
*s = sumf;
|
1727
|
+
#endif
|
1728
|
+
}
|
1729
|
+
|
1730
|
+
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1731
|
+
assert(n % QK_K == 0);
|
1732
|
+
|
1733
|
+
const block_q5_K * restrict x = vx;
|
1734
|
+
const block_q8_K * restrict y = vy;
|
1735
|
+
|
1736
|
+
const int nb = n / QK_K;
|
1737
|
+
|
1738
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
1739
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
1740
|
+
static const uint32_t kmask3 = 0x03030303;
|
1741
|
+
|
1742
|
+
uint32_t utmp[4];
|
1743
|
+
|
1744
|
+
|
1745
|
+
#ifdef __ARM_NEON
|
1746
|
+
|
1747
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
1748
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
1749
|
+
const uint8x16_t mone = vdupq_n_u8(1);
|
1750
|
+
const uint8x16_t mtwo = vdupq_n_u8(2);
|
1751
|
+
|
1752
|
+
int8x16x4_t q5bytes;
|
1753
|
+
|
1754
|
+
float sumf = 0;
|
1755
|
+
|
1756
|
+
for (int i = 0; i < nb; ++i) {
|
1757
|
+
|
1758
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1759
|
+
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1760
|
+
|
1761
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
1762
|
+
|
1763
|
+
memcpy(utmp, x[i].scales, 12);
|
1764
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1765
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1766
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1767
|
+
utmp[2] = uaux;
|
1768
|
+
utmp[0] &= kmask1;
|
1769
|
+
|
1770
|
+
const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
|
1771
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
1772
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
1773
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
1774
|
+
int32_t sumi_mins = vaddvq_s32(prod);
|
1775
|
+
|
1776
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
1777
|
+
|
1778
|
+
const uint8_t * restrict q5 = x[i].qs;
|
1779
|
+
const uint8_t * restrict qh = x[i].qh;
|
1780
|
+
const int8_t * restrict q8 = y[i].qs;
|
1781
|
+
|
1782
|
+
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
|
1783
|
+
|
1784
|
+
uint8x16x4_t q5h;
|
1785
|
+
|
1786
|
+
int32_t sumi = 0;
|
1787
|
+
|
1788
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1789
|
+
|
1790
|
+
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
|
1791
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
1792
|
+
|
1793
|
+
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
1794
|
+
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
1795
|
+
q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
|
1796
|
+
q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
|
1797
|
+
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
|
1798
|
+
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
|
1799
|
+
|
1800
|
+
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
|
1801
|
+
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
|
1802
|
+
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
1803
|
+
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
1804
|
+
|
1805
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1806
|
+
|
1807
|
+
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
1808
|
+
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
1809
|
+
#else
|
1810
|
+
|
1811
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
1812
|
+
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
1813
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
1814
|
+
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
1815
|
+
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
|
1816
|
+
|
1817
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
1818
|
+
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
1819
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
1820
|
+
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
1821
|
+
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
|
1822
|
+
#endif
|
1823
|
+
}
|
1824
|
+
|
1825
|
+
sumf += d * sumi - dmin * sumi_mins;
|
1826
|
+
|
1827
|
+
}
|
1828
|
+
|
1829
|
+
*s = sumf;
|
1830
|
+
|
1831
|
+
#elif defined __AVX2__
|
1832
|
+
|
1833
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
1834
|
+
const __m128i mzero = _mm_setzero_si128();
|
1835
|
+
const __m256i mone = _mm256_set1_epi8(1);
|
1836
|
+
|
1837
|
+
__m256 acc = _mm256_setzero_ps();
|
1838
|
+
|
1839
|
+
float summs = 0.f;
|
1840
|
+
|
1841
|
+
for (int i = 0; i < nb; ++i) {
|
1842
|
+
|
1843
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1844
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1845
|
+
|
1846
|
+
const uint8_t * restrict q5 = x[i].qs;
|
1847
|
+
const int8_t * restrict q8 = y[i].qs;
|
1848
|
+
|
1849
|
+
memcpy(utmp, x[i].scales, 12);
|
1850
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1851
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1852
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1853
|
+
utmp[2] = uaux;
|
1854
|
+
utmp[0] &= kmask1;
|
1855
|
+
|
1856
|
+
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
1857
|
+
|
1858
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
|
1859
|
+
const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
|
1860
|
+
const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
|
1861
|
+
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
1862
|
+
summs += dmin * _mm_extract_epi32(hsum, 0);
|
1863
|
+
|
1864
|
+
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
1865
|
+
const __m256i scales = _mm256_set_m128i(sc128, sc128);
|
1866
|
+
|
1867
|
+
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
|
1868
|
+
__m256i hmask = mone;
|
1869
|
+
|
1870
|
+
__m256i sumi = _mm256_setzero_si256();
|
1871
|
+
|
1872
|
+
int bit = 0;
|
1873
|
+
|
1874
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1875
|
+
|
1876
|
+
const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
|
1877
|
+
const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
|
1878
|
+
|
1879
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
|
1880
|
+
|
1881
|
+
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
|
1882
|
+
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
|
1883
|
+
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
1884
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
1885
|
+
|
1886
|
+
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
1887
|
+
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
|
1888
|
+
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
1889
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
1890
|
+
|
1891
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1892
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
1893
|
+
|
1894
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
|
1895
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
1896
|
+
|
1897
|
+
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
1898
|
+
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
1899
|
+
|
1900
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
1901
|
+
|
1902
|
+
}
|
1903
|
+
|
1904
|
+
__m256 vd = _mm256_set1_ps(d);
|
1905
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
1906
|
+
|
1907
|
+
}
|
1908
|
+
|
1909
|
+
*s = hsum_float_8(acc) + summs;
|
1910
|
+
|
1911
|
+
#else
|
1912
|
+
|
1913
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
1914
|
+
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
1915
|
+
|
1916
|
+
int8_t aux8[QK_K];
|
1917
|
+
int16_t aux16[8];
|
1918
|
+
float sums [8];
|
1919
|
+
int32_t aux32[8];
|
1920
|
+
memset(sums, 0, 8*sizeof(float));
|
1921
|
+
|
1922
|
+
float sumf = 0;
|
1923
|
+
for (int i = 0; i < nb; ++i) {
|
1924
|
+
const uint8_t * restrict q4 = x[i].qs;
|
1925
|
+
const uint8_t * restrict hm = x[i].qh;
|
1926
|
+
const int8_t * restrict q8 = y[i].qs;
|
1927
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
1928
|
+
int8_t * restrict a = aux8;
|
1929
|
+
uint8_t m = 1;
|
1930
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1931
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
1932
|
+
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
1933
|
+
a += 32; m <<= 1;
|
1934
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
1935
|
+
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
1936
|
+
a += 32; m <<= 1;
|
1937
|
+
q4 += 32;
|
1938
|
+
}
|
1939
|
+
memcpy(utmp, x[i].scales, 12);
|
1940
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1941
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1942
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1943
|
+
utmp[2] = uaux;
|
1944
|
+
utmp[0] &= kmask1;
|
1945
|
+
|
1946
|
+
int sumi = 0;
|
1947
|
+
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
1948
|
+
a = aux8;
|
1949
|
+
int is = 0;
|
1950
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
1951
|
+
int32_t scale = scales[is++];
|
1952
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1953
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1954
|
+
q8 += 8; a += 8;
|
1955
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1956
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1957
|
+
q8 += 8; a += 8;
|
1958
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1959
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1960
|
+
q8 += 8; a += 8;
|
1961
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1962
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1963
|
+
q8 += 8; a += 8;
|
1964
|
+
}
|
1965
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
1966
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
1967
|
+
const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
|
1968
|
+
sumf -= dmin * sumi;
|
1969
|
+
}
|
1970
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
1971
|
+
*s = sumf;
|
1972
|
+
#endif
|
1973
|
+
}
|
1974
|
+
|
1975
|
+
|
1976
|
+
|
1977
|
+
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1978
|
+
assert(n % QK_K == 0);
|
1979
|
+
|
1980
|
+
const block_q6_K * restrict x = vx;
|
1981
|
+
const block_q8_K * restrict y = vy;
|
1982
|
+
|
1983
|
+
const int nb = n / QK_K;
|
1984
|
+
|
1985
|
+
#ifdef __ARM_NEON
|
1986
|
+
|
1987
|
+
float sum = 0;
|
1988
|
+
|
1989
|
+
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
1990
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1991
|
+
//const int8x16_t m32s = vdupq_n_s8(32);
|
1992
|
+
|
1993
|
+
const uint8x16_t mone = vdupq_n_u8(3);
|
1994
|
+
|
1995
|
+
int8x16x4_t q6bytes;
|
1996
|
+
uint8x16x4_t q6h;
|
1997
|
+
|
1998
|
+
for (int i = 0; i < nb; ++i) {
|
1999
|
+
|
2000
|
+
const float d_all = ggml_fp16_to_fp32(x[i].d);
|
2001
|
+
|
2002
|
+
const uint8_t * restrict q6 = x[i].ql;
|
2003
|
+
const uint8_t * restrict qh = x[i].qh;
|
2004
|
+
const int8_t * restrict q8 = y[i].qs;
|
2005
|
+
|
2006
|
+
const int8_t * restrict scale = x[i].scales;
|
2007
|
+
|
2008
|
+
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
2009
|
+
const int8x16_t scales = vld1q_s8(scale);
|
2010
|
+
const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
2011
|
+
|
2012
|
+
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
2013
|
+
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
2014
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
|
2015
|
+
vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
|
2016
|
+
int32_t isum_mins = vaddvq_s32(prod);
|
2017
|
+
|
2018
|
+
int32_t isum = 0;
|
2019
|
+
|
2020
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
2021
|
+
|
2022
|
+
uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
|
2023
|
+
uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
|
2024
|
+
int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
2025
|
+
|
2026
|
+
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
2027
|
+
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
2028
|
+
uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
|
2029
|
+
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
2030
|
+
shifted = vshrq_n_u8(qhbits.val[1], 2);
|
2031
|
+
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
2032
|
+
|
2033
|
+
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
|
2034
|
+
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
|
2035
|
+
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
|
2036
|
+
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
|
2037
|
+
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
|
2038
|
+
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
|
2039
|
+
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
2040
|
+
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
2041
|
+
|
2042
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2043
|
+
|
2044
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
2045
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
2046
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
2047
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
2048
|
+
scale += 4;
|
2049
|
+
|
2050
|
+
#else
|
2051
|
+
|
2052
|
+
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
2053
|
+
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
2054
|
+
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
2055
|
+
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
2056
|
+
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
2057
|
+
scale += 2;
|
2058
|
+
|
2059
|
+
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
2060
|
+
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
2061
|
+
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
2062
|
+
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
2063
|
+
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
2064
|
+
scale += 2;
|
2065
|
+
#endif
|
2066
|
+
|
2067
|
+
q8bytes = vld1q_s8_x4(q8); q8 += 64;
|
2068
|
+
|
2069
|
+
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
2070
|
+
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
2071
|
+
shifted = vshrq_n_u8(qhbits.val[1], 4);
|
2072
|
+
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
2073
|
+
shifted = vshrq_n_u8(qhbits.val[0], 6);
|
2074
|
+
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
2075
|
+
shifted = vshrq_n_u8(qhbits.val[1], 6);
|
2076
|
+
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
2077
|
+
|
2078
|
+
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
|
2079
|
+
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
|
2080
|
+
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
|
2081
|
+
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
|
2082
|
+
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
|
2083
|
+
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
|
2084
|
+
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
2085
|
+
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
2086
|
+
|
2087
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2088
|
+
|
2089
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
2090
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
2091
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
2092
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
2093
|
+
scale += 4;
|
2094
|
+
|
2095
|
+
//for (int l = 0; l < 4; ++l) {
|
2096
|
+
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
|
2097
|
+
// isum += vaddvq_s32(p) * *scale++;
|
2098
|
+
//}
|
2099
|
+
#else
|
2100
|
+
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
2101
|
+
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
2102
|
+
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
2103
|
+
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
2104
|
+
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
2105
|
+
scale += 2;
|
2106
|
+
|
2107
|
+
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
2108
|
+
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
2109
|
+
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
2110
|
+
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
2111
|
+
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
2112
|
+
scale += 2;
|
2113
|
+
#endif
|
2114
|
+
|
2115
|
+
}
|
2116
|
+
//sum += isum * d_all * y[i].d;
|
2117
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
2118
|
+
|
2119
|
+
}
|
2120
|
+
*s = sum;
|
2121
|
+
|
2122
|
+
#elif defined __AVX2__
|
2123
|
+
|
2124
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
2125
|
+
const __m256i m2 = _mm256_set1_epi8(3);
|
2126
|
+
const __m256i m32s = _mm256_set1_epi8(32);
|
2127
|
+
|
2128
|
+
__m256 acc = _mm256_setzero_ps();
|
2129
|
+
|
2130
|
+
for (int i = 0; i < nb; ++i) {
|
2131
|
+
|
2132
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2133
|
+
|
2134
|
+
const uint8_t * restrict q4 = x[i].ql;
|
2135
|
+
const uint8_t * restrict qh = x[i].qh;
|
2136
|
+
const int8_t * restrict q8 = y[i].qs;
|
2137
|
+
|
2138
|
+
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
2139
|
+
|
2140
|
+
__m256i sumi = _mm256_setzero_si256();
|
2141
|
+
|
2142
|
+
int is = 0;
|
2143
|
+
|
2144
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
2145
|
+
|
2146
|
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
|
2147
|
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
2148
|
+
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
2149
|
+
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
2150
|
+
is += 4;
|
2151
|
+
|
2152
|
+
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
|
2153
|
+
const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
|
2154
|
+
const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
|
2155
|
+
|
2156
|
+
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
|
2157
|
+
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
|
2158
|
+
const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
|
2159
|
+
const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
|
2160
|
+
|
2161
|
+
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
2162
|
+
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
|
2163
|
+
const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
|
2164
|
+
const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
|
2165
|
+
|
2166
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2167
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2168
|
+
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2169
|
+
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2170
|
+
|
2171
|
+
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
2172
|
+
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
2173
|
+
__m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
|
2174
|
+
__m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
|
2175
|
+
|
2176
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
2177
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
2178
|
+
__m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
|
2179
|
+
__m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
|
2180
|
+
|
2181
|
+
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
2182
|
+
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
2183
|
+
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
|
2184
|
+
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
|
2185
|
+
|
2186
|
+
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
2187
|
+
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
2188
|
+
p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
|
2189
|
+
p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
|
2190
|
+
|
2191
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
2192
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
|
2193
|
+
|
2194
|
+
}
|
2195
|
+
|
2196
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
2197
|
+
}
|
2198
|
+
|
2199
|
+
*s = hsum_float_8(acc);
|
2200
|
+
|
2201
|
+
#else
|
2202
|
+
|
2203
|
+
int8_t aux8[QK_K];
|
2204
|
+
int16_t aux16[8];
|
2205
|
+
float sums [8];
|
2206
|
+
int32_t aux32[8];
|
2207
|
+
memset(sums, 0, 8*sizeof(float));
|
2208
|
+
|
2209
|
+
float sumf = 0;
|
2210
|
+
for (int i = 0; i < nb; ++i) {
|
2211
|
+
const uint8_t * restrict q4 = x[i].ql;
|
2212
|
+
const uint8_t * restrict qh = x[i].qh;
|
2213
|
+
const int8_t * restrict q8 = y[i].qs;
|
2214
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2215
|
+
int8_t * restrict a = aux8;
|
2216
|
+
for (int j = 0; j < QK_K; j += 128) {
|
2217
|
+
for (int l = 0; l < 32; ++l) {
|
2218
|
+
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
2219
|
+
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
2220
|
+
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
2221
|
+
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
2222
|
+
}
|
2223
|
+
a += 128;
|
2224
|
+
q4 += 64;
|
2225
|
+
qh += 32;
|
2226
|
+
}
|
2227
|
+
a = aux8;
|
2228
|
+
int is = 0;
|
2229
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
2230
|
+
int scale = x[i].scales[is++];
|
2231
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2232
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2233
|
+
q8 += 8; a += 8;
|
2234
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2235
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2236
|
+
q8 += 8; a += 8;
|
2237
|
+
}
|
2238
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
2239
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
2240
|
+
}
|
2241
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2242
|
+
*s = sumf;
|
2243
|
+
#endif
|
2244
|
+
}
|