llama_cpp 0.2.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
256
256
  (device float *) ((device char *) dst + i*nb1), ne00);
257
257
  }
258
258
 
259
+ kernel void kernel_norm(
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0)]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+ // MEAN
271
+ // parallel sum
272
+ sum[tpitg] = 0.0f;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00];
275
+ }
276
+ // reduce
277
+ threadgroup_barrier(mem_flags::mem_threadgroup);
278
+ for (uint i = ntg/2; i > 0; i /= 2) {
279
+ if (tpitg < i) {
280
+ sum[tpitg] += sum[tpitg + i];
281
+ }
282
+ threadgroup_barrier(mem_flags::mem_threadgroup);
283
+ }
284
+ // broadcast
285
+ if (tpitg == 0) {
286
+ sum[0] /= ne00;
287
+ }
288
+ threadgroup_barrier(mem_flags::mem_threadgroup);
289
+ const float mean = sum[0];
290
+
291
+ // recenter
292
+ device float * y = dst + tgpig*ne00;
293
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
294
+ y[i00] = x[i00] - mean;
295
+ }
296
+
297
+ // VARIANCE
298
+ // parallel sum
299
+ sum[tpitg] = 0.0f;
300
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
301
+ sum[tpitg] += y[i00] * y[i00];
302
+ }
303
+ // reduce
304
+ threadgroup_barrier(mem_flags::mem_threadgroup);
305
+ for (uint i = ntg/2; i > 0; i /= 2) {
306
+ if (tpitg < i) {
307
+ sum[tpitg] += sum[tpitg + i];
308
+ }
309
+ threadgroup_barrier(mem_flags::mem_threadgroup);
310
+ }
311
+ // broadcast
312
+ if (tpitg == 0) {
313
+ sum[0] /= ne00;
314
+ }
315
+ threadgroup_barrier(mem_flags::mem_threadgroup);
316
+ const float variance = sum[0];
317
+
318
+ const float scale = 1.0f/sqrt(variance + eps);
319
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
320
+ y[i00] = y[i00] * scale;
321
+ }
322
+ }
323
+
324
+
259
325
  kernel void kernel_rms_norm(
260
326
  device const void * src0,
261
327
  device float * dst,
@@ -362,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
362
428
  }
363
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
364
430
  if (ith == 0) {
365
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
431
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
366
432
  dst[r1*ne0 + r0] = sum[0];
367
433
  }
368
434
  }
@@ -431,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
431
497
  }
432
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
433
499
  if (ith == 0) {
434
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
500
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
435
501
  dst[r1*ne0 + r0] = sum[0];
436
502
  }
437
503
  }
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
485
551
  }
486
552
  }
487
553
 
554
+ kernel void kernel_alibi_f32(
555
+ device const float * src0,
556
+ device float * dst,
557
+ constant int64_t & ne00,
558
+ constant int64_t & ne01,
559
+ constant int64_t & ne02,
560
+ constant int64_t & ne03,
561
+ constant uint64_t & nb00,
562
+ constant uint64_t & nb01,
563
+ constant uint64_t & nb02,
564
+ constant uint64_t & nb03,
565
+ constant int64_t & ne0,
566
+ constant int64_t & ne1,
567
+ constant int64_t & ne2,
568
+ constant int64_t & ne3,
569
+ constant uint64_t & nb0,
570
+ constant uint64_t & nb1,
571
+ constant uint64_t & nb2,
572
+ constant uint64_t & nb3,
573
+ constant float & m0,
574
+ uint3 tgpig[[threadgroup_position_in_grid]],
575
+ uint3 tpitg[[thread_position_in_threadgroup]],
576
+ uint3 ntg[[threads_per_threadgroup]]) {
577
+ const int64_t i03 = tgpig[2];
578
+ const int64_t i02 = tgpig[1];
579
+ const int64_t i01 = tgpig[0];
580
+
581
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
582
+
583
+ const int64_t i3 = n / (ne2*ne1*ne0);
584
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
585
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
586
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
587
+
588
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
589
+ float m_k = pow(m0, i2 + 1);
590
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
591
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
592
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
593
+ }
594
+ }
595
+
488
596
  kernel void kernel_rope(
489
597
  device const void * src0,
490
598
  device float * dst,
@@ -540,6 +648,47 @@ kernel void kernel_rope(
540
648
  }
541
649
  }
