llama_cpp 0.2.2 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
428
428
  }
429
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
430
430
  if (ith == 0) {
431
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
431
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
432
432
  dst[r1*ne0 + r0] = sum[0];
433
433
  }
434
434
  }
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
497
497
  }
498
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
499
499
  if (ith == 0) {
500
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
500
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
501
501
  dst[r1*ne0 + r0] = sum[0];
502
502
  }
503
503
  }
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
775
775
 
776
776
  //============================================ k-quants ======================================================
777
777
 
778
+ #ifndef QK_K
778
779
  #define QK_K 256
780
+ #else
781
+ static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
782
+ #endif
783
+
784
+ #if QK_K == 256
785
+ #define K_SCALE_SIZE 12
786
+ #else
787
+ #define K_SCALE_SIZE 4
788
+ #endif
779
789
 
780
790
  typedef struct {
781
791
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
782
792
  uint8_t qs[QK_K/4]; // quants
783
793
  half d; // super-block scale for quantized scales
784
794
  half dmin; // super-block scale for quantized mins
785
- } block_q2_k;
795
+ } block_q2_K;
786
796
  // 84 bytes / block
787
797
 
788
798
  typedef struct {
789
799
  uint8_t hmask[QK_K/8]; // quants - high bit
790
800
  uint8_t qs[QK_K/4]; // quants - low 2 bits
791
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
792
- half d; // super-block scale
793
- } block_q3_k;
794
- // 110 bytes / block
795
-
801
+ #if QK_K == 64
802
+ uint8_t scales[2];
803
+ #else
804
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
805
+ #endif
806
+ half d; // super-block scale
807
+ } block_q3_K;
808
+
809
+ #if QK_K == 64
810
+ typedef struct {
811
+ half d[2]; // super-block scales/mins
812
+ uint8_t scales[2];
813
+ uint8_t qs[QK_K/2]; // 4-bit quants
814
+ } block_q4_K;
815
+ #else
796
816
  typedef struct {
797
817
  half d; // super-block scale for quantized scales
798
818
  half dmin; // super-block scale for quantized mins
799
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
819
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
800
820
  uint8_t qs[QK_K/2]; // 4--bit quants
801
- } block_q4_k;
802
- // 144 bytes / block
821
+ } block_q4_K;
822
+ #endif
803
823
 
824
+ #if QK_K == 64
825
+ typedef struct {
826
+ half d; // super-block scales/mins
827
+ int8_t scales[QK_K/16]; // 8-bit block scales
828
+ uint8_t qh[QK_K/8]; // quants, high bit
829
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
830
+ } block_q5_K;
831
+ #else
804
832
  typedef struct {
805
833
  half d; // super-block scale for quantized scales
806
834
  half dmin; // super-block scale for quantized mins
807
835
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
808
836
  uint8_t qh[QK_K/8]; // quants, high bit
809
837
  uint8_t qs[QK_K/2]; // quants, low 4 bits
810
- } block_q5_k;
838
+ } block_q5_K;
811
839
  // 176 bytes / block
840
+ #endif
812
841
 
813
842
  typedef struct {
814
843
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
815
844
  uint8_t qh[QK_K/4]; // quants, upper 2 bits
816
845
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
817
846
  half d; // super-block scale
818
- } block_q6_k;
847
+ } block_q6_K;
819
848
  // 210 bytes / block
820
849
 
821
850
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
836
865
 
837
866
  //========================================== dequantization =============================
838
867
 
839
- static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
868
+ static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
840
869
  assert(k % QK_K == 0);
841
870
  const int nb = k / QK_K;
842
871
 
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
847
876
 
848
877
  device const uint8_t * q = x[i].qs;
849
878
 
879
+ #if QK_K == 256
850
880
  int is = 0;
851
881
  float dl, ml;
852
882
  for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
865
895
  }
866
896
  q += 32;
867
897
  }
898
+ #else
899
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
900
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
901
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
902
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
903
+ for (int l = 0; l < 16; ++l) {
904
+ y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
905
+ y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
906
+ y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
907
+ y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
908
+ }
909
+ y += QK_K;
910
+ #endif
868
911
 
869
912
  }
870
913
  }
871
914
 
872
- static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
915
+ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
873
916
  assert(k % QK_K == 0);
874
917
  const int nb = k / QK_K;
875
918
 
