llama_cpp 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
256
256
  (device float *) ((device char *) dst + i*nb1), ne00);
257
257
  }
258
258
 
259
+ kernel void kernel_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+ // MEAN
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00];
275
+ }
276
+ // reduce
277
+ threadgroup_barrier(mem_flags::mem_threadgroup);
278
+ for (uint i = ntg/2; i > 0; i /= 2) {
279
+ if (tpitg < i) {
280
+ sum[tpitg] += sum[tpitg + i];
281
+ }
282
+ threadgroup_barrier(mem_flags::mem_threadgroup);
283
+ }
284
+ // broadcast
285
+ if (tpitg == 0) {
286
+ sum[0] /= ne00;
287
+ }
288
+ threadgroup_barrier(mem_flags::mem_threadgroup);
289
+ const float mean = sum[0];
290
+
291
+ // recenter
292
+ device float * y = dst + tgpig*ne00;
293
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
294
+ y[i00] = x[i00] - mean;
295
+ }
296
+
297
+ // VARIANCE
298
+ // parallel sum
299
+ sum[tpitg] = 0.0f;
300
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
301
+ sum[tpitg] += y[i00] * y[i00];
302
+ }
303
+ // reduce
304
+ threadgroup_barrier(mem_flags::mem_threadgroup);
305
+ for (uint i = ntg/2; i > 0; i /= 2) {
306
+ if (tpitg < i) {
307
+ sum[tpitg] += sum[tpitg + i];
308
+ }
309
+ threadgroup_barrier(mem_flags::mem_threadgroup);
310
+ }
311
+ // broadcast
312
+ if (tpitg == 0) {
313
+ sum[0] /= ne00;
314
+ }
315
+ threadgroup_barrier(mem_flags::mem_threadgroup);
316
+ const float variance = sum[0];
317
+
318
+ const float scale = 1.0f/sqrt(variance + eps);
319
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
320
+ y[i00] = y[i00] * scale;
321
+ }
322
+ }
323
+
324
+
259
325
  kernel void kernel_rms_norm(
260
326
  device const void * src0,
261
327
  device float * dst,
@@ -362,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
362
428
  }
363
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
364
430
  if (ith == 0) {
365
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
431
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
366
432
  dst[r1*ne0 + r0] = sum[0];
367
433
  }
368
434
  }
@@ -431,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
431
497
  }
432
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
433
499
  if (ith == 0) {
434
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
500
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
435
501
  dst[r1*ne0 + r0] = sum[0];
436
502
  }
437
503
  }
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
485
551
  }
486
552
  }
487
553
 
554
+ kernel void kernel_alibi_f32(
555
+ device const float * src0,
556
+ device float * dst,
557
+ constant int64_t & ne00,
558
+ constant int64_t & ne01,
559
+ constant int64_t & ne02,
560
+ constant int64_t & ne03,
561
+ constant uint64_t & nb00,
562
+ constant uint64_t & nb01,
563
+ constant uint64_t & nb02,
564
+ constant uint64_t & nb03,
565
+ constant int64_t & ne0,
566
+ constant int64_t & ne1,
567
+ constant int64_t & ne2,
568
+ constant int64_t & ne3,
569
+ constant uint64_t & nb0,
570
+ constant uint64_t & nb1,
571
+ constant uint64_t & nb2,
572
+ constant uint64_t & nb3,
573
+ constant float & m0,
574
+ uint3 tgpig[[threadgroup_position_in_grid]],
575
+ uint3 tpitg[[thread_position_in_threadgroup]],
576
+ uint3 ntg[[threads_per_threadgroup]]) {
577
+ const int64_t i03 = tgpig[2];
578
+ const int64_t i02 = tgpig[1];
579
+ const int64_t i01 = tgpig[0];
580
+
581
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
582
+
583
+ const int64_t i3 = n / (ne2*ne1*ne0);
584
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
585
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
586
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
587
+
588
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
589
+ float m_k = pow(m0, i2 + 1);
590
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
591
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
592
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
593
+ }
594
+ }
595
+
488
596
  kernel void kernel_rope(
489
597
  device const void * src0,
490
598
  device float * dst,
@@ -540,6 +648,47 @@ kernel void kernel_rope(
540
648
  }
541
649
  }
