llama_cpp 0.2.2 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
261
261
  return scale;
262
262
  }
263
263
 
264
+ #if QK_K == 256
264
265
  static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
265
266
  if (j < 4) {
266
267
  *d = q[j] & 63; *m = q[j + 4] & 63;
@@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
269
270
  *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
270
271
  }
271
272
  }
273
+ #endif
272
274
 
273
275
  //========================- 2-bit (de)-quantization
274
276
 
@@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
330
332
  }
331
333
  }
332
334
 
335
+ #if QK_K == 256
333
336
  for (int j = 0; j < QK_K; j += 128) {
334
337
  for (int l = 0; l < 32; ++l) {
335
338
  y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
336
339
  }
337
340
  }
341
+ #else
342
+ for (int l = 0; l < 16; ++l) {
343
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
344
+ }
345
+ #endif
338
346
 
339
347
  x += QK_K;
340
348
 
@@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
352
360
 
353
361
  const uint8_t * q = x[i].qs;
354
362
 
363
+ #if QK_K == 256
355
364
  int is = 0;
356
365
  float dl, ml;
357
366
  for (int n = 0; n < QK_K; n += 128) {
@@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
370
379
  }
371
380
  q += 32;
372
381
  }
373
-
382
+ #else
383
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
384
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
385
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
386
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
387
+ for (int l = 0; l < 16; ++l) {
388
+ y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
389
+ y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
390
+ y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
391
+ y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
392
+ }
393
+ y += QK_K;
394
+ #endif
374
395
  }
375
396
  }
376
397
 
@@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
412
433
  }
413
434
  }
414
435
 
436
+ #if QK_K == 256
415
437
  memset(y[i].scales, 0, 12);
416
438
  if (max_scale) {
417
439
  float iscale = -32.f/max_scale;
@@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
445
467
  L[16*j + ii] = l + 4;
446
468
  }
447
469
  }
470
+ #else
471
+ if (max_scale) {
472
+ float iscale = -8.f/max_scale;
473
+ for (int j = 0; j < QK_K/16; j+=2) {
474
+ int l1 = nearest_int(iscale*scales[j]);
475
+ l1 = 8 + MAX(-8, MIN(7, l1));
476
+ int l2 = nearest_int(iscale*scales[j+1]);
477
+ l2 = 8 + MAX(-8, MIN(7, l2));
478
+ y[i].scales[j/2] = l1 | (l2 << 4);
479
+ }
480
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
481
+ } else {
482
+ for (int j = 0; j < QK_K/16; j+=2) {
483
+ y[i].scales[j/2] = 0;
484
+ }
485
+ y[i].d = ggml_fp32_to_fp16(0.f);
486
+ }
487
+ for (int j = 0; j < QK_K/16; ++j) {
488
+ int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
489
+ float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
490
+ if (!d) {
491
+ continue;
492
+ }
493
+ for (int ii = 0; ii < 16; ++ii) {
494
+ int l = nearest_int(x[16*j + ii]/d);
495
+ l = MAX(-4, MIN(3, l));
496
+ L[16*j + ii] = l + 4;
497
+ }
498
+ }
499
+ #endif
448
500
 
449
501
  memset(y[i].hmask, 0, QK_K/8);
450
- // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
502
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
451
503
  int m = 0;
452
504
  uint8_t hm = 1;
453
505
  for (int j = 0; j < QK_K; ++j) {
@@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
459
511
  m = 0; hm <<= 1;
460
512
  }
461
513
  }
514
+ #if QK_K == 256
462
515
  for (int j = 0; j < QK_K; j += 128) {
463
516
  for (int l = 0; l < 32; ++l) {
464
517
  y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
465
518
  }
466
519
  }
520
+ #else
521
+ for (int l = 0; l < 16; ++l) {
522
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
523
+ }
524
+ #endif
467
525
 
468
526
  x += QK_K;
469
527
  }
470
528
  }
471
529
 
530
+ #if QK_K == 256
472
531
  void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
473
532
  assert(k % QK_K == 0);
474
- assert(QK_K == 256);
475
533
  const int nb = k / QK_K;
476
534
 
477
535
  const uint32_t kmask1 = 0x03030303;
@@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
519
577
 
520
578
  }
521
579
  }
580
+ #else
581
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
582
+ assert(k % QK_K == 0);
583
+ assert(QK_K == 64);
584
+ const int nb = k / QK_K;
585
+
586
+ for (int i = 0; i < nb; i++) {
587
+
588
+ const float d_all = ggml_fp16_to_fp32(x[i].d);
589
+
590
+ const uint8_t * restrict q = x[i].qs;
591
+ const uint8_t * restrict hm = x[i].hmask;
592
+
593
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
594
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
595
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
596
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
597
+
598
+ for (int l=0; l<8; ++l) {
599
+ uint8_t h = hm[l];
600
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
601
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
602
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
603
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
604
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
605
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
606
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
607
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
608
+ }
609
+ y += QK_K;
610
+ }
611
+ }
612
+ #endif
522
613
 
523
614
  void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
524
615
  quantize_row_q3_K_reference(x, vy, k);
@@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
563
654
  }
564
655
  }
565
656
 
657
+ #if QK_K == 256
566
658
  float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
567
659
  float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
568
660
  for (int j = 0; j < QK_K/32; ++j) {
@@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
594
686
  L[32*j + ii] = l;
595
687
  }
596
688
  }
689
+ #else
690
+ const float s_factor = 15.f;
691
+ float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
692
+ float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
693
+ int d1 = nearest_int(inv_scale*scales[0]);
694
+ int m1 = nearest_int(inv_min*mins[0]);
695
+ int d2 = nearest_int(inv_scale*scales[1]);
696
+ int m2 = nearest_int(inv_min*mins[1]);
697
+ y[i].scales[0] = d1 | (m1 << 4);
698
+ y[i].scales[1] = d2 | (m2 << 4);
699
+ y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
700
+ y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
701
+
702
+ float sumlx = 0;
703
+ int suml2 = 0;
704
+ for (int j = 0; j < QK_K/32; ++j) {
705
+ const uint8_t sd = y[i].scales[j] & 0xF;
706
+ const uint8_t sm = y[i].scales[j] >> 4;
707
+ const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
708
+ if (!d) continue;
709
+ const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
710
+ for (int ii = 0; ii < 32; ++ii) {
711
+ int l = nearest_int((x[32*j + ii] + m)/d);
712
+ l = MAX(0, MIN(15, l));
713
+ L[32*j + ii] = l;
714
+ sumlx += (x[32*j + ii] + m)*l*sd;
715
+ suml2 += l*l*sd*sd;
716
+ }
717
+ }
718
+ if (suml2) {
719
+ y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
720
+ }
721
+ #endif
597
722
  uint8_t * q = y[i].qs;
598
723
  for (int j = 0; j < QK_K; j += 64) {
599
- for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
724
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
725
+ q += 32;
600
726
  }
601
727
 
602
728
  x += QK_K;
@@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
610
736
 
611
737
  for (int i = 0; i < nb; i++) {
612
738
 
613
- const float d = ggml_fp16_to_fp32(x[i].d);
614
- const float min = ggml_fp16_to_fp32(x[i].dmin);
615
-
616
739
  const uint8_t * q = x[i].qs;
617
740
 
741
+ #if QK_K == 256
742
+
743
+ const float d = ggml_fp16_to_fp32(x[i].d);
744
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
745
+
618
746
  int is = 0;
619
747
  uint8_t sc, m;
620
748
  for (int j = 0; j < QK_K; j += 64) {
@@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
626
754
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
627
755
  q += 32; is += 2;
628
756
  }
757
+ #else
758
+ const float dall = ggml_fp16_to_fp32(x[i].d[0]);
759
+ const float mall = ggml_fp16_to_fp32(x[i].d[1]);
760
+ const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
761
+ const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
762
+ for (int l = 0; l < 32; ++l) {
763
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
764
+ y[l+32] = d2 * (q[l] >> 4) - m2;
765
+ }
766
+ y += QK_K;
767
+ #endif
629
768
 
630
769
  }
