llama_cpp 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
261
261
  return scale;
262
262
  }
263
263
 
264
+ #if QK_K == 256
264
265
  static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
265
266
  if (j < 4) {
266
267
  *d = q[j] & 63; *m = q[j + 4] & 63;
@@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
269
270
  *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
270
271
  }
271
272
  }
273
+ #endif
272
274
 
273
275
  //========================- 2-bit (de)-quantization
274
276
 
@@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
330
332
  }
331
333
  }
332
334
 
335
+ #if QK_K == 256
333
336
  for (int j = 0; j < QK_K; j += 128) {
334
337
  for (int l = 0; l < 32; ++l) {
335
338
  y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
336
339
  }
337
340
  }
341
+ #else
342
+ for (int l = 0; l < 16; ++l) {
343
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
344
+ }
345
+ #endif
338
346
 
339
347
  x += QK_K;
340
348
 
@@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
352
360
 
353
361
  const uint8_t * q = x[i].qs;
354
362
 
363
+ #if QK_K == 256
355
364
  int is = 0;
356
365
  float dl, ml;
357
366
  for (int n = 0; n < QK_K; n += 128) {
@@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
370
379
  }
371
380
  q += 32;
372
381
  }
373
-
382
+ #else
383
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
384
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
385
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
386
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
387
+ for (int l = 0; l < 16; ++l) {
388
+ y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
389
+ y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
390
+ y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
391
+ y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
392
+ }
393
+ y += QK_K;
394
+ #endif
374
395
  }
375
396
  }
376
397
 
@@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
412
433
  }
413
434
  }
414
435
 
436
+ #if QK_K == 256
415
437
  memset(y[i].scales, 0, 12);
416
438
  if (max_scale) {
417
439
  float iscale = -32.f/max_scale;
@@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
445
467
  L[16*j + ii] = l + 4;
446
468
  }
447
469
  }
470
+ #else
471
+ if (max_scale) {
472
+ float iscale = -8.f/max_scale;
473
+ for (int j = 0; j < QK_K/16; j+=2) {
474
+ int l1 = nearest_int(iscale*scales[j]);
475
+ l1 = 8 + MAX(-8, MIN(7, l1));
476
+ int l2 = nearest_int(iscale*scales[j+1]);
477
+ l2 = 8 + MAX(-8, MIN(7, l2));
478
+ y[i].scales[j/2] = l1 | (l2 << 4);
479
+ }
480
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
481
+ } else {
482
+ for (int j = 0; j < QK_K/16; j+=2) {
483
+ y[i].scales[j/2] = 0;
484
+ }
485
+ y[i].d = ggml_fp32_to_fp16(0.f);
486
+ }
487
+ for (int j = 0; j < QK_K/16; ++j) {
488
+ int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
489
+ float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
490
+ if (!d) {
491
+ continue;
492
+ }
493
+ for (int ii = 0; ii < 16; ++ii) {
494
+ int l = nearest_int(x[16*j + ii]/d);
495
+ l = MAX(-4, MIN(3, l));
496
+ L[16*j + ii] = l + 4;
497
+ }
498
+ }
499
+ #endif
448
500
 
449
501
  memset(y[i].hmask, 0, QK_K/8);
450
- // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
502
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
451
503
  int m = 0;
452
504
  uint8_t hm = 1;
453
505
  for (int j = 0; j < QK_K; ++j) {
@@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
459
511
  m = 0; hm <<= 1;
460
512
  }
461
513
  }
514
+ #if QK_K == 256
462
515
  for (int j = 0; j < QK_K; j += 128) {
463
516
  for (int l = 0; l < 32; ++l) {
464
517
  y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
465
518
  }
466
519
  }
520
+ #else
521
+ for (int l = 0; l < 16; ++l) {
522
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
523
+ }
524
+ #endif
467
525
 
468
526
  x += QK_K;
469
527
  }
470
528
  }
471
529
 
530
+ #if QK_K == 256
472
531
  void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
473
532
  assert(k % QK_K == 0);
474
- assert(QK_K == 256);
475
533
  const int nb = k / QK_K;
476
534
 
477
535
  const uint32_t kmask1 = 0x03030303;
@@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
519
577
 
520
578
  }
521
579
  }
580
+ #else
581
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
582
+ assert(k % QK_K == 0);
583
+ assert(QK_K == 64);
584
+ const int nb = k / QK_K;
585
+
586
+ for (int i = 0; i < nb; i++) {
587
+
588
+ const float d_all = ggml_fp16_to_fp32(x[i].d);
589
+
590
+ const uint8_t * restrict q = x[i].qs;
591
+ const uint8_t * restrict hm = x[i].hmask;
592
+
593
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
594
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
595
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
596
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
597
+
598
+ for (int l=0; l<8; ++l) {
599
+ uint8_t h = hm[l];
600
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
601
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
602
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
603
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
604
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
605
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
606
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
607
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
608
+ }
609
+ y += QK_K;
610
+ }
611
+ }
612
+ #endif
522
613
 
523
614
  void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
524
615
  quantize_row_q3_K_reference(x, vy, k);
@@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
563
654
  }
564
655
  }
565
656
 
657
+ #if QK_K == 256
566
658
  float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
567
659
  float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
568
660
  for (int j = 0; j < QK_K/32; ++j) {
@@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
594
686
  L[32*j + ii] = l;
595
687
  }
596
688
  }
689
+ #else
690
+ const float s_factor = 15.f;
691
+ float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
692
+ float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
693
+ int d1 = nearest_int(inv_scale*scales[0]);
694
+ int m1 = nearest_int(inv_min*mins[0]);
695
+ int d2 = nearest_int(inv_scale*scales[1]);
696
+ int m2 = nearest_int(inv_min*mins[1]);
697
+ y[i].scales[0] = d1 | (m1 << 4);
698
+ y[i].scales[1] = d2 | (m2 << 4);
699
+ y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
700
+ y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
701
+
702
+ float sumlx = 0;
703
+ int suml2 = 0;
704
+ for (int j = 0; j < QK_K/32; ++j) {
705
+ const uint8_t sd = y[i].scales[j] & 0xF;
706
+ const uint8_t sm = y[i].scales[j] >> 4;
707
+ const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
708
+ if (!d) continue;
709
+ const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
710
+ for (int ii = 0; ii < 32; ++ii) {
711
+ int l = nearest_int((x[32*j + ii] + m)/d);
712
+ l = MAX(0, MIN(15, l));
713
+ L[32*j + ii] = l;
714
+ sumlx += (x[32*j + ii] + m)*l*sd;
715
+ suml2 += l*l*sd*sd;
716
+ }
717
+ }
718
+ if (suml2) {
719
+ y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
720
+ }
721
+ #endif
597
722
  uint8_t * q = y[i].qs;
598
723
  for (int j = 0; j < QK_K; j += 64) {
599
- for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
724
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
725
+ q += 32;
600
726
  }
601
727
 
602
728
  x += QK_K;
@@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
610
736
 
611
737
  for (int i = 0; i < nb; i++) {
612
738
 
613
- const float d = ggml_fp16_to_fp32(x[i].d);
614
- const float min = ggml_fp16_to_fp32(x[i].dmin);
615
-
616
739
  const uint8_t * q = x[i].qs;
617
740
 
741
+ #if QK_K == 256
742
+
743
+ const float d = ggml_fp16_to_fp32(x[i].d);
744
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
745
+
618
746
  int is = 0;
619
747
  uint8_t sc, m;
620
748
  for (int j = 0; j < QK_K; j += 64) {
@@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
626
754
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
627
755
  q += 32; is += 2;
628
756
  }
757
+ #else
758
+ const float dall = ggml_fp16_to_fp32(x[i].d[0]);
759
+ const float mall = ggml_fp16_to_fp32(x[i].d[1]);
760
+ const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
761
+ const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
762
+ for (int l = 0; l < 32; ++l) {
763
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
764
+ y[l+32] = d2 * (q[l] >> 4) - m2;
765
+ }
766
+ y += QK_K;
767
+ #endif
629
768
 
630
769
  }