542
650
 
651
+ kernel void kernel_cpy_f16_f16(
652
+ device const half * src0,
653
+ device half * dst,
654
+ constant int64_t & ne00,
655
+ constant int64_t & ne01,
656
+ constant int64_t & ne02,
657
+ constant int64_t & ne03,
658
+ constant uint64_t & nb00,
659
+ constant uint64_t & nb01,
660
+ constant uint64_t & nb02,
661
+ constant uint64_t & nb03,
662
+ constant int64_t & ne0,
663
+ constant int64_t & ne1,
664
+ constant int64_t & ne2,
665
+ constant int64_t & ne3,
666
+ constant uint64_t & nb0,
667
+ constant uint64_t & nb1,
668
+ constant uint64_t & nb2,
669
+ constant uint64_t & nb3,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint3 tpitg[[thread_position_in_threadgroup]],
672
+ uint3 ntg[[threads_per_threadgroup]]) {
673
+ const int64_t i03 = tgpig[2];
674
+ const int64_t i02 = tgpig[1];
675
+ const int64_t i01 = tgpig[0];
676
+
677
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
678
+
679
+ const int64_t i3 = n / (ne2*ne1*ne0);
680
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
681
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
682
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
683
+
684
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
685
+
686
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
687
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
688
+ dst_data[i00] = src[0];
689
+ }
690
+ }
691
+
543
692
  kernel void kernel_cpy_f32_f16(
544
693
  device const float * src0,
545
694
  device half * dst,
@@ -626,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
626
775
 
627
776
  //============================================ k-quants ======================================================
628
777
 
778
+ #ifndef QK_K
629
779
  #define QK_K 256
780
+ #else
781
+ static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
782
+ #endif
783
+
784
+ #if QK_K == 256
785
+ #define K_SCALE_SIZE 12
786
+ #else
787
+ #define K_SCALE_SIZE 4
788
+ #endif
630
789
 
631
790
  typedef struct {
632
791
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
633
792
  uint8_t qs[QK_K/4]; // quants
634
793
  half d; // super-block scale for quantized scales
635
794
  half dmin; // super-block scale for quantized mins
636
- } block_q2_k;
795
+ } block_q2_K;
637
796
  // 84 bytes / block
638
797
 
639
798
  typedef struct {
640
799
  uint8_t hmask[QK_K/8]; // quants - high bit
641
800
  uint8_t qs[QK_K/4]; // quants - low 2 bits
642
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
643
- half d; // super-block scale
644
- } block_q3_k;
645
- // 110 bytes / block
646
-
801
+ #if QK_K == 64
802
+ uint8_t scales[2];
803
+ #else
804
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
805
+ #endif
806
+ half d; // super-block scale
807
+ } block_q3_K;
808
+
809
+ #if QK_K == 64
810
+ typedef struct {
811
+ half d[2]; // super-block scales/mins
812
+ uint8_t scales[2];
813
+ uint8_t qs[QK_K/2]; // 4-bit quants
814
+ } block_q4_K;
815
+ #else
647
816
  typedef struct {
648
817
  half d; // super-block scale for quantized scales
649
818
  half dmin; // super-block scale for quantized mins
650
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
819
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
651
820
  uint8_t qs[QK_K/2]; // 4--bit quants
652
- } block_q4_k;
653
- // 144 bytes / block
821
+ } block_q4_K;
822
+ #endif
654
823
 