631
770
  }
@@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
653
792
  assert(k % QK_K == 0);
654
793
  const int nb = k / QK_K;
655
794
 
795
+ #if QK_K == 256
656
796
  uint8_t L[QK_K];
657
797
  float mins[QK_K/32];
658
798
  float scales[QK_K/32];
799
+ #else
800
+ int8_t L[QK_K];
801
+ float scales[QK_K/16];
802
+ #endif
659
803
 
660
804
  for (int i = 0; i < nb; i++) {
661
805
 
806
+ #if QK_K == 256
807
+
662
808
  float max_scale = 0; // as we are deducting the min, scales are always positive
663
809
  float max_min = 0;
664
810
  for (int j = 0; j < QK_K/32; ++j) {
@@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
725
871
  m1 <<= 2; m2 <<= 2;
726
872
  ql += 32;
727
873
  }
874
+ #else
875
+ float max_scale = 0, amax = 0;
876
+ for (int j = 0; j < QK_K/16; ++j) {
877
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
878
+ float abs_scale = fabsf(scales[j]);
879
+ if (abs_scale > amax) {
880
+ amax = abs_scale;
881
+ max_scale = scales[j];
882
+ }
883
+ }
884
+
885
+ float iscale = -128.f/max_scale;
886
+ for (int j = 0; j < QK_K/16; ++j) {
887
+ int l = nearest_int(iscale*scales[j]);
888
+ y[i].scales[j] = MAX(-128, MIN(127, l));
889
+ }
890
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
891
+
892
+ for (int j = 0; j < QK_K/16; ++j) {
893
+ const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
894
+ if (!d) continue;
895
+ for (int ii = 0; ii < 16; ++ii) {
896
+ int l = nearest_int(x[16*j + ii]/d);
897
+ l = MAX(-16, MIN(15, l));
898
+ L[16*j + ii] = l + 16;
899
+ }
900
+ }
901
+
902
+ uint8_t * restrict qh = y[i].qh;
903
+ uint8_t * restrict ql = y[i].qs;
904
+ memset(qh, 0, QK_K/8);
905
+
906
+ for (int j = 0; j < 32; ++j) {
907
+ int jm = j%8;
908
+ int is = j/8;
909
+ int l1 = L[j];
910
+ if (l1 > 15) {
911
+ l1 -= 16; qh[jm] |= (1 << is);
912
+ }
913
+ int l2 = L[j + 32];
914
+ if (l2 > 15) {
915
+ l2 -= 16; qh[jm] |= (1 << (4 + is));
916
+ }
917
+ ql[j] = l1 | (l2 << 4);
918
+ }
919
+ #endif
728
920
 
729
921
  x += QK_K;
730
922
 
@@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
737
929
 
738
930
  for (int i = 0; i < nb; i++) {
739
931
 
740
- const float d = ggml_fp16_to_fp32(x[i].d);
741
- const float min = ggml_fp16_to_fp32(x[i].dmin);
742
-
743
932
  const uint8_t * ql = x[i].qs;
744
933
  const uint8_t * qh = x[i].qh;
745
934
 
935
+ #if QK_K == 256
936
+
937
+ const float d = ggml_fp16_to_fp32(x[i].d);
938
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
939
+
746
940
  int is = 0;
747
941
  uint8_t sc, m;
748
942
  uint8_t u1 = 1, u2 = 2;
@@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
756
950
  ql += 32; is += 2;
757
951
  u1 <<= 2; u2 <<= 2;
758
952
  }
953
+ #else
954
+ float d = ggml_fp16_to_fp32(x[i].d);
955
+ const int8_t * restrict s = x[i].scales;
956
+ for (int l = 0; l < 8; ++l) {
957
+ y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
958
+ y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
959
+ y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
960
+ y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
961
+ y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
962
+ y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
963
+ y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
964
+ y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
965
+ }
966
+ y += QK_K;
967
+ #endif
759
968
  }
760
969
  }
761
970
 
@@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
823
1032
 
824
1033
  uint8_t * restrict ql = y[i].ql;
825
1034
  uint8_t * restrict qh = y[i].qh;
1035
+ #if QK_K == 256
826
1036
  for (int j = 0; j < QK_K; j += 128) {
827
1037
  for (int l = 0; l < 32; ++l) {
828
1038
  const uint8_t q1 = L[j + l + 0] & 0xF;
@@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
836
1046
  ql += 64;
837
1047
  qh += 32;
838
1048
  }
1049
+ #else
1050
+ for (int l = 0; l < 32; ++l) {
1051
+ const uint8_t q1 = L[l + 0] & 0xF;
1052
+ const uint8_t q2 = L[l + 32] & 0xF;
1053
+ ql[l] = q1 | (q2 << 4);
1054
+ }
1055
+ for (int l = 0; l < 16; ++l) {
1056
+ qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
1057
+ }
1058
+ #endif
839
1059
 
840
1060
  x += QK_K;
841
1061
 
@@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
854
1074
  const uint8_t * restrict qh = x[i].qh;
855
1075
  const int8_t * restrict sc = x[i].scales;
856
1076
 
1077
+ #if QK_K == 256
857
1078
  for (int n = 0; n < QK_K; n += 128) {
858
1079
  for (int l = 0; l < 32; ++l) {
859
1080
  int is = l/16;
@@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
871
1092
  qh += 32;
872
1093
  sc += 8;
873
1094
  }
1095
+ #else
1096
+ for (int l = 0; l < 16; ++l) {
1097
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1098
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1099
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1100
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1101
+ y[l+ 0] = d * sc[0] * q1;
1102
+ y[l+16] = d * sc[1] * q2;
1103
+ y[l+32] = d * sc[2] * q3;
1104
+ y[l+48] = d * sc[3] * q4;
1105
+ }
1106
+ y += 64;
1107
+ #endif
874
1108
 
875
1109
  }
876
1110
  }
@@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) {
1002
1236
  }
1003
1237
  #endif
1004
1238
 
1239
+ #if QK_K == 256
1005
1240
  void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1006
1241
 
1007
1242
  const block_q2_K * restrict x = vx;
@@ -1158,6 +1393,112 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1158
1393
 
1159
1394
  *s = hsum_float_8(acc);
1160
1395
 