631
770
  }
@@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
653
792
  assert(k % QK_K == 0);
654
793
  const int nb = k / QK_K;
655
794
 
795
+ #if QK_K == 256
656
796
  uint8_t L[QK_K];
657
797
  float mins[QK_K/32];
658
798
  float scales[QK_K/32];
799
+ #else
800
+ int8_t L[QK_K];
801
+ float scales[QK_K/16];
802
+ #endif
659
803
 
660
804
  for (int i = 0; i < nb; i++) {
661
805
 
806
+ #if QK_K == 256
807
+
662
808
  float max_scale = 0; // as we are deducting the min, scales are always positive
663
809
  float max_min = 0;
664
810
  for (int j = 0; j < QK_K/32; ++j) {
@@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
725
871
  m1 <<= 2; m2 <<= 2;
726
872
  ql += 32;
727
873
  }
874
+ #else
875
+ float max_scale = 0, amax = 0;
876
+ for (int j = 0; j < QK_K/16; ++j) {
877
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
878
+ float abs_scale = fabsf(scales[j]);
879
+ if (abs_scale > amax) {
880
+ amax = abs_scale;
881
+ max_scale = scales[j];
882
+ }
883
+ }
884
+
885
+ float iscale = -128.f/max_scale;
886
+ for (int j = 0; j < QK_K/16; ++j) {
887
+ int l = nearest_int(iscale*scales[j]);
888
+ y[i].scales[j] = MAX(-128, MIN(127, l));
889
+ }
890
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
891
+
892
+ for (int j = 0; j < QK_K/16; ++j) {
893
+ const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
894
+ if (!d) continue;
895
+ for (int ii = 0; ii < 16; ++ii) {
896
+ int l = nearest_int(x[16*j + ii]/d);
897
+ l = MAX(-16, MIN(15, l));
898
+ L[16*j + ii] = l + 16;
899
+ }
900
+ }
901
+
902
+ uint8_t * restrict qh = y[i].qh;
903
+ uint8_t * restrict ql = y[i].qs;
904
+ memset(qh, 0, QK_K/8);
905
+
906
+ for (int j = 0; j < 32; ++j) {
907
+ int jm = j%8;
908
+ int is = j/8;
909
+ int l1 = L[j];
910
+ if (l1 > 15) {
911
+ l1 -= 16; qh[jm] |= (1 << is);
912
+ }
913
+ int l2 = L[j + 32];
914
+ if (l2 > 15) {
915
+ l2 -= 16; qh[jm] |= (1 << (4 + is));
916
+ }
917
+ ql[j] = l1 | (l2 << 4);
918
+ }
919
+ #endif
728
920
 
729
921
  x += QK_K;
730
922
 
@@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
737
929
 
738
930
  for (int i = 0; i < nb; i++) {
739
931
 
740
- const float d = ggml_fp16_to_fp32(x[i].d);
741
- const float min = ggml_fp16_to_fp32(x[i].dmin);
742
-
743
932
  const uint8_t * ql = x[i].qs;
744
933
  const uint8_t * qh = x[i].qh;
745
934
 
935
+ #if QK_K == 256
936
+
937
+ const float d = ggml_fp16_to_fp32(x[i].d);
938
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
939
+
746
940
  int is = 0;
747
941
  uint8_t sc, m;
748
942
  uint8_t u1 = 1, u2 = 2;
@@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
756
950
  ql += 32; is += 2;
757
951
  u1 <<= 2; u2 <<= 2;
758
952
  }
953
+ #else
954
+ float d = ggml_fp16_to_fp32(x[i].d);
955
+ const int8_t * restrict s = x[i].scales;
956
+ for (int l = 0; l < 8; ++l) {
957
+ y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
958
+ y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
959
+ y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
960
+ y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
961
+ y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
962
+ y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
963
+ y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
964
+ y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
965
+ }
966
+ y += QK_K;
967
+ #endif
759
968
  }
760
969
  }
761
970
 
@@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
823
1032
 
824
1033
  uint8_t * restrict ql = y[i].ql;
825
1034
  uint8_t * restrict qh = y[i].qh;
1035
+ #if QK_K == 256
826
1036
  for (int j = 0; j < QK_K; j += 128) {
827
1037
  for (int l = 0; l < 32; ++l) {
828
1038
  const uint8_t q1 = L[j + l + 0] & 0xF;
@@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
836
1046
  ql += 64;
837
1047
  qh += 32;
838
1048
  }
1049
+ #else
1050
+ for (int l = 0; l < 32; ++l) {
1051
+ const uint8_t q1 = L[l + 0] & 0xF;
1052
+ const uint8_t q2 = L[l + 32] & 0xF;
1053
+ ql[l] = q1 | (q2 << 4);
1054
+ }
1055
+ for (int l = 0; l < 16; ++l) {
1056
+ qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
1057
+ }
1058
+ #endif
839
1059
 
840
1060
  x += QK_K;
841
1061
 
@@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
854
1074
  const uint8_t * restrict qh = x[i].qh;
855
1075
  const int8_t * restrict sc = x[i].scales;
856
1076
 
1077
+ #if QK_K == 256
857
1078
  for (int n = 0; n < QK_K; n += 128) {
858
1079
  for (int l = 0; l < 32; ++l) {
859
1080
  int is = l/16;
@@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
871
1092
  qh += 32;
872
1093
  sc += 8;
873
1094
  }
1095
+ #else
1096
+ for (int l = 0; l < 16; ++l) {
1097
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1098
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1099
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1100
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1101
+ y[l+ 0] = d * sc[0] * q1;
1102
+ y[l+16] = d * sc[1] * q2;
1103
+ y[l+32] = d * sc[2] * q3;
1104
+ y[l+48] = d * sc[3] * q4;
1105
+ }
1106
+ y += 64;
1107
+ #endif
874
1108
 
875
1109
  }
876
1110
  }
@@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) {
1002
1236
  }
1003
1237
  #endif
1004
1238
 
1239
+ #if QK_K == 256
1005
1240
  void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1006
1241
 
1007
1242
  const block_q2_K * restrict x = vx;
@@ -1158,6 +1393,112 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1158
1393
 
1159
1394
  *s = hsum_float_8(acc);
1160
1395
 