542
650
 
651
+ kernel void kernel_cpy_f16_f16(
652
+ device const half * src0,
653
+ device half * dst,
654
+ constant int64_t & ne00,
655
+ constant int64_t & ne01,
656
+ constant int64_t & ne02,
657
+ constant int64_t & ne03,
658
+ constant uint64_t & nb00,
659
+ constant uint64_t & nb01,
660
+ constant uint64_t & nb02,
661
+ constant uint64_t & nb03,
662
+ constant int64_t & ne0,
663
+ constant int64_t & ne1,
664
+ constant int64_t & ne2,
665
+ constant int64_t & ne3,
666
+ constant uint64_t & nb0,
667
+ constant uint64_t & nb1,
668
+ constant uint64_t & nb2,
669
+ constant uint64_t & nb3,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint3 tpitg[[thread_position_in_threadgroup]],
672
+ uint3 ntg[[threads_per_threadgroup]]) {
673
+ const int64_t i03 = tgpig[2];
674
+ const int64_t i02 = tgpig[1];
675
+ const int64_t i01 = tgpig[0];
676
+
677
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
678
+
679
+ const int64_t i3 = n / (ne2*ne1*ne0);
680
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
681
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
682
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
683
+
684
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
685
+
686
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
687
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
688
+ dst_data[i00] = src[0];
689
+ }
690
+ }
691
+
543
692
  kernel void kernel_cpy_f32_f16(
544
693
  device const float * src0,
545
694
  device half * dst,
@@ -626,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
626
775
 
627
776
  //============================================ k-quants ======================================================
628
777
 
778
+ #ifndef QK_K
629
779
  #define QK_K 256
780
+ #else
781
+ static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
782
+ #endif
783
+
784
+ #if QK_K == 256
785
+ #define K_SCALE_SIZE 12
786
+ #else
787
+ #define K_SCALE_SIZE 4
788
+ #endif
630
789
 
631
790
  typedef struct {
632
791
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
633
792
  uint8_t qs[QK_K/4]; // quants
634
793
  half d; // super-block scale for quantized scales
635
794
  half dmin; // super-block scale for quantized mins
636
- } block_q2_k;
795
+ } block_q2_K;
637
796
  // 84 bytes / block
638
797
 
639
798
  typedef struct {
640
799
  uint8_t hmask[QK_K/8]; // quants - high bit
641
800
  uint8_t qs[QK_K/4]; // quants - low 2 bits
642
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
643
- half d; // super-block scale
644
- } block_q3_k;
645
- // 110 bytes / block
646
-
801
+ #if QK_K == 64
802
+ uint8_t scales[2];
803
+ #else
804
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
805
+ #endif
806
+ half d; // super-block scale
807
+ } block_q3_K;
808
+
809
+ #if QK_K == 64
810
+ typedef struct {
811
+ half d[2]; // super-block scales/mins
812
+ uint8_t scales[2];
813
+ uint8_t qs[QK_K/2]; // 4-bit quants
814
+ } block_q4_K;
815
+ #else
647
816
  typedef struct {
648
817
  half d; // super-block scale for quantized scales
649
818
  half dmin; // super-block scale for quantized mins
650
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
819
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
651
820
  uint8_t qs[QK_K/2]; // 4--bit quants
652
- } block_q4_k;
653
- // 144 bytes / block
821
+ } block_q4_K;
822
+ #endif
654
823
 
824
+ #if QK_K == 64
825
+ typedef struct {
826
+ half d; // super-block scales/mins
827
+ int8_t scales[QK_K/16]; // 8-bit block scales
828
+ uint8_t qh[QK_K/8]; // quants, high bit
829
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
830
+ } block_q5_K;
831
+ #else
655
832
  typedef struct {
656
833
  half d; // super-block scale for quantized scales
657
834
  half dmin; // super-block scale for quantized mins
658
835
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
659
836
  uint8_t qh[QK_K/8]; // quants, high bit
660
837
  uint8_t qs[QK_K/2]; // quants, low 4 bits
661
- } block_q5_k;
838
+ } block_q5_K;
662
839
  // 176 bytes / block
