llama_cpp 0.2.1 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +32 -0
- data/README.md +39 -6
- data/examples/README.md +32 -0
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +38 -0
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +231 -132
- data/ext/llama_cpp/src/ggml-cuda.cu +844 -337
- data/ext/llama_cpp/src/ggml-metal.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.m +193 -49
- data/ext/llama_cpp/src/ggml-metal.metal +477 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +493 -4
- data/ext/llama_cpp/src/ggml.c +1565 -430
- data/ext/llama_cpp/src/ggml.h +208 -14
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +194 -101
- data/ext/llama_cpp/src/llama.h +41 -14
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +12 -17
- metadata +3 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
|
|
261
261
|
return scale;
|
262
262
|
}
|
263
263
|
|
264
|
+
#if QK_K == 256
|
264
265
|
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
265
266
|
if (j < 4) {
|
266
267
|
*d = q[j] & 63; *m = q[j + 4] & 63;
|
@@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
|
|
269
270
|
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
270
271
|
}
|
271
272
|
}
|
273
|
+
#endif
|
272
274
|
|
273
275
|
//========================- 2-bit (de)-quantization
|
274
276
|
|
@@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
|
|
330
332
|
}
|
331
333
|
}
|
332
334
|
|
335
|
+
#if QK_K == 256
|
333
336
|
for (int j = 0; j < QK_K; j += 128) {
|
334
337
|
for (int l = 0; l < 32; ++l) {
|
335
338
|
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
336
339
|
}
|
337
340
|
}
|
341
|
+
#else
|
342
|
+
for (int l = 0; l < 16; ++l) {
|
343
|
+
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
|
344
|
+
}
|
345
|
+
#endif
|
338
346
|
|
339
347
|
x += QK_K;
|
340
348
|
|
@@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
|
|
352
360
|
|
353
361
|
const uint8_t * q = x[i].qs;
|
354
362
|
|
363
|
+
#if QK_K == 256
|
355
364
|
int is = 0;
|
356
365
|
float dl, ml;
|
357
366
|
for (int n = 0; n < QK_K; n += 128) {
|
@@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
|
|
370
379
|
}
|
371
380
|
q += 32;
|
372
381
|
}
|
373
|
-
|
382
|
+
#else
|
383
|
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
384
|
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
385
|
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
386
|
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
387
|
+
for (int l = 0; l < 16; ++l) {
|
388
|
+
y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
|
389
|
+
y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
|
390
|
+
y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
|
391
|
+
y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
|
392
|
+
}
|
393
|
+
y += QK_K;
|
394
|
+
#endif
|
374
395
|
}
|
375
396
|
}
|
376
397
|
|
@@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
|
412
433
|
}
|
413
434
|
}
|
414
435
|
|
436
|
+
#if QK_K == 256
|
415
437
|
memset(y[i].scales, 0, 12);
|
416
438
|
if (max_scale) {
|
417
439
|
float iscale = -32.f/max_scale;
|
@@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
|
445
467
|
L[16*j + ii] = l + 4;
|
446
468
|
}
|
447
469
|
}
|
470
|
+
#else
|
471
|
+
if (max_scale) {
|
472
|
+
float iscale = -8.f/max_scale;
|
473
|
+
for (int j = 0; j < QK_K/16; j+=2) {
|
474
|
+
int l1 = nearest_int(iscale*scales[j]);
|
475
|
+
l1 = 8 + MAX(-8, MIN(7, l1));
|
476
|
+
int l2 = nearest_int(iscale*scales[j+1]);
|
477
|
+
l2 = 8 + MAX(-8, MIN(7, l2));
|
478
|
+
y[i].scales[j/2] = l1 | (l2 << 4);
|
479
|
+
}
|
480
|
+
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
481
|
+
} else {
|
482
|
+
for (int j = 0; j < QK_K/16; j+=2) {
|
483
|
+
y[i].scales[j/2] = 0;
|
484
|
+
}
|
485
|
+
y[i].d = ggml_fp32_to_fp16(0.f);
|
486
|
+
}
|
487
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
488
|
+
int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
|
489
|
+
float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
|
490
|
+
if (!d) {
|
491
|
+
continue;
|
492
|
+
}
|
493
|
+
for (int ii = 0; ii < 16; ++ii) {
|
494
|
+
int l = nearest_int(x[16*j + ii]/d);
|
495
|
+
l = MAX(-4, MIN(3, l));
|
496
|
+
L[16*j + ii] = l + 4;
|
497
|
+
}
|
498
|
+
}
|
499
|
+
#endif
|
448
500
|
|
449
501
|
memset(y[i].hmask, 0, QK_K/8);
|
450
|
-
// We put the high-bit for the 1st
|
502
|
+
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
|
451
503
|
int m = 0;
|
452
504
|
uint8_t hm = 1;
|
453
505
|
for (int j = 0; j < QK_K; ++j) {
|
@@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
|
|
459
511
|
m = 0; hm <<= 1;
|
460
512
|
}
|
461
513
|
}
|
514
|
+
#if QK_K == 256
|
462
515
|
for (int j = 0; j < QK_K; j += 128) {
|
463
516
|
for (int l = 0; l < 32; ++l) {
|
464
517
|
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
|
465
518
|
}
|
466
519
|
}
|
520
|
+
#else
|
521
|
+
for (int l = 0; l < 16; ++l) {
|
522
|
+
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
|
523
|
+
}
|
524
|
+
#endif
|
467
525
|
|
468
526
|
x += QK_K;
|
469
527
|
}
|
470
528
|
}
|
471
529
|
|
530
|
+
#if QK_K == 256
|
472
531
|
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
|
473
532
|
assert(k % QK_K == 0);
|
474
|
-
assert(QK_K == 256);
|
475
533
|
const int nb = k / QK_K;
|
476
534
|
|
477
535
|
const uint32_t kmask1 = 0x03030303;
|
@@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
|
|
519
577
|
|
520
578
|
}
|
521
579
|
}
|
580
|
+
#else
|
581
|
+
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
|
582
|
+
assert(k % QK_K == 0);
|
583
|
+
assert(QK_K == 64);
|
584
|
+
const int nb = k / QK_K;
|
585
|
+
|
586
|
+
for (int i = 0; i < nb; i++) {
|
587
|
+
|
588
|
+
const float d_all = ggml_fp16_to_fp32(x[i].d);
|
589
|
+
|
590
|
+
const uint8_t * restrict q = x[i].qs;
|
591
|
+
const uint8_t * restrict hm = x[i].hmask;
|
592
|
+
|
593
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
594
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
595
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
596
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
597
|
+
|
598
|
+
for (int l=0; l<8; ++l) {
|
599
|
+
uint8_t h = hm[l];
|
600
|
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
601
|
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
602
|
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
603
|
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
604
|
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
605
|
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
606
|
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
607
|
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
608
|
+
}
|
609
|
+
y += QK_K;
|
610
|
+
}
|
611
|
+
}
|
612
|
+
#endif
|
522
613
|
|
523
614
|
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
|
524
615
|
quantize_row_q3_K_reference(x, vy, k);
|
@@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
563
654
|
}
|
564
655
|
}
|
565
656
|
|
657
|
+
#if QK_K == 256
|
566
658
|
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
|
567
659
|
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
|
568
660
|
for (int j = 0; j < QK_K/32; ++j) {
|
@@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|
594
686
|
L[32*j + ii] = l;
|
595
687
|
}
|
596
688
|
}
|
689
|
+
#else
|
690
|
+
const float s_factor = 15.f;
|
691
|
+
float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
|
692
|
+
float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
|
693
|
+
int d1 = nearest_int(inv_scale*scales[0]);
|
694
|
+
int m1 = nearest_int(inv_min*mins[0]);
|
695
|
+
int d2 = nearest_int(inv_scale*scales[1]);
|
696
|
+
int m2 = nearest_int(inv_min*mins[1]);
|
697
|
+
y[i].scales[0] = d1 | (m1 << 4);
|
698
|
+
y[i].scales[1] = d2 | (m2 << 4);
|
699
|
+
y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
|
700
|
+
y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
|
701
|
+
|
702
|
+
float sumlx = 0;
|
703
|
+
int suml2 = 0;
|
704
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
705
|
+
const uint8_t sd = y[i].scales[j] & 0xF;
|
706
|
+
const uint8_t sm = y[i].scales[j] >> 4;
|
707
|
+
const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
|
708
|
+
if (!d) continue;
|
709
|
+
const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
|
710
|
+
for (int ii = 0; ii < 32; ++ii) {
|
711
|
+
int l = nearest_int((x[32*j + ii] + m)/d);
|
712
|
+
l = MAX(0, MIN(15, l));
|
713
|
+
L[32*j + ii] = l;
|
714
|
+
sumlx += (x[32*j + ii] + m)*l*sd;
|
715
|
+
suml2 += l*l*sd*sd;
|
716
|
+
}
|
717
|
+
}
|
718
|
+
if (suml2) {
|
719
|
+
y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
|
720
|
+
}
|
721
|
+
#endif
|
597
722
|
uint8_t * q = y[i].qs;
|
598
723
|
for (int j = 0; j < QK_K; j += 64) {
|
599
|
-
for (int l = 0; l < 32; ++l)
|
724
|
+
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
|
725
|
+
q += 32;
|
600
726
|
}
|
601
727
|
|
602
728
|
x += QK_K;
|
@@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
|
610
736
|
|
611
737
|
for (int i = 0; i < nb; i++) {
|
612
738
|
|
613
|
-
const float d = ggml_fp16_to_fp32(x[i].d);
|
614
|
-
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
615
|
-
|
616
739
|
const uint8_t * q = x[i].qs;
|
617
740
|
|
741
|
+
#if QK_K == 256
|
742
|
+
|
743
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
744
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
745
|
+
|
618
746
|
int is = 0;
|
619
747
|
uint8_t sc, m;
|
620
748
|
for (int j = 0; j < QK_K; j += 64) {
|
@@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
|
626
754
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
627
755
|
q += 32; is += 2;
|
628
756
|
}
|
757
|
+
#else
|
758
|
+
const float dall = ggml_fp16_to_fp32(x[i].d[0]);
|
759
|
+
const float mall = ggml_fp16_to_fp32(x[i].d[1]);
|
760
|
+
const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
|
761
|
+
const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
|
762
|
+
for (int l = 0; l < 32; ++l) {
|
763
|
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
764
|
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
765
|
+
}
|
766
|
+
y += QK_K;
|
767
|
+
#endif
|
629
768
|
|
630
769
|
}
|
631
770
|
}
|
@@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
653
792
|
assert(k % QK_K == 0);
|
654
793
|
const int nb = k / QK_K;
|
655
794
|
|
795
|
+
#if QK_K == 256
|
656
796
|
uint8_t L[QK_K];
|
657
797
|
float mins[QK_K/32];
|
658
798
|
float scales[QK_K/32];
|
799
|
+
#else
|
800
|
+
int8_t L[QK_K];
|
801
|
+
float scales[QK_K/16];
|
802
|
+
#endif
|
659
803
|
|
660
804
|
for (int i = 0; i < nb; i++) {
|
661
805
|
|
806
|
+
#if QK_K == 256
|
807
|
+
|
662
808
|
float max_scale = 0; // as we are deducting the min, scales are always positive
|
663
809
|
float max_min = 0;
|
664
810
|
for (int j = 0; j < QK_K/32; ++j) {
|
@@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
|
|
725
871
|
m1 <<= 2; m2 <<= 2;
|
726
872
|
ql += 32;
|
727
873
|
}
|
874
|
+
#else
|
875
|
+
float max_scale = 0, amax = 0;
|
876
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
877
|
+
scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
|
878
|
+
float abs_scale = fabsf(scales[j]);
|
879
|
+
if (abs_scale > amax) {
|
880
|
+
amax = abs_scale;
|
881
|
+
max_scale = scales[j];
|
882
|
+
}
|
883
|
+
}
|
884
|
+
|
885
|
+
float iscale = -128.f/max_scale;
|
886
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
887
|
+
int l = nearest_int(iscale*scales[j]);
|
888
|
+
y[i].scales[j] = MAX(-128, MIN(127, l));
|
889
|
+
}
|
890
|
+
y[i].d = ggml_fp32_to_fp16(1/iscale);
|
891
|
+
|
892
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
893
|
+
const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
|
894
|
+
if (!d) continue;
|
895
|
+
for (int ii = 0; ii < 16; ++ii) {
|
896
|
+
int l = nearest_int(x[16*j + ii]/d);
|
897
|
+
l = MAX(-16, MIN(15, l));
|
898
|
+
L[16*j + ii] = l + 16;
|
899
|
+
}
|
900
|
+
}
|
901
|
+
|
902
|
+
uint8_t * restrict qh = y[i].qh;
|
903
|
+
uint8_t * restrict ql = y[i].qs;
|
904
|
+
memset(qh, 0, QK_K/8);
|
905
|
+
|
906
|
+
for (int j = 0; j < 32; ++j) {
|
907
|
+
int jm = j%8;
|
908
|
+
int is = j/8;
|
909
|
+
int l1 = L[j];
|
910
|
+
if (l1 > 15) {
|
911
|
+
l1 -= 16; qh[jm] |= (1 << is);
|
912
|
+
}
|
913
|
+
int l2 = L[j + 32];
|
914
|
+
if (l2 > 15) {
|
915
|
+
l2 -= 16; qh[jm] |= (1 << (4 + is));
|
916
|
+
}
|
917
|
+
ql[j] = l1 | (l2 << 4);
|
918
|
+
}
|
919
|
+
#endif
|
728
920
|
|
729
921
|
x += QK_K;
|
730
922
|
|
@@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
|
|
737
929
|
|
738
930
|
for (int i = 0; i < nb; i++) {
|
739
931
|
|
740
|
-
const float d = ggml_fp16_to_fp32(x[i].d);
|
741
|
-
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
742
|
-
|
743
932
|
const uint8_t * ql = x[i].qs;
|
744
933
|
const uint8_t * qh = x[i].qh;
|
745
934
|
|
935
|
+
#if QK_K == 256
|
936
|
+
|
937
|
+
const float d = ggml_fp16_to_fp32(x[i].d);
|
938
|
+
const float min = ggml_fp16_to_fp32(x[i].dmin);
|
939
|
+
|
746
940
|
int is = 0;
|
747
941
|
uint8_t sc, m;
|
748
942
|
uint8_t u1 = 1, u2 = 2;
|
@@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
|
|
756
950
|
ql += 32; is += 2;
|
757
951
|
u1 <<= 2; u2 <<= 2;
|
758
952
|
}
|
953
|
+
#else
|
954
|
+
float d = ggml_fp16_to_fp32(x[i].d);
|
955
|
+
const int8_t * restrict s = x[i].scales;
|
956
|
+
for (int l = 0; l < 8; ++l) {
|
957
|
+
y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
958
|
+
y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
959
|
+
y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
960
|
+
y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
961
|
+
y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
962
|
+
y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
963
|
+
y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
964
|
+
y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
965
|
+
}
|
966
|
+
y += QK_K;
|
967
|
+
#endif
|
759
968
|
}
|
760
969
|
}
|
761
970
|
|
@@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
|
|
823
1032
|
|
824
1033
|
uint8_t * restrict ql = y[i].ql;
|
825
1034
|
uint8_t * restrict qh = y[i].qh;
|
1035
|
+
#if QK_K == 256
|
826
1036
|
for (int j = 0; j < QK_K; j += 128) {
|
827
1037
|
for (int l = 0; l < 32; ++l) {
|
828
1038
|
const uint8_t q1 = L[j + l + 0] & 0xF;
|
@@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
|
|
836
1046
|
ql += 64;
|
837
1047
|
qh += 32;
|
838
1048
|
}
|
1049
|
+
#else
|
1050
|
+
for (int l = 0; l < 32; ++l) {
|
1051
|
+
const uint8_t q1 = L[l + 0] & 0xF;
|
1052
|
+
const uint8_t q2 = L[l + 32] & 0xF;
|
1053
|
+
ql[l] = q1 | (q2 << 4);
|
1054
|
+
}
|
1055
|
+
for (int l = 0; l < 16; ++l) {
|
1056
|
+
qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
|
1057
|
+
}
|
1058
|
+
#endif
|
839
1059
|
|
840
1060
|
x += QK_K;
|
841
1061
|
|
@@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
|
|
854
1074
|
const uint8_t * restrict qh = x[i].qh;
|
855
1075
|
const int8_t * restrict sc = x[i].scales;
|
856
1076
|
|
1077
|
+
#if QK_K == 256
|
857
1078
|
for (int n = 0; n < QK_K; n += 128) {
|
858
1079
|
for (int l = 0; l < 32; ++l) {
|
859
1080
|
int is = l/16;
|
@@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
|
|
871
1092
|
qh += 32;
|
872
1093
|
sc += 8;
|
873
1094
|
}
|
1095
|
+
#else
|
1096
|
+
for (int l = 0; l < 16; ++l) {
|
1097
|
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1098
|
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1099
|
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1100
|
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1101
|
+
y[l+ 0] = d * sc[0] * q1;
|
1102
|
+
y[l+16] = d * sc[1] * q2;
|
1103
|
+
y[l+32] = d * sc[2] * q3;
|
1104
|
+
y[l+48] = d * sc[3] * q4;
|
1105
|
+
}
|
1106
|
+
y += 64;
|
1107
|
+
#endif
|
874
1108
|
|
875
1109
|
}
|
876
1110
|
}
|
@@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) {
|
|
1002
1236
|
}
|
1003
1237
|
#endif
|
1004
1238
|
|
1239
|
+
#if QK_K == 256
|
1005
1240
|
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1006
1241
|
|
1007
1242
|
const block_q2_K * restrict x = vx;
|
@@ -1158,6 +1393,112 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1158
1393
|
|
1159
1394
|
*s = hsum_float_8(acc);
|
1160
1395
|
|
1396
|
+
#elif defined __AVX__
|
1397
|
+
|
1398
|
+
const __m128i m3 = _mm_set1_epi8(0x3);
|
1399
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
1400
|
+
const __m128i m2 = _mm_set1_epi8(0x2);
|
1401
|
+
|
1402
|
+
__m256 acc = _mm256_setzero_ps();
|
1403
|
+
|
1404
|
+
for (int i = 0; i < nb; ++i) {
|
1405
|
+
|
1406
|
+
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1407
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1408
|
+
|
1409
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1410
|
+
const int8_t * restrict q8 = y[i].qs;
|
1411
|
+
|
1412
|
+
// load mins and scales from block_q2_K.scales[QK_K/16]
|
1413
|
+
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
1414
|
+
const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
|
1415
|
+
const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
|
1416
|
+
const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
|
1417
|
+
const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
|
1418
|
+
|
1419
|
+
// summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
|
1420
|
+
const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
|
1421
|
+
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
|
1422
|
+
|
1423
|
+
// sumf += -dmin * summs in 32bits*8
|
1424
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
|
1425
|
+
|
1426
|
+
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
|
1427
|
+
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
|
1428
|
+
const __m128i scales[2] = { scales_0, scales_1 };
|
1429
|
+
|
1430
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
1431
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
1432
|
+
|
1433
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1434
|
+
|
1435
|
+
// load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
|
1436
|
+
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1437
|
+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1438
|
+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1439
|
+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1440
|
+
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1441
|
+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1442
|
+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1443
|
+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
1444
|
+
|
1445
|
+
// load 2bits*16*8 from block_q2_K.qs[QK_K/4]
|
1446
|
+
__m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
|
1447
|
+
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
1448
|
+
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
1449
|
+
const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
1450
|
+
const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
1451
|
+
q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
|
1452
|
+
const __m128i q2_1 = _mm_and_si128(q2bits, m3);
|
1453
|
+
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
1454
|
+
const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
1455
|
+
const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
1456
|
+
|
1457
|
+
// isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
|
1458
|
+
__m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
|
1459
|
+
__m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
|
1460
|
+
__m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
|
1461
|
+
__m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
|
1462
|
+
__m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
|
1463
|
+
__m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
|
1464
|
+
__m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
|
1465
|
+
__m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
|
1466
|
+
|
1467
|
+
// isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
|
1468
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
1469
|
+
p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
|
1470
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1471
|
+
p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
|
1472
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1473
|
+
p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
|
1474
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1475
|
+
p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
|
1476
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1477
|
+
p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
|
1478
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1479
|
+
p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
|
1480
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1481
|
+
p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
|
1482
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
1483
|
+
p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
|
1484
|
+
|
1485
|
+
p0 = _mm_add_epi32(p0, p1);
|
1486
|
+
p2 = _mm_add_epi32(p2, p3);
|
1487
|
+
p4 = _mm_add_epi32(p4, p5);
|
1488
|
+
p6 = _mm_add_epi32(p6, p7);
|
1489
|
+
|
1490
|
+
// isum in 32bits*4*2
|
1491
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
|
1492
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
|
1493
|
+
}
|
1494
|
+
|
1495
|
+
// sumf += dall * isum - dmin * summs in 32bits
|
1496
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
1497
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
|
1498
|
+
}
|
1499
|
+
|
1500
|
+
*s = hsum_float_8(acc);
|
1501
|
+
|
1161
1502
|
#else
|
1162
1503
|
|
1163
1504
|
float sumf = 0;
|
@@ -1201,6 +1542,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1201
1542
|
#endif
|
1202
1543
|
}
|
1203
1544
|
|
1545
|
+
#else
|
1546
|
+
|
1547
|
+
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1548
|
+
|
1549
|
+
const block_q2_K * restrict x = vx;
|
1550
|
+
const block_q8_K * restrict y = vy;
|
1551
|
+
|
1552
|
+
const int nb = n / QK_K;
|
1553
|
+
|
1554
|
+
#ifdef __ARM_NEON
|
1555
|
+
|
1556
|
+
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
1557
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1558
|
+
|
1559
|
+
int8x16x4_t q2bytes;
|
1560
|
+
|
1561
|
+
uint32_t aux32[2];
|
1562
|
+
const uint8_t * scales = (const uint8_t *)aux32;
|
1563
|
+
|
1564
|
+
float sum = 0;
|
1565
|
+
|
1566
|
+
for (int i = 0; i < nb; ++i) {
|
1567
|
+
|
1568
|
+
const float d = y[i].d * (float)x[i].d;
|
1569
|
+
const float dmin = -y[i].d * (float)x[i].dmin;
|
1570
|
+
|
1571
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1572
|
+
const int8_t * restrict q8 = y[i].qs;
|
1573
|
+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
1574
|
+
|
1575
|
+
aux32[0] = sc[0] & 0x0f0f0f0f;
|
1576
|
+
aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
|
1577
|
+
|
1578
|
+
sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
|
1579
|
+
|
1580
|
+
int isum1 = 0, isum2 = 0;
|
1581
|
+
|
1582
|
+
const uint8x16_t q2bits = vld1q_u8(q2);
|
1583
|
+
|
1584
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
1585
|
+
|
1586
|
+
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
|
1587
|
+
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
|
1588
|
+
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
1589
|
+
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
1590
|
+
|
1591
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1592
|
+
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
1593
|
+
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
1594
|
+
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
1595
|
+
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
1596
|
+
#else
|
1597
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
1598
|
+
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
1599
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
1600
|
+
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
1601
|
+
isum1 += vaddvq_s16(p1) * scales[0];
|
1602
|
+
isum2 += vaddvq_s16(p2) * scales[1];
|
1603
|
+
|
1604
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
1605
|
+
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
1606
|
+
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
1607
|
+
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
1608
|
+
isum1 += vaddvq_s16(p3) * scales[2];
|
1609
|
+
isum2 += vaddvq_s16(p4) * scales[3];
|
1610
|
+
#endif
|
1611
|
+
sum += d * (isum1 + isum2);
|
1612
|
+
|
1613
|
+
}
|
1614
|
+
|
1615
|
+
*s = sum;
|
1616
|
+
|
1617
|
+
#elif defined __AVX2__
|
1618
|
+
|
1619
|
+
const __m256i m3 = _mm256_set1_epi8(3);
|
1620
|
+
|
1621
|
+
__m256 acc = _mm256_setzero_ps();
|
1622
|
+
|
1623
|
+
uint32_t ud, um;
|
1624
|
+
const uint8_t * restrict db = (const uint8_t *)&ud;
|
1625
|
+
const uint8_t * restrict mb = (const uint8_t *)&um;
|
1626
|
+
|
1627
|
+
float summs = 0;
|
1628
|
+
|
1629
|
+
// TODO: optimize this
|
1630
|
+
|
1631
|
+
for (int i = 0; i < nb; ++i) {
|
1632
|
+
|
1633
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1634
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1635
|
+
|
1636
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1637
|
+
const int8_t * restrict q8 = y[i].qs;
|
1638
|
+
|
1639
|
+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
1640
|
+
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
1641
|
+
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
1642
|
+
|
1643
|
+
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
1644
|
+
summs += dmin * smin;
|
1645
|
+
|
1646
|
+
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
1647
|
+
const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
1648
|
+
const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
1649
|
+
|
1650
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
1651
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
1652
|
+
|
1653
|
+
const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
|
1654
|
+
const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
|
1655
|
+
|
1656
|
+
const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
|
1657
|
+
const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
|
1658
|
+
const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
|
1659
|
+
const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
|
1660
|
+
|
1661
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
|
1662
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
|
1663
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
|
1664
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
|
1665
|
+
}
|
1666
|
+
|
1667
|
+
*s = hsum_float_8(acc) + summs;
|
1668
|
+
|
1669
|
+
#else
|
1670
|
+
|
1671
|
+
float sumf = 0;
|
1672
|
+
|
1673
|
+
int isum[4];
|
1674
|
+
|
1675
|
+
for (int i = 0; i < nb; ++i) {
|
1676
|
+
|
1677
|
+
const uint8_t * q2 = x[i].qs;
|
1678
|
+
const int8_t * q8 = y[i].qs;
|
1679
|
+
const uint8_t * sc = x[i].scales;
|
1680
|
+
|
1681
|
+
int summs = 0;
|
1682
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
1683
|
+
summs += y[i].bsums[j] * (sc[j] >> 4);
|
1684
|
+
}
|
1685
|
+
|
1686
|
+
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1687
|
+
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1688
|
+
|
1689
|
+
isum[0] = isum[1] = isum[2] = isum[3] = 0;
|
1690
|
+
for (int l = 0; l < 16; ++l) {
|
1691
|
+
isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
|
1692
|
+
isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
|
1693
|
+
isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
|
1694
|
+
isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
|
1695
|
+
}
|
1696
|
+
for (int l = 0; l < 4; ++l) {
|
1697
|
+
isum[l] *= (sc[l] & 0xF);
|
1698
|
+
}
|
1699
|
+
sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
|
1700
|
+
}
|
1701
|
+
*s = sumf;
|
1702
|
+
#endif
|
1703
|
+
}
|
1704
|
+
#endif
|
1705
|
+
|
1706
|
+
#if QK_K == 256
|
1204
1707
|
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1205
1708
|
assert(n % QK_K == 0);
|
1206
1709
|
|
@@ -1434,34 +1937,176 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1434
1937
|
|
1435
1938
|
*s = hsum_float_8(acc);
|
1436
1939
|
|
1437
|
-
#
|
1438
|
-
// scalar version
|
1439
|
-
// This function is written like this so the compiler can manage to vectorize most of it
|
1440
|
-
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
1441
|
-
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
1442
|
-
// The ideal situation would be if we could just write the code once, and the compiler would
|
1443
|
-
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
1444
|
-
// write vectorized versions for AVX, ARM_NEON, etc.
|
1940
|
+
#elif defined __AVX__
|
1445
1941
|
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
memset(sums, 0, 8*sizeof(float));
|
1942
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
1943
|
+
const __m128i mone = _mm_set1_epi8(1);
|
1944
|
+
const __m128i m32 = _mm_set1_epi8(32);
|
1945
|
+
const __m128i m2 = _mm_set1_epi8(2);
|
1451
1946
|
|
1452
|
-
|
1453
|
-
|
1947
|
+
__m256 acc = _mm256_setzero_ps();
|
1948
|
+
|
1949
|
+
uint32_t *aux;
|
1454
1950
|
|
1455
|
-
float sumf = 0;
|
1456
1951
|
for (int i = 0; i < nb; ++i) {
|
1952
|
+
|
1953
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1954
|
+
|
1457
1955
|
const uint8_t * restrict q3 = x[i].qs;
|
1458
|
-
const
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1956
|
+
const int8_t * restrict q8 = y[i].qs;
|
1957
|
+
|
1958
|
+
// Set up scales
|
1959
|
+
aux = (uint32_t *)x[i].scales;
|
1960
|
+
__m128i scales128 = _mm_set_epi32(
|
1961
|
+
((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
|
1962
|
+
((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
|
1963
|
+
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
|
1964
|
+
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
|
1965
|
+
scales128 = _mm_sub_epi8(scales128, m32);
|
1966
|
+
const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
|
1967
|
+
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
|
1968
|
+
const __m128i scales[2] = { scales_0, scales_1 };
|
1969
|
+
|
1970
|
+
// high bit *128*2 from block_q3_K.hmask[QK_K/8]
|
1971
|
+
const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
|
1972
|
+
const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
|
1973
|
+
|
1974
|
+
// integer accumulator
|
1975
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
1976
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
1977
|
+
|
1978
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1979
|
+
// load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
|
1980
|
+
const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
|
1981
|
+
const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
|
1982
|
+
|
1983
|
+
// prepare low and high bits
|
1984
|
+
const int bit = j << 2;
|
1985
|
+
|
1986
|
+
const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
|
1987
|
+
const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
|
1988
|
+
const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
|
1989
|
+
const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
|
1990
|
+
|
1991
|
+
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
|
1992
|
+
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
|
1993
|
+
const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
|
1994
|
+
const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
|
1995
|
+
|
1996
|
+
const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
|
1997
|
+
const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
|
1998
|
+
const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
|
1999
|
+
const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
|
2000
|
+
|
2001
|
+
const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
|
2002
|
+
const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
|
2003
|
+
const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
|
2004
|
+
const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
|
2005
|
+
|
2006
|
+
// load Q8 quants from block_q8_K.qs[QK_K]
|
2007
|
+
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2008
|
+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2009
|
+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2010
|
+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2011
|
+
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2012
|
+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2013
|
+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2014
|
+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2015
|
+
|
2016
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
2017
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
2018
|
+
// and 2 if the high bit was set)
|
2019
|
+
__m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
|
2020
|
+
__m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
|
2021
|
+
__m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
|
2022
|
+
__m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
|
2023
|
+
__m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
|
2024
|
+
__m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
|
2025
|
+
__m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
|
2026
|
+
__m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
|
2027
|
+
|
2028
|
+
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
|
2029
|
+
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
|
2030
|
+
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
|
2031
|
+
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
|
2032
|
+
__m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
|
2033
|
+
__m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
|
2034
|
+
__m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
|
2035
|
+
__m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
|
2036
|
+
|
2037
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
2038
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
2039
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
2040
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
2041
|
+
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
|
2042
|
+
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
2043
|
+
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
2044
|
+
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
2045
|
+
|
2046
|
+
// multiply with scales
|
2047
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
2048
|
+
p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
|
2049
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2050
|
+
p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
|
2051
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2052
|
+
p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
|
2053
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2054
|
+
p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
|
2055
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2056
|
+
p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
|
2057
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2058
|
+
p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
|
2059
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2060
|
+
p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
|
2061
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2062
|
+
p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
|
2063
|
+
|
2064
|
+
// accumulate
|
2065
|
+
p16_0 = _mm_add_epi32(p16_0, p16_1);
|
2066
|
+
p16_2 = _mm_add_epi32(p16_2, p16_3);
|
2067
|
+
p16_4 = _mm_add_epi32(p16_4, p16_5);
|
2068
|
+
p16_6 = _mm_add_epi32(p16_6, p16_7);
|
2069
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
2070
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
|
2071
|
+
|
2072
|
+
}
|
2073
|
+
|
2074
|
+
// multiply with block scale and accumulate
|
2075
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
2076
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
2077
|
+
|
2078
|
+
}
|
2079
|
+
|
2080
|
+
*s = hsum_float_8(acc);
|
2081
|
+
|
2082
|
+
#else
|
2083
|
+
// scalar version
|
2084
|
+
// This function is written like this so the compiler can manage to vectorize most of it
|
2085
|
+
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
2086
|
+
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
2087
|
+
// The ideal situation would be if we could just write the code once, and the compiler would
|
2088
|
+
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
2089
|
+
// write vectorized versions for AVX, ARM_NEON, etc.
|
2090
|
+
|
2091
|
+
int8_t aux8[QK_K];
|
2092
|
+
int16_t aux16[8];
|
2093
|
+
float sums [8];
|
2094
|
+
int32_t aux32[8];
|
2095
|
+
memset(sums, 0, 8*sizeof(float));
|
2096
|
+
|
2097
|
+
uint32_t auxs[4];
|
2098
|
+
const int8_t * scales = (const int8_t*)auxs;
|
2099
|
+
|
2100
|
+
float sumf = 0;
|
2101
|
+
for (int i = 0; i < nb; ++i) {
|
2102
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2103
|
+
const uint8_t * restrict hm = x[i].hmask;
|
2104
|
+
const int8_t * restrict q8 = y[i].qs;
|
2105
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2106
|
+
int8_t * restrict a = aux8;
|
2107
|
+
uint8_t m = 1;
|
2108
|
+
for (int j = 0; j < QK_K; j += 128) {
|
2109
|
+
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
|
1465
2110
|
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
1466
2111
|
a += 32; m <<= 1;
|
1467
2112
|
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
|
@@ -1501,6 +2146,206 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1501
2146
|
|
1502
2147
|
}
|
1503
2148
|
|
2149
|
+
#else
|
2150
|
+
|
2151
|
+
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2152
|
+
assert(n % QK_K == 0);
|
2153
|
+
|
2154
|
+
const block_q3_K * restrict x = vx;
|
2155
|
+
const block_q8_K * restrict y = vy;
|
2156
|
+
|
2157
|
+
const int nb = n / QK_K;
|
2158
|
+
|
2159
|
+
#ifdef __ARM_NEON
|
2160
|
+
|
2161
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
2162
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
2163
|
+
#endif
|
2164
|
+
|
2165
|
+
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
2166
|
+
const uint8x16_t mh = vdupq_n_u8(4);
|
2167
|
+
|
2168
|
+
int8x16x4_t q3bytes;
|
2169
|
+
|
2170
|
+
uint16_t aux16[2];
|
2171
|
+
int8_t * scales = (int8_t *)aux16;
|
2172
|
+
|
2173
|
+
float sum = 0;
|
2174
|
+
|
2175
|
+
for (int i = 0; i < nb; ++i) {
|
2176
|
+
|
2177
|
+
uint8x16x4_t q3h;
|
2178
|
+
|
2179
|
+
const uint8x8_t hbits = vld1_u8(x[i].hmask);
|
2180
|
+
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
2181
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
|
2182
|
+
|
2183
|
+
const uint16_t a = *(const uint16_t *)x[i].scales;
|
2184
|
+
aux16[0] = a & 0x0f0f;
|
2185
|
+
aux16[1] = (a >> 4) & 0x0f0f;
|
2186
|
+
|
2187
|
+
for (int j = 0; j < 4; ++j) scales[j] -= 8;
|
2188
|
+
|
2189
|
+
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
2190
|
+
|
2191
|
+
const float d = y[i].d * (float)x[i].d;
|
2192
|
+
|
2193
|
+
const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
|
2194
|
+
q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
|
2195
|
+
q3h.val[1] = vandq_u8(mh, htmp);
|
2196
|
+
q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
|
2197
|
+
q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
|
2198
|
+
|
2199
|
+
q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
|
2200
|
+
q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
|
2201
|
+
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
2202
|
+
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
2203
|
+
|
2204
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2205
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
2206
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
2207
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
2208
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
2209
|
+
#else
|
2210
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
2211
|
+
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
2212
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
2213
|
+
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
2214
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
2215
|
+
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
2216
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
2217
|
+
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
2218
|
+
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
2219
|
+
#endif
|
2220
|
+
|
2221
|
+
sum += d * isum;
|
2222
|
+
|
2223
|
+
}
|
2224
|
+
|
2225
|
+
*s = sum;
|
2226
|
+
|
2227
|
+
#elif defined __AVX2__
|
2228
|
+
|
2229
|
+
const __m256i m3 = _mm256_set1_epi8(3);
|
2230
|
+
const __m256i m1 = _mm256_set1_epi8(1);
|
2231
|
+
|
2232
|
+
__m256 acc = _mm256_setzero_ps();
|
2233
|
+
|
2234
|
+
uint64_t aux64;
|
2235
|
+
|
2236
|
+
uint16_t aux16[2];
|
2237
|
+
const int8_t * aux8 = (const int8_t *)aux16;
|
2238
|
+
|
2239
|
+
for (int i = 0; i < nb; ++i) {
|
2240
|
+
|
2241
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2242
|
+
|
2243
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2244
|
+
const int8_t * restrict q8 = y[i].qs;
|
2245
|
+
|
2246
|
+
const uint16_t a = *(const uint16_t *)x[i].scales;
|
2247
|
+
aux16[0] = a & 0x0f0f;
|
2248
|
+
aux16[1] = (a >> 4) & 0x0f0f;
|
2249
|
+
|
2250
|
+
const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
2251
|
+
const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
2252
|
+
|
2253
|
+
memcpy(&aux64, x[i].hmask, 8);
|
2254
|
+
|
2255
|
+
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
2256
|
+
__m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
|
2257
|
+
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
2258
|
+
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
2259
|
+
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
2260
|
+
|
2261
|
+
// load low 2 bits
|
2262
|
+
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
2263
|
+
|
2264
|
+
// prepare low and high bits
|
2265
|
+
const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
|
2266
|
+
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
2267
|
+
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
2268
|
+
|
2269
|
+
// load Q8 quants
|
2270
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2271
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2272
|
+
|
2273
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
|
2274
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
2275
|
+
// and 2 if the high bit was set)
|
2276
|
+
const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
|
2277
|
+
const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
|
2278
|
+
|
2279
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
|
2280
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
|
2281
|
+
|
2282
|
+
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
2283
|
+
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
2284
|
+
|
2285
|
+
// multiply with scales
|
2286
|
+
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
2287
|
+
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
2288
|
+
|
2289
|
+
p16_0 = _mm256_add_epi32(p16_0, p16_1);
|
2290
|
+
|
2291
|
+
// multiply with block scale and accumulate
|
2292
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
|
2293
|
+
|
2294
|
+
}
|
2295
|
+
|
2296
|
+
*s = hsum_float_8(acc);
|
2297
|
+
|
2298
|
+
#else
|
2299
|
+
|
2300
|
+
int8_t aux8[QK_K];
|
2301
|
+
int16_t aux16[8];
|
2302
|
+
float sums [8];
|
2303
|
+
int32_t aux32[8];
|
2304
|
+
int32_t scales[4];
|
2305
|
+
memset(sums, 0, 8*sizeof(float));
|
2306
|
+
|
2307
|
+
float sumf = 0;
|
2308
|
+
for (int i = 0; i < nb; ++i) {
|
2309
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2310
|
+
const uint8_t * restrict hm = x[i].hmask;
|
2311
|
+
const int8_t * restrict q8 = y[i].qs;
|
2312
|
+
int8_t * restrict a = aux8;
|
2313
|
+
for (int l = 0; l < 8; ++l) {
|
2314
|
+
a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
|
2315
|
+
a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
|
2316
|
+
a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
|
2317
|
+
a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
|
2318
|
+
a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
|
2319
|
+
a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
|
2320
|
+
a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
|
2321
|
+
a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
|
2322
|
+
}
|
2323
|
+
|
2324
|
+
scales[0] = (x[i].scales[0] & 0xF) - 8;
|
2325
|
+
scales[1] = (x[i].scales[0] >> 4) - 8;
|
2326
|
+
scales[2] = (x[i].scales[1] & 0xF) - 8;
|
2327
|
+
scales[3] = (x[i].scales[1] >> 4) - 8;
|
2328
|
+
|
2329
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2330
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
2331
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2332
|
+
q8 += 8; a += 8;
|
2333
|
+
for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
|
2334
|
+
q8 += 8; a += 8;
|
2335
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
|
2336
|
+
}
|
2337
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
2338
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
2339
|
+
}
|
2340
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2341
|
+
*s = sumf;
|
2342
|
+
|
2343
|
+
#endif
|
2344
|
+
|
2345
|
+
}
|
2346
|
+
#endif
|
2347
|
+
|
2348
|
+
#if QK_K == 256
|
1504
2349
|
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1505
2350
|
assert(n % QK_K == 0);
|
1506
2351
|
|
@@ -1614,9 +2459,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1614
2459
|
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1615
2460
|
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1616
2461
|
|
1617
|
-
const uint8_t * restrict q4 = x[i].qs;
|
1618
|
-
const int8_t * restrict q8 = y[i].qs;
|
1619
|
-
|
1620
2462
|
memcpy(utmp, x[i].scales, 12);
|
1621
2463
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1622
2464
|
const uint32_t uaux = utmp[1] & kmask1;
|
@@ -1624,6 +2466,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1624
2466
|
utmp[2] = uaux;
|
1625
2467
|
utmp[0] &= kmask1;
|
1626
2468
|
|
2469
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2470
|
+
const int8_t * restrict q8 = y[i].qs;
|
2471
|
+
|
1627
2472
|
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
1628
2473
|
|
1629
2474
|
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
|
@@ -1667,6 +2512,88 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1667
2512
|
|
1668
2513
|
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
|
1669
2514
|
|
2515
|
+
#elif defined __AVX__
|
2516
|
+
|
2517
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
2518
|
+
const __m128i m2 = _mm_set1_epi8(0x2);
|
2519
|
+
|
2520
|
+
__m256 acc = _mm256_setzero_ps();
|
2521
|
+
__m128 acc_m = _mm_setzero_ps();
|
2522
|
+
|
2523
|
+
for (int i = 0; i < nb; ++i) {
|
2524
|
+
|
2525
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2526
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
2527
|
+
|
2528
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2529
|
+
const int8_t * restrict q8 = y[i].qs;
|
2530
|
+
|
2531
|
+
memcpy(utmp, x[i].scales, 12);
|
2532
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
2533
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
2534
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2535
|
+
utmp[2] = uaux;
|
2536
|
+
utmp[0] &= kmask1;
|
2537
|
+
|
2538
|
+
const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
|
2539
|
+
const __m128i scales = _mm_cvtepu8_epi16(utmps);
|
2540
|
+
const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
|
2541
|
+
|
2542
|
+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
|
2543
|
+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
|
2544
|
+
const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
|
2545
|
+
const __m128i prod = _mm_madd_epi16(mins, q8s);
|
2546
|
+
acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
|
2547
|
+
|
2548
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
2549
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
2550
|
+
|
2551
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
2552
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2553
|
+
|
2554
|
+
const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
|
2555
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2556
|
+
const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
|
2557
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
2558
|
+
|
2559
|
+
__m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
2560
|
+
const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
|
2561
|
+
const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
|
2562
|
+
q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
2563
|
+
const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
|
2564
|
+
const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
|
2565
|
+
|
2566
|
+
const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2567
|
+
__m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
|
2568
|
+
p16l = _mm_madd_epi16(scale_l, p16l);
|
2569
|
+
sumi_0 = _mm_add_epi32(sumi_0, p16l);
|
2570
|
+
const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2571
|
+
p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
|
2572
|
+
p16l = _mm_madd_epi16(scale_l, p16l);
|
2573
|
+
sumi_1 = _mm_add_epi32(sumi_1, p16l);
|
2574
|
+
|
2575
|
+
const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2576
|
+
__m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
|
2577
|
+
p16h = _mm_madd_epi16(scale_h, p16h);
|
2578
|
+
sumi_0 = _mm_add_epi32(sumi_0, p16h);
|
2579
|
+
const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
2580
|
+
p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
|
2581
|
+
p16h = _mm_madd_epi16(scale_h, p16h);
|
2582
|
+
sumi_1 = _mm_add_epi32(sumi_1, p16h);
|
2583
|
+
|
2584
|
+
}
|
2585
|
+
|
2586
|
+
__m256 vd = _mm256_set1_ps(d);
|
2587
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
2588
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
2589
|
+
|
2590
|
+
}
|
2591
|
+
|
2592
|
+
acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
|
2593
|
+
acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
|
2594
|
+
|
2595
|
+
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
|
2596
|
+
|
1670
2597
|
#else
|
1671
2598
|
|
1672
2599
|
|
@@ -1726,7 +2653,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1726
2653
|
*s = sumf;
|
1727
2654
|
#endif
|
1728
2655
|
}
|
2656
|
+
#else
|
2657
|
+
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2658
|
+
assert(n % QK_K == 0);
|
2659
|
+
|
2660
|
+
const block_q4_K * restrict x = vx;
|
2661
|
+
const block_q8_K * restrict y = vy;
|
2662
|
+
|
2663
|
+
const int nb = n / QK_K;
|
2664
|
+
|
2665
|
+
#ifdef __ARM_NEON
|
2666
|
+
|
2667
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2668
|
+
|
2669
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
2670
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
2671
|
+
#endif
|
2672
|
+
|
2673
|
+
float sumf = 0;
|
2674
|
+
|
2675
|
+
int8x16x2_t q4bytes;
|
2676
|
+
int8x16x4_t q8bytes;
|
2677
|
+
|
2678
|
+
float sum_mins = 0.f;
|
2679
|
+
|
2680
|
+
uint16_t aux16[2];
|
2681
|
+
const uint8_t * restrict scales = (const uint8_t *)aux16;
|
2682
|
+
|
2683
|
+
for (int i = 0; i < nb; ++i) {
|
2684
|
+
|
2685
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2686
|
+
const int8_t * restrict q8 = y[i].qs;
|
2687
|
+
|
2688
|
+
const uint16_t * restrict a = (const uint16_t *)x[i].scales;
|
2689
|
+
aux16[0] = a[0] & 0x0f0f;
|
2690
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
2691
|
+
|
2692
|
+
const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
|
2693
|
+
sum_mins += y[i].d * (float)x[i].d[1] * summi;
|
2694
|
+
|
2695
|
+
const float d = y[i].d * (float)x[i].d[0];
|
2696
|
+
|
2697
|
+
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
|
2698
|
+
|
2699
|
+
#ifdef __ARM_FEATURE_DOTPROD
|
2700
|
+
q8bytes = vld1q_s8_x4(q8);
|
2701
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
2702
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
2703
|
+
|
2704
|
+
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
2705
|
+
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
2706
|
+
|
2707
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
2708
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
2709
|
+
|
2710
|
+
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
2711
|
+
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
2712
|
+
|
2713
|
+
#else
|
2714
|
+
q8bytes = vld1q_s8_x4(q8);
|
2715
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
2716
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
2717
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
2718
|
+
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
2719
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
2720
|
+
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
2721
|
+
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
2722
|
+
|
2723
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
2724
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
2725
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
|
2726
|
+
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
2727
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
2728
|
+
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
2729
|
+
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
2730
|
+
|
2731
|
+
#endif
|
2732
|
+
sumf += d * (sumi1 + sumi2);
|
2733
|
+
|
2734
|
+
}
|
2735
|
+
|
2736
|
+
*s = sumf - sum_mins;
|
2737
|
+
|
2738
|
+
#elif defined __AVX2__
|
2739
|
+
|
2740
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
2741
|
+
|
2742
|
+
__m256 acc = _mm256_setzero_ps();
|
2743
|
+
|
2744
|
+
float summs = 0;
|
2745
|
+
|
2746
|
+
uint16_t aux16[2];
|
2747
|
+
const uint8_t * scales = (const uint8_t *)aux16;
|
2748
|
+
|
2749
|
+
for (int i = 0; i < nb; ++i) {
|
2750
|
+
|
2751
|
+
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
2752
|
+
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
2753
|
+
const __m256 vd = _mm256_set1_ps(d);
|
2754
|
+
|
2755
|
+
const uint16_t * a = (const uint16_t *)x[i].scales;
|
2756
|
+
aux16[0] = a[0] & 0x0f0f;
|
2757
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
2758
|
+
|
2759
|
+
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
2760
|
+
|
2761
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2762
|
+
const int8_t * restrict q8 = y[i].qs;
|
2763
|
+
|
2764
|
+
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
2765
|
+
const __m256i q4l = _mm256_and_si256(q4bits, m4);
|
2766
|
+
const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
|
1729
2767
|
|
2768
|
+
const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2769
|
+
const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2770
|
+
|
2771
|
+
const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
|
2772
|
+
const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
|
2773
|
+
|
2774
|
+
const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
|
2775
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
|
2776
|
+
|
2777
|
+
const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
|
2778
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
|
2779
|
+
|
2780
|
+
}
|
2781
|
+
|
2782
|
+
*s = hsum_float_8(acc) - summs;
|
2783
|
+
|
2784
|
+
#else
|
2785
|
+
|
2786
|
+
uint8_t aux8[QK_K];
|
2787
|
+
int16_t aux16[16];
|
2788
|
+
float sums [8];
|
2789
|
+
memset(sums, 0, 8*sizeof(float));
|
2790
|
+
|
2791
|
+
uint16_t s16[2];
|
2792
|
+
const uint8_t * restrict scales = (const uint8_t *)s16;
|
2793
|
+
|
2794
|
+
float sumf = 0;
|
2795
|
+
for (int i = 0; i < nb; ++i) {
|
2796
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2797
|
+
const int8_t * restrict q8 = y[i].qs;
|
2798
|
+
uint8_t * restrict a = aux8;
|
2799
|
+
for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
|
2800
|
+
for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
|
2801
|
+
|
2802
|
+
const uint16_t * restrict b = (const uint16_t *)x[i].scales;
|
2803
|
+
s16[0] = b[0] & 0x0f0f;
|
2804
|
+
s16[1] = (b[0] >> 4) & 0x0f0f;
|
2805
|
+
|
2806
|
+
sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
2807
|
+
|
2808
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
|
2809
|
+
|
2810
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
2811
|
+
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
|
2812
|
+
q8 += 16; a += 16;
|
2813
|
+
for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
|
2814
|
+
q8 += 16; a += 16;
|
2815
|
+
const float dl = d * scales[j];
|
2816
|
+
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
|
2817
|
+
}
|
2818
|
+
}
|
2819
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2820
|
+
*s = sumf;
|
2821
|
+
#endif
|
2822
|
+
}
|
2823
|
+
#endif
|
2824
|
+
|
2825
|
+
#if QK_K == 256
|
1730
2826
|
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1731
2827
|
assert(n % QK_K == 0);
|
1732
2828
|
|
@@ -1840,18 +2936,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1840
2936
|
|
1841
2937
|
for (int i = 0; i < nb; ++i) {
|
1842
2938
|
|
1843
|
-
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1844
|
-
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1845
|
-
|
1846
2939
|
const uint8_t * restrict q5 = x[i].qs;
|
1847
2940
|
const int8_t * restrict q8 = y[i].qs;
|
1848
2941
|
|
2942
|
+
#if QK_K == 256
|
2943
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2944
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
2945
|
+
|
1849
2946
|
memcpy(utmp, x[i].scales, 12);
|
1850
2947
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1851
2948
|
const uint32_t uaux = utmp[1] & kmask1;
|
1852
2949
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1853
2950
|
utmp[2] = uaux;
|
1854
2951
|
utmp[0] &= kmask1;
|
2952
|
+
#else
|
2953
|
+
// TODO
|
2954
|
+
const float d = 0, dmin = 0;
|
2955
|
+
#endif
|
1855
2956
|
|
1856
2957
|
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
|
1857
2958
|
|
@@ -1876,33 +2977,133 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1876
2977
|
const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
|
1877
2978
|
const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
|
1878
2979
|
|
1879
|
-
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
|
2980
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
|
2981
|
+
|
2982
|
+
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
|
2983
|
+
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
|
2984
|
+
const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
|
2985
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
2986
|
+
|
2987
|
+
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
2988
|
+
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
|
2989
|
+
const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
|
2990
|
+
hmask = _mm256_slli_epi16(hmask, 1);
|
2991
|
+
|
2992
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2993
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
2994
|
+
|
2995
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
|
2996
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
2997
|
+
|
2998
|
+
p16_0 = _mm256_madd_epi16(scale_0, p16_0);
|
2999
|
+
p16_1 = _mm256_madd_epi16(scale_1, p16_1);
|
3000
|
+
|
3001
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
3002
|
+
|
3003
|
+
}
|
3004
|
+
|
3005
|
+
__m256 vd = _mm256_set1_ps(d);
|
3006
|
+
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
|
3007
|
+
|
3008
|
+
}
|
3009
|
+
|
3010
|
+
*s = hsum_float_8(acc) + summs;
|
3011
|
+
|
3012
|
+
#elif defined __AVX__
|
3013
|
+
|
3014
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
3015
|
+
const __m128i mzero = _mm_setzero_si128();
|
3016
|
+
const __m128i mone = _mm_set1_epi8(1);
|
3017
|
+
const __m128i m2 = _mm_set1_epi8(2);
|
3018
|
+
|
3019
|
+
__m256 acc = _mm256_setzero_ps();
|
3020
|
+
|
3021
|
+
float summs = 0.f;
|
3022
|
+
|
3023
|
+
for (int i = 0; i < nb; ++i) {
|
3024
|
+
|
3025
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3026
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
3027
|
+
|
3028
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3029
|
+
const int8_t * restrict q8 = y[i].qs;
|
3030
|
+
|
3031
|
+
memcpy(utmp, x[i].scales, 12);
|
3032
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
3033
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
3034
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
3035
|
+
utmp[2] = uaux;
|
3036
|
+
utmp[0] &= kmask1;
|
3037
|
+
|
3038
|
+
const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
|
3039
|
+
const __m128i scales = _mm_cvtepu8_epi16(utmps);
|
3040
|
+
const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
|
1880
3041
|
|
1881
|
-
|
1882
|
-
|
1883
|
-
|
1884
|
-
|
3042
|
+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
|
3043
|
+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
|
3044
|
+
const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
|
3045
|
+
const __m128i prod = _mm_madd_epi16(mins, q8s);
|
3046
|
+
const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
|
3047
|
+
summs += dmin * _mm_extract_epi32(hsum, 0);
|
1885
3048
|
|
1886
|
-
|
1887
|
-
|
1888
|
-
|
1889
|
-
hmask = _mm256_slli_epi16(hmask, 1);
|
3049
|
+
const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
|
3050
|
+
const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
|
3051
|
+
__m128i hmask = mone;
|
1890
3052
|
|
1891
|
-
|
1892
|
-
|
3053
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
3054
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
1893
3055
|
|
1894
|
-
|
1895
|
-
__m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
|
3056
|
+
int bit = 0;
|
1896
3057
|
|
1897
|
-
|
1898
|
-
|
3058
|
+
__m128i shuffle = _mm_set1_epi16(0x0100);
|
3059
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1899
3060
|
|
1900
|
-
|
3061
|
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
3062
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
3063
|
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
3064
|
+
shuffle = _mm_add_epi16(shuffle, m2);
|
3065
|
+
|
3066
|
+
const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
|
3067
|
+
const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
|
3068
|
+
|
3069
|
+
__m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
|
3070
|
+
__m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
|
3071
|
+
__m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
|
3072
|
+
__m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
|
3073
|
+
__m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
|
3074
|
+
__m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
|
3075
|
+
hmask = _mm_slli_epi16(hmask, 1);
|
3076
|
+
|
3077
|
+
__m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3078
|
+
__m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3079
|
+
__m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
|
3080
|
+
__m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
|
3081
|
+
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
3082
|
+
p16_1 = _mm_madd_epi16(scale_0, p16_1);
|
3083
|
+
|
3084
|
+
q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
|
3085
|
+
q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
|
3086
|
+
q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
|
3087
|
+
q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
|
3088
|
+
q5_0 = _mm_add_epi8(q5l_0, q5h_0);
|
3089
|
+
q5_1 = _mm_add_epi8(q5l_1, q5h_1);
|
3090
|
+
hmask = _mm_slli_epi16(hmask, 1);
|
3091
|
+
|
3092
|
+
q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3093
|
+
q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3094
|
+
__m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
|
3095
|
+
__m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
|
3096
|
+
p16_2 = _mm_madd_epi16(scale_1, p16_2);
|
3097
|
+
p16_3 = _mm_madd_epi16(scale_1, p16_3);
|
3098
|
+
|
3099
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
3100
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
1901
3101
|
|
1902
3102
|
}
|
1903
3103
|
|
1904
3104
|
__m256 vd = _mm256_set1_ps(d);
|
1905
|
-
|
3105
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
3106
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
1906
3107
|
|
1907
3108
|
}
|
1908
3109
|
|
@@ -1972,8 +3173,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1972
3173
|
#endif
|
1973
3174
|
}
|
1974
3175
|
|
3176
|
+
#else
|
3177
|
+
|
3178
|
+
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3179
|
+
assert(n % QK_K == 0);
|
3180
|
+
|
3181
|
+
const block_q5_K * restrict x = vx;
|
3182
|
+
const block_q8_K * restrict y = vy;
|
3183
|
+
|
3184
|
+
const int nb = n / QK_K;
|
3185
|
+
|
3186
|
+
#ifdef __ARM_NEON
|
3187
|
+
|
3188
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
3189
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
3190
|
+
const uint8x16_t mh = vdupq_n_u8(16);
|
3191
|
+
|
3192
|
+
int8x16x4_t q5bytes;
|
3193
|
+
uint8x16x4_t q5h;
|
3194
|
+
|
3195
|
+
float sumf = 0;
|
3196
|
+
|
3197
|
+
for (int i = 0; i < nb; ++i) {
|
3198
|
+
|
3199
|
+
const float d = y[i].d * (float)x[i].d;
|
3200
|
+
const int8_t * sc = x[i].scales;
|
3201
|
+
|
3202
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3203
|
+
const uint8_t * restrict qh = x[i].qh;
|
3204
|
+
const int8_t * restrict q8 = y[i].qs;
|
3205
|
+
|
3206
|
+
const uint8x8_t qhbits = vld1_u8(qh);
|
3207
|
+
|
3208
|
+
const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
|
3209
|
+
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
3210
|
+
|
3211
|
+
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
|
3212
|
+
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
|
3213
|
+
q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
|
3214
|
+
q5h.val[2] = vbicq_u8(mh, htmp);
|
3215
|
+
q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
|
3216
|
+
|
3217
|
+
q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
|
3218
|
+
q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
|
3219
|
+
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
3220
|
+
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
3221
|
+
|
3222
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3223
|
+
|
3224
|
+
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
3225
|
+
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
3226
|
+
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
3227
|
+
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
3228
|
+
|
3229
|
+
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
3230
|
+
|
3231
|
+
#else
|
3232
|
+
|
3233
|
+
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
3234
|
+
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
3235
|
+
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
3236
|
+
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
3237
|
+
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
|
3238
|
+
|
3239
|
+
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
3240
|
+
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
3241
|
+
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
3242
|
+
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
3243
|
+
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
|
3244
|
+
|
3245
|
+
sumf += d*sumi;
|
3246
|
+
#endif
|
3247
|
+
|
3248
|
+
}
|
3249
|
+
|
3250
|
+
*s = sumf;
|
3251
|
+
|
3252
|
+
#elif defined __AVX2__
|
3253
|
+
|
3254
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
3255
|
+
const __m256i mone = _mm256_set1_epi8(1);
|
3256
|
+
|
3257
|
+
__m256 acc = _mm256_setzero_ps();
|
3258
|
+
|
3259
|
+
for (int i = 0; i < nb; ++i) {
|
3260
|
+
|
3261
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3262
|
+
const int8_t * restrict q8 = y[i].qs;
|
3263
|
+
|
3264
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3265
|
+
|
3266
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
3267
|
+
|
3268
|
+
const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
3269
|
+
const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
3270
|
+
|
3271
|
+
int64_t aux64;
|
3272
|
+
memcpy(&aux64, x[i].qh, 8);
|
3273
|
+
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
|
3274
|
+
const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
|
3275
|
+
|
3276
|
+
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
|
3277
|
+
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
|
3278
|
+
|
3279
|
+
const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
|
3280
|
+
const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
|
3281
|
+
|
3282
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
3283
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
3284
|
+
|
3285
|
+
const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
|
3286
|
+
const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
|
3287
|
+
const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
|
3288
|
+
const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
|
3289
|
+
|
3290
|
+
const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
|
3291
|
+
|
3292
|
+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
|
3293
|
+
|
3294
|
+
}
|
3295
|
+
|
3296
|
+
*s = hsum_float_8(acc);
|
3297
|
+
|
3298
|
+
#else
|
3299
|
+
|
3300
|
+
|
3301
|
+
uint8_t aux8[QK_K];
|
3302
|
+
int16_t aux16[16];
|
3303
|
+
float sums [8];
|
3304
|
+
memset(sums, 0, 8*sizeof(float));
|
3305
|
+
|
3306
|
+
float sumf = 0;
|
3307
|
+
for (int i = 0; i < nb; ++i) {
|
3308
|
+
const uint8_t * restrict q4 = x[i].qs;
|
3309
|
+
const uint8_t * restrict hm = x[i].qh;
|
3310
|
+
const int8_t * restrict q8 = y[i].qs;
|
3311
|
+
uint8_t * restrict a = aux8;
|
3312
|
+
for (int l = 0; l < 32; ++l) {
|
3313
|
+
a[l+ 0] = q4[l] & 0xF;
|
3314
|
+
a[l+32] = q4[l] >> 4;
|
3315
|
+
}
|
3316
|
+
for (int is = 0; is < 8; ++is) {
|
3317
|
+
uint8_t m = 1 << is;
|
3318
|
+
for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
|
3319
|
+
}
|
3320
|
+
|
3321
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3322
|
+
const int8_t * restrict sc = x[i].scales;
|
3323
|
+
|
3324
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
3325
|
+
const float dl = d * sc[j];
|
3326
|
+
for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
|
3327
|
+
for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
|
3328
|
+
q8 += 16; a += 16;
|
3329
|
+
}
|
3330
|
+
}
|
3331
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
3332
|
+
*s = sumf;
|
3333
|
+
#endif
|
3334
|
+
}
|
3335
|
+
#endif
|
1975
3336
|
|
1976
3337
|
|
3338
|
+
#if QK_K == 256
|
1977
3339
|
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1978
3340
|
assert(n % QK_K == 0);
|
1979
3341
|
|
@@ -2198,6 +3560,124 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2198
3560
|
|
2199
3561
|
*s = hsum_float_8(acc);
|
2200
3562
|
|
3563
|
+
#elif defined __AVX__
|
3564
|
+
|
3565
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
3566
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
3567
|
+
const __m128i m32s = _mm_set1_epi8(32);
|
3568
|
+
const __m128i m2 = _mm_set1_epi8(2);
|
3569
|
+
|
3570
|
+
__m256 acc = _mm256_setzero_ps();
|
3571
|
+
|
3572
|
+
for (int i = 0; i < nb; ++i) {
|
3573
|
+
|
3574
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3575
|
+
|
3576
|
+
const uint8_t * restrict q4 = x[i].ql;
|
3577
|
+
const uint8_t * restrict qh = x[i].qh;
|
3578
|
+
const int8_t * restrict q8 = y[i].qs;
|
3579
|
+
|
3580
|
+
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
3581
|
+
|
3582
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
3583
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
3584
|
+
|
3585
|
+
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
|
3586
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
3587
|
+
|
3588
|
+
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
3589
|
+
const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
3590
|
+
|
3591
|
+
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
3592
|
+
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
3593
|
+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
|
3594
|
+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
|
3595
|
+
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
|
3596
|
+
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
|
3597
|
+
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
|
3598
|
+
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
|
3599
|
+
|
3600
|
+
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3601
|
+
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3602
|
+
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3603
|
+
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
3604
|
+
|
3605
|
+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
|
3606
|
+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
|
3607
|
+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
|
3608
|
+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
|
3609
|
+
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
|
3610
|
+
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
|
3611
|
+
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
|
3612
|
+
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
|
3613
|
+
|
3614
|
+
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3615
|
+
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3616
|
+
const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3617
|
+
const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3618
|
+
const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3619
|
+
const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3620
|
+
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3621
|
+
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
3622
|
+
|
3623
|
+
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
|
3624
|
+
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
|
3625
|
+
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
|
3626
|
+
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
|
3627
|
+
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
|
3628
|
+
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
|
3629
|
+
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
|
3630
|
+
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
|
3631
|
+
|
3632
|
+
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
3633
|
+
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
3634
|
+
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
3635
|
+
__m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
|
3636
|
+
__m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
|
3637
|
+
__m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
|
3638
|
+
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
3639
|
+
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
3640
|
+
|
3641
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
3642
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
3643
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
3644
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
3645
|
+
p16_4 = _mm_sub_epi16(p16_4, q8s_4);
|
3646
|
+
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
3647
|
+
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
3648
|
+
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
3649
|
+
|
3650
|
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
3651
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3652
|
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
3653
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3654
|
+
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
|
3655
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3656
|
+
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
|
3657
|
+
shuffle = _mm_add_epi8(shuffle, m2);
|
3658
|
+
|
3659
|
+
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
3660
|
+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
3661
|
+
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
3662
|
+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
3663
|
+
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
3664
|
+
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
|
3665
|
+
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
3666
|
+
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
|
3667
|
+
|
3668
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
3669
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
3670
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
|
3671
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
|
3672
|
+
|
3673
|
+
}
|
3674
|
+
|
3675
|
+
__m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
|
3676
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
3677
|
+
}
|
3678
|
+
|
3679
|
+
*s = hsum_float_8(acc);
|
3680
|
+
|
2201
3681
|
#else
|
2202
3682
|
|
2203
3683
|
int8_t aux8[QK_K];
|
@@ -2242,3 +3722,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2242
3722
|
*s = sumf;
|
2243
3723
|
#endif
|
2244
3724
|
}
|
3725
|
+
|
3726
|
+
#else
|
3727
|
+
|
3728
|
+
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
3729
|
+
assert(n % QK_K == 0);
|
3730
|
+
|
3731
|
+
const block_q6_K * restrict x = vx;
|
3732
|
+
const block_q8_K * restrict y = vy;
|
3733
|
+
|
3734
|
+
const int nb = n / QK_K;
|
3735
|
+
|
3736
|
+
#ifdef __ARM_NEON
|
3737
|
+
|
3738
|
+
float sum = 0;
|
3739
|
+
|
3740
|
+
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
3741
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
3742
|
+
const int8x16_t m32s = vdupq_n_s8(32);
|
3743
|
+
|
3744
|
+
const uint8x16_t mone = vdupq_n_u8(3);
|
3745
|
+
|
3746
|
+
int8x16x4_t q6bytes;
|
3747
|
+
uint8x16x4_t q6h;
|
3748
|
+
|
3749
|
+
for (int i = 0; i < nb; ++i) {
|
3750
|
+
|
3751
|
+
const float d_all = (float)x[i].d;
|
3752
|
+
|
3753
|
+
const uint8_t * restrict q6 = x[i].ql;
|
3754
|
+
const uint8_t * restrict qh = x[i].qh;
|
3755
|
+
const int8_t * restrict q8 = y[i].qs;
|
3756
|
+
|
3757
|
+
const int8_t * restrict scale = x[i].scales;
|
3758
|
+
|
3759
|
+
int32_t isum = 0;
|
3760
|
+
|
3761
|
+
uint8x16_t qhbits = vld1q_u8(qh);
|
3762
|
+
uint8x16x2_t q6bits = vld1q_u8_x2(q6);
|
3763
|
+
int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
3764
|
+
|
3765
|
+
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
|
3766
|
+
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
|
3767
|
+
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3768
|
+
shifted = vshrq_n_u8(qhbits, 4);
|
3769
|
+
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3770
|
+
shifted = vshrq_n_u8(qhbits, 6);
|
3771
|
+
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3772
|
+
|
3773
|
+
q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
|
3774
|
+
q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
|
3775
|
+
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
3776
|
+
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
3777
|
+
|
3778
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
3779
|
+
|
3780
|
+
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
3781
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
3782
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
3783
|
+
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
3784
|
+
#else
|
3785
|
+
|
3786
|
+
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
3787
|
+
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
3788
|
+
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
3789
|
+
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
3790
|
+
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
3791
|
+
|
3792
|
+
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
3793
|
+
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
3794
|
+
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
3795
|
+
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
3796
|
+
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
3797
|
+
#endif
|
3798
|
+
|
3799
|
+
sum += isum * d_all * y[i].d;
|
3800
|
+
|
3801
|
+
}
|
3802
|
+
*s = sum;
|
3803
|
+
|
3804
|
+
#elif defined __AVX2__
|
3805
|
+
|
3806
|
+
const __m256i m4 = _mm256_set1_epi8(0xF);
|
3807
|
+
const __m256i m2 = _mm256_set1_epi8(3);
|
3808
|
+
const __m256i m32s = _mm256_set1_epi8(32);
|
3809
|
+
|
3810
|
+
__m256 acc = _mm256_setzero_ps();
|
3811
|
+
|
3812
|
+
for (int i = 0; i < nb; ++i) {
|
3813
|
+
|
3814
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3815
|
+
|
3816
|
+
const uint8_t * restrict q4 = x[i].ql;
|
3817
|
+
const uint8_t * restrict qh = x[i].qh;
|
3818
|
+
const int8_t * restrict q8 = y[i].qs;
|
3819
|
+
|
3820
|
+
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
3821
|
+
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
3822
|
+
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
3823
|
+
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
3824
|
+
|
3825
|
+
__m256i sumi = _mm256_setzero_si256();
|
3826
|
+
|
3827
|
+
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
3828
|
+
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
3829
|
+
|
3830
|
+
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
3831
|
+
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
3832
|
+
|
3833
|
+
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
3834
|
+
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
3835
|
+
|
3836
|
+
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
3837
|
+
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
3838
|
+
|
3839
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
3840
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
3841
|
+
|
3842
|
+
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
|
3843
|
+
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
|
3844
|
+
|
3845
|
+
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
|
3846
|
+
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
|
3847
|
+
|
3848
|
+
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
|
3849
|
+
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
|
3850
|
+
|
3851
|
+
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
|
3852
|
+
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
|
3853
|
+
|
3854
|
+
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
|
3855
|
+
|
3856
|
+
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
|
3857
|
+
}
|
3858
|
+
|
3859
|
+
*s = hsum_float_8(acc);
|
3860
|
+
|
3861
|
+
#else
|
3862
|
+
|
3863
|
+
int8_t aux8[QK_K];
|
3864
|
+
int16_t aux16[8];
|
3865
|
+
float sums [8];
|
3866
|
+
int32_t aux32[8];
|
3867
|
+
memset(sums, 0, 8*sizeof(float));
|
3868
|
+
|
3869
|
+
float sumf = 0;
|
3870
|
+
for (int i = 0; i < nb; ++i) {
|
3871
|
+
const uint8_t * restrict q4 = x[i].ql;
|
3872
|
+
const uint8_t * restrict qh = x[i].qh;
|
3873
|
+
const int8_t * restrict q8 = y[i].qs;
|
3874
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
3875
|
+
int8_t * restrict a = aux8;
|
3876
|
+
for (int l = 0; l < 16; ++l) {
|
3877
|
+
a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
3878
|
+
a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
3879
|
+
a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
3880
|
+
a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
3881
|
+
}
|
3882
|
+
int is = 0;
|
3883
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
3884
|
+
int scale = x[i].scales[is++];
|
3885
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
3886
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
3887
|
+
q8 += 8; a += 8;
|
3888
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
3889
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
3890
|
+
q8 += 8; a += 8;
|
3891
|
+
}
|
3892
|
+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
|
3893
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
3894
|
+
}
|
3895
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
3896
|
+
*s = sumf;
|
3897
|
+
#endif
|
3898
|
+
}
|
3899
|
+
|
3900
|
+
#endif
|