1396
+ #elif defined __AVX__
1397
+
1398
+ const __m128i m3 = _mm_set1_epi8(0x3);
1399
+ const __m128i m4 = _mm_set1_epi8(0xF);
1400
+ const __m128i m2 = _mm_set1_epi8(0x2);
1401
+
1402
+ __m256 acc = _mm256_setzero_ps();
1403
+
1404
+ for (int i = 0; i < nb; ++i) {
1405
+
1406
+ const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1407
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1408
+
1409
+ const uint8_t * restrict q2 = x[i].qs;
1410
+ const int8_t * restrict q8 = y[i].qs;
1411
+
1412
+ // load mins and scales from block_q2_K.scales[QK_K/16]
1413
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1414
+ const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
1415
+ const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1416
+ const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
1417
+ const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
1418
+
1419
+ // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
1420
+ const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
1421
+ const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1422
+
1423
+ // sumf += -dmin * summs in 32bits*8
1424
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
1425
+
1426
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1427
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
1428
+ const __m128i scales[2] = { scales_0, scales_1 };
1429
+
1430
+ __m128i sumi_0 = _mm_setzero_si128();
1431
+ __m128i sumi_1 = _mm_setzero_si128();
1432
+
1433
+ for (int j = 0; j < QK_K/128; ++j) {
1434
+
1435
+ // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
1436
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1437
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1438
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1439
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1440
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1441
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1442
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1443
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1444
+
1445
+ // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
1446
+ __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1447
+ const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1448
+ const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1449
+ const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1450
+ const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1451
+ q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1452
+ const __m128i q2_1 = _mm_and_si128(q2bits, m3);
1453
+ const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1454
+ const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1455
+ const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1456
+
1457
+ // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
1458
+ __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
1459
+ __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
1460
+ __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
1461
+ __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
1462
+ __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
1463
+ __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
1464
+ __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
1465
+ __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
1466
+
1467
+ // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
1468
+ __m128i shuffle = _mm_set1_epi16(0x0100);
1469
+ p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
1470
+ shuffle = _mm_add_epi16(shuffle, m2);
1471
+ p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
1472
+ shuffle = _mm_add_epi16(shuffle, m2);
1473
+ p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
1474
+ shuffle = _mm_add_epi16(shuffle, m2);
1475
+ p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
1476
+ shuffle = _mm_add_epi16(shuffle, m2);
1477
+ p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
1478
+ shuffle = _mm_add_epi16(shuffle, m2);
1479
+ p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
1480
+ shuffle = _mm_add_epi16(shuffle, m2);
1481
+ p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
1482
+ shuffle = _mm_add_epi16(shuffle, m2);
1483
+ p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
1484
+
1485
+ p0 = _mm_add_epi32(p0, p1);
1486
+ p2 = _mm_add_epi32(p2, p3);
1487
+ p4 = _mm_add_epi32(p4, p5);
1488
+ p6 = _mm_add_epi32(p6, p7);
1489
+
1490
+ // isum in 32bits*4*2
1491
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
1492
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
1493
+ }
1494
+
1495
+ // sumf += dall * isum - dmin * summs in 32bits
1496
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
1497
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1498
+ }
1499
+
1500
+ *s = hsum_float_8(acc);
1501
+
1161
1502
  #else
1162
1503
 
1163
1504
  float sumf = 0;
@@ -1201,6 +1542,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1201
1542
  #endif
1202
1543
  }
1203
1544
 
1545
+ #else
1546
+
1547
+ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1548
+
1549
+ const block_q2_K * restrict x = vx;
1550
+ const block_q8_K * restrict y = vy;
1551
+
1552
+ const int nb = n / QK_K;
1553
+
1554
+ #ifdef __ARM_NEON
1555
+
1556
+ const uint8x16_t m3 = vdupq_n_u8(0x3);
1557
+ const int32x4_t vzero = vdupq_n_s32(0);
1558
+
1559
+ int8x16x4_t q2bytes;
1560
+
1561
+ uint32_t aux32[2];
1562
+ const uint8_t * scales = (const uint8_t *)aux32;
1563
+
1564
+ float sum = 0;
1565
+
1566
+ for (int i = 0; i < nb; ++i) {
1567
+
1568
+ const float d = y[i].d * (float)x[i].d;
1569
+ const float dmin = -y[i].d * (float)x[i].dmin;
1570
+
1571
+ const uint8_t * restrict q2 = x[i].qs;
1572
+ const int8_t * restrict q8 = y[i].qs;
1573
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1574
+
1575
+ aux32[0] = sc[0] & 0x0f0f0f0f;
1576
+ aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
1577
+
1578
+ sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
1579
+
1580
+ int isum1 = 0, isum2 = 0;
1581
+
1582
+ const uint8x16_t q2bits = vld1q_u8(q2);
1583
+
1584
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
1585
+
1586
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
1587
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
1588
+ q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
1589
+ q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
1590
+
1591
+ #if defined(__ARM_FEATURE_DOTPROD)
1592
+ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
1593
+ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
1594
+ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
1595
+ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
1596
+ #else
1597
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
1598
+ vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
1599
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
1600
+ vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
1601
+ isum1 += vaddvq_s16(p1) * scales[0];
1602
+ isum2 += vaddvq_s16(p2) * scales[1];
1603
+
1604
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
1605
+ vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
1606
+ const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
1607
+ vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
1608
+ isum1 += vaddvq_s16(p3) * scales[2];
1609
+ isum2 += vaddvq_s16(p4) * scales[3];
1610
+ #endif
1611
+ sum += d * (isum1 + isum2);
1612
+
1613
+ }
1614
+
1615
+ *s = sum;
1616
+
1617
+ #elif defined __AVX2__
1618
+
1619
+ const __m256i m3 = _mm256_set1_epi8(3);
1620
+
1621
+ __m256 acc = _mm256_setzero_ps();
1622
+
1623
+ uint32_t ud, um;
1624
+ const uint8_t * restrict db = (const uint8_t *)&ud;
1625
+ const uint8_t * restrict mb = (const uint8_t *)&um;
1626
+
1627
+ float summs = 0;
1628
+
1629
+ // TODO: optimize this
1630
+
1631
+ for (int i = 0; i < nb; ++i) {
1632
+
1633
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1634
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1635
+
1636
+ const uint8_t * restrict q2 = x[i].qs;
1637
+ const int8_t * restrict q8 = y[i].qs;
1638
+
1639
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1640
+ ud = (sc[0] >> 0) & 0x0f0f0f0f;
1641
+ um = (sc[0] >> 4) & 0x0f0f0f0f;
1642
+
1643
+ int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
1644
+ summs += dmin * smin;
1645
+
1646
+ const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1647
+ const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
1648
+ const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
1649
+
1650
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1651
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
1652
+
1653
+ const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
1654
+ const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
1655
+
1656
+ const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
1657
+ const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
1658
+ const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
1659
+ const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
1660
+
1661
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
1662
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
1663
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
1664
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
1665
+ }
1666
+
1667
+ *s = hsum_float_8(acc) + summs;
1668
+
1669
+ #else
1670
+
1671
+ float sumf = 0;
1672
+
1673
+ int isum[4];
1674
+
1675
+ for (int i = 0; i < nb; ++i) {
1676
+
1677
+ const uint8_t * q2 = x[i].qs;
1678
+ const int8_t * q8 = y[i].qs;
1679
+ const uint8_t * sc = x[i].scales;
1680
+
1681
+ int summs = 0;
1682
+ for (int j = 0; j < QK_K/16; ++j) {
1683
+ summs += y[i].bsums[j] * (sc[j] >> 4);
1684
+ }
1685
+
1686
+ const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1687
+ const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1688
+
1689
+ isum[0] = isum[1] = isum[2] = isum[3] = 0;
1690
+ for (int l = 0; l < 16; ++l) {
1691
+ isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
1692
+ isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
1693
+ isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
1694
+ isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
1695
+ }
1696
+ for (int l = 0; l < 4; ++l) {
1697
+ isum[l] *= (sc[l] & 0xF);
1698
+ }
1699
+ sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
1700
+ }
1701
+ *s = sumf;
1702
+ #endif
1703
+ }
1704
+ #endif
1705
+
1706
+ #if QK_K == 256
1204
1707
  void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1205
1708
  assert(n % QK_K == 0);
1206
1709
 
@@ -1434,34 +1937,176 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1434
1937
 
1435
1938
  *s = hsum_float_8(acc);
1436
1939
 
1437
- #else
1438
- // scalar version
1439
- // This function is written like this so the compiler can manage to vectorize most of it
1440
- // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1441
- // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1442
- // The ideal situation would be if we could just write the code once, and the compiler would
1443
- // automatically produce the best possible set of machine instructions, instead of us having to manually
1444
- // write vectorized versions for AVX, ARM_NEON, etc.
1940
+ #elif defined __AVX__
1445
1941
 
