llama_cpp 0.10.3 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/LICENSE.txt +1 -1
- data/ext/llama_cpp/extconf.rb +35 -110
- data/ext/llama_cpp/llama_cpp.cpp +52 -28
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -1
- data/vendor/include/.gitkeep +0 -0
- data/vendor/lib/.gitkeep +0 -0
- data/vendor/tmp/llama.cpp/Makefile +758 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.c +6 -2
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.cu +73 -63
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-impl.h +1 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.m +43 -20
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.metal +464 -245
- data/vendor/tmp/llama.cpp/ggml-opencl.h +25 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.c +61 -57
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.c +171 -5
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml.h +1 -0
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.cpp +222 -105
- data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/llama.h +31 -32
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +38 -0
- metadata +30 -27
- data/ext/llama_cpp/src/ggml-opencl.h +0 -25
- data/ext/llama_cpp/src/llama-util.h +0 -546
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/LICENSE +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-alloc.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend-impl.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-backend.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-cuda.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-metal.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.c +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-mpi.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-opencl.cpp +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/ggml-quants.h +0 -0
- /data/{ext/llama_cpp/src → vendor/tmp/llama.cpp}/unicode.h +0 -0
@@ -59,26 +59,26 @@ kernel void kernel_add(
|
|
59
59
|
constant int64_t & ne01,
|
60
60
|
constant int64_t & ne02,
|
61
61
|
constant int64_t & ne03,
|
62
|
-
constant
|
63
|
-
constant
|
64
|
-
constant
|
65
|
-
constant
|
62
|
+
constant uint64_t & nb00,
|
63
|
+
constant uint64_t & nb01,
|
64
|
+
constant uint64_t & nb02,
|
65
|
+
constant uint64_t & nb03,
|
66
66
|
constant int64_t & ne10,
|
67
67
|
constant int64_t & ne11,
|
68
68
|
constant int64_t & ne12,
|
69
69
|
constant int64_t & ne13,
|
70
|
-
constant
|
71
|
-
constant
|
72
|
-
constant
|
73
|
-
constant
|
70
|
+
constant uint64_t & nb10,
|
71
|
+
constant uint64_t & nb11,
|
72
|
+
constant uint64_t & nb12,
|
73
|
+
constant uint64_t & nb13,
|
74
74
|
constant int64_t & ne0,
|
75
75
|
constant int64_t & ne1,
|
76
76
|
constant int64_t & ne2,
|
77
77
|
constant int64_t & ne3,
|
78
|
-
constant
|
79
|
-
constant
|
80
|
-
constant
|
81
|
-
constant
|
78
|
+
constant uint64_t & nb0,
|
79
|
+
constant uint64_t & nb1,
|
80
|
+
constant uint64_t & nb2,
|
81
|
+
constant uint64_t & nb3,
|
82
82
|
constant int64_t & offs,
|
83
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
84
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -109,26 +109,26 @@ kernel void kernel_mul(
|
|
109
109
|
constant int64_t & ne01,
|
110
110
|
constant int64_t & ne02,
|
111
111
|
constant int64_t & ne03,
|
112
|
-
constant
|
113
|
-
constant
|
114
|
-
constant
|
115
|
-
constant
|
112
|
+
constant uint64_t & nb00,
|
113
|
+
constant uint64_t & nb01,
|
114
|
+
constant uint64_t & nb02,
|
115
|
+
constant uint64_t & nb03,
|
116
116
|
constant int64_t & ne10,
|
117
117
|
constant int64_t & ne11,
|
118
118
|
constant int64_t & ne12,
|
119
119
|
constant int64_t & ne13,
|
120
|
-
constant
|
121
|
-
constant
|
122
|
-
constant
|
123
|
-
constant
|
120
|
+
constant uint64_t & nb10,
|
121
|
+
constant uint64_t & nb11,
|
122
|
+
constant uint64_t & nb12,
|
123
|
+
constant uint64_t & nb13,
|
124
124
|
constant int64_t & ne0,
|
125
125
|
constant int64_t & ne1,
|
126
126
|
constant int64_t & ne2,
|
127
127
|
constant int64_t & ne3,
|
128
|
-
constant
|
129
|
-
constant
|
130
|
-
constant
|
131
|
-
constant
|
128
|
+
constant uint64_t & nb0,
|
129
|
+
constant uint64_t & nb1,
|
130
|
+
constant uint64_t & nb2,
|
131
|
+
constant uint64_t & nb3,
|
132
132
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
133
133
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
134
134
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -158,26 +158,26 @@ kernel void kernel_div(
|
|
158
158
|
constant int64_t & ne01,
|
159
159
|
constant int64_t & ne02,
|
160
160
|
constant int64_t & ne03,
|
161
|
-
constant
|
162
|
-
constant
|
163
|
-
constant
|
164
|
-
constant
|
161
|
+
constant uint64_t & nb00,
|
162
|
+
constant uint64_t & nb01,
|
163
|
+
constant uint64_t & nb02,
|
164
|
+
constant uint64_t & nb03,
|
165
165
|
constant int64_t & ne10,
|
166
166
|
constant int64_t & ne11,
|
167
167
|
constant int64_t & ne12,
|
168
168
|
constant int64_t & ne13,
|
169
|
-
constant
|
170
|
-
constant
|
171
|
-
constant
|
172
|
-
constant
|
169
|
+
constant uint64_t & nb10,
|
170
|
+
constant uint64_t & nb11,
|
171
|
+
constant uint64_t & nb12,
|
172
|
+
constant uint64_t & nb13,
|
173
173
|
constant int64_t & ne0,
|
174
174
|
constant int64_t & ne1,
|
175
175
|
constant int64_t & ne2,
|
176
176
|
constant int64_t & ne3,
|
177
|
-
constant
|
178
|
-
constant
|
179
|
-
constant
|
180
|
-
constant
|
177
|
+
constant uint64_t & nb0,
|
178
|
+
constant uint64_t & nb1,
|
179
|
+
constant uint64_t & nb2,
|
180
|
+
constant uint64_t & nb3,
|
181
181
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
182
182
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
183
183
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
|
|
205
205
|
device const float4 * src0,
|
206
206
|
device const float4 * src1,
|
207
207
|
device float4 * dst,
|
208
|
-
constant
|
208
|
+
constant uint64_t & nb [[buffer(28)]],
|
209
209
|
uint tpig[[thread_position_in_grid]]) {
|
210
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
211
211
|
}
|
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
|
|
214
214
|
device const float4 * src0,
|
215
215
|
device const float4 * src1,
|
216
216
|
device float4 * dst,
|
217
|
-
constant
|
217
|
+
constant uint64_t & nb [[buffer(28)]],
|
218
218
|
uint tpig[[thread_position_in_grid]]) {
|
219
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
220
220
|
}
|
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
|
|
223
223
|
device const float4 * src0,
|
224
224
|
device const float4 * src1,
|
225
225
|
device float4 * dst,
|
226
|
-
constant
|
226
|
+
constant uint64_t & nb [[buffer(28)]],
|
227
227
|
uint tpig[[thread_position_in_grid]]) {
|
228
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
229
229
|
}
|
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
|
|
307
307
|
constant int64_t & ne01,
|
308
308
|
constant int64_t & ne02,
|
309
309
|
constant int64_t & ne03,
|
310
|
-
constant
|
311
|
-
constant
|
312
|
-
constant
|
313
|
-
constant
|
310
|
+
constant uint64_t & nb00,
|
311
|
+
constant uint64_t & nb01,
|
312
|
+
constant uint64_t & nb02,
|
313
|
+
constant uint64_t & nb03,
|
314
314
|
constant int64_t & ne10,
|
315
315
|
constant int64_t & ne11,
|
316
316
|
constant int64_t & ne12,
|
317
317
|
constant int64_t & ne13,
|
318
|
-
constant
|
319
|
-
constant
|
320
|
-
constant
|
321
|
-
constant
|
318
|
+
constant uint64_t & nb10,
|
319
|
+
constant uint64_t & nb11,
|
320
|
+
constant uint64_t & nb12,
|
321
|
+
constant uint64_t & nb13,
|
322
322
|
constant int64_t & ne0,
|
323
323
|
constant int64_t & ne1,
|
324
324
|
constant int64_t & ne2,
|
325
325
|
constant int64_t & ne3,
|
326
|
-
constant
|
327
|
-
constant
|
328
|
-
constant
|
329
|
-
constant
|
326
|
+
constant uint64_t & nb0,
|
327
|
+
constant uint64_t & nb1,
|
328
|
+
constant uint64_t & nb2,
|
329
|
+
constant uint64_t & nb3,
|
330
330
|
uint3 tpig[[thread_position_in_grid]]) {
|
331
331
|
int64_t i3 = tpig.z;
|
332
332
|
int64_t i2 = tpig.y;
|
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
846
846
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
847
847
|
//Note: This is a template, but strictly speaking it only applies to
|
848
848
|
// quantizations where the block size is 32. It also does not
|
849
|
-
//
|
849
|
+
// guard against the number of rows not being divisible by
|
850
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
851
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
852
852
|
void mul_vec_q_n_f32_impl(
|
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
920
920
|
device const float * src1,
|
921
921
|
device float * dst,
|
922
922
|
constant int64_t & ne00,
|
923
|
-
constant int64_t & ne01
|
924
|
-
constant int64_t & ne02
|
925
|
-
constant
|
926
|
-
constant
|
927
|
-
constant
|
928
|
-
constant int64_t &
|
929
|
-
constant
|
930
|
-
constant
|
923
|
+
constant int64_t & ne01,
|
924
|
+
constant int64_t & ne02,
|
925
|
+
constant uint64_t & nb00,
|
926
|
+
constant uint64_t & nb01,
|
927
|
+
constant uint64_t & nb02,
|
928
|
+
constant int64_t & ne10,
|
929
|
+
constant int64_t & ne11,
|
930
|
+
constant int64_t & ne12,
|
931
|
+
constant uint64_t & nb10,
|
932
|
+
constant uint64_t & nb11,
|
933
|
+
constant uint64_t & nb12,
|
934
|
+
constant int64_t & ne0,
|
935
|
+
constant int64_t & ne1,
|
936
|
+
constant uint & r2,
|
937
|
+
constant uint & r3,
|
931
938
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
932
939
|
uint tiisg[[thread_index_in_simdgroup]],
|
933
940
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
939
946
|
device const float * src1,
|
940
947
|
device float * dst,
|
941
948
|
constant int64_t & ne00,
|
942
|
-
constant int64_t & ne01
|
943
|
-
constant int64_t & ne02
|
944
|
-
constant
|
945
|
-
constant
|
946
|
-
constant
|
947
|
-
constant int64_t &
|
948
|
-
constant
|
949
|
-
constant
|
949
|
+
constant int64_t & ne01,
|
950
|
+
constant int64_t & ne02,
|
951
|
+
constant uint64_t & nb00,
|
952
|
+
constant uint64_t & nb01,
|
953
|
+
constant uint64_t & nb02,
|
954
|
+
constant int64_t & ne10,
|
955
|
+
constant int64_t & ne11,
|
956
|
+
constant int64_t & ne12,
|
957
|
+
constant uint64_t & nb10,
|
958
|
+
constant uint64_t & nb11,
|
959
|
+
constant uint64_t & nb12,
|
960
|
+
constant int64_t & ne0,
|
961
|
+
constant int64_t & ne1,
|
962
|
+
constant uint & r2,
|
963
|
+
constant uint & r3,
|
950
964
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
951
965
|
uint tiisg[[thread_index_in_simdgroup]],
|
952
966
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
958
972
|
device const float * src1,
|
959
973
|
device float * dst,
|
960
974
|
constant int64_t & ne00,
|
961
|
-
constant int64_t & ne01
|
962
|
-
constant int64_t & ne02
|
963
|
-
constant
|
964
|
-
constant
|
965
|
-
constant
|
966
|
-
constant int64_t &
|
967
|
-
constant
|
968
|
-
constant
|
975
|
+
constant int64_t & ne01,
|
976
|
+
constant int64_t & ne02,
|
977
|
+
constant uint64_t & nb00,
|
978
|
+
constant uint64_t & nb01,
|
979
|
+
constant uint64_t & nb02,
|
980
|
+
constant int64_t & ne10,
|
981
|
+
constant int64_t & ne11,
|
982
|
+
constant int64_t & ne12,
|
983
|
+
constant uint64_t & nb10,
|
984
|
+
constant uint64_t & nb11,
|
985
|
+
constant uint64_t & nb12,
|
986
|
+
constant int64_t & ne0,
|
987
|
+
constant int64_t & ne1,
|
988
|
+
constant uint & r2,
|
989
|
+
constant uint & r3,
|
969
990
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
970
991
|
uint tiisg[[thread_index_in_simdgroup]],
|
971
992
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
977
998
|
device const float * src1,
|
978
999
|
device float * dst,
|
979
1000
|
constant int64_t & ne00,
|
980
|
-
constant int64_t & ne01
|
981
|
-
constant int64_t & ne02
|
982
|
-
constant
|
983
|
-
constant
|
984
|
-
constant
|
985
|
-
constant int64_t &
|
986
|
-
constant
|
987
|
-
constant
|
1001
|
+
constant int64_t & ne01,
|
1002
|
+
constant int64_t & ne02,
|
1003
|
+
constant uint64_t & nb00,
|
1004
|
+
constant uint64_t & nb01,
|
1005
|
+
constant uint64_t & nb02,
|
1006
|
+
constant int64_t & ne10,
|
1007
|
+
constant int64_t & ne11,
|
1008
|
+
constant int64_t & ne12,
|
1009
|
+
constant uint64_t & nb10,
|
1010
|
+
constant uint64_t & nb11,
|
1011
|
+
constant uint64_t & nb12,
|
1012
|
+
constant int64_t & ne0,
|
1013
|
+
constant int64_t & ne1,
|
1014
|
+
constant uint & r2,
|
1015
|
+
constant uint & r3,
|
988
1016
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
989
1017
|
uint tiisg[[thread_index_in_simdgroup]],
|
990
1018
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
1071
1099
|
constant int64_t & ne00,
|
1072
1100
|
constant int64_t & ne01,
|
1073
1101
|
constant int64_t & ne02,
|
1102
|
+
constant uint64_t & nb00,
|
1103
|
+
constant uint64_t & nb01,
|
1104
|
+
constant uint64_t & nb02,
|
1074
1105
|
constant int64_t & ne10,
|
1106
|
+
constant int64_t & ne11,
|
1075
1107
|
constant int64_t & ne12,
|
1108
|
+
constant uint64_t & nb10,
|
1109
|
+
constant uint64_t & nb11,
|
1110
|
+
constant uint64_t & nb12,
|
1076
1111
|
constant int64_t & ne0,
|
1077
1112
|
constant int64_t & ne1,
|
1078
|
-
constant uint & r2
|
1079
|
-
constant uint & r3
|
1113
|
+
constant uint & r2,
|
1114
|
+
constant uint & r3,
|
1080
1115
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
1116
|
uint tiisg[[thread_index_in_simdgroup]],
|
1082
1117
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
1182
1217
|
constant uint64_t & nb12,
|
1183
1218
|
constant int64_t & ne0,
|
1184
1219
|
constant int64_t & ne1,
|
1185
|
-
constant uint & r2
|
1186
|
-
constant uint & r3
|
1220
|
+
constant uint & r2,
|
1221
|
+
constant uint & r3,
|
1187
1222
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
1223
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
1224
|
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
1209
1244
|
constant uint64_t & nb12,
|
1210
1245
|
constant int64_t & ne0,
|
1211
1246
|
constant int64_t & ne1,
|
1212
|
-
constant uint & r2
|
1213
|
-
constant uint & r3
|
1247
|
+
constant uint & r2,
|
1248
|
+
constant uint & r3,
|
1214
1249
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1215
1250
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1216
1251
|
|
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1346
1381
|
constant uint64_t & nb12,
|
1347
1382
|
constant int64_t & ne0,
|
1348
1383
|
constant int64_t & ne1,
|
1349
|
-
constant uint & r2
|
1350
|
-
constant uint & r3
|
1384
|
+
constant uint & r2,
|
1385
|
+
constant uint & r3,
|
1351
1386
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
1387
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
1388
|
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1452
1487
|
constant uint64_t & nb12,
|
1453
1488
|
constant int64_t & ne0,
|
1454
1489
|
constant int64_t & ne1,
|
1455
|
-
constant uint & r2
|
1456
|
-
constant uint & r3
|
1490
|
+
constant uint & r2,
|
1491
|
+
constant uint & r3,
|
1457
1492
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
1493
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
1494
|
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1478
1513
|
constant uint64_t & nb12,
|
1479
1514
|
constant int64_t & ne0,
|
1480
1515
|
constant int64_t & ne1,
|
1481
|
-
constant uint & r2
|
1482
|
-
constant uint & r3
|
1516
|
+
constant uint & r2,
|
1517
|
+
constant uint & r3,
|
1483
1518
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1484
1519
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1485
1520
|
|
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|
1543
1578
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
1544
1579
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1545
1580
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1546
|
-
|
1581
|
+
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1582
|
+
|
1547
1583
|
const int64_t k = i3*ne3 + i2;
|
1548
1584
|
|
1549
1585
|
float m_k;
|
@@ -2410,22 +2446,6 @@ typedef struct {
|
|
2410
2446
|
} block_q6_K;
|
2411
2447
|
// 210 bytes / block
|
2412
2448
|
|
2413
|
-
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
2414
|
-
uchar4 r;
|
2415
|
-
if (j < 4) {
|
2416
|
-
r[0] = q[j+0] & 63;
|
2417
|
-
r[2] = q[j+1] & 63;
|
2418
|
-
r[1] = q[j+4] & 63;
|
2419
|
-
r[3] = q[j+5] & 63;
|
2420
|
-
} else {
|
2421
|
-
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
2422
|
-
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
2423
|
-
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
2424
|
-
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
2425
|
-
}
|
2426
|
-
return r;
|
2427
|
-
}
|
2428
|
-
|
2429
2449
|
//====================================== dot products =========================
|
2430
2450
|
|
2431
2451
|
void kernel_mul_mv_q2_K_f32_impl(
|
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2584
2604
|
device const float * src1,
|
2585
2605
|
device float * dst,
|
2586
2606
|
constant int64_t & ne00,
|
2587
|
-
constant int64_t & ne01
|
2588
|
-
constant int64_t & ne02
|
2589
|
-
constant
|
2590
|
-
constant
|
2591
|
-
constant
|
2592
|
-
constant int64_t &
|
2593
|
-
constant
|
2594
|
-
constant
|
2607
|
+
constant int64_t & ne01,
|
2608
|
+
constant int64_t & ne02,
|
2609
|
+
constant uint64_t & nb00,
|
2610
|
+
constant uint64_t & nb01,
|
2611
|
+
constant uint64_t & nb02,
|
2612
|
+
constant int64_t & ne10,
|
2613
|
+
constant int64_t & ne11,
|
2614
|
+
constant int64_t & ne12,
|
2615
|
+
constant uint64_t & nb10,
|
2616
|
+
constant uint64_t & nb11,
|
2617
|
+
constant uint64_t & nb12,
|
2618
|
+
constant int64_t & ne0,
|
2619
|
+
constant int64_t & ne1,
|
2620
|
+
constant uint & r2,
|
2621
|
+
constant uint & r3,
|
2595
2622
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2596
2623
|
uint tiisg[[thread_index_in_simdgroup]],
|
2597
2624
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2841
2868
|
device const float * src1,
|
2842
2869
|
device float * dst,
|
2843
2870
|
constant int64_t & ne00,
|
2844
|
-
constant int64_t & ne01
|
2845
|
-
constant int64_t & ne02
|
2846
|
-
constant
|
2847
|
-
constant
|
2848
|
-
constant
|
2849
|
-
constant int64_t &
|
2850
|
-
constant
|
2851
|
-
constant
|
2871
|
+
constant int64_t & ne01,
|
2872
|
+
constant int64_t & ne02,
|
2873
|
+
constant uint64_t & nb00,
|
2874
|
+
constant uint64_t & nb01,
|
2875
|
+
constant uint64_t & nb02,
|
2876
|
+
constant int64_t & ne10,
|
2877
|
+
constant int64_t & ne11,
|
2878
|
+
constant int64_t & ne12,
|
2879
|
+
constant uint64_t & nb10,
|
2880
|
+
constant uint64_t & nb11,
|
2881
|
+
constant uint64_t & nb12,
|
2882
|
+
constant int64_t & ne0,
|
2883
|
+
constant int64_t & ne1,
|
2884
|
+
constant uint & r2,
|
2885
|
+
constant uint & r3,
|
2852
2886
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2853
2887
|
uint tiisg[[thread_index_in_simdgroup]],
|
2854
2888
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
2984
3018
|
constant uint & r2,
|
2985
3019
|
constant uint & r3,
|
2986
3020
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2987
|
-
uint
|
2988
|
-
uint
|
3021
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3022
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2989
3023
|
|
2990
3024
|
const int ix = tiisg/4; // 0...7
|
2991
3025
|
const int it = tiisg%4; // 0...3
|
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
2994
3028
|
const int r0 = tgpig.x;
|
2995
3029
|
const int r1 = tgpig.y;
|
2996
3030
|
const int im = tgpig.z;
|
2997
|
-
const int first_row =
|
3031
|
+
const int first_row = r0 * N_DST;
|
2998
3032
|
const int ib_row = first_row * nb;
|
2999
3033
|
|
3000
3034
|
const uint i12 = im%ne12;
|
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3060
3094
|
for (int row = 0; row < N_DST; ++row) {
|
3061
3095
|
all_sum = simd_sum(sumf[row]);
|
3062
3096
|
if (tiisg == 0) {
|
3063
|
-
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
3097
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
3064
3098
|
}
|
3065
3099
|
}
|
3066
3100
|
}
|
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
3072
3106
|
device const float * src1,
|
3073
3107
|
device float * dst,
|
3074
3108
|
constant int64_t & ne00,
|
3075
|
-
constant int64_t & ne01
|
3076
|
-
constant int64_t & ne02
|
3077
|
-
constant
|
3078
|
-
constant
|
3079
|
-
constant
|
3080
|
-
constant int64_t &
|
3081
|
-
constant
|
3082
|
-
constant
|
3109
|
+
constant int64_t & ne01,
|
3110
|
+
constant int64_t & ne02,
|
3111
|
+
constant uint64_t & nb00,
|
3112
|
+
constant uint64_t & nb01,
|
3113
|
+
constant uint64_t & nb02,
|
3114
|
+
constant int64_t & ne10,
|
3115
|
+
constant int64_t & ne11,
|
3116
|
+
constant int64_t & ne12,
|
3117
|
+
constant uint64_t & nb10,
|
3118
|
+
constant uint64_t & nb11,
|
3119
|
+
constant uint64_t & nb12,
|
3120
|
+
constant int64_t & ne0,
|
3121
|
+
constant int64_t & ne1,
|
3122
|
+
constant uint & r2,
|
3123
|
+
constant uint & r3,
|
3083
3124
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3084
3125
|
uint tiisg[[thread_index_in_simdgroup]],
|
3085
3126
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
3271
3312
|
device const float * src1,
|
3272
3313
|
device float * dst,
|
3273
3314
|
constant int64_t & ne00,
|
3274
|
-
constant int64_t & ne01
|
3275
|
-
constant int64_t & ne02
|
3276
|
-
constant
|
3277
|
-
constant
|
3278
|
-
constant
|
3279
|
-
constant int64_t &
|
3280
|
-
constant
|
3281
|
-
constant
|
3315
|
+
constant int64_t & ne01,
|
3316
|
+
constant int64_t & ne02,
|
3317
|
+
constant uint64_t & nb00,
|
3318
|
+
constant uint64_t & nb01,
|
3319
|
+
constant uint64_t & nb02,
|
3320
|
+
constant int64_t & ne10,
|
3321
|
+
constant int64_t & ne11,
|
3322
|
+
constant int64_t & ne12,
|
3323
|
+
constant uint64_t & nb10,
|
3324
|
+
constant uint64_t & nb11,
|
3325
|
+
constant uint64_t & nb12,
|
3326
|
+
constant int64_t & ne0,
|
3327
|
+
constant int64_t & ne1,
|
3328
|
+
constant uint & r2,
|
3329
|
+
constant uint & r3,
|
3282
3330
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3283
3331
|
uint tiisg[[thread_index_in_simdgroup]],
|
3284
3332
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
3398
3446
|
device const float * src1,
|
3399
3447
|
device float * dst,
|
3400
3448
|
constant int64_t & ne00,
|
3401
|
-
constant int64_t & ne01
|
3402
|
-
constant int64_t & ne02
|
3403
|
-
constant
|
3404
|
-
constant
|
3405
|
-
constant
|
3406
|
-
constant int64_t &
|
3407
|
-
constant
|
3408
|
-
constant
|
3449
|
+
constant int64_t & ne01,
|
3450
|
+
constant int64_t & ne02,
|
3451
|
+
constant uint64_t & nb00,
|
3452
|
+
constant uint64_t & nb01,
|
3453
|
+
constant uint64_t & nb02,
|
3454
|
+
constant int64_t & ne10,
|
3455
|
+
constant int64_t & ne11,
|
3456
|
+
constant int64_t & ne12,
|
3457
|
+
constant uint64_t & nb10,
|
3458
|
+
constant uint64_t & nb11,
|
3459
|
+
constant uint64_t & nb12,
|
3460
|
+
constant int64_t & ne0,
|
3461
|
+
constant int64_t & ne1,
|
3462
|
+
constant uint & r2,
|
3463
|
+
constant uint & r3,
|
3409
3464
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3410
3465
|
uint tiisg[[thread_index_in_simdgroup]],
|
3411
3466
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
3523
3578
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
3524
3579
|
const half d = xb->d;
|
3525
3580
|
|
3526
|
-
for (int i=0;i<16;i++) {
|
3581
|
+
for (int i = 0; i < 16; i++) {
|
3527
3582
|
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
3528
3583
|
}
|
3529
3584
|
}
|
@@ -3774,6 +3829,35 @@ kernel void kernel_get_rows_f16(
|
|
3774
3829
|
}
|
3775
3830
|
}
|
3776
3831
|
|
3832
|
+
kernel void kernel_get_rows_i32(
|
3833
|
+
device const void * src0,
|
3834
|
+
device const char * src1,
|
3835
|
+
device int32_t * dst,
|
3836
|
+
constant int64_t & ne00,
|
3837
|
+
constant uint64_t & nb01,
|
3838
|
+
constant uint64_t & nb02,
|
3839
|
+
constant int64_t & ne10,
|
3840
|
+
constant uint64_t & nb10,
|
3841
|
+
constant uint64_t & nb11,
|
3842
|
+
constant uint64_t & nb1,
|
3843
|
+
constant uint64_t & nb2,
|
3844
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3845
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3846
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3847
|
+
const int64_t i10 = tgpig.x;
|
3848
|
+
const int64_t i11 = tgpig.y;
|
3849
|
+
|
3850
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3851
|
+
|
3852
|
+
const int64_t i02 = i11;
|
3853
|
+
|
3854
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3855
|
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3856
|
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3857
|
+
}
|
3858
|
+
}
|
3859
|
+
|
3860
|
+
|
3777
3861
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
3778
3862
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
3779
3863
|
#define BLOCK_SIZE_K 32
|
@@ -3792,12 +3876,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
3792
3876
|
device float * dst,
|
3793
3877
|
constant int64_t & ne00,
|
3794
3878
|
constant int64_t & ne02,
|
3795
|
-
constant
|
3796
|
-
constant
|
3879
|
+
constant uint64_t & nb01,
|
3880
|
+
constant uint64_t & nb02,
|
3797
3881
|
constant int64_t & ne12,
|
3798
|
-
constant
|
3799
|
-
constant
|
3800
|
-
constant
|
3882
|
+
constant uint64_t & nb10,
|
3883
|
+
constant uint64_t & nb11,
|
3884
|
+
constant uint64_t & nb12,
|
3801
3885
|
constant int64_t & ne0,
|
3802
3886
|
constant int64_t & ne1,
|
3803
3887
|
constant uint & r2,
|
@@ -3918,18 +4002,143 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
3918
4002
|
}
|
3919
4003
|
}
|
3920
4004
|
|
4005
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
4006
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
4007
|
+
void kernel_mul_mm_id_impl(
|
4008
|
+
device const uchar * src0,
|
4009
|
+
device const uchar * src1,
|
4010
|
+
thread short * src1ids,
|
4011
|
+
device float * dst,
|
4012
|
+
constant int64_t & ne00,
|
4013
|
+
constant int64_t & ne02,
|
4014
|
+
constant uint64_t & nb01,
|
4015
|
+
constant uint64_t & nb02,
|
4016
|
+
constant int64_t & ne12,
|
4017
|
+
constant uint64_t & nb10,
|
4018
|
+
constant uint64_t & nb11,
|
4019
|
+
constant uint64_t & nb12,
|
4020
|
+
constant int64_t & ne0,
|
4021
|
+
int64_t ne1,
|
4022
|
+
constant uint & r2,
|
4023
|
+
constant uint & r3,
|
4024
|
+
threadgroup uchar * shared_memory,
|
4025
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4026
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4027
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4028
|
+
|
4029
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
4030
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
4031
|
+
|
4032
|
+
const uint r0 = tgpig.y;
|
4033
|
+
const uint r1 = tgpig.x;
|
4034
|
+
const uint im = tgpig.z;
|
4035
|
+
|
4036
|
+
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
4037
|
+
|
4038
|
+
// if this block is of 64x32 shape or smaller
|
4039
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
4040
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
4041
|
+
|
4042
|
+
// a thread shouldn't load data outside of the matrix
|
4043
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
4044
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
4045
|
+
|
4046
|
+
simdgroup_half8x8 ma[4];
|
4047
|
+
simdgroup_float8x8 mb[2];
|
4048
|
+
simdgroup_float8x8 c_res[8];
|
4049
|
+
for (int i = 0; i < 8; i++){
|
4050
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
4051
|
+
}
|
4052
|
+
|
4053
|
+
short il = (tiitg % THREAD_PER_ROW);
|
4054
|
+
|
4055
|
+
const uint i12 = im%ne12;
|
4056
|
+
const uint i13 = im/ne12;
|
4057
|
+
|
4058
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
4059
|
+
ushort offset1 = il/nl;
|
4060
|
+
|
4061
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
4062
|
+
device const float * y = (device const float *)(src1
|
4063
|
+
+ nb12 * im
|
4064
|
+
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
4065
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
4066
|
+
|
4067
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
4068
|
+
// load data and store to threadgroup memory
|
4069
|
+
half4x4 temp_a;
|
4070
|
+
dequantize_func(x, il, temp_a);
|
4071
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4072
|
+
|
4073
|
+
for (int i = 0; i < 16; i++) {
|
4074
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
4075
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
4076
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
4077
|
+
}
|
4078
|
+
|
4079
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
4080
|
+
|
4081
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
4082
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
4083
|
+
y += BLOCK_SIZE_K;
|
4084
|
+
|
4085
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4086
|
+
|
4087
|
+
// load matrices from threadgroup memory and conduct outer products
|
4088
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
4089
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
4090
|
+
|
4091
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
4092
|
+
for (int i = 0; i < 4; i++) {
|
4093
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
4094
|
+
}
|
4095
|
+
simdgroup_barrier(mem_flags::mem_none);
|
4096
|
+
for (int i = 0; i < 2; i++) {
|
4097
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
4098
|
+
}
|
4099
|
+
|
4100
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
4101
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
4102
|
+
|
4103
|
+
for (int i = 0; i < 8; i++){
|
4104
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
4105
|
+
}
|
4106
|
+
}
|
4107
|
+
}
|
4108
|
+
|
4109
|
+
{
|
4110
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4111
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
4112
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
4113
|
+
for (int i = 0; i < 8; i++) {
|
4114
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
4115
|
+
}
|
4116
|
+
|
4117
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4118
|
+
|
4119
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
4120
|
+
if (sgitg == 0) {
|
4121
|
+
for (int i = 0; i < n_rows; i++) {
|
4122
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
4123
|
+
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
4124
|
+
}
|
4125
|
+
}
|
4126
|
+
}
|
4127
|
+
}
|
4128
|
+
}
|
4129
|
+
|
3921
4130
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3922
4131
|
kernel void kernel_mul_mm(device const uchar * src0,
|
3923
4132
|
device const uchar * src1,
|
3924
4133
|
device float * dst,
|
3925
4134
|
constant int64_t & ne00,
|
3926
4135
|
constant int64_t & ne02,
|
3927
|
-
constant
|
3928
|
-
constant
|
4136
|
+
constant uint64_t & nb01,
|
4137
|
+
constant uint64_t & nb02,
|
3929
4138
|
constant int64_t & ne12,
|
3930
|
-
constant
|
3931
|
-
constant
|
3932
|
-
constant
|
4139
|
+
constant uint64_t & nb10,
|
4140
|
+
constant uint64_t & nb11,
|
4141
|
+
constant uint64_t & nb12,
|
3933
4142
|
constant int64_t & ne0,
|
3934
4143
|
constant int64_t & ne1,
|
3935
4144
|
constant uint & r2,
|
@@ -3964,20 +4173,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|
3964
4173
|
kernel void kernel_mul_mm_id(
|
3965
4174
|
device const uchar * ids,
|
3966
4175
|
device const uchar * src1,
|
3967
|
-
device
|
3968
|
-
constant
|
4176
|
+
device float * dst,
|
4177
|
+
constant uint64_t & nbi1,
|
3969
4178
|
constant int64_t & ne00,
|
3970
4179
|
constant int64_t & ne02,
|
3971
|
-
constant
|
3972
|
-
constant
|
4180
|
+
constant uint64_t & nb01,
|
4181
|
+
constant uint64_t & nb02,
|
3973
4182
|
constant int64_t & ne12,
|
3974
4183
|
constant int64_t & ne13,
|
3975
|
-
constant
|
3976
|
-
constant
|
3977
|
-
constant
|
4184
|
+
constant uint64_t & nb10,
|
4185
|
+
constant uint64_t & nb11,
|
4186
|
+
constant uint64_t & nb12,
|
3978
4187
|
constant int64_t & ne0,
|
3979
4188
|
constant int64_t & ne1,
|
3980
|
-
constant
|
4189
|
+
constant uint64_t & nb1,
|
3981
4190
|
constant uint & r2,
|
3982
4191
|
constant uint & r3,
|
3983
4192
|
constant int & idx,
|
@@ -3993,18 +4202,28 @@ kernel void kernel_mul_mm_id(
|
|
3993
4202
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3994
4203
|
uint tiitg[[thread_index_in_threadgroup]],
|
3995
4204
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3996
|
-
device const uchar *
|
4205
|
+
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3997
4206
|
|
3998
|
-
|
4207
|
+
// expert id
|
4208
|
+
const int32_t id = tgpig.z/(ne12*ne13);
|
3999
4209
|
|
4000
4210
|
tgpig.z = tgpig.z%(ne12*ne13);
|
4001
4211
|
|
4002
|
-
|
4212
|
+
// row indices of src1 for expert id
|
4213
|
+
int64_t _ne1 = 0;
|
4214
|
+
short src1ids[512];
|
4003
4215
|
|
4004
|
-
|
4005
|
-
|
4006
|
-
|
4007
|
-
|
4216
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
4217
|
+
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
4218
|
+
src1ids[_ne1++] = i1;
|
4219
|
+
}
|
4220
|
+
}
|
4221
|
+
|
4222
|
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
4223
|
+
src0s[id],
|
4224
|
+
src1,
|
4225
|
+
src1ids,
|
4226
|
+
dst,
|
4008
4227
|
ne00,
|
4009
4228
|
ne02,
|
4010
4229
|
nb01,
|
@@ -4014,7 +4233,7 @@ kernel void kernel_mul_mm_id(
|
|
4014
4233
|
nb11,
|
4015
4234
|
nb12,
|
4016
4235
|
ne0,
|
4017
|
-
|
4236
|
+
_ne1,
|
4018
4237
|
r2,
|
4019
4238
|
r3,
|
4020
4239
|
shared_memory,
|
@@ -4070,12 +4289,12 @@ typedef void (mat_mm_t)(
|
|
4070
4289
|
device float * dst,
|
4071
4290
|
constant int64_t & ne00,
|
4072
4291
|
constant int64_t & ne02,
|
4073
|
-
constant
|
4074
|
-
constant
|
4292
|
+
constant uint64_t & nb01,
|
4293
|
+
constant uint64_t & nb02,
|
4075
4294
|
constant int64_t & ne12,
|
4076
|
-
constant
|
4077
|
-
constant
|
4078
|
-
constant
|
4295
|
+
constant uint64_t & nb10,
|
4296
|
+
constant uint64_t & nb11,
|
4297
|
+
constant uint64_t & nb12,
|
4079
4298
|
constant int64_t & ne0,
|
4080
4299
|
constant int64_t & ne1,
|
4081
4300
|
constant uint & r2,
|
@@ -4103,20 +4322,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
4103
4322
|
typedef void (mat_mm_id_t)(
|
4104
4323
|
device const uchar * ids,
|
4105
4324
|
device const uchar * src1,
|
4106
|
-
device
|
4107
|
-
constant
|
4325
|
+
device float * dst,
|
4326
|
+
constant uint64_t & nbi1,
|
4108
4327
|
constant int64_t & ne00,
|
4109
4328
|
constant int64_t & ne02,
|
4110
|
-
constant
|
4111
|
-
constant
|
4329
|
+
constant uint64_t & nb01,
|
4330
|
+
constant uint64_t & nb02,
|
4112
4331
|
constant int64_t & ne12,
|
4113
4332
|
constant int64_t & ne13,
|
4114
|
-
constant
|
4115
|
-
constant
|
4116
|
-
constant
|
4333
|
+
constant uint64_t & nb10,
|
4334
|
+
constant uint64_t & nb11,
|
4335
|
+
constant uint64_t & nb12,
|
4117
4336
|
constant int64_t & ne0,
|
4118
4337
|
constant int64_t & ne1,
|
4119
|
-
constant
|
4338
|
+
constant uint64_t & nb1,
|
4120
4339
|
constant uint & r2,
|
4121
4340
|
constant uint & r3,
|
4122
4341
|
constant int & idx,
|
@@ -4152,8 +4371,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
4152
4371
|
kernel void kernel_mul_mv_id_f32_f32(
|
4153
4372
|
device const char * ids,
|
4154
4373
|
device const char * src1,
|
4155
|
-
device
|
4156
|
-
constant
|
4374
|
+
device float * dst,
|
4375
|
+
constant uint64_t & nbi1,
|
4157
4376
|
constant int64_t & ne00,
|
4158
4377
|
constant int64_t & ne01,
|
4159
4378
|
constant int64_t & ne02,
|
@@ -4169,7 +4388,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
4169
4388
|
constant uint64_t & nb12,
|
4170
4389
|
constant int64_t & ne0,
|
4171
4390
|
constant int64_t & ne1,
|
4172
|
-
constant
|
4391
|
+
constant uint64_t & nb1,
|
4173
4392
|
constant uint & r2,
|
4174
4393
|
constant uint & r3,
|
4175
4394
|
constant int & idx,
|
@@ -4196,7 +4415,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
4196
4415
|
kernel_mul_mv_f32_f32_impl(
|
4197
4416
|
src0[id],
|
4198
4417
|
src1 + bid*nb11,
|
4199
|
-
|
4418
|
+
dst + bid*ne0,
|
4200
4419
|
ne00,
|
4201
4420
|
ne01,
|
4202
4421
|
ne02,
|
@@ -4221,8 +4440,8 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
4221
4440
|
kernel void kernel_mul_mv_id_f16_f32(
|
4222
4441
|
device const char * ids,
|
4223
4442
|
device const char * src1,
|
4224
|
-
device
|
4225
|
-
constant
|
4443
|
+
device float * dst,
|
4444
|
+
constant uint64_t & nbi1,
|
4226
4445
|
constant int64_t & ne00,
|
4227
4446
|
constant int64_t & ne01,
|
4228
4447
|
constant int64_t & ne02,
|
@@ -4238,7 +4457,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
4238
4457
|
constant uint64_t & nb12,
|
4239
4458
|
constant int64_t & ne0,
|
4240
4459
|
constant int64_t & ne1,
|
4241
|
-
constant
|
4460
|
+
constant uint64_t & nb1,
|
4242
4461
|
constant uint & r2,
|
4243
4462
|
constant uint & r3,
|
4244
4463
|
constant int & idx,
|
@@ -4265,7 +4484,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
4265
4484
|
kernel_mul_mv_f16_f32_impl(
|
4266
4485
|
src0[id],
|
4267
4486
|
src1 + bid*nb11,
|
4268
|
-
|
4487
|
+
dst + bid*ne0,
|
4269
4488
|
ne00,
|
4270
4489
|
ne01,
|
4271
4490
|
ne02,
|
@@ -4290,8 +4509,8 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
4290
4509
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
4291
4510
|
device const char * ids,
|
4292
4511
|
device const char * src1,
|
4293
|
-
device
|
4294
|
-
constant
|
4512
|
+
device float * dst,
|
4513
|
+
constant uint64_t & nbi1,
|
4295
4514
|
constant int64_t & ne00,
|
4296
4515
|
constant int64_t & ne01,
|
4297
4516
|
constant int64_t & ne02,
|
@@ -4307,7 +4526,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4307
4526
|
constant uint64_t & nb12,
|
4308
4527
|
constant int64_t & ne0,
|
4309
4528
|
constant int64_t & ne1,
|
4310
|
-
constant
|
4529
|
+
constant uint64_t & nb1,
|
4311
4530
|
constant uint & r2,
|
4312
4531
|
constant uint & r3,
|
4313
4532
|
constant int & idx,
|
@@ -4334,7 +4553,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4334
4553
|
kernel_mul_mv_q8_0_f32_impl(
|
4335
4554
|
src0[id],
|
4336
4555
|
(device const float *) (src1 + bid*nb11),
|
4337
|
-
|
4556
|
+
dst + bid*ne0,
|
4338
4557
|
ne00,
|
4339
4558
|
ne01,
|
4340
4559
|
ne02,
|
@@ -4353,8 +4572,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4353
4572
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
4354
4573
|
device const char * ids,
|
4355
4574
|
device const char * src1,
|
4356
|
-
device
|
4357
|
-
constant
|
4575
|
+
device float * dst,
|
4576
|
+
constant uint64_t & nbi1,
|
4358
4577
|
constant int64_t & ne00,
|
4359
4578
|
constant int64_t & ne01,
|
4360
4579
|
constant int64_t & ne02,
|
@@ -4370,7 +4589,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4370
4589
|
constant uint64_t & nb12,
|
4371
4590
|
constant int64_t & ne0,
|
4372
4591
|
constant int64_t & ne1,
|
4373
|
-
constant
|
4592
|
+
constant uint64_t & nb1,
|
4374
4593
|
constant uint & r2,
|
4375
4594
|
constant uint & r3,
|
4376
4595
|
constant int & idx,
|
@@ -4397,7 +4616,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4397
4616
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4398
4617
|
src0[id],
|
4399
4618
|
(device const float *) (src1 + bid*nb11),
|
4400
|
-
|
4619
|
+
dst + bid*ne0,
|
4401
4620
|
ne00,
|
4402
4621
|
ne01,
|
4403
4622
|
ne02,
|
@@ -4416,8 +4635,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4416
4635
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
4417
4636
|
device const char * ids,
|
4418
4637
|
device const char * src1,
|
4419
|
-
device
|
4420
|
-
constant
|
4638
|
+
device float * dst,
|
4639
|
+
constant uint64_t & nbi1,
|
4421
4640
|
constant int64_t & ne00,
|
4422
4641
|
constant int64_t & ne01,
|
4423
4642
|
constant int64_t & ne02,
|
@@ -4433,7 +4652,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4433
4652
|
constant uint64_t & nb12,
|
4434
4653
|
constant int64_t & ne0,
|
4435
4654
|
constant int64_t & ne1,
|
4436
|
-
constant
|
4655
|
+
constant uint64_t & nb1,
|
4437
4656
|
constant uint & r2,
|
4438
4657
|
constant uint & r3,
|
4439
4658
|
constant int & idx,
|
@@ -4460,7 +4679,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4460
4679
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4461
4680
|
src0[id],
|
4462
4681
|
(device const float *) (src1 + bid*nb11),
|
4463
|
-
|
4682
|
+
dst + bid*ne0,
|
4464
4683
|
ne00,
|
4465
4684
|
ne01,
|
4466
4685
|
ne02,
|
@@ -4479,8 +4698,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4479
4698
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
4480
4699
|
device const char * ids,
|
4481
4700
|
device const char * src1,
|
4482
|
-
device
|
4483
|
-
constant
|
4701
|
+
device float * dst,
|
4702
|
+
constant uint64_t & nbi1,
|
4484
4703
|
constant int64_t & ne00,
|
4485
4704
|
constant int64_t & ne01,
|
4486
4705
|
constant int64_t & ne02,
|
@@ -4496,7 +4715,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4496
4715
|
constant uint64_t & nb12,
|
4497
4716
|
constant int64_t & ne0,
|
4498
4717
|
constant int64_t & ne1,
|
4499
|
-
constant
|
4718
|
+
constant uint64_t & nb1,
|
4500
4719
|
constant uint & r2,
|
4501
4720
|
constant uint & r3,
|
4502
4721
|
constant int & idx,
|
@@ -4523,7 +4742,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4523
4742
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4524
4743
|
src0[id],
|
4525
4744
|
(device const float *) (src1 + bid*nb11),
|
4526
|
-
|
4745
|
+
dst + bid*ne0,
|
4527
4746
|
ne00,
|
4528
4747
|
ne01,
|
4529
4748
|
ne02,
|
@@ -4542,8 +4761,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4542
4761
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
4543
4762
|
device const char * ids,
|
4544
4763
|
device const char * src1,
|
4545
|
-
device
|
4546
|
-
constant
|
4764
|
+
device float * dst,
|
4765
|
+
constant uint64_t & nbi1,
|
4547
4766
|
constant int64_t & ne00,
|
4548
4767
|
constant int64_t & ne01,
|
4549
4768
|
constant int64_t & ne02,
|
@@ -4559,7 +4778,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4559
4778
|
constant uint64_t & nb12,
|
4560
4779
|
constant int64_t & ne0,
|
4561
4780
|
constant int64_t & ne1,
|
4562
|
-
constant
|
4781
|
+
constant uint64_t & nb1,
|
4563
4782
|
constant uint & r2,
|
4564
4783
|
constant uint & r3,
|
4565
4784
|
constant int & idx,
|
@@ -4586,7 +4805,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4586
4805
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4587
4806
|
src0[id],
|
4588
4807
|
(device const float *) (src1 + bid*nb11),
|
4589
|
-
|
4808
|
+
dst + bid*ne0,
|
4590
4809
|
ne00,
|
4591
4810
|
ne01,
|
4592
4811
|
ne02,
|
@@ -4605,8 +4824,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4605
4824
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
4606
4825
|
device const char * ids,
|
4607
4826
|
device const char * src1,
|
4608
|
-
device
|
4609
|
-
constant
|
4827
|
+
device float * dst,
|
4828
|
+
constant uint64_t & nbi1,
|
4610
4829
|
constant int64_t & ne00,
|
4611
4830
|
constant int64_t & ne01,
|
4612
4831
|
constant int64_t & ne02,
|
@@ -4622,7 +4841,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4622
4841
|
constant uint64_t & nb12,
|
4623
4842
|
constant int64_t & ne0,
|
4624
4843
|
constant int64_t & ne1,
|
4625
|
-
constant
|
4844
|
+
constant uint64_t & nb1,
|
4626
4845
|
constant uint & r2,
|
4627
4846
|
constant uint & r3,
|
4628
4847
|
constant int & idx,
|
@@ -4649,7 +4868,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4649
4868
|
kernel_mul_mv_q2_K_f32_impl(
|
4650
4869
|
src0[id],
|
4651
4870
|
(device const float *) (src1 + bid*nb11),
|
4652
|
-
|
4871
|
+
dst + bid*ne0,
|
4653
4872
|
ne00,
|
4654
4873
|
ne01,
|
4655
4874
|
ne02,
|
@@ -4668,8 +4887,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4668
4887
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
4669
4888
|
device const char * ids,
|
4670
4889
|
device const char * src1,
|
4671
|
-
device
|
4672
|
-
constant
|
4890
|
+
device float * dst,
|
4891
|
+
constant uint64_t & nbi1,
|
4673
4892
|
constant int64_t & ne00,
|
4674
4893
|
constant int64_t & ne01,
|
4675
4894
|
constant int64_t & ne02,
|
@@ -4685,7 +4904,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4685
4904
|
constant uint64_t & nb12,
|
4686
4905
|
constant int64_t & ne0,
|
4687
4906
|
constant int64_t & ne1,
|
4688
|
-
constant
|
4907
|
+
constant uint64_t & nb1,
|
4689
4908
|
constant uint & r2,
|
4690
4909
|
constant uint & r3,
|
4691
4910
|
constant int & idx,
|
@@ -4712,7 +4931,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4712
4931
|
kernel_mul_mv_q3_K_f32_impl(
|
4713
4932
|
src0[id],
|
4714
4933
|
(device const float *) (src1 + bid*nb11),
|
4715
|
-
|
4934
|
+
dst + bid*ne0,
|
4716
4935
|
ne00,
|
4717
4936
|
ne01,
|
4718
4937
|
ne02,
|
@@ -4731,8 +4950,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4731
4950
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
4732
4951
|
device const char * ids,
|
4733
4952
|
device const char * src1,
|
4734
|
-
device
|
4735
|
-
constant
|
4953
|
+
device float * dst,
|
4954
|
+
constant uint64_t & nbi1,
|
4736
4955
|
constant int64_t & ne00,
|
4737
4956
|
constant int64_t & ne01,
|
4738
4957
|
constant int64_t & ne02,
|
@@ -4748,7 +4967,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4748
4967
|
constant uint64_t & nb12,
|
4749
4968
|
constant int64_t & ne0,
|
4750
4969
|
constant int64_t & ne1,
|
4751
|
-
constant
|
4970
|
+
constant uint64_t & nb1,
|
4752
4971
|
constant uint & r2,
|
4753
4972
|
constant uint & r3,
|
4754
4973
|
constant int & idx,
|
@@ -4775,7 +4994,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4775
4994
|
kernel_mul_mv_q4_K_f32_impl(
|
4776
4995
|
src0[id],
|
4777
4996
|
(device const float *) (src1 + bid*nb11),
|
4778
|
-
|
4997
|
+
dst + bid*ne0,
|
4779
4998
|
ne00,
|
4780
4999
|
ne01,
|
4781
5000
|
ne02,
|
@@ -4794,8 +5013,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4794
5013
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
4795
5014
|
device const char * ids,
|
4796
5015
|
device const char * src1,
|
4797
|
-
device
|
4798
|
-
constant
|
5016
|
+
device float * dst,
|
5017
|
+
constant uint64_t & nbi1,
|
4799
5018
|
constant int64_t & ne00,
|
4800
5019
|
constant int64_t & ne01,
|
4801
5020
|
constant int64_t & ne02,
|
@@ -4811,7 +5030,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4811
5030
|
constant uint64_t & nb12,
|
4812
5031
|
constant int64_t & ne0,
|
4813
5032
|
constant int64_t & ne1,
|
4814
|
-
constant
|
5033
|
+
constant uint64_t & nb1,
|
4815
5034
|
constant uint & r2,
|
4816
5035
|
constant uint & r3,
|
4817
5036
|
constant int & idx,
|
@@ -4838,7 +5057,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4838
5057
|
kernel_mul_mv_q5_K_f32_impl(
|
4839
5058
|
src0[id],
|
4840
5059
|
(device const float *) (src1 + bid*nb11),
|
4841
|
-
|
5060
|
+
dst + bid*ne0,
|
4842
5061
|
ne00,
|
4843
5062
|
ne01,
|
4844
5063
|
ne02,
|
@@ -4857,8 +5076,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4857
5076
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
4858
5077
|
device const char * ids,
|
4859
5078
|
device const char * src1,
|
4860
|
-
device
|
4861
|
-
constant
|
5079
|
+
device float * dst,
|
5080
|
+
constant uint64_t & nbi1,
|
4862
5081
|
constant int64_t & ne00,
|
4863
5082
|
constant int64_t & ne01,
|
4864
5083
|
constant int64_t & ne02,
|
@@ -4874,7 +5093,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4874
5093
|
constant uint64_t & nb12,
|
4875
5094
|
constant int64_t & ne0,
|
4876
5095
|
constant int64_t & ne1,
|
4877
|
-
constant
|
5096
|
+
constant uint64_t & nb1,
|
4878
5097
|
constant uint & r2,
|
4879
5098
|
constant uint & r3,
|
4880
5099
|
constant int & idx,
|
@@ -4901,7 +5120,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4901
5120
|
kernel_mul_mv_q6_K_f32_impl(
|
4902
5121
|
src0[id],
|
4903
5122
|
(device const float *) (src1 + bid*nb11),
|
4904
|
-
|
5123
|
+
dst + bid*ne0,
|
4905
5124
|
ne00,
|
4906
5125
|
ne01,
|
4907
5126
|
ne02,
|