llama_cpp 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +231 -132
- data/ext/llama_cpp/src/ggml-cuda.cu +319 -52
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +138 -72
- data/ext/llama_cpp/src/llama.h +33 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +12 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
|
261
261
|
return scale;
|
262
262
|
}
|
263
263
|
|
264
|
+
#if QK_K == 256
|
264
265
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
265
266
|
if (j < 4) {
|
266
267
|
*d = q[j] & 63; *m = q[j + 4] & 63;
|
@@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
|
|
269
270
|
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
270
271
|
}
|
271
272
|
}
|
273
|
+
#endif
|
272
274
|
|
273
275
|
//========================- 2-bit (de)-quantization
|
274
276
|
|
@@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
|
330
332
|
}
|
331
333
|
}
|
332
334
|
|
335
|
+
#if QK_K == 256
|
333
336
|
for (int j = 0; j < QK_K; j += 128) {
|
334
337
|
for (int l = 0; l < 32; ++l) {
|
335
338
|
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
336
339
|
}
|
337
340
|
}
|
341
|
+
#else
|
342
|
+
for (int l = 0; l < 16; ++l) {
|
343
|
+
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
|
344
|
+
}
|
345
|
+
#endif
|
338
346
|
|
339
347
|
x += QK_K;
|
340
348
|
|
@@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
|
|
352
360
|
|
353
361
|
const uint8_t * q = x[i].qs;
|
354
362
|
|
363
|
+
#if QK_K == 256
|
355
364
|
int is = 0;
|
356
365
|
float dl, ml;
|
357
366
|
for (int n = 0; n < QK_K; n += 128) {
|
@@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
|
|
370
379
|
}
|
371
380
|
q += 32;
|
372
381
|
}
|
373
|
-
|
382
|
+
#else
|
383
|
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
384
|
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
385
|
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
386
|
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
387
|
+
for (int l = 0; l < 16; ++l) {
|
388
|
+
y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
|
389
|
+
y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
|
390
|
+
y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
|
391
|
+
y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
|
392
|
+
}
|
393
|
+
y += QK_K;
|
394
|
+
#endif
|
374
395
|
}
|
375
396
|
}
|
376
397
|
|
@@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
|
412
433
|
}
|
413
434
|
}
|
414
435
|
|
436
|
+
#if QK_K == 256
|
415
437
|
memset(y[i].scales, 0, 12);
|
416
438
|
if (max_scale) {
|
417
439
|
float iscale = -32.f/max_scale;
|
@@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
|
445
467
|
L[16*j + ii] = l + 4;
|
446
468
|
}
|
447
469
|
}
|
470
|
+
#else
|
471
|
+
if (max_scale) {
|
472
|
+
float iscale = -8.f/max_scale;
|
473
|
+
for (int j = 0; j < QK_K/16; j+=2) {
|
474
|
+
int l1 = nearest_int(iscale*scales[j]);
|
475
|
+
l1 = 8 + MAX(-8, MIN(7, l1));
|
476
|
+
int l2 = nearest_int(iscale*scales[j+1]);
|
477
|
+
l2 = 8 + MAX(-8, MIN(7, l2));
|
478
|
+
y[i].scales[j/2] = l1 | (l2 << 4);
|
479
|
+
}
|
480
|
+
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
481
|
+
} else {
|
482
|
+
for (int j = 0; j < QK_K/16; j+=2) {
|
483
|
+
y[i].scales[j/2] = 0;
|
484
|
+
}
|
485
|
+
y[i].d = ggml_fp32_to_fp16(0.f);
|
486
|
+
}
|
487
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
488
|
+
int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
|
489
|
+
float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
|
490
|
+
if (!d) {
|
491
|
+
continue;
|
492
|
+
}
|
493
|
+
for (int ii = 0; ii < 16; ++ii) {
|
494
|
+
int l = nearest_int(x[16*j + ii]/d);
|
495
|
+
l = MAX(-4, MIN(3, l));
|
496
|
+
L[16*j + ii] = l + 4;
|
497
|
+
}
|
498
|
+
}
|
499
|
+
#endif
|
448
500
|
|
449
501
|
memset(y[i].hmask, 0, QK_K/8);
|
450
|
-
// We put the high-bit for the 1st
|
502
|
+
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
|
451
503
|
int m = 0;
|
452
504
|
uint8_t hm = 1;
|
453
505
|
for (int j = 0; j < QK_K; ++j) {
|
@@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
|
459
511
|
m = 0; hm <<= 1;
|
460
512
|
}
|
461
513
|
}
|
514
|
+
#if QK_K == 256
|
462
515
|
for (int j = 0; j < QK_K; j += 128) {
|
463
516
|
for (int l = 0; l < 32; ++l) {
|
464
517
|
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
465
518
|
}
|
466
519
|
}
|
520
|
+
#else
|
521
|
+
for (int l = 0; l < 16; ++l) {
|
522
|
+
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
|
523
|
+
}
|
524
|
+
#endif
|
467
525
|
|
468
526
|
x += QK_K;
|
469
527
|
}
|
470
528
|
}
|
471
529
|
|
530
|
+
#if QK_K == 256
|
472
531
|
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
|
473
532
|
assert(k % QK_K == 0);
|
474
|
-
assert(QK_K == 256);
|
475
533
|
const int nb = k / QK_K;
|
476
534
|
|
477
535
|
const uint32_t kmask1 = 0x03030303;
|
@@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
|
|
519
577
|
|
520
578
|
}
|
521
579
|
}
|
580
|
+
#else
|
581
|
+
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
|
582
|
+
assert(k % QK_K == 0);
|
583
|
+
assert(QK_K == 64);
|
584
|
+
const int nb = k / QK_K;
|
585
|
+
|
586
|
+
for (int i = 0; i < nb; i++) {
|
587
|
+
|
588
|
+
const float d_all = ggml_fp16_to_fp32(x[i].d);
|
589
|
+
|
590
|
+
const uint8_t * restrict q = x[i].qs;
|
591
|
+
const uint8_t * restrict hm = x[i].hmask;
|
592
|
+
|
593
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
594
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
595
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
596
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
597
|
+
|
598
|
+
for (int l=0; l<8; ++l) {
|
599
|
+
uint8_t h = hm[l];
|
600
|
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
601
|
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
602
|
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
603
|
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
604
|
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
605
|
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
606
|
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
607
|
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
608
|
+
}
|
609
|
+
y += QK_K;
|
610
|
+
}
|
611
|
+
}
|
612
|
+
#endif
|
522
613
|
|
523
614
|
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
|
524
615
|
quantize_row_q3_K_reference(x, vy, k);
|
@@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
563
654
|
}
|
564
655
|
}
|
565
656
|
|
657
|
+
#if QK_K == 256
|
566
658
|
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
567
659
|
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
568
660
|
for (int j = 0; j < QK_K/32; ++j) {
|
@@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
594
686
|
L[32*j + ii] = l;
|
595
687
|
}
|
596
688
|
}
|
689
|
+
#else
|
690
|
+
const float s_factor = 15.f;
|
691
|
+
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
|
692
|
+
float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
|
693
|
+
int d1 = nearest_int(inv_scale*scales[0]);
|
694
|
+
int m1 = nearest_int(inv_min*mins[0]);
|
695
|
+
int d2 = nearest_int(inv_scale*scales[1]);
|
696
|
+
int m2 = nearest_int(inv_min*mins[1]);
|
697
|
+
y[i].scales[0] = d1 | (m1 << 4);
|
698
|
+
y[i].scales[1] = d2 | (m2 << 4);
|
699
|
+
y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
|
700
|
+
y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
|
701
|
+
|
702
|
+
float sumlx = 0;
|
703
|
+
int suml2 = 0;
|
704
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
705
|
+
const uint8_t sd = y[i].scales[j] & 0xF;
|
706
|
+
const uint8_t sm = y[i].scales[j] >> 4;
|
707
|
+
const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
|
708
|
+
if (!d) continue;
|
709
|
+
const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
|
710
|
+
for (int ii = 0; ii < 32; ++ii) {
|
711
|
+
int l = nearest_int((x[32*j + ii] + m)/d);
|
712
|
+
l = MAX(0, MIN(15, l));
|
713
|
+
L[32*j + ii] = l;
|
714
|
+
sumlx += (x[32*j + ii] + m)*l*sd;
|
715
|
+
suml2 += l*l*sd*sd;
|
716
|
+
}
|
717
|
+
}
|
718
|
+
if (suml2) {
|
719
|
+
y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
|
720
|
+
}
|
721
|
+
#endif
|
597
722
|
uint8_t * q = y[i].qs;
|
598
723
|
for (int j = 0; j < QK_K; j += 64) {
|
599
|
-
for (int l = 0; l < 32; ++l)
|
724
|
+
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
|
725
|
+
q += 32;
|
600
726
|
}
|
601
727
|
|
602
728
|
x += QK_K;
|
@@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
|
610
736
|
|
611
737
|
for (int i = 0; i < nb; i++) {
|
612
738
|
|
613
|
-
const float d = ggml_fp16_to_fp32(x[i].d);
|
614
|
-
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
615
|
-
|
616
739
|
const uint8_t * q = x[i].qs;
|
617
740
|
|
741
|
+
#if QK_K == 256
|
742
|
+
|
743
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
744
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
745
|
+
|
618
746
|
int is = 0;
|
619
747
|
uint8_t sc, m;
|
620
748
|
for (int j = 0; j < QK_K; j += 64) {
|
@@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
|
626
754
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
627
755
|
q += 32; is += 2;
|
628
756
|
}
|
757
|
+
#else
|
758
|
+
const float dall = ggml_fp16_to_fp32(x[i].d[0]);
|
759
|
+
const float mall = ggml_fp16_to_fp32(x[i].d[1]);
|
760
|
+
const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
|
761
|
+
const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
|
762
|
+
for (int l = 0; l < 32; ++l) {
|
763
|
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
764
|
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
765
|
+
}
|
766
|
+
y += QK_K;
|
767
|
+
#endif
|
629
768
|
|
630
769
|
}
|
631
770
|
}
|
@@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
653
792
|
assert(k % QK_K == 0);
|
654
793
|
const int nb = k / QK_K;
|
655
794
|
|
795
|
+
#if QK_K == 256
|
656
796
|
uint8_t L[QK_K];
|
657
797
|
float mins[QK_K/32];
|
658
798
|
float scales[QK_K/32];
|
799
|
+
#else
|
800
|
+
int8_t L[QK_K];
|
801
|
+
float scales[QK_K/16];
|
802
|
+
#endif
|
659
803
|
|
660
804
|
for (int i = 0; i < nb; i++) {
|
661
805
|
|
806
|
+
#if QK_K == 256
|
807
|
+
|
662
808
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
663
809
|
float max_min = 0;
|
664
810
|
for (int j = 0; j < QK_K/32; ++j) {
|
@@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
725
871
|
m1 <<= 2; m2 <<= 2;
|
726
872
|
ql += 32;
|
727
873
|
}
|
874
|
+
#else
|
875
|
+
float max_scale = 0, amax = 0;
|
876
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
877
|
+
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
|
878
|
+
float abs_scale = fabsf(scales[j]);
|
879
|
+
if (abs_scale > amax) {
|
880
|
+
amax = abs_scale;
|
881
|
+
max_scale = scales[j];
|
882
|
+
}
|
883
|
+
}
|
884
|
+
|
885
|
+
float iscale = -128.f/max_scale;
|
886
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
887
|
+
int l = nearest_int(iscale*scales[j]);
|
888
|
+
y[i].scales[j] = MAX(-128, MIN(127, l));
|
889
|
+
}
|
890
|
+
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
891
|
+
|
892
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
893
|
+
const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
|
894
|
+
if (!d) continue;
|
895
|
+
for (int ii = 0; ii < 16; ++ii) {
|
896
|
+
int l = nearest_int(x[16*j + ii]/d);
|
897
|
+
l = MAX(-16, MIN(15, l));
|
898
|
+
L[16*j + ii] = l + 16;
|
899
|
+
}
|
900
|
+
}
|
901
|
+
|
902
|
+
uint8_t * restrict qh = y[i].qh;
|
903
|
+
uint8_t * restrict ql = y[i].qs;
|
904
|
+
memset(qh, 0, QK_K/8);
|
905
|
+
|
906
|
+
for (int j = 0; j < 32; ++j) {
|
907
|
+
int jm = j%8;
|
908
|
+
int is = j/8;
|
909
|
+
int l1 = L[j];
|
910
|
+
if (l1 > 15) {
|
911
|
+
l1 -= 16; qh[jm] |= (1 << is);
|
912
|
+
}
|
913
|
+
int l2 = L[j + 32];
|
914
|
+
if (l2 > 15) {
|
915
|
+
l2 -= 16; qh[jm] |= (1 << (4 + is));
|
916
|
+
}
|
917
|
+
ql[j] = l1 | (l2 << 4);
|
918
|
+
}
|
919
|
+
#endif
|
728
920
|
|
729
921
|
x += QK_K;
|
730
922
|
|
@@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
|
|
737
929
|
|
738
930
|
for (int i = 0; i < nb; i++) {
|
739
931
|
|
740
|
-
const float d = ggml_fp16_to_fp32(x[i].d);
|
741
|
-
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
742
|
-
|
743
932
|
const uint8_t * ql = x[i].qs;
|
744
933
|
const uint8_t * qh = x[i].qh;
|
745
934
|
|
935
|
+
#if QK_K == 256
|
936
|
+
|
937
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
938
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
939
|
+
|
746
940
|
int is = 0;
|
747
941
|
uint8_t sc, m;
|
748
942
|
uint8_t u1 = 1, u2 = 2;
|
@@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
|
|
756
950
|
ql += 32; is += 2;
|
757
951
|
u1 <<= 2; u2 <<= 2;
|
758
952
|
}
|
953
|
+
#else
|
954
|
+
float d = ggml_fp16_to_fp32(x[i].d);
|
955
|
+
const int8_t * restrict s = x[i].scales;
|
956
|
+
for (int l = 0; l < 8; ++l) {
|
957
|
+
y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
958
|
+
y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
959
|
+
y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
960
|
+
y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
961
|
+
y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
962
|
+
y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
963
|
+
y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
964
|
+
y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
965
|
+
}
|
966
|
+
y += QK_K;
|
967
|
+
#endif
|
759
968
|
}
|
760
969
|
}
|
761
970
|
|
@@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
|
|
823
1032
|
|
824
1033
|
uint8_t * restrict ql = y[i].ql;
|
825
1034
|
uint8_t * restrict qh = y[i].qh;
|
1035
|
+
#if QK_K == 256
|
826
1036
|
for (int j = 0; j < QK_K; j += 128) {
|
827
1037
|
for (int l = 0; l < 32; ++l) {
|
828
1038
|
const uint8_t q1 = L[j + l + 0] & 0xF;
|
@@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
|
|
836
1046
|
ql += 64;
|
837
1047
|
qh += 32;
|
838
1048
|
}
|
1049
|
+
#else
|
1050
|
+
for (int l = 0; l < 32; ++l) {
|
1051
|
+
const uint8_t q1 = L[l + 0] & 0xF;
|
1052
|
+
const uint8_t q2 = L[l + 32] & 0xF;
|
1053
|
+
ql[l] = q1 | (q2 << 4);
|
1054
|
+
}
|
1055
|
+
for (int l = 0; l < 16; ++l) {
|
1056
|
+
qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
|
1057
|
+
}
|
1058
|
+
#endif
|
839
1059
|
|
840
1060
|
x += QK_K;
|
841
1061
|
|
@@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
|
|
854
1074
|
const uint8_t * restrict qh = x[i].qh;
|
855
1075
|
const int8_t * restrict sc = x[i].scales;
|
856
1076
|
|
1077
|
+
#if QK_K == 256
|
857
1078
|
for (int n = 0; n < QK_K; n += 128) {
|
858
1079
|
for (int l = 0; l < 32; ++l) {
|
859
1080
|
int is = l/16;
|
@@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
|
|
871
1092
|
qh += 32;
|
872
1093
|
sc += 8;
|
873
1094
|
}
|
1095
|
+
#else
|
1096
|
+
for (int l = 0; l < 16; ++l) {
|
1097
|
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1098
|
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1099
|
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1100
|
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1101
|
+
y[l+ 0] = d * sc[0] * q1;
|
1102
|
+
y[l+16] = d * sc[1] * q2;
|
1103
|
+
y[l+32] = d * sc[2] * q3;
|
1104
|
+
y[l+48] = d * sc[3] * q4;
|
1105
|
+
}
|
1106
|
+
y += 64;
|
1107
|
+
#endif
|
874
1108
|
|
875
1109
|
}
|
876
1110
|
}
|
@@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) {
|
|
1002
1236
|
}
|
1003
1237
|
#endif
|
1004
1238
|
|
1239
|
+
#if QK_K == 256
|
1005
1240
|
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1006
1241
|
|
1007
1242
|
const block_q2_K * restrict x = vx;
|
@@ -1158,6 +1393,112 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1158
1393
|
|
1159
1394
|
*s = hsum_float_8(acc);
|
1160
1395
|
|
1396
|
+
#elif defined __AVX__
|
1397
|
+
|
1398
|
+
const __m128i m3 = _mm_set1_epi8(0x3);
|
1399
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
1400
|
+
const __m128i m2 = _mm_set1_epi8(0x2);
|
1401
|
+
|
1402
|
+
__m256 acc = _mm256_setzero_ps();
|
1403
|
+
|
1404
|
+
for (int i = 0; i < nb; ++i) {
|
1405
|
+
|
1406
|
+
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1407
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1408
|
+
|
1409
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1410
|
+
const int8_t * restrict q8 = y[i].qs;
|
1411
|
+
|
1412
|
+
// load mins and scales from block_q2_K.scales[QK_K/16]
|
1413
|
+
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
1414
|
+
const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
|
1415
|
+
const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
1416
|
+
const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
|
1417
|
+
const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
|
1418
|
+
|
1419
|
+
// summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
|
1420
|
+
const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
|
1421
|
+
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
|
1422
|
+
|
1423
|
+
// sumf += -dmin * summs in 32bits*8
|
1424
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
|
1425
|
+
|
1426
|
+
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
|
1427
|
+
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
|
1428
|
+
const __m128i scales[2] = { scales_0, scales_1 };
|
1429
|
+
|
1430
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
1431
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
1432
|
+
|
1433
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1434
|
+
|
1435
|
+
// load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
|
1436
|
+
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1437
|
+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1438
|
+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1439
|
+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1440
|
+
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1441
|
+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1442
|
+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1443
|
+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1444
|
+
|
1445
|
+
// load 2bits*16*8 from block_q2_K.qs[QK_K/4]
|
1446
|
+
__m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
|
1447
|
+
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
1448
|
+
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
1449
|
+
const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
1450
|
+
const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
1451
|
+
q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
|
1452
|
+
const __m128i q2_1 = _mm_and_si128(q2bits, m3);
|
1453
|
+
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
1454
|
+
const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
1455
|
+
const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
1456
|
+
|
1457
|
+
// isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
|
1458
|
+
__m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
|
1459
|
+
__m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
|
1460
|
+
__m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
|
1461
|
+
__m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
|
1462
|
+
__m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
|
1463
|
+
__m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
|
1464
|
+
__m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
|
1465
|
+
__m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
|
1466
|
+
|
1467
|
+
// isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
|
1468
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
1469
|
+
p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
|
1470
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1471
|
+
p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
|
1472
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1473
|
+
p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
|
1474
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1475
|
+
p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
|
1476
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1477
|
+
p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
|
1478
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1479
|
+
p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
|
1480
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1481
|
+
p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
|
1482
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1483
|
+
p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
|
1484
|
+
|
1485
|
+
p0 = _mm_add_epi32(p0, p1);
|
1486
|
+
p2 = _mm_add_epi32(p2, p3);
|
1487
|
+
p4 = _mm_add_epi32(p4, p5);
|
1488
|
+
p6 = _mm_add_epi32(p6, p7);
|
1489
|
+
|
1490
|
+
// isum in 32bits*4*2
|
1491
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
|
1492
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
|
1493
|
+
}
|
1494
|
+
|
1495
|
+
// sumf += dall * isum - dmin * summs in 32bits
|
1496
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
1497
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
|
1498
|
+
}
|
1499
|
+
|
1500
|
+
*s = hsum_float_8(acc);
|
1501
|
+
|
1161
1502
|
#else
|
1162
1503
|
|
1163
1504
|
float sumf = 0;
|
@@ -1201,6 +1542,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1201
1542
|
#endif
|
1202
1543
|
}
|
1203
1544
|
|
1545
|
+
#else
|
1546
|
+
|
1547
|
+
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1548
|
+
|
1549
|
+
const block_q2_K * restrict x = vx;
|
1550
|
+
const block_q8_K * restrict y = vy;
|
1551
|
+
|
1552
|
+
const int nb = n / QK_K;
|
1553
|
+
|
1554
|
+
#ifdef __ARM_NEON
|
1555
|
+
|
1556
|
+
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
1557
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1558
|
+
|
1559
|
+
int8x16x4_t q2bytes;
|
1560
|
+
|
1561
|
+
uint32_t aux32[2];
|
1562
|
+
const uint8_t * scales = (const uint8_t *)aux32;
|
1563
|
+
|
1564
|
+
float sum = 0;
|
1565
|
+
|
1566
|
+
for (int i = 0; i < nb; ++i) {
|
1567
|
+
|
1568
|
+
const float d = y[i].d * (float)x[i].d;
|
1569
|
+
const float dmin = -y[i].d * (float)x[i].dmin;
|
1570
|
+
|
1571
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1572
|
+
const int8_t * restrict q8 = y[i].qs;
|
1573
|
+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
1574
|
+
|
1575
|
+
aux32[0] = sc[0] & 0x0f0f0f0f;
|
1576
|
+
aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
|
1577
|
+
|
1578
|
+
sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
|
1579
|
+
|
1580
|
+
int isum1 = 0, isum2 = 0;
|
1581
|
+
|
1582
|
+
const uint8x16_t q2bits = vld1q_u8(q2);
|
1583
|
+
|
1584
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
1585
|
+
|
1586
|
+
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
|
1587
|
+
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
|
1588
|
+
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
1589
|
+
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
1590
|
+
|
1591
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1592
|
+
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
1593
|
+
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
1594
|
+
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
1595
|
+
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
1596
|
+
#else
|
1597
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
1598
|
+
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
1599
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
1600
|
+
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
1601
|
+
isum1 += vaddvq_s16(p1) * scales[0];
|
1602
|
+
isum2 += vaddvq_s16(p2) * scales[1];
|
1603
|
+
|
1604
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
1605
|
+
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
1606
|
+
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
1607
|
+
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
1608
|
+
isum1 += vaddvq_s16(p3) * scales[2];
|
1609
|
+
isum2 += vaddvq_s16(p4) * scales[3];
|
1610
|
+
#endif
|
1611
|
+
sum += d * (isum1 + isum2);
|
1612
|
+
|
1613
|
+
}
|
1614
|
+
|
1615
|
+
*s = sum;
|
1616
|
+
|
1617
|
+
#elif defined __AVX2__
|
1618
|
+
|
1619
|
+
const __m256i m3 = _mm256_set1_epi8(3);
|
1620
|
+
|
1621
|
+
__m256 acc = _mm256_setzero_ps();
|
1622
|
+
|
1623
|
+
uint32_t ud, um;
|
1624
|
+
const uint8_t * restrict db = (const uint8_t *)&ud;
|
1625
|
+
const uint8_t * restrict mb = (const uint8_t *)&um;
|
1626
|
+
|
1627
|
+
float summs = 0;
|
1628
|
+
|
1629
|
+
// TODO: optimize this
|
1630
|
+
|
1631
|
+
for (int i = 0; i < nb; ++i) {
|
1632
|
+
|
1633
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1634
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1635
|
+
|
1636
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1637
|
+
const int8_t * restrict q8 = y[i].qs;
|
1638
|
+
|
1639
|
+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
1640
|
+
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
1641
|
+
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
1642
|
+
|
1643
|
+
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
1644
|
+
summs += dmin * smin;
|
1645
|
+
|
1646
|
+
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
1647
|
+
const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
1648
|
+
const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
1649
|
+
|
1650
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
1651
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
1652
|
+
|
1653
|
+
const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
1654
|
+
const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
1655
|
+
|
1656
|
+
const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
|
1657
|
+
const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
|
1658
|
+
const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
|
1659
|
+
const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
|
1660
|
+
|
1661
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
|
1662
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
|
1663
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
|
1664
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
|
1665
|
+
}
|
1666
|
+
|
1667
|
+
*s = hsum_float_8(acc) + summs;
|
1668
|
+
|
1669
|
+
#else
|
1670
|
+
|
1671
|
+
float sumf = 0;
|
1672
|
+
|
1673
|
+
int isum[4];
|
1674
|
+
|
1675
|
+
for (int i = 0; i < nb; ++i) {
|
1676
|
+
|
1677
|
+
const uint8_t * q2 = x[i].qs;
|
1678
|
+
const int8_t * q8 = y[i].qs;
|
1679
|
+
const uint8_t * sc = x[i].scales;
|
1680
|
+
|
1681
|
+
int summs = 0;
|
1682
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
1683
|
+
summs += y[i].bsums[j] * (sc[j] >> 4);
|
1684
|
+
}
|
1685
|
+
|
1686
|
+
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1687
|
+
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1688
|
+
|
1689
|
+
isum[0] = isum[1] = isum[2] = isum[3] = 0;
|
1690
|
+
for (int l = 0; l < 16; ++l) {
|
1691
|
+
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
|
1692
|
+
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
|
1693
|
+
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
|
1694
|
+
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
|
1695
|
+
}
|
1696
|
+
for (int l = 0; l < 4; ++l) {
|
1697
|
+
isum[l] *= (sc[l] & 0xF);
|
1698
|
+
}
|
1699
|
+
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
|
1700
|
+
}
|
1701
|
+
*s = sumf;
|
1702
|
+
#endif
|
1703
|
+
}
|
1704
|
+
#endif
|
1705
|
+
|
1706
|
+
#if QK_K == 256
|
1204
1707
|
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1205
1708
|
assert(n % QK_K == 0);
|
1206
1709
|
|
@@ -1434,34 +1937,176 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1434
1937
|
|
1435
1938
|
*s = hsum_float_8(acc);
|
1436
1939
|
|
1437
|
-
#
|
1438
|
-
// scalar version
|
1439
|
-
// This function is written like this so the compiler can manage to vectorize most of it
|
1440
|
-
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
1441
|
-
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
1442
|
-
// The ideal situation would be if we could just write the code once, and the compiler would
|
1443
|
-
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
1444
|
-
// write vectorized versions for AVX, ARM_NEON, etc.
|
1940
|
+
#elif defined __AVX__
|
1445
1941
|
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
memset(sums, 0, 8*sizeof(float));
|
1942
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
1943
|
+
const __m128i mone = _mm_set1_epi8(1);
|
1944
|
+
const __m128i m32 = _mm_set1_epi8(32);
|
1945
|
+
const __m128i m2 = _mm_set1_epi8(2);
|
1451
1946
|
|
1452
|
-
|
1453
|
-
|
1947
|
+
__m256 acc = _mm256_setzero_ps();
|
1948
|
+
|
1949
|
+
uint32_t *aux;
|
1454
1950
|
|
1455
|
-
float sumf = 0;
|
1456
1951
|
for (int i = 0; i < nb; ++i) {
|
1952
|
+
|
1953
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1954
|
+
|
1457
1955
|
const uint8_t * restrict q3 = x[i].qs;
|
1458
|
-
const
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1956
|
+
const int8_t * restrict q8 = y[i].qs;
|
1957
|
+
|
1958
|
+
// Set up scales
|
1959
|
+
aux = (uint32_t *)x[i].scales;
|
1960
|
+
__m128i scales128 = _mm_set_epi32(
|
1961
|
+
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
|
1962
|
+
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
|
1963
|
+
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
|
1964
|
+
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
|
1965
|
+
scales128 = _mm_sub_epi8(scales128, m32);
|
1966
|
+
const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
|
1967
|
+
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
|
1968
|
+
const __m128i scales[2] = { scales_0, scales_1 };
|
1969
|
+
|
1970
|
+
// high bit *128*2 from block_q3_K.hmask[QK_K/8]
|
1971
|
+
const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
|
1972
|
+
const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
|
1973
|
+
|
1974
|
+
// integer accumulator
|
1975
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
1976
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
1977
|
+
|
1978
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1979
|
+
// load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
|
1980
|
+
const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
|
1981
|
+
const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
|
1982
|
+
|
1983
|
+
// prepare low and high bits
|
1984
|
+
const int bit = j << 2;
|
1985
|
+
|
1986
|
+
const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
|
1987
|
+
const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
|
1988
|
+
const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
|
1989
|
+
const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
|
1990
|
+
|
1991
|
+
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
|
1992
|
+
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
|
1993
|
+
const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
|
1994
|
+
const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
|
1995
|
+
|
1996
|
+
const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
|
1997
|
+
const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
|
1998
|
+
const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
|
1999
|
+
const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
|
2000
|
+
|
2001
|
+
const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
|
2002
|
+
const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
|
2003
|
+
const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
|
2004
|
+
const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
|
2005
|
+
|
2006
|
+
// load Q8 quants from block_q8_K.qs[QK_K]
|
2007
|
+
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2008
|
+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2009
|
+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2010
|
+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2011
|
+
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2012
|
+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2013
|
+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2014
|
+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2015
|
+
|
2016
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
2017
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
2018
|
+
// and 2 if the high bit was set)
|
2019
|
+
__m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
|
2020
|
+
__m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
|
2021
|
+
__m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
|
2022
|
+
__m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
|
2023
|
+
__m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
|
2024
|
+
__m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
|
2025
|
+
__m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
|
2026
|
+
__m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
|
2027
|
+
|
2028
|
+
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
|
2029
|
+
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
|
2030
|
+
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
|
2031
|
+
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
|
2032
|
+
__m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
|
2033
|
+
__m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
|
2034
|
+
__m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
|
2035
|
+
__m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
|
2036
|
+
|
2037
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
2038
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
2039
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
2040
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
2041
|
+
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
|
2042
|
+
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
2043
|
+
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
2044
|
+
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
2045
|
+
|
2046
|
+
// multiply with scales
|
2047
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
2048
|
+
p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
|
2049
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2050
|
+
p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
|
2051
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2052
|
+
p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
|
2053
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2054
|
+
p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
|
2055
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2056
|
+
p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
|
2057
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2058
|
+
p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
|
2059
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2060
|
+
p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
|
2061
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2062
|
+
p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
|
2063
|
+
|
2064
|
+
// accumulate
|
2065
|
+
p16_0 = _mm_add_epi32(p16_0, p16_1);
|
2066
|
+
p16_2 = _mm_add_epi32(p16_2, p16_3);
|
2067
|
+
p16_4 = _mm_add_epi32(p16_4, p16_5);
|
2068
|
+
p16_6 = _mm_add_epi32(p16_6, p16_7);
|
2069
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
2070
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
|
2071
|
+
|
2072
|
+
}
|
2073
|
+
|
2074
|
+
// multiply with block scale and accumulate
|
2075
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
2076
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
2077
|
+
|
2078
|
+
}
|
2079
|
+
|
2080
|
+
*s = hsum_float_8(acc);
|
2081
|
+
|
2082
|
+
#else
|
2083
|
+
// scalar version
|
2084
|
+
// This function is written like this so the compiler can manage to vectorize most of it
|
2085
|
+
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
2086
|
+
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
2087
|
+
// The ideal situation would be if we could just write the code once, and the compiler would
|
2088
|
+
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
2089
|
+
// write vectorized versions for AVX, ARM_NEON, etc.
|
2090
|
+
|
2091
|
+
int8_t aux8[QK_K];
|
2092
|
+
int16_t aux16[8];
|
2093
|
+
float sums [8];
|
2094
|
+
int32_t aux32[8];
|
2095
|
+
memset(sums, 0, 8*sizeof(float));
|
2096
|
+
|
2097
|
+
uint32_t auxs[4];
|
2098
|
+
const int8_t * scales = (const int8_t*)auxs;
|
2099
|
+
|
2100
|
+
float sumf = 0;
|
2101
|
+
for (int i = 0; i < nb; ++i) {
|
2102
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2103
|
+
const uint8_t * restrict hm = x[i].hmask;
|
2104
|
+
const int8_t * restrict q8 = y[i].qs;
|
2105
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2106
|
+
int8_t * restrict a = aux8;
|
2107
|
+
uint8_t m = 1;
|
2108
|
+
for (int j = 0; j < QK_K; j += 128) {
|
2109
|
+
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
|
1465
2110
|
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
1466
2111
|
a += 32; m <<= 1;
|
1467
2112
|
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
|
@@ -1501,6 +2146,206 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1501
2146
|
|
1502
2147
|
}
|
1503
2148
|
|
2149
|
+
#else
|
2150
|
+
|
2151
|
+
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2152
|
+
assert(n % QK_K == 0);
|
2153
|
+
|
2154
|
+
const block_q3_K * restrict x = vx;
|
2155
|
+
const block_q8_K * restrict y = vy;
|
2156
|
+
|
2157
|
+
const int nb = n / QK_K;
|
2158
|
+
|
2159
|
+
#ifdef __ARM_NEON
|
2160
|
+
|
2161
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
2162
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
2163
|
+
#endif
|
2164
|
+
|
2165
|
+
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
2166
|
+
const uint8x16_t mh = vdupq_n_u8(4);
|
2167
|
+
|
2168
|
+
int8x16x4_t q3bytes;
|
2169
|
+
|
2170
|
+
uint16_t aux16[2];
|
2171
|
+
int8_t * scales = (int8_t *)aux16;
|
2172
|
+
|
2173
|
+
float sum = 0;
|
2174
|
+
|
2175
|
+
for (int i = 0; i < nb; ++i) {
|
2176
|
+
|
2177
|
+
uint8x16x4_t q3h;
|
2178
|
+
|
2179
|
+
const uint8x8_t hbits = vld1_u8(x[i].hmask);
|
2180
|
+
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
2181
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
|
2182
|
+
|
2183
|
+
const uint16_t a = *(const uint16_t *)x[i].scales;
|
2184
|
+
aux16[0] = a & 0x0f0f;
|
2185
|
+
aux16[1] = (a >> 4) & 0x0f0f;
|
2186
|
+
|
2187
|
+
for (int j = 0; j < 4; ++j) scales[j] -= 8;
|
2188
|
+
|
2189
|
+
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
2190
|
+
|
2191
|
+
const float d = y[i].d * (float)x[i].d;
|
2192
|
+
|
2193
|
+
const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
|
2194
|
+
q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
|
2195
|
+
q3h.val[1] = vandq_u8(mh, htmp);
|
2196
|
+
q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
|
2197
|
+
q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
|
2198
|
+
|
2199
|
+
q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
|
2200
|
+
q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
|
2201
|
+
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
2202
|
+
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
2203
|
+
|
2204
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2205
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
2206
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
2207
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
2208
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
2209
|
+
#else
|
2210
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
2211
|
+
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
2212
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
2213
|
+
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
2214
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
2215
|
+
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
2216
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
2217
|
+
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
2218
|
+
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
2219
|
+
#endif
|
2220
|
+
|
2221
|
+
sum += d * isum;
|
2222
|
+
|
2223
|
+
}
|
2224
|
+
|
2225
|
+
*s = sum;
|
2226
|
+
|
2227
|
+
#elif defined __AVX2__
|
2228
|
+
|
2229
|
+
const __m256i m3 = _mm256_set1_epi8(3);
|
2230
|
+
const __m256i m1 = _mm256_set1_epi8(1);
|
2231
|
+
|
2232
|
+
__m256 acc = _mm256_setzero_ps();
|
2233
|
+
|
2234
|
+
uint64_t aux64;
|
2235
|
+
|
2236
|
+
uint16_t aux16[2];
|
2237
|
+
const int8_t * aux8 = (const int8_t *)aux16;
|
2238
|
+
|
2239
|
+
for (int i = 0; i < nb; ++i) {
|
2240
|
+
|
2241
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2242
|
+
|
2243
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2244
|
+
const int8_t * restrict q8 = y[i].qs;
|
2245
|
+
|
2246
|
+
const uint16_t a = *(const uint16_t *)x[i].scales;
|
2247
|
+
aux16[0] = a & 0x0f0f;
|
2248
|
+
aux16[1] = (a >> 4) & 0x0f0f;
|
2249
|
+
|
2250
|
+
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
2251
|
+
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
2252
|
+
|
2253
|
+
memcpy(&aux64, x[i].hmask, 8);
|
2254
|
+
|
2255
|
+
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
2256
|
+
__m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
|
2257
|
+
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
2258
|
+
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
2259
|
+
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
2260
|
+
|
2261
|
+
// load low 2 bits
|
2262
|
+
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
2263
|
+
|
2264
|
+
// prepare low and high bits
|
2265
|
+
const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
|
2266
|
+
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
2267
|
+
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
2268
|
+
|
2269
|
+
// load Q8 quants
|
2270
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2271
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2272
|
+
|
2273
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
2274
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
2275
|
+
// and 2 if the high bit was set)
|
2276
|
+
const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
2277
|
+
const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
2278
|
+
|
2279
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
2280
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
2281
|
+
|
2282
|
+
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
2283
|
+
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
2284
|
+
|
2285
|
+
// multiply with scales
|
2286
|
+
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
2287
|
+
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
2288
|
+
|
2289
|
+
p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
2290
|
+
|
2291
|
+
// multiply with block scale and accumulate
|
2292
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
|
2293
|
+
|
2294
|
+
}
|
2295
|
+
|
2296
|
+
*s = hsum_float_8(acc);
|
2297
|
+
|
2298
|
+
#else
|
2299
|
+
|
2300
|
+
int8_t aux8[QK_K];
|
2301
|
+
int16_t aux16[8];
|
2302
|
+
float sums [8];
|
2303
|
+
int32_t aux32[8];
|
2304
|
+
int32_t scales[4];
|
2305
|
+
memset(sums, 0, 8*sizeof(float));
|
2306
|
+
|
2307
|
+
float sumf = 0;
|
2308
|
+
for (int i = 0; i < nb; ++i) {
|
2309
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2310
|
+
const uint8_t * restrict hm = x[i].hmask;
|
2311
|
+
const int8_t * restrict q8 = y[i].qs;
|
2312
|
+
int8_t * restrict a = aux8;
|
2313
|
+
for (int l = 0; l < 8; ++l) {
|
2314
|
+
a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
|
2315
|
+
a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
|
2316
|
+
a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
|
2317
|
+
a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
|
2318
|
+
a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
|
2319
|
+
a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
|
2320
|
+
a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
|
2321
|
+
a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
|
2322
|
+
}
|
2323
|
+
|
2324
|
+
scales[0] = (x[i].scales[0] & 0xF) - 8;
|
2325
|
+
scales[1] = (x[i].scales[0] >> 4) - 8;
|
2326
|
+
scales[2] = (x[i].scales[1] & 0xF) - 8;
|
2327
|
+
scales[3] = (x[i].scales[1] >> 4) - 8;
|
2328
|
+
|
2329
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2330
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
2331
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2332
|
+
q8 += 8; a += 8;
|
2333
|
+
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
2334
|
+
q8 += 8; a += 8;
|
2335
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
|
2336
|
+
}
|
2337
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
2338
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
2339
|
+
}
|
2340
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2341
|
+
*s = sumf;
|
2342
|
+
|
2343
|
+
#endif
|
2344
|
+
|
2345
|
+
}
|
2346
|
+
#endif
|
2347
|
+
|
2348
|
+
#if QK_K == 256
|
1504
2349
|
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1505
2350
|
assert(n % QK_K == 0);
|
1506
2351
|
|
@@ -1614,9 +2459,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1614
2459
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1615
2460
|
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1616
2461
|
|
1617
|
-
const uint8_t * restrict q4 = x[i].qs;
|
1618
|
-
const int8_t * restrict q8 = y[i].qs;
|
1619
|
-
|
1620
2462
|
memcpy(utmp, x[i].scales, 12);
|
1621
2463
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1622
2464
|
const uint32_t uaux = utmp[1] & kmask1;
|
@@ -1624,6 +2466,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1624
2466
|
utmp[2] = uaux;
|
1625
2467
|
utmp[0] &= kmask1;
|
1626
2468
|
|
2469
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2470
|
+
const int8_t * restrict q8 = y[i].qs;
|
2471
|
+
|
1627
2472
|
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
1628
2473
|
|
1629
2474
|
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
|
@@ -1667,6 +2512,88 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1667
2512
|
|
1668
2513
|
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
|
1669
2514
|
|
2515
|
+
#elif defined __AVX__
|
2516
|
+
|
2517
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
2518
|
+
const __m128i m2 = _mm_set1_epi8(0x2);
|
2519
|
+
|
2520
|
+
__m256 acc = _mm256_setzero_ps();
|
2521
|
+
__m128 acc_m = _mm_setzero_ps();
|
2522
|
+
|
2523
|
+
for (int i = 0; i < nb; ++i) {
|
2524
|
+
|
2525
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2526
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
2527
|
+
|
2528
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2529
|
+
const int8_t * restrict q8 = y[i].qs;
|
2530
|
+
|
2531
|
+
memcpy(utmp, x[i].scales, 12);
|
2532
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
2533
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
2534
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2535
|
+
utmp[2] = uaux;
|
2536
|
+
utmp[0] &= kmask1;
|
2537
|
+
|
2538
|
+
const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
|
2539
|
+
const __m128i scales = _mm_cvtepu8_epi16(utmps);
|
2540
|
+
const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
|
2541
|
+
|
2542
|
+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
|
2543
|
+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
|
2544
|
+
const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
|
2545
|
+
const __m128i prod = _mm_madd_epi16(mins, q8s);
|
2546
|
+
acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
|
2547
|
+
|
2548
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
2549
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
2550
|
+
|
2551
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
2552
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2553
|
+
|
2554
|
+
const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
|
2555
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2556
|
+
const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
|
2557
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2558
|
+
|
2559
|
+
__m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
2560
|
+
const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
|
2561
|
+
const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
|
2562
|
+
q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
2563
|
+
const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
|
2564
|
+
const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
|
2565
|
+
|
2566
|
+
const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2567
|
+
__m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
|
2568
|
+
p16l = _mm_madd_epi16(scale_l, p16l);
|
2569
|
+
sumi_0 = _mm_add_epi32(sumi_0, p16l);
|
2570
|
+
const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2571
|
+
p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
|
2572
|
+
p16l = _mm_madd_epi16(scale_l, p16l);
|
2573
|
+
sumi_1 = _mm_add_epi32(sumi_1, p16l);
|
2574
|
+
|
2575
|
+
const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2576
|
+
__m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
|
2577
|
+
p16h = _mm_madd_epi16(scale_h, p16h);
|
2578
|
+
sumi_0 = _mm_add_epi32(sumi_0, p16h);
|
2579
|
+
const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2580
|
+
p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
|
2581
|
+
p16h = _mm_madd_epi16(scale_h, p16h);
|
2582
|
+
sumi_1 = _mm_add_epi32(sumi_1, p16h);
|
2583
|
+
|
2584
|
+
}
|
2585
|
+
|
2586
|
+
__m256 vd = _mm256_set1_ps(d);
|
2587
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
2588
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
2589
|
+
|
2590
|
+
}
|
2591
|
+
|
2592
|
+
acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
2593
|
+
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
2594
|
+
|
2595
|
+
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
|
2596
|
+
|
1670
2597
|
#else
|
1671
2598
|
|
1672
2599
|
|
@@ -1726,7 +2653,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1726
2653
|
*s = sumf;
|
1727
2654
|
#endif
|
1728
2655
|
}
|
2656
|
+
#else
|
2657
|
+
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2658
|
+
assert(n % QK_K == 0);
|
2659
|
+
|
2660
|
+
const block_q4_K * restrict x = vx;
|
2661
|
+
const block_q8_K * restrict y = vy;
|
2662
|
+
|
2663
|
+
const int nb = n / QK_K;
|
2664
|
+
|
2665
|
+
#ifdef __ARM_NEON
|
2666
|
+
|
2667
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2668
|
+
|
2669
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
2670
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
2671
|
+
#endif
|
2672
|
+
|
2673
|
+
float sumf = 0;
|
2674
|
+
|
2675
|
+
int8x16x2_t q4bytes;
|
2676
|
+
int8x16x4_t q8bytes;
|
2677
|
+
|
2678
|
+
float sum_mins = 0.f;
|
2679
|
+
|
2680
|
+
uint16_t aux16[2];
|
2681
|
+
const uint8_t * restrict scales = (const uint8_t *)aux16;
|
2682
|
+
|
2683
|
+
for (int i = 0; i < nb; ++i) {
|
2684
|
+
|
2685
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2686
|
+
const int8_t * restrict q8 = y[i].qs;
|
2687
|
+
|
2688
|
+
const uint16_t * restrict a = (const uint16_t *)x[i].scales;
|
2689
|
+
aux16[0] = a[0] & 0x0f0f;
|
2690
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
2691
|
+
|
2692
|
+
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
|
2693
|
+
sum_mins += y[i].d * (float)x[i].d[1] * summi;
|
2694
|
+
|
2695
|
+
const float d = y[i].d * (float)x[i].d[0];
|
2696
|
+
|
2697
|
+
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
|
2698
|
+
|
2699
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
2700
|
+
q8bytes = vld1q_s8_x4(q8);
|
2701
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
2702
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
2703
|
+
|
2704
|
+
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
2705
|
+
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
2706
|
+
|
2707
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
2708
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
2709
|
+
|
2710
|
+
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
2711
|
+
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
2712
|
+
|
2713
|
+
#else
|
2714
|
+
q8bytes = vld1q_s8_x4(q8);
|
2715
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
2716
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
2717
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
2718
|
+
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
2719
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
2720
|
+
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
2721
|
+
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
2722
|
+
|
2723
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
2724
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
2725
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
|
2726
|
+
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
2727
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
2728
|
+
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
2729
|
+
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
2730
|
+
|
2731
|
+
#endif
|
2732
|
+
sumf += d * (sumi1 + sumi2);
|
2733
|
+
|
2734
|
+
}
|
2735
|
+
|
2736
|
+
*s = sumf - sum_mins;
|
2737
|
+
|
2738
|
+
#elif defined __AVX2__
|
2739
|
+
|
2740
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
2741
|
+
|
2742
|
+
__m256 acc = _mm256_setzero_ps();
|
2743
|
+
|
2744
|
+
float summs = 0;
|
2745
|
+
|
2746
|
+
uint16_t aux16[2];
|
2747
|
+
const uint8_t * scales = (const uint8_t *)aux16;
|
2748
|
+
|
2749
|
+
for (int i = 0; i < nb; ++i) {
|
2750
|
+
|
2751
|
+
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
2752
|
+
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
2753
|
+
const __m256 vd = _mm256_set1_ps(d);
|
2754
|
+
|
2755
|
+
const uint16_t * a = (const uint16_t *)x[i].scales;
|
2756
|
+
aux16[0] = a[0] & 0x0f0f;
|
2757
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
2758
|
+
|
2759
|
+
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
2760
|
+
|
2761
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2762
|
+
const int8_t * restrict q8 = y[i].qs;
|
2763
|
+
|
2764
|
+
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
2765
|
+
const __m256i q4l = _mm256_and_si256(q4bits, m4);
|
2766
|
+
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
1729
2767
|
|
2768
|
+
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2769
|
+
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2770
|
+
|
2771
|
+
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
2772
|
+
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
2773
|
+
|
2774
|
+
const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
|
2775
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
|
2776
|
+
|
2777
|
+
const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
|
2778
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
|
2779
|
+
|
2780
|
+
}
|
2781
|
+
|
2782
|
+
*s = hsum_float_8(acc) - summs;
|
2783
|
+
|
2784
|
+
#else
|
2785
|
+
|
2786
|
+
uint8_t aux8[QK_K];
|
2787
|
+
int16_t aux16[16];
|
2788
|
+
float sums [8];
|
2789
|
+
memset(sums, 0, 8*sizeof(float));
|
2790
|
+
|
2791
|
+
uint16_t s16[2];
|
2792
|
+
const uint8_t * restrict scales = (const uint8_t *)s16;
|
2793
|
+
|
2794
|
+
float sumf = 0;
|
2795
|
+
for (int i = 0; i < nb; ++i) {
|
2796
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2797
|
+
const int8_t * restrict q8 = y[i].qs;
|
2798
|
+
uint8_t * restrict a = aux8;
|
2799
|
+
for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
|
2800
|
+
for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
|
2801
|
+
|
2802
|
+
const uint16_t * restrict b = (const uint16_t *)x[i].scales;
|
2803
|
+
s16[0] = b[0] & 0x0f0f;
|
2804
|
+
s16[1] = (b[0] >> 4) & 0x0f0f;
|
2805
|
+
|
2806
|
+
sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
2807
|
+
|
2808
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
|
2809
|
+
|
2810
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
2811
|
+
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
|
2812
|
+
q8 += 16; a += 16;
|
2813
|
+
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
|
2814
|
+
q8 += 16; a += 16;
|
2815
|
+
const float dl = d * scales[j];
|
2816
|
+
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
|
2817
|
+
}
|
2818
|
+
}
|
2819
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2820
|
+
*s = sumf;
|
2821
|
+
#endif
|
2822
|
+
}
|
2823
|
+
#endif
|
2824
|
+
|
2825
|
+
#if QK_K == 256
|
1730
2826
|
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1731
2827
|
assert(n % QK_K == 0);
|
1732
2828
|
|
@@ -1840,18 +2936,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1840
2936
|
|
1841
2937
|
for (int i = 0; i < nb; ++i) {
|
1842
2938
|
|
1843
|
-
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1844
|
-
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1845
|
-
|
1846
2939
|
const uint8_t * restrict q5 = x[i].qs;
|
1847
2940
|
const int8_t * restrict q8 = y[i].qs;
|
1848
2941
|
|
2942
|
+
#if QK_K == 256
|
2943
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2944
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
2945
|
+
|
1849
2946
|
memcpy(utmp, x[i].scales, 12);
|
1850
2947
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1851
2948
|
const uint32_t uaux = utmp[1] & kmask1;
|
1852
2949
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1853
2950
|
utmp[2] = uaux;
|
1854
2951
|
utmp[0] &= kmask1;
|
2952
|
+
#else
|
2953
|
+
// TODO
|
2954
|
+
const float d = 0, dmin = 0;
|
2955
|
+
#endif
|
1855
2956
|
|
1856
2957
|
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
1857
2958
|
|
@@ -1876,33 +2977,133 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1876
2977
|
const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
|
1877
2978
|
const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
|
1878
2979
|
|
1879
|
-
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
|
2980
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
|
2981
|
+
|
2982
|
+
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
|
2983
|
+
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
|
2984
|
+
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
2985
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
2986
|
+
|
2987
|
+
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
2988
|
+
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
|
2989
|
+
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
2990
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
2991
|
+
|
2992
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2993
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2994
|
+
|
2995
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
|
2996
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
2997
|
+
|
2998
|
+
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
2999
|
+
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
3000
|
+
|
3001
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
3002
|
+
|
3003
|
+
}
|
3004
|
+
|
3005
|
+
__m256 vd = _mm256_set1_ps(d);
|
3006
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
3007
|
+
|
3008
|
+
}
|
3009
|
+
|
3010
|
+
*s = hsum_float_8(acc) + summs;
|
3011
|
+
|
3012
|
+
#elif defined __AVX__
|
3013
|
+
|
3014
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
3015
|
+
const __m128i mzero = _mm_setzero_si128();
|
3016
|
+
const __m128i mone = _mm_set1_epi8(1);
|
3017
|
+
const __m128i m2 = _mm_set1_epi8(2);
|
3018
|
+
|
3019
|
+
__m256 acc = _mm256_setzero_ps();
|
3020
|
+
|
3021
|
+
float summs = 0.f;
|
3022
|
+
|
3023
|
+
for (int i = 0; i < nb; ++i) {
|
3024
|
+
|
3025
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3026
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
3027
|
+
|
3028
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3029
|
+
const int8_t * restrict q8 = y[i].qs;
|
3030
|
+
|
3031
|
+
memcpy(utmp, x[i].scales, 12);
|
3032
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
3033
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
3034
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
3035
|
+
utmp[2] = uaux;
|
3036
|
+
utmp[0] &= kmask1;
|
3037
|
+
|
3038
|
+
const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
|
3039
|
+
const __m128i scales = _mm_cvtepu8_epi16(utmps);
|
3040
|
+
const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
|
1880
3041
|
|
1881
|
-
|
1882
|
-
|
1883
|
-
|
1884
|
-
|
3042
|
+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
|
3043
|
+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
|
3044
|
+
const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
|
3045
|
+
const __m128i prod = _mm_madd_epi16(mins, q8s);
|
3046
|
+
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
3047
|
+
summs += dmin * _mm_extract_epi32(hsum, 0);
|
1885
3048
|
|
1886
|
-
|
1887
|
-
|
1888
|
-
|
1889
|
-
hmask = _mm256_slli_epi16(hmask, 1);
|
3049
|
+
const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
|
3050
|
+
const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
|
3051
|
+
__m128i hmask = mone;
|
1890
3052
|
|
1891
|
-
|
1892
|
-
|
3053
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
3054
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
1893
3055
|
|
1894
|
-
|
1895
|
-
__m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
3056
|
+
int bit = 0;
|
1896
3057
|
|
1897
|
-
|
1898
|
-
|
3058
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
3059
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1899
3060
|
|
1900
|
-
|
3061
|
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
3062
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
3063
|
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
3064
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
3065
|
+
|
3066
|
+
const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
|
3067
|
+
const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
|
3068
|
+
|
3069
|
+
__m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
|
3070
|
+
__m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
|
3071
|
+
__m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
|
3072
|
+
__m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
|
3073
|
+
__m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
|
3074
|
+
__m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
|
3075
|
+
hmask = _mm_slli_epi16(hmask, 1);
|
3076
|
+
|
3077
|
+
__m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3078
|
+
__m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3079
|
+
__m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
|
3080
|
+
__m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
|
3081
|
+
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
3082
|
+
p16_1 = _mm_madd_epi16(scale_0, p16_1);
|
3083
|
+
|
3084
|
+
q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
|
3085
|
+
q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
|
3086
|
+
q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
|
3087
|
+
q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
|
3088
|
+
q5_0 = _mm_add_epi8(q5l_0, q5h_0);
|
3089
|
+
q5_1 = _mm_add_epi8(q5l_1, q5h_1);
|
3090
|
+
hmask = _mm_slli_epi16(hmask, 1);
|
3091
|
+
|
3092
|
+
q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3093
|
+
q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3094
|
+
__m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
|
3095
|
+
__m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
|
3096
|
+
p16_2 = _mm_madd_epi16(scale_1, p16_2);
|
3097
|
+
p16_3 = _mm_madd_epi16(scale_1, p16_3);
|
3098
|
+
|
3099
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
3100
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
1901
3101
|
|
1902
3102
|
}
|
1903
3103
|
|
1904
3104
|
__m256 vd = _mm256_set1_ps(d);
|
1905
|
-
|
3105
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
3106
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
1906
3107
|
|
1907
3108
|
}
|
1908
3109
|
|
@@ -1972,8 +3173,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1972
3173
|
#endif
|
1973
3174
|
}
|
1974
3175
|
|
3176
|
+
#else
|
3177
|
+
|
3178
|
+
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3179
|
+
assert(n % QK_K == 0);
|
3180
|
+
|
3181
|
+
const block_q5_K * restrict x = vx;
|
3182
|
+
const block_q8_K * restrict y = vy;
|
3183
|
+
|
3184
|
+
const int nb = n / QK_K;
|
3185
|
+
|
3186
|
+
#ifdef __ARM_NEON
|
3187
|
+
|
3188
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
3189
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
3190
|
+
const uint8x16_t mh = vdupq_n_u8(16);
|
3191
|
+
|
3192
|
+
int8x16x4_t q5bytes;
|
3193
|
+
uint8x16x4_t q5h;
|
3194
|
+
|
3195
|
+
float sumf = 0;
|
3196
|
+
|
3197
|
+
for (int i = 0; i < nb; ++i) {
|
3198
|
+
|
3199
|
+
const float d = y[i].d * (float)x[i].d;
|
3200
|
+
const int8_t * sc = x[i].scales;
|
3201
|
+
|
3202
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3203
|
+
const uint8_t * restrict qh = x[i].qh;
|
3204
|
+
const int8_t * restrict q8 = y[i].qs;
|
3205
|
+
|
3206
|
+
const uint8x8_t qhbits = vld1_u8(qh);
|
3207
|
+
|
3208
|
+
const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
|
3209
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
3210
|
+
|
3211
|
+
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
|
3212
|
+
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
|
3213
|
+
q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
|
3214
|
+
q5h.val[2] = vbicq_u8(mh, htmp);
|
3215
|
+
q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
|
3216
|
+
|
3217
|
+
q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
|
3218
|
+
q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
|
3219
|
+
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
3220
|
+
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
3221
|
+
|
3222
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3223
|
+
|
3224
|
+
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
3225
|
+
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
3226
|
+
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
3227
|
+
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
3228
|
+
|
3229
|
+
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
3230
|
+
|
3231
|
+
#else
|
3232
|
+
|
3233
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
3234
|
+
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
3235
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
3236
|
+
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
3237
|
+
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
|
3238
|
+
|
3239
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
3240
|
+
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
3241
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
3242
|
+
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
3243
|
+
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
|
3244
|
+
|
3245
|
+
sumf += d*sumi;
|
3246
|
+
#endif
|
3247
|
+
|
3248
|
+
}
|
3249
|
+
|
3250
|
+
*s = sumf;
|
3251
|
+
|
3252
|
+
#elif defined __AVX2__
|
3253
|
+
|
3254
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
3255
|
+
const __m256i mone = _mm256_set1_epi8(1);
|
3256
|
+
|
3257
|
+
__m256 acc = _mm256_setzero_ps();
|
3258
|
+
|
3259
|
+
for (int i = 0; i < nb; ++i) {
|
3260
|
+
|
3261
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3262
|
+
const int8_t * restrict q8 = y[i].qs;
|
3263
|
+
|
3264
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3265
|
+
|
3266
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
3267
|
+
|
3268
|
+
const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
3269
|
+
const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
3270
|
+
|
3271
|
+
int64_t aux64;
|
3272
|
+
memcpy(&aux64, x[i].qh, 8);
|
3273
|
+
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
|
3274
|
+
const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
|
3275
|
+
|
3276
|
+
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
|
3277
|
+
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
|
3278
|
+
|
3279
|
+
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
|
3280
|
+
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
3281
|
+
|
3282
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
3283
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
3284
|
+
|
3285
|
+
const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
|
3286
|
+
const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
|
3287
|
+
const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
|
3288
|
+
const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
|
3289
|
+
|
3290
|
+
const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
|
3291
|
+
|
3292
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
|
3293
|
+
|
3294
|
+
}
|
3295
|
+
|
3296
|
+
*s = hsum_float_8(acc);
|
3297
|
+
|
3298
|
+
#else
|
3299
|
+
|
3300
|
+
|
3301
|
+
uint8_t aux8[QK_K];
|
3302
|
+
int16_t aux16[16];
|
3303
|
+
float sums [8];
|
3304
|
+
memset(sums, 0, 8*sizeof(float));
|
3305
|
+
|
3306
|
+
float sumf = 0;
|
3307
|
+
for (int i = 0; i < nb; ++i) {
|
3308
|
+
const uint8_t * restrict q4 = x[i].qs;
|
3309
|
+
const uint8_t * restrict hm = x[i].qh;
|
3310
|
+
const int8_t * restrict q8 = y[i].qs;
|
3311
|
+
uint8_t * restrict a = aux8;
|
3312
|
+
for (int l = 0; l < 32; ++l) {
|
3313
|
+
a[l+ 0] = q4[l] & 0xF;
|
3314
|
+
a[l+32] = q4[l] >> 4;
|
3315
|
+
}
|
3316
|
+
for (int is = 0; is < 8; ++is) {
|
3317
|
+
uint8_t m = 1 << is;
|
3318
|
+
for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
|
3319
|
+
}
|
3320
|
+
|
3321
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3322
|
+
const int8_t * restrict sc = x[i].scales;
|
3323
|
+
|
3324
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
3325
|
+
const float dl = d * sc[j];
|
3326
|
+
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
|
3327
|
+
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
|
3328
|
+
q8 += 16; a += 16;
|
3329
|
+
}
|
3330
|
+
}
|
3331
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
3332
|
+
*s = sumf;
|
3333
|
+
#endif
|
3334
|
+
}
|
3335
|
+
#endif
|
1975
3336
|
|
1976
3337
|
|
3338
|
+
#if QK_K == 256
|
1977
3339
|
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1978
3340
|
assert(n % QK_K == 0);
|
1979
3341
|
|
@@ -2198,6 +3560,124 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2198
3560
|
|
2199
3561
|
*s = hsum_float_8(acc);
|
2200
3562
|
|
3563
|
+
#elif defined __AVX__
|
3564
|
+
|
3565
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
3566
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
3567
|
+
const __m128i m32s = _mm_set1_epi8(32);
|
3568
|
+
const __m128i m2 = _mm_set1_epi8(2);
|
3569
|
+
|
3570
|
+
__m256 acc = _mm256_setzero_ps();
|
3571
|
+
|
3572
|
+
for (int i = 0; i < nb; ++i) {
|
3573
|
+
|
3574
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3575
|
+
|
3576
|
+
const uint8_t * restrict q4 = x[i].ql;
|
3577
|
+
const uint8_t * restrict qh = x[i].qh;
|
3578
|
+
const int8_t * restrict q8 = y[i].qs;
|
3579
|
+
|
3580
|
+
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
3581
|
+
|
3582
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
3583
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
3584
|
+
|
3585
|
+
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
|
3586
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
3587
|
+
|
3588
|
+
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
3589
|
+
const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
3590
|
+
|
3591
|
+
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
3592
|
+
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
3593
|
+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
|
3594
|
+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
|
3595
|
+
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
|
3596
|
+
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
|
3597
|
+
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
|
3598
|
+
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
|
3599
|
+
|
3600
|
+
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3601
|
+
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3602
|
+
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3603
|
+
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3604
|
+
|
3605
|
+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
|
3606
|
+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
|
3607
|
+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
|
3608
|
+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
|
3609
|
+
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
|
3610
|
+
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
|
3611
|
+
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
|
3612
|
+
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
|
3613
|
+
|
3614
|
+
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3615
|
+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3616
|
+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3617
|
+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3618
|
+
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3619
|
+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3620
|
+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3621
|
+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3622
|
+
|
3623
|
+
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
|
3624
|
+
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
|
3625
|
+
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
|
3626
|
+
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
|
3627
|
+
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
|
3628
|
+
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
|
3629
|
+
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
|
3630
|
+
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
|
3631
|
+
|
3632
|
+
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
3633
|
+
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
3634
|
+
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
3635
|
+
__m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
|
3636
|
+
__m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
|
3637
|
+
__m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
|
3638
|
+
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
3639
|
+
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
3640
|
+
|
3641
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
3642
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
3643
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
3644
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
3645
|
+
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
|
3646
|
+
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
3647
|
+
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
3648
|
+
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
3649
|
+
|
3650
|
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
3651
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3652
|
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
3653
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3654
|
+
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
|
3655
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3656
|
+
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
|
3657
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3658
|
+
|
3659
|
+
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
3660
|
+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
3661
|
+
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
3662
|
+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
3663
|
+
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
3664
|
+
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
|
3665
|
+
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
3666
|
+
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
|
3667
|
+
|
3668
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
3669
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
3670
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
|
3671
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
|
3672
|
+
|
3673
|
+
}
|
3674
|
+
|
3675
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
3676
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
3677
|
+
}
|
3678
|
+
|
3679
|
+
*s = hsum_float_8(acc);
|
3680
|
+
|
2201
3681
|
#else
|
2202
3682
|
|
2203
3683
|
int8_t aux8[QK_K];
|
@@ -2242,3 +3722,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2242
3722
|
*s = sumf;
|
2243
3723
|
#endif
|
2244
3724
|
}
|
3725
|
+
|
3726
|
+
#else
|
3727
|
+
|
3728
|
+
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3729
|
+
assert(n % QK_K == 0);
|
3730
|
+
|
3731
|
+
const block_q6_K * restrict x = vx;
|
3732
|
+
const block_q8_K * restrict y = vy;
|
3733
|
+
|
3734
|
+
const int nb = n / QK_K;
|
3735
|
+
|
3736
|
+
#ifdef __ARM_NEON
|
3737
|
+
|
3738
|
+
float sum = 0;
|
3739
|
+
|
3740
|
+
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
3741
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
3742
|
+
const int8x16_t m32s = vdupq_n_s8(32);
|
3743
|
+
|
3744
|
+
const uint8x16_t mone = vdupq_n_u8(3);
|
3745
|
+
|
3746
|
+
int8x16x4_t q6bytes;
|
3747
|
+
uint8x16x4_t q6h;
|
3748
|
+
|
3749
|
+
for (int i = 0; i < nb; ++i) {
|
3750
|
+
|
3751
|
+
const float d_all = (float)x[i].d;
|
3752
|
+
|
3753
|
+
const uint8_t * restrict q6 = x[i].ql;
|
3754
|
+
const uint8_t * restrict qh = x[i].qh;
|
3755
|
+
const int8_t * restrict q8 = y[i].qs;
|
3756
|
+
|
3757
|
+
const int8_t * restrict scale = x[i].scales;
|
3758
|
+
|
3759
|
+
int32_t isum = 0;
|
3760
|
+
|
3761
|
+
uint8x16_t qhbits = vld1q_u8(qh);
|
3762
|
+
uint8x16x2_t q6bits = vld1q_u8_x2(q6);
|
3763
|
+
int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
3764
|
+
|
3765
|
+
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
|
3766
|
+
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
|
3767
|
+
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3768
|
+
shifted = vshrq_n_u8(qhbits, 4);
|
3769
|
+
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3770
|
+
shifted = vshrq_n_u8(qhbits, 6);
|
3771
|
+
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3772
|
+
|
3773
|
+
q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
|
3774
|
+
q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
|
3775
|
+
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
3776
|
+
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
3777
|
+
|
3778
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3779
|
+
|
3780
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
3781
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
3782
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
3783
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
3784
|
+
#else
|
3785
|
+
|
3786
|
+
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
3787
|
+
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
3788
|
+
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
3789
|
+
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
3790
|
+
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
3791
|
+
|
3792
|
+
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
3793
|
+
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
3794
|
+
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
3795
|
+
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
3796
|
+
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
3797
|
+
#endif
|
3798
|
+
|
3799
|
+
sum += isum * d_all * y[i].d;
|
3800
|
+
|
3801
|
+
}
|
3802
|
+
*s = sum;
|
3803
|
+
|
3804
|
+
#elif defined __AVX2__
|
3805
|
+
|
3806
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
3807
|
+
const __m256i m2 = _mm256_set1_epi8(3);
|
3808
|
+
const __m256i m32s = _mm256_set1_epi8(32);
|
3809
|
+
|
3810
|
+
__m256 acc = _mm256_setzero_ps();
|
3811
|
+
|
3812
|
+
for (int i = 0; i < nb; ++i) {
|
3813
|
+
|
3814
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3815
|
+
|
3816
|
+
const uint8_t * restrict q4 = x[i].ql;
|
3817
|
+
const uint8_t * restrict qh = x[i].qh;
|
3818
|
+
const int8_t * restrict q8 = y[i].qs;
|
3819
|
+
|
3820
|
+
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
3821
|
+
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
3822
|
+
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
3823
|
+
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
3824
|
+
|
3825
|
+
__m256i sumi = _mm256_setzero_si256();
|
3826
|
+
|
3827
|
+
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
3828
|
+
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
3829
|
+
|
3830
|
+
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
3831
|
+
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
3832
|
+
|
3833
|
+
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
3834
|
+
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
3835
|
+
|
3836
|
+
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
3837
|
+
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
3838
|
+
|
3839
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
3840
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
3841
|
+
|
3842
|
+
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
3843
|
+
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
3844
|
+
|
3845
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
3846
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
3847
|
+
|
3848
|
+
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
3849
|
+
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
3850
|
+
|
3851
|
+
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
3852
|
+
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
3853
|
+
|
3854
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
3855
|
+
|
3856
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
3857
|
+
}
|
3858
|
+
|
3859
|
+
*s = hsum_float_8(acc);
|
3860
|
+
|
3861
|
+
#else
|
3862
|
+
|
3863
|
+
int8_t aux8[QK_K];
|
3864
|
+
int16_t aux16[8];
|
3865
|
+
float sums [8];
|
3866
|
+
int32_t aux32[8];
|
3867
|
+
memset(sums, 0, 8*sizeof(float));
|
3868
|
+
|
3869
|
+
float sumf = 0;
|
3870
|
+
for (int i = 0; i < nb; ++i) {
|
3871
|
+
const uint8_t * restrict q4 = x[i].ql;
|
3872
|
+
const uint8_t * restrict qh = x[i].qh;
|
3873
|
+
const int8_t * restrict q8 = y[i].qs;
|
3874
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
3875
|
+
int8_t * restrict a = aux8;
|
3876
|
+
for (int l = 0; l < 16; ++l) {
|
3877
|
+
a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
3878
|
+
a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
3879
|
+
a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
3880
|
+
a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
3881
|
+
}
|
3882
|
+
int is = 0;
|
3883
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
3884
|
+
int scale = x[i].scales[is++];
|
3885
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
3886
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
3887
|
+
q8 += 8; a += 8;
|
3888
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
3889
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
3890
|
+
q8 += 8; a += 8;
|
3891
|
+
}
|
3892
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
3893
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
3894
|
+
}
|
3895
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
3896
|
+
*s = sumf;
|
3897
|
+
#endif
|
3898
|
+
}
|
3899
|
+
|
3900
|
+
#endif
|