llama_cpp 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -304,34 +304,22 @@ kernel void kernel_mul_mat_q4_0_f32(
304
304
  device const float * src1,
305
305
  device float * dst,
306
306
  constant int64_t & ne00,
307
- constant int64_t & ne01,
308
- constant uint64_t & nb00,
309
- constant uint64_t & nb01,
310
- constant uint64_t & nb02,
311
307
  constant int64_t & ne10,
312
- constant int64_t & ne11,
313
- constant uint64_t & nb10,
314
- constant uint64_t & nb11,
315
- constant uint64_t & nb12,
316
308
  constant int64_t & ne0,
317
- constant int64_t & ne1,
318
309
  threadgroup float * sum [[threadgroup(0)]],
319
310
  uint2 tgpig[[threadgroup_position_in_grid]],
320
- uint2 tpig[[thread_position_in_grid]],
321
311
  uint2 tpitg[[thread_position_in_threadgroup]],
322
312
  uint2 tptg[[threads_per_threadgroup]]) {
323
313
  const int nb = ne00/QK4_0;
324
314
 
325
- const int8_t m8 = 8;
326
-
327
315
  const int64_t r0 = tgpig.x;
328
316
  const int64_t r1 = tgpig.y;
329
317
 
330
318
  device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
331
319
  device const float * y = (device const float *) src1 + r1*ne10;
332
320
 
333
- const uint nth = tptg.x*tptg.y;
334
- const uint ith = tptg.y*tpitg.x + tpitg.y;
321
+ const int nth = tptg.x*tptg.y;
322
+ const int ith = tptg.y*tpitg.x + tpitg.y;
335
323
 
336
324
  const int ix = tpitg.y/4; // 0 or 1
337
325
  const int iy = tpitg.y - 4*ix; // 0...3
@@ -351,47 +339,32 @@ kernel void kernel_mul_mat_q4_0_f32(
351
339
 
352
340
  for (int j = 0; j < 4; ++j) {
353
341
 
354
- acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
355
- acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
342
+ acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
343
+ acc[1] += yl[j] + yl[j+16];
356
344
 
357
345
  }
358
346
 
359
- sumf += d * (acc[0] + acc[1]);
347
+ sumf += d * (acc[0] - 8.f*acc[1]);
360
348
  }
361
349
 
362
350
  sum[ith] = sumf;
363
351
 
364
352
  //
365
353
  // Accumulate the sum from all threads in the threadgroup
366
- // This version is slightly faster than the commented out one below,
367
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
368
354
  //
369
355
  threadgroup_barrier(mem_flags::mem_threadgroup);
370
356
  if (ith%4 == 0) {
371
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
357
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
372
358
  }
373
359
  threadgroup_barrier(mem_flags::mem_threadgroup);
374
360
  if (ith%16 == 0) {
375
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
361
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
376
362
  }
377
363
  threadgroup_barrier(mem_flags::mem_threadgroup);
378
364
  if (ith == 0) {
379
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
365
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
380
366
  dst[r1*ne0 + r0] = sum[0];
381
367
  }
382
-
383
- //// accumulate the sum from all threads in the threadgroup
384
- //threadgroup_barrier(mem_flags::mem_threadgroup);
385
- //for (uint i = nth/2; i > 0; i /= 2) {
386
- // if (ith < i) {
387
- // sum[ith] += sum[ith + i];
388
- // }
389
- // threadgroup_barrier(mem_flags::mem_threadgroup);
390
- //}
391
-
392
- //if (ith == 0) {
393
- // dst[r1*ne0 + r0] = sum[0];
394
- //}
395
368
  }
396
369
 