1396
+ #elif defined __AVX__
1397
+
1398
+ const __m128i m3 = _mm_set1_epi8(0x3);
1399
+ const __m128i m4 = _mm_set1_epi8(0xF);
1400
+ const __m128i m2 = _mm_set1_epi8(0x2);
1401
+
1402
+ __m256 acc = _mm256_setzero_ps();
1403
+
1404
+ for (int i = 0; i < nb; ++i) {
1405
+
1406
+ const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1407
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1408
+
1409
+ const uint8_t * restrict q2 = x[i].qs;
1410
+ const int8_t * restrict q8 = y[i].qs;
1411
+
1412
+ // load mins and scales from block_q2_K.scales[QK_K/16]
1413
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1414
+ const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
1415
+ const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1416
+ const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
1417
+ const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
1418
+
1419
+ // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
1420
+ const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
1421
+ const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1422
+
1423
+ // sumf += -dmin * summs in 32bits*8
1424
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
1425
+
1426
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1427
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
1428
+ const __m128i scales[2] = { scales_0, scales_1 };
1429
+
1430
+ __m128i sumi_0 = _mm_setzero_si128();
1431
+ __m128i sumi_1 = _mm_setzero_si128();
1432
+
1433
+ for (int j = 0; j < QK_K/128; ++j) {
1434
+
1435
+ // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
1436
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1437
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1438
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1439
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1440
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1441
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1442
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1443
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1444
+
1445
+ // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
1446
+ __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1447
+ const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1448
+ const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1449
+ const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1450
+ const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1451
+ q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1452
+ const __m128i q2_1 = _mm_and_si128(q2bits, m3);
1453
+ const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1454
+ const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1455
+ const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1456
+
1457
+ // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
1458
+ __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
1459
+ __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
1460
+ __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
1461
+ __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
1462
+ __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
1463
+ __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
1464
+ __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
1465
+ __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
1466
+
1467
+ // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
1468
+ __m128i shuffle = _mm_set1_epi16(0x0100);
1469
+ p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
1470
+ shuffle = _mm_add_epi16(shuffle, m2);
1471
+ p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
1472
+ shuffle = _mm_add_epi16(shuffle, m2);
1473
+ p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
1474
+ shuffle = _mm_add_epi16(shuffle, m2);
1475
+ p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
1476
+ shuffle = _mm_add_epi16(shuffle, m2);
1477
+ p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
1478
+ shuffle = _mm_add_epi16(shuffle, m2);
1479
+ p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
1480
+ shuffle = _mm_add_epi16(shuffle, m2);
1481
+ p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
1482
+ shuffle = _mm_add_epi16(shuffle, m2);
1483
+ p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
1484
+
1485
+ p0 = _mm_add_epi32(p0, p1);
1486
+ p2 = _mm_add_epi32(p2, p3);
1487
+ p4 = _mm_add_epi32(p4, p5);
1488
+ p6 = _mm_add_epi32(p6, p7);
1489
+
1490
+ // isum in 32bits*4*2
1491
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
1492
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
1493
+ }
1494
+
1495
+ // sumf += dall * isum - dmin * summs in 32bits
1496
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
1497
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1498
+ }
1499
+
1500
+ *s = hsum_float_8(acc);
1501
+
1161
1502
  #else
1162
1503
 
1163
1504
  float sumf = 0;
@@ -1201,6 +1542,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1201
1542
  #endif
1202
1543
  }
1203
1544
 
1545
+ #else
1546
+
1547
+ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1548
+
1549
+ const block_q2_K * restrict x = vx;
1550
+ const block_q8_K * restrict y = vy;
1551
+
1552
+ const int nb = n / QK_K;
1553
+
1554
+ #ifdef __ARM_NEON
1555
+
1556
+ const uint8x16_t m3 = vdupq_n_u8(0x3);
1557
+ const int32x4_t vzero = vdupq_n_s32(0);
1558
+
1559
+ int8x16x4_t q2bytes;
1560
+
1561
+ uint32_t aux32[2];
1562
+ const uint8_t * scales = (const uint8_t *)aux32;
1563
+
1564
+ float sum = 0;
1565
+
1566
+ for (int i = 0; i < nb; ++i) {
1567
+
1568
+ const float d = y[i].d * (float)x[i].d;
1569
+ const float dmin = -y[i].d * (float)x[i].dmin;
1570
+
1571
+ const uint8_t * restrict q2 = x[i].qs;
1572
+ const int8_t * restrict q8 = y[i].qs;
1573
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1574
+
1575
+ aux32[0] = sc[0] & 0x0f0f0f0f;
1576
+ aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
1577
+
1578
+ sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
1579
+
1580
+ int isum1 = 0, isum2 = 0;
1581
+
1582
+ const uint8x16_t q2bits = vld1q_u8(q2);
1583
+
1584
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
1585
+
1586
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
1587
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
1588
+ q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
1589
+ q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
1590
+
1591
+ #if defined(__ARM_FEATURE_DOTPROD)
1592
+ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
1593
+ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
1594
+ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
1595
+ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
1596
+ #else
1597
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
1598
+ vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
1599
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
1600
+ vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
1601
+ isum1 += vaddvq_s16(p1) * scales[0];
1602
+ isum2 += vaddvq_s16(p2) * scales[1];
1603
+
1604
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
1605
+ vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
1606
+ const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
1607
+ vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
1608
+ isum1 += vaddvq_s16(p3) * scales[2];
1609
+ isum2 += vaddvq_s16(p4) * scales[3];
1610
+ #endif
1611
+ sum += d * (isum1 + isum2);
1612
+
1613
+ }
1614
+
1615
+ *s = sum;
1616
+
1617
+ #elif defined __AVX2__
1618
+
1619
+ const __m256i m3 = _mm256_set1_epi8(3);
1620
+
1621
+ __m256 acc = _mm256_setzero_ps();
1622
+
1623
+ uint32_t ud, um;
1624
+ const uint8_t * restrict db = (const uint8_t *)&ud;
1625
+ const uint8_t * restrict mb = (const uint8_t *)&um;
1626
+
1627
+ float summs = 0;
1628
+
1629
+ // TODO: optimize this
1630
+
1631
+ for (int i = 0; i < nb; ++i) {
1632
+
1633
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1634
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1635
+
1636
+ const uint8_t * restrict q2 = x[i].qs;
1637
+ const int8_t * restrict q8 = y[i].qs;
1638
+
1639
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1640
+ ud = (sc[0] >> 0) & 0x0f0f0f0f;
1641
+ um = (sc[0] >> 4) & 0x0f0f0f0f;
1642
+
1643
+ int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
1644
+ summs += dmin * smin;
1645
+
1646
+ const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1647
+ const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
1648
+ const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
1649
+
1650
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1651
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
1652
+
1653
+ const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
1654
+ const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
1655
+
1656
+ const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
1657
+ const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
1658
+ const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
1659
+ const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
1660
+
1661
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
1662
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
1663
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
1664
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
1665
+ }
1666
+
1667
+ *s = hsum_float_8(acc) + summs;
1668
+
1669
+ #else
1670
+
1671
+ float sumf = 0;
1672
+
1673
+ int isum[4];
1674
+
1675
+ for (int i = 0; i < nb; ++i) {
1676
+
1677
+ const uint8_t * q2 = x[i].qs;
1678
+ const int8_t * q8 = y[i].qs;
1679
+ const uint8_t * sc = x[i].scales;
1680
+
1681
+ int summs = 0;
1682
+ for (int j = 0; j < QK_K/16; ++j) {
1683
+ summs += y[i].bsums[j] * (sc[j] >> 4);
1684
+ }
1685
+
1686
+ const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1687
+ const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1688
+
1689
+ isum[0] = isum[1] = isum[2] = isum[3] = 0;
1690
+ for (int l = 0; l < 16; ++l) {
1691
+ isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
1692
+ isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
1693
+ isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
1694
+ isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
1695
+ }
1696
+ for (int l = 0; l < 4; ++l) {
1697
+ isum[l] *= (sc[l] & 0xF);
1698
+ }
1699
+ sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
1700
+ }
1701
+ *s = sumf;
1702
+ #endif
1703
+ }
1704
+ #endif
1705
+
1706
+ #if QK_K == 256
1204
1707
  void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1205
1708
  assert(n % QK_K == 0);
1206
1709
 
@@ -1434,34 +1937,176 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1434
1937
 
1435
1938
  *s = hsum_float_8(acc);
1436
1939
 
1437
- #else
1438
- // scalar version
1439
- // This function is written like this so the compiler can manage to vectorize most of it
1440
- // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1441
- // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1442
- // The ideal situation would be if we could just write the code once, and the compiler would
1443
- // automatically produce the best possible set of machine instructions, instead of us having to manually
1444
- // write vectorized versions for AVX, ARM_NEON, etc.
1940
+ #elif defined __AVX__
1445
1941
 