1446
- int8_t aux8[QK_K];
1447
- int16_t aux16[8];
1448
- float sums [8];
1449
- int32_t aux32[8];
1450
- memset(sums, 0, 8*sizeof(float));
1942
+ const __m128i m3 = _mm_set1_epi8(3);
1943
+ const __m128i mone = _mm_set1_epi8(1);
1944
+ const __m128i m32 = _mm_set1_epi8(32);
1945
+ const __m128i m2 = _mm_set1_epi8(2);
1451
1946
 
1452
- uint32_t auxs[4];
1453
- const int8_t * scales = (const int8_t*)auxs;
1947
+ __m256 acc = _mm256_setzero_ps();
1948
+
1949
+ uint32_t *aux;
1454
1950
 
1455
- float sumf = 0;
1456
1951
  for (int i = 0; i < nb; ++i) {
1952
+
1953
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1954
+
1457
1955
  const uint8_t * restrict q3 = x[i].qs;
1458
- const uint8_t * restrict hm = x[i].hmask;
1459
- const int8_t * restrict q8 = y[i].qs;
1460
- memset(aux32, 0, 8*sizeof(int32_t));
1461
- int8_t * restrict a = aux8;
1462
- uint8_t m = 1;
1463
- for (int j = 0; j < QK_K; j += 128) {
1464
- for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1956
+ const int8_t * restrict q8 = y[i].qs;
1957
+
1958
+ // Set up scales
1959
+ aux = (uint32_t *)x[i].scales;
1960
+ __m128i scales128 = _mm_set_epi32(
1961
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1962
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1963
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1964
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1965
+ scales128 = _mm_sub_epi8(scales128, m32);
1966
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
1967
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
1968
+ const __m128i scales[2] = { scales_0, scales_1 };
1969
+
1970
+ // high bit *128*2 from block_q3_K.hmask[QK_K/8]
1971
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
1972
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
1973
+
1974
+ // integer accumulator
1975
+ __m128i sumi_0 = _mm_setzero_si128();
1976
+ __m128i sumi_1 = _mm_setzero_si128();
1977
+
1978
+ for (int j = 0; j < QK_K/128; ++j) {
1979
+ // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
1980
+ const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1981
+ const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1982
+
1983
+ // prepare low and high bits
1984
+ const int bit = j << 2;
1985
+
1986
+ const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
1987
+ const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
1988
+ const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
1989
+ const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
1990
+
1991
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
1992
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
1993
+ const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1994
+ const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1995
+
1996
+ const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
1997
+ const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
1998
+ const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1999
+ const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
2000
+
2001
+ const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
2002
+ const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
2003
+ const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
2004
+ const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
2005
+
2006
+ // load Q8 quants from block_q8_K.qs[QK_K]
2007
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2008
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2009
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2010
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2011
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2012
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2013
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2014
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2015
+
2016
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
2017
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2018
+ // and 2 if the high bit was set)
2019
+ __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
2020
+ __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
2021
+ __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
2022
+ __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
2023
+ __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
2024
+ __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
2025
+ __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
2026
+ __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
2027
+
2028
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
2029
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
2030
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
2031
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
2032
+ __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
2033
+ __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
2034
+ __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
2035
+ __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
2036
+
2037
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
2038
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
2039
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
2040
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
2041
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
2042
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
2043
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
2044
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
2045
+
2046
+ // multiply with scales
2047
+ __m128i shuffle = _mm_set1_epi16(0x0100);
2048
+ p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
2049
+ shuffle = _mm_add_epi16(shuffle, m2);
2050
+ p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
2051
+ shuffle = _mm_add_epi16(shuffle, m2);
2052
+ p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
2053
+ shuffle = _mm_add_epi16(shuffle, m2);
2054
+ p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
2055
+ shuffle = _mm_add_epi16(shuffle, m2);
2056
+ p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
2057
+ shuffle = _mm_add_epi16(shuffle, m2);
2058
+ p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
2059
+ shuffle = _mm_add_epi16(shuffle, m2);
2060
+ p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
2061
+ shuffle = _mm_add_epi16(shuffle, m2);
2062
+ p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
2063
+
2064
+ // accumulate
2065
+ p16_0 = _mm_add_epi32(p16_0, p16_1);
2066
+ p16_2 = _mm_add_epi32(p16_2, p16_3);
2067
+ p16_4 = _mm_add_epi32(p16_4, p16_5);
2068
+ p16_6 = _mm_add_epi32(p16_6, p16_7);
2069
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2070
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
2071
+
2072
+ }
2073
+
2074
+ // multiply with block scale and accumulate
2075
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2076
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
2077
+
2078
+ }
2079
+
2080
+ *s = hsum_float_8(acc);
2081
+
2082
+ #else
2083
+ // scalar version
2084
+ // This function is written like this so the compiler can manage to vectorize most of it
2085
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
2086
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
2087
+ // The ideal situation would be if we could just write the code once, and the compiler would
2088
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
2089
+ // write vectorized versions for AVX, ARM_NEON, etc.
2090
+
2091
+ int8_t aux8[QK_K];
2092
+ int16_t aux16[8];
2093
+ float sums [8];
2094
+ int32_t aux32[8];
2095
+ memset(sums, 0, 8*sizeof(float));
2096
+
2097
+ uint32_t auxs[4];
2098
+ const int8_t * scales = (const int8_t*)auxs;
2099
+
2100
+ float sumf = 0;
2101
+ for (int i = 0; i < nb; ++i) {
2102
+ const uint8_t * restrict q3 = x[i].qs;
2103
+ const uint8_t * restrict hm = x[i].hmask;
2104
+ const int8_t * restrict q8 = y[i].qs;
2105
+ memset(aux32, 0, 8*sizeof(int32_t));
2106
+ int8_t * restrict a = aux8;
2107
+ uint8_t m = 1;
2108
+ for (int j = 0; j < QK_K; j += 128) {
2109
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1465
2110
  for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1466
2111
  a += 32; m <<= 1;
1467
2112
  for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
@@ -1501,6 +2146,206 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1501
2146
 
1502
2147
  }
1503
2148
 
2149
+ #else
2150
+
2151
+ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2152
+ assert(n % QK_K == 0);
2153
+
2154
+ const block_q3_K * restrict x = vx;
2155
+ const block_q8_K * restrict y = vy;
2156
+
2157
+ const int nb = n / QK_K;
2158
+
2159
+ #ifdef __ARM_NEON
2160
+
2161
+ #ifdef __ARM_FEATURE_DOTPROD
2162
+ const int32x4_t vzero = vdupq_n_s32(0);
2163
+ #endif
2164
+
2165
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
2166
+ const uint8x16_t mh = vdupq_n_u8(4);
2167
+
2168
+ int8x16x4_t q3bytes;
2169
+
2170
+ uint16_t aux16[2];
2171
+ int8_t * scales = (int8_t *)aux16;
2172
+
2173
+ float sum = 0;
2174
+
2175
+ for (int i = 0; i < nb; ++i) {
2176
+
2177
+ uint8x16x4_t q3h;
2178
+
2179
+ const uint8x8_t hbits = vld1_u8(x[i].hmask);
2180
+ const uint8x16_t q3bits = vld1q_u8(x[i].qs);
2181
+ const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
2182
+
2183
+ const uint16_t a = *(const uint16_t *)x[i].scales;
2184
+ aux16[0] = a & 0x0f0f;
2185
+ aux16[1] = (a >> 4) & 0x0f0f;
2186
+
2187
+ for (int j = 0; j < 4; ++j) scales[j] -= 8;
2188
+
2189
+ int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
2190
+
2191
+ const float d = y[i].d * (float)x[i].d;
2192
+
2193
+ const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
2194
+ q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
2195
+ q3h.val[1] = vandq_u8(mh, htmp);
2196
+ q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
2197
+ q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
2198
+
2199
+ q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
2200
+ q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
2201
+ q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
2202
+ q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
2203
+
2204
+ #if defined(__ARM_FEATURE_DOTPROD)
2205
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
2206
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
2207
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
2208
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
2209
+ #else
2210
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
2211
+ vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
2212
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
2213
+ vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
2214
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
2215
+ vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
2216
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
2217
+ vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
2218
+ isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
2219
+ #endif
2220
+
2221
+ sum += d * isum;
2222
+
2223
+ }
2224
+
2225
+ *s = sum;
2226
+
2227
+ #elif defined __AVX2__
2228
+
2229
+ const __m256i m3 = _mm256_set1_epi8(3);
2230
+ const __m256i m1 = _mm256_set1_epi8(1);
2231
+
2232
+ __m256 acc = _mm256_setzero_ps();
2233
+
2234
+ uint64_t aux64;
2235
+
2236
+ uint16_t aux16[2];
2237
+ const int8_t * aux8 = (const int8_t *)aux16;
2238
+
2239
+ for (int i = 0; i < nb; ++i) {
2240
+
2241
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2242
+
2243
+ const uint8_t * restrict q3 = x[i].qs;
2244
+ const int8_t * restrict q8 = y[i].qs;
2245
+
2246
+ const uint16_t a = *(const uint16_t *)x[i].scales;
2247
+ aux16[0] = a & 0x0f0f;
2248
+ aux16[1] = (a >> 4) & 0x0f0f;
2249
+
2250
+ const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
2251
+ const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
2252
+
2253
+ memcpy(&aux64, x[i].hmask, 8);
2254
+
2255
+ const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2256
+ __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
2257
+ __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
2258
+ q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
2259
+ q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
2260
+
2261
+ // load low 2 bits
2262
+ const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2263
+
2264
+ // prepare low and high bits
2265
+ const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
2266
+ const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
2267
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
2268
+
2269
+ // load Q8 quants
2270
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2271
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2272
+
2273
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
2274
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2275
+ // and 2 if the high bit was set)
2276
+ const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
2277
+ const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
2278
+
2279
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
2280
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
2281
+
2282
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2283
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2284
+
2285
+ // multiply with scales
2286
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2287
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
2288
+
2289
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
2290
+
2291
+ // multiply with block scale and accumulate
2292
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
2293
+
2294
+ }
2295
+
2296
+ *s = hsum_float_8(acc);
2297
+
2298
+ #else
2299
+
2300
+ int8_t aux8[QK_K];
2301
+ int16_t aux16[8];
2302
+ float sums [8];
2303
+ int32_t aux32[8];
2304
+ int32_t scales[4];
2305
+ memset(sums, 0, 8*sizeof(float));
2306
+
2307
+ float sumf = 0;
2308
+ for (int i = 0; i < nb; ++i) {
2309
+ const uint8_t * restrict q3 = x[i].qs;
2310
+ const uint8_t * restrict hm = x[i].hmask;
2311
+ const int8_t * restrict q8 = y[i].qs;
2312
+ int8_t * restrict a = aux8;
2313
+ for (int l = 0; l < 8; ++l) {
2314
+ a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
2315
+ a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
2316
+ a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
2317
+ a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
2318
+ a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
2319
+ a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
2320
+ a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
2321
+ a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
2322
+ }
2323
+
2324
+ scales[0] = (x[i].scales[0] & 0xF) - 8;
2325
+ scales[1] = (x[i].scales[0] >> 4) - 8;
2326
+ scales[2] = (x[i].scales[1] & 0xF) - 8;
2327
+ scales[3] = (x[i].scales[1] >> 4) - 8;
2328
+
2329
+ memset(aux32, 0, 8*sizeof(int32_t));
2330
+ for (int j = 0; j < QK_K/16; ++j) {
2331
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2332
+ q8 += 8; a += 8;
2333
+ for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
2334
+ q8 += 8; a += 8;
2335
+ for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
2336
+ }
2337
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
2338
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2339
+ }
2340
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2341
+ *s = sumf;
2342
+
2343
+ #endif
2344
+
2345
+ }
2346
+ #endif
2347
+
2348
+ #if QK_K == 256
1504
2349
  void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1505
2350
  assert(n % QK_K == 0);
1506
2351
 
@@ -1614,9 +2459,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1614
2459
  const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1615
2460
  const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1616
2461
 
1617
- const uint8_t * restrict q4 = x[i].qs;
1618
- const int8_t * restrict q8 = y[i].qs;
1619
-
1620
2462
  memcpy(utmp, x[i].scales, 12);
1621
2463
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1622
2464
  const uint32_t uaux = utmp[1] & kmask1;
@@ -1624,6 +2466,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1624
2466
  utmp[2] = uaux;
1625
2467
  utmp[0] &= kmask1;
1626
2468
 
2469
+ const uint8_t * restrict q4 = x[i].qs;
2470
+ const int8_t * restrict q8 = y[i].qs;
2471
+
1627
2472
  const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1628
2473
 
1629
2474
  const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
@@ -1667,6 +2512,88 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1667
2512
 
1668
2513
  *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1669
2514
 
2515
+ #elif defined __AVX__
2516
+
2517
+ const __m128i m4 = _mm_set1_epi8(0xF);
2518
+ const __m128i m2 = _mm_set1_epi8(0x2);
2519
+
2520
+ __m256 acc = _mm256_setzero_ps();
2521
+ __m128 acc_m = _mm_setzero_ps();
2522
+
2523
+ for (int i = 0; i < nb; ++i) {
2524
+
2525
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2526
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
2527
+
2528
+ const uint8_t * restrict q4 = x[i].qs;
2529
+ const int8_t * restrict q8 = y[i].qs;
2530
+
2531
+ memcpy(utmp, x[i].scales, 12);
2532
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2533
+ const uint32_t uaux = utmp[1] & kmask1;
2534
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2535
+ utmp[2] = uaux;
2536
+ utmp[0] &= kmask1;
2537
+
2538
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
2539
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
2540
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
2541
+
2542
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
2543
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
2544
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
2545
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
2546
+ acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
2547
+
2548
+ __m128i sumi_0 = _mm_setzero_si128();
2549
+ __m128i sumi_1 = _mm_setzero_si128();
2550
+
2551
+ __m128i shuffle = _mm_set1_epi16(0x0100);
2552
+ for (int j = 0; j < QK_K/64; ++j) {
2553
+
2554
+ const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
2555
+ shuffle = _mm_add_epi16(shuffle, m2);
2556
+ const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
2557
+ shuffle = _mm_add_epi16(shuffle, m2);
2558
+
2559
+ __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2560
+ const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
2561
+ const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
2562
+ q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2563
+ const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
2564
+ const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
2565
+
2566
+ const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2567
+ __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
2568
+ p16l = _mm_madd_epi16(scale_l, p16l);
2569
+ sumi_0 = _mm_add_epi32(sumi_0, p16l);
2570
+ const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2571
+ p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
2572
+ p16l = _mm_madd_epi16(scale_l, p16l);
2573
+ sumi_1 = _mm_add_epi32(sumi_1, p16l);
2574
+
2575
+ const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2576
+ __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
2577
+ p16h = _mm_madd_epi16(scale_h, p16h);
2578
+ sumi_0 = _mm_add_epi32(sumi_0, p16h);
2579
+ const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2580
+ p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
2581
+ p16h = _mm_madd_epi16(scale_h, p16h);
2582
+ sumi_1 = _mm_add_epi32(sumi_1, p16h);
2583
+
2584
+ }
2585
+
2586
+ __m256 vd = _mm256_set1_ps(d);
2587
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2588
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2589
+
2590
+ }
2591
+
2592
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
2593
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
2594
+
2595
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
2596
+
1670
2597
  #else
1671
2598
 
1672
2599
 
@@ -1726,7 +2653,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1726
2653
  *s = sumf;
1727
2654
  #endif
1728
2655
  }
2656
+ #else
2657
+ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2658
+ assert(n % QK_K == 0);
2659
+
2660
+ const block_q4_K * restrict x = vx;
2661
+ const block_q8_K * restrict y = vy;
2662
+
2663
+ const int nb = n / QK_K;
2664
+
2665
+ #ifdef __ARM_NEON
2666
+
2667
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2668
+
2669
+ #ifdef __ARM_FEATURE_DOTPROD
2670
+ const int32x4_t mzero = vdupq_n_s32(0);
2671
+ #endif
2672
+
2673
+ float sumf = 0;
2674
+
2675
+ int8x16x2_t q4bytes;
2676
+ int8x16x4_t q8bytes;
2677
+
2678
+ float sum_mins = 0.f;
2679
+
2680
+ uint16_t aux16[2];
2681
+ const uint8_t * restrict scales = (const uint8_t *)aux16;
2682
+
2683
+ for (int i = 0; i < nb; ++i) {
2684
+
2685
+ const uint8_t * restrict q4 = x[i].qs;
2686
+ const int8_t * restrict q8 = y[i].qs;
2687
+
2688
+ const uint16_t * restrict a = (const uint16_t *)x[i].scales;
2689
+ aux16[0] = a[0] & 0x0f0f;
2690
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
2691
+
2692
+ const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
2693
+ sum_mins += y[i].d * (float)x[i].d[1] * summi;
2694
+
2695
+ const float d = y[i].d * (float)x[i].d[0];
2696
+
2697
+ const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
2698
+
2699
+ #ifdef __ARM_FEATURE_DOTPROD
2700
+ q8bytes = vld1q_s8_x4(q8);
2701
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
2702
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
2703
+
2704
+ const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2705
+ const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
2706
+
2707
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2708
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2709
+
2710
+ const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
2711
+ const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
2712
+
2713
+ #else
2714
+ q8bytes = vld1q_s8_x4(q8);
2715
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
2716
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
2717
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
2718
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
2719
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
2720
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
2721
+ int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
2722
+
2723
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2724
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2725
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
2726
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
2727
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
2728
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
2729
+ int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
2730
+
2731
+ #endif
2732
+ sumf += d * (sumi1 + sumi2);
2733
+
2734
+ }
2735
+
2736
+ *s = sumf - sum_mins;
2737
+
2738
+ #elif defined __AVX2__
2739
+
2740
+ const __m256i m4 = _mm256_set1_epi8(0xF);
2741
+
2742
+ __m256 acc = _mm256_setzero_ps();
2743
+
2744
+ float summs = 0;
2745
+
2746
+ uint16_t aux16[2];
2747
+ const uint8_t * scales = (const uint8_t *)aux16;
2748
+
2749
+ for (int i = 0; i < nb; ++i) {
2750
+
2751
+ const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
2752
+ const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
2753
+ const __m256 vd = _mm256_set1_ps(d);
2754
+
2755
+ const uint16_t * a = (const uint16_t *)x[i].scales;
2756
+ aux16[0] = a[0] & 0x0f0f;
2757
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
2758
+
2759
+ summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2760
+
2761
+ const uint8_t * restrict q4 = x[i].qs;
2762
+ const int8_t * restrict q8 = y[i].qs;
2763
+
2764
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
2765
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
2766
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
1729
2767
 
2768
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2769
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
2770
+
2771
+ const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
2772
+ const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
2773
+
2774
+ const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
2775
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
2776
+
2777
+ const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
2778
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
2779
+
2780
+ }
2781
+
2782
+ *s = hsum_float_8(acc) - summs;
2783
+
2784
+ #else
2785
+
2786
+ uint8_t aux8[QK_K];
2787
+ int16_t aux16[16];
2788
+ float sums [8];
2789
+ memset(sums, 0, 8*sizeof(float));
2790
+
2791
+ uint16_t s16[2];
2792
+ const uint8_t * restrict scales = (const uint8_t *)s16;
2793
+
2794
+ float sumf = 0;
2795
+ for (int i = 0; i < nb; ++i) {
2796
+ const uint8_t * restrict q4 = x[i].qs;
2797
+ const int8_t * restrict q8 = y[i].qs;
2798
+ uint8_t * restrict a = aux8;
2799
+ for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
2800
+ for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
2801
+
2802
+ const uint16_t * restrict b = (const uint16_t *)x[i].scales;
2803
+ s16[0] = b[0] & 0x0f0f;
2804
+ s16[1] = (b[0] >> 4) & 0x0f0f;
2805
+
2806
+ sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2807
+
2808
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
2809
+
2810
+ for (int j = 0; j < QK_K/32; ++j) {
2811
+ for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
2812
+ q8 += 16; a += 16;
2813
+ for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
2814
+ q8 += 16; a += 16;
2815
+ const float dl = d * scales[j];
2816
+ for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
2817
+ }
2818
+ }
2819
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2820
+ *s = sumf;
2821
+ #endif
2822
+ }
2823
+ #endif
2824
+
2825
+ #if QK_K == 256
1730
2826
  void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1731
