llama_cpp 0.2.2 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
428
428
  }
429
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
430
430
  if (ith == 0) {
431
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
431
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
432
432
  dst[r1*ne0 + r0] = sum[0];
433
433
  }
434
434
  }
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
497
497
  }
498
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
499
499
  if (ith == 0) {
500
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
500
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
501
501
  dst[r1*ne0 + r0] = sum[0];
502
502
  }
503
503
  }
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
775
775
 
776
776
  //============================================ k-quants ======================================================
777
777
 
778
+ #ifndef QK_K
778
779
  #define QK_K 256
780
+ #else
781
+ static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
782
+ #endif
783
+
784
+ #if QK_K == 256
785
+ #define K_SCALE_SIZE 12
786
+ #else
787
+ #define K_SCALE_SIZE 4
788
+ #endif
779
789
 
780
790
  typedef struct {
781
791
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
782
792
  uint8_t qs[QK_K/4]; // quants
783
793
  half d; // super-block scale for quantized scales
784
794
  half dmin; // super-block scale for quantized mins
785
- } block_q2_k;
795
+ } block_q2_K;
786
796
  // 84 bytes / block
787
797
 
788
798
  typedef struct {
789
799
  uint8_t hmask[QK_K/8]; // quants - high bit
790
800
  uint8_t qs[QK_K/4]; // quants - low 2 bits
791
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
792
- half d; // super-block scale
793
- } block_q3_k;
794
- // 110 bytes / block
795
-
801
+ #if QK_K == 64
802
+ uint8_t scales[2];
803
+ #else
804
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
805
+ #endif
806
+ half d; // super-block scale
807
+ } block_q3_K;
808
+
809
+ #if QK_K == 64
810
+ typedef struct {
811
+ half d[2]; // super-block scales/mins
812
+ uint8_t scales[2];
813
+ uint8_t qs[QK_K/2]; // 4-bit quants
814
+ } block_q4_K;
815
+ #else
796
816
  typedef struct {
797
817
  half d; // super-block scale for quantized scales
798
818
  half dmin; // super-block scale for quantized mins
799
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
819
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
800
820
  uint8_t qs[QK_K/2]; // 4--bit quants
801
- } block_q4_k;
802
- // 144 bytes / block
821
+ } block_q4_K;
822
+ #endif
803
823
 
824
+ #if QK_K == 64
825
+ typedef struct {
826
+ half d; // super-block scales/mins
827
+ int8_t scales[QK_K/16]; // 8-bit block scales
828
+ uint8_t qh[QK_K/8]; // quants, high bit
829
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
830
+ } block_q5_K;
831
+ #else
804
832
  typedef struct {
805
833
  half d; // super-block scale for quantized scales
806
834
  half dmin; // super-block scale for quantized mins
807
835
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
808
836
  uint8_t qh[QK_K/8]; // quants, high bit
809
837
  uint8_t qs[QK_K/2]; // quants, low 4 bits
810
- } block_q5_k;
838
+ } block_q5_K;
811
839
  // 176 bytes / block
840
+ #endif
812
841
 
813
842
  typedef struct {
814
843
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
815
844
  uint8_t qh[QK_K/4]; // quants, upper 2 bits
816
845
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
817
846
  half d; // super-block scale
818
- } block_q6_k;
847
+ } block_q6_K;
819
848
  // 210 bytes / block
820
849
 
821
850
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
836
865
 
837
866
  //========================================== dequantization =============================
838
867
 
839
- static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
868
+ static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
840
869
  assert(k % QK_K == 0);
841
870
  const int nb = k / QK_K;
842
871
 
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
847
876
 
848
877
  device const uint8_t * q = x[i].qs;
849
878
 
879
+ #if QK_K == 256
850
880
  int is = 0;
851
881
  float dl, ml;
852
882
  for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
865
895
  }
866
896
  q += 32;
867
897
  }
898
+ #else
899
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
900
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
901
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
902
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
903
+ for (int l = 0; l < 16; ++l) {
904
+ y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
905
+ y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
906
+ y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
907
+ y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
908
+ }
909
+ y += QK_K;
910
+ #endif
868
911
 
869
912
  }
870
913
  }
871
914
 
872
- static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
915
+ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
873
916
  assert(k % QK_K == 0);
874
917
  const int nb = k / QK_K;
875
918
 
