llama_cpp 0.2.2 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
428
428
|
}
|
429
429
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
430
430
|
if (ith == 0) {
|
431
|
-
for (
|
431
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
432
432
|
dst[r1*ne0 + r0] = sum[0];
|
433
433
|
}
|
434
434
|
}
|
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
497
497
|
}
|
498
498
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
499
499
|
if (ith == 0) {
|
500
|
-
for (
|
500
|
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
501
501
|
dst[r1*ne0 + r0] = sum[0];
|
502
502
|
}
|
503
503
|
}
|
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
|
|
775
775
|
|
776
776
|
//============================================ k-quants ======================================================
|
777
777
|
|
778
|
+
#ifndef QK_K
|
778
779
|
#define QK_K 256
|
780
|
+
#else
|
781
|
+
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
782
|
+
#endif
|
783
|
+
|
784
|
+
#if QK_K == 256
|
785
|
+
#define K_SCALE_SIZE 12
|
786
|
+
#else
|
787
|
+
#define K_SCALE_SIZE 4
|
788
|
+
#endif
|
779
789
|
|
780
790
|
typedef struct {
|
781
791
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
782
792
|
uint8_t qs[QK_K/4]; // quants
|
783
793
|
half d; // super-block scale for quantized scales
|
784
794
|
half dmin; // super-block scale for quantized mins
|
785
|
-
}
|
795
|
+
} block_q2_K;
|
786
796
|
// 84 bytes / block
|
787
797
|
|
788
798
|
typedef struct {
|
789
799
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
790
800
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
//
|
795
|
-
|
801
|
+
#if QK_K == 64
|
802
|
+
uint8_t scales[2];
|
803
|
+
#else
|
804
|
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
805
|
+
#endif
|
806
|
+
half d; // super-block scale
|
807
|
+
} block_q3_K;
|
808
|
+
|
809
|
+
#if QK_K == 64
|
810
|
+
typedef struct {
|
811
|
+
half d[2]; // super-block scales/mins
|
812
|
+
uint8_t scales[2];
|
813
|
+
uint8_t qs[QK_K/2]; // 4-bit quants
|
814
|
+
} block_q4_K;
|
815
|
+
#else
|
796
816
|
typedef struct {
|
797
817
|
half d; // super-block scale for quantized scales
|
798
818
|
half dmin; // super-block scale for quantized mins
|
799
|
-
uint8_t scales[
|
819
|
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
800
820
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
801
|
-
}
|
802
|
-
|
821
|
+
} block_q4_K;
|
822
|
+
#endif
|
803
823
|
|
824
|
+
#if QK_K == 64
|
825
|
+
typedef struct {
|
826
|
+
half d; // super-block scales/mins
|
827
|
+
int8_t scales[QK_K/16]; // 8-bit block scales
|
828
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
829
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
830
|
+
} block_q5_K;
|
831
|
+
#else
|
804
832
|
typedef struct {
|
805
833
|
half d; // super-block scale for quantized scales
|
806
834
|
half dmin; // super-block scale for quantized mins
|
807
835
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
808
836
|
uint8_t qh[QK_K/8]; // quants, high bit
|
809
837
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
810
|
-
}
|
838
|
+
} block_q5_K;
|
811
839
|
// 176 bytes / block
|
840
|
+
#endif
|
812
841
|
|
813
842
|
typedef struct {
|
814
843
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
815
844
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
816
845
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
817
846
|
half d; // super-block scale
|
818
|
-
}
|
847
|
+
} block_q6_K;
|
819
848
|
// 210 bytes / block
|
820
849
|
|
821
850
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
836
865
|
|
837
866
|
//========================================== dequantization =============================
|
838
867
|
|
839
|
-
static void
|
868
|
+
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
840
869
|
assert(k % QK_K == 0);
|
841
870
|
const int nb = k / QK_K;
|
842
871
|
|
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
847
876
|
|
848
877
|
device const uint8_t * q = x[i].qs;
|
849
878
|
|
879
|
+
#if QK_K == 256
|
850
880
|
int is = 0;
|
851
881
|
float dl, ml;
|
852
882
|
for (int n = 0; n < QK_K; n += 128) {
|
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
865
895
|
}
|
866
896
|
q += 32;
|
867
897
|
}
|
898
|
+
#else
|
899
|
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
900
|
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
901
|
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
902
|
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
903
|
+
for (int l = 0; l < 16; ++l) {
|
904
|
+
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
905
|
+
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
906
|
+
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
907
|
+
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
908
|
+
}
|
909
|
+
y += QK_K;
|
910
|
+
#endif
|
868
911
|
|
869
912
|
}
|
870
913
|
}
|
871
914
|
|
872
|
-
static void
|
915
|
+
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
873
916
|
assert(k % QK_K == 0);
|
874
917
|
const int nb = k / QK_K;
|
875
918
|
|
919
|
+
#if QK_K == 256
|
920
|
+
|
876
921
|
const uint16_t kmask1 = 0x0303;
|
877
922
|
const uint16_t kmask2 = 0x0f0f;
|
878
923
|
|
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|
918
963
|
}
|
919
964
|
q += 32;
|
920
965
|
}
|
966
|
+
}
|
967
|
+
#else
|
968
|
+
for (int i = 0; i < nb; i++) {
|
921
969
|
|
970
|
+
const float d_all = (float)(x[i].d);
|
971
|
+
|
972
|
+
device const uint8_t * q = x[i].qs;
|
973
|
+
device const uint8_t * hm = x[i].hmask;
|
974
|
+
|
975
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
976
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
977
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
978
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
979
|
+
|
980
|
+
for (int l = 0; l < 8; ++l) {
|
981
|
+
uint8_t h = hm[l];
|
982
|
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
983
|
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
984
|
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
985
|
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
986
|
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
987
|
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
988
|
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
989
|
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
990
|
+
}
|
991
|
+
y += QK_K;
|
922
992
|
}
|
993
|
+
#endif
|
923
994
|
|
924
995
|
}
|
925
996
|
|
926
|
-
static void
|
997
|
+
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
927
998
|
assert(k % QK_K == 0);
|
928
999
|
const int nb = k / QK_K;
|
929
1000
|
|
930
|
-
|
931
1001
|
for (int i = 0; i < nb; i++) {
|
932
1002
|
|
1003
|
+
device const uint8_t * q = x[i].qs;
|
1004
|
+
|
1005
|
+
#if QK_K == 256
|
933
1006
|
const float d = x[i].d;
|
934
1007
|
const float min = x[i].dmin;
|
935
1008
|
|
936
|
-
device const uint8_t * q = x[i].qs;
|
937
1009
|
device const uint8_t * scales = x[i].scales;
|
938
1010
|
|
939
1011
|
int is = 0;
|
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
945
1017
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
946
1018
|
q += 32; is += 2;
|
947
1019
|
}
|
1020
|
+
#else
|
1021
|
+
device const uint8_t * s = x[i].scales;
|
1022
|
+
device const half2 * dh = (device const half2 *)x[i].d;
|
1023
|
+
const float2 d = (float2)dh[0];
|
1024
|
+
const float d1 = d[0] * (s[0] & 0xF);
|
1025
|
+
const float d2 = d[0] * (s[1] & 0xF);
|
1026
|
+
const float m1 = d[1] * (s[0] >> 4);
|
1027
|
+
const float m2 = d[1] * (s[1] >> 4);
|
1028
|
+
for (int l = 0; l < 32; ++l) {
|
1029
|
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
1030
|
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
1031
|
+
}
|
1032
|
+
y += QK_K;
|
1033
|
+
#endif
|
948
1034
|
|
949
1035
|
}
|
950
1036
|
}
|
951
1037
|
|
952
|
-
static void
|
1038
|
+
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
953
1039
|
assert(k % QK_K == 0);
|
954
1040
|
const int nb = k / QK_K;
|
955
1041
|
|
1042
|
+
#if QK_K == 256
|
956
1043
|
for (int i = 0; i < nb; i++) {
|
957
1044
|
|
958
1045
|
const float d = (float)(x[i].d);
|
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|
973
1060
|
u1 <<= 2; u2 <<= 2;
|
974
1061
|
}
|
975
1062
|
}
|
1063
|
+
#else
|
1064
|
+
for (int i = 0; i < nb; i++) {
|
1065
|
+
|
1066
|
+
const float d = (float)x[i].d;
|
1067
|
+
|
1068
|
+
device const uint8_t * ql = x[i].qs;
|
1069
|
+
device const uint8_t * qh = x[i].qh;
|
1070
|
+
device const int8_t * sc = x[i].scales;
|
1071
|
+
|
1072
|
+
for (int l = 0; l < 8; ++l) {
|
1073
|
+
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
1074
|
+
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
1075
|
+
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
1076
|
+
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
1077
|
+
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
1078
|
+
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
1079
|
+
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
1080
|
+
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
1081
|
+
}
|
1082
|
+
y += QK_K;
|
1083
|
+
}
|
1084
|
+
#endif
|
976
1085
|
|
977
1086
|
}
|
978
1087
|
|
979
|
-
static void
|
1088
|
+
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
980
1089
|
assert(k % QK_K == 0);
|
981
1090
|
const int nb = k / QK_K;
|
982
1091
|
|
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
988
1097
|
|
989
1098
|
const float d = x[i].d;
|
990
1099
|
|
1100
|
+
#if QK_K == 256
|
991
1101
|
for (int n = 0; n < QK_K; n += 128) {
|
992
1102
|
for (int l = 0; l < 32; ++l) {
|
993
1103
|
int is = l/16;
|
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
1005
1115
|
qh += 32;
|
1006
1116
|
sc += 8;
|
1007
1117
|
}
|
1118
|
+
#else
|
1119
|
+
for (int l = 0; l < 16; ++l) {
|
1120
|
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1121
|
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1122
|
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1123
|
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1124
|
+
y[l+ 0] = d * sc[0] * q1;
|
1125
|
+
y[l+16] = d * sc[1] * q2;
|
1126
|
+
y[l+32] = d * sc[2] * q3;
|
1127
|
+
y[l+48] = d * sc[3] * q4;
|
1128
|
+
}
|
1129
|
+
y += 64;
|
1130
|
+
#endif
|
1008
1131
|
}
|
1009
1132
|
}
|
1010
1133
|
|
1011
|
-
kernel void
|
1134
|
+
kernel void kernel_get_rows_q2_K(
|
1012
1135
|
device const void * src0,
|
1013
1136
|
device const int * src1,
|
1014
1137
|
device float * dst,
|
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
|
|
1019
1142
|
const int i = tpig;
|
1020
1143
|
const int r = ((device int32_t *) src1)[i];
|
1021
1144
|
|
1022
|
-
|
1023
|
-
(device const
|
1145
|
+
dequantize_row_q2_K(
|
1146
|
+
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
1024
1147
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1025
1148
|
}
|
1026
1149
|
|
1027
|
-
kernel void
|
1150
|
+
kernel void kernel_get_rows_q3_K(
|
1028
1151
|
device const void * src0,
|
1029
1152
|
device const int * src1,
|
1030
1153
|
device float * dst,
|
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
|
|
1035
1158
|
const int i = tpig;
|
1036
1159
|
const int r = ((device int32_t *) src1)[i];
|
1037
1160
|
|
1038
|
-
|
1039
|
-
(device const
|
1161
|
+
dequantize_row_q3_K(
|
1162
|
+
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
1040
1163
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1041
1164
|
}
|
1042
1165
|
|
1043
|
-
kernel void
|
1166
|
+
kernel void kernel_get_rows_q4_K(
|
1044
1167
|
device const void * src0,
|
1045
1168
|
device const int * src1,
|
1046
1169
|
device float * dst,
|
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
|
|
1051
1174
|
const int i = tpig;
|
1052
1175
|
const int r = ((device int32_t *) src1)[i];
|
1053
1176
|
|
1054
|
-
|
1055
|
-
(device const
|
1177
|
+
dequantize_row_q4_K(
|
1178
|
+
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
1056
1179
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1057
1180
|
}
|
1058
1181
|
|
1059
|
-
kernel void
|
1182
|
+
kernel void kernel_get_rows_q5_K(
|
1060
1183
|
device const void * src0,
|
1061
1184
|
device const int * src1,
|
1062
1185
|
device float * dst,
|
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
|
|
1067
1190
|
const int i = tpig;
|
1068
1191
|
const int r = ((device int32_t *) src1)[i];
|
1069
1192
|
|
1070
|
-
|
1071
|
-
(device const
|
1193
|
+
dequantize_row_q5_K(
|
1194
|
+
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
1072
1195
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1073
1196
|
}
|
1074
1197
|
|
1075
|
-
kernel void
|
1198
|
+
kernel void kernel_get_rows_q6_K(
|
1076
1199
|
device const void * src0,
|
1077
1200
|
device const int * src1,
|
1078
1201
|
device float * dst,
|
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
|
|
1083
1206
|
const int i = tpig;
|
1084
1207
|
const int r = ((device int32_t *) src1)[i];
|
1085
1208
|
|
1086
|
-
|
1087
|
-
(device const
|
1209
|
+
dequantize_row_q6_K(
|
1210
|
+
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
1088
1211
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1089
1212
|
}
|
1090
1213
|
|
1091
1214
|
//====================================== dot products =========================
|
1092
1215
|
|
1093
|
-
kernel void
|
1216
|
+
kernel void kernel_mul_mat_q2_K_f32(
|
1094
1217
|
device const void * src0,
|
1095
1218
|
device const float * src1,
|
1096
1219
|
device float * dst,
|
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1107
1230
|
const int64_t r0 = tgpig.x;
|
1108
1231
|
const int64_t r1 = tgpig.y;
|
1109
1232
|
|
1110
|
-
device const
|
1233
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
1111
1234
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1112
1235
|
|
1113
1236
|
const int nth = tptg.x*tptg.y;
|
1114
1237
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1115
1238
|
|
1239
|
+
float sumf = 0;
|
1240
|
+
|
1241
|
+
#if QK_K == 256
|
1116
1242
|
const int tid = tpitg.y; // 0...16
|
1117
1243
|
const int il = tid/4; // 0...3
|
1118
1244
|
const int ir = tid%4; // 0...3
|
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1125
1251
|
const int y_offset = 64*il + n*ir;
|
1126
1252
|
const int q_offset = 32*ip + n*ir;
|
1127
1253
|
|
1128
|
-
sum[ith] = 0.0f;
|
1129
|
-
|
1130
|
-
float sumf = 0;
|
1131
1254
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1132
1255
|
|
1133
1256
|
device const uint8_t * q = x[i].qs + q_offset;
|
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1140
1263
|
|
1141
1264
|
device const float * y = yy + i*QK_K + y_offset;
|
1142
1265
|
|
1143
|
-
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
1144
1266
|
float2 s = {0.f, 0.f};
|
1145
1267
|
float smin = 0;
|
1146
1268
|
for (int l = 0; l < n; ++l) {
|
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1155
1277
|
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1156
1278
|
|
1157
1279
|
}
|
1158
|
-
|
1280
|
+
#else
|
1281
|
+
const int il = 4 * tpitg.x;
|
1159
1282
|
|
1160
|
-
|
1161
|
-
|
1283
|
+
uint32_t aux[2];
|
1284
|
+
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1285
|
+
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1162
1286
|
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1287
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1288
|
+
|
1289
|
+
device const uint8_t * q = x[i].qs + il;
|
1290
|
+
device const float * y = yy + i*QK_K + il;
|
1291
|
+
|
1292
|
+
const float dall = (float)x[i].d;
|
1293
|
+
const float dmin = (float)x[i].dmin;
|
1294
|
+
|
1295
|
+
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
1296
|
+
aux[0] = a[0] & 0x0f0f0f0f;
|
1297
|
+
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
1298
|
+
|
1299
|
+
for (int l = 0; l < 4; ++l) {
|
1300
|
+
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
1301
|
+
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
1302
|
+
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
1303
|
+
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
1304
|
+
}
|
1305
|
+
}
|
1306
|
+
#endif
|
1307
|
+
|
1308
|
+
sum[ith] = sumf;
|
1172
1309
|
|
1173
1310
|
//
|
1174
1311
|
// Accumulate the sum from all threads in the threadgroup
|
1175
|
-
// This version is slightly faster than the commented out one below,
|
1176
|
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
1177
1312
|
//
|
1178
1313
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1179
1314
|
if (ith%4 == 0) {
|
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1190
1325
|
}
|
1191
1326
|
}
|
1192
1327
|
|
1193
|
-
kernel void
|
1328
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1194
1329
|
device const void * src0,
|
1195
1330
|
device const float * src1,
|
1196
1331
|
device float * dst,
|
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1203
1338
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1204
1339
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1205
1340
|
|
1206
|
-
const uint16_t kmask1 = 0x0303;
|
1207
|
-
const uint16_t kmask2 = 0x0f0f;
|
1208
|
-
|
1209
|
-
const uint8_t m3 = 3;
|
1210
|
-
const int8_t m4 = 4;
|
1211
|
-
|
1212
1341
|
const int nb = ne00/QK_K;
|
1213
1342
|
|
1214
1343
|
const int64_t r0 = tgpig.x;
|
1215
1344
|
const int64_t r1 = tgpig.y;
|
1216
1345
|
|
1217
|
-
device const
|
1346
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
1218
1347
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1219
1348
|
|
1220
1349
|
const int nth = tptg.x*tptg.y;
|
1221
1350
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1222
1351
|
|
1352
|
+
#if QK_K == 256
|
1353
|
+
|
1354
|
+
const uint8_t m3 = 3;
|
1355
|
+
const int8_t m4 = 4;
|
1356
|
+
|
1357
|
+
const uint16_t kmask1 = 0x0303;
|
1358
|
+
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
|
1223
1360
|
const int tid = tpitg.y; // expecting 16
|
1224
1361
|
const int ip = tid/8; // 0 or 1
|
1225
1362
|
const int il = tid/2 - 4*ip; // 0...3
|
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1273
1410
|
|
1274
1411
|
//sum[ith] = sumf;
|
1275
1412
|
sum[ith] = sumf1 - 32.f*sumf2;
|
1413
|
+
#else
|
1414
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1415
|
+
const int im = il/8; // 0, 0, 1, 1
|
1416
|
+
const int in = il%8; // 0, 4, 0, 4
|
1417
|
+
|
1418
|
+
float sumf = 0;
|
1419
|
+
|
1420
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1421
|
+
|
1422
|
+
const float d_all = (float)(x[i].d);
|
1423
|
+
|
1424
|
+
device const uint8_t * q = x[i].qs + il;
|
1425
|
+
device const uint8_t * h = x[i].hmask + in;
|
1426
|
+
device const float * y = yy + i * QK_K + il;
|
1427
|
+
|
1428
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
1429
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
1430
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
1431
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
1432
|
+
|
1433
|
+
for (int l = 0; l < 4; ++l) {
|
1434
|
+
const uint8_t hm = h[l] >> im;
|
1435
|
+
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1436
|
+
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1437
|
+
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1438
|
+
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
}
|
1442
|
+
|
1443
|
+
sum[ith] = sumf;
|
1444
|
+
|
1445
|
+
#endif
|
1276
1446
|
|
1277
1447
|
//
|
1278
1448
|
// Accumulate the sum from all threads in the threadgroup
|
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1293
1463
|
|
1294
1464
|
}
|
1295
1465
|
|
1296
|
-
kernel void
|
1466
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1297
1467
|
device const void * src0,
|
1298
1468
|
device const float * src1,
|
1299
1469
|
device float * dst,
|
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1305
1475
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1306
1476
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1307
1477
|
|
1308
|
-
const uint16_t kmask1 = 0x3f3f;
|
1309
|
-
const uint16_t kmask2 = 0x0f0f;
|
1310
|
-
const uint16_t kmask3 = 0xc0c0;
|
1311
|
-
|
1312
1478
|
const int nb = ne00/QK_K;
|
1313
1479
|
|
1314
1480
|
const int64_t r0 = tgpig.x;
|
1315
1481
|
const int64_t r1 = tgpig.y;
|
1316
1482
|
|
1317
|
-
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
1318
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1319
|
-
|
1320
1483
|
const int nth = tptg.x*tptg.y;
|
1321
1484
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1322
1485
|
|
1486
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
1487
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1488
|
+
|
1489
|
+
float sumf = 0;
|
1490
|
+
|
1491
|
+
#if QK_K == 256
|
1492
|
+
|
1493
|
+
const uint16_t kmask1 = 0x3f3f;
|
1494
|
+
const uint16_t kmask2 = 0x0f0f;
|
1495
|
+
const uint16_t kmask3 = 0xc0c0;
|
1496
|
+
|
1323
1497
|
const int tid = tpitg.y; // 0...16
|
1324
1498
|
const int il = tid/4; // 0...3
|
1325
1499
|
const int ir = tid - 4*il;// 0...3
|
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1332
1506
|
const int q_offset = 32*im + l0;
|
1333
1507
|
const int y_offset = 64*im + l0;
|
1334
1508
|
|
1335
|
-
sum[ith] = 0.0f;
|
1336
|
-
|
1337
1509
|
uchar2 sc1, sc2, sc3, sc4;
|
1338
1510
|
|
1339
|
-
float sumf = 0;
|
1340
1511
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1341
1512
|
|
1342
1513
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1365
1536
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1366
1537
|
|
1367
1538
|
}
|
1539
|
+
#else
|
1540
|
+
uint16_t aux16[2];
|
1541
|
+
thread const uint8_t * scales = (thread const uint8_t *)aux16;
|
1542
|
+
|
1543
|
+
const int il = 4*tpitg.x;
|
1544
|
+
|
1545
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1546
|
+
|
1547
|
+
device const uint8_t * q = x[i].qs + il;
|
1548
|
+
device const float * y = yy + i * QK_K + il;
|
1549
|
+
|
1550
|
+
const float d = (float)x[i].d[0];
|
1551
|
+
const float m = (float)x[i].d[1];
|
1552
|
+
|
1553
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
1554
|
+
aux16[0] = a[0] & 0x0f0f;
|
1555
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1556
|
+
|
1557
|
+
for (int l = 0; l < 4; ++l) {
|
1558
|
+
sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
|
1559
|
+
+ d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
|
1560
|
+
}
|
1561
|
+
}
|
1562
|
+
#endif
|
1368
1563
|
|
1369
1564
|
sum[ith] = sumf;
|
1370
1565
|
|
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1401
1596
|
//}
|
1402
1597
|
}
|
1403
1598
|
|
1404
|
-
kernel void
|
1599
|
+
kernel void kernel_mul_mat_q5_K_f32(
|
1405
1600
|
device const void * src0,
|
1406
1601
|
device const float * src1,
|
1407
1602
|
device float * dst,
|
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1413
1608
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1414
1609
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1415
1610
|
|
1416
|
-
const uint16_t kmask1 = 0x3f3f;
|
1417
|
-
const uint16_t kmask2 = 0x0f0f;
|
1418
|
-
const uint16_t kmask3 = 0xc0c0;
|
1419
|
-
|
1420
1611
|
const int nb = ne00/QK_K;
|
1421
1612
|
|
1422
1613
|
const int64_t r0 = tgpig.x;
|
1423
1614
|
const int64_t r1 = tgpig.y;
|
1424
1615
|
|
1425
|
-
device const
|
1616
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
|
1426
1617
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1427
1618
|
|
1428
1619
|
const int nth = tptg.x*tptg.y;
|
1429
1620
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1430
1621
|
|
1622
|
+
float sumf = 0;
|
1623
|
+
|
1624
|
+
#if QK_K == 256
|
1625
|
+
|
1626
|
+
const uint16_t kmask1 = 0x3f3f;
|
1627
|
+
const uint16_t kmask2 = 0x0f0f;
|
1628
|
+
const uint16_t kmask3 = 0xc0c0;
|
1629
|
+
|
1431
1630
|
const int tid = tpitg.y; // 0...16
|
1432
1631
|
const int il = tid/4; // 0...3
|
1433
1632
|
const int ir = tid - 4*il;// 0...3
|
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1447
1646
|
|
1448
1647
|
uchar2 sc1, sc2, sc3, sc4;
|
1449
1648
|
|
1450
|
-
float sumf = 0;
|
1451
1649
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1452
1650
|
|
1453
1651
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1479
1677
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1480
1678
|
|
1481
1679
|
}
|
1680
|
+
#else
|
1681
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1682
|
+
const int im = il/8; // 0, 0, 1, 1
|
1683
|
+
const int in = il%8; // 0, 4, 0, 4
|
1684
|
+
|
1685
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1686
|
+
|
1687
|
+
const float d = (float)x[i].d;
|
1688
|
+
device const uint8_t * q = x[i].qs + il;
|
1689
|
+
device const uint8_t * h = x[i].qh + in;
|
1690
|
+
device const int8_t * s = x[i].scales;
|
1691
|
+
device const float * y = yy + i*QK_K + il;
|
1692
|
+
|
1693
|
+
for (int l = 0; l < 4; ++l) {
|
1694
|
+
const uint8_t hl = h[l] >> im;
|
1695
|
+
sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
|
1696
|
+
+ y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
|
1697
|
+
+ y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
|
1698
|
+
+ y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
|
1699
|
+
}
|
1700
|
+
}
|
1701
|
+
#endif
|
1482
1702
|
sum[ith] = sumf;
|
1483
1703
|
|
1484
1704
|
//
|
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1500
1720
|
|
1501
1721
|
}
|
1502
1722
|
|
1503
|
-
kernel void
|
1723
|
+
kernel void kernel_mul_mat_q6_K_f32(
|
1504
1724
|
device const void * src0,
|
1505
1725
|
device const float * src1,
|
1506
1726
|
device float * dst,
|
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1522
1742
|
const int64_t r0 = tgpig.x;
|
1523
1743
|
const int64_t r1 = tgpig.y;
|
1524
1744
|
|
1525
|
-
device const
|
1745
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
1526
1746
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1527
1747
|
|
1528
1748
|
const int nth = tptg.x*tptg.y;
|
1529
1749
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1530
1750
|
|
1751
|
+
float sumf = 0;
|
1752
|
+
|
1753
|
+
#if QK_K == 256
|
1531
1754
|
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
1532
1755
|
const int iqs = 16 * tpitg.y;
|
1533
1756
|
const int ip = iqs / 128; // 0 or 1
|
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1540
1763
|
const int q_offset_l = 64*ip + l0;
|
1541
1764
|
const int q_offset_h = 32*ip + l0;
|
1542
1765
|
|
1543
|
-
float sumf = 0;
|
1544
1766
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1545
1767
|
|
1546
1768
|
device const uint8_t * ql = x[i].ql + q_offset_l;
|
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1562
1784
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1563
1785
|
|
1564
1786
|
}
|
1787
|
+
#else
|
1788
|
+
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
1789
|
+
|
1790
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1791
|
+
device const float * y = yy + i * QK_K + il;
|
1792
|
+
device const uint8_t * ql = x[i].ql + il;
|
1793
|
+
device const uint8_t * qh = x[i].qh + il;
|
1794
|
+
device const int8_t * s = x[i].scales;
|
1795
|
+
|
1796
|
+
const float d = x[i].d;
|
1797
|
+
|
1798
|
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1799
|
+
for (int l = 0; l < 4; ++l) {
|
1800
|
+
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1801
|
+
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1802
|
+
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
1803
|
+
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1804
|
+
}
|
1805
|
+
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
1806
|
+
}
|
1807
|
+
|
1808
|
+
#endif
|
1565
1809
|
|
1566
1810
|
sum[ith] = sumf;
|
1567
1811
|
|