2827
  assert(n % QK_K == 0);
1732
2828
 
@@ -1840,18 +2936,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
1840
2936
 
1841
2937
  for (int i = 0; i < nb; ++i) {
1842
2938
 
1843
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1844
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1845
-
1846
2939
  const uint8_t * restrict q5 = x[i].qs;
1847
2940
  const int8_t * restrict q8 = y[i].qs;
1848
2941
 
2942
+ #if QK_K == 256
2943
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2944
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
2945
+
1849
2946
  memcpy(utmp, x[i].scales, 12);
1850
2947
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1851
2948
  const uint32_t uaux = utmp[1] & kmask1;
1852
2949
  utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1853
2950
  utmp[2] = uaux;
1854
2951
  utmp[0] &= kmask1;
2952
+ #else
2953
+ // TODO
2954
+ const float d = 0, dmin = 0;
2955
+ #endif
1855
2956
 
1856
2957
  const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1857
2958
 
@@ -1876,33 +2977,133 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
1876
2977
  const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1877
2978
  const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1878
2979
 
1879
- const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
2980
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
2981
+
2982
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
2983
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
2984
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
2985
+ hmask = _mm256_slli_epi16(hmask, 1);
2986
+
2987
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
2988
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
2989
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
2990
+ hmask = _mm256_slli_epi16(hmask, 1);
2991
+
2992
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2993
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2994
+
2995
+ __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
2996
+ __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
2997
+
2998
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2999
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
3000
+
3001
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
3002
+
3003
+ }
3004
+
3005
+ __m256 vd = _mm256_set1_ps(d);
3006
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
3007
+
3008
+ }
3009
+
3010
+ *s = hsum_float_8(acc) + summs;
3011
+
3012
+ #elif defined __AVX__
3013
+
3014
+ const __m128i m4 = _mm_set1_epi8(0xF);
3015
+ const __m128i mzero = _mm_setzero_si128();
3016
+ const __m128i mone = _mm_set1_epi8(1);
3017
+ const __m128i m2 = _mm_set1_epi8(2);
3018
+
3019
+ __m256 acc = _mm256_setzero_ps();
3020
+
3021
+ float summs = 0.f;
3022
+
3023
+ for (int i = 0; i < nb; ++i) {
3024
+
3025
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3026
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3027
+
3028
+ const uint8_t * restrict q5 = x[i].qs;
3029
+ const int8_t * restrict q8 = y[i].qs;
3030
+
3031
+ memcpy(utmp, x[i].scales, 12);
3032
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
3033
+ const uint32_t uaux = utmp[1] & kmask1;
3034
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
3035
+ utmp[2] = uaux;
3036
+ utmp[0] &= kmask1;
3037
+
3038
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
3039
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
3040
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
1880
3041
 
1881
- const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
1882
- const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1883
- const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
1884
- hmask = _mm256_slli_epi16(hmask, 1);
3042
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
3043
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
3044
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
3045
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
3046
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
3047
+ summs += dmin * _mm_extract_epi32(hsum, 0);
1885
3048
 
1886
- const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
1887
- const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1888
- const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
1889
- hmask = _mm256_slli_epi16(hmask, 1);
3049
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
3050
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
3051
+ __m128i hmask = mone;
1890
3052
 
1891
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1892
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3053
+ __m128i sumi_0 = _mm_setzero_si128();
3054
+ __m128i sumi_1 = _mm_setzero_si128();
1893
3055
 
1894
- __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
1895
- __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
3056
+ int bit = 0;
1896
3057
 
1897
- p16_0 = _mm256_madd_epi16(scale_0, p16_0);
1898
- p16_1 = _mm256_madd_epi16(scale_1, p16_1);
3058
+ __m128i shuffle = _mm_set1_epi16(0x0100);
3059
+ for (int j = 0; j < QK_K/64; ++j) {
1899
3060
 
1900
- sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
3061
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
3062
+ shuffle = _mm_add_epi16(shuffle, m2);
3063
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
3064
+ shuffle = _mm_add_epi16(shuffle, m2);
3065
+
3066
+ const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
3067
+ const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
3068
+
3069
+ __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
3070
+ __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
3071
+ __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
3072
+ __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
3073
+ __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
3074
+ __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
3075
+ hmask = _mm_slli_epi16(hmask, 1);
3076
+
3077
+ __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3078
+ __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3079
+ __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
3080
+ __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
3081
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
3082
+ p16_1 = _mm_madd_epi16(scale_0, p16_1);
3083
+
3084
+ q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
3085
+ q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
3086
+ q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
3087
+ q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
3088
+ q5_0 = _mm_add_epi8(q5l_0, q5h_0);
3089
+ q5_1 = _mm_add_epi8(q5l_1, q5h_1);
3090
+ hmask = _mm_slli_epi16(hmask, 1);
3091
+
3092
+ q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3093
+ q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3094
+ __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
3095
+ __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
3096
+ p16_2 = _mm_madd_epi16(scale_1, p16_2);
3097
+ p16_3 = _mm_madd_epi16(scale_1, p16_3);
3098
+
3099
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
3100
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
1901
3101
 
1902
3102
  }
1903
3103
 
1904
3104
  __m256 vd = _mm256_set1_ps(d);
1905
- acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
3105
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3106
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
1906
3107
 
1907
3108
  }