1446
- int8_t aux8[QK_K];
1447
- int16_t aux16[8];
1448
- float sums [8];
1449
- int32_t aux32[8];
1450
- memset(sums, 0, 8*sizeof(float));
1942
+ const __m128i m3 = _mm_set1_epi8(3);
1943
+ const __m128i mone = _mm_set1_epi8(1);
1944
+ const __m128i m32 = _mm_set1_epi8(32);
1945
+ const __m128i m2 = _mm_set1_epi8(2);
1451
1946
 
1452
- uint32_t auxs[4];
1453
- const int8_t * scales = (const int8_t*)auxs;
1947
+ __m256 acc = _mm256_setzero_ps();
1948
+
1949
+ uint32_t *aux;
1454
1950
 
1455
- float sumf = 0;
1456
1951
  for (int i = 0; i < nb; ++i) {
1952
+
1953
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1954
+
1457
1955
  const uint8_t * restrict q3 = x[i].qs;
1458
- const uint8_t * restrict hm = x[i].hmask;
1459
- const int8_t * restrict q8 = y[i].qs;
1460
- memset(aux32, 0, 8*sizeof(int32_t));
1461
- int8_t * restrict a = aux8;
1462
- uint8_t m = 1;
1463
- for (int j = 0; j < QK_K; j += 128) {
1464
- for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1956
+ const int8_t * restrict q8 = y[i].qs;
1957
+
1958
+ // Set up scales
1959
+ aux = (uint32_t *)x[i].scales;
1960
+ __m128i scales128 = _mm_set_epi32(
1961
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1962
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1963
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1964
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1965
+ scales128 = _mm_sub_epi8(scales128, m32);
1966
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
1967
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
1968
+ const __m128i scales[2] = { scales_0, scales_1 };
1969
+
1970
+ // high bit *128*2 from block_q3_K.hmask[QK_K/8]
1971
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
1972
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
1973
+
1974
+ // integer accumulator
1975
+ __m128i sumi_0 = _mm_setzero_si128();
1976
+ __m128i sumi_1 = _mm_setzero_si128();
1977
+
1978
+ for (int j = 0; j < QK_K/128; ++j) {
1979
+ // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
1980
+ const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1981
+ const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1982
+
1983
+ // prepare low and high bits
1984
+ const int bit = j << 2;
1985
+
1986
+ const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
1987
+ const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
1988
+ const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
1989
+ const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
1990
+
1991
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
1992
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
1993
+ const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1994
+ const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1995
+
1996
+ const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
1997
+ const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
1998
+ const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1999
+ const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
2000
+
2001
+ const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
2002
+ const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
2003
+ const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
2004
+ const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
2005
+
2006
+ // load Q8 quants from block_q8_K.qs[QK_K]
2007
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2008
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2009
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2010
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2011
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2012
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2013
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2014
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2015
+
2016
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
2017
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2018
+ // and 2 if the high bit was set)
2019
+ __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
2020
+ __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
2021
+ __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
2022
+ __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
2023
+ __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
2024
+ __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
2025
+ __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
2026
+ __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
2027
+
2028
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
2029
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
2030
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
2031
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
2032
+ __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
2033
+ __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
2034
+ __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
2035
+ __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
2036
+
2037
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
2038
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
2039
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
2040
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
2041
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
2042
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
2043
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
2044
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
2045
+
2046
+ // multiply with scales
2047
+ __m128i shuffle = _mm_set1_epi16(0x0100);
2048
+ p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
2049
+ shuffle = _mm_add_epi16(shuffle, m2);
2050
+ p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
2051
+ shuffle = _mm_add_epi16(shuffle, m2);
2052
+ p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
2053
+ shuffle = _mm_add_epi16(shuffle, m2);
2054
+ p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
2055
+ shuffle = _mm_add_epi16(shuffle, m2);
2056
+ p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
2057
+ shuffle = _mm_add_epi16(shuffle, m2);
2058
+ p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
2059
+ shuffle = _mm_add_epi16(shuffle, m2);
2060
+ p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
2061
+ shuffle = _mm_add_epi16(shuffle, m2);
2062
+ p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
2063
+
2064
+ // accumulate
2065
+ p16_0 = _mm_add_epi32(p16_0, p16_1);
2066
+ p16_2 = _mm_add_epi32(p16_2, p16_3);
2067
+ p16_4 = _mm_add_epi32(p16_4, p16_5);
2068
+ p16_6 = _mm_add_epi32(p16_6, p16_7);
2069
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2070
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
2071
+
2072
+ }
2073
+
2074
+ // multiply with block scale and accumulate
2075
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2076
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
2077
+
2078
+ }
2079
+
2080
+ *s = hsum_float_8(acc);
2081
+
2082
+ #else
2083
+ // scalar version
2084
+ // This function is written like this so the compiler can manage to vectorize most of it
2085
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
2086
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
2087
+ // The ideal situation would be if we could just write the code once, and the compiler would
2088
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
2089
+ // write vectorized versions for AVX, ARM_NEON, etc.
2090
+
2091
+ int8_t aux8[QK_K];
2092
+ int16_t aux16[8];
2093
+ float sums [8];
2094
+ int32_t aux32[8];
2095
+ memset(sums, 0, 8*sizeof(float));
2096
+
2097
+ uint32_t auxs[4];
2098
+ const int8_t * scales = (const int8_t*)auxs;
2099
+
2100
+ float sumf = 0;
2101
+ for (int i = 0; i < nb; ++i) {
2102
+ const uint8_t * restrict q3 = x[i].qs;
2103
+ const uint8_t * restrict hm = x[i].hmask;
2104
+ const int8_t * restrict q8 = y[i].qs;
2105
+ memset(aux32, 0, 8*sizeof(int32_t));
2106
+ int8_t * restrict a = aux8;
2107
+ uint8_t m = 1;
2108
+ for (int j = 0; j < QK_K; j += 128) {
2109
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1465
2110
  for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1466
2111
  a += 32; m <<= 1;
1467
2112
  for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
@@ -1501,6 +2146,206 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1501
2146
 
1502
2147
  }
1503
2148
 
2149
+ #else
2150
+
2151
+ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2152
+ assert(n % QK_K == 0);
2153
+
2154
+ const block_q3_K * restrict x = vx;
2155
+ const block_q8_K * restrict y = vy;
2156
+
2157
+ const int nb = n / QK_K;
2158
+
2159
+ #ifdef __ARM_NEON
2160
+
2161
+ #ifdef __ARM_FEATURE_DOTPROD
2162
+ const int32x4_t vzero = vdupq_n_s32(0);
2163
+ #endif
2164
+
2165
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
2166
+ const uint8x16_t mh = vdupq_n_u8(4);
2167
+
2168
+ int8x16x4_t q3bytes;
2169
+
2170
+ uint16_t aux16[2];
2171
+ int8_t * scales = (int8_t *)aux16;
2172
+
2173
+ float sum = 0;
2174
+
2175
+ for (int i = 0; i < nb; ++i) {
2176
+
2177
+ uint8x16x4_t q3h;
2178
+
2179
+ const uint8x8_t hbits = vld1_u8(x[i].hmask);
2180
+ const uint8x16_t q3bits = vld1q_u8(x[i].qs);
2181
+ const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
2182
+
2183
+ const uint16_t a = *(const uint16_t *)x[i].scales;
2184
+ aux16[0] = a & 0x0f0f;
2185
+ aux16[1] = (a >> 4) & 0x0f0f;
2186
+
2187
+ for (int j = 0; j < 4; ++j) scales[j] -= 8;
2188
+
2189
+ int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
2190
+
2191
+ const float d = y[i].d * (float)x[i].d;
2192
+
2193
+ const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
2194
+ q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
2195
+ q3h.val[1] = vandq_u8(mh, htmp);
2196
+ q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
2197
+ q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
2198
+
2199
+ q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
2200
+ q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
2201
+ q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
2202
+ q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
2203
+
2204
+ #if defined(__ARM_FEATURE_DOTPROD)
2205
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
2206
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
2207
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
2208
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
2209
+ #else
2210
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
2211
+ vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
2212
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
2213
+ vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
2214
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
2215
+ vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
2216
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
2217
+ vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
2218
+ isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
2219
+ #endif
2220
+
2221
+ sum += d * isum;
2222
+
2223
+ }
2224
+
2225
+ *s = sum;
2226
+
2227
+ #elif defined __AVX2__
2228
+
2229
+ const __m256i m3 = _mm256_set1_epi8(3);
2230
+ const __m256i m1 = _mm256_set1_epi8(1);
2231
+
2232
+ __m256 acc = _mm256_setzero_ps();
2233
+
2234
+ uint64_t aux64;
2235
+
2236
+ uint16_t aux16[2];
2237
+ const int8_t * aux8 = (const int8_t *)aux16;
2238
+
2239
+ for (int i = 0; i < nb; ++i) {
2240
+
2241
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2242
+
2243
+ const uint8_t * restrict q3 = x[i].qs;
2244
+ const int8_t * restrict q8 = y[i].qs;
2245
+
2246
+ const uint16_t a = *(const uint16_t *)x[i].scales;
2247
+ aux16[0] = a & 0x0f0f;
2248
+ aux16[1] = (a >> 4) & 0x0f0f;
2249
+
2250
+ const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
2251
+ const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
2252
+
2253
+ memcpy(&aux64, x[i].hmask, 8);
2254
+
2255
+ const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2256
+ __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
2257
+ __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
2258
+ q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
2259
+ q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
2260
+
2261
+ // load low 2 bits
2262
+ const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2263
+
2264
+ // prepare low and high bits
2265
+ const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
2266
+ const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
2267
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
2268
+
2269
+ // load Q8 quants
2270
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2271
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2272
+
2273
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
2274
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2275
+ // and 2 if the high bit was set)
2276
+ const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
2277
+ const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
2278
+
2279
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
2280
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
2281
+
2282
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2283
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2284
+
2285
+ // multiply with scales
2286
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2287
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
2288
+
2289
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
2290
+
2291
+ // multiply with block scale and accumulate
2292
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
2293
+
2294
+ }
2295
+
2296
+ *s = hsum_float_8(acc);
2297
+
2298
+ #else
2299
+
2300
+ int8_t aux8[QK_K];
2301
+ int16_t aux16[8];
2302
+ float sums [8];
2303
+ int32_t aux32[8];
2304
+ int32_t scales[4];
2305
+ memset(sums, 0, 8*sizeof(float));
2306
+
2307
+ float sumf = 0;
2308
+ for (int i = 0; i < nb; ++i) {
2309
+ const uint8_t * restrict q3 = x[i].qs;
2310
+ const uint8_t * restrict hm = x[i].hmask;
2311
+ const int8_t * restrict q8 = y[i].qs;
2312
+ int8_t * restrict a = aux8;
2313
+ for (int l = 0; l < 8; ++l) {
2314
+ a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
2315
+ a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
2316
+ a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
2317
+ a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
2318
+ a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
2319
+ a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
2320
+ a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
2321
+ a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
2322
+ }
2323
+
2324
+ scales[0] = (x[i].scales[0] & 0xF) - 8;
2325
+ scales[1] = (x[i].scales[0] >> 4) - 8;
2326
+ scales[2] = (x[i].scales[1] & 0xF) - 8;
2327
+ scales[3] = (x[i].scales[1] >> 4) - 8;
2328
+
2329
+ memset(aux32, 0, 8*sizeof(int32_t));
2330
+ for (int j = 0; j < QK_K/16; ++j) {
2331
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2332
+ q8 += 8; a += 8;
2333
+ for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
2334
+ q8 += 8; a += 8;
2335
+ for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
2336
+ }
2337
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
2338
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2339
+ }
2340
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2341
+ *s = sumf;
2342
+
2343
+ #endif
2344
+
2345
+ }
2346
+ #endif
2347
+
2348
+ #if QK_K == 256
1504
2349
  void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1505
2350
  assert(n % QK_K == 0);
1506
2351
 
@@ -1614,9 +2459,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1614
2459
  const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1615
2460
  const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1616
2461
 
1617
- const uint8_t * restrict q4 = x[i].qs;
1618
- const int8_t * restrict q8 = y[i].qs;
1619
-
1620
2462
  memcpy(utmp, x[i].scales, 12);
1621
2463
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1622
2464
  const uint32_t uaux = utmp[1] & kmask1;
@@ -1624,6 +2466,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1624
2466
  utmp[2] = uaux;
1625
2467
  utmp[0] &= kmask1;
1626
2468
 
2469
+ const uint8_t * restrict q4 = x[i].qs;
2470
+ const int8_t * restrict q8 = y[i].qs;
2471
+
1627
2472
  const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1628
2473
 
1629
2474
  const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
@@ -1667,6 +2512,88 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1667
2512
 
1668
2513
  *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1669
2514
 
2515
+ #elif defined __AVX__
2516
+
2517
+ const __m128i m4 = _mm_set1_epi8(0xF);
2518
+ const __m128i m2 = _mm_set1_epi8(0x2);
2519
+
2520
+ __m256 acc = _mm256_setzero_ps();
2521
+ __m128 acc_m = _mm_setzero_ps();
2522
+
2523
+ for (int i = 0; i < nb; ++i) {
2524
+
2525
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2526
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
2527
+
2528
+ const uint8_t * restrict q4 = x[i].qs;
2529
+ const int8_t * restrict q8 = y[i].qs;
2530
+
2531
+ memcpy(utmp, x[i].scales, 12);
2532
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2533
+ const uint32_t uaux = utmp[1] & kmask1;
2534
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2535
+ utmp[2] = uaux;
2536
+ utmp[0] &= kmask1;
2537
+
2538
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
2539
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
2540
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
2541
+
2542
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
2543
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
2544
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
2545
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
2546
+ acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
2547
+
2548
+ __m128i sumi_0 = _mm_setzero_si128();
2549
+ __m128i sumi_1 = _mm_setzero_si128();
2550
+
2551
+ __m128i shuffle = _mm_set1_epi16(0x0100);
2552
+ for (int j = 0; j < QK_K/64; ++j) {
2553
+
2554
+ const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
2555
+ shuffle = _mm_add_epi16(shuffle, m2);
2556
+ const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
2557
+ shuffle = _mm_add_epi16(shuffle, m2);
2558
+
2559
+ __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2560
+ const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
2561
+ const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
2562
+ q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2563
+ const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
2564
+ const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
2565
+
2566
+ const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2567
+ __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
2568
+ p16l = _mm_madd_epi16(scale_l, p16l);
2569
+ sumi_0 = _mm_add_epi32(sumi_0, p16l);
2570
+ const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2571
+ p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
2572
+ p16l = _mm_madd_epi16(scale_l, p16l);
2573
+ sumi_1 = _mm_add_epi32(sumi_1, p16l);
2574
+
2575
+ const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2576
+ __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
2577
+ p16h = _mm_madd_epi16(scale_h, p16h);
2578
+ sumi_0 = _mm_add_epi32(sumi_0, p16h);
2579
+ const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2580
+ p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
2581
+ p16h = _mm_madd_epi16(scale_h, p16h);
2582
+ sumi_1 = _mm_add_epi32(sumi_1, p16h);
2583
+
2584
+ }
2585
+
2586
+ __m256 vd = _mm256_set1_ps(d);
2587
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2588
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2589
+
2590
+ }
2591
+
2592
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
2593
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
2594
+
2595
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
2596
+
1670
2597
  #else
1671
2598
 
1672
2599
 
@@ -1726,7 +2653,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
1726
2653
  *s = sumf;
1727
2654
  #endif
1728
2655
  }
2656
+ #else
2657
+ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2658
+ assert(n % QK_K == 0);
2659
+
2660
+ const block_q4_K * restrict x = vx;
2661
+ const block_q8_K * restrict y = vy;
2662
+
2663
+ const int nb = n / QK_K;
2664
+
2665
+ #ifdef __ARM_NEON
2666
+
2667
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2668
+
2669
+ #ifdef __ARM_FEATURE_DOTPROD
2670
+ const int32x4_t mzero = vdupq_n_s32(0);
2671
+ #endif
2672
+
2673
+ float sumf = 0;
2674
+
2675
+ int8x16x2_t q4bytes;
2676
+ int8x16x4_t q8bytes;
2677
+
2678
+ float sum_mins = 0.f;
2679
+
2680
+ uint16_t aux16[2];
2681
+ const uint8_t * restrict scales = (const uint8_t *)aux16;
2682
+
2683
+ for (int i = 0; i < nb; ++i) {
2684
+
2685
+ const uint8_t * restrict q4 = x[i].qs;
2686
+ const int8_t * restrict q8 = y[i].qs;
2687
+
2688
+ const uint16_t * restrict a = (const uint16_t *)x[i].scales;
2689
+ aux16[0] = a[0] & 0x0f0f;
2690
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
2691
+
2692
+ const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
2693
+ sum_mins += y[i].d * (float)x[i].d[1] * summi;
2694
+
2695
+ const float d = y[i].d * (float)x[i].d[0];
2696
+
2697
+ const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
2698
+
2699
+ #ifdef __ARM_FEATURE_DOTPROD
2700
+ q8bytes = vld1q_s8_x4(q8);
2701
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
2702
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
2703
+
2704
+ const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
2705
+ const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
2706
+
2707
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2708
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2709
+
2710
+ const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
2711
+ const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
2712
+
2713
+ #else
2714
+ q8bytes = vld1q_s8_x4(q8);
2715
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
2716
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
2717
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
2718
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
2719
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
2720
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
2721
+ int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
2722
+
2723
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
2724
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
2725
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
2726
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
2727
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
2728
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
2729
+ int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
2730
+
2731
+ #endif
2732
+ sumf += d * (sumi1 + sumi2);
2733
+
2734
+ }
2735
+
2736
+ *s = sumf - sum_mins;
2737
+
2738
+ #elif defined __AVX2__
2739
+
2740
+ const __m256i m4 = _mm256_set1_epi8(0xF);
2741
+
2742
+ __m256 acc = _mm256_setzero_ps();
2743
+
2744
+ float summs = 0;
2745
+
2746
+ uint16_t aux16[2];
2747
+ const uint8_t * scales = (const uint8_t *)aux16;
2748
+
2749
+ for (int i = 0; i < nb; ++i) {
2750
+
2751
+ const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
2752
+ const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
2753
+ const __m256 vd = _mm256_set1_ps(d);
2754
+
2755
+ const uint16_t * a = (const uint16_t *)x[i].scales;
2756
+ aux16[0] = a[0] & 0x0f0f;
2757
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
2758
+
2759
+ summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2760
+
2761
+ const uint8_t * restrict q4 = x[i].qs;
2762
+ const int8_t * restrict q8 = y[i].qs;
2763
+
2764
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
2765
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
2766
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
1729
2767
 
2768
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2769
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
2770
+
2771
+ const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
2772
+ const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
2773
+
2774
+ const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
2775
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
2776
+
2777
+ const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
2778
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
2779
+
2780
+ }
2781
+
2782
+ *s = hsum_float_8(acc) - summs;
2783
+
2784
+ #else
2785
+
2786
+ uint8_t aux8[QK_K];
2787
+ int16_t aux16[16];
2788
+ float sums [8];
2789
+ memset(sums, 0, 8*sizeof(float));
2790
+
2791
+ uint16_t s16[2];
2792
+ const uint8_t * restrict scales = (const uint8_t *)s16;
2793
+
2794
+ float sumf = 0;
2795
+ for (int i = 0; i < nb; ++i) {
2796
+ const uint8_t * restrict q4 = x[i].qs;
2797
+ const int8_t * restrict q8 = y[i].qs;
2798
+ uint8_t * restrict a = aux8;
2799
+ for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
2800
+ for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
2801
+
2802
+ const uint16_t * restrict b = (const uint16_t *)x[i].scales;
2803
+ s16[0] = b[0] & 0x0f0f;
2804
+ s16[1] = (b[0] >> 4) & 0x0f0f;
2805
+
2806
+ sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2807
+
2808
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
2809
+
2810
+ for (int j = 0; j < QK_K/32; ++j) {
2811
+ for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
2812
+ q8 += 16; a += 16;
2813
+ for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
2814
+ q8 += 16; a += 16;
2815
+ const float dl = d * scales[j];
2816
+ for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
2817
+ }
2818
+ }
2819
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2820
+ *s = sumf;
2821
+ #endif
2822
+ }
2823
+ #endif
2824
+
2825
+ #if QK_K == 256
1730
2826
  void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1731
