llama_cpp 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -304,34 +304,22 @@ kernel void kernel_mul_mat_q4_0_f32(
304
304
  device const float * src1,
305
305
  device float * dst,
306
306
  constant int64_t & ne00,
307
- constant int64_t & ne01,
308
- constant uint64_t & nb00,
309
- constant uint64_t & nb01,
310
- constant uint64_t & nb02,
311
307
  constant int64_t & ne10,
312
- constant int64_t & ne11,
313
- constant uint64_t & nb10,
314
- constant uint64_t & nb11,
315
- constant uint64_t & nb12,
316
308
  constant int64_t & ne0,
317
- constant int64_t & ne1,
318
309
  threadgroup float * sum [[threadgroup(0)]],
319
310
  uint2 tgpig[[threadgroup_position_in_grid]],
320
- uint2 tpig[[thread_position_in_grid]],
321
311
  uint2 tpitg[[thread_position_in_threadgroup]],
322
312
  uint2 tptg[[threads_per_threadgroup]]) {
323
313
  const int nb = ne00/QK4_0;
324
314
 
325
- const int8_t m8 = 8;
326
-
327
315
  const int64_t r0 = tgpig.x;
328
316
  const int64_t r1 = tgpig.y;
329
317
 
330
318
  device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
331
319
  device const float * y = (device const float *) src1 + r1*ne10;
332
320
 
333
- const uint nth = tptg.x*tptg.y;
334
- const uint ith = tptg.y*tpitg.x + tpitg.y;
321
+ const int nth = tptg.x*tptg.y;
322
+ const int ith = tptg.y*tpitg.x + tpitg.y;
335
323
 
336
324
  const int ix = tpitg.y/4; // 0 or 1
337
325
  const int iy = tpitg.y - 4*ix; // 0...3
@@ -351,47 +339,32 @@ kernel void kernel_mul_mat_q4_0_f32(
351
339
 
352
340
  for (int j = 0; j < 4; ++j) {
353
341
 
354
- acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
355
- acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
342
+ acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
343
+ acc[1] += yl[j] + yl[j+16];
356
344
 
357
345
  }
358
346
 
359
- sumf += d * (acc[0] + acc[1]);
347
+ sumf += d * (acc[0] - 8.f*acc[1]);
360
348
  }
361
349
 
362
350
  sum[ith] = sumf;
363
351
 
364
352
  //
365
353
  // Accumulate the sum from all threads in the threadgroup
366
- // This version is slightly faster than the commented out one below,
367
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
368
354
  //
369
355
  threadgroup_barrier(mem_flags::mem_threadgroup);
370
356
  if (ith%4 == 0) {
371
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
357
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
372
358
  }
373
359
  threadgroup_barrier(mem_flags::mem_threadgroup);
374
360
  if (ith%16 == 0) {
375
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
361
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
376
362
  }
377
363
  threadgroup_barrier(mem_flags::mem_threadgroup);
378
364
  if (ith == 0) {
379
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
365
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
380
366
  dst[r1*ne0 + r0] = sum[0];
381
367
  }
382
-
383
- //// accumulate the sum from all threads in the threadgroup
384
- //threadgroup_barrier(mem_flags::mem_threadgroup);
385
- //for (uint i = nth/2; i > 0; i /= 2) {
386
- // if (ith < i) {
387
- // sum[ith] += sum[ith + i];
388
- // }
389
- // threadgroup_barrier(mem_flags::mem_threadgroup);
390
- //}
391
-
392
- //if (ith == 0) {
393
- // dst[r1*ne0 + r0] = sum[0];
394
- //}
395
368
  }
396
369
 