1908
3109
 
@@ -1972,8 +3173,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
1972
3173
  #endif
1973
3174
  }
1974
3175
 
3176
+ #else
3177
+
3178
+ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3179
+ assert(n % QK_K == 0);
3180
+
3181
+ const block_q5_K * restrict x = vx;
3182
+ const block_q8_K * restrict y = vy;
3183
+
3184
+ const int nb = n / QK_K;
3185
+
3186
+ #ifdef __ARM_NEON
3187
+
3188
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
3189
+ const int32x4_t mzero = vdupq_n_s32(0);
3190
+ const uint8x16_t mh = vdupq_n_u8(16);
3191
+
3192
+ int8x16x4_t q5bytes;
3193
+ uint8x16x4_t q5h;
3194
+
3195
+ float sumf = 0;
3196
+
3197
+ for (int i = 0; i < nb; ++i) {
3198
+
3199
+ const float d = y[i].d * (float)x[i].d;
3200
+ const int8_t * sc = x[i].scales;
3201
+
3202
+ const uint8_t * restrict q5 = x[i].qs;
3203
+ const uint8_t * restrict qh = x[i].qh;
3204
+ const int8_t * restrict q8 = y[i].qs;
3205
+
3206
+ const uint8x8_t qhbits = vld1_u8(qh);
3207
+
3208
+ const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
3209
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
3210
+
3211
+ const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
3212
+ q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
3213
+ q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
3214
+ q5h.val[2] = vbicq_u8(mh, htmp);
3215
+ q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
3216
+
3217
+ q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
3218
+ q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
3219
+ q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
3220
+ q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
3221
+
3222
+ #if defined(__ARM_FEATURE_DOTPROD)
3223
+
3224
+ int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
3225
+ int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
3226
+ int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
3227
+ int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
3228
+
3229
+ sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
3230
+
3231
+ #else
3232
+
3233
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
3234
+ vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
3235
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
3236
+ vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
3237
+ int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
3238
+
3239
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
3240
+ vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
3241
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
3242
+ vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
3243
+ sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
3244
+
3245
+ sumf += d*sumi;
3246
+ #endif
3247
+
3248
+ }
3249
+
3250
+ *s = sumf;
3251
+
3252
+ #elif defined __AVX2__
3253
+
3254
+ const __m256i m4 = _mm256_set1_epi8(0xF);
3255
+ const __m256i mone = _mm256_set1_epi8(1);
3256
+
3257
+ __m256 acc = _mm256_setzero_ps();
3258
+
3259
+ for (int i = 0; i < nb; ++i) {
3260
+
3261
+ const uint8_t * restrict q5 = x[i].qs;
3262
+ const int8_t * restrict q8 = y[i].qs;
3263
+
3264
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3265
+
3266
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3267
+
3268
+ const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
3269
+ const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
3270
+
3271
+ int64_t aux64;
3272
+ memcpy(&aux64, x[i].qh, 8);
3273
+ const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
3274
+ const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
3275
+
3276
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
3277
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
3278
+
3279
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
3280
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
3281
+
3282
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3283
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3284
+
3285
+ const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
3286
+ const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
3287
+ const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
3288
+ const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
3289
+
3290
+ const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
3291
+
3292
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
3293
+
3294
+ }
3295
+
3296
+ *s = hsum_float_8(acc);
3297
+
3298
+ #else
3299
+
3300
+
3301
+ uint8_t aux8[QK_K];
3302
+ int16_t aux16[16];
3303
+ float sums [8];
3304
+ memset(sums, 0, 8*sizeof(float));
3305
+
3306
+ float sumf = 0;
3307
+ for (int i = 0; i < nb; ++i) {
3308
+ const uint8_t * restrict q4 = x[i].qs;
3309
+ const uint8_t * restrict hm = x[i].qh;
3310
+ const int8_t * restrict q8 = y[i].qs;
3311
+ uint8_t * restrict a = aux8;
3312
+ for (int l = 0; l < 32; ++l) {
3313
+ a[l+ 0] = q4[l] & 0xF;
3314
+ a[l+32] = q4[l] >> 4;
3315
+ }
3316
+ for (int is = 0; is < 8; ++is) {
3317
+ uint8_t m = 1 << is;
3318
+ for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
3319
+ }
3320
+
3321
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3322
+ const int8_t * restrict sc = x[i].scales;
3323
+
3324
+ for (int j = 0; j < QK_K/16; ++j) {
3325
+ const float dl = d * sc[j];
3326
+ for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
3327
+ for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
3328
+ q8 += 16; a += 16;
3329
+ }
3330
+ }
3331
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
3332
+ *s = sumf;
3333
+ #endif
3334
+ }
3335
+ #endif
1975
3336
 
