llama_cpp 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,2244 @@
1
+ #include "k_quants.h"
2
+ #include "ggml.h"
3
+
4
+ #include <math.h>
5
+ #include <string.h>
6
+ #include <assert.h>
7
+
8
+ #ifdef __ARM_NEON
9
+
10
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
11
+ //
12
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
13
+ //
14
+ #include <arm_neon.h>
15
+
16
+ #else
17
+
18
+ #ifdef __wasm_simd128__
19
+ #include <wasm_simd128.h>
20
+ #else
21
+ #ifdef __POWER9_VECTOR__
22
+ #include <altivec.h>
23
+ #undef bool
24
+ #define bool _Bool
25
+ #else
26
+ #if defined(_MSC_VER) || defined(__MINGW32__)
27
+ #include <intrin.h>
28
+ #else
29
+ #if !defined(__riscv)
30
+ #include <immintrin.h>
31
+ #endif
32
+ #endif
33
+ #endif
34
+ #endif
35
+ #endif
36
+
37
+ #undef MIN
38
+ #undef MAX
39
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
40
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
41
+
42
+ //
43
+ // 2-6 bit quantization in super-blocks
44
+ //
45
+
46
+
47
+ //
48
+ // ===================== Helper functions
49
+ //
50
+ static inline int nearest_int(float fval) {
51
+ assert(fval <= 4194303.f);
52
+ float val = fval + 12582912.f;
53
+ int i; memcpy(&i, &val, sizeof(int));
54
+ return (i & 0x007fffff) - 0x00400000;
55
+ }
56
+
57
+ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
58
+ float max = 0;
59
+ float amax = 0;
60
+ for (int i = 0; i < n; ++i) {
61
+ float ax = fabsf(x[i]);
62
+ if (ax > amax) { amax = ax; max = x[i]; }
63
+ }
64
+ if (!amax) { // all zero
65
+ for (int i = 0; i < n; ++i) {
66
+ L[i] = 0;
67
+ }
68
+ return 0.f;
69
+ }
70
+ float iscale = -nmax / max;
71
+ if (rmse_type == 0) {
72
+ for (int i = 0; i < n; ++i) {
73
+ int l = nearest_int(iscale * x[i]);
74
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
75
+ }
76
+ return 1/iscale;
77
+ }
78
+ int weight_type = rmse_type%2;
79
+ float sumlx = 0;
80
+ float suml2 = 0;
81
+ for (int i = 0; i < n; ++i) {
82
+ int l = nearest_int(iscale * x[i]);
83
+ l = MAX(-nmax, MIN(nmax-1, l));
84
+ L[i] = l + nmax;
85
+ float w = weight_type == 1 ? x[i] * x[i] : 1;
86
+ sumlx += w*x[i]*l;
87
+ suml2 += w*l*l;
88
+ }
89
+ float scale = sumlx/suml2;
90
+ float best = scale * sumlx;
91
+ for (int itry = 0; itry < 3; ++itry) {
92
+ iscale = 1/scale;
93
+ float slx = 0;
94
+ float sl2 = 0;
95
+ bool changed = false;
96
+ for (int i = 0; i < n; ++i) {
97
+ int l = nearest_int(iscale * x[i]);
98
+ l = MAX(-nmax, MIN(nmax-1, l));
99
+ if (l + nmax != L[i]) { changed = true; }
100
+ float w = weight_type == 1 ? x[i] * x[i] : 1.f;
101
+ slx += w*x[i]*l;
102
+ sl2 += w*l*l;
103
+ }
104
+ if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
105
+ for (int i = 0; i < n; ++i) {
106
+ int l = nearest_int(iscale * x[i]);
107
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
108
+ }
109
+ sumlx = slx; suml2 = sl2;
110
+ scale = sumlx/suml2;
111
+ best = scale * sumlx;
112
+ }
113
+ for (int itry = 0; itry < 5; ++itry) {
114
+ int n_changed = 0;
115
+ for (int i = 0; i < n; ++i) {
116
+ float w = weight_type == 1 ? x[i]*x[i] : 1;
117
+ int l = L[i] - nmax;
118
+ float slx = sumlx - w*x[i]*l;
119
+ if (slx > 0) {
120
+ float sl2 = suml2 - w*l*l;
121
+ int new_l = nearest_int(x[i] * sl2 / slx);
122
+ new_l = MAX(-nmax, MIN(nmax-1, new_l));
123
+ if (new_l != l) {
124
+ slx += w*x[i]*new_l;
125
+ sl2 += w*new_l*new_l;
126
+ if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
127
+ L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
128
+ scale = sumlx / suml2; best = scale * sumlx;
129
+ ++n_changed;
130
+ }
131
+ }
132
+ }
133
+ }
134
+ if (!n_changed) { break; }
135
+ }
136
+ if (rmse_type < 3) {
137
+ return scale;
138
+ }
139
+ for (int is = -4; is <= 4; ++is) {
140
+ if (is == 0) {
141
+ continue;
142
+ }
143
+ iscale = -(nmax + 0.1f*is) / max;
144
+ sumlx = suml2 = 0;
145
+ for (int i = 0; i < n; ++i) {
146
+ int l = nearest_int(iscale * x[i]);
147
+ l = MAX(-nmax, MIN(nmax-1, l));
148
+ float w = weight_type == 1 ? x[i] * x[i] : 1;
149
+ sumlx += w*x[i]*l;
150
+ suml2 += w*l*l;
151
+ }
152
+ if (suml2 > 0 && sumlx*sumlx > best*suml2) {
153
+ for (int i = 0; i < n; ++i) {
154
+ int l = nearest_int(iscale * x[i]);
155
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
156
+ }
157
+ scale = sumlx/suml2; best = scale*sumlx;
158
+ }
159
+ }
160
+ return scale;
161
+ }
162
+
163
+ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
164
+ float max = 0;
165
+ float amax = 0;
166
+ for (int i = 0; i < n; ++i) {
167
+ float ax = fabsf(x[i]);
168
+ if (ax > amax) { amax = ax; max = x[i]; }
169
+ }
170
+ if (!amax) { // all zero
171
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
172
+ return 0.f;
173
+ }
174
+ float iscale = -nmax / max;
175
+ if (do_rmse) {
176
+ float sumlx = 0;
177
+ float suml2 = 0;
178
+ for (int i = 0; i < n; ++i) {
179
+ int l = nearest_int(iscale * x[i]);
180
+ l = MAX(-nmax, MIN(nmax-1, l));
181
+ L[i] = l;
182
+ float w = x[i]*x[i];
183
+ sumlx += w*x[i]*l;
184
+ suml2 += w*l*l;
185
+ }
186
+ for (int itry = 0; itry < 5; ++itry) {
187
+ int n_changed = 0;
188
+ for (int i = 0; i < n; ++i) {
189
+ float w = x[i]*x[i];
190
+ float slx = sumlx - w*x[i]*L[i];
191
+ if (slx > 0) {
192
+ float sl2 = suml2 - w*L[i]*L[i];
193
+ int new_l = nearest_int(x[i] * sl2 / slx);
194
+ new_l = MAX(-nmax, MIN(nmax-1, new_l));
195
+ if (new_l != L[i]) {
196
+ slx += w*x[i]*new_l;
197
+ sl2 += w*new_l*new_l;
198
+ if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
199
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
200
+ ++n_changed;
201
+ }
202
+ }
203
+ }
204
+ }
205
+ if (!n_changed) {
206
+ break;
207
+ }
208
+ }
209
+ for (int i = 0; i < n; ++i) {
210
+ L[i] += nmax;
211
+ }
212
+ return sumlx / suml2;
213
+ }
214
+ for (int i = 0; i < n; ++i) {
215
+ int l = nearest_int(iscale * x[i]);
216
+ l = MAX(-nmax, MIN(nmax-1, l));
217
+ L[i] = l + nmax;
218
+ }
219
+ return 1/iscale;
220
+ }
221
+
222
+ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
223
+ float min = x[0];
224
+ float max = x[0];
225
+ for (int i = 1; i < n; ++i) {
226
+ if (x[i] < min) min = x[i];
227
+ if (x[i] > max) max = x[i];
228
+ }
229
+ if (max == min) {
230
+ for (int i = 0; i < n; ++i) L[i] = 0;
231
+ *the_min = 0;
232
+ return 0.f;
233
+ }
234
+ if (min > 0) min = 0;
235
+ float iscale = nmax/(max - min);
236
+ float scale = 1/iscale;
237
+ for (int itry = 0; itry < ntry; ++itry) {
238
+ float sumlx = 0; int suml2 = 0;
239
+ bool did_change = false;
240
+ for (int i = 0; i < n; ++i) {
241
+ int l = nearest_int(iscale*(x[i] - min));
242
+ l = MAX(0, MIN(nmax, l));
243
+ if (l != L[i]) {
244
+ L[i] = l;
245
+ did_change = true;
246
+ }
247
+ sumlx += (x[i] - min)*l;
248
+ suml2 += l*l;
249
+ }
250
+ scale = sumlx/suml2;
251
+ float sum = 0;
252
+ for (int i = 0; i < n; ++i) {
253
+ sum += x[i] - scale*L[i];
254
+ }
255
+ min = sum/n;
256
+ if (min > 0) min = 0;
257
+ iscale = 1/scale;
258
+ if (!did_change) break;
259
+ }
260
+ *the_min = -min;
261
+ return scale;
262
+ }
263
+
264
+ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
265
+ if (j < 4) {
266
+ *d = q[j] & 63; *m = q[j + 4] & 63;
267
+ } else {
268
+ *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
269
+ *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
270
+ }
271
+ }
272
+
273
+ //========================- 2-bit (de)-quantization
274
+
275
+ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
276
+ assert(k % QK_K == 0);
277
+ const int nb = k / QK_K;
278
+
279
+ uint8_t L[QK_K];
280
+ float mins[QK_K/16];
281
+ float scales[QK_K/16];
282
+
283
+ const float q4scale = 15.f;
284
+
285
+ for (int i = 0; i < nb; i++) {
286
+
287
+ float max_scale = 0; // as we are deducting the min, scales are always positive
288
+ float max_min = 0;
289
+ for (int j = 0; j < QK_K/16; ++j) {
290
+ scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
291
+ float scale = scales[j];
292
+ if (scale > max_scale) {
293
+ max_scale = scale;
294
+ }
295
+ float min = mins[j];
296
+ if (min > max_min) {
297
+ max_min = min;
298
+ }
299
+ }
300
+
301
+ if (max_scale > 0) {
302
+ float iscale = q4scale/max_scale;
303
+ for (int j = 0; j < QK_K/16; ++j) {
304
+ int l = nearest_int(iscale*scales[j]);
305
+ y[i].scales[j] = l;
306
+ }
307
+ y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
308
+ } else {
309
+ for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
310
+ y[i].d = ggml_fp32_to_fp16(0.f);
311
+ }
312
+ if (max_min > 0) {
313
+ float iscale = q4scale/max_min;
314
+ for (int j = 0; j < QK_K/16; ++j) {
315
+ int l = nearest_int(iscale*mins[j]);
316
+ y[i].scales[j] |= (l << 4);
317
+ }
318
+ y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
319
+ } else {
320
+ y[i].dmin = ggml_fp32_to_fp16(0.f);
321
+ }
322
+ for (int j = 0; j < QK_K/16; ++j) {
323
+ const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
324
+ if (!d) continue;
325
+ const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
326
+ for (int ii = 0; ii < 16; ++ii) {
327
+ int l = nearest_int((x[16*j + ii] + dm)/d);
328
+ l = MAX(0, MIN(3, l));
329
+ L[16*j + ii] = l;
330
+ }
331
+ }
332
+
333
+ for (int j = 0; j < QK_K; j += 128) {
334
+ for (int l = 0; l < 32; ++l) {
335
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
336
+ }
337
+ }
338
+
339
+ x += QK_K;
340
+
341
+ }
342
+ }
343
+
344
+ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
345
+ assert(k % QK_K == 0);
346
+ const int nb = k / QK_K;
347
+
348
+ for (int i = 0; i < nb; i++) {
349
+
350
+ const float d = ggml_fp16_to_fp32(x[i].d);
351
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
352
+
353
+ const uint8_t * q = x[i].qs;
354
+
355
+ int is = 0;
356
+ float dl, ml;
357
+ for (int n = 0; n < QK_K; n += 128) {
358
+ int shift = 0;
359
+ for (int j = 0; j < 4; ++j) {
360
+
361
+ uint8_t sc = x[i].scales[is++];
362
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
363
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
364
+
365
+ sc = x[i].scales[is++];
366
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
367
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
368
+
369
+ shift += 2;
370
+ }
371
+ q += 32;
372
+ }
373
+
374
+ }
375
+ }
376
+
377
+ void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
378
+ quantize_row_q2_K_reference(x, vy, k);
379
+ }
380
+
381
+ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
382
+ const int nb = k / QK_K;
383
+
384
+ // TODO - collect histograms - although, at a second thought, I don't really care about them
385
+ (void)hist;
386
+
387
+ for (int j = 0; j < nb; j += k) {
388
+ block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
389
+ quantize_row_q2_K_reference(src + j, y, k);
390
+ }
391
+ return (n/QK_K*sizeof(block_q2_K));
392
+ }
393
+
394
+ //========================= 3-bit (de)-quantization
395
+
396
+ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
397
+ assert(k % QK_K == 0);
398
+ const int nb = k / QK_K;
399
+
400
+ int8_t L[QK_K];
401
+ float scales[QK_K / 16];
402
+
403
+ for (int i = 0; i < nb; i++) {
404
+
405
+ float max_scale = 0;
406
+ float amax = 0;
407
+ for (int j = 0; j < QK_K/16; ++j) {
408
+ scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
409
+ float scale = fabsf(scales[j]);
410
+ if (scale > amax) {
411
+ amax = scale; max_scale = scales[j];
412
+ }
413
+ }
414
+
415
+ memset(y[i].scales, 0, 12);
416
+ if (max_scale) {
417
+ float iscale = -32.f/max_scale;
418
+ for (int j = 0; j < QK_K/16; ++j) {
419
+ int8_t l = nearest_int(iscale*scales[j]);
420
+ l = MAX(-32, MIN(31, l)) + 32;
421
+ if (j < 8) {
422
+ y[i].scales[j] = l & 0xF;
423
+ } else {
424
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
425
+ }
426
+ l >>= 4;
427
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
428
+ }
429
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
430
+ } else {
431
+ y[i].d = ggml_fp32_to_fp16(0.f);
432
+ }
433
+
434
+ int8_t sc;
435
+ for (int j = 0; j < QK_K/16; ++j) {
436
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
437
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
438
+ float d = ggml_fp16_to_fp32(y[i].d) * sc;
439
+ if (!d) {
440
+ continue;
441
+ }
442
+ for (int ii = 0; ii < 16; ++ii) {
443
+ int l = nearest_int(x[16*j + ii]/d);
444
+ l = MAX(-4, MIN(3, l));
445
+ L[16*j + ii] = l + 4;
446
+ }
447
+ }
448
+
449
+ memset(y[i].hmask, 0, QK_K/8);
450
+ // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
451
+ int m = 0;
452
+ uint8_t hm = 1;
453
+ for (int j = 0; j < QK_K; ++j) {
454
+ if (L[j] > 3) {
455
+ y[i].hmask[m] |= hm;
456
+ L[j] -= 4;
457
+ }
458
+ if (++m == QK_K/8) {
459
+ m = 0; hm <<= 1;
460
+ }
461
+ }
462
+ for (int j = 0; j < QK_K; j += 128) {
463
+ for (int l = 0; l < 32; ++l) {
464
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
465
+ }
466
+ }
467
+
468
+ x += QK_K;
469
+ }
470
+ }
471
+
472
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
473
+ assert(k % QK_K == 0);
474
+ assert(QK_K == 256);
475
+ const int nb = k / QK_K;
476
+
477
+ const uint32_t kmask1 = 0x03030303;
478
+ const uint32_t kmask2 = 0x0f0f0f0f;
479
+
480
+ uint32_t aux[4];
481
+ const int8_t * scales = (const int8_t*)aux;
482
+
483
+ for (int i = 0; i < nb; i++) {
484
+
485
+ const float d_all = ggml_fp16_to_fp32(x[i].d);
486
+
487
+ const uint8_t * restrict q = x[i].qs;
488
+ const uint8_t * restrict hm = x[i].hmask;
489
+ uint8_t m = 1;
490
+
491
+ memcpy(aux, x[i].scales, 12);
492
+ uint32_t tmp = aux[2];
493
+ aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
494
+ aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
495
+ aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
496
+ aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
497
+
498
+ int is = 0;
499
+ float dl;
500
+ for (int n = 0; n < QK_K; n += 128) {
501
+ int shift = 0;
502
+ for (int j = 0; j < 4; ++j) {
503
+
504
+ dl = d_all * (scales[is++] - 32);
505
+ for (int l = 0; l < 16; ++l) {
506
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
507
+ }
508
+
509
+ dl = d_all * (scales[is++] - 32);
510
+ for (int l = 0; l < 16; ++l) {
511
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
512
+ }
513
+
514
+ shift += 2;
515
+ m <<= 1;
516
+ }
517
+ q += 32;
518
+ }
519
+
520
+ }
521
+ }
522
+
523
+ void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
524
+ quantize_row_q3_K_reference(x, vy, k);
525
+ }
526
+
527
+ size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
528
+ const int nb = k / QK_K;
529
+
530
+ // TODO - collect histograms - although, at a second thought, I don't really care about them
531
+ (void)hist;
532
+
533
+ for (int j = 0; j < nb; j += k) {
534
+ block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
535
+ quantize_row_q3_K_reference(src + j, y, k);
536
+ }
537
+ return (n/QK_K*sizeof(block_q3_K));
538
+ }
539
+
540
+ // ====================== 4-bit (de)-quantization
541
+
542
+ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
543
+ assert(k % QK_K == 0);
544
+ const int nb = k / QK_K;
545
+
546
+ uint8_t L[QK_K];
547
+ float mins[QK_K/32];
548
+ float scales[QK_K/32];
549
+
550
+ for (int i = 0; i < nb; i++) {
551
+
552
+ float max_scale = 0; // as we are deducting the min, scales are always positive
553
+ float max_min = 0;
554
+ for (int j = 0; j < QK_K/32; ++j) {
555
+ scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
556
+ float scale = scales[j];
557
+ if (scale > max_scale) {
558
+ max_scale = scale;
559
+ }
560
+ float min = mins[j];
561
+ if (min > max_min) {
562
+ max_min = min;
563
+ }
564
+ }
565
+
566
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
567
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
568
+ for (int j = 0; j < QK_K/32; ++j) {
569
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
570
+ uint8_t lm = nearest_int(inv_min*mins[j]);
571
+ ls = MIN(63, ls);
572
+ lm = MIN(63, lm);
573
+ if (j < 4) {
574
+ y[i].scales[j] = ls;
575
+ y[i].scales[j+4] = lm;
576
+ } else {
577
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
578
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
579
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
580
+ }
581
+ }
582
+ y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
583
+ y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
584
+
585
+ uint8_t sc, m;
586
+ for (int j = 0; j < QK_K/32; ++j) {
587
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
588
+ const float d = ggml_fp16_to_fp32(y[i].d) * sc;
589
+ if (!d) continue;
590
+ const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
591
+ for (int ii = 0; ii < 32; ++ii) {
592
+ int l = nearest_int((x[32*j + ii] + dm)/d);
593
+ l = MAX(0, MIN(15, l));
594
+ L[32*j + ii] = l;
595
+ }
596
+ }
597
+ uint8_t * q = y[i].qs;
598
+ for (int j = 0; j < QK_K; j += 64) {
599
+ for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
600
+ }
601
+
602
+ x += QK_K;
603
+
604
+ }
605
+ }
606
+
607
+ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
608
+ assert(k % QK_K == 0);
609
+ const int nb = k / QK_K;
610
+
611
+ for (int i = 0; i < nb; i++) {
612
+
613
+ const float d = ggml_fp16_to_fp32(x[i].d);
614
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
615
+
616
+ const uint8_t * q = x[i].qs;
617
+
618
+ int is = 0;
619
+ uint8_t sc, m;
620
+ for (int j = 0; j < QK_K; j += 64) {
621
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
622
+ const float d1 = d * sc; const float m1 = min * m;
623
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
624
+ const float d2 = d * sc; const float m2 = min * m;
625
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
626
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
627
+ q += 32; is += 2;
628
+ }
629
+
630
+ }
631
+ }
632
+
633
+ void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
634
+ assert(k % QK_K == 0);
635
+ block_q4_K * restrict y = vy;
636
+ quantize_row_q4_K_reference(x, y, k);
637
+ }
638
+
639
+ size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
640
+ assert(k % QK_K == 0);
641
+ const int nb = k / QK_K;
642
+ (void)hist; // TODO: collect histograms
643
+ for (int j = 0; j < nb; j += k) {
644
+ block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
645
+ quantize_row_q4_K_reference(src + j, y, k);
646
+ }
647
+ return (n/QK_K*sizeof(block_q4_K));
648
+ }
649
+
650
+ // ====================== 5-bit (de)-quantization
651
+
652
+ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
653
+ assert(k % QK_K == 0);
654
+ const int nb = k / QK_K;
655
+
656
+ uint8_t L[QK_K];
657
+ float mins[QK_K/32];
658
+ float scales[QK_K/32];
659
+
660
+ for (int i = 0; i < nb; i++) {
661
+
662
+ float max_scale = 0; // as we are deducting the min, scales are always positive
663
+ float max_min = 0;
664
+ for (int j = 0; j < QK_K/32; ++j) {
665
+ scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
666
+ float scale = scales[j];
667
+ if (scale > max_scale) {
668
+ max_scale = scale;
669
+ }
670
+ float min = mins[j];
671
+ if (min > max_min) {
672
+ max_min = min;
673
+ }
674
+ }
675
+
676
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
677
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
678
+ for (int j = 0; j < QK_K/32; ++j) {
679
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
680
+ uint8_t lm = nearest_int(inv_min*mins[j]);
681
+ ls = MIN(63, ls);
682
+ lm = MIN(63, lm);
683
+ if (j < 4) {
684
+ y[i].scales[j] = ls;
685
+ y[i].scales[j+4] = lm;
686
+ } else {
687
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
688
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
689
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
690
+ }
691
+ }
692
+ y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
693
+ y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
694
+
695
+ uint8_t sc, m;
696
+ for (int j = 0; j < QK_K/32; ++j) {
697
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
698
+ const float d = ggml_fp16_to_fp32(y[i].d) * sc;
699
+ if (!d) continue;
700
+ const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
701
+ for (int ii = 0; ii < 32; ++ii) {
702
+ int l = nearest_int((x[32*j + ii] + dm)/d);
703
+ l = MAX(0, MIN(31, l));
704
+ L[32*j + ii] = l;
705
+ }
706
+ }
707
+
708
+ uint8_t * restrict qh = y[i].qh;
709
+ uint8_t * restrict ql = y[i].qs;
710
+ memset(qh, 0, QK_K/8);
711
+
712
+ uint8_t m1 = 1, m2 = 2;
713
+ for (int n = 0; n < QK_K; n += 64) {
714
+ for (int j = 0; j < 32; ++j) {
715
+ int l1 = L[n + j];
716
+ if (l1 > 15) {
717
+ l1 -= 16; qh[j] |= m1;
718
+ }
719
+ int l2 = L[n + j + 32];
720
+ if (l2 > 15) {
721
+ l2 -= 16; qh[j] |= m2;
722
+ }
723
+ ql[j] = l1 | (l2 << 4);
724
+ }
725
+ m1 <<= 2; m2 <<= 2;
726
+ ql += 32;
727
+ }
728
+
729
+ x += QK_K;
730
+
731
+ }
732
+ }
733
+
734
+ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
735
+ assert(k % QK_K == 0);
736
+ const int nb = k / QK_K;
737
+
738
+ for (int i = 0; i < nb; i++) {
739
+
740
+ const float d = ggml_fp16_to_fp32(x[i].d);
741
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
742
+
743
+ const uint8_t * ql = x[i].qs;
744
+ const uint8_t * qh = x[i].qh;
745
+
746
+ int is = 0;
747
+ uint8_t sc, m;
748
+ uint8_t u1 = 1, u2 = 2;
749
+ for (int j = 0; j < QK_K; j += 64) {
750
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
751
+ const float d1 = d * sc; const float m1 = min * m;
752
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
753
+ const float d2 = d * sc; const float m2 = min * m;
754
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
755
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
756
+ ql += 32; is += 2;
757
+ u1 <<= 2; u2 <<= 2;
758
+ }
759
+ }
760
+ }
761
+
762
+ void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
763
+ assert(k % QK_K == 0);
764
+ block_q5_K * restrict y = vy;
765
+ quantize_row_q5_K_reference(x, y, k);
766
+ }
767
+
768
+ size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
769
+ assert(k % QK_K == 0);
770
+ const int nb = k / QK_K;
771
+ (void)hist;
772
+ for (int j = 0; j < nb; j += k) {
773
+ block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
774
+ quantize_row_q5_K_reference(src + j, y, k);
775
+ }
776
+ return (n/QK_K*sizeof(block_q5_K));
777
+ }
778
+
779
+ // ====================== 6-bit (de)-quantization
780
+
781
+ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
782
+ assert(k % QK_K == 0);
783
+ const int nb = k / QK_K;
784
+
785
+ int8_t L[QK_K];
786
+ float scales[QK_K/16];
787
+
788
+ for (int i = 0; i < nb; i++) {
789
+
790
+ float max_scale = 0;
791
+ float max_abs_scale = 0;
792
+
793
+ for (int ib = 0; ib < QK_K/16; ++ib) {
794
+
795
+ const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
796
+ scales[ib] = scale;
797
+
798
+ const float abs_scale = fabsf(scale);
799
+ if (abs_scale > max_abs_scale) {
800
+ max_abs_scale = abs_scale;
801
+ max_scale = scale;
802
+ }
803
+
804
+ }
805
+
806
+ float iscale = -128.f/max_scale;
807
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
808
+ for (int ib = 0; ib < QK_K/16; ++ib) {
809
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
810
+ }
811
+
812
+ for (int j = 0; j < QK_K/16; ++j) {
813
+ float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
814
+ if (!d) {
815
+ continue;
816
+ }
817
+ for (int ii = 0; ii < 16; ++ii) {
818
+ int l = nearest_int(x[16*j + ii]/d);
819
+ l = MAX(-32, MIN(31, l));
820
+ L[16*j + ii] = l + 32;
821
+ }
822
+ }
823
+
824
+ uint8_t * restrict ql = y[i].ql;
825
+ uint8_t * restrict qh = y[i].qh;
826
+ for (int j = 0; j < QK_K; j += 128) {
827
+ for (int l = 0; l < 32; ++l) {
828
+ const uint8_t q1 = L[j + l + 0] & 0xF;
829
+ const uint8_t q2 = L[j + l + 32] & 0xF;
830
+ const uint8_t q3 = L[j + l + 64] & 0xF;
831
+ const uint8_t q4 = L[j + l + 96] & 0xF;
832
+ ql[l+ 0] = q1 | (q3 << 4);
833
+ ql[l+32] = q2 | (q4 << 4);
834
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
835
+ }
836
+ ql += 64;
837
+ qh += 32;
838
+ }
839
+
840
+ x += QK_K;
841
+
842
+ }
843
+ }
844
+
845
+ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
846
+ assert(k % QK_K == 0);
847
+ const int nb = k / QK_K;
848
+
849
+ for (int i = 0; i < nb; i++) {
850
+
851
+ const float d = ggml_fp16_to_fp32(x[i].d);
852
+
853
+ const uint8_t * restrict ql = x[i].ql;
854
+ const uint8_t * restrict qh = x[i].qh;
855
+ const int8_t * restrict sc = x[i].scales;
856
+
857
+ for (int n = 0; n < QK_K; n += 128) {
858
+ for (int l = 0; l < 32; ++l) {
859
+ int is = l/16;
860
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
861
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
862
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
863
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
864
+ y[l + 0] = d * sc[is + 0] * q1;
865
+ y[l + 32] = d * sc[is + 2] * q2;
866
+ y[l + 64] = d * sc[is + 4] * q3;
867
+ y[l + 96] = d * sc[is + 6] * q4;
868
+ }
869
+ y += 128;
870
+ ql += 64;
871
+ qh += 32;
872
+ sc += 8;
873
+ }
874
+
875
+ }
876
+ }
877
+
878
+ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
879
+ assert(k % QK_K == 0);
880
+ block_q6_K * restrict y = vy;
881
+ quantize_row_q6_K_reference(x, y, k);
882
+ }
883
+
884
+ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
885
+ assert(k % QK_K == 0);
886
+ const int nb = k / QK_K;
887
+
888
+ (void)hist; // TODO
889
+
890
+ for (int j = 0; j < nb; j += k) {
891
+ block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
892
+ quantize_row_q6_K_reference(src + j, y, k);
893
+ }
894
+ return (n/QK_K*sizeof(block_q6_K));
895
+ }
896
+
897
+ //===================================== Q8_K ==============================================
898
+
899
+ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
900
+ assert(k % QK_K == 0);
901
+ const int nb = k / QK_K;
902
+
903
+ for (int i = 0; i < nb; i++) {
904
+
905
+ float max = 0;
906
+ float amax = 0;
907
+ for (int j = 0; j < QK_K; ++j) {
908
+ float ax = fabsf(x[j]);
909
+ if (ax > amax) {
910
+ amax = ax; max = x[j];
911
+ }
912
+ }
913
+ if (!amax) {
914
+ y[i].d = 0;
915
+ memset(y[i].qs, 0, QK_K);
916
+ x += QK_K;
917
+ continue;
918
+ }
919
+ const float iscale = -128.f/max;
920
+ for (int j = 0; j < QK_K; ++j) {
921
+ int v = nearest_int(iscale*x[j]);
922
+ y[i].qs[j] = MIN(127, v);
923
+ }
924
+ for (int j = 0; j < QK_K/16; ++j) {
925
+ int sum = 0;
926
+ for (int ii = 0; ii < 16; ++ii) {
927
+ sum += y[i].qs[j*16 + ii];
928
+ }
929
+ y[i].bsums[j] = sum;
930
+ }
931
+ y[i].d = 1/iscale;
932
+ x += QK_K;
933
+ }
934
+ }
935
+
936
+ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
937
+ assert(k % QK_K == 0);
938
+ const int nb = k / QK_K;
939
+
940
+ for (int i = 0; i < nb; i++) {
941
+ for (int j = 0; j < QK_K; ++j) {
942
+ *y++ = x[i].d * x[i].qs[j];
943
+ }
944
+ }
945
+ }
946
+
947
+ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
948
+ quantize_row_q8_K_reference(x, y, k);
949
+ }
950
+
951
+ //===================================== Dot ptoducts =================================
952
+
953
+ //
954
+ // Helper functions
955
+ //
956
+ #if __AVX__ || __AVX2__ || __AVX512F__
957
+
958
+ // horizontally add 8 floats
959
+ static inline float hsum_float_8(const __m256 x) {
960
+ __m128 res = _mm256_extractf128_ps(x, 1);
961
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
962
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
963
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
964
+ return _mm_cvtss_f32(res);
965
+ }
966
+
967
+ // shuffles to pick the required scales in dot products
968
+ static inline __m256i get_scale_shuffle_q3k(int i) {
969
+ static const uint8_t k_shuffle[128] = {
970
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
971
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
972
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
973
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
974
+ };
975
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
976
+ }
977
+ static inline __m256i get_scale_shuffle_k4(int i) {
978
+ static const uint8_t k_shuffle[256] = {
979
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
980
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
981
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
982
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
983
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
984
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
985
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
986
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
987
+ };
988
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
989
+ }
990
+ static inline __m128i get_scale_shuffle(int i) {
991
+ static const uint8_t k_shuffle[128] = {
992
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
993
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
994
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
995
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
996
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
997
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
998
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
999
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
1000
+ };
1001
+ return _mm_loadu_si128((const __m128i*)k_shuffle + i);
1002
+ }
1003
+ #endif
1004
+
1005
+ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1006
+
1007
+ const block_q2_K * restrict x = vx;
1008
+ const block_q8_K * restrict y = vy;
1009
+
1010
+ const int nb = n / QK_K;
1011
+
1012
+ #ifdef __ARM_NEON
1013
+
1014
+ const uint8x16_t m3 = vdupq_n_u8(0x3);
1015
+ const uint8x16_t m4 = vdupq_n_u8(0xF);
1016
+ const int32x4_t vzero = vdupq_n_s32(0);
1017
+
1018
+ int8x16x2_t q2bytes;
1019
+ uint8_t aux[16];
1020
+
1021
+ float sum = 0;
1022
+
1023
+ for (int i = 0; i < nb; ++i) {
1024
+
1025
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1026
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1027
+
1028
+ const uint8_t * restrict q2 = x[i].qs;
1029
+ const int8_t * restrict q8 = y[i].qs;
1030
+ const uint8_t * restrict sc = x[i].scales;
1031
+
1032
+ const uint8x16_t mins_and_scales = vld1q_u8(sc);
1033
+ const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
1034
+ vst1q_u8(aux, scales);
1035
+
1036
+ const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
1037
+ const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
1038
+ const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
1039
+ const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
1040
+ vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
1041
+ const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
1042
+ vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
1043
+ sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
1044
+
1045
+ int isum = 0;
1046
+ int is = 0;
1047
+
1048
+ // We use this macro instead of a function call because for some reason
1049
+ // the code runs 2-3% slower, even if the function is declared inline
1050
+ #if defined(__ARM_FEATURE_DOTPROD)
1051
+ #define MULTIPLY_ACCUM_WITH_SCALE(index)\
1052
+ isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
1053
+ isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
1054
+ #else
1055
+ #define MULTIPLY_ACCUM_WITH_SCALE(index)\
1056
+ {\
1057
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
1058
+ vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
1059
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
1060
+ vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
1061
+ isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
1062
+ }
1063
+ #endif
1064
+
1065
+ #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
1066
+ q8bytes = vld1q_s8_x2(q8); q8 += 32;\
1067
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
1068
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
1069
+ MULTIPLY_ACCUM_WITH_SCALE((index));
1070
+
1071
+
1072
+ for (int j = 0; j < QK_K/128; ++j) {
1073
+
1074
+ const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
1075
+
1076
+ int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
1077
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
1078
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
1079
+ MULTIPLY_ACCUM_WITH_SCALE(0);
1080
+
1081
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
1082
+
1083
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
1084
+
1085
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
1086
+
1087
+ is += 8;
1088
+ }
1089
+ sum += d * isum;
1090
+
1091
+ }
1092
+
1093
+ *s = sum;
1094
+
1095
+ #elif defined __AVX2__
1096
+
1097
+ const __m256i m3 = _mm256_set1_epi8(3);
1098
+ const __m128i m4 = _mm_set1_epi8(0xF);
1099
+
1100
+ __m256 acc = _mm256_setzero_ps();
1101
+
1102
+ for (int i = 0; i < nb; ++i) {
1103
+
1104
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1105
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1106
+
1107
+ const uint8_t * restrict q2 = x[i].qs;
1108
+ const int8_t * restrict q8 = y[i].qs;
1109
+
1110
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1111
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
1112
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1113
+ const __m256i mins = _mm256_cvtepi8_epi16(mins8);
1114
+ const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
1115
+
1116
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
1117
+
1118
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
1119
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1120
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1121
+ const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
1122
+
1123
+ __m256i sumi = _mm256_setzero_si256();
1124
+
1125
+ for (int j = 0; j < QK_K/128; ++j) {
1126
+
1127
+ const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
1128
+
1129
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1130
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1131
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1132
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1133
+
1134
+ const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
1135
+ const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
1136
+ const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
1137
+ const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
1138
+
1139
+ __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
1140
+ __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
1141
+ __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
1142
+ __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
1143
+
1144
+ p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
1145
+ p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
1146
+ p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
1147
+ p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
1148
+
1149
+ p0 = _mm256_add_epi32(p0, p1);
1150
+ p2 = _mm256_add_epi32(p2, p3);
1151
+
1152
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
1153
+ }
1154
+
1155
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1156
+
1157
+ }
1158
+
1159
+ *s = hsum_float_8(acc);
1160
+
1161
+ #else
1162
+
1163
+ float sumf = 0;
1164
+
1165
+ for (int i = 0; i < nb; ++i) {
1166
+
1167
+ const uint8_t * q2 = x[i].qs;
1168
+ const int8_t * q8 = y[i].qs;
1169
+ const uint8_t * sc = x[i].scales;
1170
+
1171
+ int summs = 0;
1172
+ for (int j = 0; j < 16; ++j) {
1173
+ summs += y[i].bsums[j] * (sc[j] >> 4);
1174
+ }
1175
+
1176
+ const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1177
+ const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1178
+
1179
+ int isum = 0;
1180
+ int is = 0;
1181
+ int d;
1182
+ for (int k = 0; k < QK_K/128; ++k) {
1183
+ int shift = 0;
1184
+ for (int j = 0; j < 4; ++j) {
1185
+ d = sc[is++] & 0xF;
1186
+ int isuml = 0;
1187
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
1188
+ isum += d * isuml;
1189
+ d = sc[is++] & 0xF;
1190
+ isuml = 0;
1191
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
1192
+ isum += d * isuml;
1193
+ shift += 2;
1194
+ q8 += 32;
1195
+ }
1196
+ q2 += 32;
1197
+ }
1198
+ sumf += dall * isum - dmin * summs;
1199
+ }
1200
+ *s = sumf;
1201
+ #endif
1202
+ }
1203
+
1204
+ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1205
+ assert(n % QK_K == 0);
1206
+
1207
+ const uint32_t kmask1 = 0x03030303;
1208
+ const uint32_t kmask2 = 0x0f0f0f0f;
1209
+
1210
+ const block_q3_K * restrict x = vx;
1211
+ const block_q8_K * restrict y = vy;
1212
+
1213
+ const int nb = n / QK_K;
1214
+
1215
+ #ifdef __ARM_NEON
1216
+
1217
+ uint32_t aux[3];
1218
+ uint32_t utmp[4];
1219
+
1220
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
1221
+ #ifdef __ARM_FEATURE_DOTPROD
1222
+ const int32x4_t vzero = vdupq_n_s32(0);
1223
+ #endif
1224
+
1225
+ const uint8x16_t m0 = vdupq_n_u8(1);
1226
+ const uint8x16_t m1 = vshlq_n_u8(m0, 1);
1227
+ const uint8x16_t m2 = vshlq_n_u8(m0, 2);
1228
+ const uint8x16_t m3 = vshlq_n_u8(m0, 3);
1229
+ const int8_t m32 = 32;
1230
+
1231
+ int8x16x4_t q3bytes;
1232
+
1233
+ float sum = 0;
1234
+
1235
+ for (int i = 0; i < nb; ++i) {
1236
+
1237
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1238
+
1239
+ const uint8_t * restrict q3 = x[i].qs;
1240
+ const uint8_t * restrict qh = x[i].hmask;
1241
+ const int8_t * restrict q8 = y[i].qs;
1242
+
1243
+ uint8x16x2_t qhbits = vld1q_u8_x2(qh);
1244
+
1245
+ uint8x16x4_t q3h;
1246
+
1247
+ int32_t isum = 0;
1248
+
1249
+ // Set up scales
1250
+ memcpy(aux, x[i].scales, 12);
1251
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
1252
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
1253
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
1254
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
1255
+
1256
+ int8_t * scale = (int8_t *)utmp;
1257
+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
1258
+
1259
+ for (int j = 0; j < QK_K/128; ++j) {
1260
+
1261
+ const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
1262
+ const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
1263
+ const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
1264
+
1265
+ q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
1266
+ q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
1267
+ q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
1268
+ q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
1269
+
1270
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
1271
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
1272
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
1273
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
1274
+
1275
+ #if defined(__ARM_FEATURE_DOTPROD)
1276
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
1277
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
1278
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
1279
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
1280
+ #else
1281
+ int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
1282
+ vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
1283
+ int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
1284
+ vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
1285
+ int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
1286
+ vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
1287
+ int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
1288
+ vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
1289
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
1290
+ #endif
1291
+ scale += 4;
1292
+
1293
+ q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
1294
+ q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
1295
+ q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
1296
+ q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
1297
+
1298
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
1299
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
1300
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
1301
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
1302
+
1303
+ #if defined(__ARM_FEATURE_DOTPROD)
1304
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
1305
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
1306
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
1307
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
1308
+ #else
1309
+ p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
1310
+ vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
1311
+ p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
1312
+ vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
1313
+ p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
1314
+ vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
1315
+ p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
1316
+ vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
1317
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
1318
+ #endif
1319
+ scale += 4;
1320
+
1321
+ if (j == 0) {
1322
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
1323
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
1324
+ }
1325
+
1326
+ }
1327
+ sum += d * isum;
1328
+
1329
+ }
1330
+
1331
+ *s = sum;
1332
+
1333
+ #elif defined __AVX2__
1334
+
1335
+ const __m256i m3 = _mm256_set1_epi8(3);
1336
+ const __m256i mone = _mm256_set1_epi8(1);
1337
+ const __m128i m32 = _mm_set1_epi8(32);
1338
+
1339
+ __m256 acc = _mm256_setzero_ps();
1340
+
1341
+ uint32_t aux[3];
1342
+
1343
+ for (int i = 0; i < nb; ++i) {
1344
+
1345
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1346
+
1347
+ const uint8_t * restrict q3 = x[i].qs;
1348
+ const int8_t * restrict q8 = y[i].qs;
1349
+
1350
+ // Set up scales
1351
+ memcpy(aux, x[i].scales, 12);
1352
+ __m128i scales128 = _mm_set_epi32(
1353
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1354
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1355
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1356
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1357
+ scales128 = _mm_sub_epi8(scales128, m32);
1358
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
1359
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1360
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1361
+ const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
1362
+
1363
+ // high bit
1364
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
1365
+
1366
+ // integer accumulator
1367
+ __m256i sumi = _mm256_setzero_si256();
1368
+
1369
+ int bit = 0;
1370
+ int is = 0;
1371
+
1372
+ for (int j = 0; j < QK_K/128; ++j) {
1373
+ // load low 2 bits
1374
+ const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
1375
+
1376
+ // prepare low and high bits
1377
+ const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
1378
+ const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1379
+ ++bit;
1380
+
1381
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
1382
+ const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1383
+ ++bit;
1384
+
1385
+ const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
1386
+ const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1387
+ ++bit;
1388
+
1389
+ const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
1390
+ const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
1391
+ ++bit;
1392
+
1393
+ // load Q8 quants
1394
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1395
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1396
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1397
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1398
+
1399
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1400
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1401
+ // and 2 if the high bit was set)
1402
+ __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
1403
+ __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
1404
+ __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
1405
+ __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
1406
+
1407
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
1408
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
1409
+ __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
1410
+ __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
1411
+
1412
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
1413
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
1414
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
1415
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
1416
+
1417
+ // multiply with scales
1418
+ p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
1419
+ p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
1420
+ p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
1421
+ p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
1422
+
1423
+ // accumulate
1424
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
1425
+ p16_2 = _mm256_add_epi32(p16_2, p16_3);
1426
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
1427
+
1428
+ }
1429
+
1430
+ // multiply with block scale and accumulate
1431
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
1432
+
1433
+ }
1434
+
1435
+ *s = hsum_float_8(acc);
1436
+
1437
+ #else
1438
+ // scalar version
1439
+ // This function is written like this so the compiler can manage to vectorize most of it
1440
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1441
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1442
+ // The ideal situation would be if we could just write the code once, and the compiler would
1443
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
1444
+ // write vectorized versions for AVX, ARM_NEON, etc.
1445
+
1446
+ int8_t aux8[QK_K];
1447
+ int16_t aux16[8];
1448
+ float sums [8];
1449
+ int32_t aux32[8];
1450
+ memset(sums, 0, 8*sizeof(float));
1451
+
1452
+ uint32_t auxs[4];
1453
+ const int8_t * scales = (const int8_t*)auxs;
1454
+
1455
+ float sumf = 0;
1456
+ for (int i = 0; i < nb; ++i) {
1457
+ const uint8_t * restrict q3 = x[i].qs;
1458
+ const uint8_t * restrict hm = x[i].hmask;
1459
+ const int8_t * restrict q8 = y[i].qs;
1460
+ memset(aux32, 0, 8*sizeof(int32_t));
1461
+ int8_t * restrict a = aux8;
1462
+ uint8_t m = 1;
1463
+ for (int j = 0; j < QK_K; j += 128) {
1464
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1465
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1466
+ a += 32; m <<= 1;
1467
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
1468
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1469
+ a += 32; m <<= 1;
1470
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
1471
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1472
+ a += 32; m <<= 1;
1473
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
1474
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1475
+ a += 32; m <<= 1;
1476
+ q3 += 32;
1477
+ }
1478
+ a = aux8;
1479
+
1480
+ memcpy(auxs, x[i].scales, 12);
1481
+ uint32_t tmp = auxs[2];
1482
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1483
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1484
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1485
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1486
+ for (int j = 0; j < QK_K/16; ++j) {
1487
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1488
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1489
+ q8 += 8; a += 8;
1490
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1491
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1492
+ q8 += 8; a += 8;
1493
+ }
1494
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
1495
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1496
+ }
1497
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1498
+ *s = sumf;
1499
+
1500
+ #endif
1501
+
1502
+ }
1503
+
1504
+ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1505
+ assert(n % QK_K == 0);
1506
+
1507
+ const block_q4_K * restrict x = vx;
1508
+ const block_q8_K * restrict y = vy;
1509
+
1510
+ const int nb = n / QK_K;
1511
+
1512
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1513
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1514
+ static const uint32_t kmask3 = 0x03030303;
1515
+
1516
+ uint32_t utmp[4];
1517
+
1518
+ #ifdef __ARM_NEON
1519
+
1520
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
1521
+ #ifdef __ARM_FEATURE_DOTPROD
1522
+ const int32x4_t mzero = vdupq_n_s32(0);
1523
+ #endif
1524
+
1525
+ int8x16x2_t q4bytes;
1526
+ int8x16x2_t q8bytes;
1527
+
1528
+ float sumf = 0;
1529
+
1530
+ for (int i = 0; i < nb; ++i) {
1531
+
1532
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1533
+ const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1534
+
1535
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
1536
+
1537
+ memcpy(utmp, x[i].scales, 12);
1538
+
1539
+ const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
1540
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1541
+ utmp[0] &= kmask1;
1542
+
1543
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
1544
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
1545
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
1546
+ sumf -= dmin * vaddvq_s32(prod);
1547
+
1548
+ const uint8_t * scales = (const uint8_t *)utmp;
1549
+
1550
+ const uint8_t * restrict q4 = x[i].qs;
1551
+ const int8_t * restrict q8 = y[i].qs;
1552
+
1553
+ //int32x4_t isum = mzero;
1554
+
1555
+ int32_t sumi1 = 0;
1556
+ int32_t sumi2 = 0;
1557
+
1558
+ for (int j = 0; j < QK_K/64; ++j) {
1559
+
1560
+ const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
1561
+
1562
+ #ifdef __ARM_FEATURE_DOTPROD
1563
+ q8bytes = vld1q_s8_x2(q8); q8 += 32;
1564
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
1565
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
1566
+
1567
+ const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
1568
+ sumi1 += vaddvq_s32(p1) * scales[2*j+0];
1569
+
1570
+ q8bytes = vld1q_s8_x2(q8); q8 += 32;
1571
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
1572
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
1573
+
1574
+ const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
1575
+
1576
+ sumi2 += vaddvq_s32(p2) * scales[2*j+1];
1577
+ #else
1578
+ q8bytes = vld1q_s8_x2(q8); q8 += 32;
1579
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
1580
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
1581
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
1582
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
1583
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
1584
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
1585
+ sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
1586
+
1587
+ q8bytes = vld1q_s8_x2(q8); q8 += 32;
1588
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
1589
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
1590
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
1591
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
1592
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
1593
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
1594
+ sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
1595
+
1596
+ #endif
1597
+ }
1598
+
1599
+ sumf += d * (sumi1 + sumi2);
1600
+
1601
+ }
1602
+
1603
+ *s = sumf;
1604
+
1605
+ #elif defined __AVX2__
1606
+
1607
+ const __m256i m4 = _mm256_set1_epi8(0xF);
1608
+
1609
+ __m256 acc = _mm256_setzero_ps();
1610
+ __m128 acc_m = _mm_setzero_ps();
1611
+
1612
+ for (int i = 0; i < nb; ++i) {
1613
+
1614
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1615
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1616
+
1617
+ const uint8_t * restrict q4 = x[i].qs;
1618
+ const int8_t * restrict q8 = y[i].qs;
1619
+
1620
+ memcpy(utmp, x[i].scales, 12);
1621
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1622
+ const uint32_t uaux = utmp[1] & kmask1;
1623
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1624
+ utmp[2] = uaux;
1625
+ utmp[0] &= kmask1;
1626
+
1627
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1628
+
1629
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1630
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1631
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1632
+ acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
1633
+
1634
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1635
+ const __m256i scales = _mm256_set_m128i(sc128, sc128);
1636
+
1637
+ __m256i sumi = _mm256_setzero_si256();
1638
+
1639
+ for (int j = 0; j < QK_K/64; ++j) {
1640
+
1641
+ const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1642
+ const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1643
+
1644
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
1645
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
1646
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
1647
+
1648
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1649
+ __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
1650
+ p16l = _mm256_madd_epi16(scale_l, p16l);
1651
+ sumi = _mm256_add_epi32(sumi, p16l);
1652
+
1653
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1654
+ __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
1655
+ p16h = _mm256_madd_epi16(scale_h, p16h);
1656
+ sumi = _mm256_add_epi32(sumi, p16h);
1657
+
1658
+ }
1659
+
1660
+ __m256 vd = _mm256_set1_ps(d);
1661
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
1662
+
1663
+ }
1664
+
1665
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1666
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1667
+
1668
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1669
+
1670
+ #else
1671
+
1672
+
1673
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1674
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1675
+
1676
+ int8_t aux8[QK_K];
1677
+ int16_t aux16[8];
1678
+ float sums [8];
1679
+ int32_t aux32[8];
1680
+ memset(sums, 0, 8*sizeof(float));
1681
+
1682
+ float sumf = 0;
1683
+ for (int i = 0; i < nb; ++i) {
1684
+ const uint8_t * restrict q4 = x[i].qs;
1685
+ const int8_t * restrict q8 = y[i].qs;
1686
+ memset(aux32, 0, 8*sizeof(int32_t));
1687
+ int8_t * restrict a = aux8;
1688
+ for (int j = 0; j < QK_K/64; ++j) {
1689
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1690
+ a += 32;
1691
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1692
+ a += 32; q4 += 32;
1693
+ }
1694
+ memcpy(utmp, x[i].scales, 12);
1695
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1696
+ const uint32_t uaux = utmp[1] & kmask1;
1697
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1698
+ utmp[2] = uaux;
1699
+ utmp[0] &= kmask1;
1700
+
1701
+ int sumi = 0;
1702
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1703
+ a = aux8;
1704
+ int is = 0;
1705
+ for (int j = 0; j < QK_K/32; ++j) {
1706
+ int32_t scale = scales[is++];
1707
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1708
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1709
+ q8 += 8; a += 8;
1710
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1711
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1712
+ q8 += 8; a += 8;
1713
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1714
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1715
+ q8 += 8; a += 8;
1716
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1717
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1718
+ q8 += 8; a += 8;
1719
+ }
1720
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
1721
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1722
+ const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
1723
+ sumf -= dmin * sumi;
1724
+ }
1725
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1726
+ *s = sumf;
1727
+ #endif
1728
+ }
1729
+
1730
+ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1731
+ assert(n % QK_K == 0);
1732
+
1733
+ const block_q5_K * restrict x = vx;
1734
+ const block_q8_K * restrict y = vy;
1735
+
1736
+ const int nb = n / QK_K;
1737
+
1738
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1739
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1740
+ static const uint32_t kmask3 = 0x03030303;
1741
+
1742
+ uint32_t utmp[4];
1743
+
1744
+
1745
+ #ifdef __ARM_NEON
1746
+
1747
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
1748
+ const int32x4_t mzero = vdupq_n_s32(0);
1749
+ const uint8x16_t mone = vdupq_n_u8(1);
1750
+ const uint8x16_t mtwo = vdupq_n_u8(2);
1751
+
1752
+ int8x16x4_t q5bytes;
1753
+
1754
+ float sumf = 0;
1755
+
1756
+ for (int i = 0; i < nb; ++i) {
1757
+
1758
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1759
+ const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1760
+
1761
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
1762
+
1763
+ memcpy(utmp, x[i].scales, 12);
1764
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1765
+ const uint32_t uaux = utmp[1] & kmask1;
1766
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1767
+ utmp[2] = uaux;
1768
+ utmp[0] &= kmask1;
1769
+
1770
+ const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
1771
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
1772
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
1773
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
1774
+ int32_t sumi_mins = vaddvq_s32(prod);
1775
+
1776
+ const uint8_t * scales = (const uint8_t *)utmp;
1777
+
1778
+ const uint8_t * restrict q5 = x[i].qs;
1779
+ const uint8_t * restrict qh = x[i].qh;
1780
+ const int8_t * restrict q8 = y[i].qs;
1781
+
1782
+ uint8x16x2_t qhbits = vld1q_u8_x2(qh);
1783
+
1784
+ uint8x16x4_t q5h;
1785
+
1786
+ int32_t sumi = 0;
1787
+
1788
+ for (int j = 0; j < QK_K/64; ++j) {
1789
+
1790
+ const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
1791
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
1792
+
1793
+ q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
1794
+ q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
1795
+ q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
1796
+ q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
1797
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
1798
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
1799
+
1800
+ q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
1801
+ q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
1802
+ q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
1803
+ q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
1804
+
1805
+ #if defined(__ARM_FEATURE_DOTPROD)
1806
+
1807
+ sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
1808
+ sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
1809
+ #else
1810
+
1811
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
1812
+ vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
1813
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
1814
+ vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
1815
+ sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
1816
+
1817
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
1818
+ vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
1819
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
1820
+ vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
1821
+ sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
1822
+ #endif
1823
+ }
1824
+
1825
+ sumf += d * sumi - dmin * sumi_mins;
1826
+
1827
+ }
1828
+
1829
+ *s = sumf;
1830
+
1831
+ #elif defined __AVX2__
1832
+
1833
+ const __m256i m4 = _mm256_set1_epi8(0xF);
1834
+ const __m128i mzero = _mm_setzero_si128();
1835
+ const __m256i mone = _mm256_set1_epi8(1);
1836
+
1837
+ __m256 acc = _mm256_setzero_ps();
1838
+
1839
+ float summs = 0.f;
1840
+
1841
+ for (int i = 0; i < nb; ++i) {
1842
+
1843
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1844
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1845
+
1846
+ const uint8_t * restrict q5 = x[i].qs;
1847
+ const int8_t * restrict q8 = y[i].qs;
1848
+
1849
+ memcpy(utmp, x[i].scales, 12);
1850
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1851
+ const uint32_t uaux = utmp[1] & kmask1;
1852
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1853
+ utmp[2] = uaux;
1854
+ utmp[0] &= kmask1;
1855
+
1856
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1857
+
1858
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
1859
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1860
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
1861
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
1862
+ summs += dmin * _mm_extract_epi32(hsum, 0);
1863
+
1864
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1865
+ const __m256i scales = _mm256_set_m128i(sc128, sc128);
1866
+
1867
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
1868
+ __m256i hmask = mone;
1869
+
1870
+ __m256i sumi = _mm256_setzero_si256();
1871
+
1872
+ int bit = 0;
1873
+
1874
+ for (int j = 0; j < QK_K/64; ++j) {
1875
+
1876
+ const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1877
+ const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1878
+
1879
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
1880
+
1881
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
1882
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1883
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
1884
+ hmask = _mm256_slli_epi16(hmask, 1);
1885
+
1886
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
1887
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1888
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
1889
+ hmask = _mm256_slli_epi16(hmask, 1);
1890
+
1891
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1892
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1893
+
1894
+ __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
1895
+ __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
1896
+
1897
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
1898
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
1899
+
1900
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
1901
+
1902
+ }
1903
+
1904
+ __m256 vd = _mm256_set1_ps(d);
1905
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
1906
+
1907
+ }
1908
+
1909
+ *s = hsum_float_8(acc) + summs;
1910
+
1911
+ #else
1912
+
1913
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1914
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1915
+
1916
+ int8_t aux8[QK_K];
1917
+ int16_t aux16[8];
1918
+ float sums [8];
1919
+ int32_t aux32[8];
1920
+ memset(sums, 0, 8*sizeof(float));
1921
+
1922
+ float sumf = 0;
1923
+ for (int i = 0; i < nb; ++i) {
1924
+ const uint8_t * restrict q4 = x[i].qs;
1925
+ const uint8_t * restrict hm = x[i].qh;
1926
+ const int8_t * restrict q8 = y[i].qs;
1927
+ memset(aux32, 0, 8*sizeof(int32_t));
1928
+ int8_t * restrict a = aux8;
1929
+ uint8_t m = 1;
1930
+ for (int j = 0; j < QK_K/64; ++j) {
1931
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1932
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1933
+ a += 32; m <<= 1;
1934
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1935
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1936
+ a += 32; m <<= 1;
1937
+ q4 += 32;
1938
+ }
1939
+ memcpy(utmp, x[i].scales, 12);
1940
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1941
+ const uint32_t uaux = utmp[1] & kmask1;
1942
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1943
+ utmp[2] = uaux;
1944
+ utmp[0] &= kmask1;
1945
+
1946
+ int sumi = 0;
1947
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1948
+ a = aux8;
1949
+ int is = 0;
1950
+ for (int j = 0; j < QK_K/32; ++j) {
1951
+ int32_t scale = scales[is++];
1952
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1953
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1954
+ q8 += 8; a += 8;
1955
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1956
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1957
+ q8 += 8; a += 8;
1958
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1959
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1960
+ q8 += 8; a += 8;
1961
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1962
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1963
+ q8 += 8; a += 8;
1964
+ }
1965
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
1966
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1967
+ const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
1968
+ sumf -= dmin * sumi;
1969
+ }
1970
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1971
+ *s = sumf;
1972
+ #endif
1973
+ }
1974
+
1975
+
1976
+
1977
+ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1978
+ assert(n % QK_K == 0);
1979
+
1980
+ const block_q6_K * restrict x = vx;
1981
+ const block_q8_K * restrict y = vy;
1982
+
1983
+ const int nb = n / QK_K;
1984
+
1985
+ #ifdef __ARM_NEON
1986
+
1987
+ float sum = 0;
1988
+
1989
+ const uint8x16_t m4b = vdupq_n_u8(0xF);
1990
+ const int32x4_t vzero = vdupq_n_s32(0);
1991
+ //const int8x16_t m32s = vdupq_n_s8(32);
1992
+
1993
+ const uint8x16_t mone = vdupq_n_u8(3);
1994
+
1995
+ int8x16x4_t q6bytes;
1996
+ uint8x16x4_t q6h;
1997
+
1998
+ for (int i = 0; i < nb; ++i) {
1999
+
2000
+ const float d_all = ggml_fp16_to_fp32(x[i].d);
2001
+
2002
+ const uint8_t * restrict q6 = x[i].ql;
2003
+ const uint8_t * restrict qh = x[i].qh;
2004
+ const int8_t * restrict q8 = y[i].qs;
2005
+
2006
+ const int8_t * restrict scale = x[i].scales;
2007
+
2008
+ const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
2009
+ const int8x16_t scales = vld1q_s8(scale);
2010
+ const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
2011
+
2012
+ const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
2013
+ vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
2014
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
2015
+ vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
2016
+ int32_t isum_mins = vaddvq_s32(prod);
2017
+
2018
+ int32_t isum = 0;
2019
+
2020
+ for (int j = 0; j < QK_K/128; ++j) {
2021
+
2022
+ uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
2023
+ uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
2024
+ int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
2025
+
2026
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
2027
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
2028
+ uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
2029
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
2030
+ shifted = vshrq_n_u8(qhbits.val[1], 2);
2031
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
2032
+
2033
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
2034
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
2035
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
2036
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
2037
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
2038
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
2039
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
2040
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
2041
+
2042
+ #if defined(__ARM_FEATURE_DOTPROD)
2043
+
2044
+ isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
2045
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
2046
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
2047
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
2048
+ scale += 4;
2049
+
2050
+ #else
2051
+
2052
+ int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
2053
+ vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
2054
+ int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
2055
+ vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
2056
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
2057
+ scale += 2;
2058
+
2059
+ int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
2060
+ vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
2061
+ int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
2062
+ vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
2063
+ isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
2064
+ scale += 2;
2065
+ #endif
2066
+
2067
+ q8bytes = vld1q_s8_x4(q8); q8 += 64;
2068
+
2069
+ shifted = vshrq_n_u8(qhbits.val[0], 4);
2070
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
2071
+ shifted = vshrq_n_u8(qhbits.val[1], 4);
2072
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
2073
+ shifted = vshrq_n_u8(qhbits.val[0], 6);
2074
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
2075
+ shifted = vshrq_n_u8(qhbits.val[1], 6);
2076
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
2077
+
2078
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
2079
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
2080
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
2081
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
2082
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
2083
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
2084
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
2085
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
2086
+
2087
+ #if defined(__ARM_FEATURE_DOTPROD)
2088
+
2089
+ isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
2090
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
2091
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
2092
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
2093
+ scale += 4;
2094
+
2095
+ //for (int l = 0; l < 4; ++l) {
2096
+ // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
2097
+ // isum += vaddvq_s32(p) * *scale++;
2098
+ //}
2099
+ #else
2100
+ p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
2101
+ vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
2102
+ p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
2103
+ vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
2104
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
2105
+ scale += 2;
2106
+
2107
+ p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
2108
+ vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
2109
+ p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
2110
+ vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
2111
+ isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
2112
+ scale += 2;
2113
+ #endif
2114
+
2115
+ }
2116
+ //sum += isum * d_all * y[i].d;
2117
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
2118
+
2119
+ }
2120
+ *s = sum;
2121
+
2122
+ #elif defined __AVX2__
2123
+
2124
+ const __m256i m4 = _mm256_set1_epi8(0xF);
2125
+ const __m256i m2 = _mm256_set1_epi8(3);
2126
+ const __m256i m32s = _mm256_set1_epi8(32);
2127
+
2128
+ __m256 acc = _mm256_setzero_ps();
2129
+
2130
+ for (int i = 0; i < nb; ++i) {
2131
+
2132
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2133
+
2134
+ const uint8_t * restrict q4 = x[i].ql;
2135
+ const uint8_t * restrict qh = x[i].qh;
2136
+ const int8_t * restrict q8 = y[i].qs;
2137
+
2138
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2139
+
2140
+ __m256i sumi = _mm256_setzero_si256();
2141
+
2142
+ int is = 0;
2143
+
2144
+ for (int j = 0; j < QK_K/128; ++j) {
2145
+
2146
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2147
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2148
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2149
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2150
+ is += 4;
2151
+
2152
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2153
+ const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
2154
+ const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
2155
+
2156
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
2157
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
2158
+ const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
2159
+ const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
2160
+
2161
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
2162
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
2163
+ const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
2164
+ const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
2165
+
2166
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2167
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2168
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2169
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2170
+
2171
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
2172
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
2173
+ __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
2174
+ __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
2175
+
2176
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
2177
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
2178
+ __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
2179
+ __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
2180
+
2181
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2182
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2183
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
2184
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
2185
+
2186
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
2187
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
2188
+ p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
2189
+ p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
2190
+
2191
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
2192
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
2193
+
2194
+ }
2195
+
2196
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
2197
+ }
2198
+
2199
+ *s = hsum_float_8(acc);
2200
+
2201
+ #else
2202
+
2203
+ int8_t aux8[QK_K];
2204
+ int16_t aux16[8];
2205
+ float sums [8];
2206
+ int32_t aux32[8];
2207
+ memset(sums, 0, 8*sizeof(float));
2208
+
2209
+ float sumf = 0;
2210
+ for (int i = 0; i < nb; ++i) {
2211
+ const uint8_t * restrict q4 = x[i].ql;
2212
+ const uint8_t * restrict qh = x[i].qh;
2213
+ const int8_t * restrict q8 = y[i].qs;
2214
+ memset(aux32, 0, 8*sizeof(int32_t));
2215
+ int8_t * restrict a = aux8;
2216
+ for (int j = 0; j < QK_K; j += 128) {
2217
+ for (int l = 0; l < 32; ++l) {
2218
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2219
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2220
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2221
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2222
+ }
2223
+ a += 128;
2224
+ q4 += 64;
2225
+ qh += 32;
2226
+ }
2227
+ a = aux8;
2228
+ int is = 0;
2229
+ for (int j = 0; j < QK_K/16; ++j) {
2230
+ int scale = x[i].scales[is++];
2231
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2232
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2233
+ q8 += 8; a += 8;
2234
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2235
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2236
+ q8 += 8; a += 8;
2237
+ }
2238
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
2239
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2240
+ }
2241
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2242
+ *s = sumf;
2243
+ #endif
2244
+ }