919
+ #if QK_K == 256
920
+
876
921
  const uint16_t kmask1 = 0x0303;
877
922
  const uint16_t kmask2 = 0x0f0f;
878
923
 
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
918
963
  }
919
964
  q += 32;
920
965
  }
966
+ }
967
+ #else
968
+ for (int i = 0; i < nb; i++) {
921
969
 
970
+ const float d_all = (float)(x[i].d);
971
+
972
+ device const uint8_t * q = x[i].qs;
973
+ device const uint8_t * hm = x[i].hmask;
974
+
975
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
976
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
977
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
978
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
979
+
980
+ for (int l = 0; l < 8; ++l) {
981
+ uint8_t h = hm[l];
982
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
983
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
984
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
985
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
986
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
987
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
988
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
989
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
990
+ }
991
+ y += QK_K;
922
992
  }
993
+ #endif
923
994
 
924
995
  }
925
996
 
926
- static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
997
+ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
927
998
  assert(k % QK_K == 0);
928
999
  const int nb = k / QK_K;
929
1000
 
930
-
931
1001
  for (int i = 0; i < nb; i++) {
932
1002
 
1003
+ device const uint8_t * q = x[i].qs;
1004
+
1005
+ #if QK_K == 256
933
1006
  const float d = x[i].d;
934
1007
  const float min = x[i].dmin;
935
1008
 
936
- device const uint8_t * q = x[i].qs;
937
1009
  device const uint8_t * scales = x[i].scales;
938
1010
 
939
1011
  int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
945
1017
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
946
1018
  q += 32; is += 2;
947
1019
  }
1020
+ #else
1021
+ device const uint8_t * s = x[i].scales;
1022
+ device const half2 * dh = (device const half2 *)x[i].d;
1023
+ const float2 d = (float2)dh[0];
1024
+ const float d1 = d[0] * (s[0] & 0xF);
1025
+ const float d2 = d[0] * (s[1] & 0xF);
1026
+ const float m1 = d[1] * (s[0] >> 4);
1027
+ const float m2 = d[1] * (s[1] >> 4);
1028
+ for (int l = 0; l < 32; ++l) {
1029
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1030
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1031
+ }
1032
+ y += QK_K;
1033
+ #endif
948
1034
 
949
1035
  }
950
1036
  }
951
1037
 
952
- static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
1038
+ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
953
1039
  assert(k % QK_K == 0);
954
1040
  const int nb = k / QK_K;
955
1041
 
1042
+ #if QK_K == 256
956
1043
  for (int i = 0; i < nb; i++) {
957
1044
 
958
1045
  const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
973
1060
  u1 <<= 2; u2 <<= 2;
974
1061
  }
975
1062
  }
1063
+ #else
1064
+ for (int i = 0; i < nb; i++) {
1065
+
1066
+ const float d = (float)x[i].d;
1067
+
1068
+ device const uint8_t * ql = x[i].qs;
1069
+ device const uint8_t * qh = x[i].qh;
1070
+ device const int8_t * sc = x[i].scales;
1071
+
1072
+ for (int l = 0; l < 8; ++l) {
1073
+ y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1074
+ y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1075
+ y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1076
+ y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1077
+ y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1078
+ y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1079
+ y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1080
+ y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1081
+ }
1082
+ y += QK_K;
1083
+ }
1084
+ #endif
976
1085
 
977
1086
  }
978
1087
 
979
- static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
1088
+ static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
980
1089
  assert(k % QK_K == 0);
981
1090
  const int nb = k / QK_K;
982
1091
 
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
988
1097
 
989
1098
  const float d = x[i].d;
990
1099
 
1100
+ #if QK_K == 256
991
1101
  for (int n = 0; n < QK_K; n += 128) {
992
1102
  for (int l = 0; l < 32; ++l) {
993
1103
  int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
1005
1115
  qh += 32;
1006
1116
  sc += 8;
1007
1117
  }
1118
+ #else
1119
+ for (int l = 0; l < 16; ++l) {
1120
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1121
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1122
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1123
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1124
+ y[l+ 0] = d * sc[0] * q1;
1125
+ y[l+16] = d * sc[1] * q2;
1126
+ y[l+32] = d * sc[2] * q3;
1127
+ y[l+48] = d * sc[3] * q4;
1128
+ }
1129
+ y += 64;
1130
+ #endif
1008
1131
  }