397
370
  kernel void kernel_mul_mat_q4_1_f32(
@@ -399,20 +372,10 @@ kernel void kernel_mul_mat_q4_1_f32(
399
372
  device const float * src1,
400
373
  device float * dst,
401
374
  constant int64_t & ne00,
402
- constant int64_t & ne01,
403
- constant uint64_t & nb00,
404
- constant uint64_t & nb01,
405
- constant uint64_t & nb02,
406
375
  constant int64_t & ne10,
407
- constant int64_t & ne11,
408
- constant uint64_t & nb10,
409
- constant uint64_t & nb11,
410
- constant uint64_t & nb12,
411
376
  constant int64_t & ne0,
412
- constant int64_t & ne1,
413
377
  threadgroup float * sum [[threadgroup(0)]],
414
378
  uint2 tgpig[[threadgroup_position_in_grid]],
415
- uint2 tpig[[thread_position_in_grid]],
416
379
  uint2 tpitg[[thread_position_in_threadgroup]],
417
380
  uint2 tptg[[threads_per_threadgroup]]) {
418
381
  const int nb = ne00/QK4_1;
@@ -460,11 +423,11 @@ kernel void kernel_mul_mat_q4_1_f32(
460
423
  //
461
424
  threadgroup_barrier(mem_flags::mem_threadgroup);
462
425
  if (ith%4 == 0) {
463
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
426
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
464
427
  }
465
428
  threadgroup_barrier(mem_flags::mem_threadgroup);
466
429
  if (ith%16 == 0) {
467
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
430
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
468
431
  }
469
432
  threadgroup_barrier(mem_flags::mem_threadgroup);
470
433
  if (ith == 0) {
@@ -671,6 +634,15 @@ typedef struct {
671
634
  half d; // super-block scale for quantized scales
672
635
  half dmin; // super-block scale for quantized mins
673
636
  } block_q2_k;
637
+ // 84 bytes / block
638
+
639
+ typedef struct {
640
+ uint8_t hmask[QK_K/8]; // quants - high bit
641
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
642
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
643
+ half d; // super-block scale
644
+ } block_q3_k;
645
+ // 110 bytes / block
674
646
 
675
647
  typedef struct {
676
648
  half d; // super-block scale for quantized scales
@@ -678,6 +650,16 @@ typedef struct {
678
650
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
679
651
  uint8_t qs[QK_K/2]; // 4--bit quants
680
652
  } block_q4_k;
653
+ // 144 bytes / block
654
+
655
+ typedef struct {
656
+ half d; // super-block scale for quantized scales
657
+ half dmin; // super-block scale for quantized mins
658
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
659
+ uint8_t qh[QK_K/8]; // quants, high bit
660
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
661
+ } block_q5_k;
662
+ // 176 bytes / block
681
663
 
682
664
  typedef struct {
683
665
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -685,16 +667,19 @@ typedef struct {
685
667
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
686
668
  half d; // super-block scale
687
669
  } block_q6_k;
670
+ // 210 bytes / block
688
671
 
689
672
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
690
673
  uchar4 r;
691
674
  if (j < 4) {
692
- r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
693
- r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
675
+ r[0] = q[j+0] & 63;
676
+ r[2] = q[j+1] & 63;
677
+ r[1] = q[j+4] & 63;
678
+ r[3] = q[j+5] & 63;
694
679
  } else {
695
680
  r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
696
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
697
681
  r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
682
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
698
683
  r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
699
684
  }
700
685
  return r;
@@ -735,10 +720,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
735
720
  }
736
721
  }
737
722
 
723
+ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
724
+ assert(k % QK_K == 0);
725
+ const int nb = k / QK_K;
726
+
727
+ const uint16_t kmask1 = 0x0303;
728
+ const uint16_t kmask2 = 0x0f0f;
729
+
730
+ uint16_t aux[8];
731
+ thread const int8_t * scales = (thread const int8_t*)aux;
732
+
733
+ for (int i = 0; i < nb; i++) {
734
+
735
+ const float d_all = (float)(x[i].d);
736
+
737
+ device const uint8_t * q = x[i].qs;
738
+ device const uint8_t * h = x[i].hmask;
739
+ uint8_t m = 1;
740
+
741
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
742
+ aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
743
+ aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
744
+ aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
745
+ aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
746
+ aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
747
+ aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
748
+ aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
749
+ aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
750
+
751
+ int is = 0;
752
+ float dl;
753
+ for (int n = 0; n < QK_K; n += 128) {
754
+ int shift = 0;
755
+ for (int j = 0; j < 4; ++j) {
756
+
757
+ dl = d_all * (scales[is++] - 32);
758
+ for (int l = 0; l < 16; ++l) {
759
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
760
+ }
761
+
762
+ dl = d_all * (scales[is++] - 32);
763
+ for (int l = 0; l < 16; ++l) {
764
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
765
+ }
766
+
767
+ shift += 2;
768
+ m <<= 1;
769
+ }
770
+ q += 32;
771
+ }
772
+
773
+ }
774
+
775
+ }
776
+
738
777
  static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
739
778
  assert(k % QK_K == 0);
740
779
  const int nb = k / QK_K;
741
780
 
781
+
742
782
  for (int i = 0; i < nb; i++) {
743
783
 
744
784
  const float d = x[i].d;
@@ -760,6 +800,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
760
800
  }
761
801
  }
762
802
 
803
+ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
804
+ assert(k % QK_K == 0);
805
+ const int nb = k / QK_K;
806
+
807
+ for (int i = 0; i < nb; i++) {
808
+
809
+ const float d = (float)(x[i].d);
810
+ const float min = (float)(x[i].dmin);
811
+
812
+ device const uint8_t * ql = x[i].qs;
813
+ device const uint8_t * qh = x[i].qh;
814
+
815
+ int is = 0;
816
+ uint8_t u1 = 1, u2 = 2;
817
+ for (int j = 0; j < QK_K; j += 64) {
818
+ const uchar4 sc = get_scale_min_k4(is, x[i].scales);
819
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
820
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
821
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
822
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
823
+ ql += 32; is += 2;
824
+ u1 <<= 2; u2 <<= 2;
825
+ }
826
+ }
827
+
828
+ }
829
+
763
830
  static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
764
831
  assert(k % QK_K == 0);
765
832
  const int nb = k / QK_K;
@@ -808,6 +875,22 @@ kernel void kernel_get_rows_q2_k(
808
875
  (device float *) ((device char *) dst + i*nb1), ne00);
809
876
  }