840
+ #endif
663
841
 
664
842
  typedef struct {
665
843
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
666
844
  uint8_t qh[QK_K/4]; // quants, upper 2 bits
667
845
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
668
846
  half d; // super-block scale
669
- } block_q6_k;
847
+ } block_q6_K;
670
848
  // 210 bytes / block
671
849
 
672
850
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -687,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
687
865
 
688
866
  //========================================== dequantization =============================
689
867
 
690
- static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
868
+ static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
691
869
  assert(k % QK_K == 0);
692
870
  const int nb = k / QK_K;
693
871
 
@@ -698,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
698
876
 
699
877
  device const uint8_t * q = x[i].qs;
700
878
 
879
+ #if QK_K == 256
701
880
  int is = 0;
702
881
  float dl, ml;
703
882
  for (int n = 0; n < QK_K; n += 128) {
@@ -716,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
716
895
  }
717
896
  q += 32;
718
897
  }
898
+ #else
899
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
900
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
901
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
902
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
903
+ for (int l = 0; l < 16; ++l) {
904
+ y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
905
+ y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
906
+ y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
907
+ y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
908
+ }
909
+ y += QK_K;
910
+ #endif
719
911
 
720
912
  }
721
913
  }
722
914
 
723
- static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
915
+ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
724
916
  assert(k % QK_K == 0);
725
917
  const int nb = k / QK_K;
726
918
 
919
+ #if QK_K == 256
920
+
727
921
  const uint16_t kmask1 = 0x0303;
728
922
  const uint16_t kmask2 = 0x0f0f;
729
923
 
@@ -769,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
769
963
  }
770
964
  q += 32;
771
965
  }
966
+ }
967
+ #else
968
+ for (int i = 0; i < nb; i++) {
969
+
970
+ const float d_all = (float)(x[i].d);
772
971
 
972
+ device const uint8_t * q = x[i].qs;
973
+ device const uint8_t * hm = x[i].hmask;
974
+
975
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
976
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
977
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
978
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
979
+
980
+ for (int l = 0; l < 8; ++l) {
981
+ uint8_t h = hm[l];
982
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
983
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
984
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
985
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
986
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
987
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
988
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
989
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
990
+ }
991
+ y += QK_K;
773
992
  }
993
+ #endif
774
994
 
775
995
  }
776
996
 
777
- static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
997
+ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
778
998
  assert(k % QK_K == 0);
779
999
  const int nb = k / QK_K;
780
1000
 
781
-
782
1001
  for (int i = 0; i < nb; i++) {
783
1002
 
1003
+ device const uint8_t * q = x[i].qs;
1004
+
1005
+ #if QK_K == 256
784
1006
  const float d = x[i].d;
785
1007
  const float min = x[i].dmin;
786
1008
 
787
- device const uint8_t * q = x[i].qs;
788
1009
  device const uint8_t * scales = x[i].scales;
789
1010
 
790
1011
  int is = 0;
@@ -796,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
796
1017
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
797
1018
  q += 32; is += 2;
798
1019
  }
1020
+ #else
1021
+ device const uint8_t * s = x[i].scales;
1022
+ device const half2 * dh = (device const half2 *)x[i].d;
1023
+ const float2 d = (float2)dh[0];
1024
+ const float d1 = d[0] * (s[0] & 0xF);
1025
+ const float d2 = d[0] * (s[1] & 0xF);
1026
+ const float m1 = d[1] * (s[0] >> 4);
1027
+ const float m2 = d[1] * (s[1] >> 4);
1028
+ for (int l = 0; l < 32; ++l) {
1029
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1030
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1031
+ }
1032
+ y += QK_K;
1033
+ #endif
799
1034
 
800
1035
  }
801
1036
  }
802
1037
 