824
+ #if QK_K == 64
825
+ typedef struct {
826
+ half d; // super-block scales/mins
827
+ int8_t scales[QK_K/16]; // 8-bit block scales
828
+ uint8_t qh[QK_K/8]; // quants, high bit
829
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
830
+ } block_q5_K;
831
+ #else
655
832
  typedef struct {
656
833
  half d; // super-block scale for quantized scales
657
834
  half dmin; // super-block scale for quantized mins
658
835
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
659
836
  uint8_t qh[QK_K/8]; // quants, high bit
660
837
  uint8_t qs[QK_K/2]; // quants, low 4 bits
661
- } block_q5_k;
838
+ } block_q5_K;
662
839
  // 176 bytes / block
840
+ #endif
663
841
 
664
842
  typedef struct {
665
843
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
666
844
  uint8_t qh[QK_K/4]; // quants, upper 2 bits
667
845
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
668
846
  half d; // super-block scale
669
- } block_q6_k;
847
+ } block_q6_K;
670
848
  // 210 bytes / block
671
849
 
672
850
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -687,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
687
865
 
688
866
  //========================================== dequantization =============================
689
867
 
690
- static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
868
+ static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
691
869
  assert(k % QK_K == 0);
692
870
  const int nb = k / QK_K;
693
871
 
@@ -698,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
698
876
 
699
877
  device const uint8_t * q = x[i].qs;
700
878
 
879
+ #if QK_K == 256
701
880
  int is = 0;
702
881
  float dl, ml;
703
882
  for (int n = 0; n < QK_K; n += 128) {
@@ -716,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
716
895
  }
717
896
  q += 32;
718
897
  }
898
+ #else
899
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
900
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
901
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
902
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
903
+ for (int l = 0; l < 16; ++l) {
904
+ y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
905
+ y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
906
+ y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
907
+ y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
908
+ }
909
+ y += QK_K;
910
+ #endif
719
911
 
720
912
  }
721
913
  }
722
914
 
723
- static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
915
+ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
724
916
  assert(k % QK_K == 0);
725
917
  const int nb = k / QK_K;
726
918
 
919
+ #if QK_K == 256
920
+
727
921
  const uint16_t kmask1 = 0x0303;
728
922
  const uint16_t kmask2 = 0x0f0f;
729
923
 
@@ -769,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
769
963
  }
770
964
  q += 32;
771
965
  }
966
+ }
967
+ #else
968
+ for (int i = 0; i < nb; i++) {
969
+
970
+ const float d_all = (float)(x[i].d);
772
971
 
972
+ device const uint8_t * q = x[i].qs;
973
+ device const uint8_t * hm = x[i].hmask;
974
+
975
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
976
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
977
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
978
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
979
+
980
+ for (int l = 0; l < 8; ++l) {
981
+ uint8_t h = hm[l];
982
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
983
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
984
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
985
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
986
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
987
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
988
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
989
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
990
+ }
991
+ y += QK_K;
773
992
  }
993
+ #endif
774
994
 
775
995
  }
776
996
 
777
- static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
997
+ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
778
998
  assert(k % QK_K == 0);
779
999
  const int nb = k / QK_K;
780
1000
 
781
-
782
1001
  for (int i = 0; i < nb; i++) {
783
1002
 
1003
+ device const uint8_t * q = x[i].qs;
1004
+
1005
+ #if QK_K == 256
784
1006
  const float d = x[i].d;
785
1007
  const float min = x[i].dmin;
786
1008
 
787
- device const uint8_t * q = x[i].qs;
788
1009
  device const uint8_t * scales = x[i].scales;
789
1010
 
790
1011
  int is = 0;
@@ -796,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
796
1017
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
797
1018
  q += 32; is += 2;
798
1019
  }
1020
+ #else
1021
+ device const uint8_t * s = x[i].scales;
1022
+ device const half2 * dh = (device const half2 *)x[i].d;
1023
+ const float2 d = (float2)dh[0];
1024
+ const float d1 = d[0] * (s[0] & 0xF);
1025
+ const float d2 = d[0] * (s[1] & 0xF);
1026
+ const float m1 = d[1] * (s[0] >> 4);
1027
+ const float m2 = d[1] * (s[1] >> 4);
1028
+ for (int l = 0; l < 32; ++l) {
1029
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1030
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1031
+ }
1032
+ y += QK_K;
1033
+ #endif
799
1034
 