2827
  assert(n % QK_K == 0);
1732
2828
 
@@ -1840,18 +2936,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
1840
2936
 
1841
2937
  for (int i = 0; i < nb; ++i) {
1842
2938
 
1843
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1844
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1845
-
1846
2939
  const uint8_t * restrict q5 = x[i].qs;
1847
2940
  const int8_t * restrict q8 = y[i].qs;
1848
2941
 
2942
+ #if QK_K == 256
2943
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2944
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
2945
+
1849
2946
  memcpy(utmp, x[i].scales, 12);
1850
2947
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1851
2948
  const uint32_t uaux = utmp[1] & kmask1;
1852
2949
  utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1853
2950
  utmp[2] = uaux;
1854
2951
  utmp[0] &= kmask1;
2952
+ #else
2953
+ // TODO
2954
+ const float d = 0, dmin = 0;
2955
+ #endif
1855
2956
 
1856
2957
  const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
1857
2958
 
@@ -1876,33 +2977,133 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
1876
2977
  const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
1877
2978
  const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
1878
2979
 
1879
- const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
2980
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
2981
+
2982
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
2983
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
2984
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
2985
+ hmask = _mm256_slli_epi16(hmask, 1);
2986
+
2987
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
2988
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
2989
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
2990
+ hmask = _mm256_slli_epi16(hmask, 1);
2991
+
2992
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2993
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
2994
+
2995
+ __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
2996
+ __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
2997
+
2998
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
2999
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
3000
+
3001
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
3002
+
3003
+ }
3004
+
3005
+ __m256 vd = _mm256_set1_ps(d);
3006
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
3007
+
3008
+ }
3009
+
3010
+ *s = hsum_float_8(acc) + summs;
3011
+
3012
+ #elif defined __AVX__
3013
+
3014
+ const __m128i m4 = _mm_set1_epi8(0xF);
3015
+ const __m128i mzero = _mm_setzero_si128();
3016
+ const __m128i mone = _mm_set1_epi8(1);
3017
+ const __m128i m2 = _mm_set1_epi8(2);
3018
+
3019
+ __m256 acc = _mm256_setzero_ps();
3020
+
3021
+ float summs = 0.f;
3022
+
3023
+ for (int i = 0; i < nb; ++i) {
3024
+
3025
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3026
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3027
+
3028
+ const uint8_t * restrict q5 = x[i].qs;
3029
+ const int8_t * restrict q8 = y[i].qs;
3030
+
3031
+ memcpy(utmp, x[i].scales, 12);
3032
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
3033
+ const uint32_t uaux = utmp[1] & kmask1;
3034
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
3035
+ utmp[2] = uaux;
3036
+ utmp[0] &= kmask1;
3037
+
3038
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
3039
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
3040
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
1880
3041
 
1881
- const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
1882
- const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1883
- const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
1884
- hmask = _mm256_slli_epi16(hmask, 1);
3042
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
3043
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
3044
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
3045
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
3046
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
3047
+ summs += dmin * _mm_extract_epi32(hsum, 0);
1885
3048
 
1886
- const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
1887
- const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
1888
- const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
1889
- hmask = _mm256_slli_epi16(hmask, 1);
3049
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
3050
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
3051
+ __m128i hmask = mone;
1890
3052
 
1891
- const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
1892
- const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
3053
+ __m128i sumi_0 = _mm_setzero_si128();
3054
+ __m128i sumi_1 = _mm_setzero_si128();
1893
3055
 
1894
- __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
1895
- __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
3056
+ int bit = 0;
1896
3057
 
1897
- p16_0 = _mm256_madd_epi16(scale_0, p16_0);
1898
- p16_1 = _mm256_madd_epi16(scale_1, p16_1);
3058
+ __m128i shuffle = _mm_set1_epi16(0x0100);
3059
+ for (int j = 0; j < QK_K/64; ++j) {
1899
3060
 
1900
- sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
3061
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
3062
+ shuffle = _mm_add_epi16(shuffle, m2);
3063
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
3064
+ shuffle = _mm_add_epi16(shuffle, m2);
3065
+
3066
+ const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
3067
+ const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
3068
+
3069
+ __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
3070
+ __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
3071
+ __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
3072
+ __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
3073
+ __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
3074
+ __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
3075
+ hmask = _mm_slli_epi16(hmask, 1);
3076
+
3077
+ __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3078
+ __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3079
+ __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
3080
+ __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
3081
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
3082
+ p16_1 = _mm_madd_epi16(scale_0, p16_1);
3083
+
3084
+ q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
3085
+ q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
3086
+ q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
3087
+ q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
3088
+ q5_0 = _mm_add_epi8(q5l_0, q5h_0);
3089
+ q5_1 = _mm_add_epi8(q5l_1, q5h_1);
3090
+ hmask = _mm_slli_epi16(hmask, 1);
3091
+
3092
+ q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3093
+ q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3094
+ __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
3095
+ __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
3096
+ p16_2 = _mm_madd_epi16(scale_1, p16_2);
3097
+ p16_3 = _mm_madd_epi16(scale_1, p16_3);
3098
+
3099
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
3100
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
1901
3101
 
1902
3102
  }
1903
3103
 
1904
3104
  __m256 vd = _mm256_set1_ps(d);
1905
- acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
3105
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3106
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
1906
3107
 
1907
3108
  }