803
- static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
1038
+ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
804
1039
  assert(k % QK_K == 0);
805
1040
  const int nb = k / QK_K;
806
1041
 
1042
+ #if QK_K == 256
807
1043
  for (int i = 0; i < nb; i++) {
808
1044
 
809
1045
  const float d = (float)(x[i].d);
@@ -824,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
824
1060
  u1 <<= 2; u2 <<= 2;
825
1061
  }
826
1062
  }
1063
+ #else
1064
+ for (int i = 0; i < nb; i++) {
1065
+
1066
+ const float d = (float)x[i].d;
1067
+
1068
+ device const uint8_t * ql = x[i].qs;
1069
+ device const uint8_t * qh = x[i].qh;
1070
+ device const int8_t * sc = x[i].scales;
1071
+
1072
+ for (int l = 0; l < 8; ++l) {
1073
+ y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1074
+ y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1075
+ y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1076
+ y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1077
+ y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1078
+ y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1079
+ y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1080
+ y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1081
+ }
1082
+ y += QK_K;
1083
+ }
1084
+ #endif
827
1085
 
828
1086
  }
829
1087
 
830
- static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
1088
+ static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
831
1089
  assert(k % QK_K == 0);
832
1090
  const int nb = k / QK_K;
833
1091
 
@@ -839,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
839
1097
 
840
1098
  const float d = x[i].d;
841
1099
 
1100
+ #if QK_K == 256
842
1101
  for (int n = 0; n < QK_K; n += 128) {
843
1102
  for (int l = 0; l < 32; ++l) {
844
1103
  int is = l/16;
@@ -856,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
856
1115
  qh += 32;
857
1116
  sc += 8;
858
1117
  }
1118
+ #else
1119
+ for (int l = 0; l < 16; ++l) {
1120
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1121
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1122
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1123
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1124
+ y[l+ 0] = d * sc[0] * q1;
1125
+ y[l+16] = d * sc[1] * q2;
1126
+ y[l+32] = d * sc[2] * q3;
1127
+ y[l+48] = d * sc[3] * q4;
1128
+ }
1129
+ y += 64;
1130
+ #endif
859
1131
  }
860
1132
  }
861
1133
 