1976
3337
 
3338
+ #if QK_K == 256
1977
3339
  void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1978
3340
  assert(n % QK_K == 0);
1979
3341
 
@@ -2198,6 +3560,124 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
2198
3560
 
2199
3561
  *s = hsum_float_8(acc);
2200
3562
 
3563
+ #elif defined __AVX__
3564
+
3565
+ const __m128i m4 = _mm_set1_epi8(0xF);
3566
+ const __m128i m3 = _mm_set1_epi8(3);
3567
+ const __m128i m32s = _mm_set1_epi8(32);
3568
+ const __m128i m2 = _mm_set1_epi8(2);
3569
+
3570
+ __m256 acc = _mm256_setzero_ps();
3571
+
3572
+ for (int i = 0; i < nb; ++i) {
3573
+
3574
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3575
+
3576
+ const uint8_t * restrict q4 = x[i].ql;
3577
+ const uint8_t * restrict qh = x[i].qh;
3578
+ const int8_t * restrict q8 = y[i].qs;
3579
+
3580
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
3581
+
3582
+ __m128i sumi_0 = _mm_setzero_si128();
3583
+ __m128i sumi_1 = _mm_setzero_si128();
3584
+
3585
+ __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
3586
+ for (int j = 0; j < QK_K/128; ++j) {
3587
+
3588
+ const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
3589
+ const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
3590
+
3591
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
3592
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
3593
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
3594
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
3595
+ const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
3596
+ const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
3597
+ const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
3598
+ const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
3599
+
3600
+ const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3601
+ const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3602
+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3603
+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3604
+
3605
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
3606
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
3607
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
3608
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
3609
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
3610
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
3611
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
3612
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
3613
+
3614
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3615
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3616
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3617
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3618
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3619
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3620
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3621
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3622
+
3623
+ __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
3624
+ __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
3625
+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
3626
+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
3627
+ __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
3628
+ __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
3629
+ __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
3630
+ __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
3631
+
3632
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
3633
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
3634
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
3635
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
3636
+ __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
3637
+ __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
3638
+ __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
3639
+ __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
3640
+
3641
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
3642
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
3643
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
3644
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
3645
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
3646
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
3647
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
3648
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
3649
+
3650
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
3651
+ shuffle = _mm_add_epi8(shuffle, m2);
3652
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
3653
+ shuffle = _mm_add_epi8(shuffle, m2);
3654
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
3655
+ shuffle = _mm_add_epi8(shuffle, m2);
3656
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
3657
+ shuffle = _mm_add_epi8(shuffle, m2);
3658
+
3659
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
3660
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
3661
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
3662
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
3663
+ p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
3664
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
3665
+ p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
3666
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
3667
+
3668
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
3669
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
3670
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
3671
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
3672
+
3673
+ }
3674
+
3675
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3676
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
3677
+ }
3678
+
3679
+ *s = hsum_float_8(acc);
3680
+
2201
3681
  #else