1908
3109
 
@@ -1972,8 +3173,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
1972
3173
  #endif
1973
3174
  }
1974
3175
 
3176
+ #else
3177
+
3178
+ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3179
+ assert(n % QK_K == 0);
3180
+
3181
+ const block_q5_K * restrict x = vx;
3182
+ const block_q8_K * restrict y = vy;
3183
+
3184
+ const int nb = n / QK_K;
3185
+
3186
+ #ifdef __ARM_NEON
3187
+
3188
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
3189
+ const int32x4_t mzero = vdupq_n_s32(0);
3190
+ const uint8x16_t mh = vdupq_n_u8(16);
3191
+
3192
+ int8x16x4_t q5bytes;
3193
+ uint8x16x4_t q5h;
3194
+
3195
+ float sumf = 0;
3196
+
3197
+ for (int i = 0; i < nb; ++i) {
3198
+
3199
+ const float d = y[i].d * (float)x[i].d;
3200
+ const int8_t * sc = x[i].scales;
3201
+
3202
+ const uint8_t * restrict q5 = x[i].qs;
3203
+ const uint8_t * restrict qh = x[i].qh;
3204
+ const int8_t * restrict q8 = y[i].qs;
3205
+
3206
+ const uint8x8_t qhbits = vld1_u8(qh);
3207
+
3208
+ const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
3209
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
3210
+
3211
+ const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
3212
+ q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
3213
+ q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
3214
+ q5h.val[2] = vbicq_u8(mh, htmp);
3215
+ q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
3216
+
3217
+ q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
3218
+ q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
3219
+ q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
3220
+ q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
3221
+
3222
+ #if defined(__ARM_FEATURE_DOTPROD)
3223
+
3224
+ int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
3225
+ int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
3226
+ int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
3227
+ int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
3228
+
3229
+ sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
3230
+
3231
+ #else
3232
+
3233
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
3234
+ vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
3235
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
3236
+ vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
3237
+ int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
3238
+
3239
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
3240
+ vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
3241
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
3242
+ vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
3243
+ sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
3244
+
3245
+ sumf += d*sumi;
3246
+ #endif
3247
+
3248
+ }
3249
+
3250
+ *s = sumf;
3251
+
3252
+ #elif defined __AVX2__
3253
+
3254
+ const __m256i m4 = _mm256_set1_epi8(0xF);
3255
+ const __m256i mone = _mm256_set1_epi8(1);
3256
+
3257
+ __m256 acc = _mm256_setzero_ps();
3258
+
3259
+ for (int i = 0; i < nb; ++i) {
3260
+
3261
+ const uint8_t * restrict q5 = x[i].qs;
3262
+ const int8_t * restrict q8 = y[i].qs;
3263
+
3264
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3265
+
3266
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3267
+
3268
+ const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
3269
+ const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
3270
+
3271
+ int64_t aux64;
3272
+ memcpy(&aux64, x[i].qh, 8);
3273
+ const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
3274
+ const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
3275
+
3276
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
3277
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
3278
+
3279
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
3280
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
3281
+
3282
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3283
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3284
+
3285
+ const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
3286
+ const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
3287
+ const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
3288
+ const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
3289
+
3290
+ const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
3291
+
3292
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
3293
+
3294
+ }
3295
+
3296
+ *s = hsum_float_8(acc);
3297
+
3298
+ #else
3299
+
3300
+
3301
+ uint8_t aux8[QK_K];
3302
+ int16_t aux16[16];
3303
+ float sums [8];
3304
+ memset(sums, 0, 8*sizeof(float));
3305
+
3306
+ float sumf = 0;
3307
+ for (int i = 0; i < nb; ++i) {
3308
+ const uint8_t * restrict q4 = x[i].qs;
3309
+ const uint8_t * restrict hm = x[i].qh;
3310
+ const int8_t * restrict q8 = y[i].qs;
3311
+ uint8_t * restrict a = aux8;
3312
+ for (int l = 0; l < 32; ++l) {
3313
+ a[l+ 0] = q4[l] & 0xF;
3314
+ a[l+32] = q4[l] >> 4;
3315
+ }
3316
+ for (int is = 0; is < 8; ++is) {
3317
+ uint8_t m = 1 << is;
3318
+ for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
3319
+ }
3320
+
3321
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3322
+ const int8_t * restrict sc = x[i].scales;
3323
+
3324
+ for (int j = 0; j < QK_K/16; ++j) {
3325
+ const float dl = d * sc[j];
3326
+ for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
3327
+ for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
3328
+ q8 += 16; a += 16;
3329
+ }
3330
+ }
3331
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
3332
+ *s = sumf;
3333
+ #endif
3334
+ }
3335
+ #endif
1975
3336
 