862
- kernel void kernel_get_rows_q2_k(
1134
+ kernel void kernel_get_rows_q2_K(
863
1135
  device const void * src0,
864
1136
  device const int * src1,
865
1137
  device float * dst,
@@ -870,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
870
1142
  const int i = tpig;
871
1143
  const int r = ((device int32_t *) src1)[i];
872
1144
 
873
- dequantize_row_q2_k(
874
- (device const block_q2_k *) ((device char *) src0 + r*nb01),
1145
+ dequantize_row_q2_K(
1146
+ (device const block_q2_K *) ((device char *) src0 + r*nb01),
875
1147
  (device float *) ((device char *) dst + i*nb1), ne00);
876
1148
  }
877
1149
 
878
- kernel void kernel_get_rows_q3_k(
1150
+ kernel void kernel_get_rows_q3_K(
879
1151
  device const void * src0,
880
1152
  device const int * src1,
881
1153
  device float * dst,
@@ -886,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
886
1158
  const int i = tpig;
887
1159
  const int r = ((device int32_t *) src1)[i];
888
1160
 
889
- dequantize_row_q3_k(
890
- (device const block_q3_k *) ((device char *) src0 + r*nb01),
1161
+ dequantize_row_q3_K(
1162
+ (device const block_q3_K *) ((device char *) src0 + r*nb01),
891
1163
  (device float *) ((device char *) dst + i*nb1), ne00);
892
1164
  }
893
1165
 
894
- kernel void kernel_get_rows_q4_k(
1166
+ kernel void kernel_get_rows_q4_K(
895
1167
  device const void * src0,
896
1168
  device const int * src1,
897
1169
  device float * dst,
@@ -902,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
902
1174
  const int i = tpig;
903
1175
  const int r = ((device int32_t *) src1)[i];
904
1176
 
905
- dequantize_row_q4_k(
906
- (device const block_q4_k *) ((device char *) src0 + r*nb01),
1177
+ dequantize_row_q4_K(
1178
+ (device const block_q4_K *) ((device char *) src0 + r*nb01),
907
1179
  (device float *) ((device char *) dst + i*nb1), ne00);
908
1180
  }
909
1181
 
910
- kernel void kernel_get_rows_q5_k(
1182
+ kernel void kernel_get_rows_q5_K(
911
1183
  device const void * src0,
912
1184
  device const int * src1,
913
1185
  device float * dst,
@@ -918,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
918
1190
  const int i = tpig;
919
1191
  const int r = ((device int32_t *) src1)[i];
920
1192
 
921
- dequantize_row_q5_k(
922
- (device const block_q5_k *) ((device char *) src0 + r*nb01),
1193
+ dequantize_row_q5_K(
1194
+ (device const block_q5_K *) ((device char *) src0 + r*nb01),
923
1195
  (device float *) ((device char *) dst + i*nb1), ne00);
924
1196
  }
925
1197
 
926
- kernel void kernel_get_rows_q6_k(
1198
+ kernel void kernel_get_rows_q6_K(
927
1199
  device const void * src0,
928
1200
  device const int * src1,
929
1201
  device float * dst,
@@ -934,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
934
1206
  const int i = tpig;
935
1207
  const int r = ((device int32_t *) src1)[i];
936
1208
 
937
- dequantize_row_q6_k(
938
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
1209
+ dequantize_row_q6_K(
1210
+ (device const block_q6_K *) ((device char *) src0 + r*nb01),
939
1211
  (device float *) ((device char *) dst + i*nb1), ne00);
940
1212
  }
941
1213
 
942
1214
  //====================================== dot products =========================
943
1215
 
944
- kernel void kernel_mul_mat_q2_k_f32(
1216
+ kernel void kernel_mul_mat_q2_K_f32(
945
1217
  device const void * src0,
946
1218
  device const float * src1,
947
1219
  device float * dst,
@@ -958,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
958
1230
  const int64_t r0 = tgpig.x;
959
1231
  const int64_t r1 = tgpig.y;
960
1232
 
961
- device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
1233
+ device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
962
1234
  device const float * yy = (device const float *) src1 + r1*ne10;
963
1235
 
964
1236
  const int nth = tptg.x*tptg.y;
965
1237
  const int ith = tptg.y*tpitg.x + tpitg.y;
966
1238
 
1239
+ float sumf = 0;
1240
+
1241
+ #if QK_K == 256
967
1242
  const int tid = tpitg.y; // 0...16
968
1243
  const int il = tid/4; // 0...3
969
1244
  const int ir = tid%4; // 0...3
@@ -976,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
976
1251
  const int y_offset = 64*il + n*ir;
977
1252
  const int q_offset = 32*ip + n*ir;
978
1253
 
979
- sum[ith] = 0.0f;
980
-
981
- float sumf = 0;
982
1254
  for (int i = tpitg.x; i < nb; i += tptg.x) {
983
1255
 
984
1256
  device const uint8_t * q = x[i].qs + q_offset;
@@ -991,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
991
1263
 
992
1264
  device const float * y = yy + i*QK_K + y_offset;
993
1265
 
994
- //float4 s = {0.f, 0.f, 0.f, 0.f};
995
1266
  float2 s = {0.f, 0.f};
996
1267
  float smin = 0;
997
1268
  for (int l = 0; l < n; ++l) {
@@ -1006,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
1006
1277
  sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1007
1278
 
1008
1279
  }
1009
- sum[ith] = sumf;
1280
+ #else
1281
+ const int il = 4 * tpitg.x;
1010
1282
 
1011
- //int mask1 = (ith%4 == 0);
1012
- //int mask2 = (ith%16 == 0);
1283
+ uint32_t aux[2];
1284
+ thread const uint8_t * d = (thread const uint8_t *)aux;
1285
+ thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1013
1286
 
1014
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1015
- //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1016
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1017
- //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1018
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1019
- //if (ith == 0) {
1020
- // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1021
- // dst[r1*ne0 + r0] = sum[0];
1022
- //}
1287
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+
1289
+ device const uint8_t * q = x[i].qs + il;
1290
+ device const float * y = yy + i*QK_K + il;
1291
+
1292
+ const float dall = (float)x[i].d;
1293
+ const float dmin = (float)x[i].dmin;
1294
+
1295
+ device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
+ aux[0] = a[0] & 0x0f0f0f0f;
1297
+ aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+
1299
+ for (int l = 0; l < 4; ++l) {
1300
+ sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
+ + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
+ + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
+ + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ }
1305
+ }
1306
+ #endif
1307
+
1308
+ sum[ith] = sumf;
1023
1309
 
1024
1310
  //
1025
1311
  // Accumulate the sum from all threads in the threadgroup
1026
- // This version is slightly faster than the commented out one below,
1027
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1028
1312
  //
1029
1313
  threadgroup_barrier(mem_flags::mem_threadgroup);
1030
1314
  if (ith%4 == 0) {
@@ -1041,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
1041
1325
  }
1042
1326
  }
1043
1327
 
1044
- kernel void kernel_mul_mat_q3_k_f32(
1328
+ kernel void kernel_mul_mat_q3_K_f32(
1045
1329
  device const void * src0,
1046
1330
  device const float * src1,
1047
1331
  device float * dst,
@@ -1054,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
1054
1338
  uint2 tpitg[[thread_position_in_threadgroup]],
1055
1339
  uint2 tptg[[threads_per_threadgroup]]) {
1056
1340
 
1057
- const uint16_t kmask1 = 0x0303;
1058
- const uint16_t kmask2 = 0x0f0f;
1059
-
1060
- const uint8_t m3 = 3;
1061
- const int8_t m4 = 4;
1062
-
1063
1341
  const int nb = ne00/QK_K;
1064
1342
 
1065
1343
  const int64_t r0 = tgpig.x;
1066
1344
  const int64_t r1 = tgpig.y;
1067
1345
 
1068
- device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1346
+ device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1069
1347
  device const float * yy = (device const float *) src1 + r1*ne10;
1070
1348
 
1071
1349
  const int nth = tptg.x*tptg.y;
1072
1350
  const int ith = tptg.y*tpitg.x + tpitg.y;
1073
1351
 
1352
+ #if QK_K == 256
1353
+
1354
+ const uint8_t m3 = 3;
1355
+ const int8_t m4 = 4;
1356
+
1357
+ const uint16_t kmask1 = 0x0303;
1358
+ const uint16_t kmask2 = 0x0f0f;
1359
+
1074
1360
  const int tid = tpitg.y; // expecting 16
1075
1361
  const int ip = tid/8; // 0 or 1
1076
1362
  const int il = tid/2 - 4*ip; // 0...3
@@ -1124,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
1124
1410
 
1125
1411
  //sum[ith] = sumf;
1126
1412
  sum[ith] = sumf1 - 32.f*sumf2;
1413
+ #else
1414
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
+ const int im = il/8; // 0, 0, 1, 1
1416
+ const int in = il%8; // 0, 4, 0, 4
1417
+
1418
+ float sumf = 0;
1419
+
1420
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1421
+
1422
+ const float d_all = (float)(x[i].d);
1423
+
1424
+ device const uint8_t * q = x[i].qs + il;
1425
+ device const uint8_t * h = x[i].hmask + in;
1426
+ device const float * y = yy + i * QK_K + il;
1427
+
1428
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1432
+
1433
+ for (int l = 0; l < 4; ++l) {
1434
+ const uint8_t hm = h[l] >> im;
1435
+ sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
+ + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
+ + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
+ + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
+ }
1440
+
1441
+ }
1442
+
1443
+ sum[ith] = sumf;
1444
+
1445
+ #endif
1127
1446
 
1128
1447
  //
1129
1448
  // Accumulate the sum from all threads in the threadgroup
@@ -1144,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
1144
1463
 
1145
1464
  }
1146
1465
 
1147
- kernel void kernel_mul_mat_q4_k_f32(
1466
+ kernel void kernel_mul_mat_q4_K_f32(
1148
1467
  device const void * src0,
1149
1468
  device const float * src1,
1150
1469
  device float * dst,
@@ -1156,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
1156
1475
  uint2 tpitg[[thread_position_in_threadgroup]],
1157
1476
  uint2 tptg[[threads_per_threadgroup]]) {
1158
1477
 
1159
- const uint16_t kmask1 = 0x3f3f;
1160
- const uint16_t kmask2 = 0x0f0f;
1161
- const uint16_t kmask3 = 0xc0c0;
1162
-
1163
1478
  const int nb = ne00/QK_K;
1164
1479
 
1165
1480
  const int64_t r0 = tgpig.x;
1166
1481
  const int64_t r1 = tgpig.y;
1167
1482
 
1168
- device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
1169
- device const float * yy = (device const float *) src1 + r1*ne10;
1170
-
1171
1483
  const int nth = tptg.x*tptg.y;
1172
1484
  const int ith = tptg.y*tpitg.x + tpitg.y;
1173
1485
 
1486
+ device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1487
+ device const float * yy = (device const float *) src1 + r1*ne10;
1488
+
1489
+ float sumf = 0;
1490
+
1491
+ #if QK_K == 256
1492
+
1493
+ const uint16_t kmask1 = 0x3f3f;
1494
+ const uint16_t kmask2 = 0x0f0f;
1495
+ const uint16_t kmask3 = 0xc0c0;
1496
+
1174
1497
  const int tid = tpitg.y; // 0...16
1175
1498
  const int il = tid/4; // 0...3
1176
1499
  const int ir = tid - 4*il;// 0...3
@@ -1183,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
1183
1506
  const int q_offset = 32*im + l0;
1184
1507
  const int y_offset = 64*im + l0;
1185
1508
 
1186
- sum[ith] = 0.0f;
1187
-
1188
1509
  uchar2 sc1, sc2, sc3, sc4;
1189
1510
 
1190
- float sumf = 0;
1191
1511
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1192
1512
 
1193
1513
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1216,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
1216
1536
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1217
1537
 
1218
1538
  }
1539
+ #else
1540
+ uint16_t aux16[2];
1541
+ thread const uint8_t * scales = (thread const uint8_t *)aux16;
1542
+
1543
+ const int il = 4*tpitg.x;
1544
+
1545
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1546
+
1547
+ device const uint8_t * q = x[i].qs + il;
1548
+ device const float * y = yy + i * QK_K + il;
1549
+
1550
+ const float d = (float)x[i].d[0];
1551
+ const float m = (float)x[i].d[1];
1552
+
1553
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
+ aux16[0] = a[0] & 0x0f0f;
1555
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
1556
+
1557
+ for (int l = 0; l < 4; ++l) {
1558
+ sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
+ + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1560
+ }
1561
+ }
1562
+ #endif
1219
1563
 
1220
1564
  sum[ith] = sumf;
1221
1565
 
@@ -1252,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
1252
1596
  //}
1253
1597
  }
1254
1598
 
1255
- kernel void kernel_mul_mat_q5_k_f32(
1599
+ kernel void kernel_mul_mat_q5_K_f32(
1256
1600
  device const void * src0,
1257
1601
  device const float * src1,
1258
1602
  device float * dst,
@@ -1264,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
1264
1608
  uint2 tpitg[[thread_position_in_threadgroup]],
1265
1609
  uint2 tptg[[threads_per_threadgroup]]) {
1266
1610
 
1267
- const uint16_t kmask1 = 0x3f3f;
1268
- const uint16_t kmask2 = 0x0f0f;
1269
- const uint16_t kmask3 = 0xc0c0;
1270
-
1271
1611
  const int nb = ne00/QK_K;
1272
1612
 
1273
1613
  const int64_t r0 = tgpig.x;
1274
1614
  const int64_t r1 = tgpig.y;
1275
1615
 
1276
- device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1616
+ device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1277
1617
  device const float * yy = (device const float *) src1 + r1*ne10;
1278
1618
 
1279
1619
  const int nth = tptg.x*tptg.y;
1280
1620
  const int ith = tptg.y*tpitg.x + tpitg.y;
1281
1621
 
1622
+ float sumf = 0;
1623
+
1624
+ #if QK_K == 256
1625
+
1626
+ const uint16_t kmask1 = 0x3f3f;
1627
+ const uint16_t kmask2 = 0x0f0f;
1628
+ const uint16_t kmask3 = 0xc0c0;
1629
+
1282
1630
  const int tid = tpitg.y; // 0...16
1283
1631
  const int il = tid/4; // 0...3
1284
1632
  const int ir = tid - 4*il;// 0...3
@@ -1298,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
1298
1646
 
1299
1647
  uchar2 sc1, sc2, sc3, sc4;
1300
1648
 
1301
- float sumf = 0;
1302
1649
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1303
1650
 
1304
1651
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1330,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
1330
1677
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1331
1678
 
1332
1679
  }
1680
+ #else
1681
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
+ const int im = il/8; // 0, 0, 1, 1
1683
+ const int in = il%8; // 0, 4, 0, 4
1684
+
1685
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1686
+
1687
+ const float d = (float)x[i].d;
1688
+ device const uint8_t * q = x[i].qs + il;
1689
+ device const uint8_t * h = x[i].qh + in;
1690
+ device const int8_t * s = x[i].scales;
1691
+ device const float * y = yy + i*QK_K + il;
1692
+
1693
+ for (int l = 0; l < 4; ++l) {
1694
+ const uint8_t hl = h[l] >> im;
1695
+ sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
+ + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
+ + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
+ + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1699
+ }
1700
+ }
1701
+ #endif
1333
1702
  sum[ith] = sumf;
1334
1703
 
1335
1704
  //
@@ -1351,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
1351
1720
 
1352
1721
  }
1353
1722
 
1354
- kernel void kernel_mul_mat_q6_k_f32(
1723
+ kernel void kernel_mul_mat_q6_K_f32(
1355
1724
  device const void * src0,
1356
1725
  device const float * src1,
1357
1726
  device float * dst,
@@ -1373,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
1373
1742
  const int64_t r0 = tgpig.x;
1374
1743
  const int64_t r1 = tgpig.y;
1375
1744
 
1376
- device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1745
+ device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1377
1746
  device const float * yy = (device const float *) src1 + r1*ne10;
1378
1747
 
1379
1748
  const int nth = tptg.x*tptg.y;
1380
1749
  const int ith = tptg.y*tpitg.x + tpitg.y;
1381
1750
 
1751
+ float sumf = 0;
1752
+
1753
+ #if QK_K == 256
1382
1754
  // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1383
1755
  const int iqs = 16 * tpitg.y;
1384
1756
  const int ip = iqs / 128; // 0 or 1
@@ -1391,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
1391
1763
  const int q_offset_l = 64*ip + l0;
1392
1764
  const int q_offset_h = 32*ip + l0;
1393
1765
 
1394
- float sumf = 0;
1395
1766
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1396
1767
 
1397
1768
  device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1413,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
1413
1784
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1414
1785
 
1415
1786
  }
1787
+ #else
1788
+ const int il = 4*tpitg.x; // 0, 4, 8, 12
1789
+
1790
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1791
+ device const float * y = yy + i * QK_K + il;
1792
+ device const uint8_t * ql = x[i].ql + il;
1793
+ device const uint8_t * qh = x[i].qh + il;
1794
+ device const int8_t * s = x[i].scales;
1795
+
1796
+ const float d = x[i].d;
1797
+
1798
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1799
+ for (int l = 0; l < 4; ++l) {
1800
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1801
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1802
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
1803
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1804
+ }
1805
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
1806
+ }
1807
+
1808
+ #endif
1416
1809
 
1417
1810
  sum[ith] = sumf;
1418
1811