397
370
  kernel void kernel_mul_mat_q4_1_f32(
@@ -399,20 +372,10 @@ kernel void kernel_mul_mat_q4_1_f32(
399
372
  device const float * src1,
400
373
  device float * dst,
401
374
  constant int64_t & ne00,
402
- constant int64_t & ne01,
403
- constant uint64_t & nb00,
404
- constant uint64_t & nb01,
405
- constant uint64_t & nb02,
406
375
  constant int64_t & ne10,
407
- constant int64_t & ne11,
408
- constant uint64_t & nb10,
409
- constant uint64_t & nb11,
410
- constant uint64_t & nb12,
411
376
  constant int64_t & ne0,
412
- constant int64_t & ne1,
413
377
  threadgroup float * sum [[threadgroup(0)]],
414
378
  uint2 tgpig[[threadgroup_position_in_grid]],
415
- uint2 tpig[[thread_position_in_grid]],
416
379
  uint2 tpitg[[thread_position_in_threadgroup]],
417
380
  uint2 tptg[[threads_per_threadgroup]]) {
418
381
  const int nb = ne00/QK4_1;
@@ -460,11 +423,11 @@ kernel void kernel_mul_mat_q4_1_f32(
460
423
  //
461
424
  threadgroup_barrier(mem_flags::mem_threadgroup);
462
425
  if (ith%4 == 0) {
463
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
426
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
464
427
  }
465
428
  threadgroup_barrier(mem_flags::mem_threadgroup);
466
429
  if (ith%16 == 0) {
467
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
430
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
468
431
  }
469
432
  threadgroup_barrier(mem_flags::mem_threadgroup);
470
433
  if (ith == 0) {
@@ -671,6 +634,15 @@ typedef struct {
671
634
  half d; // super-block scale for quantized scales
672
635
  half dmin; // super-block scale for quantized mins
673
636
  } block_q2_k;
637
+ // 84 bytes / block
638
+
639
+ typedef struct {
640
+ uint8_t hmask[QK_K/8]; // quants - high bit
641
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
642
+ uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
643
+ half d; // super-block scale
644
+ } block_q3_k;
645
+ // 110 bytes / block
674
646
 
675
647
  typedef struct {
676
648
  half d; // super-block scale for quantized scales
@@ -678,6 +650,16 @@ typedef struct {
678
650
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
679
651
  uint8_t qs[QK_K/2]; // 4--bit quants
680
652
  } block_q4_k;
653
+ // 144 bytes / block
654
+
655
+ typedef struct {
656
+ half d; // super-block scale for quantized scales
657
+ half dmin; // super-block scale for quantized mins
658
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
659
+ uint8_t qh[QK_K/8]; // quants, high bit
660
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
661
+ } block_q5_k;
662
+ // 176 bytes / block
681
663
 
682
664
  typedef struct {
683
665
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -685,16 +667,19 @@ typedef struct {
685
667
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
686
668
  half d; // super-block scale
687
669
  } block_q6_k;
670
+ // 210 bytes / block
688
671
 
689
672
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
690
673
  uchar4 r;
691
674
  if (j < 4) {
692
- r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
693
- r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
675
+ r[0] = q[j+0] & 63;
676
+ r[2] = q[j+1] & 63;
677
+ r[1] = q[j+4] & 63;
678
+ r[3] = q[j+5] & 63;
694
679
  } else {
695
680
  r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
696
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
697
681
  r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
682
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
698
683
  r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
699
684
  }
700
685
  return r;
@@ -735,10 +720,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
735
720
  }
736
721
  }
737
722
 
723
+ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
724
+ assert(k % QK_K == 0);
725
+ const int nb = k / QK_K;
726
+
727
+ const uint16_t kmask1 = 0x0303;
728
+ const uint16_t kmask2 = 0x0f0f;
729
+
730
+ uint16_t aux[8];
731
+ thread const int8_t * scales = (thread const int8_t*)aux;
732
+
733
+ for (int i = 0; i < nb; i++) {
734
+
735
+ const float d_all = (float)(x[i].d);
736
+
737
+ device const uint8_t * q = x[i].qs;
738
+ device const uint8_t * h = x[i].hmask;
739
+ uint8_t m = 1;
740
+
741
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
742
+ aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
743
+ aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
744
+ aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
745
+ aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
746
+ aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
747
+ aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
748
+ aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
749
+ aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
750
+
751
+ int is = 0;
752
+ float dl;
753
+ for (int n = 0; n < QK_K; n += 128) {
754
+ int shift = 0;
755
+ for (int j = 0; j < 4; ++j) {
756
+
757
+ dl = d_all * (scales[is++] - 32);
758
+ for (int l = 0; l < 16; ++l) {
759
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
760
+ }
761
+
762
+ dl = d_all * (scales[is++] - 32);
763
+ for (int l = 0; l < 16; ++l) {
764
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
765
+ }
766
+
767
+ shift += 2;
768
+ m <<= 1;
769
+ }
770
+ q += 32;
771
+ }
772
+
773
+ }
774
+
775
+ }
776
+
738
777
  static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
739
778
  assert(k % QK_K == 0);
740
779
  const int nb = k / QK_K;
741
780
 
781
+
742
782
  for (int i = 0; i < nb; i++) {
743
783
 
744
784
  const float d = x[i].d;
@@ -760,6 +800,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
760
800
  }
761
801
  }
762
802
 
803
+ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
804
+ assert(k % QK_K == 0);
805
+ const int nb = k / QK_K;
806
+
807
+ for (int i = 0; i < nb; i++) {
808
+
809
+ const float d = (float)(x[i].d);
810
+ const float min = (float)(x[i].dmin);
811
+
812
+ device const uint8_t * ql = x[i].qs;
813
+ device const uint8_t * qh = x[i].qh;
814
+
815
+ int is = 0;
816
+ uint8_t u1 = 1, u2 = 2;
817
+ for (int j = 0; j < QK_K; j += 64) {
818
+ const uchar4 sc = get_scale_min_k4(is, x[i].scales);
819
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
820
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
821
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
822
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
823
+ ql += 32; is += 2;
824
+ u1 <<= 2; u2 <<= 2;
825
+ }
826
+ }
827
+
828
+ }
829
+
763
830
  static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
764
831
  assert(k % QK_K == 0);
765
832
  const int nb = k / QK_K;
@@ -808,6 +875,22 @@ kernel void kernel_get_rows_q2_k(
808
875
  (device float *) ((device char *) dst + i*nb1), ne00);
809
876
  }