810
877
 
878
+ kernel void kernel_get_rows_q3_k(
879
+ device const void * src0,
880
+ device const int * src1,
881
+ device float * dst,
882
+ constant int64_t & ne00,
883
+ constant uint64_t & nb01,
884
+ constant uint64_t & nb1,
885
+ uint tpig[[thread_position_in_grid]]) {
886
+ const int i = tpig;
887
+ const int r = ((device int32_t *) src1)[i];
888
+
889
+ dequantize_row_q3_k(
890
+ (device const block_q3_k *) ((device char *) src0 + r*nb01),
891
+ (device float *) ((device char *) dst + i*nb1), ne00);
892
+ }
893
+
811
894
  kernel void kernel_get_rows_q4_k(
812
895
  device const void * src0,
813
896
  device const int * src1,
@@ -824,6 +907,22 @@ kernel void kernel_get_rows_q4_k(
824
907
  (device float *) ((device char *) dst + i*nb1), ne00);
825
908
  }
826
909
 
910
+ kernel void kernel_get_rows_q5_k(
911
+ device const void * src0,
912
+ device const int * src1,
913
+ device float * dst,
914
+ constant int64_t & ne00,
915
+ constant uint64_t & nb01,
916
+ constant uint64_t & nb1,
917
+ uint tpig[[thread_position_in_grid]]) {
918
+ const int i = tpig;
919
+ const int r = ((device int32_t *) src1)[i];
920
+
921
+ dequantize_row_q5_k(
922
+ (device const block_q5_k *) ((device char *) src0 + r*nb01),
923
+ (device float *) ((device char *) dst + i*nb1), ne00);
924
+ }
925
+
827
926
  kernel void kernel_get_rows_q6_k(
828
927
  device const void * src0,
829
928
  device const int * src1,
@@ -847,20 +946,10 @@ kernel void kernel_mul_mat_q2_k_f32(
847
946
  device const float * src1,
848
947
  device float * dst,
849
948
  constant int64_t & ne00,
850
- constant int64_t & ne01,
851
- constant uint64_t & nb00,
852
- constant uint64_t & nb01,
853
- constant uint64_t & nb02,
854
949
  constant int64_t & ne10,
855
- constant int64_t & ne11,
856
- constant uint64_t & nb10,
857
- constant uint64_t & nb11,
858
- constant uint64_t & nb12,
859
950
  constant int64_t & ne0,
860
- constant int64_t & ne1,
861
951
  threadgroup float * sum [[threadgroup(0)]],
862
952
  uint2 tgpig[[threadgroup_position_in_grid]],
863
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
864
953
  uint2 tpitg[[thread_position_in_threadgroup]],
865
954
  uint2 tptg[[threads_per_threadgroup]]) {
866
955
 
@@ -875,7 +964,6 @@ kernel void kernel_mul_mat_q2_k_f32(
875
964
  const int nth = tptg.x*tptg.y;
876
965
  const int ith = tptg.y*tpitg.x + tpitg.y;
877
966
 
878
-
879
967
  const int tid = tpitg.y; // 0...16
880
968
  const int il = tid/4; // 0...3
881
969
  const int ir = tid%4; // 0...3
@@ -885,35 +973,54 @@ kernel void kernel_mul_mat_q2_k_f32(
885
973
  const int n = 8;
886
974
  const int is = 4*il + (n*ir)/16;
887
975
 
976
+ const int y_offset = 64*il + n*ir;
977
+ const int q_offset = 32*ip + n*ir;
978
+
888
979
  sum[ith] = 0.0f;
889
980
 
890
981
  float sumf = 0;
891
982
  for (int i = tpitg.x; i < nb; i += tptg.x) {
892
983
 
893
- device const uint8_t * q = x[i].qs + 32*ip + n*ir;
984
+ device const uint8_t * q = x[i].qs + q_offset;
894
985
  device const uint8_t * scales = x[i].scales + is;
895
986
 
896
987
  uint8_t d1 = scales[0] & 0xF;
897
- uint8_t m1 = scales[0] >> 4;
898
988
  uint8_t d2 = scales[2] & 0xF;
989
+ uint8_t m1 = scales[0] >> 4;
899
990
  uint8_t m2 = scales[2] >> 4;
900
991
 
901
- device const float * y = yy + i*QK_K + 64*il + n*ir;
902
-
903
- const float dall = (float)x[i].d;
904
- const float dmin = (float)x[i].dmin;
992
+ device const float * y = yy + i*QK_K + y_offset;
905
993
 
906
- float4 s = {0.f, 0.f, 0.f, 0.f};
994
+ //float4 s = {0.f, 0.f, 0.f, 0.f};
995
+ float2 s = {0.f, 0.f};
996
+ float smin = 0;
907
997
  for (int l = 0; l < n; ++l) {
908
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
909
- s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
998
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
999
+ s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1000
+ smin += y[l+ 0] * m1 + y[l+32] * m2;
910
1001
  }
911
- sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
912
1002
 
1003
+ const float dall = (float)x[i].d;
1004
+ const float dmin = (float)x[i].dmin;
1005
+
1006
+ sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
913
1007
 
914
1008
  }
915
1009
  sum[ith] = sumf;
916
1010
 
1011
+ //int mask1 = (ith%4 == 0);
1012
+ //int mask2 = (ith%16 == 0);
1013
+
1014
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1015
+ //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1016
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1017
+ //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1018
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1019
+ //if (ith == 0) {
1020
+ // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1021
+ // dst[r1*ne0 + r0] = sum[0];
1022
+ //}
1023
+
917
1024
  //
918
1025
  // Accumulate the sum from all threads in the threadgroup
919
1026
  // This version is slightly faster than the commented out one below,
@@ -932,19 +1039,109 @@ kernel void kernel_mul_mat_q2_k_f32(
932
1039
  for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
933
1040
  dst[r1*ne0 + r0] = sum[0];
934
1041
  }
1042
+ }
935
1043
 
936
- //// accumulate the sum from all threads in the threadgroup
937
- //threadgroup_barrier(mem_flags::mem_threadgroup);
938
- //for (uint i = nth/2; i > 0; i /= 2) {
939
- // if (ith < i) {
940
- // sum[ith] += sum[ith + i];
941
- // }
942
- // threadgroup_barrier(mem_flags::mem_threadgroup);
943
- //}
1044
+ kernel void kernel_mul_mat_q3_k_f32(
1045
+ device const void * src0,
1046
+ device const float * src1,
1047
+ device float * dst,
1048
+ constant int64_t & ne00,
1049
+ constant int64_t & ne10,
1050
+ constant int64_t & ne0,
1051
+ constant int64_t & ne1,
1052
+ threadgroup float * sum [[threadgroup(0)]],
1053
+ uint2 tgpig[[threadgroup_position_in_grid]],
1054
+ uint2 tpitg[[thread_position_in_threadgroup]],
1055
+ uint2 tptg[[threads_per_threadgroup]]) {
1056
+
1057
+ const uint16_t kmask1 = 0x0303;
1058
+ const uint16_t kmask2 = 0x0f0f;
1059
+
1060
+ const uint8_t m3 = 3;
1061
+ const int8_t m4 = 4;
1062
+
1063
+ const int nb = ne00/QK_K;
1064
+
1065
+ const int64_t r0 = tgpig.x;
1066
+ const int64_t r1 = tgpig.y;
1067
+
1068
+ device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1069
+ device const float * yy = (device const float *) src1 + r1*ne10;
1070
+
1071
+ const int nth = tptg.x*tptg.y;
1072
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1073
+
1074
+ const int tid = tpitg.y; // expecting 16
1075
+ const int ip = tid/8; // 0 or 1
1076
+ const int il = tid/2 - 4*ip; // 0...3
1077
+ const int ir = tid%2;
1078
+ const int n = 8;
1079
+ const int l0 = n*ir;
1080
+
1081
+ const uint8_t m = 1 << (4*ip + il);
1082
+
1083
+ const int shift = 2*il;
1084
+
1085
+ const uint16_t s_shift1 = 4*ip;
1086
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1087
+ const int ik = 4 + (il%2);
1088
+
1089
+ const int q_offset = 32*ip + l0;
1090
+ const int y_offset = 128*ip + 32*il + l0;
1091
+
1092
+ //float sumf = 0;
1093
+ float sumf1 = 0, sumf2 = 0;
1094
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1095
+
1096
+ const float d_all = (float)(x[i].d);
1097
+
1098
+ device const uint8_t * q = x[i].qs + q_offset;
1099
+ device const uint8_t * h = x[i].hmask + l0;
1100
+ device const float * y = yy + i * QK_K + y_offset;
1101
+
1102
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1103
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1104
+
1105
+ float s = 0;
1106
+ for (int l = 0; l < n; ++l) {
1107
+ s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1108
+ }
1109
+ float d = d_all * s;
1110
+ sumf1 += d * scales[0];
1111
+ sumf2 += d;
1112
+ //sumf += d_all * s * (scales[0] - 32);
1113
+
1114
+ s = 0;
1115
+ for (int l = 0; l < n; ++l) {
1116
+ s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1117
+ }
1118
+ d = d_all * s;
1119
+ sumf1 += d * scales[1];
1120
+ sumf2 += d;
1121
+ //sumf += d_all * s * (scales[1] - 32);
1122
+
1123
+ }
1124
+
1125
+ //sum[ith] = sumf;
1126
+ sum[ith] = sumf1 - 32.f*sumf2;
1127
+
1128
+ //
1129
+ // Accumulate the sum from all threads in the threadgroup
1130
+ //
1131
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1132
+ if (ith%4 == 0) {
1133
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1134
+ }
1135
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1136
+ if (ith%16 == 0) {
1137
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1138
+ }
1139
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1140
+ if (ith == 0) {
1141
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1142
+ dst[r1*ne0 + r0] = sum[0];
1143
+ }
944
1144
 
945
- //if (ith == 0) {
946
- // dst[r1*ne0 + r0] = sum[0];
947
- //}
948
1145
  }
949
1146
 
950
1147
  kernel void kernel_mul_mat_q4_k_f32(
@@ -952,23 +1149,17 @@ kernel void kernel_mul_mat_q4_k_f32(
952
1149
  device const float * src1,
953
1150
  device float * dst,
954
1151
  constant int64_t & ne00,
955
- constant int64_t & ne01,
956
- constant uint64_t & nb00,
957
- constant uint64_t & nb01,
958
- constant uint64_t & nb02,
959
1152
  constant int64_t & ne10,
960
- constant int64_t & ne11,
961
- constant uint64_t & nb10,
962
- constant uint64_t & nb11,
963
- constant uint64_t & nb12,
964
1153
  constant int64_t & ne0,
965
- constant int64_t & ne1,
966
1154
  threadgroup float * sum [[threadgroup(0)]],
967
1155
  uint2 tgpig[[threadgroup_position_in_grid]],
968
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
969
1156
  uint2 tpitg[[thread_position_in_threadgroup]],
970
1157
  uint2 tptg[[threads_per_threadgroup]]) {
971
1158
 
1159
+ const uint16_t kmask1 = 0x3f3f;
1160
+ const uint16_t kmask2 = 0x0f0f;
1161
+ const uint16_t kmask3 = 0xc0c0;
1162
+
972
1163
  const int nb = ne00/QK_K;
973
1164
 
974
1165
  const int64_t r0 = tgpig.x;
@@ -977,37 +1168,55 @@ kernel void kernel_mul_mat_q4_k_f32(
977
1168
  device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
978
1169
  device const float * yy = (device const float *) src1 + r1*ne10;
979
1170
 
980
- const uint nth = tptg.x*tptg.y;
981
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1171
+ const int nth = tptg.x*tptg.y;
1172
+ const int ith = tptg.y*tpitg.x + tpitg.y;
982
1173
 
983
1174
  const int tid = tpitg.y; // 0...16
984
1175
  const int il = tid/4; // 0...3
985
- const int ir = tid%4; // 0...3
986
- const int n = 8;
987
- const int is = 2*il;
1176
+ const int ir = tid - 4*il;// 0...3
1177
+ const int n = 4;
1178
+
1179
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1180
+ const int in = il%2;
1181
+
1182
+ const int l0 = n*(2*ir + in);
1183
+ const int q_offset = 32*im + l0;
1184
+ const int y_offset = 64*im + l0;
988
1185
 
989
1186
  sum[ith] = 0.0f;
990
1187
 
1188
+ uchar2 sc1, sc2, sc3, sc4;
1189
+
991
1190
  float sumf = 0;
992
1191
  for (int i = tpitg.x; i < nb; i += tptg.x) {
993
1192
 
994
- device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995
- device const float * y = yy + i*QK_K + 64*il + n*ir;
996
- device const uint8_t * scales = (x + i)->scales;
1193
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1194
+ device const uint8_t * q2 = q1 + 64;
1195
+ device const float * y1 = yy + i*QK_K + y_offset;
1196
+ device const float * y2 = y1 + 128;
997
1197
 
998
1198
  const float dall = (float)((x + i)->d);
999
1199
  const float dmin = (float)((x + i)->dmin);
1000
1200
 
1001
- const uchar4 sc = get_scale_min_k4(is, scales);
1201
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1202
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1203
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1204
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1205
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1002
1206
 
1003
1207
  float4 s = {0.f, 0.f, 0.f, 0.f};
1208
+ float smin = 0;
1004
1209
  for (int l = 0; l < n; ++l) {
1005
- s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006
- s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1210
+
1211
+ s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1212
+ s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1213
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1214
+
1007
1215
  }
1008
- sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1216
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1009
1217
 
1010
1218
  }
1219
+
1011
1220
  sum[ith] = sumf;
1012
1221
 
1013
1222
  //
@@ -1043,25 +1252,114 @@ kernel void kernel_mul_mat_q4_k_f32(
1043
1252
  //}
1044
1253
  }
1045
1254
 
1255
+ kernel void kernel_mul_mat_q5_k_f32(
1256
+ device const void * src0,
1257
+ device const float * src1,
1258
+ device float * dst,
1259
+ constant int64_t & ne00,
1260
+ constant int64_t & ne10,
1261
+ constant int64_t & ne0,
1262
+ threadgroup float * sum [[threadgroup(0)]],
1263
+ uint2 tgpig[[threadgroup_position_in_grid]],
1264
+ uint2 tpitg[[thread_position_in_threadgroup]],
1265
+ uint2 tptg[[threads_per_threadgroup]]) {
1266
+
1267
+ const uint16_t kmask1 = 0x3f3f;
1268
+ const uint16_t kmask2 = 0x0f0f;
1269
+ const uint16_t kmask3 = 0xc0c0;
1270
+
1271
+ const int nb = ne00/QK_K;
1272
+
1273
+ const int64_t r0 = tgpig.x;
1274
+ const int64_t r1 = tgpig.y;
1275
+
1276
+ device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1277
+ device const float * yy = (device const float *) src1 + r1*ne10;
1278
+
1279
+ const int nth = tptg.x*tptg.y;
1280
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1281
+
1282
+ const int tid = tpitg.y; // 0...16
1283
+ const int il = tid/4; // 0...3
1284
+ const int ir = tid - 4*il;// 0...3
1285
+ const int n = 4;
1286
+
1287
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1288
+ const int in = il%2;
1289
+
1290
+ const int l0 = n*(2*ir + in);
1291
+ const int q_offset = 32*im + l0;
1292
+ const int y_offset = 64*im + l0;
1293
+
1294
+ const uint8_t hm1 = 1u << (2*im);
1295
+ const uint8_t hm2 = hm1 << 1;
1296
+ const uint8_t hm3 = hm1 << 4;
1297
+ const uint8_t hm4 = hm2 << 4;
1298
+
1299
+ uchar2 sc1, sc2, sc3, sc4;
1300
+
1301
+ float sumf = 0;
1302
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1303
+
1304
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1305
+ device const uint8_t * q2 = q1 + 64;
1306
+ device const uint8_t * qh = (x + i)->qh + l0;
1307
+ device const float * y1 = yy + i*QK_K + y_offset;
1308
+ device const float * y2 = y1 + 128;
1309
+
1310
+ const float dall = (float)((x + i)->d);
1311
+ const float dmin = (float)((x + i)->dmin);
1312
+
1313
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1314
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1315
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1316
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1317
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1318
+
1319
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1320
+ float smin = 0;
1321
+ for (int l = 0; l < n; ++l) {
1322
+
1323
+ s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1324
+ s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1325
+ s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1326
+ s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1327
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1328
+
1329
+ }
1330
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1331
+
1332
+ }
1333
+ sum[ith] = sumf;
1334
+
1335
+ //
1336
+ // Accumulate the sum from all threads in the threadgroup
1337
+ //
1338
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1339
+ if (ith%4 == 0) {
1340
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1341
+ }
1342
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1343
+ if (ith%16 == 0) {
1344
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1345
+ }
1346
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1347
+ if (ith == 0) {
1348
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1349
+ dst[r1*ne0 + r0] = sum[0];
1350
+ }
1351
+
1352
+ }
1353
+
1046
1354
  kernel void kernel_mul_mat_q6_k_f32(
1047
1355
  device const void * src0,
1048
1356
  device const float * src1,
1049
1357
  device float * dst,
1050
1358
  constant int64_t & ne00,
1051
- constant int64_t & ne01,
1052
- constant uint64_t & nb00,
1053
- constant uint64_t & nb01,
1054
- constant uint64_t & nb02,
1055
1359
  constant int64_t & ne10,
1056
- constant int64_t & ne11,
1057
- constant uint64_t & nb10,
1058
- constant uint64_t & nb11,
1059
- constant uint64_t & nb12,
1060
1360
  constant int64_t & ne0,
1061
- constant int64_t & ne1,
1062
1361
  threadgroup float * sum [[threadgroup(0)]],
1063
1362
  uint2 tgpig[[threadgroup_position_in_grid]],
1064
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
1065
1363
  uint2 tpitg[[thread_position_in_threadgroup]],
1066
1364
  uint2 tptg[[threads_per_threadgroup]]) {
1067
1365
 
@@ -1078,24 +1376,29 @@ kernel void kernel_mul_mat_q6_k_f32(
1078
1376
  device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1079
1377
  device const float * yy = (device const float *) src1 + r1*ne10;
1080
1378
 
1081
- const uint nth = tptg.x*tptg.y;
1082
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1379
+ const int nth = tptg.x*tptg.y;
1380
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1083
1381
 
1084
- const int step = QK_K / tptg.y; // we expect this to be 16
1085
- const int iqs = step * tpitg.y; // 0...240 in steps of 16
1382
+ // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1383
+ const int iqs = 16 * tpitg.y;
1086
1384
  const int ip = iqs / 128; // 0 or 1
1087
1385
  const int il = (iqs - 128*ip)/16; // 0...7
1088
1386
  const int n = 4;
1089
- const int is = 8*ip + (n*il)/16;
1387
+ const int l0 = n*il;
1388
+ const int is = 8*ip + l0/16;
1389
+
1390
+ const int y_offset = 128*ip + l0;
1391
+ const int q_offset_l = 64*ip + l0;
1392
+ const int q_offset_h = 32*ip + l0;
1090
1393
 
1091
1394
  float sumf = 0;
1092
1395
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1093
1396
 
1094
- device const uint8_t * ql = x[i].ql + 64*ip + n*il;
1095
- device const uint8_t * qh = x[i].qh + 32*ip + n*il;
1397
+ device const uint8_t * ql = x[i].ql + q_offset_l;
1398
+ device const uint8_t * qh = x[i].qh + q_offset_h;
1096
1399
  device const int8_t * sc = x[i].scales + is;
1097
1400
 
1098
- device const float * y = yy + i * QK_K + 128*ip + n*il;
1401
+ device const float * y = yy + i * QK_K + y_offset;
1099
1402
 
1100
1403
  const float dall = x[i].d;
1101
1404