800
1035
  }
801
1036
  }
802
1037
 
803
- static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
1038
+ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
804
1039
  assert(k % QK_K == 0);
805
1040
  const int nb = k / QK_K;
806
1041
 
1042
+ #if QK_K == 256
807
1043
  for (int i = 0; i < nb; i++) {
808
1044
 
809
1045
  const float d = (float)(x[i].d);
@@ -824,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
824
1060
  u1 <<= 2; u2 <<= 2;
825
1061
  }
826
1062
  }
1063
+ #else
1064
+ for (int i = 0; i < nb; i++) {
1065
+
1066
+ const float d = (float)x[i].d;
1067
+
1068
+ device const uint8_t * ql = x[i].qs;
1069
+ device const uint8_t * qh = x[i].qh;
1070
+ device const int8_t * sc = x[i].scales;
1071
+
1072
+ for (int l = 0; l < 8; ++l) {
1073
+ y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1074
+ y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1075
+ y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1076
+ y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1077
+ y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1078
+ y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1079
+ y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1080
+ y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1081
+ }
1082
+ y += QK_K;
1083
+ }
1084
+ #endif
827
1085
 
828
1086
  }
829
1087
 
830
- static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
1088
+ static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
831
1089
  assert(k % QK_K == 0);
832
1090
  const int nb = k / QK_K;
833
1091
 
@@ -839,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
839
1097
 
840
1098
  const float d = x[i].d;
841
1099
 
1100
+ #if QK_K == 256
842
1101
  for (int n = 0; n < QK_K; n += 128) {
843
1102
  for (int l = 0; l < 32; ++l) {
844
1103
  int is = l/16;
@@ -856,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
856
1115
  qh += 32;
857
1116
  sc += 8;
858
1117
  }
1118
+ #else
1119
+ for (int l = 0; l < 16; ++l) {
1120
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1121
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1122
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1123
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1124
+ y[l+ 0] = d * sc[0] * q1;
1125
+ y[l+16] = d * sc[1] * q2;
1126
+ y[l+32] = d * sc[2] * q3;
1127
+ y[l+48] = d * sc[3] * q4;
1128
+ }
1129
+ y += 64;
1130
+ #endif
859
1131
  }
860
1132
  }
861
1133
 
