llama_cpp 0.10.3 → 0.10.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/ext/llama_cpp/src/ggml-backend.c +6 -2
- data/ext/llama_cpp/src/ggml-cuda.cu +73 -63
- data/ext/llama_cpp/src/ggml-impl.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +43 -20
- data/ext/llama_cpp/src/ggml-metal.metal +464 -245
- data/ext/llama_cpp/src/ggml-opencl.h +9 -9
- data/ext/llama_cpp/src/ggml-quants.c +61 -57
- data/ext/llama_cpp/src/ggml.c +171 -5
- data/ext/llama_cpp/src/ggml.h +1 -0
- data/ext/llama_cpp/src/llama.cpp +222 -105
- data/ext/llama_cpp/src/llama.h +31 -32
- data/lib/llama_cpp/version.rb +2 -2
- metadata +3 -3
@@ -59,26 +59,26 @@ kernel void kernel_add(
|
|
59
59
|
constant int64_t & ne01,
|
60
60
|
constant int64_t & ne02,
|
61
61
|
constant int64_t & ne03,
|
62
|
-
constant
|
63
|
-
constant
|
64
|
-
constant
|
65
|
-
constant
|
62
|
+
constant uint64_t & nb00,
|
63
|
+
constant uint64_t & nb01,
|
64
|
+
constant uint64_t & nb02,
|
65
|
+
constant uint64_t & nb03,
|
66
66
|
constant int64_t & ne10,
|
67
67
|
constant int64_t & ne11,
|
68
68
|
constant int64_t & ne12,
|
69
69
|
constant int64_t & ne13,
|
70
|
-
constant
|
71
|
-
constant
|
72
|
-
constant
|
73
|
-
constant
|
70
|
+
constant uint64_t & nb10,
|
71
|
+
constant uint64_t & nb11,
|
72
|
+
constant uint64_t & nb12,
|
73
|
+
constant uint64_t & nb13,
|
74
74
|
constant int64_t & ne0,
|
75
75
|
constant int64_t & ne1,
|
76
76
|
constant int64_t & ne2,
|
77
77
|
constant int64_t & ne3,
|
78
|
-
constant
|
79
|
-
constant
|
80
|
-
constant
|
81
|
-
constant
|
78
|
+
constant uint64_t & nb0,
|
79
|
+
constant uint64_t & nb1,
|
80
|
+
constant uint64_t & nb2,
|
81
|
+
constant uint64_t & nb3,
|
82
82
|
constant int64_t & offs,
|
83
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
84
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -109,26 +109,26 @@ kernel void kernel_mul(
|
|
109
109
|
constant int64_t & ne01,
|
110
110
|
constant int64_t & ne02,
|
111
111
|
constant int64_t & ne03,
|
112
|
-
constant
|
113
|
-
constant
|
114
|
-
constant
|
115
|
-
constant
|
112
|
+
constant uint64_t & nb00,
|
113
|
+
constant uint64_t & nb01,
|
114
|
+
constant uint64_t & nb02,
|
115
|
+
constant uint64_t & nb03,
|
116
116
|
constant int64_t & ne10,
|
117
117
|
constant int64_t & ne11,
|
118
118
|
constant int64_t & ne12,
|
119
119
|
constant int64_t & ne13,
|
120
|
-
constant
|
121
|
-
constant
|
122
|
-
constant
|
123
|
-
constant
|
120
|
+
constant uint64_t & nb10,
|
121
|
+
constant uint64_t & nb11,
|
122
|
+
constant uint64_t & nb12,
|
123
|
+
constant uint64_t & nb13,
|
124
124
|
constant int64_t & ne0,
|
125
125
|
constant int64_t & ne1,
|
126
126
|
constant int64_t & ne2,
|
127
127
|
constant int64_t & ne3,
|
128
|
-
constant
|
129
|
-
constant
|
130
|
-
constant
|
131
|
-
constant
|
128
|
+
constant uint64_t & nb0,
|
129
|
+
constant uint64_t & nb1,
|
130
|
+
constant uint64_t & nb2,
|
131
|
+
constant uint64_t & nb3,
|
132
132
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
133
133
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
134
134
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -158,26 +158,26 @@ kernel void kernel_div(
|
|
158
158
|
constant int64_t & ne01,
|
159
159
|
constant int64_t & ne02,
|
160
160
|
constant int64_t & ne03,
|
161
|
-
constant
|
162
|
-
constant
|
163
|
-
constant
|
164
|
-
constant
|
161
|
+
constant uint64_t & nb00,
|
162
|
+
constant uint64_t & nb01,
|
163
|
+
constant uint64_t & nb02,
|
164
|
+
constant uint64_t & nb03,
|
165
165
|
constant int64_t & ne10,
|
166
166
|
constant int64_t & ne11,
|
167
167
|
constant int64_t & ne12,
|
168
168
|
constant int64_t & ne13,
|
169
|
-
constant
|
170
|
-
constant
|
171
|
-
constant
|
172
|
-
constant
|
169
|
+
constant uint64_t & nb10,
|
170
|
+
constant uint64_t & nb11,
|
171
|
+
constant uint64_t & nb12,
|
172
|
+
constant uint64_t & nb13,
|
173
173
|
constant int64_t & ne0,
|
174
174
|
constant int64_t & ne1,
|
175
175
|
constant int64_t & ne2,
|
176
176
|
constant int64_t & ne3,
|
177
|
-
constant
|
178
|
-
constant
|
179
|
-
constant
|
180
|
-
constant
|
177
|
+
constant uint64_t & nb0,
|
178
|
+
constant uint64_t & nb1,
|
179
|
+
constant uint64_t & nb2,
|
180
|
+
constant uint64_t & nb3,
|
181
181
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
182
182
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
183
183
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
|
|
205
205
|
device const float4 * src0,
|
206
206
|
device const float4 * src1,
|
207
207
|
device float4 * dst,
|
208
|
-
constant
|
208
|
+
constant uint64_t & nb [[buffer(28)]],
|
209
209
|
uint tpig[[thread_position_in_grid]]) {
|
210
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
211
211
|
}
|
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
|
|
214
214
|
device const float4 * src0,
|
215
215
|
device const float4 * src1,
|
216
216
|
device float4 * dst,
|
217
|
-
constant
|
217
|
+
constant uint64_t & nb [[buffer(28)]],
|
218
218
|
uint tpig[[thread_position_in_grid]]) {
|
219
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
220
220
|
}
|
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
|
|
223
223
|
device const float4 * src0,
|
224
224
|
device const float4 * src1,
|
225
225
|
device float4 * dst,
|
226
|
-
constant
|
226
|
+
constant uint64_t & nb [[buffer(28)]],
|
227
227
|
uint tpig[[thread_position_in_grid]]) {
|
228
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
229
229
|
}
|
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
|
|
307
307
|
constant int64_t & ne01,
|
308
308
|
constant int64_t & ne02,
|
309
309
|
constant int64_t & ne03,
|
310
|
-
constant
|
311
|
-
constant
|
312
|
-
constant
|
313
|
-
constant
|
310
|
+
constant uint64_t & nb00,
|
311
|
+
constant uint64_t & nb01,
|
312
|
+
constant uint64_t & nb02,
|
313
|
+
constant uint64_t & nb03,
|
314
314
|
constant int64_t & ne10,
|
315
315
|
constant int64_t & ne11,
|
316
316
|
constant int64_t & ne12,
|
317
317
|
constant int64_t & ne13,
|
318
|
-
constant
|
319
|
-
constant
|
320
|
-
constant
|
321
|
-
constant
|
318
|
+
constant uint64_t & nb10,
|
319
|
+
constant uint64_t & nb11,
|
320
|
+
constant uint64_t & nb12,
|
321
|
+
constant uint64_t & nb13,
|
322
322
|
constant int64_t & ne0,
|
323
323
|
constant int64_t & ne1,
|
324
324
|
constant int64_t & ne2,
|
325
325
|
constant int64_t & ne3,
|
326
|
-
constant
|
327
|
-
constant
|
328
|
-
constant
|
329
|
-
constant
|
326
|
+
constant uint64_t & nb0,
|
327
|
+
constant uint64_t & nb1,
|
328
|
+
constant uint64_t & nb2,
|
329
|
+
constant uint64_t & nb3,
|
330
330
|
uint3 tpig[[thread_position_in_grid]]) {
|
331
331
|
int64_t i3 = tpig.z;
|
332
332
|
int64_t i2 = tpig.y;
|
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
846
846
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
847
847
|
//Note: This is a template, but strictly speaking it only applies to
|
848
848
|
// quantizations where the block size is 32. It also does not
|
849
|
-
//
|
849
|
+
// guard against the number of rows not being divisible by
|
850
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
851
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
852
852
|
void mul_vec_q_n_f32_impl(
|
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
920
920
|
device const float * src1,
|
921
921
|
device float * dst,
|
922
922
|
constant int64_t & ne00,
|
923
|
-
constant int64_t & ne01
|
924
|
-
constant int64_t & ne02
|
925
|
-
constant
|
926
|
-
constant
|
927
|
-
constant
|
928
|
-
constant int64_t &
|
929
|
-
constant
|
930
|
-
constant
|
923
|
+
constant int64_t & ne01,
|
924
|
+
constant int64_t & ne02,
|
925
|
+
constant uint64_t & nb00,
|
926
|
+
constant uint64_t & nb01,
|
927
|
+
constant uint64_t & nb02,
|
928
|
+
constant int64_t & ne10,
|
929
|
+
constant int64_t & ne11,
|
930
|
+
constant int64_t & ne12,
|
931
|
+
constant uint64_t & nb10,
|
932
|
+
constant uint64_t & nb11,
|
933
|
+
constant uint64_t & nb12,
|
934
|
+
constant int64_t & ne0,
|
935
|
+
constant int64_t & ne1,
|
936
|
+
constant uint & r2,
|
937
|
+
constant uint & r3,
|
931
938
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
932
939
|
uint tiisg[[thread_index_in_simdgroup]],
|
933
940
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
939
946
|
device const float * src1,
|
940
947
|
device float * dst,
|
941
948
|
constant int64_t & ne00,
|
942
|
-
constant int64_t & ne01
|
943
|
-
constant int64_t & ne02
|
944
|
-
constant
|
945
|
-
constant
|
946
|
-
constant
|
947
|
-
constant int64_t &
|
948
|
-
constant
|
949
|
-
constant
|
949
|
+
constant int64_t & ne01,
|
950
|
+
constant int64_t & ne02,
|
951
|
+
constant uint64_t & nb00,
|
952
|
+
constant uint64_t & nb01,
|
953
|
+
constant uint64_t & nb02,
|
954
|
+
constant int64_t & ne10,
|
955
|
+
constant int64_t & ne11,
|
956
|
+
constant int64_t & ne12,
|
957
|
+
constant uint64_t & nb10,
|
958
|
+
constant uint64_t & nb11,
|
959
|
+
constant uint64_t & nb12,
|
960
|
+
constant int64_t & ne0,
|
961
|
+
constant int64_t & ne1,
|
962
|
+
constant uint & r2,
|
963
|
+
constant uint & r3,
|
950
964
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
951
965
|
uint tiisg[[thread_index_in_simdgroup]],
|
952
966
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
958
972
|
device const float * src1,
|
959
973
|
device float * dst,
|
960
974
|
constant int64_t & ne00,
|
961
|
-
constant int64_t & ne01
|
962
|
-
constant int64_t & ne02
|
963
|
-
constant
|
964
|
-
constant
|
965
|
-
constant
|
966
|
-
constant int64_t &
|
967
|
-
constant
|
968
|
-
constant
|
975
|
+
constant int64_t & ne01,
|
976
|
+
constant int64_t & ne02,
|
977
|
+
constant uint64_t & nb00,
|
978
|
+
constant uint64_t & nb01,
|
979
|
+
constant uint64_t & nb02,
|
980
|
+
constant int64_t & ne10,
|
981
|
+
constant int64_t & ne11,
|
982
|
+
constant int64_t & ne12,
|
983
|
+
constant uint64_t & nb10,
|
984
|
+
constant uint64_t & nb11,
|
985
|
+
constant uint64_t & nb12,
|
986
|
+
constant int64_t & ne0,
|
987
|
+
constant int64_t & ne1,
|
988
|
+
constant uint & r2,
|
989
|
+
constant uint & r3,
|
969
990
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
970
991
|
uint tiisg[[thread_index_in_simdgroup]],
|
971
992
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
977
998
|
device const float * src1,
|
978
999
|
device float * dst,
|
979
1000
|
constant int64_t & ne00,
|
980
|
-
constant int64_t & ne01
|
981
|
-
constant int64_t & ne02
|
982
|
-
constant
|
983
|
-
constant
|
984
|
-
constant
|
985
|
-
constant int64_t &
|
986
|
-
constant
|
987
|
-
constant
|
1001
|
+
constant int64_t & ne01,
|
1002
|
+
constant int64_t & ne02,
|
1003
|
+
constant uint64_t & nb00,
|
1004
|
+
constant uint64_t & nb01,
|
1005
|
+
constant uint64_t & nb02,
|
1006
|
+
constant int64_t & ne10,
|
1007
|
+
constant int64_t & ne11,
|
1008
|
+
constant int64_t & ne12,
|
1009
|
+
constant uint64_t & nb10,
|
1010
|
+
constant uint64_t & nb11,
|
1011
|
+
constant uint64_t & nb12,
|
1012
|
+
constant int64_t & ne0,
|
1013
|
+
constant int64_t & ne1,
|
1014
|
+
constant uint & r2,
|
1015
|
+
constant uint & r3,
|
988
1016
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
989
1017
|
uint tiisg[[thread_index_in_simdgroup]],
|
990
1018
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
1071
1099
|
constant int64_t & ne00,
|
1072
1100
|
constant int64_t & ne01,
|
1073
1101
|
constant int64_t & ne02,
|
1102
|
+
constant uint64_t & nb00,
|
1103
|
+
constant uint64_t & nb01,
|
1104
|
+
constant uint64_t & nb02,
|
1074
1105
|
constant int64_t & ne10,
|
1106
|
+
constant int64_t & ne11,
|
1075
1107
|
constant int64_t & ne12,
|
1108
|
+
constant uint64_t & nb10,
|
1109
|
+
constant uint64_t & nb11,
|
1110
|
+
constant uint64_t & nb12,
|
1076
1111
|
constant int64_t & ne0,
|
1077
1112
|
constant int64_t & ne1,
|
1078
|
-
constant uint & r2
|
1079
|
-
constant uint & r3
|
1113
|
+
constant uint & r2,
|
1114
|
+
constant uint & r3,
|
1080
1115
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1081
1116
|
uint tiisg[[thread_index_in_simdgroup]],
|
1082
1117
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
1182
1217
|
constant uint64_t & nb12,
|
1183
1218
|
constant int64_t & ne0,
|
1184
1219
|
constant int64_t & ne1,
|
1185
|
-
constant uint & r2
|
1186
|
-
constant uint & r3
|
1220
|
+
constant uint & r2,
|
1221
|
+
constant uint & r3,
|
1187
1222
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1188
1223
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1189
1224
|
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
1209
1244
|
constant uint64_t & nb12,
|
1210
1245
|
constant int64_t & ne0,
|
1211
1246
|
constant int64_t & ne1,
|
1212
|
-
constant uint & r2
|
1213
|
-
constant uint & r3
|
1247
|
+
constant uint & r2,
|
1248
|
+
constant uint & r3,
|
1214
1249
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1215
1250
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1216
1251
|
|
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
1346
1381
|
constant uint64_t & nb12,
|
1347
1382
|
constant int64_t & ne0,
|
1348
1383
|
constant int64_t & ne1,
|
1349
|
-
constant uint & r2
|
1350
|
-
constant uint & r3
|
1384
|
+
constant uint & r2,
|
1385
|
+
constant uint & r3,
|
1351
1386
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1352
1387
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1353
1388
|
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
1452
1487
|
constant uint64_t & nb12,
|
1453
1488
|
constant int64_t & ne0,
|
1454
1489
|
constant int64_t & ne1,
|
1455
|
-
constant uint & r2
|
1456
|
-
constant uint & r3
|
1490
|
+
constant uint & r2,
|
1491
|
+
constant uint & r3,
|
1457
1492
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1458
1493
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1459
1494
|
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1478
1513
|
constant uint64_t & nb12,
|
1479
1514
|
constant int64_t & ne0,
|
1480
1515
|
constant int64_t & ne1,
|
1481
|
-
constant uint & r2
|
1482
|
-
constant uint & r3
|
1516
|
+
constant uint & r2,
|
1517
|
+
constant uint & r3,
|
1483
1518
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1484
1519
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
1485
1520
|
|
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|
1543
1578
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
1544
1579
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1545
1580
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1546
|
-
|
1581
|
+
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1582
|
+
|
1547
1583
|
const int64_t k = i3*ne3 + i2;
|
1548
1584
|
|
1549
1585
|
float m_k;
|
@@ -2410,22 +2446,6 @@ typedef struct {
|
|
2410
2446
|
} block_q6_K;
|
2411
2447
|
// 210 bytes / block
|
2412
2448
|
|
2413
|
-
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
2414
|
-
uchar4 r;
|
2415
|
-
if (j < 4) {
|
2416
|
-
r[0] = q[j+0] & 63;
|
2417
|
-
r[2] = q[j+1] & 63;
|
2418
|
-
r[1] = q[j+4] & 63;
|
2419
|
-
r[3] = q[j+5] & 63;
|
2420
|
-
} else {
|
2421
|
-
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
2422
|
-
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
2423
|
-
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
2424
|
-
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
2425
|
-
}
|
2426
|
-
return r;
|
2427
|
-
}
|
2428
|
-
|
2429
2449
|
//====================================== dot products =========================
|
2430
2450
|
|
2431
2451
|
void kernel_mul_mv_q2_K_f32_impl(
|
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
2584
2604
|
device const float * src1,
|
2585
2605
|
device float * dst,
|
2586
2606
|
constant int64_t & ne00,
|
2587
|
-
constant int64_t & ne01
|
2588
|
-
constant int64_t & ne02
|
2589
|
-
constant
|
2590
|
-
constant
|
2591
|
-
constant
|
2592
|
-
constant int64_t &
|
2593
|
-
constant
|
2594
|
-
constant
|
2607
|
+
constant int64_t & ne01,
|
2608
|
+
constant int64_t & ne02,
|
2609
|
+
constant uint64_t & nb00,
|
2610
|
+
constant uint64_t & nb01,
|
2611
|
+
constant uint64_t & nb02,
|
2612
|
+
constant int64_t & ne10,
|
2613
|
+
constant int64_t & ne11,
|
2614
|
+
constant int64_t & ne12,
|
2615
|
+
constant uint64_t & nb10,
|
2616
|
+
constant uint64_t & nb11,
|
2617
|
+
constant uint64_t & nb12,
|
2618
|
+
constant int64_t & ne0,
|
2619
|
+
constant int64_t & ne1,
|
2620
|
+
constant uint & r2,
|
2621
|
+
constant uint & r3,
|
2595
2622
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2596
2623
|
uint tiisg[[thread_index_in_simdgroup]],
|
2597
2624
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
2841
2868
|
device const float * src1,
|
2842
2869
|
device float * dst,
|
2843
2870
|
constant int64_t & ne00,
|
2844
|
-
constant int64_t & ne01
|
2845
|
-
constant int64_t & ne02
|
2846
|
-
constant
|
2847
|
-
constant
|
2848
|
-
constant
|
2849
|
-
constant int64_t &
|
2850
|
-
constant
|
2851
|
-
constant
|
2871
|
+
constant int64_t & ne01,
|
2872
|
+
constant int64_t & ne02,
|
2873
|
+
constant uint64_t & nb00,
|
2874
|
+
constant uint64_t & nb01,
|
2875
|
+
constant uint64_t & nb02,
|
2876
|
+
constant int64_t & ne10,
|
2877
|
+
constant int64_t & ne11,
|
2878
|
+
constant int64_t & ne12,
|
2879
|
+
constant uint64_t & nb10,
|
2880
|
+
constant uint64_t & nb11,
|
2881
|
+
constant uint64_t & nb12,
|
2882
|
+
constant int64_t & ne0,
|
2883
|
+
constant int64_t & ne1,
|
2884
|
+
constant uint & r2,
|
2885
|
+
constant uint & r3,
|
2852
2886
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2853
2887
|
uint tiisg[[thread_index_in_simdgroup]],
|
2854
2888
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
2984
3018
|
constant uint & r2,
|
2985
3019
|
constant uint & r3,
|
2986
3020
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2987
|
-
uint
|
2988
|
-
uint
|
3021
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3022
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2989
3023
|
|
2990
3024
|
const int ix = tiisg/4; // 0...7
|
2991
3025
|
const int it = tiisg%4; // 0...3
|
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
2994
3028
|
const int r0 = tgpig.x;
|
2995
3029
|
const int r1 = tgpig.y;
|
2996
3030
|
const int im = tgpig.z;
|
2997
|
-
const int first_row =
|
3031
|
+
const int first_row = r0 * N_DST;
|
2998
3032
|
const int ib_row = first_row * nb;
|
2999
3033
|
|
3000
3034
|
const uint i12 = im%ne12;
|
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3060
3094
|
for (int row = 0; row < N_DST; ++row) {
|
3061
3095
|
all_sum = simd_sum(sumf[row]);
|
3062
3096
|
if (tiisg == 0) {
|
3063
|
-
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
3097
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
3064
3098
|
}
|
3065
3099
|
}
|
3066
3100
|
}
|
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
3072
3106
|
device const float * src1,
|
3073
3107
|
device float * dst,
|
3074
3108
|
constant int64_t & ne00,
|
3075
|
-
constant int64_t & ne01
|
3076
|
-
constant int64_t & ne02
|
3077
|
-
constant
|
3078
|
-
constant
|
3079
|
-
constant
|
3080
|
-
constant int64_t &
|
3081
|
-
constant
|
3082
|
-
constant
|
3109
|
+
constant int64_t & ne01,
|
3110
|
+
constant int64_t & ne02,
|
3111
|
+
constant uint64_t & nb00,
|
3112
|
+
constant uint64_t & nb01,
|
3113
|
+
constant uint64_t & nb02,
|
3114
|
+
constant int64_t & ne10,
|
3115
|
+
constant int64_t & ne11,
|
3116
|
+
constant int64_t & ne12,
|
3117
|
+
constant uint64_t & nb10,
|
3118
|
+
constant uint64_t & nb11,
|
3119
|
+
constant uint64_t & nb12,
|
3120
|
+
constant int64_t & ne0,
|
3121
|
+
constant int64_t & ne1,
|
3122
|
+
constant uint & r2,
|
3123
|
+
constant uint & r3,
|
3083
3124
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3084
3125
|
uint tiisg[[thread_index_in_simdgroup]],
|
3085
3126
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
3271
3312
|
device const float * src1,
|
3272
3313
|
device float * dst,
|
3273
3314
|
constant int64_t & ne00,
|
3274
|
-
constant int64_t & ne01
|
3275
|
-
constant int64_t & ne02
|
3276
|
-
constant
|
3277
|
-
constant
|
3278
|
-
constant
|
3279
|
-
constant int64_t &
|
3280
|
-
constant
|
3281
|
-
constant
|
3315
|
+
constant int64_t & ne01,
|
3316
|
+
constant int64_t & ne02,
|
3317
|
+
constant uint64_t & nb00,
|
3318
|
+
constant uint64_t & nb01,
|
3319
|
+
constant uint64_t & nb02,
|
3320
|
+
constant int64_t & ne10,
|
3321
|
+
constant int64_t & ne11,
|
3322
|
+
constant int64_t & ne12,
|
3323
|
+
constant uint64_t & nb10,
|
3324
|
+
constant uint64_t & nb11,
|
3325
|
+
constant uint64_t & nb12,
|
3326
|
+
constant int64_t & ne0,
|
3327
|
+
constant int64_t & ne1,
|
3328
|
+
constant uint & r2,
|
3329
|
+
constant uint & r3,
|
3282
3330
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3283
3331
|
uint tiisg[[thread_index_in_simdgroup]],
|
3284
3332
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
3398
3446
|
device const float * src1,
|
3399
3447
|
device float * dst,
|
3400
3448
|
constant int64_t & ne00,
|
3401
|
-
constant int64_t & ne01
|
3402
|
-
constant int64_t & ne02
|
3403
|
-
constant
|
3404
|
-
constant
|
3405
|
-
constant
|
3406
|
-
constant int64_t &
|
3407
|
-
constant
|
3408
|
-
constant
|
3449
|
+
constant int64_t & ne01,
|
3450
|
+
constant int64_t & ne02,
|
3451
|
+
constant uint64_t & nb00,
|
3452
|
+
constant uint64_t & nb01,
|
3453
|
+
constant uint64_t & nb02,
|
3454
|
+
constant int64_t & ne10,
|
3455
|
+
constant int64_t & ne11,
|
3456
|
+
constant int64_t & ne12,
|
3457
|
+
constant uint64_t & nb10,
|
3458
|
+
constant uint64_t & nb11,
|
3459
|
+
constant uint64_t & nb12,
|
3460
|
+
constant int64_t & ne0,
|
3461
|
+
constant int64_t & ne1,
|
3462
|
+
constant uint & r2,
|
3463
|
+
constant uint & r3,
|
3409
3464
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3410
3465
|
uint tiisg[[thread_index_in_simdgroup]],
|
3411
3466
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
3523
3578
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
3524
3579
|
const half d = xb->d;
|
3525
3580
|
|
3526
|
-
for (int i=0;i<16;i++) {
|
3581
|
+
for (int i = 0; i < 16; i++) {
|
3527
3582
|
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
3528
3583
|
}
|
3529
3584
|
}
|
@@ -3774,6 +3829,35 @@ kernel void kernel_get_rows_f16(
|
|
3774
3829
|
}
|
3775
3830
|
}
|
3776
3831
|
|
3832
|
+
kernel void kernel_get_rows_i32(
|
3833
|
+
device const void * src0,
|
3834
|
+
device const char * src1,
|
3835
|
+
device int32_t * dst,
|
3836
|
+
constant int64_t & ne00,
|
3837
|
+
constant uint64_t & nb01,
|
3838
|
+
constant uint64_t & nb02,
|
3839
|
+
constant int64_t & ne10,
|
3840
|
+
constant uint64_t & nb10,
|
3841
|
+
constant uint64_t & nb11,
|
3842
|
+
constant uint64_t & nb1,
|
3843
|
+
constant uint64_t & nb2,
|
3844
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3845
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
3846
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
3847
|
+
const int64_t i10 = tgpig.x;
|
3848
|
+
const int64_t i11 = tgpig.y;
|
3849
|
+
|
3850
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
3851
|
+
|
3852
|
+
const int64_t i02 = i11;
|
3853
|
+
|
3854
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
3855
|
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
3856
|
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
3857
|
+
}
|
3858
|
+
}
|
3859
|
+
|
3860
|
+
|
3777
3861
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
3778
3862
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
3779
3863
|
#define BLOCK_SIZE_K 32
|
@@ -3792,12 +3876,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
3792
3876
|
device float * dst,
|
3793
3877
|
constant int64_t & ne00,
|
3794
3878
|
constant int64_t & ne02,
|
3795
|
-
constant
|
3796
|
-
constant
|
3879
|
+
constant uint64_t & nb01,
|
3880
|
+
constant uint64_t & nb02,
|
3797
3881
|
constant int64_t & ne12,
|
3798
|
-
constant
|
3799
|
-
constant
|
3800
|
-
constant
|
3882
|
+
constant uint64_t & nb10,
|
3883
|
+
constant uint64_t & nb11,
|
3884
|
+
constant uint64_t & nb12,
|
3801
3885
|
constant int64_t & ne0,
|
3802
3886
|
constant int64_t & ne1,
|
3803
3887
|
constant uint & r2,
|
@@ -3918,18 +4002,143 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
3918
4002
|
}
|
3919
4003
|
}
|
3920
4004
|
|
4005
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
4006
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
4007
|
+
void kernel_mul_mm_id_impl(
|
4008
|
+
device const uchar * src0,
|
4009
|
+
device const uchar * src1,
|
4010
|
+
thread short * src1ids,
|
4011
|
+
device float * dst,
|
4012
|
+
constant int64_t & ne00,
|
4013
|
+
constant int64_t & ne02,
|
4014
|
+
constant uint64_t & nb01,
|
4015
|
+
constant uint64_t & nb02,
|
4016
|
+
constant int64_t & ne12,
|
4017
|
+
constant uint64_t & nb10,
|
4018
|
+
constant uint64_t & nb11,
|
4019
|
+
constant uint64_t & nb12,
|
4020
|
+
constant int64_t & ne0,
|
4021
|
+
int64_t ne1,
|
4022
|
+
constant uint & r2,
|
4023
|
+
constant uint & r3,
|
4024
|
+
threadgroup uchar * shared_memory,
|
4025
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4026
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
4027
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4028
|
+
|
4029
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
4030
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
4031
|
+
|
4032
|
+
const uint r0 = tgpig.y;
|
4033
|
+
const uint r1 = tgpig.x;
|
4034
|
+
const uint im = tgpig.z;
|
4035
|
+
|
4036
|
+
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
4037
|
+
|
4038
|
+
// if this block is of 64x32 shape or smaller
|
4039
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
4040
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
4041
|
+
|
4042
|
+
// a thread shouldn't load data outside of the matrix
|
4043
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
4044
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
4045
|
+
|
4046
|
+
simdgroup_half8x8 ma[4];
|
4047
|
+
simdgroup_float8x8 mb[2];
|
4048
|
+
simdgroup_float8x8 c_res[8];
|
4049
|
+
for (int i = 0; i < 8; i++){
|
4050
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
4051
|
+
}
|
4052
|
+
|
4053
|
+
short il = (tiitg % THREAD_PER_ROW);
|
4054
|
+
|
4055
|
+
const uint i12 = im%ne12;
|
4056
|
+
const uint i13 = im/ne12;
|
4057
|
+
|
4058
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
4059
|
+
ushort offset1 = il/nl;
|
4060
|
+
|
4061
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
4062
|
+
device const float * y = (device const float *)(src1
|
4063
|
+
+ nb12 * im
|
4064
|
+
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
4065
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
4066
|
+
|
4067
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
4068
|
+
// load data and store to threadgroup memory
|
4069
|
+
half4x4 temp_a;
|
4070
|
+
dequantize_func(x, il, temp_a);
|
4071
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4072
|
+
|
4073
|
+
for (int i = 0; i < 16; i++) {
|
4074
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
4075
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
4076
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
4077
|
+
}
|
4078
|
+
|
4079
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
4080
|
+
|
4081
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
4082
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
4083
|
+
y += BLOCK_SIZE_K;
|
4084
|
+
|
4085
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4086
|
+
|
4087
|
+
// load matrices from threadgroup memory and conduct outer products
|
4088
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
4089
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
4090
|
+
|
4091
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
4092
|
+
for (int i = 0; i < 4; i++) {
|
4093
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
4094
|
+
}
|
4095
|
+
simdgroup_barrier(mem_flags::mem_none);
|
4096
|
+
for (int i = 0; i < 2; i++) {
|
4097
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
4098
|
+
}
|
4099
|
+
|
4100
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
4101
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
4102
|
+
|
4103
|
+
for (int i = 0; i < 8; i++){
|
4104
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
4105
|
+
}
|
4106
|
+
}
|
4107
|
+
}
|
4108
|
+
|
4109
|
+
{
|
4110
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4111
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
4112
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
4113
|
+
for (int i = 0; i < 8; i++) {
|
4114
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
4115
|
+
}
|
4116
|
+
|
4117
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4118
|
+
|
4119
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
4120
|
+
if (sgitg == 0) {
|
4121
|
+
for (int i = 0; i < n_rows; i++) {
|
4122
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
4123
|
+
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
4124
|
+
}
|
4125
|
+
}
|
4126
|
+
}
|
4127
|
+
}
|
4128
|
+
}
|
4129
|
+
|
3921
4130
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
3922
4131
|
kernel void kernel_mul_mm(device const uchar * src0,
|
3923
4132
|
device const uchar * src1,
|
3924
4133
|
device float * dst,
|
3925
4134
|
constant int64_t & ne00,
|
3926
4135
|
constant int64_t & ne02,
|
3927
|
-
constant
|
3928
|
-
constant
|
4136
|
+
constant uint64_t & nb01,
|
4137
|
+
constant uint64_t & nb02,
|
3929
4138
|
constant int64_t & ne12,
|
3930
|
-
constant
|
3931
|
-
constant
|
3932
|
-
constant
|
4139
|
+
constant uint64_t & nb10,
|
4140
|
+
constant uint64_t & nb11,
|
4141
|
+
constant uint64_t & nb12,
|
3933
4142
|
constant int64_t & ne0,
|
3934
4143
|
constant int64_t & ne1,
|
3935
4144
|
constant uint & r2,
|
@@ -3964,20 +4173,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|
3964
4173
|
kernel void kernel_mul_mm_id(
|
3965
4174
|
device const uchar * ids,
|
3966
4175
|
device const uchar * src1,
|
3967
|
-
device
|
3968
|
-
constant
|
4176
|
+
device float * dst,
|
4177
|
+
constant uint64_t & nbi1,
|
3969
4178
|
constant int64_t & ne00,
|
3970
4179
|
constant int64_t & ne02,
|
3971
|
-
constant
|
3972
|
-
constant
|
4180
|
+
constant uint64_t & nb01,
|
4181
|
+
constant uint64_t & nb02,
|
3973
4182
|
constant int64_t & ne12,
|
3974
4183
|
constant int64_t & ne13,
|
3975
|
-
constant
|
3976
|
-
constant
|
3977
|
-
constant
|
4184
|
+
constant uint64_t & nb10,
|
4185
|
+
constant uint64_t & nb11,
|
4186
|
+
constant uint64_t & nb12,
|
3978
4187
|
constant int64_t & ne0,
|
3979
4188
|
constant int64_t & ne1,
|
3980
|
-
constant
|
4189
|
+
constant uint64_t & nb1,
|
3981
4190
|
constant uint & r2,
|
3982
4191
|
constant uint & r3,
|
3983
4192
|
constant int & idx,
|
@@ -3993,18 +4202,28 @@ kernel void kernel_mul_mm_id(
|
|
3993
4202
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3994
4203
|
uint tiitg[[thread_index_in_threadgroup]],
|
3995
4204
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3996
|
-
device const uchar *
|
4205
|
+
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
3997
4206
|
|
3998
|
-
|
4207
|
+
// expert id
|
4208
|
+
const int32_t id = tgpig.z/(ne12*ne13);
|
3999
4209
|
|
4000
4210
|
tgpig.z = tgpig.z%(ne12*ne13);
|
4001
4211
|
|
4002
|
-
|
4212
|
+
// row indices of src1 for expert id
|
4213
|
+
int64_t _ne1 = 0;
|
4214
|
+
short src1ids[512];
|
4003
4215
|
|
4004
|
-
|
4005
|
-
|
4006
|
-
|
4007
|
-
|
4216
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
4217
|
+
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
4218
|
+
src1ids[_ne1++] = i1;
|
4219
|
+
}
|
4220
|
+
}
|
4221
|
+
|
4222
|
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
4223
|
+
src0s[id],
|
4224
|
+
src1,
|
4225
|
+
src1ids,
|
4226
|
+
dst,
|
4008
4227
|
ne00,
|
4009
4228
|
ne02,
|
4010
4229
|
nb01,
|
@@ -4014,7 +4233,7 @@ kernel void kernel_mul_mm_id(
|
|
4014
4233
|
nb11,
|
4015
4234
|
nb12,
|
4016
4235
|
ne0,
|
4017
|
-
|
4236
|
+
_ne1,
|
4018
4237
|
r2,
|
4019
4238
|
r3,
|
4020
4239
|
shared_memory,
|
@@ -4070,12 +4289,12 @@ typedef void (mat_mm_t)(
|
|
4070
4289
|
device float * dst,
|
4071
4290
|
constant int64_t & ne00,
|
4072
4291
|
constant int64_t & ne02,
|
4073
|
-
constant
|
4074
|
-
constant
|
4292
|
+
constant uint64_t & nb01,
|
4293
|
+
constant uint64_t & nb02,
|
4075
4294
|
constant int64_t & ne12,
|
4076
|
-
constant
|
4077
|
-
constant
|
4078
|
-
constant
|
4295
|
+
constant uint64_t & nb10,
|
4296
|
+
constant uint64_t & nb11,
|
4297
|
+
constant uint64_t & nb12,
|
4079
4298
|
constant int64_t & ne0,
|
4080
4299
|
constant int64_t & ne1,
|
4081
4300
|
constant uint & r2,
|
@@ -4103,20 +4322,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
4103
4322
|
typedef void (mat_mm_id_t)(
|
4104
4323
|
device const uchar * ids,
|
4105
4324
|
device const uchar * src1,
|
4106
|
-
device
|
4107
|
-
constant
|
4325
|
+
device float * dst,
|
4326
|
+
constant uint64_t & nbi1,
|
4108
4327
|
constant int64_t & ne00,
|
4109
4328
|
constant int64_t & ne02,
|
4110
|
-
constant
|
4111
|
-
constant
|
4329
|
+
constant uint64_t & nb01,
|
4330
|
+
constant uint64_t & nb02,
|
4112
4331
|
constant int64_t & ne12,
|
4113
4332
|
constant int64_t & ne13,
|
4114
|
-
constant
|
4115
|
-
constant
|
4116
|
-
constant
|
4333
|
+
constant uint64_t & nb10,
|
4334
|
+
constant uint64_t & nb11,
|
4335
|
+
constant uint64_t & nb12,
|
4117
4336
|
constant int64_t & ne0,
|
4118
4337
|
constant int64_t & ne1,
|
4119
|
-
constant
|
4338
|
+
constant uint64_t & nb1,
|
4120
4339
|
constant uint & r2,
|
4121
4340
|
constant uint & r3,
|
4122
4341
|
constant int & idx,
|
@@ -4152,8 +4371,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
4152
4371
|
kernel void kernel_mul_mv_id_f32_f32(
|
4153
4372
|
device const char * ids,
|
4154
4373
|
device const char * src1,
|
4155
|
-
device
|
4156
|
-
constant
|
4374
|
+
device float * dst,
|
4375
|
+
constant uint64_t & nbi1,
|
4157
4376
|
constant int64_t & ne00,
|
4158
4377
|
constant int64_t & ne01,
|
4159
4378
|
constant int64_t & ne02,
|
@@ -4169,7 +4388,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
4169
4388
|
constant uint64_t & nb12,
|
4170
4389
|
constant int64_t & ne0,
|
4171
4390
|
constant int64_t & ne1,
|
4172
|
-
constant
|
4391
|
+
constant uint64_t & nb1,
|
4173
4392
|
constant uint & r2,
|
4174
4393
|
constant uint & r3,
|
4175
4394
|
constant int & idx,
|
@@ -4196,7 +4415,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
4196
4415
|
kernel_mul_mv_f32_f32_impl(
|
4197
4416
|
src0[id],
|
4198
4417
|
src1 + bid*nb11,
|
4199
|
-
|
4418
|
+
dst + bid*ne0,
|
4200
4419
|
ne00,
|
4201
4420
|
ne01,
|
4202
4421
|
ne02,
|
@@ -4221,8 +4440,8 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
4221
4440
|
kernel void kernel_mul_mv_id_f16_f32(
|
4222
4441
|
device const char * ids,
|
4223
4442
|
device const char * src1,
|
4224
|
-
device
|
4225
|
-
constant
|
4443
|
+
device float * dst,
|
4444
|
+
constant uint64_t & nbi1,
|
4226
4445
|
constant int64_t & ne00,
|
4227
4446
|
constant int64_t & ne01,
|
4228
4447
|
constant int64_t & ne02,
|
@@ -4238,7 +4457,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
4238
4457
|
constant uint64_t & nb12,
|
4239
4458
|
constant int64_t & ne0,
|
4240
4459
|
constant int64_t & ne1,
|
4241
|
-
constant
|
4460
|
+
constant uint64_t & nb1,
|
4242
4461
|
constant uint & r2,
|
4243
4462
|
constant uint & r3,
|
4244
4463
|
constant int & idx,
|
@@ -4265,7 +4484,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
4265
4484
|
kernel_mul_mv_f16_f32_impl(
|
4266
4485
|
src0[id],
|
4267
4486
|
src1 + bid*nb11,
|
4268
|
-
|
4487
|
+
dst + bid*ne0,
|
4269
4488
|
ne00,
|
4270
4489
|
ne01,
|
4271
4490
|
ne02,
|
@@ -4290,8 +4509,8 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
4290
4509
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
4291
4510
|
device const char * ids,
|
4292
4511
|
device const char * src1,
|
4293
|
-
device
|
4294
|
-
constant
|
4512
|
+
device float * dst,
|
4513
|
+
constant uint64_t & nbi1,
|
4295
4514
|
constant int64_t & ne00,
|
4296
4515
|
constant int64_t & ne01,
|
4297
4516
|
constant int64_t & ne02,
|
@@ -4307,7 +4526,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4307
4526
|
constant uint64_t & nb12,
|
4308
4527
|
constant int64_t & ne0,
|
4309
4528
|
constant int64_t & ne1,
|
4310
|
-
constant
|
4529
|
+
constant uint64_t & nb1,
|
4311
4530
|
constant uint & r2,
|
4312
4531
|
constant uint & r3,
|
4313
4532
|
constant int & idx,
|
@@ -4334,7 +4553,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4334
4553
|
kernel_mul_mv_q8_0_f32_impl(
|
4335
4554
|
src0[id],
|
4336
4555
|
(device const float *) (src1 + bid*nb11),
|
4337
|
-
|
4556
|
+
dst + bid*ne0,
|
4338
4557
|
ne00,
|
4339
4558
|
ne01,
|
4340
4559
|
ne02,
|
@@ -4353,8 +4572,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4353
4572
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
4354
4573
|
device const char * ids,
|
4355
4574
|
device const char * src1,
|
4356
|
-
device
|
4357
|
-
constant
|
4575
|
+
device float * dst,
|
4576
|
+
constant uint64_t & nbi1,
|
4358
4577
|
constant int64_t & ne00,
|
4359
4578
|
constant int64_t & ne01,
|
4360
4579
|
constant int64_t & ne02,
|
@@ -4370,7 +4589,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4370
4589
|
constant uint64_t & nb12,
|
4371
4590
|
constant int64_t & ne0,
|
4372
4591
|
constant int64_t & ne1,
|
4373
|
-
constant
|
4592
|
+
constant uint64_t & nb1,
|
4374
4593
|
constant uint & r2,
|
4375
4594
|
constant uint & r3,
|
4376
4595
|
constant int & idx,
|
@@ -4397,7 +4616,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4397
4616
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4398
4617
|
src0[id],
|
4399
4618
|
(device const float *) (src1 + bid*nb11),
|
4400
|
-
|
4619
|
+
dst + bid*ne0,
|
4401
4620
|
ne00,
|
4402
4621
|
ne01,
|
4403
4622
|
ne02,
|
@@ -4416,8 +4635,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4416
4635
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
4417
4636
|
device const char * ids,
|
4418
4637
|
device const char * src1,
|
4419
|
-
device
|
4420
|
-
constant
|
4638
|
+
device float * dst,
|
4639
|
+
constant uint64_t & nbi1,
|
4421
4640
|
constant int64_t & ne00,
|
4422
4641
|
constant int64_t & ne01,
|
4423
4642
|
constant int64_t & ne02,
|
@@ -4433,7 +4652,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4433
4652
|
constant uint64_t & nb12,
|
4434
4653
|
constant int64_t & ne0,
|
4435
4654
|
constant int64_t & ne1,
|
4436
|
-
constant
|
4655
|
+
constant uint64_t & nb1,
|
4437
4656
|
constant uint & r2,
|
4438
4657
|
constant uint & r3,
|
4439
4658
|
constant int & idx,
|
@@ -4460,7 +4679,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4460
4679
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4461
4680
|
src0[id],
|
4462
4681
|
(device const float *) (src1 + bid*nb11),
|
4463
|
-
|
4682
|
+
dst + bid*ne0,
|
4464
4683
|
ne00,
|
4465
4684
|
ne01,
|
4466
4685
|
ne02,
|
@@ -4479,8 +4698,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4479
4698
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
4480
4699
|
device const char * ids,
|
4481
4700
|
device const char * src1,
|
4482
|
-
device
|
4483
|
-
constant
|
4701
|
+
device float * dst,
|
4702
|
+
constant uint64_t & nbi1,
|
4484
4703
|
constant int64_t & ne00,
|
4485
4704
|
constant int64_t & ne01,
|
4486
4705
|
constant int64_t & ne02,
|
@@ -4496,7 +4715,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4496
4715
|
constant uint64_t & nb12,
|
4497
4716
|
constant int64_t & ne0,
|
4498
4717
|
constant int64_t & ne1,
|
4499
|
-
constant
|
4718
|
+
constant uint64_t & nb1,
|
4500
4719
|
constant uint & r2,
|
4501
4720
|
constant uint & r3,
|
4502
4721
|
constant int & idx,
|
@@ -4523,7 +4742,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4523
4742
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4524
4743
|
src0[id],
|
4525
4744
|
(device const float *) (src1 + bid*nb11),
|
4526
|
-
|
4745
|
+
dst + bid*ne0,
|
4527
4746
|
ne00,
|
4528
4747
|
ne01,
|
4529
4748
|
ne02,
|
@@ -4542,8 +4761,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4542
4761
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
4543
4762
|
device const char * ids,
|
4544
4763
|
device const char * src1,
|
4545
|
-
device
|
4546
|
-
constant
|
4764
|
+
device float * dst,
|
4765
|
+
constant uint64_t & nbi1,
|
4547
4766
|
constant int64_t & ne00,
|
4548
4767
|
constant int64_t & ne01,
|
4549
4768
|
constant int64_t & ne02,
|
@@ -4559,7 +4778,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4559
4778
|
constant uint64_t & nb12,
|
4560
4779
|
constant int64_t & ne0,
|
4561
4780
|
constant int64_t & ne1,
|
4562
|
-
constant
|
4781
|
+
constant uint64_t & nb1,
|
4563
4782
|
constant uint & r2,
|
4564
4783
|
constant uint & r3,
|
4565
4784
|
constant int & idx,
|
@@ -4586,7 +4805,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4586
4805
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
4587
4806
|
src0[id],
|
4588
4807
|
(device const float *) (src1 + bid*nb11),
|
4589
|
-
|
4808
|
+
dst + bid*ne0,
|
4590
4809
|
ne00,
|
4591
4810
|
ne01,
|
4592
4811
|
ne02,
|
@@ -4605,8 +4824,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4605
4824
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
4606
4825
|
device const char * ids,
|
4607
4826
|
device const char * src1,
|
4608
|
-
device
|
4609
|
-
constant
|
4827
|
+
device float * dst,
|
4828
|
+
constant uint64_t & nbi1,
|
4610
4829
|
constant int64_t & ne00,
|
4611
4830
|
constant int64_t & ne01,
|
4612
4831
|
constant int64_t & ne02,
|
@@ -4622,7 +4841,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4622
4841
|
constant uint64_t & nb12,
|
4623
4842
|
constant int64_t & ne0,
|
4624
4843
|
constant int64_t & ne1,
|
4625
|
-
constant
|
4844
|
+
constant uint64_t & nb1,
|
4626
4845
|
constant uint & r2,
|
4627
4846
|
constant uint & r3,
|
4628
4847
|
constant int & idx,
|
@@ -4649,7 +4868,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4649
4868
|
kernel_mul_mv_q2_K_f32_impl(
|
4650
4869
|
src0[id],
|
4651
4870
|
(device const float *) (src1 + bid*nb11),
|
4652
|
-
|
4871
|
+
dst + bid*ne0,
|
4653
4872
|
ne00,
|
4654
4873
|
ne01,
|
4655
4874
|
ne02,
|
@@ -4668,8 +4887,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4668
4887
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
4669
4888
|
device const char * ids,
|
4670
4889
|
device const char * src1,
|
4671
|
-
device
|
4672
|
-
constant
|
4890
|
+
device float * dst,
|
4891
|
+
constant uint64_t & nbi1,
|
4673
4892
|
constant int64_t & ne00,
|
4674
4893
|
constant int64_t & ne01,
|
4675
4894
|
constant int64_t & ne02,
|
@@ -4685,7 +4904,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4685
4904
|
constant uint64_t & nb12,
|
4686
4905
|
constant int64_t & ne0,
|
4687
4906
|
constant int64_t & ne1,
|
4688
|
-
constant
|
4907
|
+
constant uint64_t & nb1,
|
4689
4908
|
constant uint & r2,
|
4690
4909
|
constant uint & r3,
|
4691
4910
|
constant int & idx,
|
@@ -4712,7 +4931,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4712
4931
|
kernel_mul_mv_q3_K_f32_impl(
|
4713
4932
|
src0[id],
|
4714
4933
|
(device const float *) (src1 + bid*nb11),
|
4715
|
-
|
4934
|
+
dst + bid*ne0,
|
4716
4935
|
ne00,
|
4717
4936
|
ne01,
|
4718
4937
|
ne02,
|
@@ -4731,8 +4950,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4731
4950
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
4732
4951
|
device const char * ids,
|
4733
4952
|
device const char * src1,
|
4734
|
-
device
|
4735
|
-
constant
|
4953
|
+
device float * dst,
|
4954
|
+
constant uint64_t & nbi1,
|
4736
4955
|
constant int64_t & ne00,
|
4737
4956
|
constant int64_t & ne01,
|
4738
4957
|
constant int64_t & ne02,
|
@@ -4748,7 +4967,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4748
4967
|
constant uint64_t & nb12,
|
4749
4968
|
constant int64_t & ne0,
|
4750
4969
|
constant int64_t & ne1,
|
4751
|
-
constant
|
4970
|
+
constant uint64_t & nb1,
|
4752
4971
|
constant uint & r2,
|
4753
4972
|
constant uint & r3,
|
4754
4973
|
constant int & idx,
|
@@ -4775,7 +4994,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4775
4994
|
kernel_mul_mv_q4_K_f32_impl(
|
4776
4995
|
src0[id],
|
4777
4996
|
(device const float *) (src1 + bid*nb11),
|
4778
|
-
|
4997
|
+
dst + bid*ne0,
|
4779
4998
|
ne00,
|
4780
4999
|
ne01,
|
4781
5000
|
ne02,
|
@@ -4794,8 +5013,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4794
5013
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
4795
5014
|
device const char * ids,
|
4796
5015
|
device const char * src1,
|
4797
|
-
device
|
4798
|
-
constant
|
5016
|
+
device float * dst,
|
5017
|
+
constant uint64_t & nbi1,
|
4799
5018
|
constant int64_t & ne00,
|
4800
5019
|
constant int64_t & ne01,
|
4801
5020
|
constant int64_t & ne02,
|
@@ -4811,7 +5030,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4811
5030
|
constant uint64_t & nb12,
|
4812
5031
|
constant int64_t & ne0,
|
4813
5032
|
constant int64_t & ne1,
|
4814
|
-
constant
|
5033
|
+
constant uint64_t & nb1,
|
4815
5034
|
constant uint & r2,
|
4816
5035
|
constant uint & r3,
|
4817
5036
|
constant int & idx,
|
@@ -4838,7 +5057,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4838
5057
|
kernel_mul_mv_q5_K_f32_impl(
|
4839
5058
|
src0[id],
|
4840
5059
|
(device const float *) (src1 + bid*nb11),
|
4841
|
-
|
5060
|
+
dst + bid*ne0,
|
4842
5061
|
ne00,
|
4843
5062
|
ne01,
|
4844
5063
|
ne02,
|
@@ -4857,8 +5076,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4857
5076
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
4858
5077
|
device const char * ids,
|
4859
5078
|
device const char * src1,
|
4860
|
-
device
|
4861
|
-
constant
|
5079
|
+
device float * dst,
|
5080
|
+
constant uint64_t & nbi1,
|
4862
5081
|
constant int64_t & ne00,
|
4863
5082
|
constant int64_t & ne01,
|
4864
5083
|
constant int64_t & ne02,
|
@@ -4874,7 +5093,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4874
5093
|
constant uint64_t & nb12,
|
4875
5094
|
constant int64_t & ne0,
|
4876
5095
|
constant int64_t & ne1,
|
4877
|
-
constant
|
5096
|
+
constant uint64_t & nb1,
|
4878
5097
|
constant uint & r2,
|
4879
5098
|
constant uint & r3,
|
4880
5099
|
constant int & idx,
|
@@ -4901,7 +5120,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4901
5120
|
kernel_mul_mv_q6_K_f32_impl(
|
4902
5121
|
src0[id],
|
4903
5122
|
(device const float *) (src1 + bid*nb11),
|
4904
|
-
|
5123
|
+
dst + bid*ne0,
|
4905
5124
|
ne00,
|
4906
5125
|
ne01,
|
4907
5126
|
ne02,
|