2202
3682
 
2203
3683
  int8_t aux8[QK_K];
@@ -2242,3 +3722,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
2242
3722
  *s = sumf;
2243
3723
  #endif
2244
3724
  }
3725
+
3726
+ #else
3727
+
3728
+ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3729
+ assert(n % QK_K == 0);
3730
+
3731
+ const block_q6_K * restrict x = vx;
3732
+ const block_q8_K * restrict y = vy;
3733
+
3734
+ const int nb = n / QK_K;
3735
+
3736
+ #ifdef __ARM_NEON
3737
+
3738
+ float sum = 0;
3739
+
3740
+ const uint8x16_t m4b = vdupq_n_u8(0xF);
3741
+ const int32x4_t vzero = vdupq_n_s32(0);
3742
+ const int8x16_t m32s = vdupq_n_s8(32);
3743
+
3744
+ const uint8x16_t mone = vdupq_n_u8(3);
3745
+
3746
+ int8x16x4_t q6bytes;
3747
+ uint8x16x4_t q6h;
3748
+
3749
+ for (int i = 0; i < nb; ++i) {
3750
+
3751
+ const float d_all = (float)x[i].d;
3752
+
3753
+ const uint8_t * restrict q6 = x[i].ql;
3754
+ const uint8_t * restrict qh = x[i].qh;
3755
+ const int8_t * restrict q8 = y[i].qs;
3756
+
3757
+ const int8_t * restrict scale = x[i].scales;
3758
+
3759
+ int32_t isum = 0;
3760
+
3761
+ uint8x16_t qhbits = vld1q_u8(qh);
3762
+ uint8x16x2_t q6bits = vld1q_u8_x2(q6);
3763
+ int8x16x4_t q8bytes = vld1q_s8_x4(q8);
3764
+
3765
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
3766
+ uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
3767
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3768
+ shifted = vshrq_n_u8(qhbits, 4);
3769
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3770
+ shifted = vshrq_n_u8(qhbits, 6);
3771
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3772
+
3773
+ q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
3774
+ q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
3775
+ q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
3776
+ q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
3777
+
3778
+ #if defined(__ARM_FEATURE_DOTPROD)
3779
+
3780
+ isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3781
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3782
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3783
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3784
+ #else
3785
+
3786
+ int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
3787
+ vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
3788
+ int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
3789
+ vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
3790
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
3791
+
3792
+ int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
3793
+ vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
3794
+ int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
3795
+ vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
3796
+ isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
3797
+ #endif
3798
+
3799
+ sum += isum * d_all * y[i].d;
3800
+
3801
+ }
3802
+ *s = sum;
3803
+
3804
+ #elif defined __AVX2__
3805
+
3806
+ const __m256i m4 = _mm256_set1_epi8(0xF);
3807
+ const __m256i m2 = _mm256_set1_epi8(3);
3808
+ const __m256i m32s = _mm256_set1_epi8(32);
3809
+
3810
+ __m256 acc = _mm256_setzero_ps();
3811
+
3812
+ for (int i = 0; i < nb; ++i) {
3813
+
3814
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3815
+
3816
+ const uint8_t * restrict q4 = x[i].ql;
3817
+ const uint8_t * restrict qh = x[i].qh;
3818
+ const int8_t * restrict q8 = y[i].qs;
3819
+
3820
+ const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
3821
+ const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
3822
+ const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
3823
+ const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
3824
+
3825
+ __m256i sumi = _mm256_setzero_si256();
3826
+
3827
+ const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
3828
+ const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
3829
+
3830
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
3831
+ const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
3832
+
3833
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
3834
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
3835
+
3836
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
3837
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
3838
+
3839
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3840
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3841
+
3842
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
3843
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
3844
+
3845
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
3846
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
3847
+
3848
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
3849
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
3850
+
3851
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
3852
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
3853
+
3854
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
3855
+
3856
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
3857
+ }
3858
+
3859
+ *s = hsum_float_8(acc);
3860
+
3861
+ #else
3862
+
3863
+ int8_t aux8[QK_K];
3864
+ int16_t aux16[8];
3865
+ float sums [8];
3866
+ int32_t aux32[8];
3867
+ memset(sums, 0, 8*sizeof(float));
3868
+
3869
+ float sumf = 0;
3870
+ for (int i = 0; i < nb; ++i) {
3871
+ const uint8_t * restrict q4 = x[i].ql;
3872
+ const uint8_t * restrict qh = x[i].qh;
3873
+ const int8_t * restrict q8 = y[i].qs;
3874
+ memset(aux32, 0, 8*sizeof(int32_t));
3875
+ int8_t * restrict a = aux8;
3876
+ for (int l = 0; l < 16; ++l) {
3877
+ a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
3878
+ a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
3879
+ a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
3880
+ a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
3881
+ }
3882
+ int is = 0;
3883
+ for (int j = 0; j < QK_K/16; ++j) {
3884
+ int scale = x[i].scales[is++];
3885
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
3886
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3887
+ q8 += 8; a += 8;
3888
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
3889
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3890
+ q8 += 8; a += 8;
3891
+ }
3892
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
3893
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3894
+ }
3895
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
3896
+ *s = sumf;
3897
+ #endif
3898
+ }
3899
+
3900
+ #endif