810
877
 
878
+ kernel void kernel_get_rows_q3_k(
879
+ device const void * src0,
880
+ device const int * src1,
881
+ device float * dst,
882
+ constant int64_t & ne00,
883
+ constant uint64_t & nb01,
884
+ constant uint64_t & nb1,
885
+ uint tpig[[thread_position_in_grid]]) {
886
+ const int i = tpig;
887
+ const int r = ((device int32_t *) src1)[i];
888
+
889
+ dequantize_row_q3_k(
890
+ (device const block_q3_k *) ((device char *) src0 + r*nb01),
891
+ (device float *) ((device char *) dst + i*nb1), ne00);
892
+ }
893
+
811
894
  kernel void kernel_get_rows_q4_k(
812
895
  device const void * src0,
813
896
  device const int * src1,
@@ -824,6 +907,22 @@ kernel void kernel_get_rows_q4_k(
824
907
  (device float *) ((device char *) dst + i*nb1), ne00);
825
908
  }
826
909
 
910
+ kernel void kernel_get_rows_q5_k(
911
+ device const void * src0,
912
+ device const int * src1,
913
+ device float * dst,
914
+ constant int64_t & ne00,
915
+ constant uint64_t & nb01,
916
+ constant uint64_t & nb1,
917
+ uint tpig[[thread_position_in_grid]]) {
918
+ const int i = tpig;
919
+ const int r = ((device int32_t *) src1)[i];
920
+
921
+ dequantize_row_q5_k(
922
+ (device const block_q5_k *) ((device char *) src0 + r*nb01),
923
+ (device float *) ((device char *) dst + i*nb1), ne00);
924
+ }
925
+
827
926
  kernel void kernel_get_rows_q6_k(
828
927
  device const void * src0,
829
928
  device const int * src1,
@@ -847,20 +946,10 @@ kernel void kernel_mul_mat_q2_k_f32(
847
946
  device const float * src1,
848
947
  device float * dst,
849
948
  constant int64_t & ne00,
850
- constant int64_t & ne01,
851
- constant uint64_t & nb00,
852
- constant uint64_t & nb01,
853
- constant uint64_t & nb02,
854
949
  constant int64_t & ne10,
855
- constant int64_t & ne11,
856
- constant uint64_t & nb10,
857
- constant uint64_t & nb11,
858
- constant uint64_t & nb12,
859
950
  constant int64_t & ne0,
860
- constant int64_t & ne1,
861
951
  threadgroup float * sum [[threadgroup(0)]],
862
952
  uint2 tgpig[[threadgroup_position_in_grid]],
863
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
864
953
  uint2 tpitg[[thread_position_in_threadgroup]],
865
954
  uint2 tptg[[threads_per_threadgroup]]) {
866
955
 
@@ -875,7 +964,6 @@ kernel void kernel_mul_mat_q2_k_f32(
875
964
  const int nth = tptg.x*tptg.y;
876
965
  const int ith = tptg.y*tpitg.x + tpitg.y;
877
966
 
878
-
879
967
  const int tid = tpitg.y; // 0...16
880
968
  const int il = tid/4; // 0...3
881
969
  const int ir = tid%4; // 0...3
@@ -885,35 +973,54 @@ kernel void kernel_mul_mat_q2_k_f32(
885
973
  const int n = 8;
886
974
  const int is = 4*il + (n*ir)/16;
887
975
 
976
+ const int y_offset = 64*il + n*ir;
977
+ const int q_offset = 32*ip + n*ir;
978
+
888
979
  sum[ith] = 0.0f;
889
980
 
890
981
  float sumf = 0;
891
982
  for (int i = tpitg.x; i < nb; i += tptg.x) {
892
983
 
893
- device const uint8_t * q = x[i].qs + 32*ip + n*ir;
984
+ device const uint8_t * q = x[i].qs + q_offset;
894
985
  device const uint8_t * scales = x[i].scales + is;
895
986
 
896
987
  uint8_t d1 = scales[0] & 0xF;
897
- uint8_t m1 = scales[0] >> 4;
898
988
  uint8_t d2 = scales[2] & 0xF;
989
+ uint8_t m1 = scales[0] >> 4;
899
990
  uint8_t m2 = scales[2] >> 4;
900
991
 
901
- device const float * y = yy + i*QK_K + 64*il + n*ir;
902
-
903
- const float dall = (float)x[i].d;
904
- const float dmin = (float)x[i].dmin;
992
+ device const float * y = yy + i*QK_K + y_offset;
905
993
 
906
- float4 s = {0.f, 0.f, 0.f, 0.f};
994
+ //float4 s = {0.f, 0.f, 0.f, 0.f};
995
+ float2 s = {0.f, 0.f};
996
+ float smin = 0;
907
997
  for (int l = 0; l < n; ++l) {
908
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
909
- s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
998
+ s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
999
+ s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1000
+ smin += y[l+ 0] * m1 + y[l+32] * m2;
910
1001
  }
911
- sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
912
1002
 
1003
+ const float dall = (float)x[i].d;
1004
+ const float dmin = (float)x[i].dmin;
1005
+
1006
+ sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
913
1007
 
914
1008
  }
915
1009
  sum[ith] = sumf;
916
1010
 
1011
+ //int mask1 = (ith%4 == 0);
1012
+ //int mask2 = (ith%16 == 0);
1013
+
1014
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1015
+ //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1016
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1017
+ //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1018
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
1019
+ //if (ith == 0) {
1020
+ // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1021
+ // dst[r1*ne0 + r0] = sum[0];
1022
+ //}
1023
+
917
1024
  //
918
1025
  // Accumulate the sum from all threads in the threadgroup
919
1026
  // This version is slightly faster than the commented out one below,
@@ -932,19 +1039,109 @@ kernel void kernel_mul_mat_q2_k_f32(
932
1039
  for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
933
1040
  dst[r1*ne0 + r0] = sum[0];
934
1041
  }
1042
+ }
935
1043
 
936
- //// accumulate the sum from all threads in the threadgroup
937
- //threadgroup_barrier(mem_flags::mem_threadgroup);
938
- //for (uint i = nth/2; i > 0; i /= 2) {
939
- // if (ith < i) {
940
- // sum[ith] += sum[ith + i];
941
- // }
942
- // threadgroup_barrier(mem_flags::mem_threadgroup);
943
- //}
1044
+ kernel void kernel_mul_mat_q3_k_f32(
1045
+ device const void * src0,
1046
+ device const float * src1,
1047
+ device float * dst,
1048
+ constant int64_t & ne00,
1049
+ constant int64_t & ne10,
1050
+ constant int64_t & ne0,
1051
+ constant int64_t & ne1,
1052
+ threadgroup float * sum [[threadgroup(0)]],
1053
+ uint2 tgpig[[threadgroup_position_in_grid]],
1054
+ uint2 tpitg[[thread_position_in_threadgroup]],
1055
+ uint2 tptg[[threads_per_threadgroup]]) {
1056
+
1057
+ const uint16_t kmask1 = 0x0303;
1058
+ const uint16_t kmask2 = 0x0f0f;
1059
+
1060
+ const uint8_t m3 = 3;
1061
+ const int8_t m4 = 4;
1062
+
1063
+ const int nb = ne00/QK_K;
1064
+
1065
+ const int64_t r0 = tgpig.x;
1066
+ const int64_t r1 = tgpig.y;
1067
+
1068
+ device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1069
+ device const float * yy = (device const float *) src1 + r1*ne10;
1070
+
1071
+ const int nth = tptg.x*tptg.y;
1072
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1073
+
1074
+ const int tid = tpitg.y; // expecting 16
1075
+ const int ip = tid/8; // 0 or 1
1076
+ const int il = tid/2 - 4*ip; // 0...3
1077
+ const int ir = tid%2;
1078
+ const int n = 8;
1079
+ const int l0 = n*ir;
1080
+
1081
+ const uint8_t m = 1 << (4*ip + il);
1082
+
1083
+ const int shift = 2*il;
1084
+
1085
+ const uint16_t s_shift1 = 4*ip;
1086
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1087
+ const int ik = 4 + (il%2);
1088
+
1089
+ const int q_offset = 32*ip + l0;
1090
+ const int y_offset = 128*ip + 32*il + l0;
1091
+
1092
+ //float sumf = 0;
1093
+ float sumf1 = 0, sumf2 = 0;
1094
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1095
+
1096
+ const float d_all = (float)(x[i].d);
1097
+
1098
+ device const uint8_t * q = x[i].qs + q_offset;
1099
+ device const uint8_t * h = x[i].hmask + l0;
1100
+ device const float * y = yy + i * QK_K + y_offset;
1101
+
1102
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1103
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1104
+
1105
+ float s = 0;
1106
+ for (int l = 0; l < n; ++l) {
1107
+ s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1108
+ }
1109
+ float d = d_all * s;
1110
+ sumf1 += d * scales[0];
1111
+ sumf2 += d;
1112
+ //sumf += d_all * s * (scales[0] - 32);
1113
+
1114
+ s = 0;
1115
+ for (int l = 0; l < n; ++l) {
1116
+ s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1117
+ }
1118
+ d = d_all * s;
1119
+ sumf1 += d * scales[1];
1120
+ sumf2 += d;
1121
+ //sumf += d_all * s * (scales[1] - 32);
1122
+
1123
+ }
1124
+
1125
+ //sum[ith] = sumf;
1126
+ sum[ith] = sumf1 - 32.f*sumf2;
1127
+
1128
+ //
1129
+ // Accumulate the sum from all threads in the threadgroup
1130
+ //
1131
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1132
+ if (ith%4 == 0) {
1133
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1134
+ }
1135
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1136
+ if (ith%16 == 0) {
1137
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1138
+ }
1139
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1140
+ if (ith == 0) {
1141
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1142
+ dst[r1*ne0 + r0] = sum[0];
1143
+ }
944
1144
 
945
- //if (ith == 0) {
946
- // dst[r1*ne0 + r0] = sum[0];
947
- //}
948
1145
  }
949
1146
 
950
1147
  kernel void kernel_mul_mat_q4_k_f32(
@@ -952,23 +1149,17 @@ kernel void kernel_mul_mat_q4_k_f32(
952
1149
  device const float * src1,
953
1150
  device float * dst,
954
1151
  constant int64_t & ne00,
955
- constant int64_t & ne01,
956
- constant uint64_t & nb00,
957
- constant uint64_t & nb01,
958
- constant uint64_t & nb02,
959
1152
  constant int64_t & ne10,
960
- constant int64_t & ne11,
961
- constant uint64_t & nb10,
962
- constant uint64_t & nb11,
963
- constant uint64_t & nb12,
964
1153
  constant int64_t & ne0,
965
- constant int64_t & ne1,
966
1154
  threadgroup float * sum [[threadgroup(0)]],
967
1155
  uint2 tgpig[[threadgroup_position_in_grid]],
968
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
969
1156
  uint2 tpitg[[thread_position_in_threadgroup]],
970
1157
  uint2 tptg[[threads_per_threadgroup]]) {
971
1158
 
1159
+ const uint16_t kmask1 = 0x3f3f;
1160
+ const uint16_t kmask2 = 0x0f0f;
1161
+ const uint16_t kmask3 = 0xc0c0;
1162
+
972
1163
  const int nb = ne00/QK_K;
973
1164
 
974
1165
  const int64_t r0 = tgpig.x;
@@ -977,37 +1168,55 @@ kernel void kernel_mul_mat_q4_k_f32(
977
1168
  device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
978
1169
  device const float * yy = (device const float *) src1 + r1*ne10;
979
1170
 
980
- const uint nth = tptg.x*tptg.y;
981
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1171
+ const int nth = tptg.x*tptg.y;
1172
+ const int ith = tptg.y*tpitg.x + tpitg.y;
982
1173
 
983
1174
  const int tid = tpitg.y; // 0...16
984
1175
  const int il = tid/4; // 0...3
985
- const int ir = tid%4; // 0...3
986
- const int n = 8;
987
- const int is = 2*il;
1176
+ const int ir = tid - 4*il;// 0...3
1177
+ const int n = 4;
1178
+
1179
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1180
+ const int in = il%2;
1181
+
1182
+ const int l0 = n*(2*ir + in);
1183
+ const int q_offset = 32*im + l0;
1184
+ const int y_offset = 64*im + l0;
988
1185
 
989
1186
  sum[ith] = 0.0f;
990
1187
 
1188
+ uchar2 sc1, sc2, sc3, sc4;
1189
+
991
1190
  float sumf = 0;
992
1191
  for (int i = tpitg.x; i < nb; i += tptg.x) {
993
1192
 
994
- device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
995
- device const float * y = yy + i*QK_K + 64*il + n*ir;
996
- device const uint8_t * scales = (x + i)->scales;
1193
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1194
+ device const uint8_t * q2 = q1 + 64;
1195
+ device const float * y1 = yy + i*QK_K + y_offset;
1196
+ device const float * y2 = y1 + 128;
997
1197
 
998
1198
  const float dall = (float)((x + i)->d);
999
1199
  const float dmin = (float)((x + i)->dmin);
1000
1200
 
1001
- const uchar4 sc = get_scale_min_k4(is, scales);
1201
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1202
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1203
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1204
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1205
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1002
1206
 
1003
1207
  float4 s = {0.f, 0.f, 0.f, 0.f};
1208
+ float smin = 0;
1004
1209
  for (int l = 0; l < n; ++l) {
1005
- s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
1006
- s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
1210
+
1211
+ s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1212
+ s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1213
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1214
+
1007
1215
  }
1008
- sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
1216
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1009
1217
 
1010
1218
  }
1219
+
1011
1220
  sum[ith] = sumf;
1012
1221
 
1013
1222
  //
@@ -1043,25 +1252,114 @@ kernel void kernel_mul_mat_q4_k_f32(
1043
1252
  //}
1044
1253
  }
1045
1254
 
1255
+ kernel void kernel_mul_mat_q5_k_f32(
1256
+ device const void * src0,
1257
+ device const float * src1,
1258
+ device float * dst,
1259
+ constant int64_t & ne00,
1260
+ constant int64_t & ne10,
1261
+ constant int64_t & ne0,
1262
+ threadgroup float * sum [[threadgroup(0)]],
1263
+ uint2 tgpig[[threadgroup_position_in_grid]],
1264
+ uint2 tpitg[[thread_position_in_threadgroup]],
1265
+ uint2 tptg[[threads_per_threadgroup]]) {
1266
+
1267
+ const uint16_t kmask1 = 0x3f3f;
1268
+ const uint16_t kmask2 = 0x0f0f;
1269
+ const uint16_t kmask3 = 0xc0c0;
1270
+
1271
+ const int nb = ne00/QK_K;
1272
+
1273
+ const int64_t r0 = tgpig.x;
1274
+ const int64_t r1 = tgpig.y;
1275
+
1276
+ device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1277
+ device const float * yy = (device const float *) src1 + r1*ne10;
1278
+
1279
+ const int nth = tptg.x*tptg.y;
1280
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1281
+
1282
+ const int tid = tpitg.y; // 0...16
1283
+ const int il = tid/4; // 0...3
1284
+ const int ir = tid - 4*il;// 0...3
1285
+ const int n = 4;
1286
+
1287
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1288
+ const int in = il%2;
1289
+
1290
+ const int l0 = n*(2*ir + in);
1291
+ const int q_offset = 32*im + l0;
1292
+ const int y_offset = 64*im + l0;
1293
+
1294
+ const uint8_t hm1 = 1u << (2*im);
1295
+ const uint8_t hm2 = hm1 << 1;
1296
+ const uint8_t hm3 = hm1 << 4;
1297
+ const uint8_t hm4 = hm2 << 4;
1298
+
1299
+ uchar2 sc1, sc2, sc3, sc4;
1300
+
1301
+ float sumf = 0;
1302
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
1303
+
1304
+ device const uint8_t * q1 = (x + i)->qs + q_offset;
1305
+ device const uint8_t * q2 = q1 + 64;
1306
+ device const uint8_t * qh = (x + i)->qh + l0;
1307
+ device const float * y1 = yy + i*QK_K + y_offset;
1308
+ device const float * y2 = y1 + 128;
1309
+
1310
+ const float dall = (float)((x + i)->d);
1311
+ const float dmin = (float)((x + i)->dmin);
1312
+
1313
+ device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1314
+ sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1315
+ sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1316
+ sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1317
+ sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1318
+
1319
+ float4 s = {0.f, 0.f, 0.f, 0.f};
1320
+ float smin = 0;
1321
+ for (int l = 0; l < n; ++l) {
1322
+
1323
+ s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1324
+ s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1325
+ s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1326
+ s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1327
+ smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1328
+
1329
+ }
1330
+ sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1331
+
1332
+ }
1333
+ sum[ith] = sumf;
1334
+
1335
+ //
1336
+ // Accumulate the sum from all threads in the threadgroup
1337
+ //
1338
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1339
+ if (ith%4 == 0) {
1340
+ sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1341
+ }
1342
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1343
+ if (ith%16 == 0) {
1344
+ sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1345
+ }
1346
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1347
+ if (ith == 0) {
1348
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1349
+ dst[r1*ne0 + r0] = sum[0];
1350
+ }
1351
+
1352
+ }
1353
+
1046
1354
  kernel void kernel_mul_mat_q6_k_f32(
1047
1355
  device const void * src0,
1048
1356
  device const float * src1,
1049
1357
  device float * dst,
1050
1358
  constant int64_t & ne00,
1051
- constant int64_t & ne01,
1052
- constant uint64_t & nb00,
1053
- constant uint64_t & nb01,
1054
- constant uint64_t & nb02,
1055
1359
  constant int64_t & ne10,
1056
- constant int64_t & ne11,
1057
- constant uint64_t & nb10,
1058
- constant uint64_t & nb11,
1059
- constant uint64_t & nb12,
1060
1360
  constant int64_t & ne0,
1061
- constant int64_t & ne1,
1062
1361
  threadgroup float * sum [[threadgroup(0)]],
1063
1362
  uint2 tgpig[[threadgroup_position_in_grid]],
1064
- uint2 tpig[[thread_position_in_grid]], // we don't use this for now
1065
1363
  uint2 tpitg[[thread_position_in_threadgroup]],
1066
1364
  uint2 tptg[[threads_per_threadgroup]]) {
1067
1365
 
@@ -1078,24 +1376,29 @@ kernel void kernel_mul_mat_q6_k_f32(
1078
1376
  device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1079
1377
  device const float * yy = (device const float *) src1 + r1*ne10;
1080
1378
 
1081
- const uint nth = tptg.x*tptg.y;
1082
- const uint ith = tptg.y*tpitg.x + tpitg.y;
1379
+ const int nth = tptg.x*tptg.y;
1380
+ const int ith = tptg.y*tpitg.x + tpitg.y;
1083
1381
 
1084
- const int step = QK_K / tptg.y; // we expect this to be 16
1085
- const int iqs = step * tpitg.y; // 0...240 in steps of 16
1382
+ // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1383
+ const int iqs = 16 * tpitg.y;
1086
1384
  const int ip = iqs / 128; // 0 or 1
1087
1385
  const int il = (iqs - 128*ip)/16; // 0...7
1088
1386
  const int n = 4;
1089
- const int is = 8*ip + (n*il)/16;
1387
+ const int l0 = n*il;
1388
+ const int is = 8*ip + l0/16;
1389
+
1390
+ const int y_offset = 128*ip + l0;
1391
+ const int q_offset_l = 64*ip + l0;
1392
+ const int q_offset_h = 32*ip + l0;
1090
1393
 
1091
1394
  float sumf = 0;
1092
1395
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1093
1396
 
1094
- device const uint8_t * ql = x[i].ql + 64*ip + n*il;
1095
- device const uint8_t * qh = x[i].qh + 32*ip + n*il;
1397
+ device const uint8_t * ql = x[i].ql + q_offset_l;
1398
+ device const uint8_t * qh = x[i].qh + q_offset_h;
1096
1399
  device const int8_t * sc = x[i].scales + is;
1097
1400
 
1098
- device const float * y = yy + i * QK_K + 128*ip + n*il;
1401
+ device const float * y = yy + i * QK_K + y_offset;
1099
1402
 
1100
1403
  const float dall = x[i].d;
1101
1404