1976
3337
 
3338
+ #if QK_K == 256
1977
3339
  void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1978
3340
  assert(n % QK_K == 0);
1979
3341
 
@@ -2198,6 +3560,124 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
2198
3560
 
2199
3561
  *s = hsum_float_8(acc);
2200
3562
 
3563
+ #elif defined __AVX__
3564
+
3565
+ const __m128i m4 = _mm_set1_epi8(0xF);
3566
+ const __m128i m3 = _mm_set1_epi8(3);
3567
+ const __m128i m32s = _mm_set1_epi8(32);
3568
+ const __m128i m2 = _mm_set1_epi8(2);
3569
+
3570
+ __m256 acc = _mm256_setzero_ps();
3571
+
3572
+ for (int i = 0; i < nb; ++i) {
3573
+
3574
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3575
+
3576
+ const uint8_t * restrict q4 = x[i].ql;
3577
+ const uint8_t * restrict qh = x[i].qh;
3578
+ const int8_t * restrict q8 = y[i].qs;
3579
+
3580
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
3581
+
3582
+ __m128i sumi_0 = _mm_setzero_si128();
3583
+ __m128i sumi_1 = _mm_setzero_si128();
3584
+
3585
+ __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
3586
+ for (int j = 0; j < QK_K/128; ++j) {
3587
+
3588
+ const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
3589
+ const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
3590
+
3591
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
3592
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
3593
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
3594
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
3595
+ const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
3596
+ const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
3597
+ const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
3598
+ const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
3599
+
3600
+ const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3601
+ const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3602
+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3603
+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
3604
+
3605
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
3606
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
3607
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
3608
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
3609
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
3610
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
3611
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
3612
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
3613
+
3614
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3615
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3616
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3617
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3618
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3619
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3620
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3621
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
3622
+
3623
+ __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
3624
+ __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
3625
+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
3626
+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
3627
+ __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
3628
+ __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
3629
+ __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
3630
+ __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
3631
+
3632
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
3633
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
3634
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
3635
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
3636
+ __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
3637
+ __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
3638
+ __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
3639
+ __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
3640
+
3641
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
3642
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
3643
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
3644
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
3645
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
3646
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
3647
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
3648
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
3649
+
3650
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
3651
+ shuffle = _mm_add_epi8(shuffle, m2);
3652
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
3653
+ shuffle = _mm_add_epi8(shuffle, m2);
3654
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
3655
+ shuffle = _mm_add_epi8(shuffle, m2);
3656
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
3657
+ shuffle = _mm_add_epi8(shuffle, m2);
3658
+
3659
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
3660
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
3661
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
3662
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
3663
+ p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
3664
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
3665
+ p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
3666
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
3667
+
3668
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
3669
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
3670
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
3671
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
3672
+
3673
+ }
3674
+
3675
+ __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3676
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
3677
+ }
3678
+
3679
+ *s = hsum_float_8(acc);
3680
+
2201
3681
  #else
