node-llama-cpp 2.8.3 → 2.8.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/mac-arm64/ggml-metal.metal +1243 -340
- package/llamaBins/mac-arm64/llama-addon.node +0 -0
- package/llamaBins/mac-x64/ggml-metal.metal +1243 -340
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/package.json +1 -1
|
@@ -59,26 +59,26 @@ kernel void kernel_add(
|
|
|
59
59
|
constant int64_t & ne01,
|
|
60
60
|
constant int64_t & ne02,
|
|
61
61
|
constant int64_t & ne03,
|
|
62
|
-
constant
|
|
63
|
-
constant
|
|
64
|
-
constant
|
|
65
|
-
constant
|
|
62
|
+
constant uint64_t & nb00,
|
|
63
|
+
constant uint64_t & nb01,
|
|
64
|
+
constant uint64_t & nb02,
|
|
65
|
+
constant uint64_t & nb03,
|
|
66
66
|
constant int64_t & ne10,
|
|
67
67
|
constant int64_t & ne11,
|
|
68
68
|
constant int64_t & ne12,
|
|
69
69
|
constant int64_t & ne13,
|
|
70
|
-
constant
|
|
71
|
-
constant
|
|
72
|
-
constant
|
|
73
|
-
constant
|
|
70
|
+
constant uint64_t & nb10,
|
|
71
|
+
constant uint64_t & nb11,
|
|
72
|
+
constant uint64_t & nb12,
|
|
73
|
+
constant uint64_t & nb13,
|
|
74
74
|
constant int64_t & ne0,
|
|
75
75
|
constant int64_t & ne1,
|
|
76
76
|
constant int64_t & ne2,
|
|
77
77
|
constant int64_t & ne3,
|
|
78
|
-
constant
|
|
79
|
-
constant
|
|
80
|
-
constant
|
|
81
|
-
constant
|
|
78
|
+
constant uint64_t & nb0,
|
|
79
|
+
constant uint64_t & nb1,
|
|
80
|
+
constant uint64_t & nb2,
|
|
81
|
+
constant uint64_t & nb3,
|
|
82
82
|
constant int64_t & offs,
|
|
83
83
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
84
84
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -109,26 +109,26 @@ kernel void kernel_mul(
|
|
|
109
109
|
constant int64_t & ne01,
|
|
110
110
|
constant int64_t & ne02,
|
|
111
111
|
constant int64_t & ne03,
|
|
112
|
-
constant
|
|
113
|
-
constant
|
|
114
|
-
constant
|
|
115
|
-
constant
|
|
112
|
+
constant uint64_t & nb00,
|
|
113
|
+
constant uint64_t & nb01,
|
|
114
|
+
constant uint64_t & nb02,
|
|
115
|
+
constant uint64_t & nb03,
|
|
116
116
|
constant int64_t & ne10,
|
|
117
117
|
constant int64_t & ne11,
|
|
118
118
|
constant int64_t & ne12,
|
|
119
119
|
constant int64_t & ne13,
|
|
120
|
-
constant
|
|
121
|
-
constant
|
|
122
|
-
constant
|
|
123
|
-
constant
|
|
120
|
+
constant uint64_t & nb10,
|
|
121
|
+
constant uint64_t & nb11,
|
|
122
|
+
constant uint64_t & nb12,
|
|
123
|
+
constant uint64_t & nb13,
|
|
124
124
|
constant int64_t & ne0,
|
|
125
125
|
constant int64_t & ne1,
|
|
126
126
|
constant int64_t & ne2,
|
|
127
127
|
constant int64_t & ne3,
|
|
128
|
-
constant
|
|
129
|
-
constant
|
|
130
|
-
constant
|
|
131
|
-
constant
|
|
128
|
+
constant uint64_t & nb0,
|
|
129
|
+
constant uint64_t & nb1,
|
|
130
|
+
constant uint64_t & nb2,
|
|
131
|
+
constant uint64_t & nb3,
|
|
132
132
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
133
133
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
134
134
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -158,26 +158,26 @@ kernel void kernel_div(
|
|
|
158
158
|
constant int64_t & ne01,
|
|
159
159
|
constant int64_t & ne02,
|
|
160
160
|
constant int64_t & ne03,
|
|
161
|
-
constant
|
|
162
|
-
constant
|
|
163
|
-
constant
|
|
164
|
-
constant
|
|
161
|
+
constant uint64_t & nb00,
|
|
162
|
+
constant uint64_t & nb01,
|
|
163
|
+
constant uint64_t & nb02,
|
|
164
|
+
constant uint64_t & nb03,
|
|
165
165
|
constant int64_t & ne10,
|
|
166
166
|
constant int64_t & ne11,
|
|
167
167
|
constant int64_t & ne12,
|
|
168
168
|
constant int64_t & ne13,
|
|
169
|
-
constant
|
|
170
|
-
constant
|
|
171
|
-
constant
|
|
172
|
-
constant
|
|
169
|
+
constant uint64_t & nb10,
|
|
170
|
+
constant uint64_t & nb11,
|
|
171
|
+
constant uint64_t & nb12,
|
|
172
|
+
constant uint64_t & nb13,
|
|
173
173
|
constant int64_t & ne0,
|
|
174
174
|
constant int64_t & ne1,
|
|
175
175
|
constant int64_t & ne2,
|
|
176
176
|
constant int64_t & ne3,
|
|
177
|
-
constant
|
|
178
|
-
constant
|
|
179
|
-
constant
|
|
180
|
-
constant
|
|
177
|
+
constant uint64_t & nb0,
|
|
178
|
+
constant uint64_t & nb1,
|
|
179
|
+
constant uint64_t & nb2,
|
|
180
|
+
constant uint64_t & nb3,
|
|
181
181
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
182
182
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
183
183
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
|
|
|
205
205
|
device const float4 * src0,
|
|
206
206
|
device const float4 * src1,
|
|
207
207
|
device float4 * dst,
|
|
208
|
-
constant
|
|
208
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
209
209
|
uint tpig[[thread_position_in_grid]]) {
|
|
210
210
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
211
211
|
}
|
|
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
|
|
|
214
214
|
device const float4 * src0,
|
|
215
215
|
device const float4 * src1,
|
|
216
216
|
device float4 * dst,
|
|
217
|
-
constant
|
|
217
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
218
218
|
uint tpig[[thread_position_in_grid]]) {
|
|
219
219
|
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
|
220
220
|
}
|
|
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
|
|
|
223
223
|
device const float4 * src0,
|
|
224
224
|
device const float4 * src1,
|
|
225
225
|
device float4 * dst,
|
|
226
|
-
constant
|
|
226
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
227
227
|
uint tpig[[thread_position_in_grid]]) {
|
|
228
228
|
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
|
229
229
|
}
|
|
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
|
|
|
307
307
|
constant int64_t & ne01,
|
|
308
308
|
constant int64_t & ne02,
|
|
309
309
|
constant int64_t & ne03,
|
|
310
|
-
constant
|
|
311
|
-
constant
|
|
312
|
-
constant
|
|
313
|
-
constant
|
|
310
|
+
constant uint64_t & nb00,
|
|
311
|
+
constant uint64_t & nb01,
|
|
312
|
+
constant uint64_t & nb02,
|
|
313
|
+
constant uint64_t & nb03,
|
|
314
314
|
constant int64_t & ne10,
|
|
315
315
|
constant int64_t & ne11,
|
|
316
316
|
constant int64_t & ne12,
|
|
317
317
|
constant int64_t & ne13,
|
|
318
|
-
constant
|
|
319
|
-
constant
|
|
320
|
-
constant
|
|
321
|
-
constant
|
|
318
|
+
constant uint64_t & nb10,
|
|
319
|
+
constant uint64_t & nb11,
|
|
320
|
+
constant uint64_t & nb12,
|
|
321
|
+
constant uint64_t & nb13,
|
|
322
322
|
constant int64_t & ne0,
|
|
323
323
|
constant int64_t & ne1,
|
|
324
324
|
constant int64_t & ne2,
|
|
325
325
|
constant int64_t & ne3,
|
|
326
|
-
constant
|
|
327
|
-
constant
|
|
328
|
-
constant
|
|
329
|
-
constant
|
|
326
|
+
constant uint64_t & nb0,
|
|
327
|
+
constant uint64_t & nb1,
|
|
328
|
+
constant uint64_t & nb2,
|
|
329
|
+
constant uint64_t & nb3,
|
|
330
330
|
uint3 tpig[[thread_position_in_grid]]) {
|
|
331
331
|
int64_t i3 = tpig.z;
|
|
332
332
|
int64_t i2 = tpig.y;
|
|
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
846
846
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
847
847
|
//Note: This is a template, but strictly speaking it only applies to
|
|
848
848
|
// quantizations where the block size is 32. It also does not
|
|
849
|
-
//
|
|
849
|
+
// guard against the number of rows not being divisible by
|
|
850
850
|
// N_DST, so this is another explicit assumption of the implementation.
|
|
851
851
|
template<typename block_q_type, int nr, int nsg, int nw>
|
|
852
852
|
void mul_vec_q_n_f32_impl(
|
|
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
920
920
|
device const float * src1,
|
|
921
921
|
device float * dst,
|
|
922
922
|
constant int64_t & ne00,
|
|
923
|
-
constant int64_t & ne01
|
|
924
|
-
constant int64_t & ne02
|
|
925
|
-
constant
|
|
926
|
-
constant
|
|
927
|
-
constant
|
|
928
|
-
constant int64_t &
|
|
929
|
-
constant
|
|
930
|
-
constant
|
|
923
|
+
constant int64_t & ne01,
|
|
924
|
+
constant int64_t & ne02,
|
|
925
|
+
constant uint64_t & nb00,
|
|
926
|
+
constant uint64_t & nb01,
|
|
927
|
+
constant uint64_t & nb02,
|
|
928
|
+
constant int64_t & ne10,
|
|
929
|
+
constant int64_t & ne11,
|
|
930
|
+
constant int64_t & ne12,
|
|
931
|
+
constant uint64_t & nb10,
|
|
932
|
+
constant uint64_t & nb11,
|
|
933
|
+
constant uint64_t & nb12,
|
|
934
|
+
constant int64_t & ne0,
|
|
935
|
+
constant int64_t & ne1,
|
|
936
|
+
constant uint & r2,
|
|
937
|
+
constant uint & r3,
|
|
931
938
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
932
939
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
933
940
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
939
946
|
device const float * src1,
|
|
940
947
|
device float * dst,
|
|
941
948
|
constant int64_t & ne00,
|
|
942
|
-
constant int64_t & ne01
|
|
943
|
-
constant int64_t & ne02
|
|
944
|
-
constant
|
|
945
|
-
constant
|
|
946
|
-
constant
|
|
947
|
-
constant int64_t &
|
|
948
|
-
constant
|
|
949
|
-
constant
|
|
949
|
+
constant int64_t & ne01,
|
|
950
|
+
constant int64_t & ne02,
|
|
951
|
+
constant uint64_t & nb00,
|
|
952
|
+
constant uint64_t & nb01,
|
|
953
|
+
constant uint64_t & nb02,
|
|
954
|
+
constant int64_t & ne10,
|
|
955
|
+
constant int64_t & ne11,
|
|
956
|
+
constant int64_t & ne12,
|
|
957
|
+
constant uint64_t & nb10,
|
|
958
|
+
constant uint64_t & nb11,
|
|
959
|
+
constant uint64_t & nb12,
|
|
960
|
+
constant int64_t & ne0,
|
|
961
|
+
constant int64_t & ne1,
|
|
962
|
+
constant uint & r2,
|
|
963
|
+
constant uint & r3,
|
|
950
964
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
951
965
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
952
966
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
958
972
|
device const float * src1,
|
|
959
973
|
device float * dst,
|
|
960
974
|
constant int64_t & ne00,
|
|
961
|
-
constant int64_t & ne01
|
|
962
|
-
constant int64_t & ne02
|
|
963
|
-
constant
|
|
964
|
-
constant
|
|
965
|
-
constant
|
|
966
|
-
constant int64_t &
|
|
967
|
-
constant
|
|
968
|
-
constant
|
|
975
|
+
constant int64_t & ne01,
|
|
976
|
+
constant int64_t & ne02,
|
|
977
|
+
constant uint64_t & nb00,
|
|
978
|
+
constant uint64_t & nb01,
|
|
979
|
+
constant uint64_t & nb02,
|
|
980
|
+
constant int64_t & ne10,
|
|
981
|
+
constant int64_t & ne11,
|
|
982
|
+
constant int64_t & ne12,
|
|
983
|
+
constant uint64_t & nb10,
|
|
984
|
+
constant uint64_t & nb11,
|
|
985
|
+
constant uint64_t & nb12,
|
|
986
|
+
constant int64_t & ne0,
|
|
987
|
+
constant int64_t & ne1,
|
|
988
|
+
constant uint & r2,
|
|
989
|
+
constant uint & r3,
|
|
969
990
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
970
991
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
971
992
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
977
998
|
device const float * src1,
|
|
978
999
|
device float * dst,
|
|
979
1000
|
constant int64_t & ne00,
|
|
980
|
-
constant int64_t & ne01
|
|
981
|
-
constant int64_t & ne02
|
|
982
|
-
constant
|
|
983
|
-
constant
|
|
984
|
-
constant
|
|
985
|
-
constant int64_t &
|
|
986
|
-
constant
|
|
987
|
-
constant
|
|
1001
|
+
constant int64_t & ne01,
|
|
1002
|
+
constant int64_t & ne02,
|
|
1003
|
+
constant uint64_t & nb00,
|
|
1004
|
+
constant uint64_t & nb01,
|
|
1005
|
+
constant uint64_t & nb02,
|
|
1006
|
+
constant int64_t & ne10,
|
|
1007
|
+
constant int64_t & ne11,
|
|
1008
|
+
constant int64_t & ne12,
|
|
1009
|
+
constant uint64_t & nb10,
|
|
1010
|
+
constant uint64_t & nb11,
|
|
1011
|
+
constant uint64_t & nb12,
|
|
1012
|
+
constant int64_t & ne0,
|
|
1013
|
+
constant int64_t & ne1,
|
|
1014
|
+
constant uint & r2,
|
|
1015
|
+
constant uint & r3,
|
|
988
1016
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
989
1017
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
990
1018
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
1071
1099
|
constant int64_t & ne00,
|
|
1072
1100
|
constant int64_t & ne01,
|
|
1073
1101
|
constant int64_t & ne02,
|
|
1102
|
+
constant uint64_t & nb00,
|
|
1103
|
+
constant uint64_t & nb01,
|
|
1104
|
+
constant uint64_t & nb02,
|
|
1074
1105
|
constant int64_t & ne10,
|
|
1106
|
+
constant int64_t & ne11,
|
|
1075
1107
|
constant int64_t & ne12,
|
|
1108
|
+
constant uint64_t & nb10,
|
|
1109
|
+
constant uint64_t & nb11,
|
|
1110
|
+
constant uint64_t & nb12,
|
|
1076
1111
|
constant int64_t & ne0,
|
|
1077
1112
|
constant int64_t & ne1,
|
|
1078
|
-
constant uint & r2
|
|
1079
|
-
constant uint & r3
|
|
1113
|
+
constant uint & r2,
|
|
1114
|
+
constant uint & r3,
|
|
1080
1115
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1081
1116
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1082
1117
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
|
|
|
1182
1217
|
constant uint64_t & nb12,
|
|
1183
1218
|
constant int64_t & ne0,
|
|
1184
1219
|
constant int64_t & ne1,
|
|
1185
|
-
constant uint & r2
|
|
1186
|
-
constant uint & r3
|
|
1220
|
+
constant uint & r2,
|
|
1221
|
+
constant uint & r3,
|
|
1187
1222
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1188
1223
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1189
1224
|
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
|
|
|
1209
1244
|
constant uint64_t & nb12,
|
|
1210
1245
|
constant int64_t & ne0,
|
|
1211
1246
|
constant int64_t & ne1,
|
|
1212
|
-
constant uint & r2
|
|
1213
|
-
constant uint & r3
|
|
1247
|
+
constant uint & r2,
|
|
1248
|
+
constant uint & r3,
|
|
1214
1249
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1215
1250
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1216
1251
|
|
|
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
|
|
|
1346
1381
|
constant uint64_t & nb12,
|
|
1347
1382
|
constant int64_t & ne0,
|
|
1348
1383
|
constant int64_t & ne1,
|
|
1349
|
-
constant uint & r2
|
|
1350
|
-
constant uint & r3
|
|
1384
|
+
constant uint & r2,
|
|
1385
|
+
constant uint & r3,
|
|
1351
1386
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1352
1387
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1353
1388
|
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
|
|
|
1452
1487
|
constant uint64_t & nb12,
|
|
1453
1488
|
constant int64_t & ne0,
|
|
1454
1489
|
constant int64_t & ne1,
|
|
1455
|
-
constant uint & r2
|
|
1456
|
-
constant uint & r3
|
|
1490
|
+
constant uint & r2,
|
|
1491
|
+
constant uint & r3,
|
|
1457
1492
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1458
1493
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1459
1494
|
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
|
1478
1513
|
constant uint64_t & nb12,
|
|
1479
1514
|
constant int64_t & ne0,
|
|
1480
1515
|
constant int64_t & ne1,
|
|
1481
|
-
constant uint & r2
|
|
1482
|
-
constant uint & r3
|
|
1516
|
+
constant uint & r2,
|
|
1517
|
+
constant uint & r3,
|
|
1483
1518
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1484
1519
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
1485
1520
|
|
|
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
|
|
|
1543
1578
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
1544
1579
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
1545
1580
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
1546
|
-
|
|
1581
|
+
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
1582
|
+
|
|
1547
1583
|
const int64_t k = i3*ne3 + i2;
|
|
1548
1584
|
|
|
1549
1585
|
float m_k;
|
|
@@ -2410,21 +2446,18 @@ typedef struct {
|
|
|
2410
2446
|
} block_q6_K;
|
|
2411
2447
|
// 210 bytes / block
|
|
2412
2448
|
|
|
2413
|
-
|
|
2414
|
-
|
|
2415
|
-
|
|
2416
|
-
|
|
2417
|
-
|
|
2418
|
-
|
|
2419
|
-
|
|
2420
|
-
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
}
|
|
2426
|
-
return r;
|
|
2427
|
-
}
|
|
2449
|
+
typedef struct {
|
|
2450
|
+
half d;
|
|
2451
|
+
uint16_t qs[QK_K/8];
|
|
2452
|
+
} block_iq2_xxs;
|
|
2453
|
+
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
|
2454
|
+
|
|
2455
|
+
typedef struct {
|
|
2456
|
+
half d;
|
|
2457
|
+
uint16_t qs[QK_K/8];
|
|
2458
|
+
uint8_t scales[QK_K/32];
|
|
2459
|
+
} block_iq2_xs;
|
|
2460
|
+
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
|
2428
2461
|
|
|
2429
2462
|
//====================================== dot products =========================
|
|
2430
2463
|
|
|
@@ -2584,14 +2617,21 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
2584
2617
|
device const float * src1,
|
|
2585
2618
|
device float * dst,
|
|
2586
2619
|
constant int64_t & ne00,
|
|
2587
|
-
constant int64_t & ne01
|
|
2588
|
-
constant int64_t & ne02
|
|
2589
|
-
constant
|
|
2590
|
-
constant
|
|
2591
|
-
constant
|
|
2592
|
-
constant int64_t &
|
|
2593
|
-
constant
|
|
2594
|
-
constant
|
|
2620
|
+
constant int64_t & ne01,
|
|
2621
|
+
constant int64_t & ne02,
|
|
2622
|
+
constant uint64_t & nb00,
|
|
2623
|
+
constant uint64_t & nb01,
|
|
2624
|
+
constant uint64_t & nb02,
|
|
2625
|
+
constant int64_t & ne10,
|
|
2626
|
+
constant int64_t & ne11,
|
|
2627
|
+
constant int64_t & ne12,
|
|
2628
|
+
constant uint64_t & nb10,
|
|
2629
|
+
constant uint64_t & nb11,
|
|
2630
|
+
constant uint64_t & nb12,
|
|
2631
|
+
constant int64_t & ne0,
|
|
2632
|
+
constant int64_t & ne1,
|
|
2633
|
+
constant uint & r2,
|
|
2634
|
+
constant uint & r3,
|
|
2595
2635
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2596
2636
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2597
2637
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2841,14 +2881,21 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
2841
2881
|
device const float * src1,
|
|
2842
2882
|
device float * dst,
|
|
2843
2883
|
constant int64_t & ne00,
|
|
2844
|
-
constant int64_t & ne01
|
|
2845
|
-
constant int64_t & ne02
|
|
2846
|
-
constant
|
|
2847
|
-
constant
|
|
2848
|
-
constant
|
|
2849
|
-
constant int64_t &
|
|
2850
|
-
constant
|
|
2851
|
-
constant
|
|
2884
|
+
constant int64_t & ne01,
|
|
2885
|
+
constant int64_t & ne02,
|
|
2886
|
+
constant uint64_t & nb00,
|
|
2887
|
+
constant uint64_t & nb01,
|
|
2888
|
+
constant uint64_t & nb02,
|
|
2889
|
+
constant int64_t & ne10,
|
|
2890
|
+
constant int64_t & ne11,
|
|
2891
|
+
constant int64_t & ne12,
|
|
2892
|
+
constant uint64_t & nb10,
|
|
2893
|
+
constant uint64_t & nb11,
|
|
2894
|
+
constant uint64_t & nb12,
|
|
2895
|
+
constant int64_t & ne0,
|
|
2896
|
+
constant int64_t & ne1,
|
|
2897
|
+
constant uint & r2,
|
|
2898
|
+
constant uint & r3,
|
|
2852
2899
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2853
2900
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
2854
2901
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -2984,8 +3031,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
2984
3031
|
constant uint & r2,
|
|
2985
3032
|
constant uint & r3,
|
|
2986
3033
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2987
|
-
uint
|
|
2988
|
-
uint
|
|
3034
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3035
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
2989
3036
|
|
|
2990
3037
|
const int ix = tiisg/4; // 0...7
|
|
2991
3038
|
const int it = tiisg%4; // 0...3
|
|
@@ -2994,7 +3041,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
2994
3041
|
const int r0 = tgpig.x;
|
|
2995
3042
|
const int r1 = tgpig.y;
|
|
2996
3043
|
const int im = tgpig.z;
|
|
2997
|
-
const int first_row =
|
|
3044
|
+
const int first_row = r0 * N_DST;
|
|
2998
3045
|
const int ib_row = first_row * nb;
|
|
2999
3046
|
|
|
3000
3047
|
const uint i12 = im%ne12;
|
|
@@ -3060,7 +3107,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
3060
3107
|
for (int row = 0; row < N_DST; ++row) {
|
|
3061
3108
|
all_sum = simd_sum(sumf[row]);
|
|
3062
3109
|
if (tiisg == 0) {
|
|
3063
|
-
dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
|
|
3110
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
3064
3111
|
}
|
|
3065
3112
|
}
|
|
3066
3113
|
}
|
|
@@ -3072,14 +3119,21 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
3072
3119
|
device const float * src1,
|
|
3073
3120
|
device float * dst,
|
|
3074
3121
|
constant int64_t & ne00,
|
|
3075
|
-
constant int64_t & ne01
|
|
3076
|
-
constant int64_t & ne02
|
|
3077
|
-
constant
|
|
3078
|
-
constant
|
|
3079
|
-
constant
|
|
3080
|
-
constant int64_t &
|
|
3081
|
-
constant
|
|
3082
|
-
constant
|
|
3122
|
+
constant int64_t & ne01,
|
|
3123
|
+
constant int64_t & ne02,
|
|
3124
|
+
constant uint64_t & nb00,
|
|
3125
|
+
constant uint64_t & nb01,
|
|
3126
|
+
constant uint64_t & nb02,
|
|
3127
|
+
constant int64_t & ne10,
|
|
3128
|
+
constant int64_t & ne11,
|
|
3129
|
+
constant int64_t & ne12,
|
|
3130
|
+
constant uint64_t & nb10,
|
|
3131
|
+
constant uint64_t & nb11,
|
|
3132
|
+
constant uint64_t & nb12,
|
|
3133
|
+
constant int64_t & ne0,
|
|
3134
|
+
constant int64_t & ne1,
|
|
3135
|
+
constant uint & r2,
|
|
3136
|
+
constant uint & r3,
|
|
3083
3137
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3084
3138
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
3085
3139
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3271,14 +3325,21 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
3271
3325
|
device const float * src1,
|
|
3272
3326
|
device float * dst,
|
|
3273
3327
|
constant int64_t & ne00,
|
|
3274
|
-
constant int64_t & ne01
|
|
3275
|
-
constant int64_t & ne02
|
|
3276
|
-
constant
|
|
3277
|
-
constant
|
|
3278
|
-
constant
|
|
3279
|
-
constant int64_t &
|
|
3280
|
-
constant
|
|
3281
|
-
constant
|
|
3328
|
+
constant int64_t & ne01,
|
|
3329
|
+
constant int64_t & ne02,
|
|
3330
|
+
constant uint64_t & nb00,
|
|
3331
|
+
constant uint64_t & nb01,
|
|
3332
|
+
constant uint64_t & nb02,
|
|
3333
|
+
constant int64_t & ne10,
|
|
3334
|
+
constant int64_t & ne11,
|
|
3335
|
+
constant int64_t & ne12,
|
|
3336
|
+
constant uint64_t & nb10,
|
|
3337
|
+
constant uint64_t & nb11,
|
|
3338
|
+
constant uint64_t & nb12,
|
|
3339
|
+
constant int64_t & ne0,
|
|
3340
|
+
constant int64_t & ne1,
|
|
3341
|
+
constant uint & r2,
|
|
3342
|
+
constant uint & r3,
|
|
3282
3343
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3283
3344
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
3284
3345
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3398,14 +3459,21 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
3398
3459
|
device const float * src1,
|
|
3399
3460
|
device float * dst,
|
|
3400
3461
|
constant int64_t & ne00,
|
|
3401
|
-
constant int64_t & ne01
|
|
3402
|
-
constant int64_t & ne02
|
|
3403
|
-
constant
|
|
3404
|
-
constant
|
|
3405
|
-
constant
|
|
3406
|
-
constant int64_t &
|
|
3407
|
-
constant
|
|
3408
|
-
constant
|
|
3462
|
+
constant int64_t & ne01,
|
|
3463
|
+
constant int64_t & ne02,
|
|
3464
|
+
constant uint64_t & nb00,
|
|
3465
|
+
constant uint64_t & nb01,
|
|
3466
|
+
constant uint64_t & nb02,
|
|
3467
|
+
constant int64_t & ne10,
|
|
3468
|
+
constant int64_t & ne11,
|
|
3469
|
+
constant int64_t & ne12,
|
|
3470
|
+
constant uint64_t & nb10,
|
|
3471
|
+
constant uint64_t & nb11,
|
|
3472
|
+
constant uint64_t & nb12,
|
|
3473
|
+
constant int64_t & ne0,
|
|
3474
|
+
constant int64_t & ne1,
|
|
3475
|
+
constant uint & r2,
|
|
3476
|
+
constant uint & r3,
|
|
3409
3477
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3410
3478
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
3411
3479
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
@@ -3413,53 +3481,542 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
3413
3481
|
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
3414
3482
|
}
|
|
3415
3483
|
|
|
3416
|
-
|
|
3484
|
+
// ======================= "True" 2-bit
|
|
3485
|
+
|
|
3486
|
+
constexpr constant static uint64_t iq2xxs_grid[256] = {
|
|
3487
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3488
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
|
3489
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
|
3490
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
|
3491
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
|
3492
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
|
3493
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
|
3494
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
|
3495
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
|
3496
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
|
3497
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
|
3498
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
|
3499
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
|
3500
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
|
3501
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
|
3502
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
|
3503
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
|
3504
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
|
3505
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
|
3506
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
|
3507
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
|
3508
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
|
3509
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
|
3510
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
|
3511
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
|
3512
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
|
3513
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
|
3514
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
|
3515
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
|
3516
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
|
3517
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
|
3518
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
|
3519
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
|
3520
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
|
3521
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
|
3522
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
|
3523
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
|
3524
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
|
3525
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
|
3526
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
|
3527
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
|
3528
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
|
3529
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
|
3530
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
|
3531
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
|
3532
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
|
3533
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
|
3534
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
|
3535
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
|
3536
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
|
3537
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
|
3538
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
|
3539
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
|
3540
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
|
3541
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
|
3542
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
|
3543
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
|
3544
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
|
3545
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
|
3546
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
|
3547
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
|
3548
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
|
3549
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
|
3550
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
|
3551
|
+
};
|
|
3417
3552
|
|
|
3418
|
-
|
|
3419
|
-
|
|
3420
|
-
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
3424
|
-
|
|
3425
|
-
|
|
3553
|
+
constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
3554
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3555
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
|
3556
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
|
3557
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
|
3558
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
|
3559
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
|
3560
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
|
3561
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
|
3562
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
|
3563
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
|
3564
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
|
3565
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
|
3566
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
|
3567
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
|
3568
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
|
3569
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
|
3570
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
|
3571
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
|
3572
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
|
3573
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
|
3574
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
|
3575
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
|
3576
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
|
3577
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
|
3578
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
|
3579
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
|
3580
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
|
3581
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
|
3582
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
|
3583
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
|
3584
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
|
3585
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
|
3586
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
|
3587
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
|
3588
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
|
3589
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
|
3590
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
|
3591
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
|
3592
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
|
3593
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
|
3594
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
|
3595
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
|
3596
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
|
3597
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
|
3598
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
|
3599
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
|
3600
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
|
3601
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
|
3602
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
|
3603
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
|
3604
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
|
3605
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
|
3606
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
|
3607
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
|
3608
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
|
3609
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
|
3610
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
|
3611
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
|
3612
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
|
3613
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
|
3614
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
|
3615
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
|
3616
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
|
3617
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
|
3618
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
|
3619
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
|
3620
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
|
3621
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
|
3622
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
|
3623
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
|
3624
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
|
3625
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
|
3626
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
|
3627
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
|
3628
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
|
3629
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
|
3630
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
|
3631
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
|
3632
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
|
3633
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
|
3634
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
|
3635
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
|
3636
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
|
3637
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
|
3638
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
|
3639
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
|
3640
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
|
3641
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
|
3642
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
|
3643
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
|
3644
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
|
3645
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
|
3646
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
|
3647
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
|
3648
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
|
3649
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
|
3650
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
|
3651
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
|
3652
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
|
3653
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
|
3654
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
|
3655
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
|
3656
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
|
3657
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
|
3658
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
|
3659
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
|
3660
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
|
3661
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
|
3662
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
|
3663
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
|
3664
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
|
3665
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
|
3666
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
|
3667
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
|
3668
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
|
3669
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
|
3670
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
|
3671
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
|
3672
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
|
3673
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
|
3674
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
|
3675
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
|
3676
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
|
3677
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
|
3678
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
|
3679
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
|
3680
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
|
3681
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
|
3682
|
+
};
|
|
3426
3683
|
|
|
3427
|
-
|
|
3428
|
-
|
|
3429
|
-
|
|
3430
|
-
|
|
3431
|
-
|
|
3432
|
-
|
|
3433
|
-
|
|
3684
|
+
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3685
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
3686
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
3687
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
|
3688
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
|
3689
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
|
3690
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
|
3691
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
|
3692
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
|
3693
|
+
};
|
|
3434
3694
|
|
|
3435
|
-
|
|
3436
|
-
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
|
3437
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
3438
|
-
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
3439
|
-
const float d2 = d1 / 256.f;
|
|
3440
|
-
const float md = -8.h * xb->d;
|
|
3441
|
-
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
3442
|
-
const ushort mask1 = mask0 << 8;
|
|
3695
|
+
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
|
3443
3696
|
|
|
3444
|
-
|
|
3445
|
-
|
|
3446
|
-
|
|
3447
|
-
|
|
3448
|
-
|
|
3697
|
+
void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3698
|
+
device const void * src0,
|
|
3699
|
+
device const float * src1,
|
|
3700
|
+
device float * dst,
|
|
3701
|
+
constant int64_t & ne00,
|
|
3702
|
+
constant int64_t & ne01,
|
|
3703
|
+
constant int64_t & ne02,
|
|
3704
|
+
constant int64_t & ne10,
|
|
3705
|
+
constant int64_t & ne12,
|
|
3706
|
+
constant int64_t & ne0,
|
|
3707
|
+
constant int64_t & ne1,
|
|
3708
|
+
constant uint & r2,
|
|
3709
|
+
constant uint & r3,
|
|
3710
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3711
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3449
3714
|
|
|
3450
|
-
|
|
3451
|
-
|
|
3452
|
-
|
|
3453
|
-
const
|
|
3454
|
-
const float d2 = d1 / 256.f;
|
|
3455
|
-
const float m = xb->m;
|
|
3456
|
-
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
3457
|
-
const ushort mask1 = mask0 << 8;
|
|
3715
|
+
const int nb = ne00/QK_K;
|
|
3716
|
+
const int r0 = tgpig.x;
|
|
3717
|
+
const int r1 = tgpig.y;
|
|
3718
|
+
const int im = tgpig.z;
|
|
3458
3719
|
|
|
3459
|
-
|
|
3460
|
-
|
|
3461
|
-
|
|
3462
|
-
|
|
3720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
3721
|
+
const int ib_row = first_row * nb;
|
|
3722
|
+
|
|
3723
|
+
const uint i12 = im%ne12;
|
|
3724
|
+
const uint i13 = im/ne12;
|
|
3725
|
+
|
|
3726
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3727
|
+
|
|
3728
|
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
|
|
3729
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
3730
|
+
|
|
3731
|
+
float yl[32];
|
|
3732
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
3733
|
+
|
|
3734
|
+
const int nb32 = nb * (QK_K / 32);
|
|
3735
|
+
|
|
3736
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
3737
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
|
3738
|
+
{
|
|
3739
|
+
int nval = 4;
|
|
3740
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
3741
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
|
|
3742
|
+
nval = 2;
|
|
3743
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
3744
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
3745
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3746
|
+
}
|
|
3747
|
+
|
|
3748
|
+
#if QK_K == 256
|
|
3749
|
+
const int ix = tiisg;
|
|
3750
|
+
|
|
3751
|
+
device const float * y4 = y + 32 * ix;
|
|
3752
|
+
|
|
3753
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
3754
|
+
|
|
3755
|
+
for (int i = 0; i < 32; ++i) {
|
|
3756
|
+
yl[i] = y4[i];
|
|
3757
|
+
}
|
|
3758
|
+
|
|
3759
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
3760
|
+
const int ib = ib32 % (QK_K / 32);
|
|
3761
|
+
|
|
3762
|
+
device const block_iq2_xxs * xr = x + ibl;
|
|
3763
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
3764
|
+
device const half * dh = &xr->d;
|
|
3765
|
+
|
|
3766
|
+
for (int row = 0; row < N_DST; row++) {
|
|
3767
|
+
|
|
3768
|
+
const float db = dh[0];
|
|
3769
|
+
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
|
3770
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
|
3771
|
+
const float d = db * (0.5f + (aux32 >> 28));
|
|
3772
|
+
|
|
3773
|
+
float sum = 0;
|
|
3774
|
+
for (int l = 0; l < 4; ++l) {
|
|
3775
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
|
3776
|
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
|
3777
|
+
for (int j = 0; j < 8; ++j) {
|
|
3778
|
+
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3779
|
+
}
|
|
3780
|
+
}
|
|
3781
|
+
sumf[row] += d * sum;
|
|
3782
|
+
|
|
3783
|
+
dh += nb*sizeof(block_iq2_xxs)/2;
|
|
3784
|
+
q2 += nb*sizeof(block_iq2_xxs)/2;
|
|
3785
|
+
}
|
|
3786
|
+
|
|
3787
|
+
y4 += 32 * 32;
|
|
3788
|
+
}
|
|
3789
|
+
#else
|
|
3790
|
+
// TODO
|
|
3791
|
+
#endif
|
|
3792
|
+
|
|
3793
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
3794
|
+
all_sum = simd_sum(sumf[row]);
|
|
3795
|
+
if (tiisg == 0) {
|
|
3796
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
3797
|
+
}
|
|
3798
|
+
}
|
|
3799
|
+
}
|
|
3800
|
+
|
|
3801
|
+
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
|
3802
|
+
kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
3803
|
+
device const void * src0,
|
|
3804
|
+
device const float * src1,
|
|
3805
|
+
device float * dst,
|
|
3806
|
+
constant int64_t & ne00,
|
|
3807
|
+
constant int64_t & ne01,
|
|
3808
|
+
constant int64_t & ne02,
|
|
3809
|
+
constant uint64_t & nb00,
|
|
3810
|
+
constant uint64_t & nb01,
|
|
3811
|
+
constant uint64_t & nb02,
|
|
3812
|
+
constant int64_t & ne10,
|
|
3813
|
+
constant int64_t & ne11,
|
|
3814
|
+
constant int64_t & ne12,
|
|
3815
|
+
constant uint64_t & nb10,
|
|
3816
|
+
constant uint64_t & nb11,
|
|
3817
|
+
constant uint64_t & nb12,
|
|
3818
|
+
constant int64_t & ne0,
|
|
3819
|
+
constant int64_t & ne1,
|
|
3820
|
+
constant uint & r2,
|
|
3821
|
+
constant uint & r3,
|
|
3822
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3823
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3824
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3825
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3826
|
+
|
|
3827
|
+
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3828
|
+
}
|
|
3829
|
+
|
|
3830
|
+
void kernel_mul_mv_iq2_xs_f32_impl(
|
|
3831
|
+
device const void * src0,
|
|
3832
|
+
device const float * src1,
|
|
3833
|
+
device float * dst,
|
|
3834
|
+
constant int64_t & ne00,
|
|
3835
|
+
constant int64_t & ne01,
|
|
3836
|
+
constant int64_t & ne02,
|
|
3837
|
+
constant int64_t & ne10,
|
|
3838
|
+
constant int64_t & ne12,
|
|
3839
|
+
constant int64_t & ne0,
|
|
3840
|
+
constant int64_t & ne1,
|
|
3841
|
+
constant uint & r2,
|
|
3842
|
+
constant uint & r3,
|
|
3843
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3844
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3845
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3846
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3847
|
+
|
|
3848
|
+
const int nb = ne00/QK_K;
|
|
3849
|
+
const int r0 = tgpig.x;
|
|
3850
|
+
const int r1 = tgpig.y;
|
|
3851
|
+
const int im = tgpig.z;
|
|
3852
|
+
|
|
3853
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
3854
|
+
const int ib_row = first_row * nb;
|
|
3855
|
+
|
|
3856
|
+
const uint i12 = im%ne12;
|
|
3857
|
+
const uint i13 = im/ne12;
|
|
3858
|
+
|
|
3859
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
3860
|
+
|
|
3861
|
+
device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
|
|
3862
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
3863
|
+
|
|
3864
|
+
float yl[32];
|
|
3865
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
3866
|
+
|
|
3867
|
+
const int nb32 = nb * (QK_K / 32);
|
|
3868
|
+
|
|
3869
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
3870
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
|
3871
|
+
{
|
|
3872
|
+
int nval = 8;
|
|
3873
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
3874
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
|
|
3875
|
+
nval = 2;
|
|
3876
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
3877
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
3878
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3879
|
+
}
|
|
3880
|
+
|
|
3881
|
+
#if QK_K == 256
|
|
3882
|
+
const int ix = tiisg;
|
|
3883
|
+
|
|
3884
|
+
device const float * y4 = y + 32 * ix;
|
|
3885
|
+
|
|
3886
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
3887
|
+
|
|
3888
|
+
for (int i = 0; i < 32; ++i) {
|
|
3889
|
+
yl[i] = y4[i];
|
|
3890
|
+
}
|
|
3891
|
+
|
|
3892
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
3893
|
+
const int ib = ib32 % (QK_K / 32);
|
|
3894
|
+
|
|
3895
|
+
device const block_iq2_xs * xr = x + ibl;
|
|
3896
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
|
3897
|
+
device const uint8_t * sc = xr->scales + ib;
|
|
3898
|
+
device const half * dh = &xr->d;
|
|
3899
|
+
|
|
3900
|
+
for (int row = 0; row < N_DST; row++) {
|
|
3901
|
+
|
|
3902
|
+
const float db = dh[0];
|
|
3903
|
+
const uint8_t ls1 = sc[0] & 0xf;
|
|
3904
|
+
const uint8_t ls2 = sc[0] >> 4;
|
|
3905
|
+
const float d1 = db * (0.5f + ls1);
|
|
3906
|
+
const float d2 = db * (0.5f + ls2);
|
|
3907
|
+
|
|
3908
|
+
float sum1 = 0, sum2 = 0;
|
|
3909
|
+
for (int l = 0; l < 2; ++l) {
|
|
3910
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
|
3911
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
|
3912
|
+
for (int j = 0; j < 8; ++j) {
|
|
3913
|
+
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3914
|
+
}
|
|
3915
|
+
}
|
|
3916
|
+
for (int l = 2; l < 4; ++l) {
|
|
3917
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
|
3918
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
|
3919
|
+
for (int j = 0; j < 8; ++j) {
|
|
3920
|
+
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
3921
|
+
}
|
|
3922
|
+
}
|
|
3923
|
+
sumf[row] += d1 * sum1 + d2 * sum2;
|
|
3924
|
+
|
|
3925
|
+
dh += nb*sizeof(block_iq2_xs)/2;
|
|
3926
|
+
q2 += nb*sizeof(block_iq2_xs)/2;
|
|
3927
|
+
sc += nb*sizeof(block_iq2_xs);
|
|
3928
|
+
}
|
|
3929
|
+
|
|
3930
|
+
y4 += 32 * 32;
|
|
3931
|
+
}
|
|
3932
|
+
#else
|
|
3933
|
+
// TODO
|
|
3934
|
+
#endif
|
|
3935
|
+
|
|
3936
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
3937
|
+
all_sum = simd_sum(sumf[row]);
|
|
3938
|
+
if (tiisg == 0) {
|
|
3939
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
3940
|
+
}
|
|
3941
|
+
}
|
|
3942
|
+
}
|
|
3943
|
+
|
|
3944
|
+
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
|
3945
|
+
kernel void kernel_mul_mv_iq2_xs_f32(
|
|
3946
|
+
device const void * src0,
|
|
3947
|
+
device const float * src1,
|
|
3948
|
+
device float * dst,
|
|
3949
|
+
constant int64_t & ne00,
|
|
3950
|
+
constant int64_t & ne01,
|
|
3951
|
+
constant int64_t & ne02,
|
|
3952
|
+
constant uint64_t & nb00,
|
|
3953
|
+
constant uint64_t & nb01,
|
|
3954
|
+
constant uint64_t & nb02,
|
|
3955
|
+
constant int64_t & ne10,
|
|
3956
|
+
constant int64_t & ne11,
|
|
3957
|
+
constant int64_t & ne12,
|
|
3958
|
+
constant uint64_t & nb10,
|
|
3959
|
+
constant uint64_t & nb11,
|
|
3960
|
+
constant uint64_t & nb12,
|
|
3961
|
+
constant int64_t & ne0,
|
|
3962
|
+
constant int64_t & ne1,
|
|
3963
|
+
constant uint & r2,
|
|
3964
|
+
constant uint & r3,
|
|
3965
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
3966
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3967
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
3968
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3969
|
+
|
|
3970
|
+
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3971
|
+
}
|
|
3972
|
+
|
|
3973
|
+
//============================= templates and their specializations =============================
|
|
3974
|
+
|
|
3975
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
3976
|
+
template <typename type4x4>
|
|
3977
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
3978
|
+
float4x4 temp = *(((device float4x4 *)src));
|
|
3979
|
+
for (int i = 0; i < 16; i++){
|
|
3980
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
3981
|
+
}
|
|
3982
|
+
}
|
|
3983
|
+
|
|
3984
|
+
template <typename type4x4>
|
|
3985
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
|
3986
|
+
half4x4 temp = *(((device half4x4 *)src));
|
|
3987
|
+
for (int i = 0; i < 16; i++){
|
|
3988
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
3989
|
+
}
|
|
3990
|
+
}
|
|
3991
|
+
|
|
3992
|
+
template <typename type4x4>
|
|
3993
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
|
3994
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
3995
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
3996
|
+
const float d2 = d1 / 256.f;
|
|
3997
|
+
const float md = -8.h * xb->d;
|
|
3998
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
3999
|
+
const ushort mask1 = mask0 << 8;
|
|
4000
|
+
|
|
4001
|
+
for (int i=0;i<8;i++) {
|
|
4002
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
|
4003
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
|
4004
|
+
}
|
|
4005
|
+
}
|
|
4006
|
+
|
|
4007
|
+
template <typename type4x4>
|
|
4008
|
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
|
4009
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
4010
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
4011
|
+
const float d2 = d1 / 256.f;
|
|
4012
|
+
const float m = xb->m;
|
|
4013
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
4014
|
+
const ushort mask1 = mask0 << 8;
|
|
4015
|
+
|
|
4016
|
+
for (int i=0;i<8;i++) {
|
|
4017
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
|
4018
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
|
4019
|
+
}
|
|
3463
4020
|
}
|
|
3464
4021
|
|
|
3465
4022
|
template <typename type4x4>
|
|
@@ -3523,7 +4080,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
|
3523
4080
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
3524
4081
|
const half d = xb->d;
|
|
3525
4082
|
|
|
3526
|
-
for (int i=0;i<16;i++) {
|
|
4083
|
+
for (int i = 0; i < 16; i++) {
|
|
3527
4084
|
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
|
3528
4085
|
}
|
|
3529
4086
|
}
|
|
@@ -3565,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
|
3565
4122
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
|
3566
4123
|
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
|
3567
4124
|
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
|
3568
|
-
|
|
3569
|
-
const
|
|
4125
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
|
4126
|
+
const float ml = 4.f * dl;
|
|
3570
4127
|
|
|
3571
4128
|
il = (il/2) & 3;
|
|
3572
4129
|
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
@@ -3633,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
|
3633
4190
|
uint8_t ul = 1 << (il/2);
|
|
3634
4191
|
il = il & 3;
|
|
3635
4192
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
3636
|
-
const float d = il < 2 ? xb->d : xb->d / 16.
|
|
4193
|
+
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
|
3637
4194
|
const float min = xb->dmin;
|
|
3638
4195
|
const float dl = d * sc[0];
|
|
3639
4196
|
const float ml = min * sc[1];
|
|
@@ -3666,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
3666
4223
|
#if QK_K == 256
|
|
3667
4224
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
|
3668
4225
|
qh = qh + 32*(il/8) + 16*(il&1);
|
|
3669
|
-
|
|
4226
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
|
3670
4227
|
il = (il/2) & 3;
|
|
3671
4228
|
#else
|
|
3672
4229
|
ql = ql + 16 * (il&1);
|
|
3673
|
-
|
|
4230
|
+
float sc = scales[il];
|
|
3674
4231
|
#endif
|
|
3675
4232
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
3676
4233
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
|
3677
|
-
const
|
|
3678
|
-
const
|
|
3679
|
-
const
|
|
4234
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
|
4235
|
+
const float ml = d_all * sc * 32.f;
|
|
4236
|
+
const float dl = d_all * sc * coef;
|
|
3680
4237
|
for (int i = 0; i < 16; ++i) {
|
|
3681
4238
|
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
|
3682
4239
|
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
|
@@ -3684,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
|
3684
4241
|
}
|
|
3685
4242
|
}
|
|
3686
4243
|
|
|
4244
|
+
template <typename type4x4>
|
|
4245
|
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
|
4246
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4247
|
+
const float d = xb->d;
|
|
4248
|
+
const int ib32 = il/2;
|
|
4249
|
+
il = il%2;
|
|
4250
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4251
|
+
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
|
4252
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4253
|
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
|
4254
|
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
|
4255
|
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
|
4256
|
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
|
4257
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
|
4258
|
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
|
4259
|
+
for (int i = 0; i < 8; ++i) {
|
|
4260
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4261
|
+
}
|
|
4262
|
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
|
4263
|
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
|
4264
|
+
for (int i = 0; i < 8; ++i) {
|
|
4265
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4266
|
+
}
|
|
4267
|
+
}
|
|
4268
|
+
|
|
4269
|
+
template <typename type4x4>
|
|
4270
|
+
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
|
4271
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4272
|
+
const float d = xb->d;
|
|
4273
|
+
const int ib32 = il/2;
|
|
4274
|
+
il = il%2;
|
|
4275
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4276
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4277
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
|
4278
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
|
4279
|
+
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
|
4280
|
+
for (int i = 0; i < 8; ++i) {
|
|
4281
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4282
|
+
}
|
|
4283
|
+
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
|
4284
|
+
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
|
4285
|
+
for (int i = 0; i < 8; ++i) {
|
|
4286
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4287
|
+
}
|
|
4288
|
+
}
|
|
4289
|
+
|
|
3687
4290
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
3688
4291
|
kernel void kernel_get_rows(
|
|
3689
4292
|
device const void * src0,
|
|
@@ -3764,48 +4367,212 @@ kernel void kernel_get_rows_f16(
|
|
|
3764
4367
|
const int64_t i10 = tgpig.x;
|
|
3765
4368
|
const int64_t i11 = tgpig.y;
|
|
3766
4369
|
|
|
3767
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4370
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4371
|
+
|
|
4372
|
+
const int64_t i02 = i11;
|
|
4373
|
+
|
|
4374
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4375
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4376
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4377
|
+
}
|
|
4378
|
+
}
|
|
4379
|
+
|
|
4380
|
+
kernel void kernel_get_rows_i32(
|
|
4381
|
+
device const void * src0,
|
|
4382
|
+
device const char * src1,
|
|
4383
|
+
device int32_t * dst,
|
|
4384
|
+
constant int64_t & ne00,
|
|
4385
|
+
constant uint64_t & nb01,
|
|
4386
|
+
constant uint64_t & nb02,
|
|
4387
|
+
constant int64_t & ne10,
|
|
4388
|
+
constant uint64_t & nb10,
|
|
4389
|
+
constant uint64_t & nb11,
|
|
4390
|
+
constant uint64_t & nb1,
|
|
4391
|
+
constant uint64_t & nb2,
|
|
4392
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4393
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4394
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4395
|
+
const int64_t i10 = tgpig.x;
|
|
4396
|
+
const int64_t i11 = tgpig.y;
|
|
4397
|
+
|
|
4398
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4399
|
+
|
|
4400
|
+
const int64_t i02 = i11;
|
|
4401
|
+
|
|
4402
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
4403
|
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4404
|
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4405
|
+
}
|
|
4406
|
+
}
|
|
4407
|
+
|
|
4408
|
+
|
|
4409
|
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
4410
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
4411
|
+
#define BLOCK_SIZE_K 32
|
|
4412
|
+
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
4413
|
+
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
4414
|
+
#define THREAD_PER_BLOCK 128
|
|
4415
|
+
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
4416
|
+
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
4417
|
+
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
4418
|
+
#define SG_MAT_ROW 8
|
|
4419
|
+
|
|
4420
|
+
// each block_q contains 16*nl weights
|
|
4421
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
4422
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
|
4423
|
+
device const uchar * src1,
|
|
4424
|
+
device float * dst,
|
|
4425
|
+
constant int64_t & ne00,
|
|
4426
|
+
constant int64_t & ne02,
|
|
4427
|
+
constant uint64_t & nb01,
|
|
4428
|
+
constant uint64_t & nb02,
|
|
4429
|
+
constant int64_t & ne12,
|
|
4430
|
+
constant uint64_t & nb10,
|
|
4431
|
+
constant uint64_t & nb11,
|
|
4432
|
+
constant uint64_t & nb12,
|
|
4433
|
+
constant int64_t & ne0,
|
|
4434
|
+
constant int64_t & ne1,
|
|
4435
|
+
constant uint & r2,
|
|
4436
|
+
constant uint & r3,
|
|
4437
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
4438
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4439
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4440
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4441
|
+
|
|
4442
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
4443
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
4444
|
+
|
|
4445
|
+
const uint r0 = tgpig.y;
|
|
4446
|
+
const uint r1 = tgpig.x;
|
|
4447
|
+
const uint im = tgpig.z;
|
|
4448
|
+
|
|
4449
|
+
// if this block is of 64x32 shape or smaller
|
|
4450
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
4451
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
4452
|
+
|
|
4453
|
+
// a thread shouldn't load data outside of the matrix
|
|
4454
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
4455
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
4456
|
+
|
|
4457
|
+
simdgroup_half8x8 ma[4];
|
|
4458
|
+
simdgroup_float8x8 mb[2];
|
|
4459
|
+
simdgroup_float8x8 c_res[8];
|
|
4460
|
+
for (int i = 0; i < 8; i++){
|
|
4461
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
4462
|
+
}
|
|
4463
|
+
|
|
4464
|
+
short il = (tiitg % THREAD_PER_ROW);
|
|
4465
|
+
|
|
4466
|
+
const uint i12 = im%ne12;
|
|
4467
|
+
const uint i13 = im/ne12;
|
|
4468
|
+
|
|
4469
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
4470
|
+
ushort offset1 = il/nl;
|
|
4471
|
+
|
|
4472
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
4473
|
+
device const float * y = (device const float *)(src1
|
|
4474
|
+
+ nb12 * im
|
|
4475
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
|
4476
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
4477
|
+
|
|
4478
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
4479
|
+
// load data and store to threadgroup memory
|
|
4480
|
+
half4x4 temp_a;
|
|
4481
|
+
dequantize_func(x, il, temp_a);
|
|
4482
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4483
|
+
|
|
4484
|
+
#pragma unroll(16)
|
|
4485
|
+
for (int i = 0; i < 16; i++) {
|
|
4486
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
4487
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
4488
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
4489
|
+
}
|
|
4490
|
+
|
|
4491
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
4492
|
+
|
|
4493
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
4494
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
4495
|
+
y += BLOCK_SIZE_K;
|
|
4496
|
+
|
|
4497
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4498
|
+
|
|
4499
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
4500
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
4501
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
4502
|
+
|
|
4503
|
+
#pragma unroll(4)
|
|
4504
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
4505
|
+
#pragma unroll(4)
|
|
4506
|
+
for (int i = 0; i < 4; i++) {
|
|
4507
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
4508
|
+
}
|
|
4509
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
4510
|
+
#pragma unroll(2)
|
|
4511
|
+
for (int i = 0; i < 2; i++) {
|
|
4512
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
4513
|
+
}
|
|
4514
|
+
|
|
4515
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
4516
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3768
4517
|
|
|
3769
|
-
|
|
4518
|
+
#pragma unroll(8)
|
|
4519
|
+
for (int i = 0; i < 8; i++){
|
|
4520
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
4521
|
+
}
|
|
4522
|
+
}
|
|
4523
|
+
}
|
|
3770
4524
|
|
|
3771
|
-
|
|
3772
|
-
|
|
3773
|
-
|
|
4525
|
+
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
|
4526
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
4527
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
4528
|
+
for (int i = 0; i < 8; i++) {
|
|
4529
|
+
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
4530
|
+
}
|
|
4531
|
+
} else {
|
|
4532
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
4533
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4534
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
4535
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
4536
|
+
for (int i = 0; i < 8; i++) {
|
|
4537
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
4538
|
+
}
|
|
4539
|
+
|
|
4540
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4541
|
+
|
|
4542
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
4543
|
+
if (sgitg == 0) {
|
|
4544
|
+
for (int i = 0; i < n_rows; i++) {
|
|
4545
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
4546
|
+
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4547
|
+
}
|
|
4548
|
+
}
|
|
4549
|
+
}
|
|
3774
4550
|
}
|
|
3775
4551
|
}
|
|
3776
4552
|
|
|
3777
|
-
|
|
3778
|
-
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
3779
|
-
#define BLOCK_SIZE_K 32
|
|
3780
|
-
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
3781
|
-
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
3782
|
-
#define THREAD_PER_BLOCK 128
|
|
3783
|
-
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
3784
|
-
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
3785
|
-
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
3786
|
-
#define SG_MAT_ROW 8
|
|
3787
|
-
|
|
3788
|
-
// each block_q contains 16*nl weights
|
|
4553
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
|
3789
4554
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
3790
|
-
void
|
|
3791
|
-
|
|
3792
|
-
|
|
3793
|
-
|
|
3794
|
-
|
|
3795
|
-
|
|
3796
|
-
|
|
3797
|
-
|
|
3798
|
-
|
|
3799
|
-
|
|
3800
|
-
|
|
3801
|
-
|
|
3802
|
-
|
|
3803
|
-
|
|
3804
|
-
|
|
3805
|
-
|
|
3806
|
-
|
|
3807
|
-
|
|
3808
|
-
|
|
4555
|
+
void kernel_mul_mm_id_impl(
|
|
4556
|
+
device const uchar * src0,
|
|
4557
|
+
device const uchar * src1,
|
|
4558
|
+
thread short * src1ids,
|
|
4559
|
+
device float * dst,
|
|
4560
|
+
constant int64_t & ne00,
|
|
4561
|
+
constant int64_t & ne02,
|
|
4562
|
+
constant uint64_t & nb01,
|
|
4563
|
+
constant uint64_t & nb02,
|
|
4564
|
+
constant int64_t & ne12,
|
|
4565
|
+
constant uint64_t & nb10,
|
|
4566
|
+
constant uint64_t & nb11,
|
|
4567
|
+
constant uint64_t & nb12,
|
|
4568
|
+
constant int64_t & ne0,
|
|
4569
|
+
int64_t ne1,
|
|
4570
|
+
constant uint & r2,
|
|
4571
|
+
constant uint & r3,
|
|
4572
|
+
threadgroup uchar * shared_memory,
|
|
4573
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4574
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4575
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3809
4576
|
|
|
3810
4577
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
3811
4578
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
@@ -3814,6 +4581,8 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3814
4581
|
const uint r1 = tgpig.x;
|
|
3815
4582
|
const uint im = tgpig.z;
|
|
3816
4583
|
|
|
4584
|
+
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
|
4585
|
+
|
|
3817
4586
|
// if this block is of 64x32 shape or smaller
|
|
3818
4587
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
3819
4588
|
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
@@ -3840,7 +4609,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3840
4609
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
3841
4610
|
device const float * y = (device const float *)(src1
|
|
3842
4611
|
+ nb12 * im
|
|
3843
|
-
+ nb11 *
|
|
4612
|
+
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
|
3844
4613
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
3845
4614
|
|
|
3846
4615
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
@@ -3849,7 +4618,6 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3849
4618
|
dequantize_func(x, il, temp_a);
|
|
3850
4619
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3851
4620
|
|
|
3852
|
-
#pragma unroll(16)
|
|
3853
4621
|
for (int i = 0; i < 16; i++) {
|
|
3854
4622
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
3855
4623
|
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
@@ -3868,14 +4636,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3868
4636
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
3869
4637
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
3870
4638
|
|
|
3871
|
-
#pragma unroll(4)
|
|
3872
4639
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
3873
|
-
#pragma unroll(4)
|
|
3874
4640
|
for (int i = 0; i < 4; i++) {
|
|
3875
4641
|
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
3876
4642
|
}
|
|
3877
4643
|
simdgroup_barrier(mem_flags::mem_none);
|
|
3878
|
-
#pragma unroll(2)
|
|
3879
4644
|
for (int i = 0; i < 2; i++) {
|
|
3880
4645
|
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
3881
4646
|
}
|
|
@@ -3883,21 +4648,13 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3883
4648
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3884
4649
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
3885
4650
|
|
|
3886
|
-
#pragma unroll(8)
|
|
3887
4651
|
for (int i = 0; i < 8; i++){
|
|
3888
4652
|
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
3889
4653
|
}
|
|
3890
4654
|
}
|
|
3891
4655
|
}
|
|
3892
4656
|
|
|
3893
|
-
|
|
3894
|
-
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
3895
|
-
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
3896
|
-
for (int i = 0; i < 8; i++) {
|
|
3897
|
-
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
3898
|
-
}
|
|
3899
|
-
} else {
|
|
3900
|
-
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
4657
|
+
{
|
|
3901
4658
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3902
4659
|
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
3903
4660
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
@@ -3907,11 +4664,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
|
3907
4664
|
|
|
3908
4665
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3909
4666
|
|
|
3910
|
-
device float * C = dst + (BLOCK_SIZE_M * r0) +
|
|
4667
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
|
3911
4668
|
if (sgitg == 0) {
|
|
3912
4669
|
for (int i = 0; i < n_rows; i++) {
|
|
3913
4670
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
3914
|
-
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4671
|
+
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
3915
4672
|
}
|
|
3916
4673
|
}
|
|
3917
4674
|
}
|
|
@@ -3924,12 +4681,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
3924
4681
|
device float * dst,
|
|
3925
4682
|
constant int64_t & ne00,
|
|
3926
4683
|
constant int64_t & ne02,
|
|
3927
|
-
constant
|
|
3928
|
-
constant
|
|
4684
|
+
constant uint64_t & nb01,
|
|
4685
|
+
constant uint64_t & nb02,
|
|
3929
4686
|
constant int64_t & ne12,
|
|
3930
|
-
constant
|
|
3931
|
-
constant
|
|
3932
|
-
constant
|
|
4687
|
+
constant uint64_t & nb10,
|
|
4688
|
+
constant uint64_t & nb11,
|
|
4689
|
+
constant uint64_t & nb12,
|
|
3933
4690
|
constant int64_t & ne0,
|
|
3934
4691
|
constant int64_t & ne1,
|
|
3935
4692
|
constant uint & r2,
|
|
@@ -3964,20 +4721,20 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|
|
3964
4721
|
kernel void kernel_mul_mm_id(
|
|
3965
4722
|
device const uchar * ids,
|
|
3966
4723
|
device const uchar * src1,
|
|
3967
|
-
device
|
|
3968
|
-
constant
|
|
4724
|
+
device float * dst,
|
|
4725
|
+
constant uint64_t & nbi1,
|
|
3969
4726
|
constant int64_t & ne00,
|
|
3970
4727
|
constant int64_t & ne02,
|
|
3971
|
-
constant
|
|
3972
|
-
constant
|
|
4728
|
+
constant uint64_t & nb01,
|
|
4729
|
+
constant uint64_t & nb02,
|
|
3973
4730
|
constant int64_t & ne12,
|
|
3974
4731
|
constant int64_t & ne13,
|
|
3975
|
-
constant
|
|
3976
|
-
constant
|
|
3977
|
-
constant
|
|
4732
|
+
constant uint64_t & nb10,
|
|
4733
|
+
constant uint64_t & nb11,
|
|
4734
|
+
constant uint64_t & nb12,
|
|
3978
4735
|
constant int64_t & ne0,
|
|
3979
4736
|
constant int64_t & ne1,
|
|
3980
|
-
constant
|
|
4737
|
+
constant uint64_t & nb1,
|
|
3981
4738
|
constant uint & r2,
|
|
3982
4739
|
constant uint & r3,
|
|
3983
4740
|
constant int & idx,
|
|
@@ -3993,18 +4750,28 @@ kernel void kernel_mul_mm_id(
|
|
|
3993
4750
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
3994
4751
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
3995
4752
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
3996
|
-
device const uchar *
|
|
4753
|
+
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
3997
4754
|
|
|
3998
|
-
|
|
4755
|
+
// expert id
|
|
4756
|
+
const int32_t id = tgpig.z/(ne12*ne13);
|
|
3999
4757
|
|
|
4000
4758
|
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4001
4759
|
|
|
4002
|
-
|
|
4760
|
+
// row indices of src1 for expert id
|
|
4761
|
+
int64_t _ne1 = 0;
|
|
4762
|
+
short src1ids[512];
|
|
4003
4763
|
|
|
4004
|
-
|
|
4005
|
-
|
|
4006
|
-
|
|
4007
|
-
|
|
4764
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
4765
|
+
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
|
4766
|
+
src1ids[_ne1++] = i1;
|
|
4767
|
+
}
|
|
4768
|
+
}
|
|
4769
|
+
|
|
4770
|
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
4771
|
+
src0s[id],
|
|
4772
|
+
src1,
|
|
4773
|
+
src1ids,
|
|
4774
|
+
dst,
|
|
4008
4775
|
ne00,
|
|
4009
4776
|
ne02,
|
|
4010
4777
|
nb01,
|
|
@@ -4014,7 +4781,7 @@ kernel void kernel_mul_mm_id(
|
|
|
4014
4781
|
nb11,
|
|
4015
4782
|
nb12,
|
|
4016
4783
|
ne0,
|
|
4017
|
-
|
|
4784
|
+
_ne1,
|
|
4018
4785
|
r2,
|
|
4019
4786
|
r3,
|
|
4020
4787
|
shared_memory,
|
|
@@ -4059,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
|
|
|
4059
4826
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4060
4827
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4061
4828
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4829
|
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4830
|
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4062
4831
|
|
|
4063
4832
|
//
|
|
4064
4833
|
// matrix-matrix multiplication
|
|
@@ -4070,12 +4839,12 @@ typedef void (mat_mm_t)(
|
|
|
4070
4839
|
device float * dst,
|
|
4071
4840
|
constant int64_t & ne00,
|
|
4072
4841
|
constant int64_t & ne02,
|
|
4073
|
-
constant
|
|
4074
|
-
constant
|
|
4842
|
+
constant uint64_t & nb01,
|
|
4843
|
+
constant uint64_t & nb02,
|
|
4075
4844
|
constant int64_t & ne12,
|
|
4076
|
-
constant
|
|
4077
|
-
constant
|
|
4078
|
-
constant
|
|
4845
|
+
constant uint64_t & nb10,
|
|
4846
|
+
constant uint64_t & nb11,
|
|
4847
|
+
constant uint64_t & nb12,
|
|
4079
4848
|
constant int64_t & ne0,
|
|
4080
4849
|
constant int64_t & ne1,
|
|
4081
4850
|
constant uint & r2,
|
|
@@ -4095,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
4095
4864
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4096
4865
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4097
4866
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4867
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4868
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4098
4869
|
|
|
4099
4870
|
//
|
|
4100
4871
|
// indirect matrix-matrix multiplication
|
|
@@ -4103,20 +4874,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
4103
4874
|
typedef void (mat_mm_id_t)(
|
|
4104
4875
|
device const uchar * ids,
|
|
4105
4876
|
device const uchar * src1,
|
|
4106
|
-
device
|
|
4107
|
-
constant
|
|
4877
|
+
device float * dst,
|
|
4878
|
+
constant uint64_t & nbi1,
|
|
4108
4879
|
constant int64_t & ne00,
|
|
4109
4880
|
constant int64_t & ne02,
|
|
4110
|
-
constant
|
|
4111
|
-
constant
|
|
4881
|
+
constant uint64_t & nb01,
|
|
4882
|
+
constant uint64_t & nb02,
|
|
4112
4883
|
constant int64_t & ne12,
|
|
4113
4884
|
constant int64_t & ne13,
|
|
4114
|
-
constant
|
|
4115
|
-
constant
|
|
4116
|
-
constant
|
|
4885
|
+
constant uint64_t & nb10,
|
|
4886
|
+
constant uint64_t & nb11,
|
|
4887
|
+
constant uint64_t & nb12,
|
|
4117
4888
|
constant int64_t & ne0,
|
|
4118
4889
|
constant int64_t & ne1,
|
|
4119
|
-
constant
|
|
4890
|
+
constant uint64_t & nb1,
|
|
4120
4891
|
constant uint & r2,
|
|
4121
4892
|
constant uint & r3,
|
|
4122
4893
|
constant int & idx,
|
|
@@ -4143,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
4143
4914
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
4144
4915
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
4145
4916
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4917
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4918
|
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
4146
4919
|
|
|
4147
4920
|
//
|
|
4148
4921
|
// matrix-vector multiplication
|
|
@@ -4152,8 +4925,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
4152
4925
|
kernel void kernel_mul_mv_id_f32_f32(
|
|
4153
4926
|
device const char * ids,
|
|
4154
4927
|
device const char * src1,
|
|
4155
|
-
device
|
|
4156
|
-
constant
|
|
4928
|
+
device float * dst,
|
|
4929
|
+
constant uint64_t & nbi1,
|
|
4157
4930
|
constant int64_t & ne00,
|
|
4158
4931
|
constant int64_t & ne01,
|
|
4159
4932
|
constant int64_t & ne02,
|
|
@@ -4169,7 +4942,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
4169
4942
|
constant uint64_t & nb12,
|
|
4170
4943
|
constant int64_t & ne0,
|
|
4171
4944
|
constant int64_t & ne1,
|
|
4172
|
-
constant
|
|
4945
|
+
constant uint64_t & nb1,
|
|
4173
4946
|
constant uint & r2,
|
|
4174
4947
|
constant uint & r3,
|
|
4175
4948
|
constant int & idx,
|
|
@@ -4196,7 +4969,7 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
4196
4969
|
kernel_mul_mv_f32_f32_impl(
|
|
4197
4970
|
src0[id],
|
|
4198
4971
|
src1 + bid*nb11,
|
|
4199
|
-
|
|
4972
|
+
dst + bid*ne0,
|
|
4200
4973
|
ne00,
|
|
4201
4974
|
ne01,
|
|
4202
4975
|
ne02,
|
|
@@ -4221,8 +4994,8 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
4221
4994
|
kernel void kernel_mul_mv_id_f16_f32(
|
|
4222
4995
|
device const char * ids,
|
|
4223
4996
|
device const char * src1,
|
|
4224
|
-
device
|
|
4225
|
-
constant
|
|
4997
|
+
device float * dst,
|
|
4998
|
+
constant uint64_t & nbi1,
|
|
4226
4999
|
constant int64_t & ne00,
|
|
4227
5000
|
constant int64_t & ne01,
|
|
4228
5001
|
constant int64_t & ne02,
|
|
@@ -4238,7 +5011,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
4238
5011
|
constant uint64_t & nb12,
|
|
4239
5012
|
constant int64_t & ne0,
|
|
4240
5013
|
constant int64_t & ne1,
|
|
4241
|
-
constant
|
|
5014
|
+
constant uint64_t & nb1,
|
|
4242
5015
|
constant uint & r2,
|
|
4243
5016
|
constant uint & r3,
|
|
4244
5017
|
constant int & idx,
|
|
@@ -4265,7 +5038,7 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
4265
5038
|
kernel_mul_mv_f16_f32_impl(
|
|
4266
5039
|
src0[id],
|
|
4267
5040
|
src1 + bid*nb11,
|
|
4268
|
-
|
|
5041
|
+
dst + bid*ne0,
|
|
4269
5042
|
ne00,
|
|
4270
5043
|
ne01,
|
|
4271
5044
|
ne02,
|
|
@@ -4290,8 +5063,8 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
4290
5063
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
4291
5064
|
device const char * ids,
|
|
4292
5065
|
device const char * src1,
|
|
4293
|
-
device
|
|
4294
|
-
constant
|
|
5066
|
+
device float * dst,
|
|
5067
|
+
constant uint64_t & nbi1,
|
|
4295
5068
|
constant int64_t & ne00,
|
|
4296
5069
|
constant int64_t & ne01,
|
|
4297
5070
|
constant int64_t & ne02,
|
|
@@ -4307,7 +5080,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
4307
5080
|
constant uint64_t & nb12,
|
|
4308
5081
|
constant int64_t & ne0,
|
|
4309
5082
|
constant int64_t & ne1,
|
|
4310
|
-
constant
|
|
5083
|
+
constant uint64_t & nb1,
|
|
4311
5084
|
constant uint & r2,
|
|
4312
5085
|
constant uint & r3,
|
|
4313
5086
|
constant int & idx,
|
|
@@ -4334,7 +5107,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
4334
5107
|
kernel_mul_mv_q8_0_f32_impl(
|
|
4335
5108
|
src0[id],
|
|
4336
5109
|
(device const float *) (src1 + bid*nb11),
|
|
4337
|
-
|
|
5110
|
+
dst + bid*ne0,
|
|
4338
5111
|
ne00,
|
|
4339
5112
|
ne01,
|
|
4340
5113
|
ne02,
|
|
@@ -4353,8 +5126,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
4353
5126
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
4354
5127
|
device const char * ids,
|
|
4355
5128
|
device const char * src1,
|
|
4356
|
-
device
|
|
4357
|
-
constant
|
|
5129
|
+
device float * dst,
|
|
5130
|
+
constant uint64_t & nbi1,
|
|
4358
5131
|
constant int64_t & ne00,
|
|
4359
5132
|
constant int64_t & ne01,
|
|
4360
5133
|
constant int64_t & ne02,
|
|
@@ -4370,7 +5143,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
4370
5143
|
constant uint64_t & nb12,
|
|
4371
5144
|
constant int64_t & ne0,
|
|
4372
5145
|
constant int64_t & ne1,
|
|
4373
|
-
constant
|
|
5146
|
+
constant uint64_t & nb1,
|
|
4374
5147
|
constant uint & r2,
|
|
4375
5148
|
constant uint & r3,
|
|
4376
5149
|
constant int & idx,
|
|
@@ -4397,7 +5170,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
4397
5170
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4398
5171
|
src0[id],
|
|
4399
5172
|
(device const float *) (src1 + bid*nb11),
|
|
4400
|
-
|
|
5173
|
+
dst + bid*ne0,
|
|
4401
5174
|
ne00,
|
|
4402
5175
|
ne01,
|
|
4403
5176
|
ne02,
|
|
@@ -4416,8 +5189,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
4416
5189
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
4417
5190
|
device const char * ids,
|
|
4418
5191
|
device const char * src1,
|
|
4419
|
-
device
|
|
4420
|
-
constant
|
|
5192
|
+
device float * dst,
|
|
5193
|
+
constant uint64_t & nbi1,
|
|
4421
5194
|
constant int64_t & ne00,
|
|
4422
5195
|
constant int64_t & ne01,
|
|
4423
5196
|
constant int64_t & ne02,
|
|
@@ -4433,7 +5206,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
4433
5206
|
constant uint64_t & nb12,
|
|
4434
5207
|
constant int64_t & ne0,
|
|
4435
5208
|
constant int64_t & ne1,
|
|
4436
|
-
constant
|
|
5209
|
+
constant uint64_t & nb1,
|
|
4437
5210
|
constant uint & r2,
|
|
4438
5211
|
constant uint & r3,
|
|
4439
5212
|
constant int & idx,
|
|
@@ -4460,7 +5233,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
4460
5233
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4461
5234
|
src0[id],
|
|
4462
5235
|
(device const float *) (src1 + bid*nb11),
|
|
4463
|
-
|
|
5236
|
+
dst + bid*ne0,
|
|
4464
5237
|
ne00,
|
|
4465
5238
|
ne01,
|
|
4466
5239
|
ne02,
|
|
@@ -4479,8 +5252,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
4479
5252
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
4480
5253
|
device const char * ids,
|
|
4481
5254
|
device const char * src1,
|
|
4482
|
-
device
|
|
4483
|
-
constant
|
|
5255
|
+
device float * dst,
|
|
5256
|
+
constant uint64_t & nbi1,
|
|
4484
5257
|
constant int64_t & ne00,
|
|
4485
5258
|
constant int64_t & ne01,
|
|
4486
5259
|
constant int64_t & ne02,
|
|
@@ -4496,7 +5269,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
4496
5269
|
constant uint64_t & nb12,
|
|
4497
5270
|
constant int64_t & ne0,
|
|
4498
5271
|
constant int64_t & ne1,
|
|
4499
|
-
constant
|
|
5272
|
+
constant uint64_t & nb1,
|
|
4500
5273
|
constant uint & r2,
|
|
4501
5274
|
constant uint & r3,
|
|
4502
5275
|
constant int & idx,
|
|
@@ -4523,7 +5296,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
4523
5296
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4524
5297
|
src0[id],
|
|
4525
5298
|
(device const float *) (src1 + bid*nb11),
|
|
4526
|
-
|
|
5299
|
+
dst + bid*ne0,
|
|
4527
5300
|
ne00,
|
|
4528
5301
|
ne01,
|
|
4529
5302
|
ne02,
|
|
@@ -4542,8 +5315,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
4542
5315
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
4543
5316
|
device const char * ids,
|
|
4544
5317
|
device const char * src1,
|
|
4545
|
-
device
|
|
4546
|
-
constant
|
|
5318
|
+
device float * dst,
|
|
5319
|
+
constant uint64_t & nbi1,
|
|
4547
5320
|
constant int64_t & ne00,
|
|
4548
5321
|
constant int64_t & ne01,
|
|
4549
5322
|
constant int64_t & ne02,
|
|
@@ -4559,7 +5332,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
4559
5332
|
constant uint64_t & nb12,
|
|
4560
5333
|
constant int64_t & ne0,
|
|
4561
5334
|
constant int64_t & ne1,
|
|
4562
|
-
constant
|
|
5335
|
+
constant uint64_t & nb1,
|
|
4563
5336
|
constant uint & r2,
|
|
4564
5337
|
constant uint & r3,
|
|
4565
5338
|
constant int & idx,
|
|
@@ -4586,7 +5359,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
4586
5359
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
4587
5360
|
src0[id],
|
|
4588
5361
|
(device const float *) (src1 + bid*nb11),
|
|
4589
|
-
|
|
5362
|
+
dst + bid*ne0,
|
|
4590
5363
|
ne00,
|
|
4591
5364
|
ne01,
|
|
4592
5365
|
ne02,
|
|
@@ -4605,8 +5378,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
4605
5378
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
4606
5379
|
device const char * ids,
|
|
4607
5380
|
device const char * src1,
|
|
4608
|
-
device
|
|
4609
|
-
constant
|
|
5381
|
+
device float * dst,
|
|
5382
|
+
constant uint64_t & nbi1,
|
|
4610
5383
|
constant int64_t & ne00,
|
|
4611
5384
|
constant int64_t & ne01,
|
|
4612
5385
|
constant int64_t & ne02,
|
|
@@ -4622,7 +5395,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
4622
5395
|
constant uint64_t & nb12,
|
|
4623
5396
|
constant int64_t & ne0,
|
|
4624
5397
|
constant int64_t & ne1,
|
|
4625
|
-
constant
|
|
5398
|
+
constant uint64_t & nb1,
|
|
4626
5399
|
constant uint & r2,
|
|
4627
5400
|
constant uint & r3,
|
|
4628
5401
|
constant int & idx,
|
|
@@ -4649,7 +5422,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
4649
5422
|
kernel_mul_mv_q2_K_f32_impl(
|
|
4650
5423
|
src0[id],
|
|
4651
5424
|
(device const float *) (src1 + bid*nb11),
|
|
4652
|
-
|
|
5425
|
+
dst + bid*ne0,
|
|
4653
5426
|
ne00,
|
|
4654
5427
|
ne01,
|
|
4655
5428
|
ne02,
|
|
@@ -4668,8 +5441,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
4668
5441
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
4669
5442
|
device const char * ids,
|
|
4670
5443
|
device const char * src1,
|
|
4671
|
-
device
|
|
4672
|
-
constant
|
|
5444
|
+
device float * dst,
|
|
5445
|
+
constant uint64_t & nbi1,
|
|
4673
5446
|
constant int64_t & ne00,
|
|
4674
5447
|
constant int64_t & ne01,
|
|
4675
5448
|
constant int64_t & ne02,
|
|
@@ -4685,7 +5458,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
4685
5458
|
constant uint64_t & nb12,
|
|
4686
5459
|
constant int64_t & ne0,
|
|
4687
5460
|
constant int64_t & ne1,
|
|
4688
|
-
constant
|
|
5461
|
+
constant uint64_t & nb1,
|
|
4689
5462
|
constant uint & r2,
|
|
4690
5463
|
constant uint & r3,
|
|
4691
5464
|
constant int & idx,
|
|
@@ -4712,7 +5485,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
4712
5485
|
kernel_mul_mv_q3_K_f32_impl(
|
|
4713
5486
|
src0[id],
|
|
4714
5487
|
(device const float *) (src1 + bid*nb11),
|
|
4715
|
-
|
|
5488
|
+
dst + bid*ne0,
|
|
4716
5489
|
ne00,
|
|
4717
5490
|
ne01,
|
|
4718
5491
|
ne02,
|
|
@@ -4731,8 +5504,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
4731
5504
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
4732
5505
|
device const char * ids,
|
|
4733
5506
|
device const char * src1,
|
|
4734
|
-
device
|
|
4735
|
-
constant
|
|
5507
|
+
device float * dst,
|
|
5508
|
+
constant uint64_t & nbi1,
|
|
4736
5509
|
constant int64_t & ne00,
|
|
4737
5510
|
constant int64_t & ne01,
|
|
4738
5511
|
constant int64_t & ne02,
|
|
@@ -4748,7 +5521,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
4748
5521
|
constant uint64_t & nb12,
|
|
4749
5522
|
constant int64_t & ne0,
|
|
4750
5523
|
constant int64_t & ne1,
|
|
4751
|
-
constant
|
|
5524
|
+
constant uint64_t & nb1,
|
|
4752
5525
|
constant uint & r2,
|
|
4753
5526
|
constant uint & r3,
|
|
4754
5527
|
constant int & idx,
|
|
@@ -4775,7 +5548,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
4775
5548
|
kernel_mul_mv_q4_K_f32_impl(
|
|
4776
5549
|
src0[id],
|
|
4777
5550
|
(device const float *) (src1 + bid*nb11),
|
|
4778
|
-
|
|
5551
|
+
dst + bid*ne0,
|
|
4779
5552
|
ne00,
|
|
4780
5553
|
ne01,
|
|
4781
5554
|
ne02,
|
|
@@ -4794,8 +5567,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
4794
5567
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
4795
5568
|
device const char * ids,
|
|
4796
5569
|
device const char * src1,
|
|
4797
|
-
device
|
|
4798
|
-
constant
|
|
5570
|
+
device float * dst,
|
|
5571
|
+
constant uint64_t & nbi1,
|
|
4799
5572
|
constant int64_t & ne00,
|
|
4800
5573
|
constant int64_t & ne01,
|
|
4801
5574
|
constant int64_t & ne02,
|
|
@@ -4811,7 +5584,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
4811
5584
|
constant uint64_t & nb12,
|
|
4812
5585
|
constant int64_t & ne0,
|
|
4813
5586
|
constant int64_t & ne1,
|
|
4814
|
-
constant
|
|
5587
|
+
constant uint64_t & nb1,
|
|
4815
5588
|
constant uint & r2,
|
|
4816
5589
|
constant uint & r3,
|
|
4817
5590
|
constant int & idx,
|
|
@@ -4838,7 +5611,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
4838
5611
|
kernel_mul_mv_q5_K_f32_impl(
|
|
4839
5612
|
src0[id],
|
|
4840
5613
|
(device const float *) (src1 + bid*nb11),
|
|
4841
|
-
|
|
5614
|
+
dst + bid*ne0,
|
|
4842
5615
|
ne00,
|
|
4843
5616
|
ne01,
|
|
4844
5617
|
ne02,
|
|
@@ -4857,8 +5630,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
4857
5630
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
4858
5631
|
device const char * ids,
|
|
4859
5632
|
device const char * src1,
|
|
4860
|
-
device
|
|
4861
|
-
constant
|
|
5633
|
+
device float * dst,
|
|
5634
|
+
constant uint64_t & nbi1,
|
|
4862
5635
|
constant int64_t & ne00,
|
|
4863
5636
|
constant int64_t & ne01,
|
|
4864
5637
|
constant int64_t & ne02,
|
|
@@ -4874,7 +5647,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
4874
5647
|
constant uint64_t & nb12,
|
|
4875
5648
|
constant int64_t & ne0,
|
|
4876
5649
|
constant int64_t & ne1,
|
|
4877
|
-
constant
|
|
5650
|
+
constant uint64_t & nb1,
|
|
4878
5651
|
constant uint & r2,
|
|
4879
5652
|
constant uint & r3,
|
|
4880
5653
|
constant int & idx,
|
|
@@ -4901,7 +5674,136 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
4901
5674
|
kernel_mul_mv_q6_K_f32_impl(
|
|
4902
5675
|
src0[id],
|
|
4903
5676
|
(device const float *) (src1 + bid*nb11),
|
|
4904
|
-
|
|
5677
|
+
dst + bid*ne0,
|
|
5678
|
+
ne00,
|
|
5679
|
+
ne01,
|
|
5680
|
+
ne02,
|
|
5681
|
+
ne10,
|
|
5682
|
+
ne12,
|
|
5683
|
+
ne0,
|
|
5684
|
+
ne1,
|
|
5685
|
+
r2,
|
|
5686
|
+
r3,
|
|
5687
|
+
tgpig,
|
|
5688
|
+
tiisg,
|
|
5689
|
+
sgitg);
|
|
5690
|
+
}
|
|
5691
|
+
|
|
5692
|
+
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
|
5693
|
+
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
5694
|
+
device const char * ids,
|
|
5695
|
+
device const char * src1,
|
|
5696
|
+
device float * dst,
|
|
5697
|
+
constant uint64_t & nbi1,
|
|
5698
|
+
constant int64_t & ne00,
|
|
5699
|
+
constant int64_t & ne01,
|
|
5700
|
+
constant int64_t & ne02,
|
|
5701
|
+
constant uint64_t & nb00,
|
|
5702
|
+
constant uint64_t & nb01,
|
|
5703
|
+
constant uint64_t & nb02,
|
|
5704
|
+
constant int64_t & ne10,
|
|
5705
|
+
constant int64_t & ne11,
|
|
5706
|
+
constant int64_t & ne12,
|
|
5707
|
+
constant int64_t & ne13,
|
|
5708
|
+
constant uint64_t & nb10,
|
|
5709
|
+
constant uint64_t & nb11,
|
|
5710
|
+
constant uint64_t & nb12,
|
|
5711
|
+
constant int64_t & ne0,
|
|
5712
|
+
constant int64_t & ne1,
|
|
5713
|
+
constant uint64_t & nb1,
|
|
5714
|
+
constant uint & r2,
|
|
5715
|
+
constant uint & r3,
|
|
5716
|
+
constant int & idx,
|
|
5717
|
+
device const char * src00,
|
|
5718
|
+
device const char * src01,
|
|
5719
|
+
device const char * src02,
|
|
5720
|
+
device const char * src03,
|
|
5721
|
+
device const char * src04,
|
|
5722
|
+
device const char * src05,
|
|
5723
|
+
device const char * src06,
|
|
5724
|
+
device const char * src07,
|
|
5725
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5726
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5727
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5728
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5729
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5730
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5731
|
+
|
|
5732
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5733
|
+
|
|
5734
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5735
|
+
|
|
5736
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5737
|
+
|
|
5738
|
+
kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5739
|
+
src0[id],
|
|
5740
|
+
(device const float *) (src1 + bid*nb11),
|
|
5741
|
+
dst + bid*ne0,
|
|
5742
|
+
ne00,
|
|
5743
|
+
ne01,
|
|
5744
|
+
ne02,
|
|
5745
|
+
ne10,
|
|
5746
|
+
ne12,
|
|
5747
|
+
ne0,
|
|
5748
|
+
ne1,
|
|
5749
|
+
r2,
|
|
5750
|
+
r3,
|
|
5751
|
+
shared_values,
|
|
5752
|
+
tgpig,
|
|
5753
|
+
tiisg,
|
|
5754
|
+
sgitg);
|
|
5755
|
+
}
|
|
5756
|
+
|
|
5757
|
+
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
|
5758
|
+
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
5759
|
+
device const char * ids,
|
|
5760
|
+
device const char * src1,
|
|
5761
|
+
device float * dst,
|
|
5762
|
+
constant uint64_t & nbi1,
|
|
5763
|
+
constant int64_t & ne00,
|
|
5764
|
+
constant int64_t & ne01,
|
|
5765
|
+
constant int64_t & ne02,
|
|
5766
|
+
constant uint64_t & nb00,
|
|
5767
|
+
constant uint64_t & nb01,
|
|
5768
|
+
constant uint64_t & nb02,
|
|
5769
|
+
constant int64_t & ne10,
|
|
5770
|
+
constant int64_t & ne11,
|
|
5771
|
+
constant int64_t & ne12,
|
|
5772
|
+
constant int64_t & ne13,
|
|
5773
|
+
constant uint64_t & nb10,
|
|
5774
|
+
constant uint64_t & nb11,
|
|
5775
|
+
constant uint64_t & nb12,
|
|
5776
|
+
constant int64_t & ne0,
|
|
5777
|
+
constant int64_t & ne1,
|
|
5778
|
+
constant uint64_t & nb1,
|
|
5779
|
+
constant uint & r2,
|
|
5780
|
+
constant uint & r3,
|
|
5781
|
+
constant int & idx,
|
|
5782
|
+
device const char * src00,
|
|
5783
|
+
device const char * src01,
|
|
5784
|
+
device const char * src02,
|
|
5785
|
+
device const char * src03,
|
|
5786
|
+
device const char * src04,
|
|
5787
|
+
device const char * src05,
|
|
5788
|
+
device const char * src06,
|
|
5789
|
+
device const char * src07,
|
|
5790
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5791
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5792
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5793
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5794
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5795
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5796
|
+
|
|
5797
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5798
|
+
|
|
5799
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5800
|
+
|
|
5801
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5802
|
+
|
|
5803
|
+
kernel_mul_mv_iq2_xs_f32_impl(
|
|
5804
|
+
src0[id],
|
|
5805
|
+
(device const float *) (src1 + bid*nb11),
|
|
5806
|
+
dst + bid*ne0,
|
|
4905
5807
|
ne00,
|
|
4906
5808
|
ne01,
|
|
4907
5809
|
ne02,
|
|
@@ -4911,6 +5813,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
4911
5813
|
ne1,
|
|
4912
5814
|
r2,
|
|
4913
5815
|
r3,
|
|
5816
|
+
shared_values,
|
|
4914
5817
|
tgpig,
|
|
4915
5818
|
tiisg,
|
|
4916
5819
|
sgitg);
|