llama_cpp 0.2.2 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +231 -132
- data/ext/llama_cpp/src/ggml-cuda.cu +319 -52
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +138 -72
- data/ext/llama_cpp/src/llama.h +33 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +12 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
428
428
|
}
|
429
429
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
430
430
|
if (ith == 0) {
|
431
|
-
for (
|
431
|
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
432
432
|
dst[r1*ne0 + r0] = sum[0];
|
433
433
|
}
|
434
434
|
}
|
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
497
497
|
}
|
498
498
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
499
499
|
if (ith == 0) {
|
500
|
-
for (
|
500
|
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
501
501
|
dst[r1*ne0 + r0] = sum[0];
|
502
502
|
}
|
503
503
|
}
|
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
|
|
775
775
|
|
776
776
|
//============================================ k-quants ======================================================
|
777
777
|
|
778
|
+
#ifndef QK_K
|
778
779
|
#define QK_K 256
|
780
|
+
#else
|
781
|
+
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
782
|
+
#endif
|
783
|
+
|
784
|
+
#if QK_K == 256
|
785
|
+
#define K_SCALE_SIZE 12
|
786
|
+
#else
|
787
|
+
#define K_SCALE_SIZE 4
|
788
|
+
#endif
|
779
789
|
|
780
790
|
typedef struct {
|
781
791
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
782
792
|
uint8_t qs[QK_K/4]; // quants
|
783
793
|
half d; // super-block scale for quantized scales
|
784
794
|
half dmin; // super-block scale for quantized mins
|
785
|
-
}
|
795
|
+
} block_q2_K;
|
786
796
|
// 84 bytes / block
|
787
797
|
|
788
798
|
typedef struct {
|
789
799
|
uint8_t hmask[QK_K/8]; // quants - high bit
|
790
800
|
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
//
|
795
|
-
|
801
|
+
#if QK_K == 64
|
802
|
+
uint8_t scales[2];
|
803
|
+
#else
|
804
|
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
805
|
+
#endif
|
806
|
+
half d; // super-block scale
|
807
|
+
} block_q3_K;
|
808
|
+
|
809
|
+
#if QK_K == 64
|
810
|
+
typedef struct {
|
811
|
+
half d[2]; // super-block scales/mins
|
812
|
+
uint8_t scales[2];
|
813
|
+
uint8_t qs[QK_K/2]; // 4-bit quants
|
814
|
+
} block_q4_K;
|
815
|
+
#else
|
796
816
|
typedef struct {
|
797
817
|
half d; // super-block scale for quantized scales
|
798
818
|
half dmin; // super-block scale for quantized mins
|
799
|
-
uint8_t scales[
|
819
|
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
800
820
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
801
|
-
}
|
802
|
-
|
821
|
+
} block_q4_K;
|
822
|
+
#endif
|
803
823
|
|
824
|
+
#if QK_K == 64
|
825
|
+
typedef struct {
|
826
|
+
half d; // super-block scales/mins
|
827
|
+
int8_t scales[QK_K/16]; // 8-bit block scales
|
828
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
829
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
830
|
+
} block_q5_K;
|
831
|
+
#else
|
804
832
|
typedef struct {
|
805
833
|
half d; // super-block scale for quantized scales
|
806
834
|
half dmin; // super-block scale for quantized mins
|
807
835
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
808
836
|
uint8_t qh[QK_K/8]; // quants, high bit
|
809
837
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
810
|
-
}
|
838
|
+
} block_q5_K;
|
811
839
|
// 176 bytes / block
|
840
|
+
#endif
|
812
841
|
|
813
842
|
typedef struct {
|
814
843
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
815
844
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
816
845
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
817
846
|
half d; // super-block scale
|
818
|
-
}
|
847
|
+
} block_q6_K;
|
819
848
|
// 210 bytes / block
|
820
849
|
|
821
850
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
836
865
|
|
837
866
|
//========================================== dequantization =============================
|
838
867
|
|
839
|
-
static void
|
868
|
+
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
840
869
|
assert(k % QK_K == 0);
|
841
870
|
const int nb = k / QK_K;
|
842
871
|
|
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
847
876
|
|
848
877
|
device const uint8_t * q = x[i].qs;
|
849
878
|
|
879
|
+
#if QK_K == 256
|
850
880
|
int is = 0;
|
851
881
|
float dl, ml;
|
852
882
|
for (int n = 0; n < QK_K; n += 128) {
|
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
865
895
|
}
|
866
896
|
q += 32;
|
867
897
|
}
|
898
|
+
#else
|
899
|
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
900
|
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
901
|
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
902
|
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
903
|
+
for (int l = 0; l < 16; ++l) {
|
904
|
+
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
905
|
+
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
906
|
+
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
907
|
+
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
908
|
+
}
|
909
|
+
y += QK_K;
|
910
|
+
#endif
|
868
911
|
|
869
912
|
}
|
870
913
|
}
|
871
914
|
|
872
|
-
static void
|
915
|
+
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
873
916
|
assert(k % QK_K == 0);
|
874
917
|
const int nb = k / QK_K;
|
875
918
|
|
919
|
+
#if QK_K == 256
|
920
|
+
|
876
921
|
const uint16_t kmask1 = 0x0303;
|
877
922
|
const uint16_t kmask2 = 0x0f0f;
|
878
923
|
|
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|
918
963
|
}
|
919
964
|
q += 32;
|
920
965
|
}
|
966
|
+
}
|
967
|
+
#else
|
968
|
+
for (int i = 0; i < nb; i++) {
|
921
969
|
|
970
|
+
const float d_all = (float)(x[i].d);
|
971
|
+
|
972
|
+
device const uint8_t * q = x[i].qs;
|
973
|
+
device const uint8_t * hm = x[i].hmask;
|
974
|
+
|
975
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
976
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
977
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
978
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
979
|
+
|
980
|
+
for (int l = 0; l < 8; ++l) {
|
981
|
+
uint8_t h = hm[l];
|
982
|
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
983
|
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
984
|
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
985
|
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
986
|
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
987
|
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
988
|
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
989
|
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
990
|
+
}
|
991
|
+
y += QK_K;
|
922
992
|
}
|
993
|
+
#endif
|
923
994
|
|
924
995
|
}
|
925
996
|
|
926
|
-
static void
|
997
|
+
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
927
998
|
assert(k % QK_K == 0);
|
928
999
|
const int nb = k / QK_K;
|
929
1000
|
|
930
|
-
|
931
1001
|
for (int i = 0; i < nb; i++) {
|
932
1002
|
|
1003
|
+
device const uint8_t * q = x[i].qs;
|
1004
|
+
|
1005
|
+
#if QK_K == 256
|
933
1006
|
const float d = x[i].d;
|
934
1007
|
const float min = x[i].dmin;
|
935
1008
|
|
936
|
-
device const uint8_t * q = x[i].qs;
|
937
1009
|
device const uint8_t * scales = x[i].scales;
|
938
1010
|
|
939
1011
|
int is = 0;
|
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
945
1017
|
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
946
1018
|
q += 32; is += 2;
|
947
1019
|
}
|
1020
|
+
#else
|
1021
|
+
device const uint8_t * s = x[i].scales;
|
1022
|
+
device const half2 * dh = (device const half2 *)x[i].d;
|
1023
|
+
const float2 d = (float2)dh[0];
|
1024
|
+
const float d1 = d[0] * (s[0] & 0xF);
|
1025
|
+
const float d2 = d[0] * (s[1] & 0xF);
|
1026
|
+
const float m1 = d[1] * (s[0] >> 4);
|
1027
|
+
const float m2 = d[1] * (s[1] >> 4);
|
1028
|
+
for (int l = 0; l < 32; ++l) {
|
1029
|
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
1030
|
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
1031
|
+
}
|
1032
|
+
y += QK_K;
|
1033
|
+
#endif
|
948
1034
|
|
949
1035
|
}
|
950
1036
|
}
|
951
1037
|
|
952
|
-
static void
|
1038
|
+
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
953
1039
|
assert(k % QK_K == 0);
|
954
1040
|
const int nb = k / QK_K;
|
955
1041
|
|
1042
|
+
#if QK_K == 256
|
956
1043
|
for (int i = 0; i < nb; i++) {
|
957
1044
|
|
958
1045
|
const float d = (float)(x[i].d);
|
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|
973
1060
|
u1 <<= 2; u2 <<= 2;
|
974
1061
|
}
|
975
1062
|
}
|
1063
|
+
#else
|
1064
|
+
for (int i = 0; i < nb; i++) {
|
1065
|
+
|
1066
|
+
const float d = (float)x[i].d;
|
1067
|
+
|
1068
|
+
device const uint8_t * ql = x[i].qs;
|
1069
|
+
device const uint8_t * qh = x[i].qh;
|
1070
|
+
device const int8_t * sc = x[i].scales;
|
1071
|
+
|
1072
|
+
for (int l = 0; l < 8; ++l) {
|
1073
|
+
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
1074
|
+
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
1075
|
+
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
1076
|
+
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
1077
|
+
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
1078
|
+
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
1079
|
+
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
1080
|
+
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
1081
|
+
}
|
1082
|
+
y += QK_K;
|
1083
|
+
}
|
1084
|
+
#endif
|
976
1085
|
|
977
1086
|
}
|
978
1087
|
|
979
|
-
static void
|
1088
|
+
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
980
1089
|
assert(k % QK_K == 0);
|
981
1090
|
const int nb = k / QK_K;
|
982
1091
|
|
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
988
1097
|
|
989
1098
|
const float d = x[i].d;
|
990
1099
|
|
1100
|
+
#if QK_K == 256
|
991
1101
|
for (int n = 0; n < QK_K; n += 128) {
|
992
1102
|
for (int l = 0; l < 32; ++l) {
|
993
1103
|
int is = l/16;
|
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
1005
1115
|
qh += 32;
|
1006
1116
|
sc += 8;
|
1007
1117
|
}
|
1118
|
+
#else
|
1119
|
+
for (int l = 0; l < 16; ++l) {
|
1120
|
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1121
|
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1122
|
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1123
|
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1124
|
+
y[l+ 0] = d * sc[0] * q1;
|
1125
|
+
y[l+16] = d * sc[1] * q2;
|
1126
|
+
y[l+32] = d * sc[2] * q3;
|
1127
|
+
y[l+48] = d * sc[3] * q4;
|
1128
|
+
}
|
1129
|
+
y += 64;
|
1130
|
+
#endif
|
1008
1131
|
}
|
1009
1132
|
}
|
1010
1133
|
|
1011
|
-
kernel void
|
1134
|
+
kernel void kernel_get_rows_q2_K(
|
1012
1135
|
device const void * src0,
|
1013
1136
|
device const int * src1,
|
1014
1137
|
device float * dst,
|
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
|
|
1019
1142
|
const int i = tpig;
|
1020
1143
|
const int r = ((device int32_t *) src1)[i];
|
1021
1144
|
|
1022
|
-
|
1023
|
-
(device const
|
1145
|
+
dequantize_row_q2_K(
|
1146
|
+
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
1024
1147
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1025
1148
|
}
|
1026
1149
|
|
1027
|
-
kernel void
|
1150
|
+
kernel void kernel_get_rows_q3_K(
|
1028
1151
|
device const void * src0,
|
1029
1152
|
device const int * src1,
|
1030
1153
|
device float * dst,
|
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
|
|
1035
1158
|
const int i = tpig;
|
1036
1159
|
const int r = ((device int32_t *) src1)[i];
|
1037
1160
|
|
1038
|
-
|
1039
|
-
(device const
|
1161
|
+
dequantize_row_q3_K(
|
1162
|
+
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
1040
1163
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1041
1164
|
}
|
1042
1165
|
|
1043
|
-
kernel void
|
1166
|
+
kernel void kernel_get_rows_q4_K(
|
1044
1167
|
device const void * src0,
|
1045
1168
|
device const int * src1,
|
1046
1169
|
device float * dst,
|
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
|
|
1051
1174
|
const int i = tpig;
|
1052
1175
|
const int r = ((device int32_t *) src1)[i];
|
1053
1176
|
|
1054
|
-
|
1055
|
-
(device const
|
1177
|
+
dequantize_row_q4_K(
|
1178
|
+
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
1056
1179
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1057
1180
|
}
|
1058
1181
|
|
1059
|
-
kernel void
|
1182
|
+
kernel void kernel_get_rows_q5_K(
|
1060
1183
|
device const void * src0,
|
1061
1184
|
device const int * src1,
|
1062
1185
|
device float * dst,
|
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
|
|
1067
1190
|
const int i = tpig;
|
1068
1191
|
const int r = ((device int32_t *) src1)[i];
|
1069
1192
|
|
1070
|
-
|
1071
|
-
(device const
|
1193
|
+
dequantize_row_q5_K(
|
1194
|
+
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
1072
1195
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1073
1196
|
}
|
1074
1197
|
|
1075
|
-
kernel void
|
1198
|
+
kernel void kernel_get_rows_q6_K(
|
1076
1199
|
device const void * src0,
|
1077
1200
|
device const int * src1,
|
1078
1201
|
device float * dst,
|
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
|
|
1083
1206
|
const int i = tpig;
|
1084
1207
|
const int r = ((device int32_t *) src1)[i];
|
1085
1208
|
|
1086
|
-
|
1087
|
-
(device const
|
1209
|
+
dequantize_row_q6_K(
|
1210
|
+
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
1088
1211
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
1089
1212
|
}
|
1090
1213
|
|
1091
1214
|
//====================================== dot products =========================
|
1092
1215
|
|
1093
|
-
kernel void
|
1216
|
+
kernel void kernel_mul_mat_q2_K_f32(
|
1094
1217
|
device const void * src0,
|
1095
1218
|
device const float * src1,
|
1096
1219
|
device float * dst,
|
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1107
1230
|
const int64_t r0 = tgpig.x;
|
1108
1231
|
const int64_t r1 = tgpig.y;
|
1109
1232
|
|
1110
|
-
device const
|
1233
|
+
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
1111
1234
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1112
1235
|
|
1113
1236
|
const int nth = tptg.x*tptg.y;
|
1114
1237
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1115
1238
|
|
1239
|
+
float sumf = 0;
|
1240
|
+
|
1241
|
+
#if QK_K == 256
|
1116
1242
|
const int tid = tpitg.y; // 0...16
|
1117
1243
|
const int il = tid/4; // 0...3
|
1118
1244
|
const int ir = tid%4; // 0...3
|
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1125
1251
|
const int y_offset = 64*il + n*ir;
|
1126
1252
|
const int q_offset = 32*ip + n*ir;
|
1127
1253
|
|
1128
|
-
sum[ith] = 0.0f;
|
1129
|
-
|
1130
|
-
float sumf = 0;
|
1131
1254
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1132
1255
|
|
1133
1256
|
device const uint8_t * q = x[i].qs + q_offset;
|
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1140
1263
|
|
1141
1264
|
device const float * y = yy + i*QK_K + y_offset;
|
1142
1265
|
|
1143
|
-
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
1144
1266
|
float2 s = {0.f, 0.f};
|
1145
1267
|
float smin = 0;
|
1146
1268
|
for (int l = 0; l < n; ++l) {
|
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1155
1277
|
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
1156
1278
|
|
1157
1279
|
}
|
1158
|
-
|
1280
|
+
#else
|
1281
|
+
const int il = 4 * tpitg.x;
|
1159
1282
|
|
1160
|
-
|
1161
|
-
|
1283
|
+
uint32_t aux[2];
|
1284
|
+
thread const uint8_t * d = (thread const uint8_t *)aux;
|
1285
|
+
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
1162
1286
|
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1287
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1288
|
+
|
1289
|
+
device const uint8_t * q = x[i].qs + il;
|
1290
|
+
device const float * y = yy + i*QK_K + il;
|
1291
|
+
|
1292
|
+
const float dall = (float)x[i].d;
|
1293
|
+
const float dmin = (float)x[i].dmin;
|
1294
|
+
|
1295
|
+
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
1296
|
+
aux[0] = a[0] & 0x0f0f0f0f;
|
1297
|
+
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
1298
|
+
|
1299
|
+
for (int l = 0; l < 4; ++l) {
|
1300
|
+
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
1301
|
+
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
1302
|
+
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
1303
|
+
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
1304
|
+
}
|
1305
|
+
}
|
1306
|
+
#endif
|
1307
|
+
|
1308
|
+
sum[ith] = sumf;
|
1172
1309
|
|
1173
1310
|
//
|
1174
1311
|
// Accumulate the sum from all threads in the threadgroup
|
1175
|
-
// This version is slightly faster than the commented out one below,
|
1176
|
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
1177
1312
|
//
|
1178
1313
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1179
1314
|
if (ith%4 == 0) {
|
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
1190
1325
|
}
|
1191
1326
|
}
|
1192
1327
|
|
1193
|
-
kernel void
|
1328
|
+
kernel void kernel_mul_mat_q3_K_f32(
|
1194
1329
|
device const void * src0,
|
1195
1330
|
device const float * src1,
|
1196
1331
|
device float * dst,
|
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1203
1338
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1204
1339
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1205
1340
|
|
1206
|
-
const uint16_t kmask1 = 0x0303;
|
1207
|
-
const uint16_t kmask2 = 0x0f0f;
|
1208
|
-
|
1209
|
-
const uint8_t m3 = 3;
|
1210
|
-
const int8_t m4 = 4;
|
1211
|
-
|
1212
1341
|
const int nb = ne00/QK_K;
|
1213
1342
|
|
1214
1343
|
const int64_t r0 = tgpig.x;
|
1215
1344
|
const int64_t r1 = tgpig.y;
|
1216
1345
|
|
1217
|
-
device const
|
1346
|
+
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
1218
1347
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1219
1348
|
|
1220
1349
|
const int nth = tptg.x*tptg.y;
|
1221
1350
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1222
1351
|
|
1352
|
+
#if QK_K == 256
|
1353
|
+
|
1354
|
+
const uint8_t m3 = 3;
|
1355
|
+
const int8_t m4 = 4;
|
1356
|
+
|
1357
|
+
const uint16_t kmask1 = 0x0303;
|
1358
|
+
const uint16_t kmask2 = 0x0f0f;
|
1359
|
+
|
1223
1360
|
const int tid = tpitg.y; // expecting 16
|
1224
1361
|
const int ip = tid/8; // 0 or 1
|
1225
1362
|
const int il = tid/2 - 4*ip; // 0...3
|
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1273
1410
|
|
1274
1411
|
//sum[ith] = sumf;
|
1275
1412
|
sum[ith] = sumf1 - 32.f*sumf2;
|
1413
|
+
#else
|
1414
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1415
|
+
const int im = il/8; // 0, 0, 1, 1
|
1416
|
+
const int in = il%8; // 0, 4, 0, 4
|
1417
|
+
|
1418
|
+
float sumf = 0;
|
1419
|
+
|
1420
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1421
|
+
|
1422
|
+
const float d_all = (float)(x[i].d);
|
1423
|
+
|
1424
|
+
device const uint8_t * q = x[i].qs + il;
|
1425
|
+
device const uint8_t * h = x[i].hmask + in;
|
1426
|
+
device const float * y = yy + i * QK_K + il;
|
1427
|
+
|
1428
|
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
1429
|
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
1430
|
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
1431
|
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
1432
|
+
|
1433
|
+
for (int l = 0; l < 4; ++l) {
|
1434
|
+
const uint8_t hm = h[l] >> im;
|
1435
|
+
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
1436
|
+
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
1437
|
+
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
1438
|
+
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
1439
|
+
}
|
1440
|
+
|
1441
|
+
}
|
1442
|
+
|
1443
|
+
sum[ith] = sumf;
|
1444
|
+
|
1445
|
+
#endif
|
1276
1446
|
|
1277
1447
|
//
|
1278
1448
|
// Accumulate the sum from all threads in the threadgroup
|
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
1293
1463
|
|
1294
1464
|
}
|
1295
1465
|
|
1296
|
-
kernel void
|
1466
|
+
kernel void kernel_mul_mat_q4_K_f32(
|
1297
1467
|
device const void * src0,
|
1298
1468
|
device const float * src1,
|
1299
1469
|
device float * dst,
|
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1305
1475
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1306
1476
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1307
1477
|
|
1308
|
-
const uint16_t kmask1 = 0x3f3f;
|
1309
|
-
const uint16_t kmask2 = 0x0f0f;
|
1310
|
-
const uint16_t kmask3 = 0xc0c0;
|
1311
|
-
|
1312
1478
|
const int nb = ne00/QK_K;
|
1313
1479
|
|
1314
1480
|
const int64_t r0 = tgpig.x;
|
1315
1481
|
const int64_t r1 = tgpig.y;
|
1316
1482
|
|
1317
|
-
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
1318
|
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
1319
|
-
|
1320
1483
|
const int nth = tptg.x*tptg.y;
|
1321
1484
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1322
1485
|
|
1486
|
+
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
1487
|
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
1488
|
+
|
1489
|
+
float sumf = 0;
|
1490
|
+
|
1491
|
+
#if QK_K == 256
|
1492
|
+
|
1493
|
+
const uint16_t kmask1 = 0x3f3f;
|
1494
|
+
const uint16_t kmask2 = 0x0f0f;
|
1495
|
+
const uint16_t kmask3 = 0xc0c0;
|
1496
|
+
|
1323
1497
|
const int tid = tpitg.y; // 0...16
|
1324
1498
|
const int il = tid/4; // 0...3
|
1325
1499
|
const int ir = tid - 4*il;// 0...3
|
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1332
1506
|
const int q_offset = 32*im + l0;
|
1333
1507
|
const int y_offset = 64*im + l0;
|
1334
1508
|
|
1335
|
-
sum[ith] = 0.0f;
|
1336
|
-
|
1337
1509
|
uchar2 sc1, sc2, sc3, sc4;
|
1338
1510
|
|
1339
|
-
float sumf = 0;
|
1340
1511
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1341
1512
|
|
1342
1513
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1365
1536
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1366
1537
|
|
1367
1538
|
}
|
1539
|
+
#else
|
1540
|
+
uint16_t aux16[2];
|
1541
|
+
thread const uint8_t * scales = (thread const uint8_t *)aux16;
|
1542
|
+
|
1543
|
+
const int il = 4*tpitg.x;
|
1544
|
+
|
1545
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1546
|
+
|
1547
|
+
device const uint8_t * q = x[i].qs + il;
|
1548
|
+
device const float * y = yy + i * QK_K + il;
|
1549
|
+
|
1550
|
+
const float d = (float)x[i].d[0];
|
1551
|
+
const float m = (float)x[i].d[1];
|
1552
|
+
|
1553
|
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
1554
|
+
aux16[0] = a[0] & 0x0f0f;
|
1555
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
1556
|
+
|
1557
|
+
for (int l = 0; l < 4; ++l) {
|
1558
|
+
sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
|
1559
|
+
+ d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
|
1560
|
+
}
|
1561
|
+
}
|
1562
|
+
#endif
|
1368
1563
|
|
1369
1564
|
sum[ith] = sumf;
|
1370
1565
|
|
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
1401
1596
|
//}
|
1402
1597
|
}
|
1403
1598
|
|
1404
|
-
kernel void
|
1599
|
+
kernel void kernel_mul_mat_q5_K_f32(
|
1405
1600
|
device const void * src0,
|
1406
1601
|
device const float * src1,
|
1407
1602
|
device float * dst,
|
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1413
1608
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
1414
1609
|
uint2 tptg[[threads_per_threadgroup]]) {
|
1415
1610
|
|
1416
|
-
const uint16_t kmask1 = 0x3f3f;
|
1417
|
-
const uint16_t kmask2 = 0x0f0f;
|
1418
|
-
const uint16_t kmask3 = 0xc0c0;
|
1419
|
-
|
1420
1611
|
const int nb = ne00/QK_K;
|
1421
1612
|
|
1422
1613
|
const int64_t r0 = tgpig.x;
|
1423
1614
|
const int64_t r1 = tgpig.y;
|
1424
1615
|
|
1425
|
-
device const
|
1616
|
+
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
|
1426
1617
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1427
1618
|
|
1428
1619
|
const int nth = tptg.x*tptg.y;
|
1429
1620
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1430
1621
|
|
1622
|
+
float sumf = 0;
|
1623
|
+
|
1624
|
+
#if QK_K == 256
|
1625
|
+
|
1626
|
+
const uint16_t kmask1 = 0x3f3f;
|
1627
|
+
const uint16_t kmask2 = 0x0f0f;
|
1628
|
+
const uint16_t kmask3 = 0xc0c0;
|
1629
|
+
|
1431
1630
|
const int tid = tpitg.y; // 0...16
|
1432
1631
|
const int il = tid/4; // 0...3
|
1433
1632
|
const int ir = tid - 4*il;// 0...3
|
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1447
1646
|
|
1448
1647
|
uchar2 sc1, sc2, sc3, sc4;
|
1449
1648
|
|
1450
|
-
float sumf = 0;
|
1451
1649
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1452
1650
|
|
1453
1651
|
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1479
1677
|
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
1480
1678
|
|
1481
1679
|
}
|
1680
|
+
#else
|
1681
|
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
1682
|
+
const int im = il/8; // 0, 0, 1, 1
|
1683
|
+
const int in = il%8; // 0, 4, 0, 4
|
1684
|
+
|
1685
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1686
|
+
|
1687
|
+
const float d = (float)x[i].d;
|
1688
|
+
device const uint8_t * q = x[i].qs + il;
|
1689
|
+
device const uint8_t * h = x[i].qh + in;
|
1690
|
+
device const int8_t * s = x[i].scales;
|
1691
|
+
device const float * y = yy + i*QK_K + il;
|
1692
|
+
|
1693
|
+
for (int l = 0; l < 4; ++l) {
|
1694
|
+
const uint8_t hl = h[l] >> im;
|
1695
|
+
sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
|
1696
|
+
+ y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
|
1697
|
+
+ y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
|
1698
|
+
+ y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
|
1699
|
+
}
|
1700
|
+
}
|
1701
|
+
#endif
|
1482
1702
|
sum[ith] = sumf;
|
1483
1703
|
|
1484
1704
|
//
|
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
1500
1720
|
|
1501
1721
|
}
|
1502
1722
|
|
1503
|
-
kernel void
|
1723
|
+
kernel void kernel_mul_mat_q6_K_f32(
|
1504
1724
|
device const void * src0,
|
1505
1725
|
device const float * src1,
|
1506
1726
|
device float * dst,
|
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1522
1742
|
const int64_t r0 = tgpig.x;
|
1523
1743
|
const int64_t r1 = tgpig.y;
|
1524
1744
|
|
1525
|
-
device const
|
1745
|
+
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
1526
1746
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
1527
1747
|
|
1528
1748
|
const int nth = tptg.x*tptg.y;
|
1529
1749
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
1530
1750
|
|
1751
|
+
float sumf = 0;
|
1752
|
+
|
1753
|
+
#if QK_K == 256
|
1531
1754
|
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
1532
1755
|
const int iqs = 16 * tpitg.y;
|
1533
1756
|
const int ip = iqs / 128; // 0 or 1
|
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1540
1763
|
const int q_offset_l = 64*ip + l0;
|
1541
1764
|
const int q_offset_h = 32*ip + l0;
|
1542
1765
|
|
1543
|
-
float sumf = 0;
|
1544
1766
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
1545
1767
|
|
1546
1768
|
device const uint8_t * ql = x[i].ql + q_offset_l;
|
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
1562
1784
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
1563
1785
|
|
1564
1786
|
}
|
1787
|
+
#else
|
1788
|
+
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
1789
|
+
|
1790
|
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
1791
|
+
device const float * y = yy + i * QK_K + il;
|
1792
|
+
device const uint8_t * ql = x[i].ql + il;
|
1793
|
+
device const uint8_t * qh = x[i].qh + il;
|
1794
|
+
device const int8_t * s = x[i].scales;
|
1795
|
+
|
1796
|
+
const float d = x[i].d;
|
1797
|
+
|
1798
|
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
1799
|
+
for (int l = 0; l < 4; ++l) {
|
1800
|
+
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
1801
|
+
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
1802
|
+
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
1803
|
+
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
1804
|
+
}
|
1805
|
+
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
1806
|
+
}
|
1807
|
+
|
1808
|
+
#endif
|
1565
1809
|
|
1566
1810
|
sum[ith] = sumf;
|
1567
1811
|
|