919
+ #if QK_K == 256
920
+
876
921
  const uint16_t kmask1 = 0x0303;
877
922
  const uint16_t kmask2 = 0x0f0f;
878
923
 
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
918
963
  }
919
964
  q += 32;
920
965
  }
966
+ }
967
+ #else
968
+ for (int i = 0; i < nb; i++) {
921
969
 
970
+ const float d_all = (float)(x[i].d);
971
+
972
+ device const uint8_t * q = x[i].qs;
973
+ device const uint8_t * hm = x[i].hmask;
974
+
975
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
976
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
977
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
978
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
979
+
980
+ for (int l = 0; l < 8; ++l) {
981
+ uint8_t h = hm[l];
982
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
983
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
984
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
985
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
986
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
987
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
988
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
989
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
990
+ }
991
+ y += QK_K;
922
992
  }
993
+ #endif
923
994
 
924
995
  }
925
996
 
926
- static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
997
+ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
927
998
  assert(k % QK_K == 0);
928
999
  const int nb = k / QK_K;
929
1000
 
930
-
931
1001
  for (int i = 0; i < nb; i++) {
932
1002
 
1003
+ device const uint8_t * q = x[i].qs;
1004
+
1005
+ #if QK_K == 256
933
1006
  const float d = x[i].d;
934
1007
  const float min = x[i].dmin;
935
1008
 
936
- device const uint8_t * q = x[i].qs;
937
1009
  device const uint8_t * scales = x[i].scales;
938
1010
 
939
1011
  int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
945
1017
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
946
1018
  q += 32; is += 2;
947
1019
  }
1020
+ #else
1021
+ device const uint8_t * s = x[i].scales;
1022
+ device const half2 * dh = (device const half2 *)x[i].d;
1023
+ const float2 d = (float2)dh[0];
1024
+ const float d1 = d[0] * (s[0] & 0xF);
1025
+ const float d2 = d[0] * (s[1] & 0xF);
1026
+ const float m1 = d[1] * (s[0] >> 4);
1027
+ const float m2 = d[1] * (s[1] >> 4);
1028
+ for (int l = 0; l < 32; ++l) {
1029
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1030
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1031
+ }
1032
+ y += QK_K;
1033
+ #endif
948
1034
 
949
1035
  }
950
1036
  }
951
1037
 
952
- static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
1038
+ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
953
1039
  assert(k % QK_K == 0);
954
1040
  const int nb = k / QK_K;
955
1041
 
1042
+ #if QK_K == 256
956
1043
  for (int i = 0; i < nb; i++) {
957
1044
 
958
1045
  const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
973
1060
  u1 <<= 2; u2 <<= 2;
974
1061
  }
975
1062
  }
1063
+ #else
1064
+ for (int i = 0; i < nb; i++) {
1065
+
1066
+ const float d = (float)x[i].d;
1067
+
1068
+ device const uint8_t * ql = x[i].qs;
1069
+ device const uint8_t * qh = x[i].qh;
1070
+ device const int8_t * sc = x[i].scales;
1071
+
1072
+ for (int l = 0; l < 8; ++l) {
1073
+ y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1074
+ y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1075
+ y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1076
+ y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1077
+ y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1078
+ y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1079
+ y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1080
+ y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1081
+ }
1082
+ y += QK_K;
1083
+ }
1084
+ #endif
976
1085
 
977
1086
  }
978
1087
 
979
- static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
1088
+ static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
980
1089
  assert(k % QK_K == 0);
981
1090
  const int nb = k / QK_K;
982
1091
 
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
988
1097
 
989
1098
  const float d = x[i].d;
990
1099
 
1100
+ #if QK_K == 256
991
1101
  for (int n = 0; n < QK_K; n += 128) {
992
1102
  for (int l = 0; l < 32; ++l) {
993
1103
  int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
1005
1115
  qh += 32;
1006
1116
  sc += 8;
1007
1117
  }
1118
+ #else
1119
+ for (int l = 0; l < 16; ++l) {
1120
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1121
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1122
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1123
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1124
+ y[l+ 0] = d * sc[0] * q1;
1125
+ y[l+16] = d * sc[1] * q2;
1126
+ y[l+32] = d * sc[2] * q3;
1127
+ y[l+48] = d * sc[3] * q4;
1128
+ }
1129
+ y += 64;
1130
+ #endif
1008
1131
  }