862
- kernel void kernel_get_rows_q2_k(
1134
+ kernel void kernel_get_rows_q2_K(
863
1135
  device const void * src0,
864
1136
  device const int * src1,
865
1137
  device float * dst,
@@ -870,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
870
1142
  const int i = tpig;
871
1143
  const int r = ((device int32_t *) src1)[i];
872
1144
 
873
- dequantize_row_q2_k(
874
- (device const block_q2_k *) ((device char *) src0 + r*nb01),
1145
+ dequantize_row_q2_K(
1146
+ (device const block_q2_K *) ((device char *) src0 + r*nb01),
875
1147
  (device float *) ((device char *) dst + i*nb1), ne00);
876
1148
  }
877
1149
 
878
- kernel void kernel_get_rows_q3_k(
1150
+ kernel void kernel_get_rows_q3_K(
879
1151
  device const void * src0,
880
1152
  device const int * src1,
881
1153
  device float * dst,
@@ -886,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
886
1158
  const int i = tpig;
887
1159
  const int r = ((device int32_t *) src1)[i];
888
1160
 
889
- dequantize_row_q3_k(
890
- (device const block_q3_k *) ((device char *) src0 + r*nb01),
1161
+ dequantize_row_q3_K(
1162
+ (device const block_q3_K *) ((device char *) src0 + r*nb01),
891
1163
  (device float *) ((device char *) dst + i*nb1), ne00);
892
1164
  }
893
1165
 
894
- kernel void kernel_get_rows_q4_k(
1166
+ kernel void kernel_get_rows_q4_K(
895
1167
  device const void * src0,
896
1168
  device const int * src1,
897
1169
  device float * dst,
@@ -902,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
902
1174
  const int i = tpig;
903
1175
  const int r = ((device int32_t *) src1)[i];
904
1176
 
905
- dequantize_row_q4_k(
906
- (device const block_q4_k *) ((device char *) src0 + r*nb01),
1177
+ dequantize_row_q4_K(
1178
+ (device const block_q4_K *) ((device char *) src0 + r*nb01),
907
1179
  (device float *) ((device char *) dst + i*nb1), ne00);
908
1180
  }
909
1181
 
910
- kernel void kernel_get_rows_q5_k(
1182
+ kernel void kernel_get_rows_q5_K(
911
1183
  device const void * src0,
912
1184
  device const int * src1,
913
1185
  device float * dst,
@@ -918,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
918
1190
  const int i = tpig;
919
1191
  const int r = ((device int32_t *) src1)[i];
920
1192
 
921
- dequantize_row_q5_k(
922
- (device const block_q5_k *) ((device char *) src0 + r*nb01),
1193
+ dequantize_row_q5_K(
1194
+ (device const block_q5_K *) ((device char *) src0 + r*nb01),
923
1195
  (device float *) ((device char *) dst + i*nb1), ne00);
924
1196
  }
925
1197
 
926
- kernel void kernel_get_rows_q6_k(
1198
+ kernel void kernel_get_rows_q6_K(
927
1199
  device const void * src0,
928
1200
  device const int * src1,
929
1201
  device float * dst,
@@ -934,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
934
1206
  const int i = tpig;
935
1207
  const int r = ((device int32_t *) src1)[i];
936
1208
 
937
- dequantize_row_q6_k(
938
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
1209
+ dequantize_row_q6_K(
1210
+ (device const block_q6_K *) ((device char *) src0 + r*nb01),
939
1211
  (device float *) ((device char *) dst + i*nb1), ne00);
940
1212
  }
941
1213
 
942
1214
  //====================================== dot products =========================
943
1215
 
944
- kernel void kernel_mul_mat_q2_k_f32(
1216
+ kernel void kernel_mul_mat_q2_K_f32(
945
1217
  device const void * src0,
946
1218
  device const float * src1,
947
1219
  device float * dst,
@@ -958,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
958
1230
  const int64_t r0 = tgpig.x;
959
1231
  const int64_t r1 = tgpig.y;
960
1232
 
961
- device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
1233
+ device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
962
1234
  device const float * yy = (device const float *) src1 + r1*ne10;
963
1235
 
964
1236
  const int nth = tptg.x*tptg.y;
965
1237
  const int ith = tptg.y*tpitg.x + tpitg.y;
966
1238
 
1239
+ float sumf = 0;
1240
+
1241
+ #if QK_K == 256
967
1242
  const int tid = tpitg.y; // 0...16
968
1243
  const int il = tid/4; // 0...3
969
1244
  const int ir = tid%4; // 0...3
@@ -976,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
976
1251
  const int y_offset = 64*il + n*ir;
977
1252
  const int q_offset = 32*ip + n*ir;
978
1253
 
979
- sum[ith] = 0.0f;
980
-
981
- float sumf = 0;
982
1254
  for (int i = tpitg.x; i < nb; i += tptg.x) {
983
1255
 
984
1256
  device const uint8_t * q = x[i].qs + q_offset;
@@ -991,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
991
1263
 
992
1264
  device const float * y = yy + i*QK_K + y_offset;
993
1265
 
994
- //float4 s = {0.f, 0.f, 0.f, 0.f};
995
1266
  float2 s = {0.f, 0.f};
996
1267
  float smin = 0;
997
1268
  for (int l = 0; l < n; ++l) {
@@ -1006,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
1006
1277
  sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1007
1278
 
1008
1279
  }
1009
- sum[ith] = sumf;
1280
+ #else
1281
+ const int il = 4 * tpitg.x;
1010
1282
 
1011
- //int mask1 = (ith%4 == 0);
1012
- //int mask2 = (ith%16 == 0);
1283
+ uint32_t aux[2];
1284
+ thread const uint8_t * d = (thread const uint8_t *)aux;
1285
+ thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1013
1286
 
1014
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1015
- //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1016
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1017
- //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1018
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1019
- //if (ith == 0) {
1020
- // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1021
- // dst[r1*ne0 + r0] = sum[0];
1022
- //}
1287
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+
1289
+ device const uint8_t * q = x[i].qs + il;
1290
+ device const float * y = yy + i*QK_K + il;
1291
+
1292
+ const float dall = (float)x[i].d;
1293
+ const float dmin = (float)x[i].dmin;
1294
+
1295
+ device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
+ aux[0] = a[0] & 0x0f0f0f0f;
1297
+ aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+
1299
+ for (int l = 0; l < 4; ++l) {
1300
+ sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
+ + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
+ + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
+ + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ }
1305
+ }
1306
+ #endif
1307
+
1308
+ sum[ith] = sumf;
1023
1309
 
1024
1310
  //
1025
1311
  // Accumulate the sum from all threads in the threadgroup
1026
- // This version is slightly faster than the commented out one below,
1027
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1028
1312
  //
1029
1313
  threadgroup_barrier(mem_flags::mem_threadgroup);
1030
1314
  if (ith%4 == 0) {
@@ -1041,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
1041
1325
  }
1042
1326
  }
1043
1327
 
1044
- kernel void kernel_mul_mat_q3_k_f32(
1328
+ kernel void kernel_mul_mat_q3_K_f32(
1045
1329
  device const void * src0,
1046
1330
  device const float * src1,
1047
1331
  device float * dst,
@@ -1054,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
1054
1338
  uint2 tpitg[[thread_position_in_threadgroup]],
1055
1339
  uint2 tptg[[threads_per_threadgroup]]) {
1056
1340
 
1057
- const uint16_t kmask1 = 0x0303;
1058
- const uint16_t kmask2 = 0x0f0f;
1059
-
1060
- const uint8_t m3 = 3;
1061
- const int8_t m4 = 4;
1062
-
1063
1341
  const int nb = ne00/QK_K;
1064
1342
 
1065
1343
  const int64_t r0 = tgpig.x;
1066
1344
  const int64_t r1 = tgpig.y;
1067
1345
 
1068
- device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1346
+ device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1069
1347
  device const float * yy = (device const float *) src1 + r1*ne10;
1070
1348
 
1071
1349
  const int nth = tptg.x*tptg.y;
1072
1350
  const int ith = tptg.y*tpitg.x + tpitg.y;
1073
1351
 
1352
+ #if QK_K == 256
1353
+
1354
+ const uint8_t m3 = 3;
1355
+ const int8_t m4 = 4;
1356
+
1357
+ const uint16_t kmask1 = 0x0303;
1358
+ const uint16_t kmask2 = 0x0f0f;
1359
+
1074
1360
  const int tid = tpitg.y; // expecting 16
1075
1361
  const int ip = tid/8; // 0 or 1
1076
1362
  const int il = tid/2 - 4*ip; // 0...3
@@ -1124,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
1124
1410
 
1125
1411
  //sum[ith] = sumf;
1126
1412
  sum[ith] = sumf1 - 32.f*sumf2;
1413
+ #else
1414
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
+ const int im = il/8; // 0, 0, 1, 1
1416
+ const int in = il%8; // 0, 4, 0, 4
1417
+
1418
+ float sumf = 0;
1419
+
1420
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1421
+
1422
+ const float d_all = (float)(x[i].d);
1423
+
1424
+ device const uint8_t * q = x[i].qs + il;
1425
+ device const uint8_t * h = x[i].hmask + in;
1426
+ device const float * y = yy + i * QK_K + il;
1427
+
1428
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1432
+
1433
+ for (int l = 0; l < 4; ++l) {
1434
+ const uint8_t hm = h[l] >> im;
1435
+ sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
+ + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
+ + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
+ + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
+ }
1440
+
1441
+ }
1442
+
1443
+ sum[ith] = sumf;
1444
+
1445
+ #endif
1127
1446
 
1128
1447
  //
1129
1448
  // Accumulate the sum from all threads in the threadgroup
@@ -1144,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
1144
1463
 
1145
1464
  }
1146
1465
 
1147
- kernel void kernel_mul_mat_q4_k_f32(
1466
+ kernel void kernel_mul_mat_q4_K_f32(
1148
1467
  device const void * src0,
1149
1468
  device const float * src1,
1150
1469
  device float * dst,
@@ -1156,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
1156
1475
  uint2 tpitg[[thread_position_in_threadgroup]],
1157
1476
  uint2 tptg[[threads_per_threadgroup]]) {
1158
1477
 
1159
- const uint16_t kmask1 = 0x3f3f;
1160
- const uint16_t kmask2 = 0x0f0f;
1161
- const uint16_t kmask3 = 0xc0c0;
1162
-
1163
1478
  const int nb = ne00/QK_K;
1164
1479
 
1165
1480
  const int64_t r0 = tgpig.x;
1166
1481
  const int64_t r1 = tgpig.y;
1167
1482
 
1168
- device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
1169
- device const float * yy = (device const float *) src1 + r1*ne10;
1170
-
1171
1483
  const int nth = tptg.x*tptg.y;
1172
1484
  const int ith = tptg.y*tpitg.x + tpitg.y;
1173
1485
 
1486
+ device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1487
+ device const float * yy = (device const float *) src1 + r1*ne10;
1488
+
1489
+ float sumf = 0;
1490
+
1491
+ #if QK_K == 256
1492
+
1493
+ const uint16_t kmask1 = 0x3f3f;
1494
+ const uint16_t kmask2 = 0x0f0f;
1495
+ const uint16_t kmask3 = 0xc0c0;
1496
+
1174
1497
  const int tid = tpitg.y; // 0...16
1175
1498
  const int il = tid/4; // 0...3
1176
1499
  const int ir = tid - 4*il;// 0...3
@@ -1183,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
1183
1506
  const int q_offset = 32*im + l0;
1184
1507
  const int y_offset = 64*im + l0;
1185
1508
 
1186
- sum[ith] = 0.0f;
1187
-
1188
1509
  uchar2 sc1, sc2, sc3, sc4;
1189
1510
 
1190
- float sumf = 0;
1191
1511
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1192
1512
 
1193
1513
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1216,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
1216
1536
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1217
1537
 
1218
1538
  }
1539
+ #else
1540
+ uint16_t aux16[2];
1541
+ thread const uint8_t * scales = (thread const uint8_t *)aux16;
1542
+
1543
+ const int il = 4*tpitg.x;
1544
+
1545
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1546
+
1547
+ device const uint8_t * q = x[i].qs + il;
1548
+ device const float * y = yy + i * QK_K + il;
1549
+
1550
+ const float d = (float)x[i].d[0];
1551
+ const float m = (float)x[i].d[1];
1552
+
1553
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
+ aux16[0] = a[0] & 0x0f0f;
1555
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
1556
+
1557
+ for (int l = 0; l < 4; ++l) {
1558
+ sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
+ + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1560
+ }
1561
+ }
1562
+ #endif
1219
1563
 
1220
1564
  sum[ith] = sumf;
1221
1565
 
@@ -1252,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
1252
1596
  //}
1253
1597
  }
1254
1598
 
1255
- kernel void kernel_mul_mat_q5_k_f32(
1599
+ kernel void kernel_mul_mat_q5_K_f32(
1256
1600
  device const void * src0,
1257
1601
  device const float * src1,
1258
1602
  device float * dst,
@@ -1264,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
1264
1608
  uint2 tpitg[[thread_position_in_threadgroup]],
1265
1609
  uint2 tptg[[threads_per_threadgroup]]) {
1266
1610
 
1267
- const uint16_t kmask1 = 0x3f3f;
1268
- const uint16_t kmask2 = 0x0f0f;
1269
- const uint16_t kmask3 = 0xc0c0;
1270
-
1271
1611
  const int nb = ne00/QK_K;
1272
1612
 
1273
1613
  const int64_t r0 = tgpig.x;
1274
1614
  const int64_t r1 = tgpig.y;
1275
1615
 
1276
- device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1616
+ device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1277
1617
  device const float * yy = (device const float *) src1 + r1*ne10;
1278
1618
 
1279
1619
  const int nth = tptg.x*tptg.y;
1280
1620
  const int ith = tptg.y*tpitg.x + tpitg.y;
1281
1621
 
1622
+ float sumf = 0;
1623
+
1624
+ #if QK_K == 256
1625
+
1626
+ const uint16_t kmask1 = 0x3f3f;
1627
+ const uint16_t kmask2 = 0x0f0f;
1628
+ const uint16_t kmask3 = 0xc0c0;
1629
+
1282
1630
  const int tid = tpitg.y; // 0...16
1283
1631
  const int il = tid/4; // 0...3
1284
1632
  const int ir = tid - 4*il;// 0...3
@@ -1298,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
1298
1646
 
1299
1647
  uchar2 sc1, sc2, sc3, sc4;
1300
1648
 
1301
- float sumf = 0;
1302
1649
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1303
1650
 
1304
1651
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1330,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
1330
1677
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1331
1678
 
1332
1679
  }
1680
+ #else
1681
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
+ const int im = il/8; // 0, 0, 1, 1
1683
+ const int in = il%8; // 0, 4, 0, 4
1684
+
1685
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1686
+
1687
+ const float d = (float)x[i].d;
1688
+ device const uint8_t * q = x[i].qs + il;
1689
+ device const uint8_t * h = x[i].qh + in;
1690
+ device const int8_t * s = x[i].scales;
1691
+ device const float * y = yy + i*QK_K + il;
1692
+
1693
+ for (int l = 0; l < 4; ++l) {
1694
+ const uint8_t hl = h[l] >> im;
1695
+ sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
+ + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
+ + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
+ + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1699
+ }
1700
+ }
1701
+ #endif
1333
1702
  sum[ith] = sumf;
1334
1703
 
1335
1704
  //
@@ -1351,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
1351
1720
 
1352
1721
  }
1353
1722
 
1354
- kernel void kernel_mul_mat_q6_k_f32(
1723
+ kernel void kernel_mul_mat_q6_K_f32(
1355
1724
  device const void * src0,
1356
1725
  device const float * src1,
1357
1726
  device float * dst,
@@ -1373,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
1373
1742
  const int64_t r0 = tgpig.x;
1374
1743
  const int64_t r1 = tgpig.y;
1375
1744
 
1376
- device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1745
+ device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1377
1746
  device const float * yy = (device const float *) src1 + r1*ne10;
1378
1747
 
1379
1748
  const int nth = tptg.x*tptg.y;
1380
1749
  const int ith = tptg.y*tpitg.x + tpitg.y;
1381
1750
 
1751
+ float sumf = 0;
1752
+
1753
+ #if QK_K == 256
1382
1754
  // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1383
1755
  const int iqs = 16 * tpitg.y;
1384
1756
  const int ip = iqs / 128; // 0 or 1
@@ -1391,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
1391
1763
  const int q_offset_l = 64*ip + l0;
1392
1764
  const int q_offset_h = 32*ip + l0;
1393
1765
 
1394
- float sumf = 0;
1395
1766
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1396
1767
 
1397
1768
  device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1413,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
1413
1784
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1414
1785
 
1415
1786
  }
1787
+ #else
1788
+ const int il = 4*tpitg.x; // 0, 4, 8, 12
1789
+
1790
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1791
+ device const float * y = yy + i * QK_K + il;
1792
+ device const uint8_t * ql = x[i].ql + il;
1793
+ device const uint8_t * qh = x[i].qh + il;
1794
+ device const int8_t * s = x[i].scales;
1795
+
1796
+ const float d = x[i].d;
1797
+
1798
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1799
+ for (int l = 0; l < 4; ++l) {
1800
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1801
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1802
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
1803
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1804
+ }
1805
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
1806
+ }
1807
+
1808
+ #endif
1416
1809
 
1417
1810
  sum[ith] = sumf;
1418
1811