2202
3682
 
2203
3683
  int8_t aux8[QK_K];
@@ -2242,3 +3722,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
2242
3722
  *s = sumf;
2243
3723
  #endif
2244
3724
  }
3725
+
3726
+ #else
3727
+
3728
+ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3729
+ assert(n % QK_K == 0);
3730
+
3731
+ const block_q6_K * restrict x = vx;
3732
+ const block_q8_K * restrict y = vy;
3733
+
3734
+ const int nb = n / QK_K;
3735
+
3736
+ #ifdef __ARM_NEON
3737
+
3738
+ float sum = 0;
3739
+
3740
+ const uint8x16_t m4b = vdupq_n_u8(0xF);
3741
+ const int32x4_t vzero = vdupq_n_s32(0);
3742
+ const int8x16_t m32s = vdupq_n_s8(32);
3743
+
3744
+ const uint8x16_t mone = vdupq_n_u8(3);
3745
+
3746
+ int8x16x4_t q6bytes;
3747
+ uint8x16x4_t q6h;
3748
+
3749
+ for (int i = 0; i < nb; ++i) {
3750
+
3751
+ const float d_all = (float)x[i].d;
3752
+
3753
+ const uint8_t * restrict q6 = x[i].ql;
3754
+ const uint8_t * restrict qh = x[i].qh;
3755
+ const int8_t * restrict q8 = y[i].qs;
3756
+
3757
+ const int8_t * restrict scale = x[i].scales;
3758
+
3759
+ int32_t isum = 0;
3760
+
3761
+ uint8x16_t qhbits = vld1q_u8(qh);
3762
+ uint8x16x2_t q6bits = vld1q_u8_x2(q6);
3763
+ int8x16x4_t q8bytes = vld1q_s8_x4(q8);
3764
+
3765
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
3766
+ uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
3767
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3768
+ shifted = vshrq_n_u8(qhbits, 4);
3769
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3770
+ shifted = vshrq_n_u8(qhbits, 6);
3771
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
3772
+
3773
+ q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
3774
+ q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
3775
+ q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
3776
+ q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
3777
+
3778
+ #if defined(__ARM_FEATURE_DOTPROD)
3779
+
3780
+ isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
3781
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
3782
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
3783
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
3784
+ #else
3785
+
3786
+ int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
3787
+ vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
3788
+ int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
3789
+ vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
3790
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
3791
+
3792
+ int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
3793
+ vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
3794
+ int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
3795
+ vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
3796
+ isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
3797
+ #endif
3798
+
3799
+ sum += isum * d_all * y[i].d;
3800
+
3801
+ }
3802
+ *s = sum;
3803
+
3804
+ #elif defined __AVX2__
3805
+
3806
+ const __m256i m4 = _mm256_set1_epi8(0xF);
3807
+ const __m256i m2 = _mm256_set1_epi8(3);
3808
+ const __m256i m32s = _mm256_set1_epi8(32);
3809
+
3810
+ __m256 acc = _mm256_setzero_ps();
3811
+
3812
+ for (int i = 0; i < nb; ++i) {
3813
+
3814
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3815
+
3816
+ const uint8_t * restrict q4 = x[i].ql;
3817
+ const uint8_t * restrict qh = x[i].qh;
3818
+ const int8_t * restrict q8 = y[i].qs;
3819
+
3820
+ const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
3821
+ const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
3822
+ const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
3823
+ const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
3824
+
3825
+ __m256i sumi = _mm256_setzero_si256();
3826
+
3827
+ const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
3828
+ const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
3829
+
3830
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
3831
+ const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
3832
+
3833
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
3834
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
3835
+
3836
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
3837
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
3838
+
3839
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3840
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3841
+
3842
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
3843
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
3844
+
3845
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
3846
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
3847
+
3848
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
3849
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
3850
+
3851
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
3852
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
3853
+
3854
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
3855
+
3856
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
3857
+ }
3858
+
3859
+ *s = hsum_float_8(acc);
3860
+
3861
+ #else
3862
+
3863
+ int8_t aux8[QK_K];
3864
+ int16_t aux16[8];
3865
+ float sums [8];
3866
+ int32_t aux32[8];
3867
+ memset(sums, 0, 8*sizeof(float));
3868
+
3869
+ float sumf = 0;
3870
+ for (int i = 0; i < nb; ++i) {
3871
+ const uint8_t * restrict q4 = x[i].ql;
3872
+ const uint8_t * restrict qh = x[i].qh;
3873
+ const int8_t * restrict q8 = y[i].qs;
3874
+ memset(aux32, 0, 8*sizeof(int32_t));
3875
+ int8_t * restrict a = aux8;
3876
+ for (int l = 0; l < 16; ++l) {
3877
+ a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
3878
+ a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
3879
+ a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
3880
+ a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
3881
+ }
3882
+ int is = 0;
3883
+ for (int j = 0; j < QK_K/16; ++j) {
3884
+ int scale = x[i].scales[is++];
3885
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
3886
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3887
+ q8 += 8; a += 8;
3888
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
3889
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3890
+ q8 += 8; a += 8;
3891
+ }
3892
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
3893
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3894
+ }
3895
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
3896
+ *s = sumf;
3897
+ #endif
3898
+ }
3899
+
3900
+ #endif