1009
1132
  }
1010
1133
 
1011
- kernel void kernel_get_rows_q2_k(
1134
+ kernel void kernel_get_rows_q2_K(
1012
1135
  device const void * src0,
1013
1136
  device const int * src1,
1014
1137
  device float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
1019
1142
  const int i = tpig;
1020
1143
  const int r = ((device int32_t *) src1)[i];
1021
1144
 
1022
- dequantize_row_q2_k(
1023
- (device const block_q2_k *) ((device char *) src0 + r*nb01),
1145
+ dequantize_row_q2_K(
1146
+ (device const block_q2_K *) ((device char *) src0 + r*nb01),
1024
1147
  (device float *) ((device char *) dst + i*nb1), ne00);
1025
1148
  }
1026
1149
 
1027
- kernel void kernel_get_rows_q3_k(
1150
+ kernel void kernel_get_rows_q3_K(
1028
1151
  device const void * src0,
1029
1152
  device const int * src1,
1030
1153
  device float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
1035
1158
  const int i = tpig;
1036
1159
  const int r = ((device int32_t *) src1)[i];
1037
1160
 
1038
- dequantize_row_q3_k(
1039
- (device const block_q3_k *) ((device char *) src0 + r*nb01),
1161
+ dequantize_row_q3_K(
1162
+ (device const block_q3_K *) ((device char *) src0 + r*nb01),
1040
1163
  (device float *) ((device char *) dst + i*nb1), ne00);
1041
1164
  }
1042
1165
 
1043
- kernel void kernel_get_rows_q4_k(
1166
+ kernel void kernel_get_rows_q4_K(
1044
1167
  device const void * src0,
1045
1168
  device const int * src1,
1046
1169
  device float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
1051
1174
  const int i = tpig;
1052
1175
  const int r = ((device int32_t *) src1)[i];
1053
1176
 
1054
- dequantize_row_q4_k(
1055
- (device const block_q4_k *) ((device char *) src0 + r*nb01),
1177
+ dequantize_row_q4_K(
1178
+ (device const block_q4_K *) ((device char *) src0 + r*nb01),
1056
1179
  (device float *) ((device char *) dst + i*nb1), ne00);
1057
1180
  }
1058
1181
 
1059
- kernel void kernel_get_rows_q5_k(
1182
+ kernel void kernel_get_rows_q5_K(
1060
1183
  device const void * src0,
1061
1184
  device const int * src1,
1062
1185
  device float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
1067
1190
  const int i = tpig;
1068
1191
  const int r = ((device int32_t *) src1)[i];
1069
1192
 
1070
- dequantize_row_q5_k(
1071
- (device const block_q5_k *) ((device char *) src0 + r*nb01),
1193
+ dequantize_row_q5_K(
1194
+ (device const block_q5_K *) ((device char *) src0 + r*nb01),
1072
1195
  (device float *) ((device char *) dst + i*nb1), ne00);
1073
1196
  }
1074
1197
 
1075
- kernel void kernel_get_rows_q6_k(
1198
+ kernel void kernel_get_rows_q6_K(
1076
1199
  device const void * src0,
1077
1200
  device const int * src1,
1078
1201
  device float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
1083
1206
  const int i = tpig;
1084
1207
  const int r = ((device int32_t *) src1)[i];
1085
1208
 
1086
- dequantize_row_q6_k(
1087
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
1209
+ dequantize_row_q6_K(
1210
+ (device const block_q6_K *) ((device char *) src0 + r*nb01),
1088
1211
  (device float *) ((device char *) dst + i*nb1), ne00);
1089
1212
  }
1090
1213
 
1091
1214
  //====================================== dot products =========================
1092
1215
 
1093
- kernel void kernel_mul_mat_q2_k_f32(
1216
+ kernel void kernel_mul_mat_q2_K_f32(
1094
1217
  device const void * src0,
1095
1218
  device const float * src1,
1096
1219
  device float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
1107
1230
  const int64_t r0 = tgpig.x;
1108
1231
  const int64_t r1 = tgpig.y;
1109
1232
 
1110
- device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
1233
+ device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1111
1234
  device const float * yy = (device const float *) src1 + r1*ne10;
1112
1235
 
1113
1236
  const int nth = tptg.x*tptg.y;
1114
1237
  const int ith = tptg.y*tpitg.x + tpitg.y;
1115
1238
 
1239
+ float sumf = 0;
1240
+
1241
+ #if QK_K == 256
1116
1242
  const int tid = tpitg.y; // 0...16
1117
1243
  const int il = tid/4; // 0...3
1118
1244
  const int ir = tid%4; // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
1125
1251
  const int y_offset = 64*il + n*ir;
1126
1252
  const int q_offset = 32*ip + n*ir;
1127
1253
 
1128
- sum[ith] = 0.0f;
1129
-
1130
- float sumf = 0;
1131
1254
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1132
1255
 
1133
1256
  device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
1140
1263
 
1141
1264
  device const float * y = yy + i*QK_K + y_offset;
1142
1265
 
1143
- //float4 s = {0.f, 0.f, 0.f, 0.f};
1144
1266
  float2 s = {0.f, 0.f};
1145
1267
  float smin = 0;
1146
1268
  for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
1155
1277
  sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1156
1278
 
1157
1279
  }
1158
- sum[ith] = sumf;
1280
+ #else
1281
+ const int il = 4 * tpitg.x;
1159
1282
 
1160
- //int mask1 = (ith%4 == 0);
1161
- //int mask2 = (ith%16 == 0);
1283
+ uint32_t aux[2];
1284
+ thread const uint8_t * d = (thread const uint8_t *)aux;
1285
+ thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1162
1286
 
1163
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1164
- //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1165
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1166
- //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1167
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1168
- //if (ith == 0) {
1169
- // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1170
- // dst[r1*ne0 + r0] = sum[0];
1171
- //}
1287
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+
1289
+ device const uint8_t * q = x[i].qs + il;
1290
+ device const float * y = yy + i*QK_K + il;
1291
+
1292
+ const float dall = (float)x[i].d;
1293
+ const float dmin = (float)x[i].dmin;
1294
+
1295
+ device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
+ aux[0] = a[0] & 0x0f0f0f0f;
1297
+ aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+
1299
+ for (int l = 0; l < 4; ++l) {
1300
+ sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
+ + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
+ + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
+ + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ }
1305
+ }
1306
+ #endif
1307
+
1308
+ sum[ith] = sumf;
1172
1309
 
1173
1310
  //
1174
1311
  // Accumulate the sum from all threads in the threadgroup
1175
- // This version is slightly faster than the commented out one below,
1176
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1177
1312
  //
1178
1313
  threadgroup_barrier(mem_flags::mem_threadgroup);
1179
1314
  if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
1190
1325
  }
1191
1326
  }
1192
1327
 
1193
- kernel void kernel_mul_mat_q3_k_f32(
1328
+ kernel void kernel_mul_mat_q3_K_f32(
1194
1329
  device const void * src0,
1195
1330
  device const float * src1,
1196
1331
  device float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
1203
1338
  uint2 tpitg[[thread_position_in_threadgroup]],
1204
1339
  uint2 tptg[[threads_per_threadgroup]]) {
1205
1340
 
1206
- const uint16_t kmask1 = 0x0303;
1207
- const uint16_t kmask2 = 0x0f0f;
1208
-
1209
- const uint8_t m3 = 3;
1210
- const int8_t m4 = 4;
1211
-
1212
1341
  const int nb = ne00/QK_K;
1213
1342
 
1214
1343
  const int64_t r0 = tgpig.x;
1215
1344
  const int64_t r1 = tgpig.y;
1216
1345
 
1217
- device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1346
+ device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1218
1347
  device const float * yy = (device const float *) src1 + r1*ne10;
1219
1348
 
1220
1349
  const int nth = tptg.x*tptg.y;
1221
1350
  const int ith = tptg.y*tpitg.x + tpitg.y;
1222
1351
 
1352
+ #if QK_K == 256
1353
+
1354
+ const uint8_t m3 = 3;
1355
+ const int8_t m4 = 4;
1356
+
1357
+ const uint16_t kmask1 = 0x0303;
1358
+ const uint16_t kmask2 = 0x0f0f;
1359
+
1223
1360
  const int tid = tpitg.y; // expecting 16
1224
1361
  const int ip = tid/8; // 0 or 1
1225
1362
  const int il = tid/2 - 4*ip; // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
1273
1410
 
1274
1411
  //sum[ith] = sumf;
1275
1412
  sum[ith] = sumf1 - 32.f*sumf2;
1413
+ #else
1414
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
+ const int im = il/8; // 0, 0, 1, 1
1416
+ const int in = il%8; // 0, 4, 0, 4
1417
+
1418
+ float sumf = 0;
1419
+
1420
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1421
+
1422
+ const float d_all = (float)(x[i].d);
1423
+
1424
+ device const uint8_t * q = x[i].qs + il;
1425
+ device const uint8_t * h = x[i].hmask + in;
1426
+ device const float * y = yy + i * QK_K + il;
1427
+
1428
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1432
+
1433
+ for (int l = 0; l < 4; ++l) {
1434
+ const uint8_t hm = h[l] >> im;
1435
+ sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
+ + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
+ + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
+ + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
+ }
1440
+
1441
+ }
1442
+
1443
+ sum[ith] = sumf;
1444
+
1445
+ #endif
1276
1446
 
1277
1447
  //
1278
1448
  // Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
1293
1463
 
1294
1464
  }
1295
1465
 
1296
- kernel void kernel_mul_mat_q4_k_f32(
1466
+ kernel void kernel_mul_mat_q4_K_f32(
1297
1467
  device const void * src0,
1298
1468
  device const float * src1,
1299
1469
  device float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
1305
1475
  uint2 tpitg[[thread_position_in_threadgroup]],
1306
1476
  uint2 tptg[[threads_per_threadgroup]]) {
1307
1477
 
1308
- const uint16_t kmask1 = 0x3f3f;
1309
- const uint16_t kmask2 = 0x0f0f;
1310
- const uint16_t kmask3 = 0xc0c0;
1311
-
1312
1478
  const int nb = ne00/QK_K;
1313
1479
 
1314
1480
  const int64_t r0 = tgpig.x;
1315
1481
  const int64_t r1 = tgpig.y;
1316
1482
 
1317
- device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
1318
- device const float * yy = (device const float *) src1 + r1*ne10;
1319
-
1320
1483
  const int nth = tptg.x*tptg.y;
1321
1484
  const int ith = tptg.y*tpitg.x + tpitg.y;
1322
1485
 
1486
+ device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1487
+ device const float * yy = (device const float *) src1 + r1*ne10;
1488
+
1489
+ float sumf = 0;
1490
+
1491
+ #if QK_K == 256
1492
+
1493
+ const uint16_t kmask1 = 0x3f3f;
1494
+ const uint16_t kmask2 = 0x0f0f;
1495
+ const uint16_t kmask3 = 0xc0c0;
1496
+
1323
1497
  const int tid = tpitg.y; // 0...16
1324
1498
  const int il = tid/4; // 0...3
1325
1499
  const int ir = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
1332
1506
  const int q_offset = 32*im + l0;
1333
1507
  const int y_offset = 64*im + l0;
1334
1508
 
1335
- sum[ith] = 0.0f;
1336
-
1337
1509
  uchar2 sc1, sc2, sc3, sc4;
1338
1510
 
1339
- float sumf = 0;
1340
1511
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1341
1512
 
1342
1513
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
1365
1536
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1366
1537
 
1367
1538
  }
1539
+ #else
1540
+ uint16_t aux16[2];
1541
+ thread const uint8_t * scales = (thread const uint8_t *)aux16;
1542
+
1543
+ const int il = 4*tpitg.x;
1544
+
1545
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1546
+
1547
+ device const uint8_t * q = x[i].qs + il;
1548
+ device const float * y = yy + i * QK_K + il;
1549
+
1550
+ const float d = (float)x[i].d[0];
1551
+ const float m = (float)x[i].d[1];
1552
+
1553
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
+ aux16[0] = a[0] & 0x0f0f;
1555
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
1556
+
1557
+ for (int l = 0; l < 4; ++l) {
1558
+ sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
+ + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1560
+ }
1561
+ }
1562
+ #endif
1368
1563
 
1369
1564
  sum[ith] = sumf;
1370
1565
 
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
1401
1596
  //}
1402
1597
  }
1403
1598
 
1404
- kernel void kernel_mul_mat_q5_k_f32(
1599
+ kernel void kernel_mul_mat_q5_K_f32(
1405
1600
  device const void * src0,
1406
1601
  device const float * src1,
1407
1602
  device float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
1413
1608
  uint2 tpitg[[thread_position_in_threadgroup]],
1414
1609
  uint2 tptg[[threads_per_threadgroup]]) {
1415
1610
 
1416
- const uint16_t kmask1 = 0x3f3f;
1417
- const uint16_t kmask2 = 0x0f0f;
1418
- const uint16_t kmask3 = 0xc0c0;
1419
-
1420
1611
  const int nb = ne00/QK_K;
1421
1612
 
1422
1613
  const int64_t r0 = tgpig.x;
1423
1614
  const int64_t r1 = tgpig.y;
1424
1615
 
1425
- device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1616
+ device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1426
1617
  device const float * yy = (device const float *) src1 + r1*ne10;
1427
1618
 
1428
1619
  const int nth = tptg.x*tptg.y;
1429
1620
  const int ith = tptg.y*tpitg.x + tpitg.y;
1430
1621
 
1622
+ float sumf = 0;
1623
+
1624
+ #if QK_K == 256
1625
+
1626
+ const uint16_t kmask1 = 0x3f3f;
1627
+ const uint16_t kmask2 = 0x0f0f;
1628
+ const uint16_t kmask3 = 0xc0c0;
1629
+
1431
1630
  const int tid = tpitg.y; // 0...16
1432
1631
  const int il = tid/4; // 0...3
1433
1632
  const int ir = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
1447
1646
 
1448
1647
  uchar2 sc1, sc2, sc3, sc4;
1449
1648
 
1450
- float sumf = 0;
1451
1649
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1452
1650
 
1453
1651
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
1479
1677
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1480
1678
 
1481
1679
  }
1680
+ #else
1681
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
+ const int im = il/8; // 0, 0, 1, 1
1683
+ const int in = il%8; // 0, 4, 0, 4
1684
+
1685
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1686
+
1687
+ const float d = (float)x[i].d;
1688
+ device const uint8_t * q = x[i].qs + il;
1689
+ device const uint8_t * h = x[i].qh + in;
1690
+ device const int8_t * s = x[i].scales;
1691
+ device const float * y = yy + i*QK_K + il;
1692
+
1693
+ for (int l = 0; l < 4; ++l) {
1694
+ const uint8_t hl = h[l] >> im;
1695
+ sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
+ + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
+ + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
+ + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1699
+ }
1700
+ }
1701
+ #endif
1482
1702
  sum[ith] = sumf;
1483
1703
 
1484
1704
  //
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
1500
1720
 
1501
1721
  }
1502
1722
 
1503
- kernel void kernel_mul_mat_q6_k_f32(
1723
+ kernel void kernel_mul_mat_q6_K_f32(
1504
1724
  device const void * src0,
1505
1725
  device const float * src1,
1506
1726
  device float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
1522
1742
  const int64_t r0 = tgpig.x;
1523
1743
  const int64_t r1 = tgpig.y;
1524
1744
 
1525
- device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1745
+ device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1526
1746
  device const float * yy = (device const float *) src1 + r1*ne10;
1527
1747
 
1528
1748
  const int nth = tptg.x*tptg.y;
1529
1749
  const int ith = tptg.y*tpitg.x + tpitg.y;
1530
1750
 
1751
+ float sumf = 0;
1752
+
1753
+ #if QK_K == 256
1531
1754
  // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1532
1755
  const int iqs = 16 * tpitg.y;
1533
1756
  const int ip = iqs / 128; // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
1540
1763
  const int q_offset_l = 64*ip + l0;
1541
1764
  const int q_offset_h = 32*ip + l0;
1542
1765
 
1543
- float sumf = 0;
1544
1766
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1545
1767
 
1546
1768
  device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
1562
1784
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1563
1785
 
1564
1786
  }
1787
+ #else
1788
+ const int il = 4*tpitg.x; // 0, 4, 8, 12
1789
+
1790
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1791
+ device const float * y = yy + i * QK_K + il;
1792
+ device const uint8_t * ql = x[i].ql + il;
1793
+ device const uint8_t * qh = x[i].qh + il;
1794
+ device const int8_t * s = x[i].scales;
1795
+
1796
+ const float d = x[i].d;
1797
+
1798
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1799
+ for (int l = 0; l < 4; ++l) {
1800
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1801
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1802
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
1803
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1804
+ }
1805
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
1806
+ }
1807
+
1808
+ #endif
1565
1809
 
1566
1810
  sum[ith] = sumf;
1567
1811