1009
1132
  }
1010
1133
 
1011
- kernel void kernel_get_rows_q2_k(
1134
+ kernel void kernel_get_rows_q2_K(
1012
1135
  device const void * src0,
1013
1136
  device const int * src1,
1014
1137
  device float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
1019
1142
  const int i = tpig;
1020
1143
  const int r = ((device int32_t *) src1)[i];
1021
1144
 
1022
- dequantize_row_q2_k(
1023
- (device const block_q2_k *) ((device char *) src0 + r*nb01),
1145
+ dequantize_row_q2_K(
1146
+ (device const block_q2_K *) ((device char *) src0 + r*nb01),
1024
1147
  (device float *) ((device char *) dst + i*nb1), ne00);
1025
1148
  }
1026
1149
 
1027
- kernel void kernel_get_rows_q3_k(
1150
+ kernel void kernel_get_rows_q3_K(
1028
1151
  device const void * src0,
1029
1152
  device const int * src1,
1030
1153
  device float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
1035
1158
  const int i = tpig;
1036
1159
  const int r = ((device int32_t *) src1)[i];
1037
1160
 
1038
- dequantize_row_q3_k(
1039
- (device const block_q3_k *) ((device char *) src0 + r*nb01),
1161
+ dequantize_row_q3_K(
1162
+ (device const block_q3_K *) ((device char *) src0 + r*nb01),
1040
1163
  (device float *) ((device char *) dst + i*nb1), ne00);
1041
1164
  }
1042
1165
 
1043
- kernel void kernel_get_rows_q4_k(
1166
+ kernel void kernel_get_rows_q4_K(
1044
1167
  device const void * src0,
1045
1168
  device const int * src1,
1046
1169
  device float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
1051
1174
  const int i = tpig;
1052
1175
  const int r = ((device int32_t *) src1)[i];
1053
1176
 
1054
- dequantize_row_q4_k(
1055
- (device const block_q4_k *) ((device char *) src0 + r*nb01),
1177
+ dequantize_row_q4_K(
1178
+ (device const block_q4_K *) ((device char *) src0 + r*nb01),
1056
1179
  (device float *) ((device char *) dst + i*nb1), ne00);
1057
1180
  }
1058
1181
 
1059
- kernel void kernel_get_rows_q5_k(
1182
+ kernel void kernel_get_rows_q5_K(
1060
1183
  device const void * src0,
1061
1184
  device const int * src1,
1062
1185
  device float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
1067
1190
  const int i = tpig;
1068
1191
  const int r = ((device int32_t *) src1)[i];
1069
1192
 
1070
- dequantize_row_q5_k(
1071
- (device const block_q5_k *) ((device char *) src0 + r*nb01),
1193
+ dequantize_row_q5_K(
1194
+ (device const block_q5_K *) ((device char *) src0 + r*nb01),
1072
1195
  (device float *) ((device char *) dst + i*nb1), ne00);
1073
1196
  }
1074
1197
 
1075
- kernel void kernel_get_rows_q6_k(
1198
+ kernel void kernel_get_rows_q6_K(
1076
1199
  device const void * src0,
1077
1200
  device const int * src1,
1078
1201
  device float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
1083
1206
  const int i = tpig;
1084
1207
  const int r = ((device int32_t *) src1)[i];
1085
1208
 
1086
- dequantize_row_q6_k(
1087
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
1209
+ dequantize_row_q6_K(
1210
+ (device const block_q6_K *) ((device char *) src0 + r*nb01),
1088
1211
  (device float *) ((device char *) dst + i*nb1), ne00);
1089
1212
  }
1090
1213
 
1091
1214
  //====================================== dot products =========================
1092
1215
 
1093
- kernel void kernel_mul_mat_q2_k_f32(
1216
+ kernel void kernel_mul_mat_q2_K_f32(
1094
1217
  device const void * src0,
1095
1218
  device const float * src1,
1096
1219
  device float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
1107
1230
  const int64_t r0 = tgpig.x;
1108
1231
  const int64_t r1 = tgpig.y;
1109
1232
 
1110
- device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
1233
+ device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1111
1234
  device const float * yy = (device const float *) src1 + r1*ne10;
1112
1235
 
1113
1236
  const int nth = tptg.x*tptg.y;
1114
1237
  const int ith = tptg.y*tpitg.x + tpitg.y;
1115
1238
 
1239
+ float sumf = 0;
1240
+
1241
+ #if QK_K == 256
1116
1242
  const int tid = tpitg.y; // 0...16
1117
1243
  const int il = tid/4; // 0...3
1118
1244
  const int ir = tid%4; // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
1125
1251
  const int y_offset = 64*il + n*ir;
1126
1252
  const int q_offset = 32*ip + n*ir;
1127
1253
 
1128
- sum[ith] = 0.0f;
1129
-
1130
- float sumf = 0;
1131
1254
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1132
1255
 
1133
1256
  device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
1140
1263
 
1141
1264
  device const float * y = yy + i*QK_K + y_offset;
1142
1265
 
1143
- //float4 s = {0.f, 0.f, 0.f, 0.f};
1144
1266
  float2 s = {0.f, 0.f};
1145
1267
  float smin = 0;
1146
1268
  for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
1155
1277
  sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1156
1278
 
1157
1279
  }
1158
- sum[ith] = sumf;
1280
+ #else
1281
+ const int il = 4 * tpitg.x;
1159
1282
 
1160
- //int mask1 = (ith%4 == 0);
1161
- //int mask2 = (ith%16 == 0);
1283
+ uint32_t aux[2];
1284
+ thread const uint8_t * d = (thread const uint8_t *)aux;
1285
+ thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1162
1286
 
1163
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1164
- //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1165
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1166
- //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1167
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1168
- //if (ith == 0) {
1169
- // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1170
- // dst[r1*ne0 + r0] = sum[0];
1171
- //}
1287
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+
1289
+ device const uint8_t * q = x[i].qs + il;
1290
+ device const float * y = yy + i*QK_K + il;
1291
+
1292
+ const float dall = (float)x[i].d;
1293
+ const float dmin = (float)x[i].dmin;
1294
+
1295
+ device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
+ aux[0] = a[0] & 0x0f0f0f0f;
1297
+ aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+
1299
+ for (int l = 0; l < 4; ++l) {
1300
+ sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
+ + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
+ + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
+ + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ }
1305
+ }
1306
+ #endif
1307
+
1308
+ sum[ith] = sumf;
1172
1309
 
1173
1310
  //
1174
1311
  // Accumulate the sum from all threads in the threadgroup
1175
- // This version is slightly faster than the commented out one below,
1176
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1177
1312
  //
1178
1313
  threadgroup_barrier(mem_flags::mem_threadgroup);
1179
1314
  if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
1190
1325
  }
1191
1326
  }
1192
1327
 
1193
- kernel void kernel_mul_mat_q3_k_f32(
1328
+ kernel void kernel_mul_mat_q3_K_f32(
1194
1329
  device const void * src0,
1195
1330
  device const float * src1,
1196
1331
  device float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
1203
1338
  uint2 tpitg[[thread_position_in_threadgroup]],
1204
1339
  uint2 tptg[[threads_per_threadgroup]]) {
1205
1340
 
1206
- const uint16_t kmask1 = 0x0303;
1207
- const uint16_t kmask2 = 0x0f0f;
1208
-
1209
- const uint8_t m3 = 3;
1210
- const int8_t m4 = 4;
1211
-
1212
1341
  const int nb = ne00/QK_K;
1213
1342
 
1214
1343
  const int64_t r0 = tgpig.x;
1215
1344
  const int64_t r1 = tgpig.y;
1216
1345
 
1217
- device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1346
+ device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1218
1347
  device const float * yy = (device const float *) src1 + r1*ne10;
1219
1348
 
1220
1349
  const int nth = tptg.x*tptg.y;
1221
1350
  const int ith = tptg.y*tpitg.x + tpitg.y;
1222
1351
 
1352
+ #if QK_K == 256
1353
+
1354
+ const uint8_t m3 = 3;
1355
+ const int8_t m4 = 4;
1356
+
1357
+ const uint16_t kmask1 = 0x0303;
1358
+ const uint16_t kmask2 = 0x0f0f;
1359
+
1223
1360
  const int tid = tpitg.y; // expecting 16
1224
1361
  const int ip = tid/8; // 0 or 1
1225
1362
  const int il = tid/2 - 4*ip; // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
1273
1410
 
1274
1411
  //sum[ith] = sumf;
1275
1412
  sum[ith] = sumf1 - 32.f*sumf2;
1413
+ #else
1414
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
+ const int im = il/8; // 0, 0, 1, 1
1416
+ const int in = il%8; // 0, 4, 0, 4
1417
+
1418
+ float sumf = 0;
1419
+
1420
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1421
+
1422
+ const float d_all = (float)(x[i].d);
1423
+
1424
+ device const uint8_t * q = x[i].qs + il;
1425
+ device const uint8_t * h = x[i].hmask + in;
1426
+ device const float * y = yy + i * QK_K + il;
1427
+
1428
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1432
+
1433
+ for (int l = 0; l < 4; ++l) {
1434
+ const uint8_t hm = h[l] >> im;
1435
+ sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
+ + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
+ + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
+ + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
+ }
1440
+
1441
+ }
1442
+
1443
+ sum[ith] = sumf;
1444
+
1445
+ #endif
1276
1446
 
1277
1447
  //
1278
1448
  // Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
1293
1463
 
1294
1464
  }
1295
1465
 
1296
- kernel void kernel_mul_mat_q4_k_f32(
1466
+ kernel void kernel_mul_mat_q4_K_f32(
1297
1467
  device const void * src0,
1298
1468
  device const float * src1,
1299
1469
  device float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
1305
1475
  uint2 tpitg[[thread_position_in_threadgroup]],
1306
1476
  uint2 tptg[[threads_per_threadgroup]]) {
1307
1477
 
1308
- const uint16_t kmask1 = 0x3f3f;
1309
- const uint16_t kmask2 = 0x0f0f;
1310
- const uint16_t kmask3 = 0xc0c0;
1311
-
1312
1478
  const int nb = ne00/QK_K;
1313
1479
 
1314
1480
  const int64_t r0 = tgpig.x;
1315
1481
  const int64_t r1 = tgpig.y;
1316
1482
 
1317
- device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
1318
- device const float * yy = (device const float *) src1 + r1*ne10;
1319
-
1320
1483
  const int nth = tptg.x*tptg.y;
1321
1484
  const int ith = tptg.y*tpitg.x + tpitg.y;
1322
1485
 
1486
+ device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1487
+ device const float * yy = (device const float *) src1 + r1*ne10;
1488
+
1489
+ float sumf = 0;
1490
+
1491
+ #if QK_K == 256
1492
+
1493
+ const uint16_t kmask1 = 0x3f3f;
1494
+ const uint16_t kmask2 = 0x0f0f;
1495
+ const uint16_t kmask3 = 0xc0c0;
1496
+
1323
1497
  const int tid = tpitg.y; // 0...16
1324
1498
  const int il = tid/4; // 0...3
1325
1499
  const int ir = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
1332
1506
  const int q_offset = 32*im + l0;
1333
1507
  const int y_offset = 64*im + l0;
1334
1508
 
1335
- sum[ith] = 0.0f;
1336
-
1337
1509
  uchar2 sc1, sc2, sc3, sc4;
1338
1510
 
1339
- float sumf = 0;
1340
1511
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1341
1512
 
1342
1513
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
1365
1536
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1366
1537
 
1367
1538
  }
1539
+ #else
1540
+ uint16_t aux16[2];
1541
+ thread const uint8_t * scales = (thread const uint8_t *)aux16;
1542
+
1543
+ const int il = 4*tpitg.x;
1544
+
1545
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1546
+
1547
+ device const uint8_t * q = x[i].qs + il;
1548
+ device const float * y = yy + i * QK_K + il;
1549
+
1550
+ const float d = (float)x[i].d[0];
1551
+ const float m = (float)x[i].d[1];
1552
+
1553
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
+ aux16[0] = a[0] & 0x0f0f;
1555
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
1556
+
1557
+ for (int l = 0; l < 4; ++l) {
1558
+ sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
+ + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1560
+ }
1561
+ }
1562
+ #endif
1368
1563
 
1369
1564
  sum[ith] = sumf;
1370
1565
 
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
1401
1596
  //}
1402
1597
  }
1403
1598
 
1404
- kernel void kernel_mul_mat_q5_k_f32(
1599
+ kernel void kernel_mul_mat_q5_K_f32(
1405
1600
  device const void * src0,
1406
1601
  device const float * src1,
1407
1602
  device float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
1413
1608
  uint2 tpitg[[thread_position_in_threadgroup]],
1414
1609
  uint2 tptg[[threads_per_threadgroup]]) {
1415
1610
 
1416
- const uint16_t kmask1 = 0x3f3f;
1417
- const uint16_t kmask2 = 0x0f0f;
1418
- const uint16_t kmask3 = 0xc0c0;
1419
-
1420
1611
  const int nb = ne00/QK_K;
1421
1612
 
1422
1613
  const int64_t r0 = tgpig.x;
1423
1614
  const int64_t r1 = tgpig.y;
1424
1615
 
1425
- device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1616
+ device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1426
1617
  device const float * yy = (device const float *) src1 + r1*ne10;
1427
1618
 
1428
1619
  const int nth = tptg.x*tptg.y;
1429
1620
  const int ith = tptg.y*tpitg.x + tpitg.y;
1430
1621
 
1622
+ float sumf = 0;
1623
+
1624
+ #if QK_K == 256
1625
+
1626
+ const uint16_t kmask1 = 0x3f3f;
1627
+ const uint16_t kmask2 = 0x0f0f;
1628
+ const uint16_t kmask3 = 0xc0c0;
1629
+
1431
1630
  const int tid = tpitg.y; // 0...16
1432
1631
  const int il = tid/4; // 0...3
1433
1632
  const int ir = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
1447
1646
 
1448
1647
  uchar2 sc1, sc2, sc3, sc4;
1449
1648
 
1450
- float sumf = 0;
1451
1649
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1452
1650
 
1453
1651
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
1479
1677
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1480
1678
 
1481
1679
  }
1680
+ #else
1681
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
+ const int im = il/8; // 0, 0, 1, 1
1683
+ const int in = il%8; // 0, 4, 0, 4
1684
+
1685
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1686
+
1687
+ const float d = (float)x[i].d;
1688
+ device const uint8_t * q = x[i].qs + il;
1689
+ device const uint8_t * h = x[i].qh + in;
1690
+ device const int8_t * s = x[i].scales;
1691
+ device const float * y = yy + i*QK_K + il;
1692
+
1693
+ for (int l = 0; l < 4; ++l) {
1694
+ const uint8_t hl = h[l] >> im;
1695
+ sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
+ + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
+ + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
+ + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1699
+ }
1700
+ }
1701
+ #endif
1482
1702
  sum[ith] = sumf;
1483
1703
 
1484
1704
  //
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
1500
1720
 
1501
1721
  }
1502
1722
 
1503
- kernel void kernel_mul_mat_q6_k_f32(
1723
+ kernel void kernel_mul_mat_q6_K_f32(
1504
1724
  device const void * src0,
1505
1725
  device const float * src1,
1506
1726
  device float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
1522
1742
  const int64_t r0 = tgpig.x;
1523
1743
  const int64_t r1 = tgpig.y;
1524
1744
 
1525
- device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1745
+ device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1526
1746
  device const float * yy = (device const float *) src1 + r1*ne10;
1527
1747
 
1528
1748
  const int nth = tptg.x*tptg.y;
1529
1749
  const int ith = tptg.y*tpitg.x + tpitg.y;
1530
1750
 
1751
+ float sumf = 0;
1752
+
1753
+ #if QK_K == 256
1531
1754
  // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1532
1755
  const int iqs = 16 * tpitg.y;
1533
1756
  const int ip = iqs / 128; // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
1540
1763
  const int q_offset_l = 64*ip + l0;
1541
1764
  const int q_offset_h = 32*ip + l0;
1542
1765
 
1543
- float sumf = 0;
1544
1766
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1545
1767
 
1546
1768
  device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
1562
1784
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1563
1785
 
1564
1786
  }
1787
+ #else
1788
+ const int il = 4*tpitg.x; // 0, 4, 8, 12
1789
+
1790
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1791
+ device const float * y = yy + i * QK_K + il;
1792
+ device const uint8_t * ql = x[i].ql + il;
1793
+ device const uint8_t * qh = x[i].qh + il;
1794
+ device const int8_t * s = x[i].scales;
1795
+
1796
+ const float d = x[i].d;
1797
+
1798
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1799
+ for (int l = 0; l < 4; ++l) {
1800
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1801
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1802
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
1803
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1804
+ }
1805
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
1806
+ }
1807
+
1808
+ #endif
1565
1809
 
1566
1810
  sum[ith] = sumf;
1567
1811